|
47 | 47 | import org.springframework.data.repository.query.ValueExpressionDelegate;
|
48 | 48 | import org.springframework.data.repository.query.parser.Part;
|
49 | 49 | import org.springframework.data.repository.query.parser.PartTree;
|
| 50 | +import org.springframework.util.NumberUtils; |
50 | 51 | import org.springframework.util.StringUtils;
|
51 | 52 |
|
52 | 53 | /**
|
@@ -136,17 +137,14 @@ QueryContainer createQuery(ValueExpressionEvaluator evaluator, ResultProcessor p
|
136 | 137 | outputType, getSimilarityFunction(accessor), indexName);
|
137 | 138 | }
|
138 | 139 |
|
139 |
| - public AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class<?> outputType, |
| 140 | + @SuppressWarnings("NullAway") |
| 141 | + AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class<?> outputType, |
140 | 142 | MongoParameterAccessor accessor, ValueExpressionEvaluator evaluator) {
|
141 | 143 |
|
142 | 144 | Vector vector = accessor.getVector();
|
143 | 145 | Score score = accessor.getScore();
|
144 | 146 | Range<Score> distance = accessor.getScoreRange();
|
145 |
| - Limit limit = Limit.unlimited(); |
146 |
| - |
147 |
| - if (input.query().isLimited()) { |
148 |
| - limit = Limit.of(input.query().getLimit()); |
149 |
| - } |
| 147 | + Limit limit = Limit.of(input.query().getLimit()); |
150 | 148 |
|
151 | 149 | List<AggregationOperation> stages = new ArrayList<>();
|
152 | 150 | VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(indexName).path(input.path()).vector(vector)
|
@@ -223,21 +221,38 @@ public AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, S
|
223 | 221 | private VectorSearchInput createSearchInput(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor,
|
224 | 222 | ParameterBindingDocumentCodec codec, ParameterBindingContext context) {
|
225 | 223 |
|
226 |
| - VectorSearchInput query = queryFactory.createQuery(accessor, codec, context); |
227 |
| - Limit limit; |
| 224 | + VectorSearchInput input = queryFactory.createQuery(accessor, codec, context); |
| 225 | + Limit limit = getLimit(evaluator, accessor); |
| 226 | + if(!input.query.isLimited() || (input.query.isLimited() && !limit.isUnlimited())) { |
| 227 | + input.query().limit(limit); |
| 228 | + } |
| 229 | + return input; |
| 230 | + } |
| 231 | + |
| 232 | + private Limit getLimit(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor) { |
| 233 | + |
228 | 234 | if (this.limitExpression != null) {
|
| 235 | + |
229 | 236 | Object value = evaluator.evaluate(this.limitExpression);
|
230 |
| - limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue()); |
231 |
| - } else if (this.limit.isLimited()) { |
232 |
| - limit = this.limit; |
233 |
| - } else { |
234 |
| - limit = accessor.getLimit(); |
| 237 | + if (value != null) { |
| 238 | + if (value instanceof Limit l) { |
| 239 | + return l; |
| 240 | + } |
| 241 | + if (value instanceof Number n) { |
| 242 | + return Limit.of(n.intValue()); |
| 243 | + } |
| 244 | + if (value instanceof String s) { |
| 245 | + return Limit.of(NumberUtils.parseNumber(s, Integer.class)); |
| 246 | + } |
| 247 | + throw new IllegalArgumentException("Invalid type for Limit. Found [%s], expected Limit or Number"); |
| 248 | + } |
235 | 249 | }
|
236 | 250 |
|
237 |
| - if (limit.isLimited()) { |
238 |
| - query.query().limit(limit); |
| 251 | + if (this.limit.isLimited()) { |
| 252 | + return this.limit; |
239 | 253 | }
|
240 |
| - return query; |
| 254 | + |
| 255 | + return accessor.getLimit(); |
241 | 256 | }
|
242 | 257 |
|
243 | 258 | public String getQueryString() {
|
@@ -378,6 +393,7 @@ private class PartTreeQueryFactory implements VectorSearchQueryFactory {
|
378 | 393 | this.tree = tree;
|
379 | 394 | }
|
380 | 395 |
|
| 396 | + @SuppressWarnings("NullAway") |
381 | 397 | public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec,
|
382 | 398 | ParameterBindingContext context) {
|
383 | 399 |
|
|
0 commit comments