Skip to content

Commit

Permalink
Fixed Issue: #1977 (#2181)
Browse files Browse the repository at this point in the history
  • Loading branch information
SSivakumar12 authored Dec 9, 2024
1 parent c3ec85d commit 50d9a49
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 5 deletions.
4 changes: 3 additions & 1 deletion bertopic/representation/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters


DEFAULT_PROMPT = """
Expand Down Expand Up @@ -124,6 +124,8 @@ def __init__(
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

self.prompts_ = []

def extract_topics(
Expand Down
3 changes: 2 additions & 1 deletion bertopic/representation/_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Mapping, List, Tuple, Union

from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters

DEFAULT_PROMPT = "What are these documents about? Please give a single label."

Expand Down Expand Up @@ -148,6 +148,7 @@ def __init__(
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

def extract_topics(
self,
Expand Down
3 changes: 2 additions & 1 deletion bertopic/representation/_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from llama_cpp import Llama
from typing import Mapping, List, Tuple, Any, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters


DEFAULT_PROMPT = """
Expand Down Expand Up @@ -116,6 +116,7 @@ def __init__(
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

self.prompts_ = []

Expand Down
3 changes: 3 additions & 0 deletions bertopic/representation/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bertopic.representation._utils import (
retry_with_exponential_backoff,
truncate_document,
validate_truncate_document_parameters,
)


Expand Down Expand Up @@ -169,6 +170,8 @@ def __init__(
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

self.prompts_ = []

self.generator_kwargs = generator_kwargs
Expand Down
3 changes: 2 additions & 1 deletion bertopic/representation/_textgeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers.pipelines.base import Pipeline
from typing import Mapping, List, Tuple, Any, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters


DEFAULT_PROMPT = """
Expand Down Expand Up @@ -112,6 +112,7 @@ def __init__(
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

self.prompts_ = []

Expand Down
15 changes: 14 additions & 1 deletion bertopic/representation/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import random
import time
from typing import Union


def truncate_document(topic_model, doc_length, tokenizer, document: str):
def truncate_document(topic_model, doc_length: Union[int, None], tokenizer: Union[str, callable], document: str) -> str:
"""Truncate a document to a certain length.
If you want to add a custom tokenizer, then it will need to have a `decode` and
Expand Down Expand Up @@ -58,6 +59,18 @@ def decode(self, doc_chunks):
return document


def validate_truncate_document_parameters(tokenizer, doc_length) -> Union[None, ValueError]:
"""Validates parameters that are used in the function `truncate_document`."""
if tokenizer is None and doc_length is not None:
raise ValueError(
"Please select from one of the valid options for the `tokenizer` parameter: \n"
"{'char', 'whitespace', 'vectorizer'} \n"
"If `tokenizer` is of type callable ensure it has methods to encode and decode a document \n"
)
elif tokenizer is not None and doc_length is None:
raise ValueError("If `tokenizer` is provided, `doc_length` of type int must be provided as well.")


def retry_with_exponential_backoff(
func,
initial_delay: float = 1,
Expand Down

0 comments on commit 50d9a49

Please sign in to comment.