Skip to content

Commit e13043e

Browse files
mp911dechristophstrobl
authored andcommitted
Fix aggregation streams, count result conversion.
See: #4939 Original Pull Request: #4970
1 parent ad68889 commit e13043e

File tree

7 files changed

+139
-43
lines changed

7 files changed

+139
-43
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java

+4
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ protected List<Object> convertSimpleRawResults(Class<?> targetType, List<Documen
103103
return list;
104104
}
105105

106+
protected Object convertSimpleRawResult(Class<?> targetType, Document rawResult) {
107+
return extractSimpleTypeResult(rawResult, targetType, mongoConverter);
108+
}
109+
106110
private static <T> @Nullable T extractSimpleTypeResult(@Nullable Document source, Class<T> targetType,
107111
MongoConverter converter) {
108112

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

+55-22
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.List;
2020
import java.util.Optional;
2121
import java.util.regex.Pattern;
22+
import java.util.stream.Stream;
2223

2324
import org.bson.Document;
2425
import org.jspecify.annotations.NullUnmarked;
@@ -49,7 +50,6 @@
4950
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
5051
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
5152
import org.springframework.data.util.ReflectionUtils;
52-
import org.springframework.javapoet.ClassName;
5353
import org.springframework.javapoet.CodeBlock;
5454
import org.springframework.javapoet.CodeBlock.Builder;
5555
import org.springframework.javapoet.TypeName;
@@ -182,17 +182,15 @@ CodeBlock build() {
182182
String mongoOpsRef = context.fieldNameOf(MongoOperations.class);
183183
Builder builder = CodeBlock.builder();
184184

185+
Class<?> domainType = context.getRepositoryInformation().getDomainType();
185186
boolean isProjecting = context.getActualReturnType() != null
186-
&& !ObjectUtils.nullSafeEquals(TypeName.get(context.getRepositoryInformation().getDomainType()),
187-
context.getActualReturnType());
187+
&& !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType());
188188

189-
Object actualReturnType = isProjecting ? context.getActualReturnType().getType()
190-
: context.getRepositoryInformation().getDomainType();
189+
Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType;
191190

192191
builder.add("\n");
193-
builder.addStatement("$T<$T> remover = $L.remove($T.class)", ExecutableRemove.class,
194-
context.getRepositoryInformation().getDomainType(), mongoOpsRef,
195-
context.getRepositoryInformation().getDomainType());
192+
builder.addStatement("$T<$T> $L = $L.remove($T.class)", ExecutableRemove.class, domainType,
193+
context.localVariable("remover"), mongoOpsRef, domainType);
196194

197195
DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL;
198196
if (!queryMethod.isCollectionQuery()) {
@@ -204,11 +202,20 @@ CodeBlock build() {
204202
}
205203

206204
actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())
207-
? ClassName.get(context.getMethod().getReturnType())
205+
? TypeName.get(context.getMethod().getReturnType())
208206
: queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType;
209207

210-
builder.addStatement("return ($T) new $T(remover, $T.$L).execute($L)", actualReturnType, DeleteExecution.class,
211-
DeleteExecution.Type.class, type.name(), queryVariableName);
208+
if (ClassUtils.isVoidType(context.getMethod().getReturnType())) {
209+
builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"),
210+
DeleteExecution.Type.class, type.name(), queryVariableName);
211+
} else if (context.getMethod().getReturnType() == Optional.class) {
212+
builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class,
213+
actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class,
214+
type.name(), queryVariableName);
215+
} else {
216+
builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class,
217+
context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName);
218+
}
212219

213220
return builder.build();
214221
}
@@ -318,14 +325,25 @@ CodeBlock build() {
318325

319326
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
320327

321-
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
322-
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
323-
if (!queryMethod.isCollectionQuery()) {
324-
builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))",
325-
CollectionUtils.class, returnType, returnType, context.localVariable("results"));
328+
if (queryMethod.isStreamQuery()) {
329+
330+
builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class,
331+
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
332+
333+
builder.addStatement("return $L.map(it -> ($T) convertSimpleRawResult($T.class, it))",
334+
context.localVariable("results"), returnType, returnType);
326335
} else {
327-
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
328-
context.localVariable("results"));
336+
337+
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
338+
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
339+
340+
if (!queryMethod.isCollectionQuery()) {
341+
builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))",
342+
CollectionUtils.class, returnType, returnType, context.localVariable("results"));
343+
} else {
344+
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
345+
context.localVariable("results"));
346+
}
329347
}
330348
} else {
331349
if (queryMethod.isSliceQuery()) {
@@ -339,8 +357,15 @@ CodeBlock build() {
339357
context.getPageableParameterName(), context.localVariable("results"), context.getPageableParameterName(),
340358
context.localVariable("hasNext"));
341359
} else {
342-
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
343-
aggregationVariableName, outputType);
360+
361+
if (queryMethod.isStreamQuery()) {
362+
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
363+
outputType);
364+
} else {
365+
366+
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
367+
aggregationVariableName, outputType);
368+
}
344369
}
345370
}
346371

@@ -420,8 +445,16 @@ CodeBlock build() {
420445
builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(),
421446
scrollPositionParameterName);
422447
} else {
423-
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
424-
terminatingMethod);
448+
if (query.isCount() && !ClassUtils.isAssignable(Long.class, context.getActualReturnType().getRawClass())) {
449+
450+
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
451+
builder.addStatement("return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)", NumberUtils.class,
452+
context.localVariable("finder"), query.name(), terminatingMethod, returnType);
453+
454+
} else {
455+
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
456+
terminatingMethod);
457+
}
425458
}
426459

427460
return builder.build();

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

+30-12
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*;
1919

2020
import java.lang.reflect.Method;
21+
import java.util.Locale;
2122
import java.util.regex.Pattern;
2223

2324
import org.apache.commons.logging.Log;
@@ -119,14 +120,23 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB
119120

120121
if (queryMethod.isModifyingQuery()) {
121122

122-
Update updateSource = queryMethod.getUpdateSource();
123-
if (StringUtils.hasText(updateSource.value())) {
124-
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()));
123+
int updateIndex = queryMethod.getParameters().getUpdateIndex();
124+
if (updateIndex != -1) {
125+
126+
UpdateInteraction update = new UpdateInteraction(query, null, updateIndex);
125127
return updateMethodContributor(queryMethod, update);
126-
}
127-
if (!ObjectUtils.isEmpty(updateSource.pipeline())) {
128-
AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline());
129-
return aggregationUpdateMethodContributor(queryMethod, update);
128+
129+
} else {
130+
Update updateSource = queryMethod.getUpdateSource();
131+
if (StringUtils.hasText(updateSource.value())) {
132+
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()), null);
133+
return updateMethodContributor(queryMethod, update);
134+
}
135+
136+
if (!ObjectUtils.isEmpty(updateSource.pipeline())) {
137+
AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline());
138+
return aggregationUpdateMethodContributor(queryMethod, update);
139+
}
130140
}
131141
}
132142

@@ -160,10 +170,12 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor
160170

161171
private static boolean backoff(MongoQueryMethod method) {
162172

163-
boolean skip = method.isGeoNearQuery() || method.isSearchQuery();
173+
// TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays.
174+
boolean skip = method.isGeoNearQuery() || method.isSearchQuery()
175+
|| method.getName().toLowerCase(Locale.ROOT).contains("regex") || method.getReturnType().getType().isArray();
164176

165177
if (skip && logger.isDebugEnabled()) {
166-
logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming, search or scrolling query"
178+
logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query"
167179
.formatted(method.getName()));
168180
}
169181
return skip;
@@ -197,9 +209,15 @@ private static MethodContributor<MongoQueryMethod> updateMethodContributor(Mongo
197209
.usingQueryVariableName(filterVariableName).build());
198210

199211
// update definition
200-
String updateVariableName = context.localVariable("updateDefinition");
201-
builder.add(
202-
updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName).build());
212+
String updateVariableName;
213+
214+
if (update.hasUpdateDefinitionParameter()) {
215+
updateVariableName = context.getParameterName(update.getRequiredUpdateDefinitionParameter());
216+
} else {
217+
updateVariableName = context.localVariable("updateDefinition");
218+
builder.add(updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName)
219+
.build());
220+
}
203221

204222
builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName)
205223
.referencingUpdate(updateVariableName).build());

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateInteraction.java

+29-5
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,66 @@
1717

1818
import java.util.Map;
1919

20+
import org.jspecify.annotations.Nullable;
21+
2022
import org.springframework.data.repository.aot.generate.QueryMetadata;
23+
import org.springframework.util.Assert;
2124

