Skip to content

Commit 7fb9b2a

Browse files
hacking
1 parent cc7449d commit 7fb9b2a

12 files changed

+362
-206
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java

+12-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import java.util.function.Predicate;
2323

2424
import org.bson.Document;
25+
import org.jspecify.annotations.Nullable;
2526
import org.springframework.lang.Contract;
2627
import org.springframework.util.Assert;
28+
import org.springframework.util.CollectionUtils;
2729

2830
/**
2931
* The {@link AggregationPipeline} holds the collection of {@link AggregationOperation aggregation stages}.
@@ -82,6 +84,14 @@ public List<AggregationOperation> getOperations() {
8284
return Collections.unmodifiableList(pipeline);
8385
}
8486

87+
public @Nullable AggregationOperation firstOperation() {
88+
return CollectionUtils.firstElement(pipeline);
89+
}
90+
91+
public @Nullable AggregationOperation lastOperation() {
92+
return CollectionUtils.lastElement(pipeline);
93+
}
94+
8595
List<Document> toDocuments(AggregationOperationContext context) {
8696

8797
verify();
@@ -97,8 +107,8 @@ public boolean isOutOrMerge() {
97107
return false;
98108
}
99109

100-
AggregationOperation operation = pipeline.get(pipeline.size() - 1);
101-
return isOut(operation) || isMerge(operation);
110+
AggregationOperation operation = lastOperation();
111+
return operation != null && (isOut(operation) || isMerge(operation));
102112
}
103113

104114
void verify() {

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,12 @@ public Sort getSort() {
104104
}
105105

106106
@Override
107-
public @org.jspecify.annotations.Nullable Score getScore() {
107+
public @Nullable Score getScore() {
108108
return delegate.getScore();
109109
}
110110

111111
@Override
112-
public @org.jspecify.annotations.Nullable Range<Score> getScoreRange() {
112+
public @Nullable Range<Score> getScoreRange() {
113113
return delegate.getScoreRange();
114114
}
115115

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java

+4-5
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,13 @@ public MongoParametersParameterAccessor(MongoQueryMethod method, Object[] values
6161
public Range<Score> getScoreRange() {
6262

6363
MongoParameters mongoParameters = method.getParameters();
64-
int rangeIndex = mongoParameters.getScoreRangeIndex();
6564

66-
if (rangeIndex != -1) {
67-
return getValue(rangeIndex);
65+
if (mongoParameters.hasScoreRangeParameter()) {
66+
return getValue(mongoParameters.getScoreRangeIndex());
6867
}
6968

70-
int scoreIndex = mongoParameters.getScoreIndex();
71-
Bound<Score> maxDistance = scoreIndex == -1 ? Bound.unbounded() : Bound.inclusive((Score) getScore());
69+
Score score = getScore();
70+
Bound<Score> maxDistance = score != null ? Bound.inclusive(score) : Bound.unbounded();
7271

7372
return Range.of(Bound.unbounded(), maxDistance);
7473
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java

+27-22
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
*/
1616
package org.springframework.data.mongodb.repository.query;
1717

18-
import static org.springframework.data.mongodb.core.query.Criteria.*;
18+
import static org.springframework.data.mongodb.core.query.Criteria.Placeholder;
19+
import static org.springframework.data.mongodb.core.query.Criteria.where;
1920

2021
import java.util.Arrays;
2122
import java.util.Collection;
@@ -27,7 +28,6 @@
2728
import org.apache.commons.logging.LogFactory;
2829
import org.bson.BsonRegularExpression;
2930
import org.jspecify.annotations.Nullable;
30-
3131
import org.springframework.data.domain.Range;
3232
import org.springframework.data.domain.Range.Bound;
3333
import org.springframework.data.domain.Sort;
@@ -118,8 +118,9 @@ protected Criteria create(Part part, Iterator<Object> iterator) {
118118
return new Criteria();
119119
}
120120

121-
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) {
122-
return null;
121+
if (isPartOfSearchQuery(part)) {
122+
skip(part, iterator);
123+
return new Criteria();
123124
}
124125

125126
PersistentPropertyPath<MongoPersistentProperty> path = context.getPersistentPropertyPath(part.getProperty());
@@ -135,7 +136,8 @@ protected Criteria and(Part part, Criteria base, Iterator<Object> iterator) {
135136
return create(part, iterator);
136137
}
137138

138-
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) {
139+
if (isPartOfSearchQuery(part)) {
140+
skip(part, iterator);
139141
return base;
140142
}
141143

@@ -176,15 +178,6 @@ protected Query complete(@Nullable Criteria criteria, Sort sort) {
176178
@SuppressWarnings("NullAway")
177179
private Criteria from(Part part, MongoPersistentProperty property, Criteria criteria, Iterator<Object> parameters) {
178180

179-
if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) {
180-
181-
int numberOfArguments = part.getType().getNumberOfArguments();
182-
for (int i = 0; i < numberOfArguments; i++) {
183-
parameters.next();
184-
}
185-
return null;
186-
}
187-
188181
Type type = part.getType();
189182

190183
switch (type) {
@@ -206,13 +199,13 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit
206199
return criteria.is(null);
207200
case NOT_IN:
208201
Object ninValue = parameters.next();
209-
if(ninValue instanceof Placeholder) {
202+
if (ninValue instanceof Placeholder) {
210203
return criteria.raw("$nin", ninValue);
211204
}
212205
return criteria.nin(valueAsList(ninValue, part));
213206
case IN:
214207
Object inValue = parameters.next();
215-
if(inValue instanceof Placeholder) {
208+
if (inValue instanceof Placeholder) {
216209
return criteria.raw("$in", inValue);
217210
}
218211
return criteria.in(valueAsList(inValue, part));
@@ -231,7 +224,7 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit
231224
return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString());
232225
case EXISTS:
233226
Object next = parameters.next();
234-
if(next instanceof Placeholder placeholder) {
227+
if (next instanceof Placeholder placeholder) {
235228
return criteria.raw("$exists", placeholder);
236229
} else {
237230
return criteria.exists((Boolean) next);
@@ -355,7 +348,7 @@ private Criteria createContainingCriteria(Part part, MongoPersistentProperty pro
355348

356349
if (property.isCollectionLike()) {
357350
Object next = parameters.next();
358-
if(next instanceof Placeholder) {
351+
if (next instanceof Placeholder) {
359352
return criteria.raw("$in", next);
360353
}
361354
return criteria.in(valueAsList(next, part));
@@ -433,8 +426,7 @@ private java.util.List<?> valueAsList(Object value, Part part) {
433426
streamable = streamable.map(it -> {
434427
if (it instanceof String sv) {
435428

436-
return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode),
437-
regexOptions);
429+
return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), regexOptions);
438430
}
439431
return it;
440432
});
@@ -468,10 +460,23 @@ private boolean isSpherical(MongoPersistentProperty property) {
468460
return false;
469461
}
470462

463+
private boolean isPartOfSearchQuery(Part part) {
464+
return isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN));
465+
}
466+
467+
private static void skip(Part part, Iterator<?> parameters) {
468+
469+
int total = part.getNumberOfArguments();
470+
int i = 0;
471+
while (parameters.hasNext() && i < total) {
472+
parameters.next();
473+
i++;
474+
}
475+
}
476+
471477
/**
472478
* Compute a {@link Type#BETWEEN} typed {@link Part} using {@link Criteria#gt(Object) $gt},
473-
* {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}.
474-
* <br />
479+
* {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}. <br />
475480
* In case the first {@literal value} is actually a {@link Range} the lower and upper bounds of the {@link Range} are
476481
* used according to their {@link Bound#isInclusive() inclusion} definition. Otherwise the {@literal value} is used
477482
* for {@literal $gt} and {@link Iterator#next() parameters.next()} as {@literal $lt}.

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java

+34-25
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323

2424
import org.bson.Document;
2525
import org.jspecify.annotations.Nullable;
26-
2726
import org.springframework.data.domain.Page;
2827
import org.springframework.data.domain.Pageable;
2928
import org.springframework.data.domain.Range;
29+
import org.springframework.data.domain.ScoringFunction;
3030
import org.springframework.data.domain.SearchResult;
3131
import org.springframework.data.domain.SearchResults;
3232
import org.springframework.data.domain.Similarity;
@@ -45,12 +45,13 @@
4545
import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove;
4646
import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate;
4747
import org.springframework.data.mongodb.core.MongoOperations;
48-
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
48+
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
4949
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
5050
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
5151
import org.springframework.data.mongodb.core.query.NearQuery;
5252
import org.springframework.data.mongodb.core.query.Query;
5353
import org.springframework.data.mongodb.core.query.UpdateDefinition;
54+
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
5455
import org.springframework.data.mongodb.repository.util.SliceUtils;
5556
import org.springframework.data.repository.query.QueryMethod;
5657
import org.springframework.data.support.PageableExecutionUtils;
@@ -186,7 +187,7 @@ public Object execute(Query query) {
186187
return isListOfGeoResult(method.getReturnType()) ? results.getContent() : results;
187188
}
188189

189-
@SuppressWarnings({"unchecked","NullAway"})
190+
@SuppressWarnings({ "unchecked", "NullAway" })
190191
GeoResults<Object> doExecuteQuery(Query query) {
191192

192193
Point nearLocation = accessor.getGeoNearLocation();
@@ -225,52 +226,60 @@ private static boolean isListOfGeoResult(TypeInformation<?> returnType) {
225226
* {@link MongoQueryExecution} to execute vector search.
226227
*
227228
* @author Mark Paluch
229+
* @author Chistoph Strobl
228230
* @since 5.0
229231
*/
230232
class VectorSearchExecution implements MongoQueryExecution {
231233

232234
private final MongoOperations operations;
233-
private final MongoQueryMethod method;
235+
private final TypeInformation<?> returnType;
234236
private final String collectionName;
235-
private final VectorSearchDelegate.QueryMetadata queryMetadata;
236-
private final List<AggregationOperation> pipeline;
237+
private final Class<?> targetType;
238+
private final ScoringFunction scoringFunction;
239+
private final AggregationPipeline pipeline;
240+
241+
VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName,
242+
QueryContainer queryContainer) {
243+
this(operations, queryContainer.outputType(), collectionName, method.getReturnType(), queryContainer.pipeline(),
244+
queryContainer.scoringFunction());
245+
}
237246

238-
public VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName,
239-
VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) {
247+
public VectorSearchExecution(MongoOperations operations, Class<?> targetType, String collectionName,
248+
TypeInformation<?> returnType, AggregationPipeline pipeline, ScoringFunction scoringFunction) {
240249

241250
this.operations = operations;
251+
this.returnType = returnType;
242252
this.collectionName = collectionName;
243-
this.queryMetadata = queryMetadata;
244-
this.method = method;
245-
this.pipeline = queryMetadata.getAggregationPipeline(method, accessor);
253+
this.targetType = targetType;
254+
this.scoringFunction = scoringFunction;
255+
this.pipeline = pipeline;
246256
}
247257

248258
@Override
249259
public Object execute(Query query) {
250260

251-
AggregationResults<?> aggregated = operations.aggregate(
252-
TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collectionName,
253-
queryMetadata.outputType());
261+
AggregationResults<?> aggregated = operations
262+
.aggregate(TypedAggregation.newAggregation(targetType, pipeline.getOperations()), collectionName, targetType);
254263

255264
List<?> mappedResults = aggregated.getMappedResults();
256265

257-
if (isSearchResult(method.getReturnType())) {
266+
if (!isSearchResult(returnType)) {
267+
return mappedResults;
268+
}
258269

259-
List<org.bson.Document> rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class);
260-
List<SearchResult<Object>> result = new ArrayList<>(mappedResults.size());
270+
List<org.bson.Document> rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class);
271+
List<SearchResult<Object>> result = new ArrayList<>(mappedResults.size());
261272

262-
for (int i = 0; i < mappedResults.size(); i++) {
263-
Document document = rawResults.get(i);
264-
SearchResult<Object> searchResult = new SearchResult<>(mappedResults.get(i),
265-
Similarity.raw(document.getDouble("__score__"), queryMetadata.scoringFunction()));
273+
for (int i = 0; i < mappedResults.size(); i++) {
266274

267-
result.add(searchResult);
268-
}
275+
Document document = rawResults.get(i);
276+
SearchResult<Object> searchResult = new SearchResult<>(mappedResults.get(i),
277+
Similarity.raw(document.getDouble("__score__"), scoringFunction));
269278

270-
return isListOfSearchResult(method.getReturnType()) ? result : new SearchResults<>(result);
279+
result.add(searchResult);
271280
}
272281

273-
return mappedResults;
282+
return isListOfSearchResult(returnType) ? result : new SearchResults<>(result);
274283
}
275284

276285
private static boolean isListOfSearchResult(TypeInformation<?> returnType) {

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java

+9-11
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,9 @@
1818
import reactor.core.publisher.Flux;
1919
import reactor.core.publisher.Mono;
2020

21-
import java.util.List;
22-
2321
import org.bson.Document;
2422
import org.jspecify.annotations.Nullable;
2523
import org.reactivestreams.Publisher;
26-
2724
import org.springframework.core.convert.converter.Converter;
2825
import org.springframework.data.convert.DtoInstantiatingConverter;
2926
import org.springframework.data.domain.Pageable;
@@ -36,11 +33,12 @@
3633
import org.springframework.data.mapping.model.EntityInstantiators;
3734
import org.springframework.data.mongodb.core.ReactiveMongoOperations;
3835
import org.springframework.data.mongodb.core.ReactiveUpdateOperation.ReactiveUpdate;
39-
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
36+
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
4037
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
4138
import org.springframework.data.mongodb.core.query.NearQuery;
4239
import org.springframework.data.mongodb.core.query.Query;
4340
import org.springframework.data.mongodb.core.query.UpdateDefinition;
41+
import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer;
4442
import org.springframework.data.repository.query.ResultProcessor;
4543
import org.springframework.data.repository.query.ReturnedType;
4644
import org.springframework.data.util.ReactiveWrappers;
@@ -134,24 +132,24 @@ private boolean isStreamOfGeoResult() {
134132
class VectorSearchExecution implements ReactiveMongoQueryExecution {
135133

136134
private final ReactiveMongoOperations operations;
137-
private final VectorSearchDelegate.QueryMetadata queryMetadata;
138-
private final List<AggregationOperation> pipeline;
135+
private final QueryContainer queryMetadata;
136+
private final AggregationPipeline pipeline;
139137
private final boolean returnSearchResult;
140138

141-
public VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method,
142-
VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) {
139+
VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, QueryContainer queryMetadata) {
143140

144141
this.operations = operations;
145142
this.queryMetadata = queryMetadata;
146-
this.pipeline = queryMetadata.getAggregationPipeline(method, accessor);
143+
this.pipeline = queryMetadata.pipeline();
147144
this.returnSearchResult = isSearchResult(method.getReturnType());
148145
}
149146

150147
@Override
151148
public Publisher<? extends Object> execute(Query query, Class<?> type, String collection) {
152149

153-
Flux<Document> aggregate = operations
154-
.aggregate(TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collection, Document.class);
150+
Flux<Document> aggregate = operations.aggregate(
151+
TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline.getOperations()), collection,
152+
Document.class);
155153

156154
return aggregate.map(document -> {
157155

0 commit comments

Comments
 (0)