-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prefix Caching with HBM and latency test #1278
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I left few comments.
0b4f412
to
754836f
Compare
Modify the recursive erase to prevent stack recursive limit in long context. |
de9868d
to
f42d4e0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Synced with Yuyan offline, I think the current approach looks good |
a8f1544
to
89712e8
Compare
Rebase and squash |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good
93e757b
to
0bee4bc
Compare
0bee4bc
to
50ee9f3
Compare
Implements Prefix Caching in HBM and latency test in inference_microbenchmark. Stores prefix tokens as a trie for fast lookup index of PrefixCache store in cache. Insert longer Key replace shorter key to be the longest common prefix key. The shorter key will never be returned even if longer key is erased, and should got evicted in the future. Assume Key is equal length to tokens, which can be used to slice prompt and cache Value. Should check the return key common prefix length by the caller. If erase the Key not the leaf, nothing will happen. If erased key match at a leaf, delete the node and ancestors would be the leaf after deleted. Value will be moved to the cache, which means cannot used the same value reference after add_to_cache. Value retrieved from cache should not be modified, too. It just return the reference. Add PrefixCaching benchmark test in inference_microbenchmark. Using proportion of the prefill_length in config as the common prefix and save specific number in config into the cache.
50ee9f3
to
c8fc38c
Compare
Description
Implements Prefix Caching in HBM and latency test in inference_microbenchmark.
Stores prefix tokens as a trie for fast lookup index of PrefixCache store in cache.
Insert longer Key replace shorter key to be the longest common prefix key.
The shorter key will never be returned even if longer key is erased, and should got evicted in the future.
Assume Key is equal length to tokens, which can be used to slice prompt and cache Value.
Should check the return key common prefix length by the caller.
If erase the Key not the leaf, nothing will happen.
If erased key match at a leaf, delete the node and ancestors would be the leaf after deleted.
Value will be moved to the cache, which means cannot used the same value reference after add_to_cache.
The jax may modified the value even stored in another python reference.
If the value need to be used after add_to_cache, make sure copy them before add_to_cache.
Return value copied from cache to avoid modified value in the cache, always copied the value before return.
Add PrefixCaching benchmark test in inference_microbenchmark.
Using half of the prefill_length as the common prefix and save 100 prefix in the cache.
Loading the cache (including jax.array.copy) appears to be independent of the prefill_length (tested with 128 and 1024), even though the saved cache sizes are different.
Using jax.profiler shows that the copy operation consumes a similar amount of time on TPU. This might be because the sizes aren't large or different enough to see a significant impact.
Part of results below
FIXES: b/389788256
TESTED: unittest
Checklist
Before submitting this PR, please make sure (put X in square brackets):