Skip to content

Commit

Permalink
Added utilities method for inmemory broker. (#388)
Browse files Browse the repository at this point in the history
* Added utilities method for inmemory broker.

* Added docs for new utilities.
  • Loading branch information
s3rius authored Dec 12, 2024
1 parent e5c6d2b commit ae6b214
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 7 deletions.
69 changes: 63 additions & 6 deletions docs/guide/testing-taskiq.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ the same interface as a real broker, but it doesn't send tasks actually.
Let's define a task.

```python
from your_project.taskiq import broker
from your_project.tkq import broker

@broker.task
async def parse_int(val: str) -> int:
Expand All @@ -107,7 +107,7 @@ And that's it. Test should pass.
What if you want to test a function that uses task. Let's define such function.

```python
from your_project.taskiq import broker
from your_project.tkq import broker

@broker.task
async def parse_int(val: str) -> int:
Expand All @@ -129,6 +129,63 @@ async def test_add_one():
assert await parse_and_add_one("11") == 12
```

### Unawaitable tasks

When a function calls an asynchronous task but doesn't await its result,
it can be challenging to test.

In such cases, the `InMemoryBroker` provides two convenient ways to help you:
the `await_inplace` constructor parameter and the `wait_all` method.

Consider the following example where we define a task and a function that calls it:

```python
from your_project.tkq import broker

@broker.task
async def parse_int(val: str) -> int:
return int(val)


async def parse_int_later(val: str) -> int:
await parse_int.kiq(val)
return 1
```

To test this function, we can do two things:

1. By setting the `await_inplace=True` parameter when creating the broker.
In that case all tasks will be automatically awaited as soon as they are called.
In such a way you don't need to manually call the `wait_result` in your code.

To set it up, define the broker as the following:

```python
...
broker = InMemoryBroker(await_inplace=True)
...

```

With this setup all `await function.kiq()` calls will behave similarly to `await function()`, but
with dependency injection and all taskiq-related functionality.

2. Alternatively, you can manually await all tasks after invoking the
target function by using the `wait_all` method.
This gives you more control over when to wait for tasks to complete.

```python
from your_project.tkq import broker

@pytest.mark.anyio
async def test_add_one():
# Call the function that triggers the async task
assert await parse_int_later("11") == 1
await broker.wait_all() # Waits for all tasks to complete
# At that time we can guarantee that all sent tasks
# have been completed and do all the assertions.
```

## Dependency injection

If you use dependencies in your tasks, you may think that this can become a problem. But it's not.
Expand All @@ -146,7 +203,7 @@ from typing import Annotated
from pathlib import Path
from taskiq import TaskiqDepends

from your_project.taskiq import broker
from your_project.tkq import broker


@broker.task
Expand All @@ -161,7 +218,7 @@ async def modify_path(some_path: Annotated[Path, TaskiqDepends()]):
from pathlib import Path
from taskiq import TaskiqDepends

from your_project.taskiq import broker
from your_project.tkq import broker


@broker.task
Expand All @@ -177,7 +234,7 @@ expected dependencies manually as function's arguments or key-word arguments.

```python
import pytest
from your_project.taskiq import broker
from your_project.tkq import broker

from pathlib import Path

Expand All @@ -193,7 +250,7 @@ must mutate dependency_context before calling a task. We suggest to do it in fix

```python
import pytest
from your_project.taskiq import broker
from your_project.tkq import broker
from pathlib import Path


Expand Down
20 changes: 19 additions & 1 deletion taskiq/brokers/inmemory_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
cast_types: bool = True,
max_async_tasks: int = 30,
propagate_exceptions: bool = True,
await_inplace: bool = False,
) -> None:
super().__init__()
self.result_backend = InmemoryResultBackend(
Expand All @@ -140,6 +141,7 @@ def __init__(
max_async_tasks=max_async_tasks,
propagate_exceptions=propagate_exceptions,
)
self.await_inplace = await_inplace
self._running_tasks: "Set[asyncio.Task[Any]]" = set()

async def kick(self, message: BrokerMessage) -> None:
Expand All @@ -156,7 +158,12 @@ async def kick(self, message: BrokerMessage) -> None:
if target_task is None:
raise TaskiqError("Unknown task.")

task = asyncio.create_task(self.receiver.callback(message=message.message))
receiver_cb = self.receiver.callback(message=message.message)
if self.await_inplace:
await receiver_cb
return

task = asyncio.create_task(receiver_cb)
self._running_tasks.add(task)
task.add_done_callback(self._running_tasks.discard)

Expand All @@ -171,6 +178,17 @@ def listen(self) -> AsyncGenerator[bytes, None]:
"""
raise RuntimeError("Inmemory brokers cannot listen.")

async def wait_all(self) -> None:
"""
Wait for all currently running tasks to complete.
Useful when used in testing and you need to await all sent tasks
before asserting results.
"""
to_await = list(self._running_tasks)
for task in to_await:
await task

async def startup(self) -> None:
"""Runs startup events for client and worker side."""
for event in (TaskiqEvents.CLIENT_STARTUP, TaskiqEvents.WORKER_STARTUP):
Expand Down
36 changes: 36 additions & 0 deletions tests/brokers/test_inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,39 @@ async def test_task() -> str:

result = await task.wait_result()
assert result.return_value == test_value


@pytest.mark.anyio
async def test_inline_awaits() -> None:
broker = InMemoryBroker(await_inplace=True)
slept = False

@broker.task
async def test_task() -> None:
nonlocal slept
await asyncio.sleep(0.2)
slept = True

task = await test_task.kiq()
assert slept
assert await task.is_ready()
assert not broker._running_tasks


@pytest.mark.anyio
async def test_wait_all() -> None:
broker = InMemoryBroker()
slept = False

@broker.task
async def test_task() -> None:
nonlocal slept
await asyncio.sleep(0.2)
slept = True

task = await test_task.kiq()
assert not slept
await broker.wait_all()
assert slept
assert await task.is_ready()
assert not broker._running_tasks

0 comments on commit ae6b214

Please sign in to comment.