Skip to content

Commit 6657508

Browse files
committed
Polishing.
1 parent b6d9efa commit 6657508

File tree

5 files changed

+164
-56
lines changed

5 files changed

+164
-56
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.mongodb.repository.query;
17+
18+
import static org.assertj.core.api.Assertions.*;
19+
import static org.mockito.Mockito.*;
20+
21+
import java.lang.reflect.Method;
22+
23+
import org.jetbrains.annotations.NotNull;
24+
import org.junit.jupiter.api.Test;
25+
26+
import org.springframework.data.domain.Limit;
27+
import org.springframework.data.domain.Score;
28+
import org.springframework.data.domain.SearchResults;
29+
import org.springframework.data.domain.Vector;
30+
import org.springframework.data.mapping.model.ValueExpressionEvaluator;
31+
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
32+
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
33+
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
34+
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
35+
import org.springframework.data.mongodb.repository.VectorSearch;
36+
import org.springframework.data.mongodb.util.json.ParameterBindingContext;
37+
import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec;
38+
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
39+
import org.springframework.data.repository.Repository;
40+
import org.springframework.data.repository.core.RepositoryMetadata;
41+
import org.springframework.data.repository.core.support.AnnotationRepositoryMetadata;
42+
import org.springframework.data.repository.query.ValueExpressionDelegate;
43+
44+
/**
45+
* Unit tests for {@link VectorSearchDelegate}.
46+
*
47+
* @author Mark Paluch
48+
*/
49+
class VectorSearchDelegateUnitTests {
50+
51+
MappingMongoConverter converter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, new MongoMappingContext());
52+
53+
@Test
54+
void shouldConsiderDerivedLimit() throws ReflectiveOperationException {
55+
56+
Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class);
57+
58+
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
59+
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1));
60+
61+
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor);
62+
63+
assertThat(query.query().getLimit()).isEqualTo(10);
64+
assertThat(query.numCandidates()).isEqualTo(10 * 20);
65+
}
66+
67+
@Test
68+
void shouldNotSetNumCandidates() throws ReflectiveOperationException {
69+
70+
Method method = VectorSearchRepository.class.getMethod("searchTop10EnnByEmbeddingNear", Vector.class, Score.class);
71+
72+
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
73+
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1));
74+
75+
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor);
76+
77+
assertThat(query.query().getLimit()).isEqualTo(10);
78+
assertThat(query.numCandidates()).isNull();
79+
}
80+
81+
@Test
82+
void shouldConsiderProvidedLimit() throws ReflectiveOperationException {
83+
84+
Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class,
85+
Limit.class);
86+
87+
MongoQueryMethod queryMethod = getMongoQueryMethod(method);
88+
MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11));
89+
90+
VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor);
91+
92+
assertThat(query.query().getLimit()).isEqualTo(11);
93+
assertThat(query.numCandidates()).isEqualTo(11 * 20);
94+
}
95+
96+
private VectorSearchDelegate.QueryMetadata createQueryMetadata(MongoQueryMethod queryMethod,
97+
MongoParametersParameterAccessor accessor) {
98+
99+
VectorSearchDelegate delegate = new VectorSearchDelegate(queryMethod, converter, ValueExpressionDelegate.create());
100+
101+
return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor,
102+
Object.class, new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class));
103+
}
104+
105+
private MongoQueryMethod getMongoQueryMethod(Method method) {
106+
RepositoryMetadata metadata = AnnotationRepositoryMetadata.getMetadata(method.getDeclaringClass());
107+
return new MongoQueryMethod(method, metadata, new SpelAwareProxyProjectionFactory(), converter.getMappingContext());
108+
}
109+
110+
@NotNull
111+
private static MongoParametersParameterAccessor getAccessor(MongoQueryMethod queryMethod, Object... values) {
112+
return new MongoParametersParameterAccessor(queryMethod, values);
113+
}
114+
115+
interface VectorSearchRepository extends Repository<WithVector, String> {
116+
117+
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
118+
SearchResults<WithVector> searchTop10ByEmbeddingNear(Vector vector, Score similarity);
119+
120+
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ENN)
121+
SearchResults<WithVector> searchTop10EnnByEmbeddingNear(Vector vector, Score similarity);
122+
123+
@VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN)
124+
SearchResults<WithVector> searchTop10ByEmbeddingNear(Vector vector, Score similarity, Limit limit);
125+
126+
}
127+
128+
static class WithVector {
129+
130+
Vector embedding;
131+
}
132+
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.bson.types.Decimal128;
3030
import org.bson.types.ObjectId;
3131
import org.bson.types.Symbol;
32+
3233
import org.springframework.data.mapping.model.SimpleTypeHolder;
3334

