Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support analyzer-based neural sparse query & build BERT tokenizer as pre-defined tokenizer #1088

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ dependencies {
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
implementation group: 'ai.djl', name: 'api', version: '0.28.0'
implementation group: 'ai.djl.huggingface', name: 'tokenizers', version: '0.28.0'
// ml-common excluded reflection for runtime so we need to add it by ourselves.
// https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9
// TODO: Remove following three lines of dependencies if ml-common include them in their jar
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.analysis;

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.util.Utils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Callable;

public class DJLUtils {
static private Path ML_CACHE_PATH;
static private String ML_CACHE_DIR_NAME = "ml_cache";
static private String HUGGING_FACE_BASE_URL = "https://huggingface.co/";
static private String HUGGING_FACE_RESOLVE_PATH = "resolve/main/";

static public void buildDJLCachePath(Path opensearchDataFolder) {
// the logic to build cache path is consistent with ml-commons plugin
// see
// https://github.com/opensearch-project/ml-commons/blob/14b971214c488aa3f4ab150d1a6cc379df1758be/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java#L53
ML_CACHE_PATH = opensearchDataFolder.resolve(ML_CACHE_DIR_NAME);
}

public static <T> T withDJLContext(Callable<T> action) throws PrivilegedActionException {
return AccessController.doPrivileged((PrivilegedExceptionAction<T>) () -> {
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
try {
System.setProperty("java.library.path", ML_CACHE_PATH.toAbsolutePath().toString());
System.setProperty("DJL_CACHE_DIR", ML_CACHE_PATH.toAbsolutePath().toString());
Thread.currentThread().setContextClassLoader(ai.djl.Model.class.getClassLoader());

return action.call();
} finally {
Thread.currentThread().setContextClassLoader(contextClassLoader);
}
});
}

public static HuggingFaceTokenizer buildHuggingFaceTokenizer(String tokenizerId) {
try {
return withDJLContext(() -> HuggingFaceTokenizer.newInstance(tokenizerId));
} catch (PrivilegedActionException e) {
throw new RuntimeException("Failed to initialize Hugging Face tokenizer. " + e);
}
}

public static Map<String, Float> parseInputStreamToTokenWeights(InputStream inputStream) {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
Map<String, Float> tokenWeights = new HashMap<>();
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().isEmpty()) {
continue;
}
String[] parts = line.split("\t");
if (parts.length != 2) {
throw new IllegalArgumentException("Invalid line in token weights file: " + line);
}
String token = parts[0];
float weight = Float.parseFloat(parts[1]);
tokenWeights.put(token, weight);
}
return tokenWeights;
} catch (IOException e) {
throw new RuntimeException("Failed to parse token weights file. " + e);
}
}

