From 734f757f4d7d7681d10eda5a7e7a75ceb3ebba71 Mon Sep 17 00:00:00 2001 From: jonny <32085184+jonnyjohnson1@users.noreply.github.com> Date: Fri, 23 Aug 2024 15:54:32 -0500 Subject: [PATCH] updates to controller and pydantic BaseModels --- topos/api/api_routes.py | 31 +++++++++++-------- topos/api/websocket_handlers.py | 16 +++++----- .../debatesim_experimental_think.py | 4 +-- topos/generations/chat_gens.py | 2 +- topos/generations/llm_client.py | 2 +- .../ontology_service/mermaid_chart.py | 8 ++--- 6 files changed, 35 insertions(+), 28 deletions(-) diff --git a/topos/api/api_routes.py b/topos/api/api_routes.py index 2fb7a8c..90bf159 100644 --- a/topos/api/api_routes.py +++ b/topos/api/api_routes.py @@ -13,7 +13,7 @@ from collections import Counter, OrderedDict, defaultdict from pydantic import BaseModel -from ..generations.chat_gens import LLMChatGens +from ..generations.chat_gens import LLMController from ..utilities.utils import create_conversation_string from ..services.ontology_service.mermaid_chart import MermaidCreator @@ -136,7 +136,7 @@ async def conv_to_image(request: ConversationIDRequest): provider = 'ollama' # defaults to ollama right now api_key = 'ollama' - llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) context = create_conversation_string(conv_data, 6) print(context) @@ -180,6 +180,8 @@ async def conv_to_image(request: ConversationIDRequest): class GenNextMessageOptions(BaseModel): conversation_id: str query: str + provider: str + api_key: str model: str voice_settings: dict @@ -187,14 +189,14 @@ class GenNextMessageOptions(BaseModel): async def create_next_messages(request: GenNextMessageOptions): conversation_id = request.conversation_id query = request.query - + print(request.provider, "/", request.model) + print(request.api_key) # model specifications - # TODO UPDATE SO ITS NOT HARDCODED model = request.model if request.model != None else "dolphin-llama3" - provider = 'ollama' # defaults to ollama right now - api_key = 'ollama' + provider = request.provider if request.provider != None else 'ollama' # defaults to ollama right now + api_key = request.api_key if request.api_key != None else 'ollama' - llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) voice_settings = request.voice_settings if request.voice_settings != None else """{ "tone": "analytical", @@ -246,7 +248,7 @@ async def create_next_messages(request: ConversationTopicsRequest): provider = 'ollama' # defaults to ollama right now api_key = 'ollama' - llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) # load conversation conv_data = cache_manager.load_from_cache(conversation_id) @@ -341,6 +343,8 @@ class MermaidChartPayload(BaseModel): conversation_id: str full_conversation: bool = False model: str = "dolphin-llama3" + provider: str = "ollama" + api_key: str = "ollama" temperature: float = 0.04 @router.post("/generate_mermaid_chart") @@ -350,11 +354,11 @@ async def generate_mermaid_chart(payload: MermaidChartPayload): full_conversation = payload.full_conversation # model specifications model = payload.model - provider = payload.get('provider', 'ollama') # defaults to ollama right now - api_key = payload.get('api_key', 'ollama') + provider = payload.provider# defaults to ollama right now + api_key = payload.api_key temperature = payload.temperature - - llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + + llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) mermaid_generator = MermaidCreator(llm_client) @@ -365,7 +369,7 @@ async def generate_mermaid_chart(payload: MermaidChartPayload): conv_data = cache_manager.load_from_cache(conversation_id) if conv_data is None: raise HTTPException(status_code=404, detail="Conversation not found in cache") - print(f"\t[ generating mermaid chart :: using model {model} :: full conversation ]") + print(f"\t[ generating mermaid chart :: {provider}/{model} :: full conversation ]") return {"status": "generating", "response": "generating mermaid chart", 'completed': False} # TODO: Complete this branch if needed @@ -375,6 +379,7 @@ async def generate_mermaid_chart(payload: MermaidChartPayload): print(f"\t[ generating mermaid chart :: using model {model} ]") try: mermaid_string = await mermaid_generator.get_mermaid_chart(message) + print(mermaid_string) if mermaid_string == "Failed to generate mermaid": return {"status": "error", "response": mermaid_string, 'completed': True} else: diff --git a/topos/api/websocket_handlers.py b/topos/api/websocket_handlers.py index a2ca06e..63796cd 100644 --- a/topos/api/websocket_handlers.py +++ b/topos/api/websocket_handlers.py @@ -4,7 +4,7 @@ import traceback import pprint -from ..generations.chat_gens import LLMChatGens +from ..generations.chat_gens import LLMController # from topos.FC.semantic_compression import SemanticCompression # from ..config import get_openai_api_key from ..models.llm_classes import vision_models @@ -49,6 +49,7 @@ async def chat(websocket: WebSocket): while True: data = await websocket.receive_text() payload = json.loads(data) + print(payload) conversation_id = payload["conversation_id"] message_id = payload["message_id"] chatbot_msg_id = payload["chatbot_msg_id"] @@ -72,8 +73,8 @@ async def chat(websocket: WebSocket): model = payload.get("model", "solar") provider = payload.get('provider', 'ollama') # defaults to ollama right now api_key = payload.get('api_key', 'ollama') - - llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + print("inputs", provider, api_key) + llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) # Update default_config with provided processing_config, if any @@ -289,13 +290,14 @@ async def meta_chat(websocket: WebSocket): temperature = float(payload.get("temperature", 0.04)) current_topic = payload.get("topic", "Unknown") - # model specifications model = payload.get("model", "solar") provider = payload.get('provider', 'ollama') # defaults to ollama right now api_key = payload.get('api_key', 'ollama') + print(provider,"/",model) + print(api_key) - llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) # Set system prompt system_prompt = f"""You are a highly skilled conversationalist, adept at communicating strategies and tactics. Help the user navigate their current conversation to determine what to say next. @@ -363,7 +365,7 @@ async def meta_chat(websocket: WebSocket): provider = payload.get('provider', 'ollama') # defaults to ollama right now api_key = payload.get('api_key', 'ollama') - llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) # load conversation @@ -427,7 +429,7 @@ async def meta_chat(websocket: WebSocket): api_key = payload.get('api_key', 'ollama') temperature = float(payload.get("temperature", 0.04)) - llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) mermaid_generator = MermaidCreator(llm_client) # load conversation diff --git a/topos/channel/experimental/debatesim_experimental_think.py b/topos/channel/experimental/debatesim_experimental_think.py index c32eea6..3189547 100644 --- a/topos/channel/experimental/debatesim_experimental_think.py +++ b/topos/channel/experimental/debatesim_experimental_think.py @@ -29,7 +29,7 @@ from ..FC.argument_detection import ArgumentDetection from ..config import get_openai_api_key from ..models.llm_classes import vision_models -from ..generations.chat_gens import LLMChatGens +from ..generations.chat_gens import LLMController from ..services.database.app_state import AppState from ..utilities.utils import create_conversation_string from ..services.classification_service.base_analysis import base_text_classifier, base_token_classifier @@ -344,7 +344,7 @@ async def debate_step(self, websocket: WebSocket, data, app_state): provider = payload.get('provider', 'ollama') # defaults to ollama right now api_key = payload.get('api_key', 'ollama') - llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) temperature = float(payload.get("temperature", 0.04)) current_topic = payload.get("topic", "Unknown") diff --git a/topos/generations/chat_gens.py b/topos/generations/chat_gens.py index c353021..4f7456f 100644 --- a/topos/generations/chat_gens.py +++ b/topos/generations/chat_gens.py @@ -10,7 +10,7 @@ "ollama": "dolphin-llama3" } -class LLMChatGens: +class LLMController: def __init__(self, model_name: str, provider: str, api_key: str): self.provier = provider self.api_key = api_key diff --git a/topos/generations/llm_client.py b/topos/generations/llm_client.py index dad88dd..ef1282f 100644 --- a/topos/generations/llm_client.py +++ b/topos/generations/llm_client.py @@ -13,7 +13,7 @@ def __init__(self, provider: str, api_key: str): self.provider = provider.lower() self.api_key = api_key self.client = self._init_client() - print(f"Init client:: {self.provider}") + print(f"Init client :: {self.provider}") def _init_client(self): if self.provider == "openai": diff --git a/topos/services/ontology_service/mermaid_chart.py b/topos/services/ontology_service/mermaid_chart.py index 629aa06..e576f3d 100644 --- a/topos/services/ontology_service/mermaid_chart.py +++ b/topos/services/ontology_service/mermaid_chart.py @@ -2,11 +2,11 @@ import re from topos.FC.ontological_feature_detection import OntologicalFeatureDetection -from topos.generations.chat_gens import LLMChatGens +from topos.generations.chat_gens import LLMController class MermaidCreator: - def __init__(self, llmChatGens: LLMChatGens): - self.client = llmChatGens + def __init__(self, LLMController: LLMController): + self.client = LLMController def get_ontology_old_method(self, message): user_id = "jonny" @@ -163,7 +163,7 @@ async def get_mermaid_chart(self, message, websocket = None): if websocket: await websocket.send_json({"status": "generating", "response": f"generating mermaid_chart_from_triples :: try {attempt + 1}", 'completed': False}) response = self.client.generate_response_messages(message_history) - mermaid_chart = self.client.extract_mermaid_chart(response) + mermaid_chart = self.extract_mermaid_chart(response) if mermaid_chart: # refined_mermaid_chart = refine_mermaid_lines(mermaid_chart) return mermaid_chart