3435
import com.mongodb.DBRef;

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2016-2025 the original author or authors.
2+
* Copyright 2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -41,7 +41,7 @@
4141
*
4242
* @author Mark Paluch
4343
* @since 5.0
44-
* @see org.springframework.data.geo.Distance
44+
* @see org.springframework.data.domain.Score
4545
* @see org.springframework.data.domain.Vector
4646
* @see org.springframework.data.domain.SearchResults
4747
*/

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java

+7-40
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import java.util.List;
2222

2323
import org.jspecify.annotations.Nullable;
24+
2425
import org.springframework.core.MethodParameter;
25-
import org.springframework.core.ResolvableType;
2626
import org.springframework.data.domain.Range;
2727
import org.springframework.data.domain.Vector;
28+
import org.springframework.data.geo.Distance;
2829
import org.springframework.data.geo.GeoPage;
2930
import org.springframework.data.geo.GeoResult;
3031
import org.springframework.data.geo.GeoResults;
31-
import org.springframework.data.geo.Distance;
3232
import org.springframework.data.geo.Point;
3333
import org.springframework.data.mongodb.core.query.Collation;
3434
import org.springframework.data.mongodb.core.query.TextCriteria;
@@ -78,7 +78,7 @@ public MongoParameters(ParametersSource parametersSource) {
7878
* @param isGeoNearMethod indicate if this is a geo-spatial query method
7979
*/
8080
public MongoParameters(ParametersSource parametersSource, boolean isGeoNearMethod) {
81-
this(parametersSource, new NearIndex(parametersSource, isGeoNearMethod), new DistanceRangeIndex(parametersSource));
81+
this(parametersSource, new NearIndex(parametersSource, isGeoNearMethod));
8282
}
8383

8484
/**
@@ -87,11 +87,10 @@ public MongoParameters(ParametersSource parametersSource, boolean isGeoNearMetho
8787
* @param parametersSource must not be {@literal null}.
8888
* @param nearIndex the near parameter index.
8989
*/
90-
private MongoParameters(ParametersSource parametersSource, NearIndex nearIndex,
91-
DistanceRangeIndex distanceRangeIndex) {
90+
private MongoParameters(ParametersSource parametersSource, NearIndex nearIndex) {
9291

9392
super(parametersSource, methodParameter -> new MongoParameter(methodParameter,
94-
parametersSource.getDomainTypeInformation(), nearIndex.nearIndex, distanceRangeIndex.distanceRangeIndex));
93+
parametersSource.getDomainTypeInformation(), nearIndex.nearIndex));
9594

9695
Method method = parametersSource.getMethod();
9796
List<Class<?>> parameterTypes = Arrays.asList(method.getParameterTypes());
@@ -156,15 +155,6 @@ public NearIndex(ParametersSource parametersSource, boolean isGeoNearMethod) {
156155
}
157156
}
158157

159-
static class DistanceRangeIndex {
160-
161-
private final int distanceRangeIndex;
162-
163-
public DistanceRangeIndex(ParametersSource parametersSource) {
164-
this.distanceRangeIndex = findDistanceRangeIndexInParameters(parametersSource.getMethod());
165-
}
166-
}
167-
168158
private static int getNearIndex(List<Class<?>> parameterTypes) {
169159

170160
for (Class<?> reference : Arrays.asList(Point.class, double[].class)) {
@@ -207,21 +197,6 @@ static int findNearIndexInParameters(Method method) {
207197
return index;
208198
}
209199

210-
static int findDistanceRangeIndexInParameters(Method method) {
211-
212-
int index = -1;
213-
for (java.lang.reflect.Parameter p : method.getParameters()) {
214-
215-
MethodParameter methodParameter = MethodParameter.forParameter(p);
216-
217-
if (Range.class.isAssignableFrom(methodParameter.getParameterType())
218-
&& ResolvableType.forMethodParameter(methodParameter).getGeneric(0).isAssignableFrom(Distance.class)) {
219-
index = methodParameter.getParameterIndex();
220-
}
221-
}
222-
return index;
223-
}
224-
225200
/**
226201
* Returns the index of the {@link Distance} parameter to be used for max distance in geo queries.
227202
*
@@ -321,21 +296,17 @@ static class MongoParameter extends Parameter {
321296

322297
private final MethodParameter parameter;
323298
private final @Nullable Integer nearIndex;
324-
private final @Nullable Integer distanceRangeIndex;
325299

326300
/**
327301
* Creates a new {@link MongoParameter}.
328302
*
329303
* @param parameter must not be {@literal null}.
330304
* @param domainType must not be {@literal null}.
331-
* @param distanceRangeIndex
332305
*/
333-
MongoParameter(MethodParameter parameter, TypeInformation<?> domainType, @Nullable Integer nearIndex,
334-
@Nullable Integer distanceRangeIndex) {
306+
MongoParameter(MethodParameter parameter, TypeInformation<?> domainType, @Nullable Integer nearIndex) {
335307
super(parameter, domainType);
336308
this.parameter = parameter;
337309
this.nearIndex = nearIndex;
338-
this.distanceRangeIndex = distanceRangeIndex;
339310

340311
if (!isPoint() && hasNearAnnotation()) {
341312
throw new IllegalArgumentException("Near annotation is only allowed at Point parameter");
@@ -345,18 +316,14 @@ static class MongoParameter extends Parameter {
345316
@Override
346317
public boolean isSpecialParameter() {
347318
return super.isSpecialParameter() || Distance.class.isAssignableFrom(getType())
348-
|| Vector.class.isAssignableFrom(getType()) || isNearParameter() || isDistanceRangeParameter()
319+
|| Vector.class.isAssignableFrom(getType()) || isNearParameter()
349320
|| TextCriteria.class.isAssignableFrom(getType()) || Collation.class.isAssignableFrom(getType());
350321
}
351322

352323
private boolean isNearParameter() {
353324
return nearIndex != null && nearIndex.equals(getIndex());
354325
}
355326

356-
private boolean isDistanceRangeParameter() {
357-
return distanceRangeIndex != null && distanceRangeIndex.equals(getIndex());
358-
}
359-
360327
private boolean isManuallyAnnotatedNearParameter() {
361328
return isPoint() && hasNearAnnotation();
362329
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java

+22-14
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,9 @@ public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProce
126126
Integer numCandidates = null;
127127
Limit limit;
128128
Class<?> outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType();
129-
130-
if (this.numCandidatesExpression != null) {
131-
numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue();
132-
} else if (this.numCandidates != null) {
133-
numCandidates = this.numCandidates;
134-
}
129+
VectorSearchInput query = queryFactory.createQuery(accessor, codec, context);
135130

136131
if (this.limitExpression != null) {
137-
138132
Object value = evaluator.evaluate(this.limitExpression);
139133
limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue());
140134
} else if (this.limit.isLimited()) {
@@ -143,12 +137,22 @@ public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProce
143137
limit = accessor.getLimit();
144138
}
145139

146-
VectorSearchInput query = queryFactory.createQuery(accessor, codec, context);
147-
148140
if (limit.isLimited()) {
149141
query.query().limit(limit);
150142
}
151143

144+
if (this.numCandidatesExpression != null) {
145+
numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue();
146+
} else if (this.numCandidates != null) {
147+
numCandidates = this.numCandidates;
148+
} else if (query.query().isLimited() && searchType == VectorSearchOperation.SearchType.ANN) {
149+
150+
/*
151+
MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy.
152+
*/
153+
numCandidates = query.query().getLimit() * 20;
154+
}
155+
152156
return new QueryMetadata(query.path, "__score__", query.query, searchType, outputType, numCandidates,
153157
getSimilarityFunction(accessor));
154158
}
@@ -335,13 +339,13 @@ public String getQueryString() {
335339
private class PartTreeQueryFactory implements VectorSearchQueryFactory {
336340

337341
private final String path;
338-
private final PartTree partTree;
342+
private final PartTree tree;
339343

340344
@SuppressWarnings("NullableProblems")
341-
PartTreeQueryFactory(PartTree partTree, MappingContext<?, MongoPersistentProperty> context) {
345+
PartTreeQueryFactory(PartTree tree, MappingContext<?, MongoPersistentProperty> context) {
342346

343347
String path = null;
344-
for (PartTree.OrPart part : partTree) {
348+
for (PartTree.OrPart part : tree) {
345349
for (Part p : part) {
346350
if (p.getType() == Part.Type.SIMPLE_PROPERTY || p.getType() == Part.Type.NEAR
347351
|| p.getType() == Part.Type.WITHIN || p.getType() == Part.Type.BETWEEN) {
@@ -362,17 +366,21 @@ private class PartTreeQueryFactory implements VectorSearchQueryFactory {
362366
}
363367

364368
this.path = path;
365-
this.partTree = partTree;
369+
this.tree = tree;
366370
}
367371

368372
public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec,
369373
ParameterBindingContext context) {
370374

371-
MongoQueryCreator creator = new MongoQueryCreator(partTree, parameterAccessor, converter.getMappingContext(),
375+
MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(),
372376
false, true);
373377

374378
Query query = creator.createQuery(parameterAccessor.getSort());
375379

380+
if (tree.isLimiting()) {
381+
query.limit(tree.getMaxResults());
382+
}
383+
376384
return new VectorSearchInput(path, query);
377385
}
378386

0 commit comments

Comments
 (0)