Skip to content

Commit

Permalink
Merge pull request #9 from jonnyjohnson1/jonny/llmclient_class
Browse files Browse the repository at this point in the history
Jonny/llmclient class
  • Loading branch information
jonnyjohnson1 authored Aug 15, 2024
2 parents b3f6ef6 + 43b5a50 commit e13dd43
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 226 deletions.
79 changes: 66 additions & 13 deletions topos/api/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from collections import Counter, OrderedDict, defaultdict
from pydantic import BaseModel

from ..generations.ollama_chat import generate_response
from ..generations.chat_gens import LLMChatGens
from ..utilities.utils import create_conversation_string
from ..services.ontology_service.mermaid_chart import get_mermaid_chart
from ..services.ontology_service.mermaid_chart import MermaidCreator

cache_manager = ConversationCacheManager()
class ConversationIDRequest(BaseModel):
Expand Down Expand Up @@ -129,16 +129,23 @@ async def conv_to_image(request: ConversationIDRequest):
if conv_data is None:
raise HTTPException(status_code=404, detail="Conversation not found in cache")



# model specifications
# TODO UPDATE SO ITS NOT HARDCODED
model = "dolphin-llama3"
provider = 'ollama' # defaults to ollama right now
api_key = 'ollama'

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)

context = create_conversation_string(conv_data, 6)
print(context)
print(f"\t[ converting conversation to image to text prompt: using model {model}]")
conv_to_text_img_prompt = "Create an interesting, and compelling image-to-text prompt that can be used in a diffussor model. Be concise and convey more with the use of metaphor. Steer the image style towards Slavador Dali's fantastic, atmospheric, heroesque paintings that appeal to everyman themes."
txt_to_img_prompt = generate_response(context, conv_to_text_img_prompt, model=model, temperature=0)
txt_to_img_prompt = llm_client.generate_response(context, conv_to_text_img_prompt, temperature=0)
# print(txt_to_img_prompt)
print(f"\t[ generating a file name {model} ]")
txt_to_img_filename = generate_response(txt_to_img_prompt, "Based on the context create an appropriate, and BRIEF, filename with no spaces. Do not use any file extensions in your name, that will be added in a later step.", model=model, temperature=0)
txt_to_img_filename = llm_client.generate_response(txt_to_img_prompt, "Based on the context create an appropriate, and BRIEF, filename with no spaces. Do not use any file extensions in your name, that will be added in a later step.", temperature=0)

# run huggingface comic diffusion
pipeline = DiffusionPipeline.from_pretrained("ogkalu/Comic-Diffusion")
Expand Down Expand Up @@ -180,7 +187,15 @@ class GenNextMessageOptions(BaseModel):
async def create_next_messages(request: GenNextMessageOptions):
conversation_id = request.conversation_id
query = request.query

# model specifications
# TODO UPDATE SO ITS NOT HARDCODED
model = request.model if request.model != None else "dolphin-llama3"
provider = 'ollama' # defaults to ollama right now
api_key = 'ollama'

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)

