Skip to content

Commit 4d037c3

Browse files
committed
Update SpringBootContextLoader to support AOT
Update `SpringBootContextLoader` so that it now implements the `AotContextLoader` interface. The `ContextLoaderHook` will abandon at `contextLoaded` if the test class is being AOT processed. This commit also introduces a new `AotApplicationContextInitializer` which allows us to plug-in an alternative AOT application context listener when the `SpringApplication` is running in test mode. Closes gh-31965
1 parent d1e7c9b commit 4d037c3

File tree

7 files changed

+469
-11
lines changed

7 files changed

+469
-11
lines changed

spring-boot-project/spring-boot-test/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ dependencies {
5555
testImplementation("org.slf4j:slf4j-api")
5656
testImplementation("org.spockframework:spock-core")
5757
testImplementation("org.springframework:spring-webmvc")
58+
testImplementation("org.springframework:spring-core-test")
5859
testImplementation("org.testng:testng")
5960

6061
testRuntimeOnly("org.junit.vintage:junit-vintage-engine")

spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/context/SpringBootContextLoader.java

+65-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.List;
2424

2525
import org.springframework.beans.BeanUtils;
26+
import org.springframework.boot.AotApplicationContextInitializer;
2627
import org.springframework.boot.ApplicationContextFactory;
2728
import org.springframework.boot.ConfigurableBootstrapContext;
2829
import org.springframework.boot.SpringApplication;
@@ -55,6 +56,8 @@
5556
import org.springframework.test.context.ContextCustomizer;
5657
import org.springframework.test.context.ContextLoader;
5758
import org.springframework.test.context.MergedContextConfiguration;
59+
import org.springframework.test.context.SmartContextLoader;
60+
import org.springframework.test.context.aot.AotContextLoader;
5861
import org.springframework.test.context.support.AbstractContextLoader;
5962
import org.springframework.test.context.support.AnnotationConfigContextLoaderUtils;
6063
import org.springframework.test.context.support.TestPropertySourceUtils;
@@ -90,15 +93,31 @@
9093
* @since 1.4.0
9194
* @see SpringBootTest
9295
*/
93-
public class SpringBootContextLoader extends AbstractContextLoader {
96+
public class SpringBootContextLoader extends AbstractContextLoader implements AotContextLoader {
9497

9598
@Override
9699
public ApplicationContext loadContext(MergedContextConfiguration mergedConfig) throws Exception {
100+
return loadContext(mergedConfig, Mode.STANDARD, null);
101+
}
102+
103+
@Override
104+
public ApplicationContext loadContextForAotProcessing(MergedContextConfiguration mergedConfig) throws Exception {
105+
return loadContext(mergedConfig, Mode.AOT_PROCESSING, null);
106+
}
107+
108+
@Override
109+
public ApplicationContext loadContextForAotRuntime(MergedContextConfiguration mergedConfig,
110+
ApplicationContextInitializer<ConfigurableApplicationContext> initializer) throws Exception {
111+
return loadContext(mergedConfig, Mode.AOT_RUNTIME, initializer);
112+
}
113+
114+
private ApplicationContext loadContext(MergedContextConfiguration mergedConfig, Mode mode,
115+
ApplicationContextInitializer<ConfigurableApplicationContext> initializer) {
97116
assertHasClassesOrLocations(mergedConfig);
98117
SpringBootTestAnnotation annotation = SpringBootTestAnnotation.get(mergedConfig);
99118
String[] args = annotation.getArgs();
100119
UseMainMethod useMainMethod = annotation.getUseMainMethod();
101-
ContextLoaderHook hook = new ContextLoaderHook(mergedConfig);
120+
ContextLoaderHook hook = new ContextLoaderHook(mergedConfig, mode, initializer);
102121
if (useMainMethod != UseMainMethod.NEVER) {
103122
Method mainMethod = getMainMethod(mergedConfig, useMainMethod);
104123
if (mainMethod != null) {
@@ -297,6 +316,31 @@ protected String getResourceSuffix() {
297316
throw new IllegalStateException();
298317
}
299318

319+
/**
320+
* Modes that the {@link SpringBootContextLoader} can operate.
321+
*/
322+
private enum Mode {
323+
324+
/**
325+
* Load for regular usage.
326+
* @see SmartContextLoader#loadContext
327+
*/
328+
STANDARD,
329+
330+
/**
331+
* Load for AOT processing.
332+
* @see AotContextLoader#loadContextForAotProcessing
333+
*/
334+
AOT_PROCESSING,
335+
336+
/**
337+
* Load for AOT runtime.
338+
* @see AotContextLoader#loadContextForAotRuntime
339+
*/
340+
AOT_RUNTIME
341+
342+
}
343+
300344
/**
301345
* Inner class to configure {@link WebMergedContextConfiguration}.
302346
*/
@@ -417,8 +461,15 @@ private class ContextLoaderHook implements SpringApplicationHook {
417461

418462
private final MergedContextConfiguration mergedConfig;
419463

420-
ContextLoaderHook(MergedContextConfiguration mergedConfig) {
464+
private final Mode mode;
465+
466+
private final ApplicationContextInitializer<ConfigurableApplicationContext> initializer;
467+
468+
ContextLoaderHook(MergedContextConfiguration mergedConfig, Mode mode,
469+
ApplicationContextInitializer<ConfigurableApplicationContext> initializer) {
421470
this.mergedConfig = mergedConfig;
471+
this.mode = mode;
472+
this.initializer = initializer;
422473
}
423474

424475
@Override
@@ -428,6 +479,17 @@ public SpringApplicationRunListener getRunListener(SpringApplication application
428479
@Override
429480
public void starting(ConfigurableBootstrapContext bootstrapContext) {
430481
SpringBootContextLoader.this.configure(ContextLoaderHook.this.mergedConfig, application);
482+
if (ContextLoaderHook.this.initializer != null) {
483+
application.addInitializers(
484+
AotApplicationContextInitializer.of(ContextLoaderHook.this.initializer));
485+
}
486+
}
487+
488+
@Override
489+
public void contextLoaded(ConfigurableApplicationContext context) {
490+
if (ContextLoaderHook.this.mode == Mode.AOT_PROCESSING) {
491+
throw new AbandonedRunException(context);
492+
}
431493
}
432494

433495
@Override

spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/context/SpringBootTestContextBootstrapper.java

+24-2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.springframework.test.context.TestContextAnnotationUtils;
5050
import org.springframework.test.context.TestContextBootstrapper;
5151
import org.springframework.test.context.TestExecutionListener;
52+
import org.springframework.test.context.aot.AotTestAttributes;
5253
import org.springframework.test.context.support.DefaultTestContextBootstrapper;
5354
import org.springframework.test.context.support.TestPropertySourceUtils;
5455
import org.springframework.test.context.web.WebAppConfiguration;
@@ -97,6 +98,16 @@ public class SpringBootTestContextBootstrapper extends DefaultTestContextBootstr
9798

9899
private static final Log logger = LogFactory.getLog(SpringBootTestContextBootstrapper.class);
99100

101+
private final AotTestAttributes aotTestAttributes;
102+
103+
public SpringBootTestContextBootstrapper() {
104+
this(AotTestAttributes.getInstance());
105+
}
106+
107+
SpringBootTestContextBootstrapper(AotTestAttributes aotTestAttributes) {
108+
this.aotTestAttributes = aotTestAttributes;
109+
}
110+
100111
@Override
101112
public TestContext buildTestContext() {
102113
TestContext context = super.buildTestContext();
@@ -231,14 +242,25 @@ protected Class<?>[] getOrFindConfigurationClasses(MergedContextConfiguration me
231242
if (containsNonTestComponent(classes) || mergedConfig.hasLocations()) {
232243
return classes;
233244
}
234-
Class<?> found = new AnnotatedClassFinder(SpringBootConfiguration.class)
235-
.findFromClass(mergedConfig.getTestClass());
245+
Class<?> found = findConfigurationClass(mergedConfig.getTestClass());
236246
Assert.state(found != null, "Unable to find a @SpringBootConfiguration, you need to use "
237247
+ "@ContextConfiguration or @SpringBootTest(classes=...) with your test");
238248
logger.info("Found @SpringBootConfiguration " + found.getName() + " for test " + mergedConfig.getTestClass());
239249
return merge(found, classes);
240250
}
241251

252+
private Class<?> findConfigurationClass(Class<?> testClass) {
253+
String propertyName = "%s.SpringBootConfiguration.%s"
254+
.formatted(SpringBootTestContextBootstrapper.class.getName(), testClass.getName());
255+
String foundClassName = this.aotTestAttributes.getString(propertyName);
256+
if (foundClassName != null) {
257+
return ClassUtils.resolveClassName(foundClassName, testClass.getClassLoader());
258+
}
259+
Class<?> found = new AnnotatedClassFinder(SpringBootConfiguration.class).findFromClass(testClass);
260+
this.aotTestAttributes.setAttribute(propertyName, found.getName());
261+
return found;
262+
}
263+
242264
private boolean containsNonTestComponent(Class<?>[] classes) {
243265
for (Class<?> candidate : classes) {
244266
if (!MergedAnnotations.from(candidate, SearchStrategy.INHERITED_ANNOTATIONS)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright 2012-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.boot.test.context;
18+
19+
import java.util.stream.Stream;
20+
21+
import org.junit.jupiter.api.Test;
22+
23+
import org.springframework.aot.AotDetector;
24+
import org.springframework.aot.generate.InMemoryGeneratedFiles;
25+
import org.springframework.aot.test.generate.compile.CompileWithTargetClassAccess;
26+
import org.springframework.aot.test.generate.compile.TestCompiler;
27+
import org.springframework.beans.factory.config.BeanDefinition;
28+
import org.springframework.beans.factory.support.GenericBeanDefinition;
29+
import org.springframework.boot.SpringBootConfiguration;
30+
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
31+
import org.springframework.context.ApplicationContextInitializer;
32+
import org.springframework.context.ConfigurableApplicationContext;
33+
import org.springframework.context.annotation.Import;
34+
import org.springframework.context.support.GenericApplicationContext;
35+
import org.springframework.test.context.BootstrapUtils;
36+
import org.springframework.test.context.MergedContextConfiguration;
37+
import org.springframework.test.context.TestContextBootstrapper;
38+
import org.springframework.test.context.aot.AotContextLoader;
39+
import org.springframework.test.context.aot.AotTestContextInitializers;
40+
import org.springframework.test.context.aot.TestContextAotGenerator;
41+
import org.springframework.test.util.ReflectionTestUtils;
42+
import org.springframework.util.ClassUtils;
43+
import org.springframework.util.function.ThrowingConsumer;
44+
45+
import static org.assertj.core.api.Assertions.assertThat;
46+
47+
/**
48+
* Tests for {@link SpringBootContextLoader} when used in AOT mode.
49+
*
50+
* @author Phillip Webb
51+
*/
52+
@CompileWithTargetClassAccess
53+
class SpringBootContextLoaderAotTests {
54+
55+
@Test
56+
void loadContextForAotProcessingAndAotRuntime() {
57+
InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles();
58+
TestContextAotGenerator generator = new TestContextAotGenerator(generatedFiles);
59+
Class<?> testClass = ExampleTest.class;
60+
generator.processAheadOfTime(Stream.of(testClass));
61+
TestCompiler.forSystem().withFiles(generatedFiles).printFiles(System.out)
62+
.compile(ThrowingConsumer.of((compiled) -> assertCompiledTest(testClass)));
63+
}
64+
65+
private void assertCompiledTest(Class<?> testClass) throws Exception {
66+
try {
67+
System.setProperty(AotDetector.AOT_ENABLED, "true");
68+
resetAotClasses();
69+
AotTestContextInitializers aotContextInitializers = new AotTestContextInitializers();
70+
TestContextBootstrapper testContextBootstrapper = BootstrapUtils.resolveTestContextBootstrapper(testClass);
71+
MergedContextConfiguration mergedConfig = testContextBootstrapper.buildMergedContextConfiguration();
72+
ApplicationContextInitializer<ConfigurableApplicationContext> contextInitializer = aotContextInitializers
73+
.getContextInitializer(testClass);
74+
ConfigurableApplicationContext context = (ConfigurableApplicationContext) ((AotContextLoader) mergedConfig
75+
.getContextLoader()).loadContextForAotRuntime(mergedConfig, contextInitializer);
76+
assertThat(context).isExactlyInstanceOf(GenericApplicationContext.class);
77+
String[] beanNames = context.getBeanNamesForType(ExampleBean.class);
78+
BeanDefinition beanDefinition = context.getBeanFactory().getBeanDefinition(beanNames[0]);
79+
assertThat(beanDefinition).isNotExactlyInstanceOf(GenericBeanDefinition.class);
80+
}
81+
finally {
82+
System.clearProperty(AotDetector.AOT_ENABLED);
83+
resetAotClasses();
84+
}
85+
}
86+
87+
private void resetAotClasses() {
88+
reset("org.springframework.test.context.aot.AotTestAttributesFactory");
89+
reset("org.springframework.test.context.aot.AotTestContextInitializersFactory");
90+
}
91+
92+
private void reset(String className) {
93+
Class<?> targetClass = ClassUtils.resolveClassName(className, null);
94+
ReflectionTestUtils.invokeMethod(targetClass, "reset");
95+
}
96+
97+
@SpringBootTest(classes = ExampleConfig.class, webEnvironment = WebEnvironment.NONE)
98+
static class ExampleTest {
99+
100+
}
101+
102+
@SpringBootConfiguration
103+
@Import(ExampleBean.class)
104+
static class ExampleConfig {
105+
106+
}
107+
108+
static class ExampleBean {
109+
110+
}
111+
112+
}

0 commit comments

Comments
 (0)