Skip to content

Commit

Permalink
Merge branch 'release/0.4.3'
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius committed May 9, 2023
2 parents f54d1ef + aca7228 commit 1123d00
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "taskiq"
version = "0.4.2"
version = "0.4.3"
description = "Distributed task queue with full async support"
authors = ["Pavel Kirilin <[email protected]>"]
maintainers = ["Pavel Kirilin <[email protected]>"]
Expand Down
16 changes: 5 additions & 11 deletions taskiq/brokers/inmemory_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, Callable, Optional, Set, TypeVar, get_type_hints
from typing import Any, AsyncGenerator, Set, TypeVar, get_type_hints

from taskiq_dependencies import DependencyGraph

Expand Down Expand Up @@ -88,22 +88,16 @@ class InMemoryBroker(AsyncBroker):
It's useful for local development, if you don't want to setup real broker.
"""

def __init__( # noqa: WPS211
def __init__(
self,
sync_tasks_pool_size: int = 4,
max_stored_results: int = 100,
cast_types: bool = True,
result_backend: Optional[AsyncResultBackend[Any]] = None,
task_id_generator: Optional[Callable[[], str]] = None,
max_async_tasks: int = 30,
) -> None:
if result_backend is None:
result_backend = InmemoryResultBackend(
max_stored_results=max_stored_results,
)
super().__init__(
result_backend=result_backend,
task_id_generator=task_id_generator,
super().__init__()
self.result_backend = InmemoryResultBackend(
max_stored_results=max_stored_results,
)
self.executor = ThreadPoolExecutor(sync_tasks_pool_size)
self.receiver = Receiver(
Expand Down
40 changes: 38 additions & 2 deletions taskiq/cli/worker/args.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from dataclasses import dataclass
from typing import List, Optional, Sequence
from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Tuple

from taskiq.cli.common_args import LogLevel


def receiver_arg_type(string: str) -> Tuple[str, str]:
"""
Parse cli --receiver_arg argument value.
:param string: cli argument value in format key=value.
:raises ValueError: if value not in format.
:return: (key, value) pair.
"""
args = string.split("=", 1)
if len(args) != 2:
raise ValueError(f"Invalid value: {string}")
return args[0], args[1]


@dataclass
class WorkerArgs:
"""Taskiq worker CLI arguments."""
Expand All @@ -24,6 +38,8 @@ class WorkerArgs:
reload: bool = False
no_gitignore: bool = False
max_async_tasks: int = 100
receiver: str = "taskiq.receiver:Receiver"
receiver_arg: List[Tuple[str, str]] = field(default_factory=list)

@classmethod
def from_cli( # noqa: WPS213
Expand All @@ -45,6 +61,26 @@ def from_cli( # noqa: WPS213
"'module.module:variable' format."
),
)
parser.add_argument(
"--receiver",
default="taskiq.receiver:Receiver",
help=(
"Where to search for receiver. "
"This string must be specified in "
"'module.module:variable' format."
),
)
parser.add_argument(
"--receiver_arg",
action="append",
type=receiver_arg_type,
default=[],
help=(
"List of args fot receiver. "
"This string must be specified in "
"`key=value` format."
),
)
parser.add_argument(
"--tasks-pattern",
"-tp",
Expand Down
26 changes: 23 additions & 3 deletions taskiq/cli/worker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import signal
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Any, Type

from taskiq.abc.broker import AsyncBroker
from taskiq.cli.utils import import_object, import_tasks
Expand Down Expand Up @@ -51,7 +51,21 @@ async def shutdown_broker(broker: AsyncBroker, timeout: float) -> None:
)


def start_listen(args: WorkerArgs) -> None: # noqa: WPS213
def get_receiver_type(args: WorkerArgs) -> Type[Receiver]:
"""
Import Receiver from args.
:param args: CLI arguments.
:raises ValueError: if receiver is not a Receiver type.
:return: Receiver type.
"""
receiver_type = import_object(args.receiver)
if not (isinstance(receiver_type, type) and issubclass(receiver_type, Receiver)):
raise ValueError("Unknown receiver type. Please use Receiver class.")
return receiver_type


def start_listen(args: WorkerArgs) -> None: # noqa: WPS210, WPS213
"""
This function starts actual listening process.
Expand All @@ -63,6 +77,7 @@ def start_listen(args: WorkerArgs) -> None: # noqa: WPS213
:param args: CLI arguments.
:raises ValueError: if broker is not an AsyncBroker instance.
:raises ValueError: if receiver is not a Receiver type.
"""
if uvloop is not None:
logger.debug("UVLOOP found. Installing policy.")
Expand All @@ -77,6 +92,9 @@ def start_listen(args: WorkerArgs) -> None: # noqa: WPS213
if not isinstance(broker, AsyncBroker):
raise ValueError("Unknown broker type. Please use AsyncBroker instance.")

receiver_type = get_receiver_type(args)
receiver_args = dict(args.receiver_arg)

# Here how we manage interruptions.
# We have to remember shutting_down state,
# because KeyboardInterrupt can be send multiple
Expand Down Expand Up @@ -105,14 +123,16 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
signal.signal(signal.SIGTERM, interrupt_handler)

loop = asyncio.get_event_loop()

try:
logger.debug("Initialize receiver.")
with ThreadPoolExecutor(args.max_threadpool_threads) as pool:
receiver = Receiver(
receiver = receiver_type(
broker=broker,
executor=pool,
validate_params=not args.no_parse,
max_async_tasks=args.max_async_tasks,
**receiver_args,
)
loop.run_until_complete(receiver.listen())
except KeyboardInterrupt:
Expand Down
17 changes: 4 additions & 13 deletions tests/cli/worker/test_receiver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, Callable, List, Optional, TypeVar
from typing import Any, AsyncGenerator, List, Optional, TypeVar

import pytest
from taskiq_dependencies import Depends

from taskiq.abc.broker import AsyncBroker
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.message import TaskiqMessage
from taskiq.receiver import Receiver
Expand All @@ -17,15 +16,8 @@


class BrokerForTests(InMemoryBroker):
def __init__(
self,
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: Optional[Callable[[], str]] = None,
) -> None:
super().__init__(
result_backend=result_backend,
task_id_generator=task_id_generator,
)
def __init__(self) -> None:
super().__init__()
self.to_send: "List[TaskiqMessage]" = []

async def listen(self) -> AsyncGenerator[bytes, None]:
Expand Down Expand Up @@ -142,8 +134,7 @@ def on_error(
def test_func() -> None:
raise ValueError()

broker = InMemoryBroker()
broker.add_middlewares(_TestMiddleware())
broker = InMemoryBroker().with_middlewares(_TestMiddleware())
receiver = get_receiver(broker)

result = await receiver.run_task(
Expand Down

0 comments on commit 1123d00

Please sign in to comment.