From ea14743acf55b062421b36dc239d7c93fe98dd9b Mon Sep 17 00:00:00 2001 From: Kirill Starkov Date: Tue, 18 Feb 2025 18:56:40 +0800 Subject: [PATCH] add compression for plain messages --- .../scratchpads/chat_utils_limit_history.rs | 76 +++++++++++++++++-- 1 file changed, 69 insertions(+), 7 deletions(-) diff --git a/refact-agent/engine/src/scratchpads/chat_utils_limit_history.rs b/refact-agent/engine/src/scratchpads/chat_utils_limit_history.rs index 8265f4146..4fb55e9b6 100644 --- a/refact-agent/engine/src/scratchpads/chat_utils_limit_history.rs +++ b/refact-agent/engine/src/scratchpads/chat_utils_limit_history.rs @@ -1,7 +1,44 @@ +use std::cmp::min; use crate::scratchpad_abstract::HasTokenizerAndEot; -use crate::call_validation::ChatMessage; +use crate::call_validation::{ChatContent, ChatMessage}; use std::collections::HashSet; +use std::sync::{Arc, RwLock}; +use tokenizers::Tokenizer; +use crate::scratchpads::multimodality::MultimodalElement; +const MESSAGE_TOKEN_LIMIT: i32 = 12_000; + +fn compress_string(text: &String, tokenizer: Arc>) -> Result { + let tokenizer_lock = tokenizer.read().unwrap(); + let tokens = tokenizer_lock.encode(&**text, false).map_err(|e| e.to_string())?; + let first_tokens = &tokens.get_ids()[0..(MESSAGE_TOKEN_LIMIT / 2) as usize]; + let last_tokens = &tokens.get_ids()[tokens.len() - (MESSAGE_TOKEN_LIMIT / 2) as usize ..]; + let mut text = tokenizer_lock.decode(first_tokens, false).map_err(|e| e.to_string())?; + text.push_str("\n...\n"); + text.push_str(&tokenizer_lock.decode(last_tokens, false).map_err(|e| e.to_string())?); + Ok(text) +} + +fn compress_message(msg: &ChatMessage, tokenizer: Arc>) -> Result { + let mut message = msg.clone(); + match message.content.clone() { + ChatContent::SimpleText(simple_text) => { + message.content = ChatContent::SimpleText(compress_string(&simple_text, tokenizer.clone())?); + } + ChatContent::Multimodal(elements) => { + let mut new_elements: Vec = vec![]; + for element in elements { + if element.is_text() { + new_elements.push(MultimodalElement::new("text".to_string(), compress_string(&element.m_content, tokenizer.clone())?)?); + } else { + new_elements.push(element.clone()); + } + } + message.content = ChatContent::Multimodal(new_elements); + } + }; + Ok(message) +} pub fn limit_messages_history( t: &HasTokenizerAndEot, @@ -16,6 +53,9 @@ pub fn limit_messages_history( let mut tokens_used: i32 = 0; let mut message_token_count: Vec = vec![0; messages.len()]; let mut message_take: Vec = vec![false; messages.len()]; + let mut message_can_be_compressed: Vec = vec![false; messages.len()]; + let message_roles: Vec = messages.iter().map(|x| x.role.clone()).collect(); + for (i, msg) in messages.iter().enumerate() { let tcnt = 3 + msg.content.count_tokens(t.tokenizer.clone(), &None)?; message_token_count[i] = tcnt; @@ -23,21 +63,33 @@ pub fn limit_messages_history( message_take[i] = true; tokens_used += tcnt; } else if i==1 && msg.role == "user" { - // we cannot drop the user message which comes right after the system message according to Antropic API + // we cannot drop the user message which comes right after the system message according to Anthropic API message_take[i] = true; - tokens_used += tcnt; + tokens_used += min(tcnt, MESSAGE_TOKEN_LIMIT + 3); } else if i >= last_user_msg_starts { message_take[i] = true; - tokens_used += tcnt; + tokens_used += min(tcnt, MESSAGE_TOKEN_LIMIT + 3); } } + + // Need to save uncompressed last messages of assistant, tool_calls and user between assistant. It could be patch tool calls + for i in (0..message_roles.len()).rev() { + if message_roles[i] == "user" { + message_can_be_compressed[i] = true; + } + } + let mut log_buffer = Vec::new(); let mut dropped = false; for i in (0..messages.len()).rev() { let tcnt = 3 + message_token_count[i]; if !message_take[i] { - if tokens_used + tcnt < tokens_limit { + if message_can_be_compressed[i] && tcnt > MESSAGE_TOKEN_LIMIT + 3 && tokens_used + MESSAGE_TOKEN_LIMIT + 3 < tokens_limit { + message_take[i] = true; + tokens_used += MESSAGE_TOKEN_LIMIT + 3; + log_buffer.push(format!("take compressed {:?}, tokens_used={} < {}", crate::nicer_logs::first_n_chars(&messages[i].content.content_text_only(), 30), tokens_used, tokens_limit)); + } else if tokens_used + tcnt < tokens_limit { message_take[i] = true; tokens_used += tcnt; log_buffer.push(format!("take {:?}, tokens_used={} < {}", crate::nicer_logs::first_n_chars(&messages[i].content.content_text_only(), 30), tokens_used, tokens_limit)); @@ -46,6 +98,7 @@ pub fn limit_messages_history( dropped = true; break; } + } else { message_take[i] = true; log_buffer.push(format!("not allowed to drop {:?}, tokens_used={} < {}", crate::nicer_logs::first_n_chars(&messages[i].content.content_text_only(), 30), tokens_used, tokens_limit)); @@ -77,7 +130,16 @@ pub fn limit_messages_history( tracing::info!("drop {:?} because of drop tool result rule", crate::nicer_logs::first_n_chars(&messages[i].content.content_text_only(), 30)); } } - - let messages_out: Vec = messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect(); + let mut messages_out: Vec = Vec::new(); + for i in 0..messages.len() { + if message_take[i] { + if message_can_be_compressed[i] && message_token_count[i] > MESSAGE_TOKEN_LIMIT { + messages_out.push(compress_message(&messages[i], t.tokenizer.clone())?); + } else { + messages_out.push(messages[i].clone()); + } + } + } + Ok(messages_out) }