Skip to content

server : separate the notion of position and KV tokens, remove prompt truncation #13576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented May 15, 2025

The motivation of this API is because currently, we have a bug where the context usage of Qwen VL models (using M-RoPE) is incorrectly tracked.

This is because for Qwen VL, a image can occupies multiple KV entries (let's not use the term "tokens" here so it's less confusing), but the image only occupies one position. For example:

# number of KV entries: 11
# but n_past is 9

token ID: 0 26 73 235 466 151 <embd> <embd> <embd> 4 463
position: 0 1  2  3   4   5   6      6      6      7 8
index:    0 1  2  3   4   5   6      7      8      9 10

My idea is to simply separate the notion of n_past and n_kv_tokens. But turns out, it's a bit more complicated than I initially thought.

To make it a bit easier, I ended up removing the prompt truncating logic, I remember someone added this feature with the argument that if the process the prompt only to have the ctx shift kicked in to remove processed token, it's wasteful. However, I don't agree with that:

  1. The computation is not wasteful, basically when the ctx shift kicked in, the knowledge about discarded token are not gone, but they are already "dissolved" into KV cache because of the attention mechanism. This is basically the idea behind sliding window attention in gemma or phi-3
  2. This prompt truncation impl is quite ugly, as it requires two separated variables n_prompt_tokens, n_prompt_tokens_processed
  3. This prompt truncation logic actually make a bad UX overall. Well, at least for me, I have no idea when using the server, will it discard the beginning of my prompt or not (especially when I process a long input)

Other than that, I'm 90% sure other things are being "migrated" correctly to n_kv_tokens ; but just in case, @ggerganov please take your time to review this. Thanks!!


Note: if this PR get merged, we must also add a breaking change in #9291

@ngxson ngxson requested a review from ggerganov May 15, 2025 22:08
@github-actions github-actions bot added examples python python script changes server labels May 15, 2025
@ngxson ngxson added the breaking change Changes that break ABIs, APIs, file formats, or other forms of backwards compatibility. label May 15, 2025
@steampunque
Copy link

Why not just keep track of the kv image delta in one variable? It looks incredibly inefficient to scan through the whole prompt to find the kv length when it could just be easily computed as the length of the prompt cache + kv image delta, which gets updated as new images are found in the prompt processing. Since there will be a small number of images its also easy to maintain a companion image prompt information stdvec which stores the position and kv usage of each image in the prompt and scan through that if dumping parts of the kv during context shift, that array will be very short and can be efficiently scanned and easily maintained without needing to scan through the text tokens in the prompt at all. Some possible applications might have very long text prompts such as scanning in a pdf with text and images in it and pdfs can go to 50k to 100k text tokens easily, not good to force walking through that whole long prompt every time just to find out how much kv is used.

@ngxson
Copy link
Collaborator Author

ngxson commented May 16, 2025

@steampunque I don't understand what you're saying about kv delta. Can you make a code example?

The scan only happen on slot start, not on every decoding. In normal case, the variable server_tokens::n_kv will be used, no scan is needed

@steampunque
Copy link

steampunque commented May 16, 2025

@steampunque I don't understand what you're saying about kv delta. Can you make a code example?

The scan only happen on slot start, not on every decoding. In normal case, the variable server_tokens::n_kv will be used, no scan is needed

Just to illustrate the high level concept:

                int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
                                                    chunks.ptr.get(),
                                                    &inp_txt,
                                                    &n_prompt_tokens_image,
                                                    bitmaps_c_ptr.data(),
                                                    bitmaps_c_ptr.size());
                if (tokenized != 0) {

.
.
.

.
.
.
GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size() + slot.n_prompt_tokens_image);
.
.
.

This can be extended if needed by passing a stdvec to the mtmd_tokenize routine which returns an array containing where the images are positioned and how much kv they consume in case needed by other processing later. There will never be thousands of images, in most cases only 1 to 10s, but there can be tens to hundred thousands of text tokens wrapped around small number of images so this image information array if even needed is short.

As an aside I never understood the whole idea of the built in "context shift" functionality when kv gets full. I believe this functionality should be completely removed from the server, not just auto-disabled when images are present. If kv hits a capacity limit generation should just stop there and full status returned so the caller can do whatever they want with it instead of executing some arbitrary token removal heuristic in the server. For a multiturn conversation contexts shifts can easily corrupt the prompt by getting rid of tokens in the middle of a prompt template. I don't think this problem should even exist because whatever process is calling the server should be responsible for handling the kv full, not the server itself. In an extended multiturn situation the process calling the server can delete earlier parts of the conversation to keep going and there will never be any worry about invalidating the prompt with a context shift going on in the server. If the input prompt is too big it should also not truncate prompt, it should return no tokens with kv full status.

