Skip to content

Commit

Permalink
add lock for multi-thread
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyanpeng-google committed Feb 25, 2025
1 parent fc0efd2 commit 0b4f412
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions MaxText/prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import jax
import jax.numpy as jnp
import logging
import threading

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -339,6 +340,7 @@ def __init__(self, hbm_bytes: int):
hbm_bytes: Total amount of HBM to use for cache.
"""
self._hbm_bytes = hbm_bytes
self._lock = threading.Lock()
# init in clear()
self._hbm_cache: HBMCache = None
self._trie: PrefixCacheTrie = None
Expand All @@ -348,36 +350,41 @@ def __init__(self, hbm_bytes: int):
def fetch_longest_common_prefix_key(self, key: Key) -> Optional[Key]:
"""Returns key with longest common prefix matched or None if not found."""
logger.debug("fetch_longest_common_prefix_key, key=%r", key)
matched_key = self._trie.get_longest_common_prefix_key(key)
logger.debug("matched_key=%r", matched_key)
return matched_key
with self._lock:
matched_key = self._trie.get_longest_common_prefix_key(key)
logger.debug("matched_key=%r", matched_key)
return matched_key

def save(self, key: Key, value: Value) -> bool:
"""Save key/value to the cache."""
logger.debug("save key=%r", key)
while not self._hbm_cache.has_enough_space(value):
if self._hbm_bytes < value.prefix_size_bytes:
logger.debug("hbm_bytes=%r < value.prefix_size_bytes=%r", self._hbm_bytes, value.prefix_size_bytes)
break
if self._evict_cache() is None:
logger.debug("cannot evict cache")
break
if not self._hbm_cache.add_to_cache(key, value):
logger.debug("cannot add to cache even after evict")
return False
self._trie.insert(key)
self._cache_strategy.use(key)
return True
with self._lock:
while not self._hbm_cache.has_enough_space(value):
if self._hbm_bytes < value.prefix_size_bytes:
logger.debug("hbm_bytes=%r < value.prefix_size_bytes=%r", self._hbm_bytes, value.prefix_size_bytes)
break
if self._evict_cache() is None:
logger.debug("cannot evict cache")
break
if not self._hbm_cache.add_to_cache(key, value):
logger.debug("cannot add to cache even after evict")
return False
self._trie.insert(key)
self._cache_strategy.use(key)
return True

def load(self, key: Key) -> Optional[Value]:
"""Returns Value stored with key or None if not found."""
logger.debug("load key=%r", key)
value = self._hbm_cache.retrieve_from_cache(key)
if value is None:
logger.warning("The key should fetched by fetch_longest_common_prefix_key, load key=%r should be valid but not.", key)
return None
self._cache_strategy.use(key)
return value
with self._lock:
value = self._hbm_cache.retrieve_from_cache(key)
if value is None:
logger.warning(
"The key should fetched by fetch_longest_common_prefix_key, load key=%r should be valid but not.", key
)
return None
self._cache_strategy.use(key)
return value

def clear(self):
"""Clear entire cache."""
Expand Down

0 comments on commit 0b4f412

Please sign in to comment.