forked from microsoft/semantic-kernel
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Python/Local Hugging Face Inference for Completions and Embeddings (m…
…icrosoft#658) ### Motivation and Context This PR introduces native python support for Hugging Face models that can: complete text, generate new text, summarize, and that generate embeddings. Currently only supports downloading the models locally from the HF model hub. Future plans include supporting the HF inference API as well. ### Description - Added 2 services: `hf_text_completion` and `hf_text_embedding` - `hf_text_completion` supports the following tasks: _text-generation_, _text2text-generation_, and _summarization_ - `hf_text_embedding` supports any model supported by the sentence-transformers pip package - Added dependencies: pytorch, transformers, sentence-transformers to `requirements.txt` and `poetry.lock` - fixed typo: `get_embedding_service_service_id` -> `get_embedding_service_id` - Added a number of integration tests for supported HF models
- Loading branch information
1 parent
5b1ed2f
commit f6059cd
Showing
26 changed files
with
1,673 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
openai==0.27.* | ||
numpy==1.24.* | ||
aiofiles>=23.1.0 | ||
aiofiles>=23.1.0 | ||
transformers>=4.28.0 | ||
sentence-transformers>=2.2.2 | ||
torch>=2.0.0 |
10 changes: 10 additions & 0 deletions
10
python/semantic_kernel/connectors/ai/hugging_face/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
from semantic_kernel.connectors.ai.hugging_face.services.hf_text_completion import ( | ||
HuggingFaceTextCompletion, | ||
) | ||
from semantic_kernel.connectors.ai.hugging_face.services.hf_text_embedding import ( | ||
HuggingFaceTextEmbedding, | ||
) | ||
|
||
__all__ = ["HuggingFaceTextCompletion", "HuggingFaceTextEmbedding"] |
95 changes: 95 additions & 0 deletions
95
python/semantic_kernel/connectors/ai/hugging_face/services/hf_text_completion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
from logging import Logger | ||
from typing import Optional | ||
|
||
import torch | ||
from transformers import pipeline | ||
|
||
from semantic_kernel.connectors.ai.ai_exception import AIException | ||
from semantic_kernel.connectors.ai.complete_request_settings import ( | ||
CompleteRequestSettings, | ||
) | ||
from semantic_kernel.connectors.ai.text_completion_client_base import ( | ||
TextCompletionClientBase, | ||
) | ||
from semantic_kernel.utils.null_logger import NullLogger | ||
|
||
|
||
class HuggingFaceTextCompletion(TextCompletionClientBase): | ||
_model_id: str | ||
_task: str | ||
_device: int | ||
_log: Logger | ||
|
||
def __init__( | ||
self, | ||
model_id: str, | ||
device: Optional[int] = -1, | ||
task: Optional[str] = None, | ||
log: Optional[Logger] = None, | ||
) -> None: | ||
""" | ||
Initializes a new instance of the HuggingFaceTextCompletion class. | ||
Arguments: | ||
model_id {str} -- Hugging Face model card string, see | ||
https://huggingface.co/models | ||
device {Optional[int]} -- Device to run the model on, -1 for CPU, 0+ for GPU. | ||
task {Optional[str]} -- Model completion task type, options are: | ||
- summarization: takes a long text and returns a shorter summary. | ||
- text-generation: takes incomplete text and returns a set of completion candidates. | ||
- text2text-generation (default): takes an input prompt and returns a completion. | ||
text2text-generation is the default as it behaves more like GPT-3+. | ||
log {Optional[Logger]} -- Logger instance. | ||
Note that this model will be downloaded from the Hugging Face model hub. | ||
""" | ||
self._model_id = model_id | ||
self._task = "text2text-generation" if task is None else task | ||
self._log = log if log is not None else NullLogger() | ||
self.device = ( | ||
"cuda:" + device if device >= 0 and torch.cuda.is_available() else "cpu" | ||
) | ||
self.generator = pipeline( | ||
task=self._task, model=self._model_id, device=self.device | ||
) | ||
|
||
async def complete_async( | ||
self, prompt: str, request_settings: CompleteRequestSettings | ||
) -> str: | ||
""" | ||
Completes a prompt using the Hugging Face model. | ||
Arguments: | ||
prompt {str} -- Prompt to complete. | ||
request_settings {CompleteRequestSettings} -- Request settings. | ||
Returns: | ||
str -- Completion result. | ||
""" | ||
try: | ||
result = self.generator( | ||
prompt, | ||
num_return_sequences=1, | ||
temperature=request_settings.temperature, | ||
top_p=request_settings.top_p, | ||
max_length=request_settings.max_tokens, | ||
pad_token_id=50256, # EOS token | ||
) | ||
|
||
if self._task == "text-generation" or self._task == "text2text-generation": | ||
return result[0]["generated_text"] | ||
|
||
elif self._task == "summarization": | ||
return result[0]["summary_text"] | ||
|
||
else: | ||
raise AIException( | ||
AIException.ErrorCodes.InvalidConfiguration, | ||
"Unsupported hugging face pipeline task: only \ | ||
text-generation, text2text-generation, and summarization are supported.", | ||
) | ||
|
||
except Exception as e: | ||
raise AIException("Hugging Face completion failed", e) |
63 changes: 63 additions & 0 deletions
63
python/semantic_kernel/connectors/ai/hugging_face/services/hf_text_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
from logging import Logger | ||
from typing import List, Optional | ||
|
||
import torch | ||
from numpy import array, ndarray | ||
from sentence_transformers import SentenceTransformer | ||
|
||
from semantic_kernel.connectors.ai.ai_exception import AIException | ||
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import ( | ||
EmbeddingGeneratorBase, | ||
) | ||
from semantic_kernel.utils.null_logger import NullLogger | ||
|
||
|
||
class HuggingFaceTextEmbedding(EmbeddingGeneratorBase): | ||
_model_id: str | ||
_device: int | ||
_log: Logger | ||
|
||
def __init__( | ||
self, | ||
model_id: str, | ||
device: Optional[int] = -1, | ||
log: Optional[Logger] = None, | ||
) -> None: | ||
""" | ||
Initializes a new instance of the HuggingFaceTextEmbedding class. | ||
Arguments: | ||
model_id {str} -- Hugging Face model card string, see | ||
https://huggingface.co/sentence-transformers | ||
device {Optional[int]} -- Device to run the model on, -1 for CPU, 0+ for GPU. | ||
log {Optional[Logger]} -- Logger instance. | ||
Note that this model will be downloaded from the Hugging Face model hub. | ||
""" | ||
self._model_id = model_id | ||
self._log = log if log is not None else NullLogger() | ||
self.device = ( | ||
"cuda:" + device if device >= 0 and torch.cuda.is_available() else "cpu" | ||
) | ||
self.generator = SentenceTransformer( | ||
model_name_or_path=self._model_id, device=self.device | ||
) | ||
|
||
async def generate_embeddings_async(self, texts: List[str]) -> ndarray: | ||
""" | ||
Generates embeddings for a list of texts. | ||
Arguments: | ||
texts {List[str]} -- Texts to generate embeddings for. | ||
Returns: | ||
ndarray -- Embeddings for the texts. | ||
""" | ||
try: | ||
self._log.info(f"Generating embeddings for {len(texts)} texts") | ||
embeddings = self.generator.encode(texts) | ||
return array(embeddings) | ||
except Exception as e: | ||
raise AIException("Hugging Face embeddings failed", e) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
python/tests/end-to-end/basics_with_hf_local_text2text_service.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
import asyncio | ||
|
||
from utils import e2e_text_completion | ||
|
||
import semantic_kernel as sk | ||
import semantic_kernel.connectors.ai.hugging_face as sk_hf | ||
|
||
kernel = sk.Kernel() | ||
|
||
# Configure LLM service | ||
kernel.config.add_text_service( | ||
"google/flan-t5-base", | ||
sk_hf.HuggingFaceTextCompletion("google/flan-t5-base", task="text2text-generation"), | ||
) | ||
|
||
asyncio.run(e2e_text_completion.simple_completion(kernel)) |
17 changes: 17 additions & 0 deletions
17
python/tests/end-to-end/basics_with_hf_local_text_generation_service.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
import asyncio | ||
|
||
from utils import e2e_text_completion | ||
|
||
import semantic_kernel as sk | ||
import semantic_kernel.connectors.ai.hugging_face as sk_hf | ||
|
||
kernel = sk.Kernel() | ||
|
||
# Configure LLM service | ||
kernel.config.add_text_service( | ||
"gpt2", sk_hf.HuggingFaceTextCompletion("gpt2", task="text-generation") | ||
) | ||
|
||
asyncio.run(e2e_text_completion.simple_completion(kernel)) |
18 changes: 18 additions & 0 deletions
18
python/tests/end-to-end/basics_with_hf_local_text_summarization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
import asyncio | ||
|
||
from utils import e2e_text_completion | ||
|
||
import semantic_kernel as sk | ||
import semantic_kernel.connectors.ai.hugging_face as sk_hf | ||
|
||
kernel = sk.Kernel() | ||
|
||
# Configure LLM service | ||
kernel.config.add_text_service( | ||
"facebook/bart-large-cnn", | ||
sk_hf.HuggingFaceTextCompletion("facebook/bart-large-cnn", task="summarization"), | ||
) | ||
|
||
asyncio.run(e2e_text_completion.simple_summarization(kernel)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.