diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9093d5e3..bbc4e1f1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,7 +37,7 @@ jobs: strategy: matrix: py_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - pydantic_ver: ["<2", ">=2,<3"] + pydantic_ver: ["<2", ">=2.5,<3"] os: [ubuntu-latest, windows-latest] runs-on: "${{ matrix.os }}" steps: diff --git a/docs/guide/cli.md b/docs/guide/cli.md index af40aa7d..a8c853f8 100644 --- a/docs/guide/cli.md +++ b/docs/guide/cli.md @@ -84,7 +84,7 @@ To enable this option simply pass the `--reload` or `-r` option to worker taskiq Also this option supports `.gitignore` files. If you have such file in your directory, it won't reload worker when you modify ignored files. To disable this functionality pass `--do-not-use-gitignore` option. -### Graceful reload +### Graceful reload (available only on Unix systems) To perform graceful reload, send `SIGHUP` signal to the main worker process. This action will reload all workers with new code. It's useful for deployment that requires zero downtime, but don't use orchestration tools like Kubernetes. diff --git a/taskiq/abc/result_backend.py b/taskiq/abc/result_backend.py index 7e0ebb65..257d0b04 100644 --- a/taskiq/abc/result_backend.py +++ b/taskiq/abc/result_backend.py @@ -1,8 +1,12 @@ from abc import ABC, abstractmethod -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from taskiq.result import TaskiqResult +if TYPE_CHECKING: # pragma: no cover + from taskiq.depends.progress_tracker import TaskProgress + + _ReturnType = TypeVar("_ReturnType") @@ -50,3 +54,25 @@ async def get_result( :param with_logs: if True it will download task's logs. :return: task's return value. """ + + async def set_progress( + self, + task_id: str, + progress: "TaskProgress[Any]", + ) -> None: + """ + Saves progress. + + :param task_id: task's id. + :param progress: progress of execution. + """ + + async def get_progress( + self, + task_id: str, + ) -> "Optional[TaskProgress[Any]]": + """ + Gets progress. + + :param task_id: task's id. + """ diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index 6c6eed86..544289ff 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -1,10 +1,11 @@ import asyncio from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncGenerator, Set, TypeVar +from typing import Any, AsyncGenerator, Optional, Set, TypeVar from taskiq.abc.broker import AsyncBroker from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult +from taskiq.depends.progress_tracker import TaskProgress from taskiq.events import TaskiqEvents from taskiq.exceptions import TaskiqError from taskiq.message import BrokerMessage @@ -27,6 +28,7 @@ class InmemoryResultBackend(AsyncResultBackend[_ReturnType]): def __init__(self, max_stored_results: int = 100) -> None: self.max_stored_results = max_stored_results self.results: OrderedDict[str, TaskiqResult[_ReturnType]] = OrderedDict() + self.progress: OrderedDict[str, TaskProgress[Any]] = OrderedDict() async def set_result(self, task_id: str, result: TaskiqResult[_ReturnType]) -> None: """ @@ -79,6 +81,37 @@ async def get_result( """ return self.results[task_id] + async def set_progress( + self, + task_id: str, + progress: TaskProgress[Any], + ) -> None: + """ + Set progress of task exection. + + :param task_id: task id + :param progress: task execution progress + """ + if ( + self.max_stored_results != -1 + and len(self.progress) >= self.max_stored_results + ): + self.progress.popitem(last=False) + + self.progress[task_id] = progress + + async def get_progress( + self, + task_id: str, + ) -> Optional[TaskProgress[Any]]: + """ + Get progress of task execution. + + :param task_id: task id + :return: progress or None + """ + return self.progress.get(task_id) + class InMemoryBroker(AsyncBroker): """ diff --git a/taskiq/cli/worker/process_manager.py b/taskiq/cli/worker/process_manager.py index 24f01551..acc3da3b 100644 --- a/taskiq/cli/worker/process_manager.py +++ b/taskiq/cli/worker/process_manager.py @@ -1,5 +1,6 @@ import logging import signal +import sys from contextlib import suppress from dataclasses import dataclass from multiprocessing import Event, Process, Queue, current_process @@ -174,10 +175,11 @@ def __init__( shutdown_handler = get_signal_handler(self.action_queue, ShutdownAction()) signal.signal(signal.SIGINT, shutdown_handler) signal.signal(signal.SIGTERM, shutdown_handler) - signal.signal( - signal.SIGHUP, - get_signal_handler(self.action_queue, ReloadAllAction()), - ) + if sys.platform != "win32": + signal.signal( + signal.SIGHUP, + get_signal_handler(self.action_queue, ReloadAllAction()), + ) self.workers: List[Process] = [] diff --git a/taskiq/depends/progress_tracker.py b/taskiq/depends/progress_tracker.py new file mode 100644 index 00000000..9f0161ae --- /dev/null +++ b/taskiq/depends/progress_tracker.py @@ -0,0 +1,72 @@ +import enum +from typing import Generic, Optional, Union + +from taskiq_dependencies import Depends +from typing_extensions import TypeVar + +from taskiq.compat import IS_PYDANTIC2 +from taskiq.context import Context + +if IS_PYDANTIC2: + from pydantic import BaseModel as GenericModel +else: + from pydantic.generics import GenericModel # type: ignore[no-redef] + + +_ProgressType = TypeVar("_ProgressType") + + +class TaskState(str, enum.Enum): + """State of task execution.""" + + STARTED = "STARTED" + FAILURE = "FAILURE" + SUCCESS = "SUCCESS" + RETRY = "RETRY" + + +class TaskProgress(GenericModel, Generic[_ProgressType]): + """Progress of task execution.""" + + state: Union[TaskState, str] + meta: Optional[_ProgressType] + + +class ProgressTracker(Generic[_ProgressType]): + """Task's dependency to set progress.""" + + def __init__( + self, + context: Context = Depends(), + ) -> None: + self.context = context + + async def set_progress( + self, + state: Union[TaskState, str], + meta: Optional[_ProgressType] = None, + ) -> None: + """Set progress. + + :param state: TaskState or str + :param meta: progress data + """ + if meta is None: + progress = await self.get_progress() + meta = progress.meta if progress else None + + progress = TaskProgress( + state=state, + meta=meta, + ) + + await self.context.broker.result_backend.set_progress( + self.context.message.task_id, + progress, + ) + + async def get_progress(self) -> Optional[TaskProgress[_ProgressType]]: + """Get progress.""" + return await self.context.broker.result_backend.get_progress( + self.context.message.task_id, + ) diff --git a/taskiq/task.py b/taskiq/task.py index 691046b5..b4d2be61 100644 --- a/taskiq/task.py +++ b/taskiq/task.py @@ -1,7 +1,9 @@ import asyncio from abc import ABC, abstractmethod from time import time -from typing import TYPE_CHECKING, Any, Coroutine, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Union + +from typing_extensions import TypeVar from taskiq.exceptions import ( ResultGetError, @@ -11,6 +13,7 @@ if TYPE_CHECKING: # pragma: no cover from taskiq.abc.result_backend import AsyncResultBackend + from taskiq.depends.progress_tracker import TaskProgress from taskiq.result import TaskiqResult _ReturnType = TypeVar("_ReturnType") @@ -65,6 +68,19 @@ def wait_result( :return: TaskiqResult. """ + @abstractmethod + def get_progress( + self, + ) -> Union[ + "Optional[TaskProgress[Any]]", + Coroutine[Any, Any, "Optional[TaskProgress[Any]]"], + ]: + """ + Get task progress. + + :return: task's progress. + """ + class AsyncTaskiqTask(_Task[_ReturnType]): """AsyncTask for AsyncResultBackend.""" @@ -137,3 +153,11 @@ async def wait_result( if 0 < timeout < time() - start_time: raise TaskiqResultTimeoutError return await self.get_result(with_logs=with_logs) + + async def get_progress(self) -> "Optional[TaskProgress[Any]]": + """ + Get task progress. + + :return: task's progress. + """ + return await self.result_backend.get_progress(self.task_id) diff --git a/tests/depends/test_progress_tracker.py b/tests/depends/test_progress_tracker.py new file mode 100644 index 00000000..040381b0 --- /dev/null +++ b/tests/depends/test_progress_tracker.py @@ -0,0 +1,121 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, Optional + +import pytest +from pydantic import ValidationError + +from taskiq import ( + AsyncTaskiqDecoratedTask, + InMemoryBroker, + TaskiqDepends, + TaskiqMessage, +) +from taskiq.abc import AsyncBroker +from taskiq.depends.progress_tracker import ProgressTracker, TaskState +from taskiq.receiver import Receiver + + +def get_receiver( + broker: Optional[AsyncBroker] = None, + no_parse: bool = False, + max_async_tasks: Optional[int] = None, +) -> Receiver: + """ + Returns receiver with custom broker and args. + + :param broker: broker, defaults to None + :param no_parse: parameter to taskiq_args, defaults to False + :param cli_args: Taskiq worker CLI arguments. + :return: new receiver. + """ + if broker is None: + broker = InMemoryBroker() + return Receiver( + broker, + executor=ThreadPoolExecutor(max_workers=10), + validate_params=not no_parse, + max_async_tasks=max_async_tasks, + ) + + +def get_message( + task: AsyncTaskiqDecoratedTask[Any, Any], + task_id: Optional[str] = None, + *args: Any, + labels: Optional[Dict[str, str]] = None, + **kwargs: Dict[str, Any], +) -> TaskiqMessage: + if labels is None: + labels = {} + return TaskiqMessage( + task_id=task_id or task.broker.id_generator(), + task_name=task.task_name, + labels=labels, + args=list(args), + kwargs=kwargs, + ) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "state,meta", + [ + (TaskState.STARTED, "hello world!"), + ("retry", "retry error!"), + ("custom state", {"Complex": "Value"}), + ], +) +async def test_progress_tracker_ctx_raw(state: Any, meta: Any) -> None: + broker = InMemoryBroker() + + @broker.task + async def test_func(tes_val: ProgressTracker[Any] = TaskiqDepends()) -> None: + await tes_val.set_progress(state, meta) + + kicker = await test_func.kiq() + result = await kicker.wait_result() + + assert not result.is_err + progress = await broker.result_backend.get_progress(kicker.task_id) + assert progress is not None + assert progress.meta == meta + assert progress.state == state + + +@pytest.mark.anyio +async def test_progress_tracker_ctx_none() -> None: + broker = InMemoryBroker() + + @broker.task + async def test_func() -> None: + pass + + kicker = await test_func.kiq() + result = await kicker.wait_result() + + assert not result.is_err + progress = await broker.result_backend.get_progress(kicker.task_id) + assert progress is None + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "state,meta", + [ + (("state", "error"), 1), + ], +) +async def test_progress_tracker_validation_error(state: Any, meta: Any) -> None: + broker = InMemoryBroker() + + @broker.task + async def test_func(progress: ProgressTracker[int] = TaskiqDepends()) -> None: + await progress.set_progress(state, meta) # type: ignore + + kicker = await test_func.kiq() + result = await kicker.wait_result() + with pytest.raises(ValidationError): + result.raise_for_error() + + progress = await broker.result_backend.get_progress(kicker.task_id) + assert progress is None