diff --git a/topos/generations/chat_gens.py b/topos/generations/chat_gens.py index 9944bc5..c353021 100644 --- a/topos/generations/chat_gens.py +++ b/topos/generations/chat_gens.py @@ -4,12 +4,27 @@ # Assuming OpenAI is a pre-defined client for API interactions +default_models = { + "groq": "llama-3.1-70b-versatile", + "openai": "gpt-4o", + "ollama": "dolphin-llama3" + } + class LLMChatGens: def __init__(self, model_name: str, provider: str, api_key: str): self.provier = provider self.api_key = api_key self.client = LLMClient(provider, api_key).get_client() - self.model_name = model_name + self.model_name = self._init_model(model_name, provider) + + def _init_model(self, model_name: str, provider: str): + if len(model_name) > 0: + return model_name + else: + if provider == 'ollama': + return model_name + else: + return default_models[provider] def stream_chat(self, message_history: List[Dict[str, str]], temperature: float = 0) -> Generator[str, None, None]: try: