diff --git a/topos/api/api_routes.py b/topos/api/api_routes.py index a155d4c..2fb7a8c 100644 --- a/topos/api/api_routes.py +++ b/topos/api/api_routes.py @@ -13,9 +13,9 @@ from collections import Counter, OrderedDict, defaultdict from pydantic import BaseModel -from ..generations.ollama_chat import generate_response +from ..generations.chat_gens import LLMChatGens from ..utilities.utils import create_conversation_string -from ..services.ontology_service.mermaid_chart import get_mermaid_chart +from ..services.ontology_service.mermaid_chart import MermaidCreator cache_manager = ConversationCacheManager() class ConversationIDRequest(BaseModel): @@ -129,16 +129,23 @@ async def conv_to_image(request: ConversationIDRequest): if conv_data is None: raise HTTPException(status_code=404, detail="Conversation not found in cache") - + + # model specifications + # TODO UPDATE SO ITS NOT HARDCODED model = "dolphin-llama3" + provider = 'ollama' # defaults to ollama right now + api_key = 'ollama' + + llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + context = create_conversation_string(conv_data, 6) print(context) print(f"\t[ converting conversation to image to text prompt: using model {model}]") conv_to_text_img_prompt = "Create an interesting, and compelling image-to-text prompt that can be used in a diffussor model. Be concise and convey more with the use of metaphor. Steer the image style towards Slavador Dali's fantastic, atmospheric, heroesque paintings that appeal to everyman themes." - txt_to_img_prompt = generate_response(context, conv_to_text_img_prompt, model=model, temperature=0) + txt_to_img_prompt = llm_client.generate_response(context, conv_to_text_img_prompt, temperature=0) # print(txt_to_img_prompt) print(f"\t[ generating a file name {model} ]") - txt_to_img_filename = generate_response(txt_to_img_prompt, "Based on the context create an appropriate, and BRIEF, filename with no spaces. Do not use any file extensions in your name, that will be added in a later step.", model=model, temperature=0) + txt_to_img_filename = llm_client.generate_response(txt_to_img_prompt, "Based on the context create an appropriate, and BRIEF, filename with no spaces. Do not use any file extensions in your name, that will be added in a later step.", temperature=0) # run huggingface comic diffusion pipeline = DiffusionPipeline.from_pretrained("ogkalu/Comic-Diffusion") @@ -180,7 +187,15 @@ class GenNextMessageOptions(BaseModel): async def create_next_messages(request: GenNextMessageOptions): conversation_id = request.conversation_id query = request.query + + # model specifications + # TODO UPDATE SO ITS NOT HARDCODED model = request.model if request.model != None else "dolphin-llama3" + provider = 'ollama' # defaults to ollama right now + api_key = 'ollama' + + llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + voice_settings = request.voice_settings if request.voice_settings != None else """{ "tone": "analytical", "distance": "distant", @@ -211,7 +226,7 @@ async def create_next_messages(request: GenNextMessageOptions): system_prompt += conv_json - next_message_options = generate_response(system_prompt, query, model=model, temperature=0) + next_message_options = llm_client.generate_response(system_prompt, query, temperature=0) print(next_message_options) # return the options @@ -225,7 +240,13 @@ class ConversationTopicsRequest(BaseModel): @router.post("/gen_conversation_topics") async def create_next_messages(request: ConversationTopicsRequest): conversation_id = request.conversation_id + # model specifications + # TODO UPDATE SO ITS NOT HARDCODED model = request.model if request.model != None else "dolphin-llama3" + provider = 'ollama' # defaults to ollama right now + api_key = 'ollama' + + llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) # load conversation conv_data = cache_manager.load_from_cache(conversation_id) @@ -233,13 +254,13 @@ async def create_next_messages(request: ConversationTopicsRequest): raise HTTPException(status_code=404, detail="Conversation not found in cache") context = create_conversation_string(conv_data, 12) - print(f"\t[ generating summary :: model {model} :: subject {subject}]") + # print(f"\t[ generating summary :: model {model} :: subject {subject}]") query = f"" # topic list first pass system_prompt = "PRESENT CONVERSATION:\n-------" + context + "\n-------\n" query += """List the topics and those closely related to what this conversation traverses.""" - topic_list = generate_response(system_prompt, query, model=model, temperature=0) + topic_list = llm_client.generate_response(system_prompt, query, temperature=0) print(topic_list) # return the image @@ -247,14 +268,37 @@ async def create_next_messages(request: ConversationTopicsRequest): @router.post("/list_models") -async def list_models(): - url = "http://localhost:11434/api/tags" +async def list_models(provider: str = 'ollama', api_key: str = 'ollama'): + # Define the URLs for different providers + + list_models_urls = { + 'ollama': "http://localhost:11434/api/tags", + 'openai': "https://api.openai.com/v1/models", + 'groq': "https://api.groq.com/openai/v1/models" + } + + if provider not in list_models_urls: + raise HTTPException(status_code=400, detail="Unsupported provider") + + # Get the appropriate URL based on the provider + url = list_models_urls.get(provider.lower()) + + if provider.lower() == 'ollama': + # No need for headers with Ollama + headers = {} + else: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + try: - result = requests.get(url) + # Make the request with the appropriate headers + result = requests.get(url, headers=headers) if result.status_code == 200: return {"result": result.json()} else: - raise HTTPException(status_code=404, detail="Models not found") + raise HTTPException(status_code=result.status_code, detail="Models not found") except requests.ConnectionError: raise HTTPException(status_code=500, detail="Server connection error") @@ -304,9 +348,18 @@ async def generate_mermaid_chart(payload: MermaidChartPayload): try: conversation_id = payload.conversation_id full_conversation = payload.full_conversation + # model specifications model = payload.model + provider = payload.get('provider', 'ollama') # defaults to ollama right now + api_key = payload.get('api_key', 'ollama') temperature = payload.temperature + llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + + mermaid_generator = MermaidCreator(llm_client) + + + if full_conversation: cache_manager = ConversationCacheManager() conv_data = cache_manager.load_from_cache(conversation_id) @@ -321,7 +374,7 @@ async def generate_mermaid_chart(payload: MermaidChartPayload): if message: print(f"\t[ generating mermaid chart :: using model {model} ]") try: - mermaid_string = await get_mermaid_chart(message) + mermaid_string = await mermaid_generator.get_mermaid_chart(message) if mermaid_string == "Failed to generate mermaid": return {"status": "error", "response": mermaid_string, 'completed': True} else: diff --git a/topos/api/websocket_handlers.py b/topos/api/websocket_handlers.py index 30cc778..4c714ae 100644 --- a/topos/api/websocket_handlers.py +++ b/topos/api/websocket_handlers.py @@ -1,10 +1,10 @@ -from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect from datetime import datetime import time import traceback import pprint -from ..generations.ollama_chat import stream_chat +from ..generations.chat_gens import LLMChatGens # from topos.FC.semantic_compression import SemanticCompression # from ..config import get_openai_api_key from ..models.llm_classes import vision_models @@ -13,7 +13,7 @@ from ..utilities.utils import create_conversation_string from ..services.classification_service.base_analysis import base_text_classifier, base_token_classifier from ..services.loggers.process_logger import ProcessLogger -from ..services.ontology_service.mermaid_chart import get_mermaid_chart +from ..services.ontology_service.mermaid_chart import MermaidCreator # cache database from topos.FC.conversation_cache_manager import ConversationCacheManager @@ -55,11 +55,11 @@ async def chat(websocket: WebSocket): chatbot_msg_id = payload["chatbot_msg_id"] message = payload["message"] message_history = payload["message_history"] - model = payload.get("model", "solar") temperature = float(payload.get("temperature", 0.04)) current_topic = payload.get("topic", "Unknown") processing_config = payload.get("processing_config", {}) - + + # Set default values if any key is missing or if processing_config is None default_config = { "showInMessageNER": True, @@ -68,6 +68,14 @@ async def chat(websocket: WebSocket): "calculateModerationTags": True, "showSidebarBaseAnalytics": True } + + # model specifications + model = payload.get("model", "solar") + provider = payload.get('provider', 'ollama') # defaults to ollama right now + api_key = payload.get('api_key', 'ollama') + + llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + # Update default_config with provided processing_config, if any config = {**default_config, **processing_config} @@ -171,9 +179,9 @@ async def chat(websocket: WebSocket): is_first_token = True total_tokens = 0 # Initialize token counter ttfs = 0 # init time to first token value - await process_logger.start("llm_generation_stream_chat", provider="ollama", model=model, len_msg_hist=len(simp_msg_history)) + await process_logger.start("llm_generation_stream_chat", provider=provider, model=model, len_msg_hist=len(simp_msg_history)) start_time = time.time() # Track the start time for the whole process - for chunk in stream_chat(simp_msg_history, model=model, temperature=temperature): + for chunk in llm_client.stream_chat(simp_msg_history, temperature=temperature): if len(chunk) > 0: if is_first_token: ttfs_end_time = time.time() @@ -283,6 +291,14 @@ async def meta_chat(websocket: WebSocket): temperature = float(payload.get("temperature", 0.04)) current_topic = payload.get("topic", "Unknown") + + # model specifications + model = payload.get("model", "solar") + provider = payload.get('provider', 'ollama') # defaults to ollama right now + api_key = payload.get('api_key', 'ollama') + + llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + # Set system prompt system_prompt = f"""You are a highly skilled conversationalist, adept at communicating strategies and tactics. Help the user navigate their current conversation to determine what to say next. You possess a private, unmentioned expertise: PhDs in CBT and DBT, an elegant, smart, provocative speech style, extensive world travel, and deep literary theory knowledge à la Terry Eagleton. Demonstrate your expertise through your guidance, without directly stating it.""" @@ -308,7 +324,7 @@ async def meta_chat(websocket: WebSocket): # Processing the chat output_combined = "" - for chunk in stream_chat(simp_msg_history, model=model, temperature=temperature): + for chunk in llm_client.stream_chat(simp_msg_history, temperature=temperature): try: output_combined += chunk await websocket.send_json({"status": "generating", "response": output_combined, 'completed': False}) @@ -342,8 +358,15 @@ async def meta_chat(websocket: WebSocket): conversation_id = payload["conversation_id"] subject = payload.get("subject", "knowledge") - model = payload.get("model", "solar") temperature = float(payload.get("temperature", 0.04)) + + # model specifications + model = payload.get("model", "solar") + provider = payload.get('provider', 'ollama') # defaults to ollama right now + api_key = payload.get('api_key', 'ollama') + + llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + # load conversation cache_manager = ConversationCacheManager() @@ -367,7 +390,7 @@ async def meta_chat(websocket: WebSocket): # Processing the chat output_combined = "" - for chunk in stream_chat(msg_history, model=model, temperature=temperature): + for chunk in llm_client.stream_chat(msg_history, temperature=temperature): try: output_combined += chunk await websocket.send_json({"status": "generating", "response": output_combined, 'completed': False}) @@ -400,9 +423,15 @@ async def meta_chat(websocket: WebSocket): message = payload.get("message", None) conversation_id = payload["conversation_id"] full_conversation = payload.get("full_conversation", False) + # model specifications model = payload.get("model", "dolphin-llama3") + provider = payload.get('provider', 'ollama') # defaults to ollama right now + api_key = payload.get('api_key', 'ollama') temperature = float(payload.get("temperature", 0.04)) + llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + + mermaid_generator = MermaidCreator(llm_client) # load conversation if full_conversation: cache_manager = ConversationCacheManager() @@ -412,13 +441,13 @@ async def meta_chat(websocket: WebSocket): print(f"\t[ generating mermaid chart :: using model {model} :: full conversation ]") await websocket.send_json({"status": "generating", "response": "generating mermaid chart", 'completed': False}) context = create_conversation_string(conv_data, 12) - # TODO Coomplete this branch + # TODO Complete this branch else: if message: print(f"\t[ generating mermaid chart :: using model {model} ]") await websocket.send_json({"status": "generating", "response": "generating mermaid chart", 'completed': False}) try: - mermaid_string = await get_mermaid_chart(message, websocket = websocket) + mermaid_string = await mermaid_generator.get_mermaid_chart(message, websocket = websocket) if mermaid_string == "Failed to generate mermaid": await websocket.send_json({"status": "error", "response": mermaid_string, 'completed': True}) else: diff --git a/topos/channel/debatesim.py b/topos/channel/debatesim.py index 0a70d0e..9cdb0cb 100644 --- a/topos/channel/debatesim.py +++ b/topos/channel/debatesim.py @@ -33,7 +33,6 @@ from ..FC.argument_detection import ArgumentDetection from ..config import get_openai_api_key from ..models.llm_classes import vision_models -from ..generations.ollama_chat import stream_chat from ..services.database.app_state import AppState from ..utilities.utils import create_conversation_string from ..services.classification_service.base_analysis import base_text_classifier, base_token_classifier diff --git a/topos/channel/experimental/debatesim_experimental_think.py b/topos/channel/experimental/debatesim_experimental_think.py index 87a5c45..c32eea6 100644 --- a/topos/channel/experimental/debatesim_experimental_think.py +++ b/topos/channel/experimental/debatesim_experimental_think.py @@ -29,7 +29,7 @@ from ..FC.argument_detection import ArgumentDetection from ..config import get_openai_api_key from ..models.llm_classes import vision_models -from ..generations.ollama_chat import stream_chat +from ..generations.chat_gens import LLMChatGens from ..services.database.app_state import AppState from ..utilities.utils import create_conversation_string from ..services.classification_service.base_analysis import base_text_classifier, base_token_classifier @@ -338,7 +338,14 @@ async def debate_step(self, websocket: WebSocket, data, app_state): user_id = app_state.get_value("user_id", "") message_history = payload["message_history"] + + # model specifications model = payload.get("model", "solar") + provider = payload.get('provider', 'ollama') # defaults to ollama right now + api_key = payload.get('api_key', 'ollama') + + llm_client = LLMChatGens(model_name=model, provider=provider, api_key=api_key) + temperature = float(payload.get("temperature", 0.04)) current_topic = payload.get("topic", "Unknown") @@ -456,7 +463,7 @@ async def debate_step(self, websocket: WebSocket, data, app_state): # Processing the chat output_combined = "" - for chunk in stream_chat(simp_msg_history, model=model, temperature=temperature): + for chunk in llm_client.stream_chat(simp_msg_history, model=model, temperature=temperature): output_combined += chunk await websocket.send_json({"status": "generating", "response": output_combined, 'completed': False}) diff --git a/topos/generations/chat_gens.py b/topos/generations/chat_gens.py new file mode 100644 index 0000000..c353021 --- /dev/null +++ b/topos/generations/chat_gens.py @@ -0,0 +1,69 @@ +from typing import List, Dict, Generator + +from .llm_client import LLMClient + +# Assuming OpenAI is a pre-defined client for API interactions + +default_models = { + "groq": "llama-3.1-70b-versatile", + "openai": "gpt-4o", + "ollama": "dolphin-llama3" + } + +class LLMChatGens: + def __init__(self, model_name: str, provider: str, api_key: str): + self.provier = provider + self.api_key = api_key + self.client = LLMClient(provider, api_key).get_client() + self.model_name = self._init_model(model_name, provider) + + def _init_model(self, model_name: str, provider: str): + if len(model_name) > 0: + return model_name + else: + if provider == 'ollama': + return model_name + else: + return default_models[provider] + + def stream_chat(self, message_history: List[Dict[str, str]], temperature: float = 0) -> Generator[str, None, None]: + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=message_history, + temperature=temperature, + stream=True + ) + for chunk in response: + if chunk.choices[0].delta.content is not None: + yield chunk.choices[0].delta.content + except Exception as e: + yield f"Error: {str(e)}" + + def generate_response(self, context: str, prompt: str, temperature: float = 0) -> str: + try: + messages = [ + {"role": "system", "content": context}, + {"role": "user", "content": prompt} + ] + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + temperature=temperature, + stream=False + ) + return response.choices[0].message.content + except Exception as e: + return f"Error: {str(e)}" + + def generate_response_messages(self, message_history: List[Dict[str, str]], temperature: float = 0) -> str: + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=message_history, + temperature=temperature, + stream=False + ) + return response.choices[0].message.content + except Exception as e: + return f"Error: {str(e)}" \ No newline at end of file diff --git a/topos/generations/groq_chat.py b/topos/generations/groq_chat.py deleted file mode 100644 index 8dd8fa8..0000000 --- a/topos/generations/groq_chat.py +++ /dev/null @@ -1,6 +0,0 @@ -# STARTER SETUP FOR GROQ API INTEGRATION - -# from openai import OpenAI - -# groq_client = OpenAI(api_key=groq_api_key, base_url="https://api.groq.com/openai/v1") -# groq_model = "llama-3.1-70b-versatile" \ No newline at end of file diff --git a/topos/generations/llm_client.py b/topos/generations/llm_client.py new file mode 100644 index 0000000..dad88dd --- /dev/null +++ b/topos/generations/llm_client.py @@ -0,0 +1,29 @@ +from openai import OpenAI + +api_url_dict = { + 'ollama': 'http://localhost:11434/v1', + 'openai': None, + 'groq': 'https://api.groq.com/openai/v1' +} + +class LLMClient: + def __init__(self, provider: str, api_key: str): + if provider not in api_url_dict: + print(f"Unsupported provider: {self.provider}") + self.provider = provider.lower() + self.api_key = api_key + self.client = self._init_client() + print(f"Init client:: {self.provider}") + + def _init_client(self): + if self.provider == "openai": + return OpenAI(api_key=self.api_key) + else: + url = api_url_dict[self.provider] + return OpenAI(api_key=self.api_key, base_url=url) + + def get_client(self): + return self.client + + def get_provider(self): + return self.provider \ No newline at end of file diff --git a/topos/generations/ollama_chat.py b/topos/generations/ollama_chat.py deleted file mode 100644 index 41cc3d2..0000000 --- a/topos/generations/ollama_chat.py +++ /dev/null @@ -1,37 +0,0 @@ -from openai import OpenAI - - -client = OpenAI( - base_url='http://localhost:11434/v1', - api_key='ollama', # required, but unused -) - -def stream_chat(message_history, model="solar", temperature=0): - response = client.chat.completions.create( - model=model, - messages=message_history, - temperature=temperature, - stream=True - ) - for chunk in response: - if chunk.choices[0].delta.content is not None: - yield chunk.choices[0].delta.content - -def generate_response(context, prompt, model="solar", temperature=0): - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": context}, - {"role": "user", "content": prompt} - ], - stream=False - ) - return response.choices[0].message.content - -def generate_response_messages(message_history, model="solar", temperature=0): - response = client.chat.completions.create( - model=model, - messages=message_history, - stream=False - ) - return response.choices[0].message.content \ No newline at end of file diff --git a/topos/services/ontology_service/mermaid_chart.py b/topos/services/ontology_service/mermaid_chart.py index c6da5b6..629aa06 100644 --- a/topos/services/ontology_service/mermaid_chart.py +++ b/topos/services/ontology_service/mermaid_chart.py @@ -2,172 +2,175 @@ import re from topos.FC.ontological_feature_detection import OntologicalFeatureDetection -from topos.generations.ollama_chat import generate_response, generate_response_messages +from topos.generations.chat_gens import LLMChatGens +class MermaidCreator: + def __init__(self, llmChatGens: LLMChatGens): + self.client = llmChatGens + + def get_ontology_old_method(self, message): + user_id = "jonny" + session_id = "127348901" + message_id = "1531ijasda8" + + composable_string = f"for user {user_id}, of {session_id}, the message is: {message}" -def get_ontology_old_method(message): - user_id = "jonny" - session_id = "127348901" - message_id = "1531ijasda8" - - composable_string = f"for user {user_id}, of {session_id}, the message is: {message}" + ontological_feature_detection = OntologicalFeatureDetection("neo4j_uri", "neo4j_user", "neo4j_password", + "showroom_db_name", False) + + entities, pos_tags, dependencies, relations, srl_results, timestamp, context_entities = ontological_feature_detection.build_ontology_from_paragraph( + user_id, session_id, message_id, composable_string) - ontological_feature_detection = OntologicalFeatureDetection("neo4j_uri", "neo4j_user", "neo4j_password", - "showroom_db_name", False) - - entities, pos_tags, dependencies, relations, srl_results, timestamp, context_entities = ontological_feature_detection.build_ontology_from_paragraph( - user_id, session_id, message_id, composable_string) + input_components = message, entities, dependencies, relations, srl_results, timestamp, context_entities - input_components = message, entities, dependencies, relations, srl_results, timestamp, context_entities + mermaid_syntax = ontological_feature_detection.extract_mermaid_syntax(input_components, input_type="components") + mermaid_to_ascii = ontological_feature_detection.mermaid_to_ascii(mermaid_syntax) + return mermaid_to_ascii - mermaid_syntax = ontological_feature_detection.extract_mermaid_syntax(input_components, input_type="components") - mermaid_to_ascii = ontological_feature_detection.mermaid_to_ascii(mermaid_syntax) - return mermaid_to_ascii + def extract_mermaid_chart(self, response): + mermaid_code = re.search(r'```mermaid\n(.*?)```', response, re.DOTALL) + if mermaid_code: + return mermaid_code.group(0) + + # Check for the variation with an extra newline character + mermaid_code_variation = re.search(r'```\nmermaid\n(.*?)```', response, re.DOTALL) + if mermaid_code_variation: + print("\t[ reformatting mermaid chart ]") + # Fix the variation by placing the mermaid text right after the ``` + fixed_mermaid_code = "```mermaid\n" + mermaid_code_variation.group(1) + "\n```" + return fixed_mermaid_code + return None -def extract_mermaid_chart(response): - mermaid_code = re.search(r'```mermaid\n(.*?)```', response, re.DOTALL) - if mermaid_code: - return mermaid_code.group(0) - - # Check for the variation with an extra newline character - mermaid_code_variation = re.search(r'```\nmermaid\n(.*?)```', response, re.DOTALL) - if mermaid_code_variation: - print("\t[ reformatting mermaid chart ]") - # Fix the variation by placing the mermaid text right after the ``` - fixed_mermaid_code = "```mermaid\n" + mermaid_code_variation.group(1) + "\n```" - return fixed_mermaid_code - return None + def refine_mermaid_lines(self, mermaid_chart): + lines = mermaid_chart.split('\n') + refined_lines = [] + for line in lines: + if '-->' in line: + parts = line.split('-->') + parts = [part.strip().replace(' ', '_') for part in parts] + refined_line = ' --> '.join(parts) + refined_lines.append(" " + refined_line) # add the indent to the start of the line + else: + refined_lines.append(line) + return '\n'.join(refined_lines) -def refine_mermaid_lines(mermaid_chart): - lines = mermaid_chart.split('\n') - refined_lines = [] - for line in lines: - if '-->' in line: - parts = line.split('-->') - parts = [part.strip().replace(' ', '_') for part in parts] - refined_line = ' --> '.join(parts) - refined_lines.append(" " + refined_line) # add the indent to the start of the line - else: - refined_lines.append(line) - return '\n'.join(refined_lines) + async def get_mermaid_chart(self, message, websocket = None): + """ + Input: String Message + Output: mermaid chart + ``` mermaid + graph TD + Texas -->|is| hot + hot -->|is| uncomfortable + hot -->|is| unwanted + Texas -->|actions| options + options -->|best| Go_INSIDE + options -->|second| Go_to_Canada + options -->|last| try_not_to_die + ```""" + + system_role = "Our goal is to help a visual learner better comprehend a sentence, by illustrating the text in a graph form. Your job is to create a list of graph triples from the speaker's sentence.\n" + system_directive = """RULES: + 1. Extract graph triples from the sentence. + 2. Use very simple synonyms to decrease the nuance in the statement. + 3. Stay true to the sentence, make inferences about the sentiment, intent, if it is reasonable to do so. + 4. Use natural language to create the triples. + 5. Write only the comma separated triples format that follow node, relationship, node pattern + 6. If the statement is an opinion, create a relationship that assigns the speaker has_preference + 6. DO NOT HAVE ANY ISLAND RELATIONSHIPS. ALL EDGES MUST CONNECT.""" + system_examples = """``` + INPUT SENTENCE: The Texas heat is OPPRESSIVE + OUTPUT: + Texas, is, hot + hot, is, uncomfortable + hot, is, unwanted + --- + SENTENCE: "Isn't Italy a better country than Spain?" + OUTPUT: + Italy, is_a, country + Spain, is_a, country + Italy, better, Spain + better, property, comparison + speaker, has_preference, Italy + ```""" + prompt = f"For the sake of illumination, represent this speaker's sentence in triples: {message}" + system_ctx = system_role + system_directive + system_examples + print("\t[ generating sentence_abstractive_graph_triples ]") + if websocket: + await websocket.send_json({"status": "generating", "response": "generating sentence_abstractive_graph_triples", 'completed': False}) + sentence_abstractive_graph_triples = self.client.generate_response(system_ctx, prompt) + # print(sentence_abstractive_graph_triples) + + prompt = f"We were just given us the above triples to represent this message: '{message}'. Improve and correct their triples in a plaintext codeblock." + print("\t[ generating refined_abstractive_graph_triples ]") + if websocket: + await websocket.send_json({"status": "generating", "response": "generating refined_abstractive_graph_triples", 'completed': False}) + refined_abstractive_graph_triples = self.client.generate_response(sentence_abstractive_graph_triples, prompt) # a second pass to refine the first generation's responses + # what is being said, + + # add relations to this existing graph that offer actions that can be taken, be humorous and absurd + + # output these graph relations into a mermaid chart we can use in markdown. Follow this form + system_ctx = f"""Generate a mermaid block based off the triples. + It should look like this: + Example 1: + ```mermaid + graph TD; + Italy--> |is_a| country; + Spain--> |is_a| country; + Italy--> better-->Spain; + better-->property-->comparison; + speaker-->has_preference-->Italy; + ``` + Example 2: + ```mermaid + graph TD; + High_School-->duration_of_study-->10_Years; + High_School-->compared_to-->4_Year_Program; + 10_Year_Program-->more_time-->4_Years; + Speaker-->seeks_change-->High_School_Length; + ``` + Rules: + 1. No spaces between entities! + """ + prompt =f"""Create a mermaid chart from these triples: {refined_abstractive_graph_triples}. Reduce the noise and combine elements if they are referencing the same thing. + Since United_States and The_United_States are the same thing, make the output just use: United_States. + Example: + Input triples + ``` + United_States, has_spending, too_much + The_United_States, could_do_with, less_spending; + too_much, is, undesirable; + ``` -async def get_mermaid_chart(message, websocket = None): - """ - Input: String Message - Output: mermaid chart - ``` mermaid - graph TD - Texas -->|is| hot - hot -->|is| uncomfortable - hot -->|is| unwanted - Texas -->|actions| options - options -->|best| Go_INSIDE - options -->|second| Go_to_Canada - options -->|last| try_not_to_die - ```""" - - system_role = "Our goal is to help a visual learner better comprehend a sentence, by illustrating the text in a graph form. Your job is to create a list of graph triples from the speaker's sentence.\n" - system_directive = """RULES: - 1. Extract graph triples from the sentence. - 2. Use very simple synonyms to decrease the nuance in the statement. - 3. Stay true to the sentence, make inferences about the sentiment, intent, if it is reasonable to do so. - 4. Use natural language to create the triples. - 5. Write only the comma separated triples format that follow node, relationship, node pattern - 6. If the statement is an opinion, create a relationship that assigns the speaker has_preference - 6. DO NOT HAVE ANY ISLAND RELATIONSHIPS. ALL EDGES MUST CONNECT.""" - system_examples = """``` - INPUT SENTENCE: The Texas heat is OPPRESSIVE - OUTPUT: - Texas, is, hot - hot, is, uncomfortable - hot, is, unwanted - --- - SENTENCE: "Isn't Italy a better country than Spain?" - OUTPUT: - Italy, is_a, country - Spain, is_a, country - Italy, better, Spain - better, property, comparison - speaker, has_preference, Italy - ```""" - prompt = f"For the sake of illumination, represent this speaker's sentence in triples: {message}" - system_ctx = system_role + system_directive + system_examples - print("\t[ generating sentence_abstractive_graph_triples ]") - if websocket: - await websocket.send_json({"status": "generating", "response": "generating sentence_abstractive_graph_triples", 'completed': False}) - sentence_abstractive_graph_triples = generate_response(system_ctx, prompt, model='dolphin-llama3') - # print(sentence_abstractive_graph_triples) - - prompt = f"We were just given us the above triples to represent this message: '{message}'. Improve and correct their triples in a plaintext codeblock." - print("\t[ generating refined_abstractive_graph_triples ]") - if websocket: - await websocket.send_json({"status": "generating", "response": "generating refined_abstractive_graph_triples", 'completed': False}) - refined_abstractive_graph_triples = generate_response(sentence_abstractive_graph_triples, prompt, model='dolphin-llama3') # a second pass to refine the first generation's responses - # what is being said, - - # add relations to this existing graph that offer actions that can be taken, be humorous and absurd - - # output these graph relations into a mermaid chart we can use in markdown. Follow this form - system_ctx = f"""Generate a mermaid block based off the triples. - It should look like this: -Example 1: + Output Mermaid Where you Substitute The_United_States, with United_States. ```mermaid -graph TD; - Italy--> |is_a| country; - Spain--> |is_a| country; - Italy--> better-->Spain; - better-->property-->comparison; - speaker-->has_preference-->Italy; -``` -Example 2: -```mermaid -graph TD; - High_School-->duration_of_study-->10_Years; - High_School-->compared_to-->4_Year_Program; - 10_Year_Program-->more_time-->4_Years; - Speaker-->seeks_change-->High_School_Length; -``` -Rules: -1. No spaces between entities! + graph TD; + United_States --> |has_spending| too_much; + United_States --> |could_do_with| less_spending; + too_much --> |is| undesirable; + ``` """ - prompt =f"""Create a mermaid chart from these triples: {refined_abstractive_graph_triples}. Reduce the noise and combine elements if they are referencing the same thing. -Since United_States and The_United_States are the same thing, make the output just use: United_States. -Example: -Input triples -``` -United_States, has_spending, too_much -The_United_States, could_do_with, less_spending; -too_much, is, undesirable; -``` - -Output Mermaid Where you Substitute The_United_States, with United_States. -```mermaid -graph TD; - United_States --> |has_spending| too_much; - United_States --> |could_do_with| less_spending; - too_much --> |is| undesirable; -``` -""" - attempt = 0 - message_history = [{"role": "system", "content": system_ctx}, {"role": "user", "content": prompt}] - while attempt < 3: - if attempt == 0: - print("\t\t[ generating mermaid chart ]") - if websocket: - await websocket.send_json({"status": "generating", "response": "generating mermaid_chart_from_triples", 'completed': False}) - else: - print(f"\t\t[ generating mermaid chart :: try {attempt + 1}]") + attempt = 0 + message_history = [{"role": "system", "content": system_ctx}, {"role": "user", "content": prompt}] + while attempt < 3: + if attempt == 0: + print("\t\t[ generating mermaid chart ]") if websocket: - await websocket.send_json({"status": "generating", "response": f"generating mermaid_chart_from_triples :: try {attempt + 1}", 'completed': False}) - response = generate_response_messages(message_history, model='dolphin-llama3') - mermaid_chart = extract_mermaid_chart(response) - if mermaid_chart: - # refined_mermaid_chart = refine_mermaid_lines(mermaid_chart) - return mermaid_chart - # print("FAILED:\n", response) - message_history.append({"role": "user", "content": "That wasn't correct. State why and do it better."}) - attempt += 1 - return "Failed to generate mermaid" + await websocket.send_json({"status": "generating", "response": "generating mermaid_chart_from_triples", 'completed': False}) + else: + print(f"\t\t[ generating mermaid chart :: try {attempt + 1}]") + if websocket: + await websocket.send_json({"status": "generating", "response": f"generating mermaid_chart_from_triples :: try {attempt + 1}", 'completed': False}) + response = self.client.generate_response_messages(message_history) + mermaid_chart = self.client.extract_mermaid_chart(response) + if mermaid_chart: + # refined_mermaid_chart = refine_mermaid_lines(mermaid_chart) + return mermaid_chart + # print("FAILED:\n", response) + message_history.append({"role": "user", "content": "That wasn't correct. State why and do it better."}) + attempt += 1 + return "Failed to generate mermaid" # message = "Why can't we go to High School for 10 years instead of 4!!!"