-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
289d16c
commit 69b479c
Showing
16 changed files
with
1,176 additions
and
13 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import os | ||
import json | ||
|
||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect | ||
|
||
from topos.FC.conversation_cache_manager import ConversationCacheManager | ||
|
||
from ....generations.chat_gens import LLMController | ||
from ....utilities.utils import create_conversation_string | ||
from ....services.ontology_service.mermaid_chart import MermaidCreator | ||
from ....models.models import MermaidChartPayload | ||
|
||
import logging | ||
|
||
router = APIRouter() | ||
|
||
db_config = { | ||
"dbname": os.getenv("POSTGRES_DB"), | ||
"user": os.getenv("POSTGRES_USER"), | ||
"password": os.getenv("POSTGRES_PASSWORD"), | ||
"host": os.getenv("POSTGRES_HOST"), | ||
"port": os.getenv("POSTGRES_PORT") | ||
} | ||
|
||
logging.info(f"Database configuration: {db_config}") | ||
|
||
use_postgres = True | ||
if use_postgres: | ||
cache_manager = ConversationCacheManager(use_postgres=True, db_config=db_config) | ||
else: | ||
cache_manager = ConversationCacheManager() | ||
|
||
|
||
@router.post("/generate_mermaid_chart") | ||
async def generate_mermaid_chart(payload: MermaidChartPayload): | ||
try: | ||
conversation_id = payload.conversation_id | ||
full_conversation = payload.full_conversation | ||
# model specifications | ||
model = payload.model | ||
provider = payload.provider# defaults to ollama right now | ||
api_key = payload.api_key | ||
temperature = payload.temperature | ||
|
||
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) | ||
|
||
mermaid_generator = MermaidCreator(llm_client) | ||
|
||
if full_conversation: | ||
cache_manager = cache_manager | ||
conv_data = cache_manager.load_from_cache(conversation_id) | ||
if conv_data is None: | ||
raise HTTPException(status_code=404, detail="Conversation not found in cache") | ||
print(f"\t[ generating mermaid chart :: {provider}/{model} :: full conversation ]") | ||
return {"status": "generating", "response": "generating mermaid chart", 'completed': False} | ||
# TODO: Complete this branch if needed | ||
|
||
else: | ||
message = payload.message | ||
if message: | ||
print(f"\t[ generating mermaid chart :: using model {model} ]") | ||
try: | ||
mermaid_string = await mermaid_generator.get_mermaid_chart(message) | ||
print(mermaid_string) | ||
if mermaid_string == "Failed to generate mermaid": | ||
return {"status": "error", "response": mermaid_string, 'completed': True} | ||
else: | ||
return {"status": "completed", "response": mermaid_string, 'completed': True} | ||
except Exception as e: | ||
return {"status": "error", "response": f"Error: {e}", 'completed': True} | ||
|
||
except Exception as e: | ||
return {"status": "error", "message": str(e)} | ||
|
||
|
||
@router.websocket("/websocket_mermaid_chart") | ||
async def meta_chat(websocket: WebSocket): | ||
""" | ||
Generates a mermaid chart from a list of message. | ||
""" | ||
await websocket.accept() | ||
try: | ||
while True: | ||
data = await websocket.receive_text() | ||
payload = json.loads(data) | ||
message = payload.get("message", None) | ||
conversation_id = payload["conversation_id"] | ||
full_conversation = payload.get("full_conversation", False) | ||
# model specifications | ||
model = payload.get("model", "dolphin-llama3") | ||
provider = payload.get('provider', 'ollama') # defaults to ollama right now | ||
api_key = payload.get('api_key', 'ollama') | ||
temperature = float(payload.get("temperature", 0.04)) | ||
|
||
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) | ||
|
||
mermaid_generator = MermaidCreator(llm_client) | ||
# load conversation | ||
if full_conversation: | ||
cache_manager = cache_manager | ||
conv_data = cache_manager.load_from_cache(conversation_id) | ||
if conv_data is None: | ||
raise HTTPException(status_code=404, detail="Conversation not found in cache") | ||
print(f"\t[ generating mermaid chart :: using model {model} :: full conversation ]") | ||
await websocket.send_json({"status": "generating", "response": "generating mermaid chart", 'completed': False}) | ||
context = create_conversation_string(conv_data, 12) | ||
# TODO Complete this branch | ||
else: | ||
if message: | ||
print(f"\t[ generating mermaid chart :: using model {model} ]") | ||
await websocket.send_json({"status": "generating", "response": "generating mermaid chart", 'completed': False}) | ||
try: | ||
mermaid_string = await mermaid_generator.get_mermaid_chart(message, websocket = websocket) | ||
if mermaid_string == "Failed to generate mermaid": | ||
await websocket.send_json({"status": "error", "response": mermaid_string, 'completed': True}) | ||
else: | ||
await websocket.send_json({"status": "completed", "response": mermaid_string, 'completed': True}) | ||
except Exception as e: | ||
await websocket.send_json({"status": "error", "response": f"Error: {e}", 'completed': True}) | ||
except WebSocketDisconnect: | ||
print("WebSocket disconnected") | ||
except Exception as e: | ||
await websocket.send_json({"status": "error", "message": str(e)}) | ||
await websocket.close() | ||
finally: | ||
await websocket.close() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect | ||
import json | ||
import os | ||
|
||
from ....generations.chat_gens import LLMController | ||
from ....utilities.utils import create_conversation_string | ||
|
||
# cache database | ||
from topos.FC.conversation_cache_manager import ConversationCacheManager | ||
|
||
import logging | ||
|
||
db_config = { | ||
"dbname": os.getenv("POSTGRES_DB"), | ||
"user": os.getenv("POSTGRES_USER"), | ||
"password": os.getenv("POSTGRES_PASSWORD"), | ||
"host": os.getenv("POSTGRES_HOST"), | ||
"port": os.getenv("POSTGRES_PORT") | ||
} | ||
|
||
logging.info(f"Database configuration: {db_config}") | ||
|
||
use_postgres = True | ||
if use_postgres: | ||
cache_manager = ConversationCacheManager(use_postgres=True, db_config=db_config) | ||
else: | ||
cache_manager = ConversationCacheManager() | ||
|
||
router = APIRouter() | ||
|
||
@router.websocket("/websocket_chat_summary") | ||
async def meta_chat(websocket: WebSocket): | ||
""" | ||
Generates a summary of the conversation oriented around a given focal point. | ||
""" | ||
await websocket.accept() | ||
try: | ||
while True: | ||
data = await websocket.receive_text() | ||
payload = json.loads(data) | ||
|
||
conversation_id = payload["conversation_id"] | ||
subject = payload.get("subject", "knowledge") | ||
temperature = float(payload.get("temperature", 0.04)) | ||
|
||
# model specifications | ||
model = payload.get("model", "solar") | ||
provider = payload.get('provider', 'ollama') # defaults to ollama right now | ||
api_key = payload.get('api_key', 'ollama') | ||
|
||
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) | ||
|
||
|
||
# load conversation | ||
cache_manager = cache_manager | ||
conv_data = cache_manager.load_from_cache(conversation_id) | ||
if conv_data is None: | ||
raise HTTPException(status_code=404, detail="Conversation not found in cache") | ||
|
||
context = create_conversation_string(conv_data, 12) | ||
|
||
print(f"\t[ generating summary :: model {model} :: subject {subject}]") | ||
|
||
# Set system prompt | ||
system_prompt = "PRESENT CONVERSATION:\n-------<context>" + context + "\n-------\n" | ||
query = f"""Summarize this conversation. Frame your response around the subject of {subject}""" | ||
|
||
msg_history = [{'role': 'system', 'content': system_prompt}] | ||
|
||
# Append the present message to the message history | ||
simplified_message = {'role': "user", 'content': query} | ||
msg_history.append(simplified_message) | ||
|
||
# Processing the chat | ||
output_combined = "" | ||
for chunk in llm_client.stream_chat(msg_history, temperature=temperature): | ||
try: | ||
output_combined += chunk | ||
await websocket.send_json({"status": "generating", "response": output_combined, 'completed': False}) | ||
except Exception as e: | ||
print(e) | ||
await websocket.send_json({"status": "error", "message": str(e)}) | ||
await websocket.close() | ||
# Send the final completed message | ||
await websocket.send_json( | ||
{"status": "completed", "response": output_combined, "completed": True}) | ||
|
||
except WebSocketDisconnect: | ||
print("WebSocket disconnected") | ||
except Exception as e: | ||
await websocket.send_json({"status": "error", "message": str(e)}) | ||
await websocket.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
|
||
|
||
import os | ||
from fastapi import APIRouter, HTTPException | ||
from topos.FC.conversation_cache_manager import ConversationCacheManager | ||
|
||
from ....generations.chat_gens import LLMController | ||
from ....utilities.utils import create_conversation_string | ||
from ....models.models import ConversationTopicsRequest | ||
|
||
import logging | ||
|
||
db_config = { | ||
"dbname": os.getenv("POSTGRES_DB"), | ||
"user": os.getenv("POSTGRES_USER"), | ||
"password": os.getenv("POSTGRES_PASSWORD"), | ||
"host": os.getenv("POSTGRES_HOST"), | ||
"port": os.getenv("POSTGRES_PORT") | ||
} | ||
|
||
logging.info(f"Database configuration: {db_config}") | ||
|
||
use_postgres = True | ||
if use_postgres: | ||
cache_manager = ConversationCacheManager(use_postgres=True, db_config=db_config) | ||
else: | ||
cache_manager = ConversationCacheManager() | ||
|
||
router = APIRouter() | ||
|
||
@router.post("/get_files") | ||
async def create_next_messages(request: ConversationTopicsRequest): | ||
conversation_id = request.conversation_id | ||
# model specifications | ||
# TODO UPDATE SO ITS NOT HARDCODED | ||
model = request.model if request.model != None else "dolphin-llama3" | ||
provider = 'ollama' # defaults to ollama right now | ||
api_key = 'ollama' | ||
|
||
llm_client = LLMController(model_name=model, provider=provider, api_key=api_key) | ||
|
||
# load conversation | ||
conv_data = cache_manager.load_from_cache(conversation_id) | ||
if conv_data is None: | ||
raise HTTPException(status_code=404, detail="Conversation not found in cache") | ||
|
||
context = create_conversation_string(conv_data, 12) | ||
# print(f"\t[ generating summary :: model {model} :: subject {subject}]") | ||
|
||
query = f"" | ||
# topic list first pass | ||
system_prompt = "PRESENT CONVERSATION:\n-------<context>" + context + "\n-------\n" | ||
query += """List the topics and those closely related to what this conversation traverses.""" | ||
topic_list = llm_client.generate_response(system_prompt, query, temperature=0) | ||
print(topic_list) | ||
|
||
# return the image | ||
return {"response" : topic_list} |
Oops, something went wrong.