-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: zhichao-aws <[email protected]>
- Loading branch information
1 parent
b084838
commit f1a1765
Showing
10 changed files
with
440 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
src/main/java/org/opensearch/neuralsearch/analysis/DJLUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.analysis; | ||
|
||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; | ||
import ai.djl.util.Utils; | ||
|
||
import java.io.BufferedReader; | ||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.io.InputStreamReader; | ||
import java.nio.charset.StandardCharsets; | ||
import java.nio.file.Path; | ||
import java.security.AccessController; | ||
import java.security.PrivilegedActionException; | ||
import java.security.PrivilegedExceptionAction; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
import java.util.concurrent.Callable; | ||
|
||
public class DJLUtils { | ||
static private Path ML_CACHE_PATH; | ||
static private String ML_CACHE_DIR_NAME = "ml_cache"; | ||
static private String HUGGING_FACE_BASE_URL = "https://huggingface.co/"; | ||
static private String HUGGING_FACE_RESOLVE_PATH = "resolve/main/"; | ||
|
||
static public void buildDJLCachePath(Path opensearchDataFolder) { | ||
// the logic to build cache path is consistent with ml-commons plugin | ||
// see | ||
// https://github.com/opensearch-project/ml-commons/blob/14b971214c488aa3f4ab150d1a6cc379df1758be/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java#L53 | ||
ML_CACHE_PATH = opensearchDataFolder.resolve(ML_CACHE_DIR_NAME); | ||
} | ||
|
||
public static <T> T withDJLContext(Callable<T> action) throws PrivilegedActionException { | ||
return AccessController.doPrivileged((PrivilegedExceptionAction<T>) () -> { | ||
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); | ||
try { | ||
System.setProperty("java.library.path", ML_CACHE_PATH.toAbsolutePath().toString()); | ||
System.setProperty("DJL_CACHE_DIR", ML_CACHE_PATH.toAbsolutePath().toString()); | ||
Thread.currentThread().setContextClassLoader(ai.djl.Model.class.getClassLoader()); | ||
|
||
return action.call(); | ||
} finally { | ||
Thread.currentThread().setContextClassLoader(contextClassLoader); | ||
} | ||
}); | ||
} | ||
|
||
public static HuggingFaceTokenizer buildHuggingFaceTokenizer(String tokenizerId) { | ||
try { | ||
return withDJLContext(() -> HuggingFaceTokenizer.newInstance(tokenizerId)); | ||
} catch (PrivilegedActionException e) { | ||
throw new RuntimeException("Failed to initialize Hugging Face tokenizer. " + e); | ||
} | ||
} | ||
|
||
public static Map<String, Float> parseInputStreamToTokenWeights(InputStream inputStream) { | ||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { | ||
Map<String, Float> tokenWeights = new HashMap<>(); | ||
String line; | ||
while ((line = reader.readLine()) != null) { | ||
if (line.trim().isEmpty()) { | ||
continue; | ||
} | ||
String[] parts = line.split("\t"); | ||
if (parts.length != 2) { | ||
throw new IllegalArgumentException("Invalid line in token weights file: " + line); | ||
} | ||
String token = parts[0]; | ||
float weight = Float.parseFloat(parts[1]); | ||
tokenWeights.put(token, weight); | ||
} | ||
return tokenWeights; | ||
} catch (IOException e) { | ||
throw new RuntimeException("Failed to parse token weights file. " + e); | ||
} | ||
} | ||
|
||
public static Map<String, Float> fetchTokenWeights(String tokenizerId, String fileName) { | ||
Map<String, Float> tokenWeights = new HashMap<>(); | ||
String url = HUGGING_FACE_BASE_URL + tokenizerId + "/" + HUGGING_FACE_RESOLVE_PATH + fileName; | ||
|
||
InputStream inputStream = null; | ||
try { | ||
inputStream = withDJLContext(() -> Utils.openUrl(url)); | ||
} catch (PrivilegedActionException e) { | ||
throw new RuntimeException("Failed to download file from " + url, e); | ||
} | ||
|
||
return parseInputStreamToTokenWeights(inputStream); | ||
} | ||
} |
29 changes: 29 additions & 0 deletions
29
src/main/java/org/opensearch/neuralsearch/analysis/HFModelAnalyzer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.analysis; | ||
|
||
import org.apache.lucene.analysis.Analyzer; | ||
import org.apache.lucene.analysis.Tokenizer; | ||
|
||
import java.util.function.Supplier; | ||
|
||
public class HFModelAnalyzer extends Analyzer { | ||
public static final String NAME = "hf_model_tokenizer"; | ||
Supplier<Tokenizer> tokenizerSupplier; | ||
|
||
public HFModelAnalyzer() { | ||
this.tokenizerSupplier = HFModelTokenizerFactory::createDefault; | ||
} | ||
|
||
HFModelAnalyzer(Supplier<Tokenizer> tokenizerSupplier) { | ||
this.tokenizerSupplier = tokenizerSupplier; | ||
} | ||
|
||
@Override | ||
protected TokenStreamComponents createComponents(String fieldName) { | ||
final Tokenizer src = tokenizerSupplier.get(); | ||
return new TokenStreamComponents(src, src); | ||
} | ||
} |
25 changes: 25 additions & 0 deletions
25
src/main/java/org/opensearch/neuralsearch/analysis/HFModelAnalyzerProvider.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.analysis; | ||
|
||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.env.Environment; | ||
import org.opensearch.index.IndexSettings; | ||
import org.opensearch.index.analysis.AbstractIndexAnalyzerProvider; | ||
|
||
public class HFModelAnalyzerProvider extends AbstractIndexAnalyzerProvider<HFModelAnalyzer> { | ||
private final HFModelAnalyzer analyzer; | ||
|
||
public HFModelAnalyzerProvider(IndexSettings indexSettings, Environment environment, String name, Settings settings) { | ||
super(indexSettings, name, settings); | ||
HFModelTokenizerFactory tokenizerFactory = new HFModelTokenizerFactory(indexSettings, environment, name, settings); | ||
analyzer = new HFModelAnalyzer(tokenizerFactory::create); | ||
} | ||
|
||
@Override | ||
public HFModelAnalyzer get() { | ||
return analyzer; | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
src/main/java/org/opensearch/neuralsearch/analysis/HFModelTokenizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.analysis; | ||
|
||
import java.io.IOException; | ||
import java.nio.ByteBuffer; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
import com.google.common.io.CharStreams; | ||
import org.apache.lucene.analysis.Tokenizer; | ||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; | ||
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; | ||
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; | ||
|
||
import ai.djl.huggingface.tokenizers.Encoding; | ||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; | ||
import org.apache.lucene.util.BytesRef; | ||
|
||
public class HFModelTokenizer extends Tokenizer { | ||
public static final String NAME = "hf_model_tokenizer"; | ||
private static final Float DEFAULT_TOKEN_WEIGHT = 1.0f; | ||
|
||
private final CharTermAttribute termAtt; | ||
private final PayloadAttribute payloadAtt; | ||
private final OffsetAttribute offsetAtt; | ||
private final HuggingFaceTokenizer tokenizer; | ||
private final Map<String, Float> tokenWeights; | ||
|
||
private Encoding encoding; | ||
private int tokenIdx = 0; | ||
private int overflowingIdx = 0; | ||
|
||
public HFModelTokenizer(HuggingFaceTokenizer huggingFaceTokenizer) { | ||
this(huggingFaceTokenizer, null); | ||
} | ||
|
||
public HFModelTokenizer(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, Float> weights) { | ||
termAtt = addAttribute(CharTermAttribute.class); | ||
offsetAtt = addAttribute(OffsetAttribute.class); | ||
if (Objects.nonNull(weights)) { | ||
payloadAtt = addAttribute(PayloadAttribute.class); | ||
} else { | ||
payloadAtt = null; | ||
} | ||
tokenizer = huggingFaceTokenizer; | ||
tokenWeights = weights; | ||
} | ||
|
||
@Override | ||
public void reset() throws IOException { | ||
super.reset(); | ||
tokenIdx = 0; | ||
overflowingIdx = -1; | ||
String inputStr = CharStreams.toString(input); | ||
encoding = tokenizer.encode(inputStr, false, true); | ||
} | ||
|
||
private static boolean isLastTokenInEncodingSegment(int idx, Encoding encodingSegment) { | ||
return idx >= encodingSegment.getTokens().length || encodingSegment.getAttentionMask()[idx] == 0; | ||
} | ||
|
||
public static byte[] floatToBytes(float value) { | ||
return ByteBuffer.allocate(4).putFloat(value).array(); | ||
} | ||
|
||
public static float bytesToFloat(byte[] bytes) { | ||
return ByteBuffer.wrap(bytes).getFloat(); | ||
} | ||
|
||
@Override | ||
final public boolean incrementToken() throws IOException { | ||
clearAttributes(); | ||
Encoding curEncoding = overflowingIdx == -1 ? encoding : encoding.getOverflowing()[overflowingIdx]; | ||
|
||
while (!isLastTokenInEncodingSegment(tokenIdx, curEncoding) || overflowingIdx < encoding.getOverflowing().length) { | ||
if (isLastTokenInEncodingSegment(tokenIdx, curEncoding)) { | ||
// reset cur segment, go to the next segment | ||
// until overflowingIdx = encoding.getOverflowing().length | ||
tokenIdx = 0; | ||
overflowingIdx++; | ||
if (overflowingIdx >= encoding.getOverflowing().length) { | ||
return false; | ||
} | ||
curEncoding = encoding.getOverflowing()[overflowingIdx]; | ||
} else { | ||
termAtt.append(curEncoding.getTokens()[tokenIdx]); | ||
offsetAtt.setOffset( | ||
curEncoding.getCharTokenSpans()[tokenIdx].getStart(), | ||
curEncoding.getCharTokenSpans()[tokenIdx].getEnd() | ||
); | ||
if (Objects.nonNull(tokenWeights)) { | ||
// for neural sparse query, write the token weight to payload field | ||
payloadAtt.setPayload( | ||
new BytesRef(floatToBytes(tokenWeights.getOrDefault(curEncoding.getTokens()[tokenIdx], DEFAULT_TOKEN_WEIGHT))) | ||
); | ||
} | ||
tokenIdx++; | ||
return true; | ||
} | ||
} | ||
|
||
return false; | ||
} | ||
} |
65 changes: 65 additions & 0 deletions
65
src/main/java/org/opensearch/neuralsearch/analysis/HFModelTokenizerFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.analysis; | ||
|
||
import org.apache.lucene.analysis.Tokenizer; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.env.Environment; | ||
import org.opensearch.index.IndexSettings; | ||
import org.opensearch.index.analysis.AbstractTokenizerFactory; | ||
|
||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; | ||
|
||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
public class HFModelTokenizerFactory extends AbstractTokenizerFactory { | ||
private final HuggingFaceTokenizer tokenizer; | ||
private final Map<String, Float> tokenWeights; | ||
|
||
/** | ||
* Atomically loads the HF tokenizer in a lazy fashion once the outer class accesses the static final set the first time.; | ||
*/ | ||
private static class DefaultTokenizerHolder { | ||
static final HuggingFaceTokenizer TOKENIZER; | ||
static final Map<String, Float> TOKEN_WEIGHTS; | ||
static private final String DEFAULT_TOKENIZER_ID = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill"; | ||
static private final String DEFAULT_TOKEN_WEIGHTS_FILE = "query_token_weights.txt"; | ||
|
||
static { | ||
try { | ||
TOKENIZER = DJLUtils.buildHuggingFaceTokenizer(DEFAULT_TOKENIZER_ID); | ||
TOKEN_WEIGHTS = DJLUtils.fetchTokenWeights(DEFAULT_TOKENIZER_ID, DEFAULT_TOKEN_WEIGHTS_FILE); | ||
} catch (Exception e) { | ||
throw new RuntimeException("Failed to initialize default hf_model_tokenizer", e); | ||
} | ||
} | ||
} | ||
|
||
static public Tokenizer createDefault() { | ||
return new HFModelTokenizer(DefaultTokenizerHolder.TOKENIZER, DefaultTokenizerHolder.TOKEN_WEIGHTS); | ||
} | ||
|
||
public HFModelTokenizerFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) { | ||
// For custom tokenizer, the factory is created during IndexModule.newIndexService | ||
// And can be accessed via indexService.getIndexAnalyzers() | ||
super(indexSettings, settings, name); | ||
String tokenizerId = settings.get("tokenizer_id", null); | ||
Objects.requireNonNull(tokenizerId, "tokenizer_id is required"); | ||
String tokenWeightsFileName = settings.get("token_weights_file", null); | ||
tokenizer = DJLUtils.buildHuggingFaceTokenizer(tokenizerId); | ||
if (tokenWeightsFileName != null) { | ||
tokenWeights = DJLUtils.fetchTokenWeights(tokenizerId, tokenWeightsFileName); | ||
} else { | ||
tokenWeights = null; | ||
} | ||
} | ||
|
||
@Override | ||
public Tokenizer create() { | ||
// the create method will be called for every single analyze request | ||
return new HFModelTokenizer(tokenizer, tokenWeights); | ||
} | ||
} |
Oops, something went wrong.