Skip to content

Commit 1ac2c9f

Browse files
are we done?
1 parent 7fb9b2a commit 1ac2c9f

File tree

6 files changed

+50
-21
lines changed

6 files changed

+50
-21
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java

+1
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ public SortArray sort(Sort sort) {
356356
* @return new instance of {@link SortArray}.
357357
* @since 4.5
358358
*/
359+
@SuppressWarnings("NullAway")
359360
public SortArray sort(Direction direction) {
360361

361362
if (usesFieldRef()) {

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,10 @@ public Document toDocument(AggregationOperationContext context) {
237237
}
238238

239239
$vectorSearch.append("index", indexName);
240-
$vectorSearch.append("limit", limit.max());
240+
241+
if(limit.isLimited()) { // TODO: exception or pass it on?
242+
$vectorSearch.append("limit", limit.max());
243+
}
241244

242245
if (numCandidates != null) {
243246
$vectorSearch.append("numCandidates", numCandidates);

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public PotentiallyConvertingIterator iterator() {
7777
}
7878

7979
@Override
80-
public Vector getVector() {
80+
public @Nullable Vector getVector() {
8181
return delegate.getVector();
8282
}
8383

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java

-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ public VectorSearchAggregation(MongoQueryMethod method, MongoOperations mongoOpe
6363
this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate);
6464
}
6565

66-
@SuppressWarnings("unchecked")
6766
@Override
6867
protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor,
6968
@Nullable Class<?> typeToRead) {

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java

+32-16
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import org.springframework.data.repository.query.ValueExpressionDelegate;
4848
import org.springframework.data.repository.query.parser.Part;
4949
import org.springframework.data.repository.query.parser.PartTree;
50+
import org.springframework.util.NumberUtils;
5051
import org.springframework.util.StringUtils;
5152

5253
/**
@@ -136,17 +137,14 @@ QueryContainer createQuery(ValueExpressionEvaluator evaluator, ResultProcessor p
136137
outputType, getSimilarityFunction(accessor), indexName);
137138
}
138139

139-
public AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class<?> outputType,
140+
@SuppressWarnings("NullAway")
141+
AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class<?> outputType,
140142
MongoParameterAccessor accessor, ValueExpressionEvaluator evaluator) {
141143

142144
Vector vector = accessor.getVector();
143145
Score score = accessor.getScore();
144146
Range<Score> distance = accessor.getScoreRange();
145-
Limit limit = Limit.unlimited();
146-
147-
if (input.query().isLimited()) {
148-
limit = Limit.of(input.query().getLimit());
149-
}
147+
Limit limit = Limit.of(input.query().getLimit());
150148

151149
List<AggregationOperation> stages = new ArrayList<>();
152150
VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(indexName).path(input.path()).vector(vector)
@@ -223,21 +221,38 @@ public AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, S
223221
private VectorSearchInput createSearchInput(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor,
224222
ParameterBindingDocumentCodec codec, ParameterBindingContext context) {
225223

226-
VectorSearchInput query = queryFactory.createQuery(accessor, codec, context);
227-
Limit limit;
224+
VectorSearchInput input = queryFactory.createQuery(accessor, codec, context);
225+
Limit limit = getLimit(evaluator, accessor);
226+
if(!input.query.isLimited() || (input.query.isLimited() && !limit.isUnlimited())) {
227+
input.query().limit(limit);
228+
}
229+
return input;
230+
}
231+
232+
private Limit getLimit(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor) {
233+
228234
if (this.limitExpression != null) {
235+
229236
Object value = evaluator.evaluate(this.limitExpression);
230-
limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue());
231-
} else if (this.limit.isLimited()) {
232-
limit = this.limit;
233-
} else {
234-
limit = accessor.getLimit();
237+
if (value != null) {
238+
if (value instanceof Limit l) {
239+
return l;
240+
}
241+
if (value instanceof Number n) {
242+
return Limit.of(n.intValue());
243+
}
244+
if (value instanceof String s) {
245+
return Limit.of(NumberUtils.parseNumber(s, Integer.class));
246+
}
247+
throw new IllegalArgumentException("Invalid type for Limit. Found [%s], expected Limit or Number");
248+
}
235249
}
236250

237-
if (limit.isLimited()) {
238-
query.query().limit(limit);
251+
if (this.limit.isLimited()) {
252+
return this.limit;
239253
}
240-
return query;
254+
255+
return accessor.getLimit();
241256
}
242257

243258
public String getQueryString() {
@@ -378,6 +393,7 @@ private class PartTreeQueryFactory implements VectorSearchQueryFactory {
378393
this.tree = tree;
379394
}
380395

396+
@SuppressWarnings("NullAway")
381397
public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec,
382398
ParameterBindingContext context) {
383399

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java

+12-2
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
*/
1616
package org.springframework.data.mongodb.core.aggregation;
1717

18-
import static org.assertj.core.api.Assertions.*;
18+
import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
1919

2020
import java.util.List;
2121

2222
import org.bson.Document;
2323
import org.junit.jupiter.api.Test;
24-
2524
import org.springframework.data.annotation.Id;
25+
import org.springframework.data.domain.Limit;
2626
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType;
2727
import org.springframework.data.mongodb.core.mapping.Field;
2828
import org.springframework.data.mongodb.core.query.Criteria;
@@ -103,6 +103,16 @@ void mapsCriteriaToDomainType() {
103103
.containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("filter", filter)));
104104
}
105105

106+
@Test
107+
void withInvalidLimit() {
108+
109+
VectorSearchOperation $search = VectorSearchOperation.search("vector_index").path("plot_embedding")
110+
.vector(-0.0016261312, -0.028070757, -0.011342932).limit(Limit.unlimited());
111+
112+
List<Document> stages = $search.toPipelineStages(TestAggregationContext.contextFor(Movie.class));
113+
assertThat(stages.get(0)).doesNotContainKey("$vectorSearch.limit");
114+
}
115+
106116
static class Movie {
107117

108118
@Id String id;

0 commit comments

Comments
 (0)