From 7b8ba464baa0a6d4ef451eea469017e8fe5286dc Mon Sep 17 00:00:00 2001 From: jonny <32085184+jonnyjohnson1@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:21:00 -0500 Subject: [PATCH] + models list api for alt llm providers --- topos/api/api_routes.py | 36 +++++++++++++++++++++++----------- topos/generations/groq_chat.py | 6 ------ 2 files changed, 25 insertions(+), 17 deletions(-) delete mode 100644 topos/generations/groq_chat.py diff --git a/topos/api/api_routes.py b/topos/api/api_routes.py index 9b3b576..2fb7a8c 100644 --- a/topos/api/api_routes.py +++ b/topos/api/api_routes.py @@ -268,23 +268,37 @@ async def create_next_messages(request: ConversationTopicsRequest): @router.post("/list_models") -async def list_models(): - # model specifications - # TODO UPDATE SO ITS NOT HARDCODED - # TODO UPDATE THIS ONE!!!! - # model = request.model if request.model != None else "dolphin-llama3" - # provider = 'ollama' # defaults to ollama right now - # api_key = 'ollama' +async def list_models(provider: str = 'ollama', api_key: str = 'ollama'): + # Define the URLs for different providers + + list_models_urls = { + 'ollama': "http://localhost:11434/api/tags", + 'openai': "https://api.openai.com/v1/models", + 'groq': "https://api.groq.com/openai/v1/models" + } + + if provider not in list_models_urls: + raise HTTPException(status_code=400, detail="Unsupported provider") + + # Get the appropriate URL based on the provider + url = list_models_urls.get(provider.lower()) - # llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + if provider.lower() == 'ollama': + # No need for headers with Ollama + headers = {} + else: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } - url = "http://localhost:11434/api/tags" try: - result = requests.get(url) + # Make the request with the appropriate headers + result = requests.get(url, headers=headers) if result.status_code == 200: return {"result": result.json()} else: - raise HTTPException(status_code=404, detail="Models not found") + raise HTTPException(status_code=result.status_code, detail="Models not found") except requests.ConnectionError: raise HTTPException(status_code=500, detail="Server connection error") diff --git a/topos/generations/groq_chat.py b/topos/generations/groq_chat.py deleted file mode 100644 index 8dd8fa8..0000000 --- a/topos/generations/groq_chat.py +++ /dev/null @@ -1,6 +0,0 @@ -# STARTER SETUP FOR GROQ API INTEGRATION - -# from openai import OpenAI - -# groq_client = OpenAI(api_key=groq_api_key, base_url="https://api.groq.com/openai/v1") -# groq_model = "llama-3.1-70b-versatile" \ No newline at end of file