2225
/**
2326
* An {@link MongoInteraction} to execute an update.
2427
*
2528
* @author Christoph Strobl
29+
* @author Mark Paluch
2630
* @since 5.0
2731
*/
2832
class UpdateInteraction extends MongoInteraction implements QueryMetadata {
2933

3034
private final QueryInteraction filter;
31-
private final StringUpdate update;
35+
private final @Nullable StringUpdate update;
36+
private final @Nullable Integer updateDefinitionParameter;
3237

33-
UpdateInteraction(QueryInteraction filter, StringUpdate update) {
38+
UpdateInteraction(QueryInteraction filter, @Nullable StringUpdate update,
39+
@Nullable Integer updateDefinitionParameter) {
3440
this.filter = filter;
3541
this.update = update;
42+
this.updateDefinitionParameter = updateDefinitionParameter;
3643
}
3744

38-
QueryInteraction getFilter() {
45+
public QueryInteraction getFilter() {
3946
return filter;
4047
}
4148

42-
StringUpdate getUpdate() {
49+
public @Nullable StringUpdate getUpdate() {
4350
return update;
4451
}
4552

53+
public int getRequiredUpdateDefinitionParameter() {
54+
55+
Assert.notNull(updateDefinitionParameter, "UpdateDefinitionParameter must not be null!");
56+
57+
return updateDefinitionParameter;
58+
}
59+
60+
public boolean hasUpdateDefinitionParameter() {
61+
return updateDefinitionParameter != null;
62+
}
63+
4664
@Override
4765
public Map<String, Object> serialize() {
4866

4967
Map<String, Object> serialized = filter.serialize();
50-
serialized.put("update", update.getUpdateString());
68+
69+
if (update != null) {
70+
serialized.put("filter", filter.getQuery().getQueryString());
71+
serialized.put("update", update.getUpdateString());
72+
}
73+
5174
return serialized;
5275
}
5376

5477
@Override
5578
InteractionType getExecutionType() {
5679
return InteractionType.UPDATE;
5780
}
81+
5882
}

spring-data-mongodb/src/test/java/example/aot/UserRepository.java

+7
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ public interface UserRepository extends CrudRepository<User, String> {
5555

5656
Long countUsersByLastname(String lastname);
5757

58+
int countUsersAsIntByLastname(String lastname);
59+
5860
Boolean existsUserByLastname(String lastname);
5961

6062
List<User> findByLastnameStartingWith(String lastname);
@@ -216,6 +218,11 @@ public interface UserRepository extends CrudRepository<User, String> {
216218
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
217219
AggregationResults<UserAggregate> groupByLastnameAndAsAggregationResults(String property);
218220

221+
@Aggregation(pipeline = { //
222+
"{ '$match' : { 'last_name' : { '$ne' : null } } }", //
223+
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
224+
Stream<UserAggregate> streamGroupByLastnameAndAsAggregationResults(String property);
225+
219226
@Aggregation(pipeline = { //
220227
"{ '$match' : { 'posts' : { '$ne' : null } } }", //
221228
"{ '$project': { 'nrPosts' : {'$size': '$posts' } } }", //

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java

+12-2
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ void testFindDerivedFinderOptionalEntity() {
107107
@Test
108108
void testDerivedCount() {
109109

110-
Long value = fragment.countUsersByLastname("Skywalker");
111-
assertThat(value).isEqualTo(2L);
110+
assertThat(fragment.countUsersByLastname("Skywalker")).isEqualTo(2L);
111+
assertThat(fragment.countUsersAsIntByLastname("Skywalker")).isEqualTo(2);
112112
}
113113

114114
@Test
@@ -559,6 +559,16 @@ void testAggregationWithProjectedResultsWrappedInAggregationResults() {
559559
new UserAggregate("Solo", List.of("Han", "Ben")));
560560
}
561561

562+
@Test
563+
void testAggregationStreamWithProjectedResultsWrappedInAggregationResults() {
564+
565+
List<UserAggregate> allLastnames = fragment.streamGroupByLastnameAndAsAggregationResults("first_name").toList();
566+
assertThat(allLastnames).containsExactlyInAnyOrder(//
567+
new UserAggregate("Skywalker", List.of("Anakin", "Luke")), //
568+
new UserAggregate("Organa", List.of("Leia")), //
569+
new UserAggregate("Solo", List.of("Han", "Ben")));
570+
}
571+
562572
@Test
563573
void testAggregationWithSingleResultExtraction() {
564574
assertThat(fragment.sumPosts()).isEqualTo(5);

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@
3737
/**
3838
* @author Christoph Strobl
3939
*/
40-
class TestMongoAotRepositoryContext implements AotRepositoryContext {
40+
public class TestMongoAotRepositoryContext implements AotRepositoryContext {
4141

4242
private final StubRepositoryInformation repositoryInformation;
4343
private final Environment environment = new StandardEnvironment();
4444

45-
TestMongoAotRepositoryContext(Class<?> repositoryInterface, @Nullable RepositoryComposition composition) {
45+
public TestMongoAotRepositoryContext(Class<?> repositoryInterface, @Nullable RepositoryComposition composition) {
4646
this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition);
4747
}
4848

0 commit comments

Comments
 (0)