From 694dcd98a348110aef32e69253d16b92278206c7 Mon Sep 17 00:00:00 2001 From: MasterYi Date: Tue, 26 Mar 2024 09:46:55 +0800 Subject: [PATCH] server: fix system_tokens being erased in kv_cache; --- examples/server/server.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c4c545c3e0ac4..06afafdc79fa6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1695,7 +1695,7 @@ struct server_context { if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { // Shift context const int n_keep = slot.params.n_keep + add_bos_token; - const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; + const int n_left = slot.n_past - n_keep; const int n_discard = n_left / 2; LOG_INFO("slot context shift", { @@ -1710,8 +1710,8 @@ struct server_context { {"n_cache_tokens", slot.cache_tokens.size()} }); - llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, slot.id + 1, system_tokens.size() + n_keep , system_tokens.size() + n_keep + n_discard); + llama_kv_cache_seq_add(ctx, slot.id + 1, system_tokens.size() + n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); if (slot.params.cache_prompt) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -1853,10 +1853,10 @@ struct server_context { // if input prompt is too big, truncate it (if group attention self-extend is disabled) if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { - const int n_left = slot.n_ctx - slot.params.n_keep; + const int n_left = slot.n_ctx - slot.params.n_keep - system_tokens.size(); const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - system_tokens.size() - n_block_size) / n_block_size; std::vector new_tokens( prompt_tokens.begin(), @@ -1864,7 +1864,7 @@ struct server_context { new_tokens.insert( new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + prompt_tokens.begin() + slot.params.n_keep + system_tokens.size() + erased_blocks * n_block_size, prompt_tokens.end()); prompt_tokens = std::move(new_tokens);