Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Caching with HBM and latency test #1278

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yuyanpeng-google
Copy link
Collaborator

@yuyanpeng-google yuyanpeng-google commented Feb 17, 2025

Description

Implements Prefix Caching in HBM and latency test in inference_microbenchmark.

Stores prefix tokens as a trie for fast lookup index of PrefixCache store in cache.

Insert longer Key replace shorter key to be the longest common prefix key.
The shorter key will never be returned even if longer key is erased, and should got evicted in the future.

Assume Key is equal length to tokens, which can be used to slice prompt and cache Value.
Should check the return key common prefix length by the caller.

If erase the Key not the leaf, nothing will happen.
If erased key match at a leaf, delete the node and ancestors would be the leaf after deleted.

Value will be moved to the cache, which means cannot used the same value reference after add_to_cache.

The jax may modified the value even stored in another python reference.
If the value need to be used after add_to_cache, make sure copy them before add_to_cache.

Return value copied from cache to avoid modified value in the cache, always copied the value before return.

Add PrefixCaching benchmark test in inference_microbenchmark.

Using half of the prefill_length as the common prefix and save 100 prefix in the cache.

Loading the cache (including jax.array.copy) appears to be independent of the prefill_length (tested with 128 and 1024), even though the saved cache sizes are different.

Using jax.profiler shows that the copy operation consumes a similar amount of time on TPU. This might be because the sizes aren't large or different enough to see a significant impact.

Part of results below

Prefix Cache benchmark results for prefill length 128:

PrefixCaching results:
	Per prefix size bytes: 0.124 GB
	Average save cache time: 12.142 ms
	Average fetch longest prefix time: 0.029 ms
	Average load cache time: 5.589 ms


Prefix Cache benchmark results for prefill length 1024:

PrefixCaching results:
	Per prefix size bytes: 0.220 GB
	Average save cache time: 12.987 ms
	Average fetch longest prefix time: 0.218 ms
	Average load cache time: 5.143 ms

FIXES: b/389788256
TESTED: unittest

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@yuyanpeng-google yuyanpeng-google marked this pull request as ready for review February 19, 2025 00:28
@yuyanpeng-google yuyanpeng-google changed the title [WIP] Prefix Caching Prefix Caching with HBM and latency test Feb 19, 2025
Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I left few comments.

@yuyanpeng-google
Copy link
Collaborator Author

Modify the recursive erase to prevent stack recursive limit in long context.
Fix the comments.
Rebase to main.

Copy link
Collaborator

@mailvijayasingh mailvijayasingh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@liurupeng
Copy link
Collaborator

Synced with Yuyan offline, I think the current approach looks good

@yuyanpeng-google
Copy link
Collaborator Author

Rebase and squash

Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good

@yuyanpeng-google yuyanpeng-google force-pushed the yuyan-prefix-cache-dev branch 2 times, most recently from 93e757b to 0bee4bc Compare February 27, 2025 02:13
Implements Prefix Caching in HBM and latency test in inference_microbenchmark.

Stores prefix tokens as a trie for fast lookup index of PrefixCache store in cache.

Insert longer Key replace shorter key to be the longest common prefix key.
The shorter key will never be returned even if longer key is erased, and should got evicted in the future.

Assume Key is equal length to tokens, which can be used to slice prompt and cache Value.
Should check the return key common prefix length by the caller.

If erase the Key not the leaf, nothing will happen.
If erased key match at a leaf, delete the node and ancestors would be the leaf after deleted.

Value will be moved to the cache, which means cannot used the same value reference after add_to_cache.
Value retrieved from cache should not be modified, too. It just return the reference.

Add PrefixCaching benchmark test in inference_microbenchmark.

Using proportion of the prefill_length in config as the common prefix and save specific number in config into the cache.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants