|
23 | 23 |
|
24 | 24 | import org.bson.Document;
|
25 | 25 | import org.jspecify.annotations.Nullable;
|
26 |
| - |
27 | 26 | import org.springframework.data.domain.Page;
|
28 | 27 | import org.springframework.data.domain.Pageable;
|
29 | 28 | import org.springframework.data.domain.Range;
|
| 29 | +import org.springframework.data.domain.ScoringFunction; |
30 | 30 | import org.springframework.data.domain.SearchResult;
|
31 | 31 | import org.springframework.data.domain.SearchResults;
|
32 | 32 | import org.springframework.data.domain.Similarity;
|
|
45 | 45 | import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove;
|
46 | 46 | import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate;
|
47 | 47 | import org.springframework.data.mongodb.core.MongoOperations;
|
48 |
| -import org.springframework.data.mongodb.core.aggregation.AggregationOperation; |
| 48 | +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; |
49 | 49 | import org.springframework.data.mongodb.core.aggregation.AggregationResults;
|
50 | 50 | import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
|
51 | 51 | import org.springframework.data.mongodb.core.query.NearQuery;
|
52 | 52 | import org.springframework.data.mongodb.core.query.Query;
|
53 | 53 | import org.springframework.data.mongodb.core.query.UpdateDefinition;
|
| 54 | +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; |
54 | 55 | import org.springframework.data.mongodb.repository.util.SliceUtils;
|
55 | 56 | import org.springframework.data.repository.query.QueryMethod;
|
56 | 57 | import org.springframework.data.support.PageableExecutionUtils;
|
@@ -186,7 +187,7 @@ public Object execute(Query query) {
|
186 | 187 | return isListOfGeoResult(method.getReturnType()) ? results.getContent() : results;
|
187 | 188 | }
|
188 | 189 |
|
189 |
| - @SuppressWarnings({"unchecked","NullAway"}) |
| 190 | + @SuppressWarnings({ "unchecked", "NullAway" }) |
190 | 191 | GeoResults<Object> doExecuteQuery(Query query) {
|
191 | 192 |
|
192 | 193 | Point nearLocation = accessor.getGeoNearLocation();
|
@@ -225,52 +226,60 @@ private static boolean isListOfGeoResult(TypeInformation<?> returnType) {
|
225 | 226 | * {@link MongoQueryExecution} to execute vector search.
|
226 | 227 | *
|
227 | 228 | * @author Mark Paluch
|
| 229 | + * @author Chistoph Strobl |
228 | 230 | * @since 5.0
|
229 | 231 | */
|
230 | 232 | class VectorSearchExecution implements MongoQueryExecution {
|
231 | 233 |
|
232 | 234 | private final MongoOperations operations;
|
233 |
| - private final MongoQueryMethod method; |
| 235 | + private final TypeInformation<?> returnType; |
234 | 236 | private final String collectionName;
|
235 |
| - private final VectorSearchDelegate.QueryMetadata queryMetadata; |
236 |
| - private final List<AggregationOperation> pipeline; |
| 237 | + private final Class<?> targetType; |
| 238 | + private final ScoringFunction scoringFunction; |
| 239 | + private final AggregationPipeline pipeline; |
| 240 | + |
| 241 | + VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName, |
| 242 | + QueryContainer queryContainer) { |
| 243 | + this(operations, queryContainer.outputType(), collectionName, method.getReturnType(), queryContainer.pipeline(), |
| 244 | + queryContainer.scoringFunction()); |
| 245 | + } |
237 | 246 |
|
238 |
| - public VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName, |
239 |
| - VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) { |
| 247 | + public VectorSearchExecution(MongoOperations operations, Class<?> targetType, String collectionName, |
| 248 | + TypeInformation<?> returnType, AggregationPipeline pipeline, ScoringFunction scoringFunction) { |
240 | 249 |
|
241 | 250 | this.operations = operations;
|
| 251 | + this.returnType = returnType; |
242 | 252 | this.collectionName = collectionName;
|
243 |
| - this.queryMetadata = queryMetadata; |
244 |
| - this.method = method; |
245 |
| - this.pipeline = queryMetadata.getAggregationPipeline(method, accessor); |
| 253 | + this.targetType = targetType; |
| 254 | + this.scoringFunction = scoringFunction; |
| 255 | + this.pipeline = pipeline; |
246 | 256 | }
|
247 | 257 |
|
248 | 258 | @Override
|
249 | 259 | public Object execute(Query query) {
|
250 | 260 |
|
251 |
| - AggregationResults<?> aggregated = operations.aggregate( |
252 |
| - TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collectionName, |
253 |
| - queryMetadata.outputType()); |
| 261 | + AggregationResults<?> aggregated = operations |
| 262 | + .aggregate(TypedAggregation.newAggregation(targetType, pipeline.getOperations()), collectionName, targetType); |
254 | 263 |
|
255 | 264 | List<?> mappedResults = aggregated.getMappedResults();
|
256 | 265 |
|
257 |
| - if (isSearchResult(method.getReturnType())) { |
| 266 | + if (!isSearchResult(returnType)) { |
| 267 | + return mappedResults; |
| 268 | + } |
258 | 269 |
|
259 |
| - List<org.bson.Document> rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class); |
260 |
| - List<SearchResult<Object>> result = new ArrayList<>(mappedResults.size()); |
| 270 | + List<org.bson.Document> rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class); |
| 271 | + List<SearchResult<Object>> result = new ArrayList<>(mappedResults.size()); |
261 | 272 |
|
262 |
| - for (int i = 0; i < mappedResults.size(); i++) { |
263 |
| - Document document = rawResults.get(i); |
264 |
| - SearchResult<Object> searchResult = new SearchResult<>(mappedResults.get(i), |
265 |
| - Similarity.raw(document.getDouble("__score__"), queryMetadata.scoringFunction())); |
| 273 | + for (int i = 0; i < mappedResults.size(); i++) { |
266 | 274 |
|
267 |
| - result.add(searchResult); |
268 |
| - } |
| 275 | + Document document = rawResults.get(i); |
| 276 | + SearchResult<Object> searchResult = new SearchResult<>(mappedResults.get(i), |
| 277 | + Similarity.raw(document.getDouble("__score__"), scoringFunction)); |
269 | 278 |
|
270 |
| - return isListOfSearchResult(method.getReturnType()) ? result : new SearchResults<>(result); |
| 279 | + result.add(searchResult); |
271 | 280 | }
|
272 | 281 |
|
273 |
| - return mappedResults; |
| 282 | + return isListOfSearchResult(returnType) ? result : new SearchResults<>(result); |
274 | 283 | }
|
275 | 284 |
|
276 | 285 | private static boolean isListOfSearchResult(TypeInformation<?> returnType) {
|
|
0 commit comments