diff --git a/topos/api/websocket_handlers.py b/topos/api/websocket_handlers.py index a10beb4..305999e 100644 --- a/topos/api/websocket_handlers.py +++ b/topos/api/websocket_handlers.py @@ -44,7 +44,7 @@ async def end_ws_process(websocket, websocket_process, process_logger, send_json async def chat(websocket: WebSocket): await websocket.accept() process_logger = ProcessLogger(verbose=False, run_logger=False) - websocket_process = "Processing /websocket_chat" + websocket_process = "/websocket_chat" await process_logger.start(websocket_process) try: while True: @@ -102,9 +102,10 @@ async def chat(websocket: WebSocket): last_message = simp_msg_history[-1]['content'] role = simp_msg_history[-1]['role'] + num_user_toks = len(last_message.split()) # Fetch base, per-message token classifiers if config['calculateInMessageNER']: - await process_logger.start("calculateInMessageNER-user") + await process_logger.start("calculateInMessageNER-user", num_toks=num_user_toks) start_time = time.time() base_analysis = base_token_classifier(last_message) # this is only an ner dict atm duration = time.time() - start_time @@ -114,7 +115,7 @@ async def chat(websocket: WebSocket): # Fetch base, per-message text classifiers # Start timer for base_text_classifier if config['calculateModerationTags']: - await process_logger.start("calculateModerationTags-user") + await process_logger.start("calculateModerationTags-user", num_toks=num_user_toks) start_time = time.time() text_classifiers = {} try: @@ -169,13 +170,14 @@ async def chat(websocket: WebSocket): output_combined = "" is_first_token = True total_tokens = 0 # Initialize token counter + ttfs = 0 # init time to first token value + await process_logger.start("llm_generation_stream_chat", provider="ollama", model=model, len_msg_hist=len(simp_msg_history)) start_time = time.time() # Track the start time for the whole process - await process_logger.start("Retrieving LLM Generation", provider="ollama", model=model, len_msg_hist=len(simp_msg_history)) - await process_logger.start("Time To First Token") for chunk in stream_chat(simp_msg_history, model=model, temperature=temperature): if len(chunk) > 0: if is_first_token: - await process_logger.end("Time To First Token") + ttfs_end_time = time.time() + ttfs = ttfs_end_time - start_time is_first_token = False output_combined += chunk total_tokens += len(chunk.split()) @@ -185,14 +187,18 @@ async def chat(websocket: WebSocket): # Calculate tokens per second if elapsed_time > 0: tokens_per_second = total_tokens / elapsed_time - await process_logger.end("Retrieving LLM Generation", toks_per_sec=f"{tokens_per_second:.1f}") + ttl_num_toks = 0 + for i in simp_msg_history: + ttl_num_toks += len(i['content'].split()) + await process_logger.end("llm_generation_stream_chat", toks_per_sec=f"{tokens_per_second:.1f}", ttfs=f"{ttfs}", num_toks=num_user_toks, ttl_num_toks=ttl_num_toks) # Fetch semantic category from the output # semantic_compression = SemanticCompression(model=f"ollama:{model}", api_key=get_openai_api_key()) # semantic_category = semantic_compression.fetch_semantic_category(output_combined) + num_response_toks=len(output_combined.split()) # Start timer for base_token_classifier if config['calculateInMessageNER']: - await process_logger.start("calculateInMessageNER-ChatBot") + await process_logger.start("calculateInMessageNER-ChatBot", num_toks=num_response_toks) start_time = time.time() base_analysis = base_token_classifier(output_combined) duration = time.time() - start_time @@ -201,7 +207,7 @@ async def chat(websocket: WebSocket): # Start timer for base_text_classifier if config['calculateModerationTags']: - await process_logger.start("calculateModerationTags-ChatBot") + await process_logger.start("calculateModerationTags-ChatBot", num_toks=num_response_toks) start_time = time.time() text_classifiers = base_text_classifier(output_combined) duration = time.time() - start_time