Skip to content

Commit ad68889

Browse files
mp911dechristophstrobl
authored andcommitted
Add AOT support for dynamic projections, streaming/scroll queries and Meta annotation.
Closes: #4970
1 parent 96ffade commit ad68889

File tree

9 files changed

+208
-32
lines changed

9 files changed

+208
-32
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

+45-17
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.bson.Document;
2424
import org.jspecify.annotations.NullUnmarked;
2525
import org.jspecify.annotations.Nullable;
26+
2627
import org.springframework.core.annotation.MergedAnnotation;
2728
import org.springframework.data.domain.SliceImpl;
2829
import org.springframework.data.domain.Sort.Order;
@@ -40,6 +41,7 @@
4041
import org.springframework.data.mongodb.core.query.BasicUpdate;
4142
import org.springframework.data.mongodb.core.query.Collation;
4243
import org.springframework.data.mongodb.repository.Hint;
44+
import org.springframework.data.mongodb.repository.Meta;
4345
import org.springframework.data.mongodb.repository.ReadPreference;
4446
import org.springframework.data.mongodb.repository.query.MongoQueryExecution.DeleteExecution;
4547
import org.springframework.data.mongodb.repository.query.MongoQueryExecution.PagedExecution;
@@ -256,15 +258,13 @@ CodeBlock build() {
256258
updateReference);
257259
} else if (ClassUtils.isAssignable(Long.class, returnType)) {
258260
builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()",
259-
context.localVariable("updater"), queryVariableName,
260-
updateReference);
261+
context.localVariable("updater"), queryVariableName, updateReference);
261262
} else {
262263
builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class,
263-
context.localVariable("modifiedCount"), context.localVariable("updater"),
264-
queryVariableName, updateReference);
264+
context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName,
265+
updateReference);
265266
builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class,
266-
context.localVariable("modifiedCount"),
267-
returnType);
267+
context.localVariable("modifiedCount"), returnType);
268268
}
269269

