Skip to content

Commit

Permalink
add optional prefix to redis keys (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 authored Feb 25, 2025
1 parent e30ed08 commit 090a00a
Showing 1 changed file with 42 additions and 18 deletions.
60 changes: 42 additions & 18 deletions taskiq_redis/redis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
result_px_time: Optional[int] = None,
max_connection_pool_size: Optional[int] = None,
serializer: Optional[TaskiqSerializer] = None,
prefix_str: Optional[str] = None,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -82,6 +83,7 @@ def __init__(
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
self.prefix_str = prefix_str

unavailable_conditions = any(
(
Expand All @@ -99,6 +101,11 @@ def __init__(
"Choose either result_ex_time or result_px_time.",
)

def _task_name(self, task_id: str) -> str:
if self.prefix_str is None:
return task_id
return f"{self.prefix_str}:{task_id}"

async def shutdown(self) -> None:
"""Closes redis connection."""
await self.redis_pool.disconnect()
Expand All @@ -119,7 +126,7 @@ async def set_result(
:param result: TaskiqResult instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id,
"name": self._task_name(task_id),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
Expand All @@ -139,7 +146,7 @@ async def is_result_ready(self, task_id: str) -> bool:
:returns: True if the result is ready else False.
"""
async with Redis(connection_pool=self.redis_pool) as redis:
return bool(await redis.exists(task_id))
return bool(await redis.exists(self._task_name(task_id)))

async def get_result(
self,
Expand All @@ -154,14 +161,15 @@ async def get_result(
:raises ResultIsMissingError: if there is no result when trying to get it.
:return: task's return value.
"""
task_name = self._task_name(task_id)
async with Redis(connection_pool=self.redis_pool) as redis:
if self.keep_results:
result_value = await redis.get(
name=task_id,
name=task_name,
)
else:
result_value = await redis.getdel(
name=task_id,
name=task_name,
)

if result_value is None:
Expand Down Expand Up @@ -192,7 +200,7 @@ async def set_progress(
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
Expand All @@ -215,7 +223,7 @@ async def get_progress(
"""
async with Redis(connection_pool=self.redis_pool) as redis:
result_value = await redis.get(
name=task_id + PROGRESS_KEY_SUFFIX,
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
Expand All @@ -237,6 +245,7 @@ def __init__(
result_ex_time: Optional[int] = None,
result_px_time: Optional[int] = None,
serializer: Optional[TaskiqSerializer] = None,
prefix_str: Optional[str] = None,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -261,6 +270,7 @@ def __init__(
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
self.prefix_str = prefix_str

unavailable_conditions = any(
(
Expand All @@ -278,6 +288,11 @@ def __init__(
"Choose either result_ex_time or result_px_time.",
)

def _task_name(self, task_id: str) -> str:
if self.prefix_str is None:
return task_id
return f"{self.prefix_str}:{task_id}"

async def shutdown(self) -> None:
"""Closes redis connection."""
await self.redis.aclose() # type: ignore[attr-defined]
Expand All @@ -298,7 +313,7 @@ async def set_result(
:param result: TaskiqResult instance.
"""
redis_set_params: Dict[str, Union[str, bytes, int]] = {
"name": task_id,
"name": self._task_name(task_id),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
Expand All @@ -316,7 +331,7 @@ async def is_result_ready(self, task_id: str) -> bool:
:returns: True if the result is ready else False.
"""
return bool(await self.redis.exists(task_id)) # type: ignore[attr-defined]
return bool(await self.redis.exists(self._task_name(task_id))) # type: ignore[attr-defined]

async def get_result(
self,
Expand All @@ -331,13 +346,14 @@ async def get_result(
:raises ResultIsMissingError: if there is no result when trying to get it.
:return: task's return value.
"""
task_name = self._task_name(task_id)
if self.keep_results:
result_value = await self.redis.get( # type: ignore[attr-defined]
name=task_id,
name=task_name,
)
else:
result_value = await self.redis.getdel( # type: ignore[attr-defined]
name=task_id,
name=task_name,
)

if result_value is None:
Expand Down Expand Up @@ -368,7 +384,7 @@ async def set_progress(
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
Expand All @@ -389,7 +405,7 @@ async def get_progress(
:return: task's TaskProgress instance.
"""
result_value = await self.redis.get( # type: ignore[attr-defined]
name=task_id + PROGRESS_KEY_SUFFIX,
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
Expand All @@ -414,6 +430,7 @@ def __init__(
min_other_sentinels: int = 0,
sentinel_kwargs: Optional[Any] = None,
serializer: Optional[TaskiqSerializer] = None,
prefix_str: Optional[str] = None,
**connection_kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -443,6 +460,7 @@ def __init__(
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
self.prefix_str = prefix_str

unavailable_conditions = any(
(
Expand All @@ -460,6 +478,11 @@ def __init__(
"Choose either result_ex_time or result_px_time.",
)

def _task_name(self, task_id: str) -> str:
if self.prefix_str is None:
return task_id
return f"{self.prefix_str}:{task_id}"

@asynccontextmanager
async def _acquire_master_conn(self) -> AsyncIterator[_Redis]:
async with self.sentinel.master_for(self.master_name) as redis_conn:
Expand All @@ -480,7 +503,7 @@ async def set_result(
:param result: TaskiqResult instance.
"""
redis_set_params: Dict[str, Union[str, bytes, int]] = {
"name": task_id,
"name": self._task_name(task_id),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
Expand All @@ -500,7 +523,7 @@ async def is_result_ready(self, task_id: str) -> bool:
:returns: True if the result is ready else False.
"""
async with self._acquire_master_conn() as redis:
return bool(await redis.exists(task_id))
return bool(await redis.exists(self._task_name(task_id)))

async def get_result(
self,
Expand All @@ -515,14 +538,15 @@ async def get_result(
:raises ResultIsMissingError: if there is no result when trying to get it.
:return: task's return value.
"""
task_name = self._task_name(task_id)
async with self._acquire_master_conn() as redis:
if self.keep_results:
result_value = await redis.get(
name=task_id,
name=task_name,
)
else:
result_value = await redis.getdel(
name=task_id,
name=task_name,
)

if result_value is None:
Expand Down Expand Up @@ -553,7 +577,7 @@ async def set_progress(
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
Expand All @@ -576,7 +600,7 @@ async def get_progress(
"""
async with self._acquire_master_conn() as redis:
result_value = await redis.get(
name=task_id + PROGRESS_KEY_SUFFIX,
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
Expand Down

0 comments on commit 090a00a

Please sign in to comment.