Skip to content

Commit

Permalink
+ data to process logger for /websocket_chat
Browse files Browse the repository at this point in the history
  • Loading branch information
jonnyjohnson1 committed Aug 12, 2024
1 parent 397ec99 commit 9a19c36
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions topos/api/websocket_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def end_ws_process(websocket, websocket_process, process_logger, send_json
async def chat(websocket: WebSocket):
await websocket.accept()
process_logger = ProcessLogger(verbose=False, run_logger=False)
websocket_process = "Processing /websocket_chat"
websocket_process = "/websocket_chat"
await process_logger.start(websocket_process)
try:
while True:
Expand Down Expand Up @@ -102,9 +102,10 @@ async def chat(websocket: WebSocket):

last_message = simp_msg_history[-1]['content']
role = simp_msg_history[-1]['role']
num_user_toks = len(last_message.split())
# Fetch base, per-message token classifiers
if config['calculateInMessageNER']:
await process_logger.start("calculateInMessageNER-user")
await process_logger.start("calculateInMessageNER-user", num_toks=num_user_toks)
start_time = time.time()
base_analysis = base_token_classifier(last_message) # this is only an ner dict atm
duration = time.time() - start_time
Expand All @@ -114,7 +115,7 @@ async def chat(websocket: WebSocket):
# Fetch base, per-message text classifiers
# Start timer for base_text_classifier
if config['calculateModerationTags']:
await process_logger.start("calculateModerationTags-user")
await process_logger.start("calculateModerationTags-user", num_toks=num_user_toks)
start_time = time.time()
text_classifiers = {}
try:
Expand Down Expand Up @@ -169,13 +170,14 @@ async def chat(websocket: WebSocket):
output_combined = ""
is_first_token = True
total_tokens = 0 # Initialize token counter
ttfs = 0 # init time to first token value
await process_logger.start("llm_generation_stream_chat", provider="ollama", model=model, len_msg_hist=len(simp_msg_history))
start_time = time.time() # Track the start time for the whole process
await process_logger.start("Retrieving LLM Generation", provider="ollama", model=model, len_msg_hist=len(simp_msg_history))
await process_logger.start("Time To First Token")
for chunk in stream_chat(simp_msg_history, model=model, temperature=temperature):
if len(chunk) > 0:
if is_first_token:
await process_logger.end("Time To First Token")
ttfs_end_time = time.time()
ttfs = ttfs_end_time - start_time
is_first_token = False
output_combined += chunk
total_tokens += len(chunk.split())
Expand All @@ -185,14 +187,18 @@ async def chat(websocket: WebSocket):
# Calculate tokens per second
if elapsed_time > 0:
tokens_per_second = total_tokens / elapsed_time
await process_logger.end("Retrieving LLM Generation", toks_per_sec=f"{tokens_per_second:.1f}")
ttl_num_toks = 0
for i in simp_msg_history:
ttl_num_toks += len(i['content'].split())
await process_logger.end("llm_generation_stream_chat", toks_per_sec=f"{tokens_per_second:.1f}", ttfs=f"{ttfs}", num_toks=num_user_toks, ttl_num_toks=ttl_num_toks)
# Fetch semantic category from the output
# semantic_compression = SemanticCompression(model=f"ollama:{model}", api_key=get_openai_api_key())
# semantic_category = semantic_compression.fetch_semantic_category(output_combined)

num_response_toks=len(output_combined.split())
# Start timer for base_token_classifier
if config['calculateInMessageNER']:
await process_logger.start("calculateInMessageNER-ChatBot")
await process_logger.start("calculateInMessageNER-ChatBot", num_toks=num_response_toks)
start_time = time.time()
base_analysis = base_token_classifier(output_combined)
duration = time.time() - start_time
Expand All @@ -201,7 +207,7 @@ async def chat(websocket: WebSocket):

# Start timer for base_text_classifier
if config['calculateModerationTags']:
await process_logger.start("calculateModerationTags-ChatBot")
await process_logger.start("calculateModerationTags-ChatBot", num_toks=num_response_toks)
start_time = time.time()
text_classifiers = base_text_classifier(output_combined)
duration = time.time() - start_time
Expand Down

0 comments on commit 9a19c36

Please sign in to comment.