Skip to content

Commit

Permalink
updates to controller and pydantic BaseModels
Browse files Browse the repository at this point in the history
  • Loading branch information
jonnyjohnson1 committed Aug 23, 2024
1 parent d97f98c commit 734f757
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 28 deletions.
31 changes: 18 additions & 13 deletions topos/api/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from collections import Counter, OrderedDict, defaultdict
from pydantic import BaseModel

from ..generations.chat_gens import LLMChatGens
from ..generations.chat_gens import LLMController
from ..utilities.utils import create_conversation_string
from ..services.ontology_service.mermaid_chart import MermaidCreator

Expand Down Expand Up @@ -136,7 +136,7 @@ async def conv_to_image(request: ConversationIDRequest):
provider = 'ollama' # defaults to ollama right now
api_key = 'ollama'

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key)

context = create_conversation_string(conv_data, 6)
print(context)
Expand Down Expand Up @@ -180,21 +180,23 @@ async def conv_to_image(request: ConversationIDRequest):
class GenNextMessageOptions(BaseModel):
conversation_id: str
query: str
provider: str
api_key: str
model: str
voice_settings: dict

@router.post("/gen_next_message_options")
async def create_next_messages(request: GenNextMessageOptions):
conversation_id = request.conversation_id
query = request.query

print(request.provider, "/", request.model)
print(request.api_key)
# model specifications
# TODO UPDATE SO ITS NOT HARDCODED
model = request.model if request.model != None else "dolphin-llama3"
provider = 'ollama' # defaults to ollama right now
api_key = 'ollama'
provider = request.provider if request.provider != None else 'ollama' # defaults to ollama right now
api_key = request.api_key if request.api_key != None else 'ollama'

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key)

voice_settings = request.voice_settings if request.voice_settings != None else """{
"tone": "analytical",
Expand Down Expand Up @@ -246,7 +248,7 @@ async def create_next_messages(request: ConversationTopicsRequest):
provider = 'ollama' # defaults to ollama right now
api_key = 'ollama'

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key)

# load conversation
conv_data = cache_manager.load_from_cache(conversation_id)
Expand Down Expand Up @@ -341,6 +343,8 @@ class MermaidChartPayload(BaseModel):
conversation_id: str
full_conversation: bool = False
model: str = "dolphin-llama3"
provider: str = "ollama"
api_key: str = "ollama"
temperature: float = 0.04

@router.post("/generate_mermaid_chart")
Expand All @@ -350,11 +354,11 @@ async def generate_mermaid_chart(payload: MermaidChartPayload):
full_conversation = payload.full_conversation
# model specifications
model = payload.model
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')
provider = payload.provider# defaults to ollama right now
api_key = payload.api_key
temperature = payload.temperature

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key)

mermaid_generator = MermaidCreator(llm_client)

Expand All @@ -365,7 +369,7 @@ async def generate_mermaid_chart(payload: MermaidChartPayload):
conv_data = cache_manager.load_from_cache(conversation_id)
if conv_data is None:
raise HTTPException(status_code=404, detail="Conversation not found in cache")
print(f"\t[ generating mermaid chart :: using model {model} :: full conversation ]")
print(f"\t[ generating mermaid chart :: {provider}/{model} :: full conversation ]")
return {"status": "generating", "response": "generating mermaid chart", 'completed': False}
# TODO: Complete this branch if needed

Expand All @@ -375,6 +379,7 @@ async def generate_mermaid_chart(payload: MermaidChartPayload):
print(f"\t[ generating mermaid chart :: using model {model} ]")
try:
mermaid_string = await mermaid_generator.get_mermaid_chart(message)
print(mermaid_string)
if mermaid_string == "Failed to generate mermaid":
return {"status": "error", "response": mermaid_string, 'completed': True}
else:
Expand Down
16 changes: 9 additions & 7 deletions topos/api/websocket_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import traceback
import pprint

from ..generations.chat_gens import LLMChatGens
from ..generations.chat_gens import LLMController
# from topos.FC.semantic_compression import SemanticCompression
# from ..config import get_openai_api_key
from ..models.llm_classes import vision_models
Expand Down Expand Up @@ -49,6 +49,7 @@ async def chat(websocket: WebSocket):
while True:
data = await websocket.receive_text()
payload = json.loads(data)
print(payload)
conversation_id = payload["conversation_id"]
message_id = payload["message_id"]
chatbot_msg_id = payload["chatbot_msg_id"]
Expand All @@ -72,8 +73,8 @@ async def chat(websocket: WebSocket):
model = payload.get("model", "solar")
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)
print("inputs", provider, api_key)
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key)


# Update default_config with provided processing_config, if any
Expand Down Expand Up @@ -289,13 +290,14 @@ async def meta_chat(websocket: WebSocket):
temperature = float(payload.get("temperature", 0.04))
current_topic = payload.get("topic", "Unknown")


# model specifications
model = payload.get("model", "solar")
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')
print(provider,"/",model)
print(api_key)

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key)

# Set system prompt
system_prompt = f"""You are a highly skilled conversationalist, adept at communicating strategies and tactics. Help the user navigate their current conversation to determine what to say next.
Expand Down Expand Up @@ -363,7 +365,7 @@ async def meta_chat(websocket: WebSocket):
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key)


# load conversation
Expand Down Expand Up @@ -427,7 +429,7 @@ async def meta_chat(websocket: WebSocket):
api_key = payload.get('api_key', 'ollama')
temperature = float(payload.get("temperature", 0.04))

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key)

mermaid_generator = MermaidCreator(llm_client)
# load conversation
Expand Down
4 changes: 2 additions & 2 deletions topos/channel/experimental/debatesim_experimental_think.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..FC.argument_detection import ArgumentDetection
from ..config import get_openai_api_key
from ..models.llm_classes import vision_models
from ..generations.chat_gens import LLMChatGens
from ..generations.chat_gens import LLMController
from ..services.database.app_state import AppState
from ..utilities.utils import create_conversation_string
from ..services.classification_service.base_analysis import base_text_classifier, base_token_classifier
Expand Down Expand Up @@ -344,7 +344,7 @@ async def debate_step(self, websocket: WebSocket, data, app_state):
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key)

temperature = float(payload.get("temperature", 0.04))
current_topic = payload.get("topic", "Unknown")
Expand Down
2 changes: 1 addition & 1 deletion topos/generations/chat_gens.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"ollama": "dolphin-llama3"
}

class LLMChatGens:
class LLMController:
def __init__(self, model_name: str, provider: str, api_key: str):
self.provier = provider
self.api_key = api_key
Expand Down
2 changes: 1 addition & 1 deletion topos/generations/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, provider: str, api_key: str):
self.provider = provider.lower()
self.api_key = api_key
self.client = self._init_client()
print(f"Init client:: {self.provider}")
print(f"Init client :: {self.provider}")

def _init_client(self):
if self.provider == "openai":
Expand Down
8 changes: 4 additions & 4 deletions topos/services/ontology_service/mermaid_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import re

from topos.FC.ontological_feature_detection import OntologicalFeatureDetection
from topos.generations.chat_gens import LLMChatGens
from topos.generations.chat_gens import LLMController

class MermaidCreator:
def __init__(self, llmChatGens: LLMChatGens):
self.client = llmChatGens
def __init__(self, LLMController: LLMController):
self.client = LLMController

def get_ontology_old_method(self, message):
user_id = "jonny"
Expand Down Expand Up @@ -163,7 +163,7 @@ async def get_mermaid_chart(self, message, websocket = None):
if websocket:
await websocket.send_json({"status": "generating", "response": f"generating mermaid_chart_from_triples :: try {attempt + 1}", 'completed': False})
response = self.client.generate_response_messages(message_history)
mermaid_chart = self.client.extract_mermaid_chart(response)
mermaid_chart = self.extract_mermaid_chart(response)
if mermaid_chart:
# refined_mermaid_chart = refine_mermaid_lines(mermaid_chart)
return mermaid_chart
Expand Down

0 comments on commit 734f757

Please sign in to comment.