Skip to content

Commit

Permalink
Add Prefix Caching with HBM and latency test
Browse files Browse the repository at this point in the history
Implements Prefix Caching in HBM and latency test in inference_microbenchmark.

Stores prefix tokens as a trie for fast lookup index of PrefixCache store in cache.

Insert longer Key replace shorter key to be the longest common prefix key.
The shorter key will never be returned even if longer key is erased, and should got evicted in the future.

Assume Key is equal length to tokens, which can be used to slice prompt and cache Value.
Should check the return key common prefix length by the caller.

If erase the Key not the leaf, nothing will happen.
If erased key match at a leaf, delete the node and ancestors would be the leaf after deleted.

Value will be moved to the cache, which means cannot used the same value reference after add_to_cache.
Value retrieved from cache should not be modified, too. It just return the reference.

Add PrefixCaching benchmark test in inference_microbenchmark.

Using proportion of the prefill_length in config as the common prefix and save specific number in config into the cache.
  • Loading branch information
yuyanpeng-google committed Feb 27, 2025
1 parent d50683d commit c8fc38c
Show file tree
Hide file tree
Showing 4 changed files with 990 additions and 0 deletions.
2 changes: 2 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ vertex_tensorboard_region: ""
max_checkify: False

# Inference
inference_microbenchmark_prefix_cache_entries_num: 100
inference_microbenchmark_prefix_cache_common_prefix_proportion: 0.5
inference_microbenchmark_prefill_lengths: "64,128,256,512,1024"
inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
Expand Down
112 changes: 112 additions & 0 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import max_utils
import maxengine
import maxtext_utils
import prefix_cache
import profiler
import pyconfig

Expand All @@ -39,6 +40,101 @@
# pylint: disable=too-many-positional-arguments


def prefix_cache_benchmark(
prefix, prefill_length: int, true_length: int, common_prefix_proportion: float, prefix_cache_entries_num: int, iters: int
):
"""Handles running prefix cache benchmark, and printing results.
Create different key with half of prefill_length common prefix insert into cache.
The value is not relevant to the cache for now. Just copy the prefix for every cache entry.
1. Fill the prefix cache to full capacity.
2. Benchmark save prefix cache with evicting time average by prefix_cache_entries_num.
3. Benchmark fetch_longest_common_prefix_key average by iters.
4. Benchmark load prefix cache time average by iters.
Args:
prefix: prefix return from prefill function
prefill_length: prefill token length after padding
true_length: true prefill token length
common_prefix_proportion: [0., 1.] common prefix proportion to the prefill_length
prefix_cache_entries_num: number of prefix cache entries insert into PrefixCache
iters: repeat time to test fetch_longest_common_prefix_key and load from cache
"""

print(f"Prefix Cache benchmark results for prefill length {prefill_length}:\n")

value = prefix_cache.Value(
prefix=prefix,
true_length=true_length,
padded_length=prefill_length,
tokens=tuple(i for i in range(prefill_length)),
)
prefix_size_bytes_gb = value.prefix_size_bytes / 1024 / 1024 / 1024
prefix_cache_inst = prefix_cache.PrefixCache(prefix_cache_entries_num * value.prefix_size_bytes)
common_len = int(prefill_length * common_prefix_proportion)
remain_len = prefill_length - common_len
common_prefix_key = tuple(i for i in range(common_len))

# Fill the prefix caching
new_value_list = []
for c_idx in range(prefix_cache_entries_num):
# Add 100 to make sure filled prefix caching will not share the common_prefix_key.
# The later save prefix part will evict all of them.
key = tuple(100 + i + c_idx * prefill_length for i in range(prefill_length))
new_value = value.clone()
prefix_cache_inst.save(key, new_value)
new_value_list.append(new_value)
jax.block_until_ready(new_value_list)
del new_value_list

# Save prefix
new_value = None
save_sec = 0
for c_idx in range(iters):
key = common_prefix_key + tuple(i + c_idx * remain_len for i in range(remain_len))
# values are not relevant for caching now, just clone the same tokens and values for test
new_value = value.clone()
jax.block_until_ready(new_value)
start = datetime.datetime.now()
prefix_cache_inst.save(key, new_value)
end = datetime.datetime.now()
save_sec += (end - start).total_seconds()
del new_value
save_avg_ms = save_sec * 1000 / iters

# Fetch longest prefix key
key_load = common_prefix_key + tuple(i + prefix_cache_entries_num * remain_len for i in range(remain_len))
matched_key = None
fetch_sec = 0
for _ in range(iters):
start = datetime.datetime.now()
matched_key = prefix_cache_inst.fetch_longest_common_prefix_key(key_load)
end = datetime.datetime.now()
fetch_sec += (end - start).total_seconds()
fetch_avg_ms = fetch_sec * 1000 / iters

# Load prefix
load_sec = 0
value_load = None
for _ in range(iters):
start = datetime.datetime.now()
value = prefix_cache_inst.load(matched_key)
jax.block_until_ready(value)
end = datetime.datetime.now()
load_sec += (end - start).total_seconds()
del value_load
load_avg_ms = load_sec * 1000 / iters

print(
f"PrefixCaching results:\n"
f"\tPer prefix size bytes: {prefix_size_bytes_gb:.3f} GB\n"
f"\tAverage save cache time: {save_avg_ms:.3f} ms\n"
f"\tAverage fetch longest prefix time: {fetch_avg_ms:.3f} ms\n"
f"\tAverage load cache time: {load_avg_ms:.3f} ms\n\n\n"
)
del prefix_cache_inst


def prefill_benchmark_loop(engine_prefill, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
Expand Down Expand Up @@ -305,6 +401,22 @@ def run_benchmarks(config):
prefill_executable[prefill_length], params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length]
)

if "prefix_cache" in stages_to_benchmark:
for prefill_length in prefill_lengths:
rng_cache = jax.random.PRNGKey(1234)
prefill_result, _ = prefill_executable[prefill_length](
params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length], rng_cache
)
prefix_cache_benchmark(
prefill_result,
prefill_length,
prefill_true_lengths[prefill_length],
config.inference_microbenchmark_prefix_cache_common_prefix_proportion,
config.inference_microbenchmark_prefix_cache_entries_num,
benchmark_loop_iters,
)
del prefill_result

for prefill_length in prefill_lengths:
benchmark_results["prefill"][prefill_length] = prefill_benchmark(
config,
Expand Down
Loading

0 comments on commit c8fc38c

Please sign in to comment.