diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/BytesRefFieldComparatorSource.java b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/BytesRefFieldComparatorSource.java index 1a12e4c7733a9..d085c78cfb6f2 100644 --- a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/BytesRefFieldComparatorSource.java +++ b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/BytesRefFieldComparatorSource.java @@ -15,7 +15,6 @@ import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FieldComparator; -import org.apache.lucene.search.LeafFieldComparator; import org.apache.lucene.search.Pruning; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.SortField; @@ -70,6 +69,20 @@ protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOEx protected void setScorer(LeafReaderContext context, Scorable scorer) {} + protected BinaryDocValues getBinaryDocValues(LeafReaderContext context, BytesRef missingBytes, SortedBinaryDocValues values) + throws IOException { + final BinaryDocValues selectedValues; + if (nested == null) { + selectedValues = sortMode.select(values, missingBytes); + } else { + final BitSet rootDocs = nested.rootDocs(context); + final DocIdSetIterator innerDocs = nested.innerDocs(context); + final int maxChildren = nested.getNestedSort() != null ? nested.getNestedSort().getMaxChildren() : Integer.MAX_VALUE; + selectedValues = sortMode.select(values, missingBytes, rootDocs, innerDocs, maxChildren); + } + return selectedValues; + } + @Override public FieldComparator newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) { assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName()); @@ -102,61 +115,22 @@ protected SortedDocValues getSortedDocValues(LeafReaderContext context, String f }; } + return newComparatorWithoutOrdinal(fieldname, numHits, enableSkipping, reversed, missingBytes, sortMissingLast); + } + protected FieldComparator newComparatorWithoutOrdinal( + String fieldname, + int numHits, + Pruning enableSkipping, + boolean reversed, + BytesRef missingBytes, + boolean sortMissingLast + ) { return new FieldComparator.TermValComparator(numHits, null, sortMissingLast) { @Override protected BinaryDocValues getBinaryDocValues(LeafReaderContext context, String field) throws IOException { - final SortedBinaryDocValues values = getValues(context); - final BinaryDocValues selectedValues; - if (nested == null) { - selectedValues = sortMode.select(values, missingBytes); - } else { - final BitSet rootDocs = nested.rootDocs(context); - final DocIdSetIterator innerDocs = nested.innerDocs(context); - final int maxChildren = nested.getNestedSort() != null ? nested.getNestedSort().getMaxChildren() : Integer.MAX_VALUE; - selectedValues = sortMode.select(values, missingBytes, rootDocs, innerDocs, maxChildren); - } - return selectedValues; - } - - @Override - public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { - LeafFieldComparator leafComparator = super.getLeafComparator(context); - // TopFieldCollector interacts with inter-segment concurrency by creating a FieldValueHitQueue per slice, each one with a - // specific instance of the FieldComparator. This ensures sequential execution across LeafFieldComparators returned by - // the same parent FieldComparator. That allows for effectively sharing the same instance of leaf comparator, like in this - // case in the Lucene code. That's fine dealing with sorting by field, but not when using script sorting, because we then - // need to set to Scorer to the specific leaf comparator, to make the _score variable available in sort scripts. The - // setScorer call happens concurrently across slices and needs to target the specific leaf context that is being searched. - return new LeafFieldComparator() { - @Override - public void setBottom(int slot) throws IOException { - leafComparator.setBottom(slot); - } - - @Override - public int compareBottom(int doc) throws IOException { - return leafComparator.compareBottom(doc); - } - - @Override - public int compareTop(int doc) throws IOException { - return leafComparator.compareTop(doc); - } - - @Override - public void copy(int slot, int doc) throws IOException { - leafComparator.copy(slot, doc); - } - - @Override - public void setScorer(Scorable scorer) { - // this ensures that the scorer is set for the specific leaf comparator - // corresponding to the leaf context we are scoring - BytesRefFieldComparatorSource.this.setScorer(context, scorer); - } - }; + return BytesRefFieldComparatorSource.this.getBinaryDocValues(context, missingBytes, getValues(context)); } }; } diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/DoubleValuesComparatorSource.java b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/DoubleValuesComparatorSource.java index c5fcb0207ce4d..f047d494d625f 100644 --- a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/DoubleValuesComparatorSource.java +++ b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/DoubleValuesComparatorSource.java @@ -60,7 +60,11 @@ protected SortedNumericDoubleValues getValues(LeafReaderContext context) throws } private NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue) throws IOException { - final SortedNumericDoubleValues values = getValues(context); + return getNumericDocValues(context, missingValue, getValues(context)); + } + + protected NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue, SortedNumericDoubleValues values) + throws IOException { if (nested == null) { return FieldData.replaceMissing(sortMode.select(values), missingValue); } else { @@ -78,6 +82,10 @@ public FieldComparator newComparator(String fieldname, int numHits, Pruning e assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName()); final double dMissingValue = (Double) missingObject(missingValue, reversed); + return newComparator(numHits, enableSkipping, reversed, dMissingValue); + } + + protected FieldComparator newComparator(int numHits, Pruning enableSkipping, boolean reversed, double dMissingValue) { // NOTE: it's important to pass null as a missing value in the constructor so that // the comparator doesn't check docsWithField since we replace missing values in select() return new DoubleComparator(numHits, null, null, reversed, Pruning.NONE) { diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/ScriptDoubleValuesComparatorSource.java b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/ScriptDoubleValuesComparatorSource.java new file mode 100644 index 0000000000000..404175c1484d3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/ScriptDoubleValuesComparatorSource.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.fielddata.fieldcomparator; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.Pruning; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.comparators.DoubleComparator; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.fielddata.FieldData; +import org.elasticsearch.index.fielddata.IndexNumericFieldData; +import org.elasticsearch.index.fielddata.NumericDoubleValues; +import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; +import org.elasticsearch.script.NumberSortScript; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.MultiValueMode; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.io.IOException; + +/** + * Script comparator source for double values. + */ +public class ScriptDoubleValuesComparatorSource extends DoubleValuesComparatorSource { + + private final CheckedFunction scriptSupplier; + + public ScriptDoubleValuesComparatorSource( + CheckedFunction scriptSupplier, + IndexNumericFieldData indexFieldData, + @Nullable Object missingValue, + MultiValueMode sortMode, + Nested nested + ) { + super(indexFieldData, missingValue, sortMode, nested); + this.scriptSupplier = scriptSupplier; + } + + private SortedNumericDoubleValues getValues(NumberSortScript leafScript) throws IOException { + final NumericDoubleValues values = new NumericDoubleValues() { + @Override + public boolean advanceExact(int doc) { + leafScript.setDocument(doc); + return true; + } + + @Override + public double doubleValue() { + return leafScript.execute(); + } + }; + return FieldData.singleton(values); + } + + private NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue, NumberSortScript leafScript) + throws IOException { + return getNumericDocValues(context, missingValue, getValues(leafScript)); + } + + @Override + protected FieldComparator newComparator(int numHits, Pruning enableSkipping, boolean reversed, double dMissingValue) { + // NOTE: it's important to pass null as a missing value in the constructor so that + // the comparator doesn't check docsWithField since we replace missing values in select() + return new DoubleComparator(numHits, null, null, reversed, Pruning.NONE) { + @Override + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + NumberSortScript leafScript = scriptSupplier.apply(context); + return new DoubleLeafComparator(context) { + @Override + protected NumericDocValues getNumericDocValues(LeafReaderContext context, String field) throws IOException { + return ScriptDoubleValuesComparatorSource.this.getNumericDocValues(context, dMissingValue, leafScript) + .getRawDoubleValues(); + } + + @Override + public void setScorer(Scorable scorer) { + leafScript.setScorer(scorer); + } + }; + } + }; + } + + @Override + public BucketedSort newBucketedSort( + BigArrays bigArrays, + SortOrder sortOrder, + DocValueFormat format, + int bucketSize, + BucketedSort.ExtraData extra + ) { + return new BucketedSort.ForDoubles(bigArrays, sortOrder, format, bucketSize, extra) { + private final double dMissingValue = (Double) missingObject(missingValue, sortOrder == SortOrder.DESC); + + @Override + public Leaf forLeaf(LeafReaderContext ctx) throws IOException { + NumberSortScript leafScript = scriptSupplier.apply(ctx); + return new Leaf(ctx) { + private final NumericDoubleValues docValues = getNumericDocValues(ctx, dMissingValue, leafScript); + private double docValue; + + @Override + protected boolean advanceExact(int doc) throws IOException { + if (docValues.advanceExact(doc)) { + docValue = docValues.doubleValue(); + return true; + } + return false; + } + + @Override + protected double docValue() { + return docValue; + } + }; + } + }; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/ScriptStringFieldComparatorSource.java b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/ScriptStringFieldComparatorSource.java new file mode 100644 index 0000000000000..35f2096ea4cda --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/ScriptStringFieldComparatorSource.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.fielddata.fieldcomparator; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.Pruning; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.index.fielddata.AbstractBinaryDocValues; +import org.elasticsearch.index.fielddata.FieldData; +import org.elasticsearch.index.fielddata.IndexFieldData; +import org.elasticsearch.index.fielddata.SortedBinaryDocValues; +import org.elasticsearch.script.StringSortScript; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.MultiValueMode; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.ScriptSortBuilder; +import org.elasticsearch.search.sort.SortOrder; + +import java.io.IOException; + +/** + * Script comparator source for string/binary values. + */ +public class ScriptStringFieldComparatorSource extends BytesRefFieldComparatorSource { + + final CheckedFunction scriptSupplier; + + public ScriptStringFieldComparatorSource( + CheckedFunction scriptSupplier, + IndexFieldData indexFieldData, + Object missingValue, + MultiValueMode sortMode, + Nested nested + ) { + super(indexFieldData, missingValue, sortMode, nested); + this.scriptSupplier = scriptSupplier; + } + + private SortedBinaryDocValues getValues(StringSortScript leafScript) throws IOException { + final BinaryDocValues values = new AbstractBinaryDocValues() { + final BytesRefBuilder spare = new BytesRefBuilder(); + + @Override + public boolean advanceExact(int doc) { + leafScript.setDocument(doc); + return true; + } + + @Override + public BytesRef binaryValue() { + spare.copyChars(leafScript.execute()); + return spare.get(); + } + }; + return FieldData.singleton(values); + } + + @Override + protected FieldComparator newComparatorWithoutOrdinal( + String fieldname, + int numHits, + Pruning enableSkipping, + boolean reversed, + BytesRef missingBytes, + boolean sortMissingLast + ) { + return new FieldComparator.TermValComparator(numHits, null, sortMissingLast) { + + StringSortScript leafScript; + + @Override + protected BinaryDocValues getBinaryDocValues(LeafReaderContext context, String field) throws IOException { + leafScript = scriptSupplier.apply(context); + return ScriptStringFieldComparatorSource.this.getBinaryDocValues(context, missingBytes, getValues(leafScript)); + } + + @Override + public void setScorer(Scorable scorer) { + leafScript.setScorer(scorer); + } + }; + } + + @Override + public BucketedSort newBucketedSort( + BigArrays bigArrays, + SortOrder sortOrder, + DocValueFormat format, + int bucketSize, + BucketedSort.ExtraData extra + ) { + throw new IllegalArgumentException( + "error building sort for [_script]: " + + "script sorting only supported on [numeric] scripts but was [" + + ScriptSortBuilder.ScriptSortType.STRING + + "]" + ); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/ScriptVersionFieldComparatorSource.java b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/ScriptVersionFieldComparatorSource.java new file mode 100644 index 0000000000000..11826fcc35faf --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/fielddata/fieldcomparator/ScriptVersionFieldComparatorSource.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.fielddata.fieldcomparator; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.Pruning; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.core.CheckedFunction; +import org.elasticsearch.index.fielddata.AbstractBinaryDocValues; +import org.elasticsearch.index.fielddata.FieldData; +import org.elasticsearch.index.fielddata.IndexFieldData; +import org.elasticsearch.index.fielddata.SortedBinaryDocValues; +import org.elasticsearch.script.BytesRefProducer; +import org.elasticsearch.script.BytesRefSortScript; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.MultiValueMode; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.ScriptSortBuilder; +import org.elasticsearch.search.sort.SortOrder; + +import java.io.IOException; + +/** + * Script comparator source for version values. + */ +public class ScriptVersionFieldComparatorSource extends BytesRefFieldComparatorSource { + + private final CheckedFunction scriptSupplier; + private final DocValueFormat scriptResultValueFormat; + + public ScriptVersionFieldComparatorSource( + CheckedFunction scriptSupplier, + DocValueFormat scriptResultValueFormat, + IndexFieldData indexFieldData, + Object missingValue, + MultiValueMode sortMode, + Nested nested + ) { + super(indexFieldData, missingValue, sortMode, nested); + this.scriptResultValueFormat = scriptResultValueFormat; + this.scriptSupplier = scriptSupplier; + } + + private SortedBinaryDocValues getValues(BytesRefSortScript leafScript) throws IOException { + final BinaryDocValues values = new AbstractBinaryDocValues() { + + @Override + public boolean advanceExact(int doc) { + leafScript.setDocument(doc); + return true; + } + + @Override + public BytesRef binaryValue() { + Object result = leafScript.execute(); + if (result == null) { + return null; + } + if (result instanceof BytesRefProducer) { + return ((BytesRefProducer) result).toBytesRef(); + } + + if (scriptResultValueFormat == null) { + throw new IllegalArgumentException("Invalid sort type: version"); + } + return scriptResultValueFormat.parseBytesRef(result); + } + }; + return FieldData.singleton(values); + } + + @Override + protected FieldComparator newComparatorWithoutOrdinal( + String fieldname, + int numHits, + Pruning enableSkipping, + boolean reversed, + BytesRef missingBytes, + boolean sortMissingLast + ) { + return new FieldComparator.TermValComparator(numHits, null, sortMissingLast) { + BytesRefSortScript leafScript; + + @Override + protected BinaryDocValues getBinaryDocValues(LeafReaderContext context, String field) throws IOException { + leafScript = scriptSupplier.apply(context); + return ScriptVersionFieldComparatorSource.this.getBinaryDocValues(context, missingBytes, getValues(leafScript)); + } + + @Override + public void setScorer(Scorable scorer) { + leafScript.setScorer(scorer); + } + }; + } + + @Override + public BucketedSort newBucketedSort( + BigArrays bigArrays, + SortOrder sortOrder, + DocValueFormat format, + int bucketSize, + BucketedSort.ExtraData extra + ) { + throw new IllegalArgumentException( + "error building sort for [_script]: " + + "script sorting only supported on [numeric] scripts but was [" + + ScriptSortBuilder.ScriptSortType.VERSION + + "]" + ); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/sort/ScriptSortBuilder.java b/server/src/main/java/org/elasticsearch/search/sort/ScriptSortBuilder.java index b3c88be60c179..b497dbae5de7b 100644 --- a/server/src/main/java/org/elasticsearch/search/sort/ScriptSortBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/sort/ScriptSortBuilder.java @@ -9,34 +9,23 @@ package org.elasticsearch.search.sort; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.Scorable; import org.apache.lucene.search.SortField; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.BytesRefBuilder; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.concurrent.ConcurrentCollections; -import org.elasticsearch.index.fielddata.AbstractBinaryDocValues; -import org.elasticsearch.index.fielddata.FieldData; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested; -import org.elasticsearch.index.fielddata.NumericDoubleValues; -import org.elasticsearch.index.fielddata.SortedBinaryDocValues; -import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; -import org.elasticsearch.index.fielddata.fieldcomparator.BytesRefFieldComparatorSource; -import org.elasticsearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.ScriptDoubleValuesComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.ScriptStringFieldComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.ScriptVersionFieldComparatorSource; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.QueryShardException; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.script.BytesRefProducer; import org.elasticsearch.script.BytesRefSortScript; import org.elasticsearch.script.DocValuesDocReader; import org.elasticsearch.script.NumberSortScript; @@ -53,7 +42,6 @@ import java.io.IOException; import java.util.Locale; -import java.util.Map; import java.util.Objects; import static org.elasticsearch.search.sort.FieldSortBuilder.validateMaxChildrenExistOnlyInTopLevelNestedSort; @@ -279,146 +267,37 @@ private IndexFieldData.XFieldComparatorSource fieldComparatorSource(SearchExecut case STRING -> { final StringSortScript.Factory factory = context.compile(script, StringSortScript.CONTEXT); final StringSortScript.LeafFactory searchScript = factory.newFactory(script.getParams()); - return new BytesRefFieldComparatorSource(null, null, valueMode, nested) { - final Map leafScripts = ConcurrentCollections.newConcurrentMap(); - - @Override - protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOException { - // we may see the same leaf context multiple times, and each time we need to refresh the doc values doc reader - StringSortScript leafScript = searchScript.newInstance(new DocValuesDocReader(searchLookup, context)); - leafScripts.put(context.id(), leafScript); - final BinaryDocValues values = new AbstractBinaryDocValues() { - final BytesRefBuilder spare = new BytesRefBuilder(); - - @Override - public boolean advanceExact(int doc) { - leafScript.setDocument(doc); - return true; - } - - @Override - public BytesRef binaryValue() { - spare.copyChars(leafScript.execute()); - return spare.get(); - } - }; - return FieldData.singleton(values); - } - - @Override - protected void setScorer(LeafReaderContext context, Scorable scorer) { - leafScripts.get(context.id()).setScorer(scorer); - } - - @Override - public BucketedSort newBucketedSort( - BigArrays bigArrays, - SortOrder sortOrder, - DocValueFormat format, - int bucketSize, - BucketedSort.ExtraData extra - ) { - throw new IllegalArgumentException( - "error building sort for [_script]: " - + "script sorting only supported on [numeric] scripts but was [" - + type - + "]" - ); - } - }; + return new ScriptStringFieldComparatorSource( + leafReaderContext -> searchScript.newInstance(new DocValuesDocReader(searchLookup, leafReaderContext)), + null, + null, + valueMode, + nested + ); } case NUMBER -> { final NumberSortScript.Factory numberSortFactory = context.compile(script, NumberSortScript.CONTEXT); // searchLookup is unnecessary here, as it's just used for expressions - final NumberSortScript.LeafFactory numberSortScriptFactory = numberSortFactory.newFactory(script.getParams(), searchLookup); - return new DoubleValuesComparatorSource(null, Double.MAX_VALUE, valueMode, nested) { - final Map leafScripts = ConcurrentCollections.newConcurrentMap(); - - @Override - protected SortedNumericDoubleValues getValues(LeafReaderContext context) throws IOException { - // we may see the same leaf context multiple times, and each time we need to refresh the doc values doc reader - NumberSortScript leafScript = numberSortScriptFactory.newInstance(new DocValuesDocReader(searchLookup, context)); - leafScripts.put(context.id(), leafScript); - final NumericDoubleValues values = new NumericDoubleValues() { - @Override - public boolean advanceExact(int doc) { - leafScript.setDocument(doc); - return true; - } - - @Override - public double doubleValue() { - return leafScript.execute(); - } - }; - return FieldData.singleton(values); - } - - @Override - protected void setScorer(LeafReaderContext context, Scorable scorer) { - leafScripts.get(context.id()).setScorer(scorer); - } - }; + final NumberSortScript.LeafFactory numberSortScript = numberSortFactory.newFactory(script.getParams(), searchLookup); + return new ScriptDoubleValuesComparatorSource( + leafReaderContext -> numberSortScript.newInstance(new DocValuesDocReader(searchLookup, leafReaderContext)), + null, + Double.MAX_VALUE, + valueMode, + nested + ); } case VERSION -> { final BytesRefSortScript.Factory factory = context.compile(script, BytesRefSortScript.CONTEXT); final BytesRefSortScript.LeafFactory searchScript = factory.newFactory(script.getParams()); - return new BytesRefFieldComparatorSource(null, null, valueMode, nested) { - final Map leafScripts = ConcurrentCollections.newConcurrentMap(); - - @Override - protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOException { - // we may see the same leaf context multiple times, and each time we need to refresh the doc values doc reader - BytesRefSortScript leafScript = searchScript.newInstance(new DocValuesDocReader(searchLookup, context)); - leafScripts.put(context.id(), leafScript); - final BinaryDocValues values = new AbstractBinaryDocValues() { - - @Override - public boolean advanceExact(int doc) { - leafScript.setDocument(doc); - return true; - } - - @Override - public BytesRef binaryValue() { - Object result = leafScript.execute(); - if (result == null) { - return null; - } - if (result instanceof BytesRefProducer) { - return ((BytesRefProducer) result).toBytesRef(); - } - - if (scriptResultValueFormat == null) { - throw new IllegalArgumentException("Invalid sort type: version"); - } - return scriptResultValueFormat.parseBytesRef(result); - } - }; - return FieldData.singleton(values); - } - - @Override - protected void setScorer(LeafReaderContext context, Scorable scorer) { - leafScripts.get(context.id()).setScorer(scorer); - } - - @Override - public BucketedSort newBucketedSort( - BigArrays bigArrays, - SortOrder sortOrder, - DocValueFormat format, - int bucketSize, - BucketedSort.ExtraData extra - ) { - throw new IllegalArgumentException( - "error building sort for [_script]: " - + "script sorting only supported on [numeric] scripts but was [" - + type - + "]" - ); - } - }; + return new ScriptVersionFieldComparatorSource( + leafReaderContext -> searchScript.newInstance(new DocValuesDocReader(searchLookup, leafReaderContext)), + scriptResultValueFormat, + null, + null, + valueMode, + nested + ); } default -> throw new QueryShardException(context, "custom script sort type [" + type + "] not supported"); } diff --git a/server/src/test/java/org/elasticsearch/search/sort/ScriptSortBuilderTests.java b/server/src/test/java/org/elasticsearch/search/sort/ScriptSortBuilderTests.java index 872775e18c7d1..7ac0db300f979 100644 --- a/server/src/test/java/org/elasticsearch/search/sort/ScriptSortBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/sort/ScriptSortBuilderTests.java @@ -15,8 +15,9 @@ import org.apache.lucene.search.TermQuery; import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource; import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested; -import org.elasticsearch.index.fielddata.fieldcomparator.BytesRefFieldComparatorSource; -import org.elasticsearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.ScriptDoubleValuesComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.ScriptStringFieldComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.ScriptVersionFieldComparatorSource; import org.elasticsearch.index.mapper.NestedPathFieldMapper; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; @@ -301,15 +302,15 @@ public void testMultiValueMode() throws IOException { public void testBuildCorrectComparatorType() throws IOException { ScriptSortBuilder sortBuilder = new ScriptSortBuilder(mockScript(MOCK_SCRIPT_NAME), ScriptSortType.STRING); SortField sortField = sortBuilder.build(createMockSearchExecutionContext()).field(); - assertThat(sortField.getComparatorSource(), instanceOf(BytesRefFieldComparatorSource.class)); + assertThat(sortField.getComparatorSource(), instanceOf(ScriptStringFieldComparatorSource.class)); sortBuilder = new ScriptSortBuilder(mockScript(MOCK_SCRIPT_NAME), ScriptSortType.NUMBER); sortField = sortBuilder.build(createMockSearchExecutionContext()).field(); - assertThat(sortField.getComparatorSource(), instanceOf(DoubleValuesComparatorSource.class)); + assertThat(sortField.getComparatorSource(), instanceOf(ScriptDoubleValuesComparatorSource.class)); sortBuilder = new ScriptSortBuilder(mockScript(MOCK_SCRIPT_NAME), ScriptSortType.VERSION); sortField = sortBuilder.build(createMockSearchExecutionContext()).field(); - assertThat(sortField.getComparatorSource(), instanceOf(BytesRefFieldComparatorSource.class)); + assertThat(sortField.getComparatorSource(), instanceOf(ScriptVersionFieldComparatorSource.class)); } /**