-
Notifications
You must be signed in to change notification settings - Fork 775
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LiteLLM as a representation model
- Loading branch information
Showing
27 changed files
with
241 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import time | ||
from litellm import completion | ||
import pandas as pd | ||
from scipy.sparse import csr_matrix | ||
from typing import Mapping, List, Tuple, Any | ||
from bertopic.representation._base import BaseRepresentation | ||
from bertopic.representation._utils import retry_with_exponential_backoff | ||
|
||
|
||
DEFAULT_PROMPT = """ | ||
I have a topic that contains the following documents: | ||
[DOCUMENTS] | ||
The topic is described by the following keywords: [KEYWORDS] | ||
Based on the information above, extract a short topic label in the following format: | ||
topic: <topic label> | ||
""" | ||
|
||
|
||
class LiteLLM(BaseRepresentation): | ||
"""Using the LiteLLM API to generate topic labels. | ||
For an overview of models see: | ||
https://docs.litellm.ai/docs/providers | ||
Arguments: | ||
model: Model to use. Defaults to OpenAI's "gpt-3.5-turbo". | ||
generator_kwargs: Kwargs passed to `litellm.completion`. | ||
prompt: The prompt to be used in the model. If no prompt is given, | ||
`self.default_prompt_` is used instead. | ||
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt | ||
to decide where the keywords and documents need to be | ||
inserted. | ||
delay_in_seconds: The delay in seconds between consecutive prompts | ||
in order to prevent RateLimitErrors. | ||
exponential_backoff: Retry requests with a random exponential backoff. | ||
A short sleep is used when a rate limit error is hit, | ||
then the requests is retried. Increase the sleep length | ||
if errors are hit until 10 unsuccesfull requests. | ||
If True, overrides `delay_in_seconds`. | ||
nr_docs: The number of documents to pass to LiteLLM if a prompt | ||
with the `["DOCUMENTS"]` tag is used. | ||
diversity: The diversity of documents to pass to LiteLLM. | ||
Accepts values between 0 and 1. A higher | ||
values results in passing more diverse documents | ||
whereas lower values passes more similar documents. | ||
Usage: | ||
To use this, you will need to install the openai package first: | ||
`pip install litellm` | ||
Then, get yourself an API key of any provider (for instance OpenAI) and use it as follows: | ||
```python | ||
import os | ||
from bertopic.representation import LiteLLM | ||
from bertopic import BERTopic | ||
# set ENV variables | ||
os.environ["OPENAI_API_KEY"] = "your-openai-key" | ||
# Create your representation model | ||
representation_model = LiteLLM(model="gpt-3.5-turbo") | ||
# Use the representation model in BERTopic on top of the default pipeline | ||
topic_model = BERTopic(representation_model=representation_model) | ||
``` | ||
You can also use a custom prompt: | ||
```python | ||
prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '" | ||
representation_model = LiteLLM(model="gpt", prompt=prompt) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: str = "gpt-3.5-turbo", | ||
prompt: str = None, | ||
generator_kwargs: Mapping[str, Any] = {}, | ||
delay_in_seconds: float = None, | ||
exponential_backoff: bool = False, | ||
nr_docs: int = 4, | ||
diversity: float = None, | ||
): | ||
self.model = model | ||
self.prompt = prompt if prompt else DEFAULT_PROMPT | ||
self.default_prompt_ = DEFAULT_PROMPT | ||
self.delay_in_seconds = delay_in_seconds | ||
self.exponential_backoff = exponential_backoff | ||
self.nr_docs = nr_docs | ||
self.diversity = diversity | ||
|
||
self.generator_kwargs = generator_kwargs | ||
if self.generator_kwargs.get("model"): | ||
self.model = generator_kwargs.get("model") | ||
if self.generator_kwargs.get("prompt"): | ||
del self.generator_kwargs["prompt"] | ||
|
||
def extract_topics( | ||
self, topic_model, documents: pd.DataFrame, c_tf_idf: csr_matrix, topics: Mapping[str, List[Tuple[str, float]]] | ||
) -> Mapping[str, List[Tuple[str, float]]]: | ||
"""Extract topics. | ||
Arguments: | ||
topic_model: A BERTopic model | ||
documents: All input documents | ||
c_tf_idf: The topic c-TF-IDF representation | ||
topics: The candidate topics as calculated with c-TF-IDF | ||
Returns: | ||
updated_topics: Updated topic representations | ||
""" | ||
# Extract the top n representative documents per topic | ||
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( | ||
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity | ||
) | ||
|
||
# Generate using a (Large) Language Model | ||
updated_topics = {} | ||
for topic, docs in repr_docs_mappings.items(): | ||
prompt = self._create_prompt(docs, topic, topics) | ||
|
||
# Delay | ||
if self.delay_in_seconds: | ||
time.sleep(self.delay_in_seconds) | ||
|
||
messages = [ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": prompt}, | ||
] | ||
kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs} | ||
if self.exponential_backoff: | ||
response = chat_completions_with_backoff(**kwargs) | ||
else: | ||
response = completion(**kwargs) | ||
label = response["choices"][0]["message"]["content"].strip().replace("topic: ", "") | ||
|
||
updated_topics[topic] = [(label, 1)] | ||
|
||
return updated_topics | ||
|
||
def _create_prompt(self, docs, topic, topics): | ||
keywords = list(zip(*topics[topic]))[0] | ||
|
||
# Use the Default Chat Prompt | ||
if self.prompt == DEFAULT_PROMPT: | ||
prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords)) | ||
prompt = self._replace_documents(prompt, docs) | ||
|
||
# Use a custom prompt that leverages keywords, documents or both using | ||
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively | ||
else: | ||
prompt = self.prompt | ||
if "[KEYWORDS]" in prompt: | ||
prompt = prompt.replace("[KEYWORDS]", " ".join(keywords)) | ||
if "[DOCUMENTS]" in prompt: | ||
prompt = self._replace_documents(prompt, docs) | ||
|
||
return prompt | ||
|
||
@staticmethod | ||
def _replace_documents(prompt, docs): | ||
to_replace = "" | ||
for doc in docs: | ||
to_replace += f"- {doc[:255]}\n" | ||
prompt = prompt.replace("[DOCUMENTS]", to_replace) | ||
return prompt | ||
|
||
|
||
def chat_completions_with_backoff(**kwargs): | ||
return retry_with_exponential_backoff( | ||
completion, | ||
)(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# `Backends` | ||
|
||
::: bertopic.backend |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# `BaseCluster` | ||
|
||
::: bertopic.cluster._base.BaseCluster |
File renamed without changes.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# `Plotting` | ||
|
||
::: bertopic.plotting |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# `Representations` | ||
|
||
::: bertopic.representation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# `Vectorizers` | ||
|
||
::: bertopic.vectorizers._online_cv.OnlineCountVectorizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters