Skip to content

Commit

Permalink
FEAT: Support gte-Qwen2-7B-instruct and multi gpu deploy (#1994)
Browse files Browse the repository at this point in the history
Co-authored-by: wuzhaoxin <[email protected]>
  • Loading branch information
amumu96 and wuzhaoxin authored Aug 2, 2024
1 parent be149a8 commit dd85cfe
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 14 deletions.
89 changes: 78 additions & 11 deletions xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ def __init__(
self,
model_uid: str,
model_path: str,
model_spec: EmbeddingModelSpec,
device: Optional[str] = None,
):
self._model_uid = model_uid
self._model_path = model_path
self._device = device
self._model = None
self._counter = 0
self._model_spec = model_spec

def load(self):
try:
Expand All @@ -139,12 +141,26 @@ def load(self):
"Please make sure 'sentence-transformers' is installed. ",
"You can install it by `pip install sentence-transformers`\n",
]

raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")

class XSentenceTransformer(SentenceTransformer):
def to(self, *args, **kwargs):
pass

from ..utils import patch_trust_remote_code

patch_trust_remote_code()
self._model = SentenceTransformer(self._model_path, device=self._device)
if (
"gte-Qwen2" in self._model_spec.model_id
or "gte-Qwen2" in self._model_spec.model_name
):
self._model = XSentenceTransformer(
self._model_path,
device=self._device,
model_kwargs={"device_map": "auto"},
)
else:
self._model = SentenceTransformer(self._model_path, device=self._device)

def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
self._counter += 1
Expand All @@ -161,6 +177,8 @@ def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
def encode(
model: SentenceTransformer,
sentences: Union[str, List[str]],
prompt_name: Optional[str] = None,
prompt: Optional[str] = None,
batch_size: int = 32,
show_progress_bar: bool = None,
output_value: str = "sentence_embedding",
Expand Down Expand Up @@ -209,10 +227,43 @@ def encode(
sentences = [sentences]
input_was_string = True

if prompt is None:
if prompt_name is not None:
try:
prompt = model.prompts[prompt_name]
except KeyError:
raise ValueError(
f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(model.prompts.keys())!r}."
)
elif model.default_prompt_name is not None:
prompt = model.prompts.get(model.default_prompt_name, None)
else:
if prompt_name is not None:
logger.warning(
"Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
"Ignoring the `prompt_name` in favor of `prompt`."
)

extra_features = {}
if prompt is not None:
sentences = [prompt + sentence for sentence in sentences]

# Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
# Tracking the prompt length allow us to remove the prompt during pooling
tokenized_prompt = model.tokenize([prompt])
if "input_ids" in tokenized_prompt:
extra_features["prompt_length"] = (
tokenized_prompt["input_ids"].shape[-1] - 1
)

if device is None:
device = model._target_device

model.to(device)
if (
"gte-Qwen2" not in self._model_spec.model_id
and "gte-Qwen2" not in self._model_spec.model_name
):
model.to(device)

all_embeddings = []
all_token_nums = 0
Expand All @@ -233,6 +284,7 @@ def encode(
]
features = model.tokenize(sentences_batch)
features = batch_to_device(features, device)
features.update(extra_features)
all_token_nums += sum([len(f) for f in features])

with torch.no_grad():
Expand Down Expand Up @@ -277,7 +329,10 @@ def encode(
]

if convert_to_tensor:
all_embeddings = torch.stack(all_embeddings)
if len(all_embeddings):
all_embeddings = torch.stack(all_embeddings)
else:
all_embeddings = torch.Tensor()
elif convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])

Expand All @@ -286,12 +341,24 @@ def encode(

return all_embeddings, all_token_nums

all_embeddings, all_token_nums = encode(
self._model,
sentences,
convert_to_numpy=False,
**kwargs,
)
if (
"gte-Qwen2" in self._model_spec.model_id
or "gte-Qwen2" in self._model_spec.model_name
):
all_embeddings, all_token_nums = encode(
self._model,
sentences,
prompt_name="query",
convert_to_numpy=False,
**kwargs,
)
else:
all_embeddings, all_token_nums = encode(
self._model,
sentences,
convert_to_numpy=False,
**kwargs,
)
if isinstance(sentences, str):
all_embeddings = [all_embeddings]
embedding_list = []
Expand Down Expand Up @@ -356,7 +423,7 @@ def create_embedding_model_instance(
if model_path is None:
model_path = cache(model_spec)

model = EmbeddingModel(model_uid, model_path, **kwargs)
model = EmbeddingModel(model_uid, model_path, model_spec, **kwargs)
model_description = EmbeddingModelDescription(
subpool_addr, devices, model_spec, model_path=model_path
)
Expand Down
8 changes: 8 additions & 0 deletions xinference/model/embedding/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -230,5 +230,13 @@
"language": ["zh", "en"],
"model_id": "moka-ai/m3e-large",
"model_revision": "12900375086c37ba5d83d1e417b21dc7d1d1f388"
},
{
"model_name": "gte-Qwen2",
"dimensions": 3584,
"max_tokens": 32000,
"language": ["zh", "en"],
"model_id": "Alibaba-NLP/gte-Qwen2-7B-instruct",
"model_revision": "e26182b2122f4435e8b3ebecbf363990f409b45b"
}
]
8 changes: 8 additions & 0 deletions xinference/model/embedding/model_spec_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,13 @@
"language": ["zh", "en"],
"model_id": "AI-ModelScope/m3e-large",
"model_hub": "modelscope"
},
{
"model_name": "gte-Qwen2",
"dimensions": 4096,
"max_tokens": 32000,
"language": ["zh", "en"],
"model_id": "iic/gte_Qwen2-7B-instruct",
"model_hub": "modelscope"
}
]
6 changes: 3 additions & 3 deletions xinference/model/embedding/tests/test_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_model():
model_path = None
try:
model_path = cache(TEST_MODEL_SPEC)
model = EmbeddingModel("mock", model_path)
model = EmbeddingModel("mock", model_path, TEST_MODEL_SPEC)
# input is a string
input_text = "what is the capital of China?"
model.load()
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_model():

def test_model_from_modelscope():
model_path = cache(TEST_MODEL_SPEC_FROM_MODELSCOPE)
model = EmbeddingModel("mock", model_path)
model = EmbeddingModel("mock", model_path, TEST_MODEL_SPEC_FROM_MODELSCOPE)
# input is a string
input_text = "乱条犹未变初黄,倚得东风势便狂。解把飞花蒙日月,不知天地有清霜。"
model.load()
Expand All @@ -108,7 +108,7 @@ def test_meta_file():
assert valid_model_revision(meta_path, TEST_MODEL_SPEC2.model_revision)

# test functionality of the new version model
model = EmbeddingModel("mock", cache_dir)
model = EmbeddingModel("mock", cache_dir, TEST_MODEL_SPEC2)
input_text = "I can do this all day."
model.load()
r = model.create_embedding(input_text)
Expand Down

0 comments on commit dd85cfe

Please sign in to comment.