public static Map<String, Float> fetchTokenWeights(String tokenizerId, String fileName) {
Map<String, Float> tokenWeights = new HashMap<>();
String url = HUGGING_FACE_BASE_URL + tokenizerId + "/" + HUGGING_FACE_RESOLVE_PATH + fileName;

InputStream inputStream = null;
try {
inputStream = withDJLContext(() -> Utils.openUrl(url));
} catch (PrivilegedActionException e) {
throw new RuntimeException("Failed to download file from " + url, e);
}

return parseInputStreamToTokenWeights(inputStream);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.analysis;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.Tokenizer;

import java.util.function.Supplier;

public class HFModelAnalyzer extends Analyzer {
public static final String NAME = "hf_model_tokenizer";
Supplier<Tokenizer> tokenizerSupplier;

public HFModelAnalyzer() {
this.tokenizerSupplier = HFModelTokenizerFactory::createDefault;
}

HFModelAnalyzer(Supplier<Tokenizer> tokenizerSupplier) {
this.tokenizerSupplier = tokenizerSupplier;
}

@Override
protected TokenStreamComponents createComponents(String fieldName) {
final Tokenizer src = tokenizerSupplier.get();
return new TokenStreamComponents(src, src);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.analysis;

import org.opensearch.common.settings.Settings;
import org.opensearch.env.Environment;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.analysis.AbstractIndexAnalyzerProvider;

public class HFModelAnalyzerProvider extends AbstractIndexAnalyzerProvider<HFModelAnalyzer> {
private final HFModelAnalyzer analyzer;

public HFModelAnalyzerProvider(IndexSettings indexSettings, Environment environment, String name, Settings settings) {
super(indexSettings, name, settings);
HFModelTokenizerFactory tokenizerFactory = new HFModelTokenizerFactory(indexSettings, environment, name, settings);
analyzer = new HFModelAnalyzer(tokenizerFactory::create);
}

@Override
public HFModelAnalyzer get() {
return analyzer;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.analysis;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.Objects;

import com.google.common.io.CharStreams;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import org.apache.lucene.util.BytesRef;

public class HFModelTokenizer extends Tokenizer {
public static final String NAME = "hf_model_tokenizer";
private static final Float DEFAULT_TOKEN_WEIGHT = 1.0f;

private final CharTermAttribute termAtt;
private final PayloadAttribute payloadAtt;
private final OffsetAttribute offsetAtt;
private final HuggingFaceTokenizer tokenizer;
private final Map<String, Float> tokenWeights;

private Encoding encoding;
private int tokenIdx = 0;
private int overflowingIdx = 0;

public HFModelTokenizer(HuggingFaceTokenizer huggingFaceTokenizer) {
this(huggingFaceTokenizer, null);
}

public HFModelTokenizer(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, Float> weights) {
termAtt = addAttribute(CharTermAttribute.class);
offsetAtt = addAttribute(OffsetAttribute.class);
if (Objects.nonNull(weights)) {
payloadAtt = addAttribute(PayloadAttribute.class);
} else {
payloadAtt = null;
}
tokenizer = huggingFaceTokenizer;
tokenWeights = weights;
}

@Override
public void reset() throws IOException {
super.reset();
tokenIdx = 0;
overflowingIdx = -1;
String inputStr = CharStreams.toString(input);
encoding = tokenizer.encode(inputStr, false, true);
}

private static boolean isLastTokenInEncodingSegment(int idx, Encoding encodingSegment) {
return idx >= encodingSegment.getTokens().length || encodingSegment.getAttentionMask()[idx] == 0;
}

public static byte[] floatToBytes(float value) {
return ByteBuffer.allocate(4).putFloat(value).array();
}

public static float bytesToFloat(byte[] bytes) {
return ByteBuffer.wrap(bytes).getFloat();
}

@Override
final public boolean incrementToken() throws IOException {
clearAttributes();
Encoding curEncoding = overflowingIdx == -1 ? encoding : encoding.getOverflowing()[overflowingIdx];

while (!isLastTokenInEncodingSegment(tokenIdx, curEncoding) || overflowingIdx < encoding.getOverflowing().length) {
if (isLastTokenInEncodingSegment(tokenIdx, curEncoding)) {
// reset cur segment, go to the next segment
// until overflowingIdx = encoding.getOverflowing().length
tokenIdx = 0;
overflowingIdx++;
if (overflowingIdx >= encoding.getOverflowing().length) {
return false;
}
curEncoding = encoding.getOverflowing()[overflowingIdx];
} else {
termAtt.append(curEncoding.getTokens()[tokenIdx]);
offsetAtt.setOffset(
curEncoding.getCharTokenSpans()[tokenIdx].getStart(),
curEncoding.getCharTokenSpans()[tokenIdx].getEnd()
);
if (Objects.nonNull(tokenWeights)) {
// for neural sparse query, write the token weight to payload field
payloadAtt.setPayload(
new BytesRef(floatToBytes(tokenWeights.getOrDefault(curEncoding.getTokens()[tokenIdx], DEFAULT_TOKEN_WEIGHT)))
);
}
tokenIdx++;
return true;
}
}

return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.analysis;

import org.apache.lucene.analysis.Tokenizer;
import org.opensearch.common.settings.Settings;
import org.opensearch.env.Environment;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.analysis.AbstractTokenizerFactory;

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;

import java.util.Map;
import java.util.Objects;

public class HFModelTokenizerFactory extends AbstractTokenizerFactory {
private final HuggingFaceTokenizer tokenizer;
private final Map<String, Float> tokenWeights;

/**
* Atomically loads the HF tokenizer in a lazy fashion once the outer class accesses the static final set the first time.;
*/
private static class DefaultTokenizerHolder {
static final HuggingFaceTokenizer TOKENIZER;
static final Map<String, Float> TOKEN_WEIGHTS;
static private final String DEFAULT_TOKENIZER_ID = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill";
static private final String DEFAULT_TOKEN_WEIGHTS_FILE = "query_token_weights.txt";

static {
try {
TOKENIZER = DJLUtils.buildHuggingFaceTokenizer(DEFAULT_TOKENIZER_ID);
TOKEN_WEIGHTS = DJLUtils.fetchTokenWeights(DEFAULT_TOKENIZER_ID, DEFAULT_TOKEN_WEIGHTS_FILE);
} catch (Exception e) {
throw new RuntimeException("Failed to initialize default hf_model_tokenizer", e);
}
}
}

static public Tokenizer createDefault() {
return new HFModelTokenizer(DefaultTokenizerHolder.TOKENIZER, DefaultTokenizerHolder.TOKEN_WEIGHTS);
}

public HFModelTokenizerFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) {
// For custom tokenizer, the factory is created during IndexModule.newIndexService
// And can be accessed via indexService.getIndexAnalyzers()
super(indexSettings, settings, name);
String tokenizerId = settings.get("tokenizer_id", null);
Objects.requireNonNull(tokenizerId, "tokenizer_id is required");
String tokenWeightsFileName = settings.get("token_weights_file", null);
tokenizer = DJLUtils.buildHuggingFaceTokenizer(tokenizerId);
if (tokenWeightsFileName != null) {
tokenWeights = DJLUtils.fetchTokenWeights(tokenizerId, tokenWeightsFileName);
} else {
tokenWeights = null;
}
}

@Override
public Tokenizer create() {
// the create method will be called for every single analyze request
return new HFModelTokenizer(tokenizer, tokenWeights);
}
}
Loading
Loading