diff --git a/bertopic/representation/_cohere.py b/bertopic/representation/_cohere.py index 8ca31c8f..a0c74434 100644 --- a/bertopic/representation/_cohere.py +++ b/bertopic/representation/_cohere.py @@ -4,7 +4,7 @@ from scipy.sparse import csr_matrix from typing import Mapping, List, Tuple, Union, Callable from bertopic.representation._base import BaseRepresentation -from bertopic.representation._utils import truncate_document +from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters DEFAULT_PROMPT = """ @@ -124,6 +124,8 @@ def __init__( self.diversity = diversity self.doc_length = doc_length self.tokenizer = tokenizer + validate_truncate_document_parameters(self.tokenizer, self.doc_length) + self.prompts_ = [] def extract_topics( diff --git a/bertopic/representation/_langchain.py b/bertopic/representation/_langchain.py index df5c4839..e7588df4 100644 --- a/bertopic/representation/_langchain.py +++ b/bertopic/representation/_langchain.py @@ -4,7 +4,7 @@ from typing import Callable, Mapping, List, Tuple, Union from bertopic.representation._base import BaseRepresentation -from bertopic.representation._utils import truncate_document +from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters DEFAULT_PROMPT = "What are these documents about? Please give a single label." @@ -148,6 +148,7 @@ def __init__( self.diversity = diversity self.doc_length = doc_length self.tokenizer = tokenizer + validate_truncate_document_parameters(self.tokenizer, self.doc_length) def extract_topics( self, diff --git a/bertopic/representation/_llamacpp.py b/bertopic/representation/_llamacpp.py index 83b18952..3fd3541b 100644 --- a/bertopic/representation/_llamacpp.py +++ b/bertopic/representation/_llamacpp.py @@ -4,7 +4,7 @@ from llama_cpp import Llama from typing import Mapping, List, Tuple, Any, Union, Callable from bertopic.representation._base import BaseRepresentation -from bertopic.representation._utils import truncate_document +from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters DEFAULT_PROMPT = """ @@ -116,6 +116,7 @@ def __init__( self.diversity = diversity self.doc_length = doc_length self.tokenizer = tokenizer + validate_truncate_document_parameters(self.tokenizer, self.doc_length) self.prompts_ = [] diff --git a/bertopic/representation/_openai.py b/bertopic/representation/_openai.py index 8fd25a1b..e05a9c66 100644 --- a/bertopic/representation/_openai.py +++ b/bertopic/representation/_openai.py @@ -8,6 +8,7 @@ from bertopic.representation._utils import ( retry_with_exponential_backoff, truncate_document, + validate_truncate_document_parameters, ) @@ -169,6 +170,8 @@ def __init__( self.diversity = diversity self.doc_length = doc_length self.tokenizer = tokenizer + validate_truncate_document_parameters(self.tokenizer, self.doc_length) + self.prompts_ = [] self.generator_kwargs = generator_kwargs diff --git a/bertopic/representation/_textgeneration.py b/bertopic/representation/_textgeneration.py index b028e575..ada27d38 100644 --- a/bertopic/representation/_textgeneration.py +++ b/bertopic/representation/_textgeneration.py @@ -5,7 +5,7 @@ from transformers.pipelines.base import Pipeline from typing import Mapping, List, Tuple, Any, Union, Callable from bertopic.representation._base import BaseRepresentation -from bertopic.representation._utils import truncate_document +from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters DEFAULT_PROMPT = """ @@ -112,6 +112,7 @@ def __init__( self.diversity = diversity self.doc_length = doc_length self.tokenizer = tokenizer + validate_truncate_document_parameters(self.tokenizer, self.doc_length) self.prompts_ = [] diff --git a/bertopic/representation/_utils.py b/bertopic/representation/_utils.py index 2a99fd1f..255c8fbe 100644 --- a/bertopic/representation/_utils.py +++ b/bertopic/representation/_utils.py @@ -1,8 +1,9 @@ import random import time +from typing import Union -def truncate_document(topic_model, doc_length, tokenizer, document: str): +def truncate_document(topic_model, doc_length: Union[int, None], tokenizer: Union[str, callable], document: str) -> str: """Truncate a document to a certain length. If you want to add a custom tokenizer, then it will need to have a `decode` and @@ -58,6 +59,18 @@ def decode(self, doc_chunks): return document +def validate_truncate_document_parameters(tokenizer, doc_length) -> Union[None, ValueError]: + """Validates parameters that are used in the function `truncate_document`.""" + if tokenizer is None and doc_length is not None: + raise ValueError( + "Please select from one of the valid options for the `tokenizer` parameter: \n" + "{'char', 'whitespace', 'vectorizer'} \n" + "If `tokenizer` is of type callable ensure it has methods to encode and decode a document \n" + ) + elif tokenizer is not None and doc_length is None: + raise ValueError("If `tokenizer` is provided, `doc_length` of type int must be provided as well.") + + def retry_with_exponential_backoff( func, initial_delay: float = 1,