@ngxson
Copy link
Collaborator Author

ngxson commented May 16, 2025

Then how your idea work in the case we want to discard last N positions (not N tokens) from the cache_tokens? How can you isolate slot.n_prompt_tokens_image for the removed image versus tokens count of images still being kept in cache?

This can be extended if needed by passing a stdvec to the mtmd_tokenize routine which returns an array containing where the images are positioned and how much kv they consume in case needed by other processing later.

It's the server_tokens::map_pos_to_image - are you sure you already read the entire code?

Comment on lines 3097 to 3098
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
slot.n_kv_tokens = slot.cache_tokens.n_kv_tokens(slot.n_past);
Copy link
Collaborator Author

@ngxson ngxson May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@steampunque The "scan" loop for counting token is only needed in this particular logic and nowhere else in the code base, so currently there is no risk of performance degrade. We can simply merge these 2 calls into one single API so we will end up with a single loop.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 678d7b1

@steampunque
Copy link

steampunque commented May 16, 2025

Then how your idea work in the case we want to discard last N positions (not N tokens) from the cache_tokens? How can you isolate slot.n_prompt_tokens_image for the removed image versus tokens count of images still being kept in cache?

This can be extended if needed by passing a stdvec to the mtmd_tokenize routine which returns an array containing where the images are positioned and how much kv they consume in case needed by other processing later.

It's the server_tokens::map_pos_to_image - are you sure you already read the entire code?

That I showed was just very high level rough concept illustrating to have mtmd_tokenize return information about the images and then use that information later in downstream prompt processing, to avoid having to recompute things which have already been figured out in the tokenize routine. More specifically :

image_info_rec {
   int pos; /* where image is located in the tokenizer output */
   inv kv_use; /* how much kv is used up by the image }
}

/* mtmd_tokenize fills this in for all images in the prompt and it can be later used in prompt processing logic
    to avoid re computation of pos and kv sizes.  This vector will be as long as the number of images in prompt. */
std::vector<image_info_rec> image_info;

@ngxson
Copy link
Collaborator Author

ngxson commented May 16, 2025

I think you still don't get it. We basically had the API mtmd_image_tokens_get_n_tokens that returns number of kv tokens an image takes. This API cost nothing, read its source code to understand

There is no need to do manual mapping image --> num of tokens like you said

@steampunque
Copy link

steampunque commented May 16, 2025

I think you still don't get it. We basically had the API mtmd_image_tokens_get_n_tokens that returns number of kv tokens an image takes. This API cost nothing, read its source code to understand

There is no need to do manual mapping image --> num of tokens like you said

You made many changes to this PR after my comment so some of my concerns are already addressed. My comments were mainly addressing the original PR where you were scanning through the entire prompt seeking out images at every position to compute kv size which should not be necessary if you already know where the images are and how much space they take with a short companion array computed when tokenizing the prompt. If you have another way to do that efficiently without the companion array and still avoid walking through the whole prompt looking for images that is fine.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can improve the naming by avoiding the notion of "KV cache". We have 2 types of counts to consider: tokens and positions. So we can call them like that: n_tok/n_tokens, n_pos/n_positions.

Introducing n_kv/n_kv_tokens seems unnecessary and a bit confusing. The "KV cache" is an implementation detail and we should aim so that the examples and the tools do not refer to it. They should only know about the "cache" or the "memory" of the context.

@ngxson
Copy link
Collaborator Author

ngxson commented May 17, 2025

@ggerganov I made an important change in de8956a:

slot.n_past now track the absolute position and not the relative position. For example, when we have 10 cached tokens and ctx shift kicks in, let's say removing 4 tokens, then traditionally we do n_past -= 4. But now, n_past will stay at 10. The old notion of n_past now becomes cache_tokens.n_pos()

I know it may sound quite cumbersome, but I think we can potentially add a notion of "view" for server_tokens. So instead of "shifting" the list of tokens, we can simply get a view from a certain offset. That's the high-level idea but I don't know yet how to implement it efficiently.

Also while working on this, I think it's now time to start thinking about how to clean up / break the upload_slots() function into smaller sub-functions. Based on my idea of state machine from #9283 , I think it can be a good reason to consider each state as a function. I'll draft an issue with more details on the idea.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking change Changes that break ABIs, APIs, file formats, or other forms of backwards compatibility. examples python python script changes server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants