diff --git a/docs/changelog/129659.yaml b/docs/changelog/129659.yaml new file mode 100644 index 0000000000000..60fce08d58398 --- /dev/null +++ b/docs/changelog/129659.yaml @@ -0,0 +1,5 @@ +pr: 129659 +summary: Simplified RRF Retriever +area: Search +type: enhancement +issues: [] diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java index 0f11df321300b..f76c22fe1344e 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java @@ -10,6 +10,7 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder; +import org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder; import java.util.Set; @@ -34,7 +35,8 @@ public Set getTestFeatures() { LINEAR_RETRIEVER_MINMAX_SINGLE_DOC_FIX, LINEAR_RETRIEVER_L2_NORM, LINEAR_RETRIEVER_MINSCORE_FIX, - LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT + LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT, + RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT ); } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 81b401597a667..9fb7fdea21bb9 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -8,19 +8,27 @@ package org.elasticsearch.xpack.rank.rrf; import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.common.util.Maps; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils; import java.io.IOException; import java.util.ArrayList; @@ -29,7 +37,6 @@ import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; /** @@ -40,11 +47,14 @@ * formula. */ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder { + public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support"); public static final String NAME = "rrf"; public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers"); public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant"); + public static final ParseField FIELDS_FIELD = new ParseField("fields"); + public static final ParseField QUERY_FIELD = new ParseField("query"); public static final int DEFAULT_RANK_CONSTANT = 60; @SuppressWarnings("unchecked") @@ -53,15 +63,20 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder { List childRetrievers = (List) args[0]; - List innerRetrievers = childRetrievers.stream().map(RetrieverSource::from).toList(); - int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1]; - int rankConstant = args[2] == null ? DEFAULT_RANK_CONSTANT : (int) args[2]; - return new RRFRetrieverBuilder(innerRetrievers, rankWindowSize, rankConstant); + List fields = (List) args[1]; + String query = (String) args[2]; + int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3]; + int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4]; + + List innerRetrievers = childRetrievers != null + ? childRetrievers.stream().map(RetrieverSource::from).toList() + : List.of(); + return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant); } ); static { - PARSER.declareObjectArray(constructorArg(), (p, c) -> { + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> { p.nextToken(); String name = p.currentName(); RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c); @@ -69,6 +84,8 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder fields; + private final String query; private final int rankConstant; - public RRFRetrieverBuilder(int rankWindowSize, int rankConstant) { - this(new ArrayList<>(), rankWindowSize, rankConstant); + public RRFRetrieverBuilder(List childRetrievers, int rankWindowSize, int rankConstant) { + this(childRetrievers, null, null, rankWindowSize, rankConstant); } - RRFRetrieverBuilder(List childRetrievers, int rankWindowSize, int rankConstant) { - super(childRetrievers, rankWindowSize); + public RRFRetrieverBuilder( + List childRetrievers, + List fields, + String query, + int rankWindowSize, + int rankConstant + ) { + // Use a mutable list for childRetrievers so that we can use addChild + super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize); + this.fields = fields == null ? List.of() : List.copyOf(fields); + this.query = query; this.rankConstant = rankConstant; } + public int rankConstant() { + return rankConstant; + } + @Override public String getName() { return NAME; } + @Override + public ActionRequestValidationException validate( + SearchSourceBuilder source, + ActionRequestValidationException validationException, + boolean isScroll, + boolean allowPartialSearchResults + ) { + validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults); + return MultiFieldsInnerRetrieverUtils.validateParams( + innerRetrievers, + fields, + query, + getName(), + RETRIEVERS_FIELD.getPreferredName(), + FIELDS_FIELD.getPreferredName(), + QUERY_FIELD.getPreferredName(), + validationException + ); + } + @Override protected RRFRetrieverBuilder clone(List newRetrievers, List newPreFilterQueryBuilders) { - RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant); + RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant); clone.preFilterQueryBuilders = newPreFilterQueryBuilders; clone.retrieverName = retrieverName; return clone; @@ -162,17 +214,72 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List rankResults return topResults; } + @Override + protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { + RetrieverBuilder rewritten = this; + + ResolvedIndices resolvedIndices = ctx.getResolvedIndices(); + if (resolvedIndices != null && query != null) { + // TODO: Refactor duplicate code + // Using the multi-fields query format + var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata(); + if (localIndicesMetadata.size() > 1) { + throw new IllegalArgumentException( + "[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying multiple indices" + ); + } else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) { + throw new IllegalArgumentException( + "[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying remote indices" + ); + } + + List fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers( + fields, + query, + localIndicesMetadata.values(), + r -> { + List retrievers = r.stream() + .map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource) + .toList(); + return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant); + }, + w -> { + if (w != 1.0f) { + throw new IllegalArgumentException( + "[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD.getPreferredName() + "]" + ); + } + } + ).stream().map(RetrieverSource::from).toList(); + + if (fieldsInnerRetrievers.isEmpty() == false) { + // TODO: This is a incomplete solution as it does not address other incomplete copy issues + // (such as dropping the retriever name and min score) + rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant); + rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); + } else { + // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices + rewritten = new StandardRetrieverBuilder(new MatchNoneQueryBuilder()); + } + } + + return rewritten; + } + // ---- FOR TESTING XCONTENT PARSING ---- @Override public boolean doEquals(Object o) { RRFRetrieverBuilder that = (RRFRetrieverBuilder) o; - return super.doEquals(o) && rankConstant == that.rankConstant; + return super.doEquals(o) + && Objects.equals(fields, that.fields) + && Objects.equals(query, that.query) + && rankConstant == that.rankConstant; } @Override public int doHashCode() { - return Objects.hash(super.doHashCode(), rankConstant); + return Objects.hash(super.doHashCode(), fields, query, rankConstant); } @Override @@ -186,6 +293,17 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.endArray(); } + if (fields.isEmpty() == false) { + builder.startArray(FIELDS_FIELD.getPreferredName()); + for (String field : fields) { + builder.value(field); + } + builder.endArray(); + } + if (query != null) { + builder.field(QUERY_FIELD.getPreferredName(), query); + } + builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant); } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java index cae758457a2ac..add6f271b06ba 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.Strings; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.TestRetrieverBuilder; @@ -45,13 +46,22 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() { if (randomBoolean()) { rankConstant = randomIntBetween(1, 1000000); } - var ret = new RRFRetrieverBuilder(rankWindowSize, rankConstant); + + List fields = null; + String query = null; + if (randomBoolean()) { + fields = randomList(1, 10, () -> randomAlphaOfLengthBetween(1, 10)); + query = randomAlphaOfLengthBetween(1, 10); + } + int retrieverCount = randomIntBetween(2, 50); + List innerRetrievers = new ArrayList<>(retrieverCount); while (retrieverCount > 0) { - ret.addChild(TestRetrieverBuilder.createRandomTestRetrieverBuilder()); + innerRetrievers.add(CompoundRetrieverBuilder.RetrieverSource.from(TestRetrieverBuilder.createRandomTestRetrieverBuilder())); --retrieverCount; } - return ret; + + return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant); } @Override @@ -94,28 +104,32 @@ protected NamedXContentRegistry xContentRegistry() { } public void testRRFRetrieverParsing() throws IOException { - String restContent = "{" - + " \"retriever\": {" - + " \"rrf\": {" - + " \"retrievers\": [" - + " {" - + " \"test\": {" - + " \"value\": \"foo\"" - + " }" - + " }," - + " {" - + " \"test\": {" - + " \"value\": \"bar\"" - + " }" - + " }" - + " ]," - + " \"rank_window_size\": 100," - + " \"rank_constant\": 10," - + " \"min_score\": 20.0," - + " \"_name\": \"foo_rrf\"" - + " }" - + " }" - + "}"; + String restContent = """ + { + "retriever": { + "rrf": { + "retrievers": [ + { + "test": { + "value": "foo" + } + }, + { + "test": { + "value": "bar" + } + } + ], + "fields": ["field1", "field2"], + "query": "baz", + "rank_window_size": 100, + "rank_constant": 10, + "min_score": 20.0, + "_name": "foo_rrf" + } + } + } + """; SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java index 3a77b733d6129..5e8d46cb5b27a 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java @@ -7,14 +7,26 @@ package org.elasticsearch.xpack.rank.rrf; +import org.elasticsearch.action.MockResolvedIndices; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.query.MultiMatchQueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; @@ -22,7 +34,13 @@ import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; /** Tests for the rrf retriever. */ public class RRFRetrieverBuilderTests extends ESTestCase { @@ -66,13 +84,121 @@ public void testRetrieverExtractionErrors() throws IOException { } } + public void testMultiFieldsParamsRewrite() { + final String indexName = "test-index"; + final List testInferenceFields = List.of("semantic_field_1", "semantic_field_2"); + final ResolvedIndices resolvedIndices = createMockResolvedIndices(indexName, testInferenceFields, null); + final QueryRewriteContext queryRewriteContext = new QueryRewriteContext( + parserConfig(), + null, + null, + resolvedIndices, + new PointInTimeBuilder(new BytesArray("pitid")), + null + ); + + // No wildcards + RRFRetrieverBuilder rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"), + "foo", + DEFAULT_RANK_WINDOW_SIZE, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT + ); + assertMultiFieldsParamsRewrite( + rrfRetrieverBuilder, + queryRewriteContext, + Map.of("field_1", 1.0f, "field_2", 1.0f), + Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f), + "foo" + ); + + // Non-default rank window size and rank constant + rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"), + "foo2", + DEFAULT_RANK_WINDOW_SIZE * 2, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT / 2 + ); + assertMultiFieldsParamsRewrite( + rrfRetrieverBuilder, + queryRewriteContext, + Map.of("field_1", 1.0f, "field_2", 1.0f), + Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f), + "foo2" + ); + + // Glob matching on inference and non-inference fields + rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + List.of("field_*", "*_field_1"), + "bar", + DEFAULT_RANK_WINDOW_SIZE, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT + ); + assertMultiFieldsParamsRewrite( + rrfRetrieverBuilder, + queryRewriteContext, + Map.of("field_*", 1.0f, "*_field_1", 1.0f), + Map.of("semantic_field_1", 1.0f), + "bar" + ); + + // All-fields wildcard + rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + List.of("*"), + "baz", + DEFAULT_RANK_WINDOW_SIZE, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT + ); + assertMultiFieldsParamsRewrite( + rrfRetrieverBuilder, + queryRewriteContext, + Map.of("*", 1.0f), + Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f), + "baz" + ); + } + + public void testSearchRemoteIndex() { + final ResolvedIndices resolvedIndices = createMockResolvedIndices( + "local-index", + List.of(), + Map.of("remote-cluster", "remote-index") + ); + final QueryRewriteContext queryRewriteContext = new QueryRewriteContext( + parserConfig(), + null, + null, + resolvedIndices, + new PointInTimeBuilder(new BytesArray("pitid")), + null + ); + + RRFRetrieverBuilder rrfRetrieverBuilder = new RRFRetrieverBuilder( + null, + null, + "foo", + DEFAULT_RANK_WINDOW_SIZE, + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT + ); + + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> rrfRetrieverBuilder.doRewrite(queryRewriteContext) + ); + assertEquals("[rrf] cannot specify [query] when querying remote indices", iae.getMessage()); + } + @Override protected NamedXContentRegistry xContentRegistry() { List entries = new SearchModule(Settings.EMPTY, List.of()).getNamedXContents(); entries.add( new NamedXContentRegistry.Entry( RetrieverBuilder.class, - new ParseField(RRFRankPlugin.NAME), + new ParseField(RRFRetrieverBuilder.NAME), (p, c) -> RRFRetrieverBuilder.fromXContent(p, (RetrieverParserContext) c) ) ); @@ -80,10 +206,94 @@ protected NamedXContentRegistry xContentRegistry() { entries.add( new NamedXContentRegistry.Entry( RetrieverBuilder.class, - new ParseField(RRFRankPlugin.NAME + "_nl"), + new ParseField(RRFRetrieverBuilder.NAME + "_nl"), (p, c) -> RRFRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c) ) ); return new NamedXContentRegistry(entries); } + + private static ResolvedIndices createMockResolvedIndices( + String localIndexName, + List inferenceFields, + Map remoteIndexNames + ) { + Index index = new Index(localIndexName, randomAlphaOfLength(10)); + IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(index.getName()) + .settings( + Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) + .put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID()) + ) + .numberOfShards(1) + .numberOfReplicas(0); + + for (String inferenceField : inferenceFields) { + indexMetadataBuilder.putInferenceField( + new InferenceFieldMetadata(inferenceField, randomAlphaOfLengthBetween(3, 5), new String[] { inferenceField }, null) + ); + } + + Map remoteIndices = new HashMap<>(); + if (remoteIndexNames != null) { + for (Map.Entry entry : remoteIndexNames.entrySet()) { + remoteIndices.put(entry.getKey(), new OriginalIndices(new String[] { entry.getValue() }, IndicesOptions.DEFAULT)); + } + } + + return new MockResolvedIndices( + remoteIndices, + new OriginalIndices(new String[] { localIndexName }, IndicesOptions.DEFAULT), + Map.of(index, indexMetadataBuilder.build()) + ); + } + + private static void assertMultiFieldsParamsRewrite( + RRFRetrieverBuilder retriever, + QueryRewriteContext ctx, + Map expectedNonInferenceFields, + Map expectedInferenceFields, + String expectedQuery + ) { + Set expectedInnerRetrievers = Set.of( + CompoundRetrieverBuilder.RetrieverSource.from( + new StandardRetrieverBuilder( + new MultiMatchQueryBuilder(expectedQuery).type(MultiMatchQueryBuilder.Type.MOST_FIELDS) + .fields(expectedNonInferenceFields) + ) + ), + Set.of(expectedInferenceFields.entrySet().stream().map(e -> { + if (e.getValue() != 1.0f) { + throw new IllegalArgumentException("Cannot apply per-field weights in RRF"); + } + return CompoundRetrieverBuilder.RetrieverSource.from( + new StandardRetrieverBuilder(new MatchQueryBuilder(e.getKey(), expectedQuery)) + ); + }).toArray()) + ); + + RetrieverBuilder rewritten = retriever.doRewrite(ctx); + assertNotSame(retriever, rewritten); + assertTrue(rewritten instanceof RRFRetrieverBuilder); + + RRFRetrieverBuilder rewrittenRrf = (RRFRetrieverBuilder) rewritten; + assertEquals(retriever.rankWindowSize(), rewrittenRrf.rankWindowSize()); + assertEquals(retriever.rankConstant(), rewrittenRrf.rankConstant()); + assertEquals(expectedInnerRetrievers, getInnerRetrieversAsSet(rewrittenRrf)); + } + + private static Set getInnerRetrieversAsSet(RRFRetrieverBuilder retriever) { + Set innerRetrieversSet = new HashSet<>(); + for (CompoundRetrieverBuilder.RetrieverSource innerRetriever : retriever.innerRetrievers()) { + if (innerRetriever.retriever() instanceof RRFRetrieverBuilder innerRrfRetriever) { + assertEquals(retriever.rankWindowSize(), innerRrfRetriever.rankWindowSize()); + assertEquals(retriever.rankConstant(), innerRrfRetriever.rankConstant()); + innerRetrieversSet.add(getInnerRetrieversAsSet(innerRrfRetriever)); + } else { + innerRetrieversSet.add(innerRetriever); + } + } + + return innerRetrieversSet; + } } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java index 1a22f8738a26a..f2a8f7f38bb06 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java @@ -11,6 +11,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; import org.junit.ClassRule; @@ -25,8 +26,12 @@ public class RRFRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .module("rank-rrf") .module("lang-painless") .module("x-pack-inference") + .systemProperty("tests.seed", System.getProperty("tests.seed")) + .setting("xpack.security.enabled", "false") + .setting("xpack.security.http.ssl.enabled", "false") .setting("xpack.license.self_generated.type", "trial") .plugin("inference-service-test") + .distribution(DistributionType.DEFAULT) .build(); public RRFRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/20_linear_retriever_simplified.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/20_linear_retriever_simplified.yml index 01cfa218c918d..dea4608c13dd1 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/20_linear_retriever_simplified.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/20_linear_retriever_simplified.yml @@ -268,6 +268,8 @@ setup: --- "Can query sparse vector fields": - do: + headers: + Content-Type: application/json search: index: test-index body: diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/310_rrf_retriever_simplified.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/310_rrf_retriever_simplified.yml new file mode 100644 index 0000000000000..a4b36be470481 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/310_rrf_retriever_simplified.yml @@ -0,0 +1,336 @@ +setup: + - requires: + cluster_features: [ "rrf_retriever.multi_fields_query_format_support" ] + reason: "RRF retriever multi-fields query format support" + test_runner_features: [ "contains" ] + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 128, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + keyword: + type: keyword + dense_inference: + type: semantic_text + inference_id: dense-inference-id + sparse_inference: + type: semantic_text + inference_id: sparse-inference-id + text_1: + type: text + text_2: + type: text + timestamp: + type: date + dense_vector: + type: dense_vector + dims: 1 + index: true + similarity: l2_norm + index_options: + type: flat + sparse_vector: + type: sparse_vector + + - do: + bulk: + index: test-index + refresh: true + body: | + {"index": {"_id": "1"}} + { + "keyword": "keyword match 1", + "dense_inference": "you know", + "sparse_inference": "for testing", + "text_1": "foo match 1", + "text_2": "x match 2", + "timestamp": "2000-03-30", + "dense_vector": [1], + "sparse_vector": { + "foo": 1.0 + } + } + {"index": {"_id": "2"}} + { + "keyword": "keyword match 2", + "dense_inference": "ElasticSearch is an open source", + "sparse_inference": "distributed, RESTful, search engine", + "text_1": "bar match 3", + "text_2": "y match 4", + "timestamp": "2010-02-08", + "dense_vector": [2], + "sparse_vector": { + "bar": 1.0 + } + } + {"index": {"_id": "3"}} + { + "keyword": "keyword match 3", + "dense_inference": "which is built on top of Lucene internally", + "sparse_inference": "and enjoys all the features it provides", + "text_1": "baz match 5", + "text_2": "z match 6", + "timestamp": "2024-08-08", + "dense_vector": [3], + "sparse_vector": { + "baz": 1.0 + } + } + +--- +"Query all fields using the simplified format": + - do: + search: + index: test-index + body: + retriever: + rrf: + query: "match" + + - match: { hits.total.value: 3 } + - length: { hits.hits: 3 } + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.2._id: "3" } + +--- +"Per-field boosting is not supported": + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + rrf: + fields: [ "text_1", "text_2^3" ] + query: "foo" + + - match: { error.root_cause.0.reason: "[rrf] does not support per-field weights in [fields]" } + +--- +"Can query keyword fields": + - do: + search: + index: test-index + body: + retriever: + rrf: + fields: [ "keyword" ] + query: "keyword match 1" + + - match: { hits.total.value: 1 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "1" } + +--- +"Can query date fields": + - do: + search: + index: test-index + body: + retriever: + rrf: + fields: [ "timestamp" ] + query: "2010-02-08" + + - match: { hits.total.value: 1 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "2" } + +--- +"Can query sparse vector fields": + - do: + search: + index: test-index + body: + retriever: + rrf: + fields: [ "sparse_vector" ] + query: "foo" + + - match: { hits.total.value: 1 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "1" } + +--- +"Cannot query dense vector fields": + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + rrf: + fields: [ "dense_vector" ] + query: "foo" + + - contains: { error.root_cause.0.reason: "[rrf] search failed - retrievers '[standard]' returned errors" } + - contains: { error.root_cause.0.suppressed.0.failed_shards.0.reason.reason: "Field [dense_vector] of type [dense_vector] does not support match queries" } + +--- +"Filters are propagated": + - do: + search: + index: test-index + body: + retriever: + rrf: + query: "match" + filter: + - term: + keyword: "keyword match 1" + + - match: { hits.total.value: 1 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "1" } + +--- +"Wildcard index patterns that do not resolve to any index are handled gracefully": + - do: + search: + index: wildcard-* + body: + retriever: + rrf: + query: "match" + + - match: { hits.total.value: 0 } + - length: { hits.hits: 0 } + +--- +"Multi-index searches are not allowed": + - do: + indices.create: + index: test-index-2 + + - do: + catch: bad_request + search: + index: [ test-index, test-index-2 ] + body: + retriever: + rrf: + query: "match" + + - match: { error.root_cause.0.reason: "[rrf] cannot specify [query] when querying multiple indices" } + + - do: + indices.put_alias: + index: test-index + name: test-alias + - do: + indices.put_alias: + index: test-index-2 + name: test-alias + + - do: + catch: bad_request + search: + index: test-alias + body: + retriever: + rrf: + query: "match" + + - match: { error.root_cause.0.reason: "[rrf] cannot specify [query] when querying multiple indices" } + +--- +"Wildcard field patterns that do not resolve to any field are handled gracefully": + - do: + search: + index: test-index + body: + retriever: + rrf: + fields: [ "wildcard-*" ] + query: "match" + + - match: { hits.total.value: 0 } + - length: { hits.hits: 0 } + +--- +"Cannot mix simplified query format with custom sub-retrievers": + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + rrf: + query: "foo" + retrievers: + - standard: + query: + match: + keyword: "bar" + + - contains: { error.root_cause.0.reason: "[rrf] cannot combine [retrievers] and [query]" } + +--- +"Missing required params": + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + rrf: + fields: [ "text_1", "text_2" ] + + - contains: { error.root_cause.0.reason: "[rrf] [query] must be provided when [fields] is specified" } + + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + rrf: + fields: [ "text_1", "text_2" ] + query: "" + + - contains: { error.root_cause.0.reason: "[rrf] [query] cannot be empty" } + + - do: + catch: bad_request + search: + index: test-index + body: + retriever: + rrf: {} + + - contains: { error.root_cause.0.reason: "[rrf] must provide [retrievers] or [query]" }