270270
return builder.build();
@@ -319,11 +319,9 @@ CodeBlock build() {
319319
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
320320

321321
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
322-
context.localVariable("results"), mongoOpsRef,
323-
aggregationVariableName, outputType);
322+
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
324323
if (!queryMethod.isCollectionQuery()) {
325-
builder.addStatement(
326-
"return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))",
324+
builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))",
327325
CollectionUtils.class, returnType, returnType, context.localVariable("results"));
328326
} else {
329327
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
@@ -332,8 +330,7 @@ CodeBlock build() {
332330
} else {
333331
if (queryMethod.isSliceQuery()) {
334332
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
335-
context.localVariable("results"), mongoOpsRef,
336-
aggregationVariableName, outputType);
333+
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
337334
builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()",
338335
context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName());
339336
builder.addStatement(
@@ -378,12 +375,16 @@ CodeBlock build() {
378375

379376
boolean isProjecting = context.getReturnedType().isProjecting();
380377
Class<?> domainType = context.getRepositoryInformation().getDomainType();
381-
Object actualReturnType = isProjecting ? context.getActualReturnType().getType()
378+
Object actualReturnType = queryMethod.getParameters().hasDynamicProjection() || isProjecting
379+
? TypeName.get(context.getActualReturnType().getType())
382380
: domainType;
383381

384382
builder.add("\n");
385383

386-
if (isProjecting) {
384+
if (queryMethod.getParameters().hasDynamicProjection()) {
385+
builder.addStatement("$T<$T> $L = $L.query($T.class).as($L)", FindWithQuery.class, actualReturnType,
386+
context.localVariable("finder"), mongoOpsRef, domainType, context.getDynamicProjectionParameterName());
387+
} else if (isProjecting) {
387388
builder.addStatement("$T<$T> $L = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType,
388389
context.localVariable("finder"), mongoOpsRef, domainType, actualReturnType);
389390
} else {
@@ -400,6 +401,8 @@ CodeBlock build() {
400401
terminatingMethod = "count()";
401402
} else if (query.isExists()) {
402403
terminatingMethod = "exists()";
404+
} else if (queryMethod.isStreamQuery()) {
405+
terminatingMethod = "stream()";
403406
} else {
404407
terminatingMethod = Optional.class.isAssignableFrom(context.getReturnType().toClass()) ? "one()" : "oneValue()";
405408
}
@@ -410,6 +413,12 @@ CodeBlock build() {
410413
} else if (queryMethod.isSliceQuery()) {
411414
builder.addStatement("return new $T($L, $L).execute($L)", SlicedExecution.class,
412415
context.localVariable("finder"), context.getPageableParameterName(), query.name());
416+
} else if (queryMethod.isScrollQuery()) {
417+
418+
String scrollPositionParameterName = context.getScrollPositionParameterName();
419+
420+
builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(),
421+
scrollPositionParameterName);
413422
} else {
414423
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
415424
terminatingMethod);
@@ -544,8 +553,7 @@ private CodeBlock aggregationOptions(String aggregationVariableName) {
544553

545554
Builder optionsBuilder = CodeBlock.builder();
546555
optionsBuilder.add("$T $L = $T.builder()\n", AggregationOptions.class,
547-
context.localVariable("aggregationOptions"),
548-
AggregationOptions.class);
556+
context.localVariable("aggregationOptions"), AggregationOptions.class);
549557
optionsBuilder.indent();
550558
for (CodeBlock optionBlock : options) {
551559
optionsBuilder.add(optionBlock);
@@ -709,7 +717,27 @@ CodeBlock build() {
709717
com.mongodb.ReadPreference.class, readPreference);
710718
}
711719

712-
// TODO: Meta annotation
720+
MergedAnnotation<Meta> metaAnnotation = context.getAnnotation(Meta.class);
721+
722+
if (metaAnnotation.isPresent()) {
723+
724+
long maxExecutionTimeMs = metaAnnotation.getLong("maxExecutionTimeMs");
725+
if (maxExecutionTimeMs != -1) {
726+
builder.addStatement("$L.maxTimeMsec($L)", queryVariableName, maxExecutionTimeMs);
727+
}
728+
729+
int cursorBatchSize = metaAnnotation.getInt("cursorBatchSize");
730+
if (cursorBatchSize != 0) {
731+
builder.addStatement("$L.cursorBatchSize($L)", queryVariableName, cursorBatchSize);
732+
}
733+
734+
String comment = metaAnnotation.getString("comment");
735+
if (StringUtils.hasText("comment")) {
736+
builder.addStatement("$L.comment($S)", queryVariableName, comment);
737+
}
738+
}
739+
740+
// TODO: Meta annotation: Disk usage
713741

714742
return builder.build();
715743
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
* MongoDB specific {@link RepositoryContributor}.
4949
*
5050
* @author Christoph Strobl
51+
* @author Mark Paluch
5152
* @since 5.0
5253
*/
5354
public class MongoRepositoryContributor extends RepositoryContributor {
@@ -159,8 +160,7 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor
159160

160161
private static boolean backoff(MongoQueryMethod method) {
161162

162-
boolean skip = method.isGeoNearQuery() || method.isScrollQuery() || method.isStreamQuery()
163-
|| method.isSearchQuery();
163+
boolean skip = method.isGeoNearQuery() || method.isSearchQuery();
164164

165165
if (skip && logger.isDebugEnabled()) {
166166
logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming, search or scrolling query"
@@ -225,8 +225,7 @@ private static MethodContributor<MongoQueryMethod> aggregationUpdateMethodContri
225225
.usingAggregationVariableName(updateVariableName).pipelineOnly(true).build());
226226

227227
builder.addStatement("$T $L = $T.from($L.getOperations())", AggregationUpdate.class,
228-
context.localVariable("aggregationUpdate"),
229-
AggregationUpdate.class, updateVariableName);
228+
context.localVariable("aggregationUpdate"), AggregationUpdate.class, updateVariableName);
230229

231230
builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName)
232231
.referencingUpdate(context.localVariable("aggregationUpdate")).build());

spring-data-mongodb/src/test/java/example/aot/UserRepository.java

+9-4
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@
2222
import java.util.Objects;
2323
import java.util.Optional;
2424
import java.util.Set;
25+
import java.util.stream.Stream;
2526

2627
import org.springframework.data.annotation.Id;
2728
import org.springframework.data.domain.Limit;
2829
import org.springframework.data.domain.Page;
2930
import org.springframework.data.domain.Pageable;
31+
import org.springframework.data.domain.ScrollPosition;
3032
import org.springframework.data.domain.Slice;
3133
import org.springframework.data.domain.Sort;
34+
import org.springframework.data.domain.Window;
3235
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
3336
import org.springframework.data.mongodb.repository.Aggregation;
3437
import org.springframework.data.mongodb.repository.Hint;
@@ -94,8 +97,10 @@ public interface UserRepository extends CrudRepository<User, String> {
9497

9598
Slice<User> findSliceOfUserByLastnameStartingWith(String lastname, Pageable page);
9699

97-
// TODO: Streaming
98-
// TODO: Scrolling
100+
Stream<User> streamByLastnameStartingWith(String lastname, Sort sort, Limit limit);
101+
102+
Window<User> findTop2WindowByLastnameStartingWithOrderByUsername(String lastname, ScrollPosition scrollPosition);
103+
99104
// TODO: GeoQueries
100105
// TODO: TextSearch
101106

@@ -176,14 +181,14 @@ public interface UserRepository extends CrudRepository<User, String> {
176181
@ReadPreference("no-such-read-preference")
177182
User findWithReadPreferenceByUsername(String username);
178183

179-
// TODO: hints
180-
181184
/* Projecting Queries */
182185

183186
List<UserProjection> findUserProjectionByLastnameStartingWith(String lastname);
184187

185188
Page<UserProjection> findUserProjectionByLastnameStartingWith(String lastname, Pageable page);
186189

190+
<T> Page<T> findUserProjectionByLastnameStartingWith(String lastname, Pageable page, Class<T> projectionType);
191+
187192
/* Aggregations */
188193

189194
@Aggregation(pipeline = { //

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
* This configuration generates the AOT repository, compiles sources and configures a BeanFactory to contain the AOT
4141
* fragment. Additionally, the fragment is exposed through a {@code repositoryInterface} JDK proxy forwarding method
4242
* invocations to the backing AOT fragment. Note that {@code repositoryInterface} is not a repository proxy.
43-
*
43+
*
4444
* @author Christoph Strobl
4545
*/
4646
public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProcessor {
@@ -62,7 +62,8 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
6262
new MongoRepositoryContributor(repositoryContext).contribute(generationContext);
6363

6464
AbstractBeanDefinition aotGeneratedRepository = BeanDefinitionBuilder
65-
.genericBeanDefinition(repositoryInterface.getName() + "Impl__Aot") //
65+
.genericBeanDefinition(
66+
repositoryInterface.getPackageName() + "." + repositoryInterface.getSimpleName() + "Impl__Aot") //
6667
.addConstructorArgReference("mongoOperations") //
6768
.addConstructorArgValue(getCreationContext(repositoryContext)).getBeanDefinition();
6869

@@ -80,6 +81,8 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
8081
}).getBeanDefinition();
8182

8283
((BeanDefinitionRegistry) beanFactory).registerBeanDefinition("fragmentFacade", fragmentFacade);
84+
85+
beanFactory.registerSingleton("generationContext", generationContext);
8386
}
8487

8588
private Object getFragmentFacadeProxy(Object fragment) {

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java

+43
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,15 @@
3434
import org.springframework.beans.factory.annotation.Autowired;
3535
import org.springframework.context.annotation.Bean;
3636
import org.springframework.context.annotation.Configuration;
37+
import org.springframework.data.domain.KeysetScrollPosition;
3738
import org.springframework.data.domain.Limit;
39+
import org.springframework.data.domain.OffsetScrollPosition;
3840
import org.springframework.data.domain.Page;
3941
import org.springframework.data.domain.PageRequest;
42+
import org.springframework.data.domain.ScrollPosition;
4043
import org.springframework.data.domain.Slice;
4144
import org.springframework.data.domain.Sort;
45+
import org.springframework.data.domain.Window;
4246
import org.springframework.data.mongodb.core.MongoOperations;
4347
import org.springframework.data.mongodb.core.MongoTemplate;
4448
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
@@ -271,6 +275,37 @@ void testDerivedFinderReturningSlice() {
271275
assertThat(slice.getContent()).extracting(User::getUsername).containsExactly("han", "kylo");
272276
}
273277

278+
@Test
279+
void testDerivedQueryReturningStream() {
280+
281+
List<User> results = fragment.streamByLastnameStartingWith("S", Sort.by("username"), Limit.of(2)).toList();
282+
283+
assertThat(results).hasSize(2);
284+
assertThat(results).extracting(User::getUsername).containsExactly("han", "kylo");
285+
}
286+
287+
@Test
288+
void testDerivedQueryReturningWindowByOffset() {
289+
290+
Window<User> window1 = fragment.findTop2WindowByLastnameStartingWithOrderByUsername("S", ScrollPosition.offset());
291+
assertThat(window1).extracting(User::getUsername).containsExactly("han", "kylo");
292+
assertThat(window1.positionAt(1)).isInstanceOf(OffsetScrollPosition.class);
293+
294+
Window<User> window2 = fragment.findTop2WindowByLastnameStartingWithOrderByUsername("S", window1.positionAt(1));
295+
assertThat(window2).extracting(User::getUsername).containsExactly("luke", "vader");
296+
}
297+
298+
@Test
299+
void testDerivedQueryReturningWindowByKeyset() {
300+
301+
Window<User> window1 = fragment.findTop2WindowByLastnameStartingWithOrderByUsername("S", ScrollPosition.keyset());
302+
assertThat(window1).extracting(User::getUsername).containsExactly("han", "kylo");
303+
assertThat(window1.positionAt(1)).isInstanceOf(KeysetScrollPosition.class);
304+
305+
Window<User> window2 = fragment.findTop2WindowByLastnameStartingWithOrderByUsername("S", window1.positionAt(1));
306+
assertThat(window2).extracting(User::getUsername).containsExactly("luke", "vader");
307+
}
308+
274309
@Test
275310
void testAnnotatedFinderReturningSingleValueWithQuery() {
276311

@@ -439,6 +474,14 @@ void testDerivedFinderReturningPageOfProjections() {
439474
assertThat(users).extracting(UserProjection::getUsername).containsExactly("han", "kylo");
440475
}
441476

477+
@Test
478+
void testDerivedFinderReturningPageOfDynamicProjections() {
479+
480+
Page<UserProjection> users = fragment.findUserProjectionByLastnameStartingWith("S",
481+
PageRequest.of(0, 2, Sort.by("username")), UserProjection.class);
482+
assertThat(users).extracting(UserProjection::getUsername).containsExactly("han", "kylo");
483+
}
484+
442485
@Test
443486
void testUpdateWithDerivedQuery() {
444487

0 commit comments

Comments
 (0)