From 977b7ea2baf32b586da8c2fa62abfe256962c702 Mon Sep 17 00:00:00 2001 From: ManasviGoyal Date: Mon, 18 Aug 2025 20:59:14 -0700 Subject: [PATCH] Make flatVectorsFormat injectable for custom format and scorers Signed-off-by: ManasviGoyal --- .../lucene99/Lucene99HnswVectorsFormat.java | 64 ++++++++++++++++++- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java index baf15174704e..837ec9837a85 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java @@ -18,6 +18,7 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; +import java.util.Objects; import java.util.concurrent.ExecutorService; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; @@ -132,8 +133,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { private final int beamWidth; /** The format for storing, reading, and merging vectors on disk. */ - private static final FlatVectorsFormat flatVectorsFormat = - new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + private final FlatVectorsFormat flatVectorsFormat; private final int numMergeWorkers; private final TaskExecutor mergeExec; @@ -168,7 +168,13 @@ public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) { */ public Lucene99HnswVectorsFormat( int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { - this(maxConn, beamWidth, numMergeWorkers, mergeExec, VERSION_CURRENT); + this( + maxConn, + beamWidth, + numMergeWorkers, + mergeExec, + VERSION_CURRENT, + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())); } /** @@ -190,6 +196,56 @@ public Lucene99HnswVectorsFormat( int numMergeWorkers, ExecutorService mergeExec, int writeVersion) { + this( + maxConn, + beamWidth, + numMergeWorkers, + mergeExec, + writeVersion, + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge. If null, the configured {@link + * MergeScheduler#getIntraMergeExecutor(MergePolicy.OneMerge)} is used. + * @param flatVectorsFormat the format used to store vectors on disk + */ + public Lucene99HnswVectorsFormat( + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService mergeExec, + FlatVectorsFormat flatVectorsFormat) { + this(maxConn, beamWidth, numMergeWorkers, mergeExec, VERSION_CURRENT, flatVectorsFormat); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge. If null, the configured {@link + * MergeScheduler#getIntraMergeExecutor(MergePolicy.OneMerge)} is used. + * @param writeVersion the version used for the writer to encode docID's (VarInt=0, GroupVarInt=1) + * @param flatVectorsFormat the format used to store vectors on disk + */ + Lucene99HnswVectorsFormat( + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService mergeExec, + int writeVersion, + FlatVectorsFormat flatVectorsFormat) { super("Lucene99HnswVectorsFormat"); if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { throw new IllegalArgumentException( @@ -218,6 +274,8 @@ public Lucene99HnswVectorsFormat( } else { this.mergeExec = null; } + // fail fast if caller forgot to supply a FlatVectorsFormat + this.flatVectorsFormat = Objects.requireNonNull(flatVectorsFormat, "flatVectorsFormat"); } @Override