19
19
import java .util .List ;
20
20
import java .util .Optional ;
21
21
import java .util .regex .Pattern ;
22
+ import java .util .stream .Stream ;
22
23
23
24
import org .bson .Document ;
24
25
import org .jspecify .annotations .NullUnmarked ;
49
50
import org .springframework .data .mongodb .repository .query .MongoQueryMethod ;
50
51
import org .springframework .data .repository .aot .generate .AotQueryMethodGenerationContext ;
51
52
import org .springframework .data .util .ReflectionUtils ;
52
- import org .springframework .javapoet .ClassName ;
53
53
import org .springframework .javapoet .CodeBlock ;
54
54
import org .springframework .javapoet .CodeBlock .Builder ;
55
55
import org .springframework .javapoet .TypeName ;
@@ -182,17 +182,15 @@ CodeBlock build() {
182
182
String mongoOpsRef = context .fieldNameOf (MongoOperations .class );
183
183
Builder builder = CodeBlock .builder ();
184
184
185
+ Class <?> domainType = context .getRepositoryInformation ().getDomainType ();
185
186
boolean isProjecting = context .getActualReturnType () != null
186
- && !ObjectUtils .nullSafeEquals (TypeName .get (context .getRepositoryInformation ().getDomainType ()),
187
- context .getActualReturnType ());
187
+ && !ObjectUtils .nullSafeEquals (TypeName .get (domainType ), context .getActualReturnType ());
188
188
189
- Object actualReturnType = isProjecting ? context .getActualReturnType ().getType ()
190
- : context .getRepositoryInformation ().getDomainType ();
189
+ Object actualReturnType = isProjecting ? context .getActualReturnType ().getType () : domainType ;
191
190
192
191
builder .add ("\n " );
193
- builder .addStatement ("$T<$T> remover = $L.remove($T.class)" , ExecutableRemove .class ,
194
- context .getRepositoryInformation ().getDomainType (), mongoOpsRef ,
195
- context .getRepositoryInformation ().getDomainType ());
192
+ builder .addStatement ("$T<$T> $L = $L.remove($T.class)" , ExecutableRemove .class , domainType ,
193
+ context .localVariable ("remover" ), mongoOpsRef , domainType );
196
194
197
195
DeleteExecution .Type type = DeleteExecution .Type .FIND_AND_REMOVE_ALL ;
198
196
if (!queryMethod .isCollectionQuery ()) {
@@ -204,11 +202,20 @@ CodeBlock build() {
204
202
}
205
203
206
204
actualReturnType = ClassUtils .isPrimitiveOrWrapper (context .getMethod ().getReturnType ())
207
- ? ClassName .get (context .getMethod ().getReturnType ())
205
+ ? TypeName .get (context .getMethod ().getReturnType ())
208
206
: queryMethod .isCollectionQuery () ? context .getReturnTypeName () : actualReturnType ;
209
207
210
- builder .addStatement ("return ($T) new $T(remover, $T.$L).execute($L)" , actualReturnType , DeleteExecution .class ,
211
- DeleteExecution .Type .class , type .name (), queryVariableName );
208
+ if (ClassUtils .isVoidType (context .getMethod ().getReturnType ())) {
209
+ builder .addStatement ("new $T($L, $T.$L).execute($L)" , DeleteExecution .class , context .localVariable ("remover" ),
210
+ DeleteExecution .Type .class , type .name (), queryVariableName );
211
+ } else if (context .getMethod ().getReturnType () == Optional .class ) {
212
+ builder .addStatement ("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))" , Optional .class ,
213
+ actualReturnType , DeleteExecution .class , context .localVariable ("remover" ), DeleteExecution .Type .class ,
214
+ type .name (), queryVariableName );
215
+ } else {
216
+ builder .addStatement ("return ($T) new $T($L, $T.$L).execute($L)" , actualReturnType , DeleteExecution .class ,
217
+ context .localVariable ("remover" ), DeleteExecution .Type .class , type .name (), queryVariableName );
218
+ }
212
219
213
220
return builder .build ();
214
221
}
@@ -318,14 +325,25 @@ CodeBlock build() {
318
325
319
326
Class <?> returnType = ClassUtils .resolvePrimitiveIfNecessary (queryMethod .getReturnedObjectType ());
320
327
321
- builder .addStatement ("$T $L = $L.aggregate($L, $T.class)" , AggregationResults .class ,
322
- context .localVariable ("results" ), mongoOpsRef , aggregationVariableName , outputType );
323
- if (!queryMethod .isCollectionQuery ()) {
324
- builder .addStatement ("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))" ,
325
- CollectionUtils .class , returnType , returnType , context .localVariable ("results" ));
328
+ if (queryMethod .isStreamQuery ()) {
329
+
330
+ builder .addStatement ("$T<$T> $L = $L.aggregateStream($L, $T.class)" , Stream .class , Document .class ,
331
+ context .localVariable ("results" ), mongoOpsRef , aggregationVariableName , outputType );
332
+
333
+ builder .addStatement ("return $L.map(it -> ($T) convertSimpleRawResult($T.class, it))" ,
334
+ context .localVariable ("results" ), returnType , returnType );
326
335
} else {
327
- builder .addStatement ("return convertSimpleRawResults($T.class, $L.getMappedResults())" , returnType ,
328
- context .localVariable ("results" ));
336
+
337
+ builder .addStatement ("$T $L = $L.aggregate($L, $T.class)" , AggregationResults .class ,
338
+ context .localVariable ("results" ), mongoOpsRef , aggregationVariableName , outputType );
339
+
340
+ if (!queryMethod .isCollectionQuery ()) {
341
+ builder .addStatement ("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))" ,
342
+ CollectionUtils .class , returnType , returnType , context .localVariable ("results" ));
343
+ } else {
344
+ builder .addStatement ("return convertSimpleRawResults($T.class, $L.getMappedResults())" , returnType ,
345
+ context .localVariable ("results" ));
346
+ }
329
347
}
330
348
} else {
331
349
if (queryMethod .isSliceQuery ()) {
@@ -339,8 +357,15 @@ CodeBlock build() {
339
357
context .getPageableParameterName (), context .localVariable ("results" ), context .getPageableParameterName (),
340
358
context .localVariable ("hasNext" ));
341
359
} else {
342
- builder .addStatement ("return $L.aggregate($L, $T.class).getMappedResults()" , mongoOpsRef ,
343
- aggregationVariableName , outputType );
360
+
361
+ if (queryMethod .isStreamQuery ()) {
362
+ builder .addStatement ("return $L.aggregateStream($L, $T.class)" , mongoOpsRef , aggregationVariableName ,
363
+ outputType );
364
+ } else {
365
+
366
+ builder .addStatement ("return $L.aggregate($L, $T.class).getMappedResults()" , mongoOpsRef ,
367
+ aggregationVariableName , outputType );
368
+ }
344
369
}
345
370
}
346
371
@@ -420,8 +445,16 @@ CodeBlock build() {
420
445
builder .addStatement ("return $L.matching($L).scroll($L)" , context .localVariable ("finder" ), query .name (),
421
446
scrollPositionParameterName );
422
447
} else {
423
- builder .addStatement ("return $L.matching($L).$L" , context .localVariable ("finder" ), query .name (),
424
- terminatingMethod );
448
+ if (query .isCount () && !ClassUtils .isAssignable (Long .class , context .getActualReturnType ().getRawClass ())) {
449
+
450
+ Class <?> returnType = ClassUtils .resolvePrimitiveIfNecessary (queryMethod .getReturnedObjectType ());
451
+ builder .addStatement ("return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)" , NumberUtils .class ,
452
+ context .localVariable ("finder" ), query .name (), terminatingMethod , returnType );
453
+
454
+ } else {
455
+ builder .addStatement ("return $L.matching($L).$L" , context .localVariable ("finder" ), query .name (),
456
+ terminatingMethod );
457
+ }
425
458
}
426
459
427
460
return builder .build ();
0 commit comments