voice_settings = request.voice_settings if request.voice_settings != None else """{
"tone": "analytical",
"distance": "distant",
Expand Down Expand Up @@ -211,7 +226,7 @@ async def create_next_messages(request: GenNextMessageOptions):
system_prompt += conv_json


next_message_options = generate_response(system_prompt, query, model=model, temperature=0)
next_message_options = llm_client.generate_response(system_prompt, query, temperature=0)
print(next_message_options)

# return the options
Expand All @@ -225,36 +240,65 @@ class ConversationTopicsRequest(BaseModel):
@router.post("/gen_conversation_topics")
async def create_next_messages(request: ConversationTopicsRequest):
conversation_id = request.conversation_id
# model specifications
# TODO UPDATE SO ITS NOT HARDCODED
model = request.model if request.model != None else "dolphin-llama3"
provider = 'ollama' # defaults to ollama right now
api_key = 'ollama'

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)

# load conversation
conv_data = cache_manager.load_from_cache(conversation_id)
if conv_data is None:
raise HTTPException(status_code=404, detail="Conversation not found in cache")

context = create_conversation_string(conv_data, 12)
print(f"\t[ generating summary :: model {model} :: subject {subject}]")
# print(f"\t[ generating summary :: model {model} :: subject {subject}]")

query = f""
# topic list first pass
system_prompt = "PRESENT CONVERSATION:\n-------<context>" + context + "\n-------\n"
query += """List the topics and those closely related to what this conversation traverses."""
topic_list = generate_response(system_prompt, query, model=model, temperature=0)
topic_list = llm_client.generate_response(system_prompt, query, temperature=0)
print(topic_list)

# return the image
return {"response" : topic_list}


@router.post("/list_models")
async def list_models():
url = "http://localhost:11434/api/tags"
async def list_models(provider: str = 'ollama', api_key: str = 'ollama'):
# Define the URLs for different providers

list_models_urls = {
'ollama': "http://localhost:11434/api/tags",
'openai': "https://api.openai.com/v1/models",
'groq': "https://api.groq.com/openai/v1/models"
}

if provider not in list_models_urls:
raise HTTPException(status_code=400, detail="Unsupported provider")

# Get the appropriate URL based on the provider
url = list_models_urls.get(provider.lower())

if provider.lower() == 'ollama':
# No need for headers with Ollama
headers = {}
else:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}

try:
result = requests.get(url)
# Make the request with the appropriate headers
result = requests.get(url, headers=headers)
if result.status_code == 200:
return {"result": result.json()}
else:
raise HTTPException(status_code=404, detail="Models not found")
raise HTTPException(status_code=result.status_code, detail="Models not found")
except requests.ConnectionError:
raise HTTPException(status_code=500, detail="Server connection error")

Expand Down Expand Up @@ -304,9 +348,18 @@ async def generate_mermaid_chart(payload: MermaidChartPayload):
try:
conversation_id = payload.conversation_id
full_conversation = payload.full_conversation
# model specifications
model = payload.model
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')
temperature = payload.temperature

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)

mermaid_generator = MermaidCreator(llm_client)



if full_conversation:
cache_manager = ConversationCacheManager()
conv_data = cache_manager.load_from_cache(conversation_id)
Expand All @@ -321,7 +374,7 @@ async def generate_mermaid_chart(payload: MermaidChartPayload):
if message:
print(f"\t[ generating mermaid chart :: using model {model} ]")
try:
mermaid_string = await get_mermaid_chart(message)
mermaid_string = await mermaid_generator.get_mermaid_chart(message)
if mermaid_string == "Failed to generate mermaid":
return {"status": "error", "response": mermaid_string, 'completed': True}
else:
Expand Down
53 changes: 41 additions & 12 deletions topos/api/websocket_handlers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from datetime import datetime
import time
import traceback
import pprint

from ..generations.ollama_chat import stream_chat
from ..generations.chat_gens import LLMChatGens
# from topos.FC.semantic_compression import SemanticCompression
# from ..config import get_openai_api_key
from ..models.llm_classes import vision_models
Expand All @@ -13,7 +13,7 @@
from ..utilities.utils import create_conversation_string
from ..services.classification_service.base_analysis import base_text_classifier, base_token_classifier
from ..services.loggers.process_logger import ProcessLogger
from ..services.ontology_service.mermaid_chart import get_mermaid_chart
from ..services.ontology_service.mermaid_chart import MermaidCreator

# cache database
from topos.FC.conversation_cache_manager import ConversationCacheManager
Expand Down Expand Up @@ -55,11 +55,11 @@ async def chat(websocket: WebSocket):
chatbot_msg_id = payload["chatbot_msg_id"]
message = payload["message"]
message_history = payload["message_history"]
model = payload.get("model", "solar")
temperature = float(payload.get("temperature", 0.04))
current_topic = payload.get("topic", "Unknown")
processing_config = payload.get("processing_config", {})



# Set default values if any key is missing or if processing_config is None
default_config = {
"showInMessageNER": True,
Expand All @@ -68,6 +68,14 @@ async def chat(websocket: WebSocket):
"calculateModerationTags": True,
"showSidebarBaseAnalytics": True
}

# model specifications
model = payload.get("model", "solar")
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)


# Update default_config with provided processing_config, if any
config = {**default_config, **processing_config}
Expand Down Expand Up @@ -171,9 +179,9 @@ async def chat(websocket: WebSocket):
is_first_token = True
total_tokens = 0 # Initialize token counter
ttfs = 0 # init time to first token value
await process_logger.start("llm_generation_stream_chat", provider="ollama", model=model, len_msg_hist=len(simp_msg_history))
await process_logger.start("llm_generation_stream_chat", provider=provider, model=model, len_msg_hist=len(simp_msg_history))
start_time = time.time() # Track the start time for the whole process
for chunk in stream_chat(simp_msg_history, model=model, temperature=temperature):
for chunk in llm_client.stream_chat(simp_msg_history, temperature=temperature):
if len(chunk) > 0:
if is_first_token:
ttfs_end_time = time.time()
Expand Down Expand Up @@ -283,6 +291,14 @@ async def meta_chat(websocket: WebSocket):
temperature = float(payload.get("temperature", 0.04))
current_topic = payload.get("topic", "Unknown")


# model specifications
model = payload.get("model", "solar")
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)

# Set system prompt
system_prompt = f"""You are a highly skilled conversationalist, adept at communicating strategies and tactics. Help the user navigate their current conversation to determine what to say next.
You possess a private, unmentioned expertise: PhDs in CBT and DBT, an elegant, smart, provocative speech style, extensive world travel, and deep literary theory knowledge à la Terry Eagleton. Demonstrate your expertise through your guidance, without directly stating it."""
Expand All @@ -308,7 +324,7 @@ async def meta_chat(websocket: WebSocket):

# Processing the chat
output_combined = ""
for chunk in stream_chat(simp_msg_history, model=model, temperature=temperature):
for chunk in llm_client.stream_chat(simp_msg_history, temperature=temperature):
try:
output_combined += chunk
await websocket.send_json({"status": "generating", "response": output_combined, 'completed': False})
Expand Down Expand Up @@ -342,8 +358,15 @@ async def meta_chat(websocket: WebSocket):

conversation_id = payload["conversation_id"]
subject = payload.get("subject", "knowledge")
model = payload.get("model", "solar")
temperature = float(payload.get("temperature", 0.04))

# model specifications
model = payload.get("model", "solar")
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)


# load conversation
cache_manager = ConversationCacheManager()
Expand All @@ -367,7 +390,7 @@ async def meta_chat(websocket: WebSocket):

# Processing the chat
output_combined = ""
for chunk in stream_chat(msg_history, model=model, temperature=temperature):
for chunk in llm_client.stream_chat(msg_history, temperature=temperature):
try:
output_combined += chunk
await websocket.send_json({"status": "generating", "response": output_combined, 'completed': False})
Expand Down Expand Up @@ -400,9 +423,15 @@ async def meta_chat(websocket: WebSocket):
message = payload.get("message", None)
conversation_id = payload["conversation_id"]
full_conversation = payload.get("full_conversation", False)
# model specifications
model = payload.get("model", "dolphin-llama3")
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')
temperature = float(payload.get("temperature", 0.04))

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)

mermaid_generator = MermaidCreator(llm_client)
# load conversation
if full_conversation:
cache_manager = ConversationCacheManager()
Expand All @@ -412,13 +441,13 @@ async def meta_chat(websocket: WebSocket):
print(f"\t[ generating mermaid chart :: using model {model} :: full conversation ]")
await websocket.send_json({"status": "generating", "response": "generating mermaid chart", 'completed': False})
context = create_conversation_string(conv_data, 12)
# TODO Coomplete this branch
# TODO Complete this branch
else:
if message:
print(f"\t[ generating mermaid chart :: using model {model} ]")
await websocket.send_json({"status": "generating", "response": "generating mermaid chart", 'completed': False})
try:
mermaid_string = await get_mermaid_chart(message, websocket = websocket)
mermaid_string = await mermaid_generator.get_mermaid_chart(message, websocket = websocket)
if mermaid_string == "Failed to generate mermaid":
await websocket.send_json({"status": "error", "response": mermaid_string, 'completed': True})
else:
Expand Down
1 change: 0 additions & 1 deletion topos/channel/debatesim.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from ..FC.argument_detection import ArgumentDetection
from ..config import get_openai_api_key
from ..models.llm_classes import vision_models
from ..generations.ollama_chat import stream_chat
from ..services.database.app_state import AppState
from ..utilities.utils import create_conversation_string
from ..services.classification_service.base_analysis import base_text_classifier, base_token_classifier
Expand Down
11 changes: 9 additions & 2 deletions topos/channel/experimental/debatesim_experimental_think.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..FC.argument_detection import ArgumentDetection
from ..config import get_openai_api_key
from ..models.llm_classes import vision_models
from ..generations.ollama_chat import stream_chat
from ..generations.chat_gens import LLMChatGens
from ..services.database.app_state import AppState
from ..utilities.utils import create_conversation_string
from ..services.classification_service.base_analysis import base_text_classifier, base_token_classifier
Expand Down Expand Up @@ -338,7 +338,14 @@ async def debate_step(self, websocket: WebSocket, data, app_state):
user_id = app_state.get_value("user_id", "")

message_history = payload["message_history"]

# model specifications
model = payload.get("model", "solar")
provider = payload.get('provider', 'ollama') # defaults to ollama right now
api_key = payload.get('api_key', 'ollama')

llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key)

temperature = float(payload.get("temperature", 0.04))
current_topic = payload.get("topic", "Unknown")

Expand Down Expand Up @@ -456,7 +463,7 @@ async def debate_step(self, websocket: WebSocket, data, app_state):

# Processing the chat
output_combined = ""
for chunk in stream_chat(simp_msg_history, model=model, temperature=temperature):
for chunk in llm_client.stream_chat(simp_msg_history, model=model, temperature=temperature):
output_combined += chunk
await websocket.send_json({"status": "generating", "response": output_combined, 'completed': False})

Expand Down
Loading

0 comments on commit e13dd43

Please sign in to comment.