From 43b5a5070b58c8b82f2342c3cc11003483778355 Mon Sep 17 00:00:00 2001 From: jonny <32085184+jonnyjohnson1@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:29:51 -0500 Subject: [PATCH] + default model on llmchatgens init --- topos/generations/chat_gens.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/topos/generations/chat_gens.py b/topos/generations/chat_gens.py index 9944bc5..c353021 100644 --- a/topos/generations/chat_gens.py +++ b/topos/generations/chat_gens.py @@ -4,12 +4,27 @@ # Assuming OpenAI is a pre-defined client for API interactions +default_models = { + "groq": "llama-3.1-70b-versatile", + "openai": "gpt-4o", + "ollama": "dolphin-llama3" + } + class LLMChatGens: def __init__(self, model_name: str, provider: str, api_key: str): self.provier = provider self.api_key = api_key self.client = LLMClient(provider, api_key).get_client() - self.model_name = model_name + self.model_name = self._init_model(model_name, provider) + + def _init_model(self, model_name: str, provider: str): + if len(model_name) > 0: + return model_name + else: + if provider == 'ollama': + return model_name + else: + return default_models[provider] def stream_chat(self, message_history: List[Dict[str, str]], temperature: float = 0) -> Generator[str, None, None]: try: