-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: handling client disconnect gracefully by propagating cancellatio…
…n exception up to the chat completion handler (#185)
- Loading branch information
Showing
7 changed files
with
265 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import asyncio | ||
from asyncio import exceptions | ||
from typing import Optional, Set | ||
|
||
|
||
class CancelScope: | ||
""" | ||
Async context manager that enforces cancellation of all tasks created within its scope when either: | ||
1. the parent task has been cancelled or has thrown an exception or | ||
2. any of the tasks created within the scope has thrown an exception. | ||
""" | ||
|
||
def __init__(self): | ||
self._tasks: Set[asyncio.Task] = set() | ||
self._on_completed_fut: Optional[asyncio.Future] = None | ||
self._cancelling: bool = False | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc, tb): | ||
|
||
cancelled_error = ( | ||
exc if isinstance(exc, exceptions.CancelledError) else None | ||
) | ||
|
||
# If the parent task has thrown an exception, cancel all the tasks | ||
if exc_type is not None: | ||
self._cancel_tasks() | ||
|
||
while self._tasks: | ||
if self._on_completed_fut is None: | ||
self._on_completed_fut = asyncio.Future() | ||
|
||
# If the parent task was cancelled, cancel all the tasks | ||
try: | ||
await self._on_completed_fut | ||
except exceptions.CancelledError as ex: | ||
cancelled_error = ex | ||
self._cancel_tasks() | ||
|
||
self._on_completed_fut = None | ||
|
||
if cancelled_error: | ||
raise cancelled_error | ||
|
||
def create_task(self, coro): | ||
task = asyncio.create_task(coro) | ||
task.add_done_callback(self._on_task_done) | ||
self._tasks.add(task) | ||
return task | ||
|
||
def _cancel_tasks(self): | ||
if not self._cancelling: | ||
self._cancelling = True | ||
for t in self._tasks: | ||
if not t.done(): | ||
t.cancel() | ||
|
||
def _on_task_done(self, task): | ||
self._tasks.discard(task) | ||
|
||
if ( | ||
self._on_completed_fut is not None | ||
and not self._on_completed_fut.done() | ||
and not self._tasks | ||
): | ||
self._on_completed_fut.set_result(True) | ||
|
||
# If any of the tasks was cancelled, cancel all the tasks | ||
if task.exception() is not None: | ||
self._cancel_tasks() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import asyncio | ||
from typing import Optional | ||
|
||
import pytest | ||
|
||
from aidial_sdk.chat_completion.request import Request as ChatCompletionRequest | ||
from aidial_sdk.chat_completion.response import ( | ||
Response as ChatCompletionResponse, | ||
) | ||
from aidial_sdk.pydantic_v1 import SecretStr | ||
from aidial_sdk.utils.streaming import add_heartbeat | ||
from tests.utils.constants import DUMMY_FASTAPI_REQUEST | ||
|
||
|
||
class Counter: | ||
done: int = 0 | ||
cancelled: int = 0 | ||
_lock: asyncio.Lock | ||
|
||
def __init__(self) -> None: | ||
self._lock = asyncio.Lock() | ||
|
||
async def inc_done(self): | ||
async with self._lock: | ||
self.done += 1 | ||
|
||
async def inc_cancelled(self): | ||
async with self._lock: | ||
self.cancelled += 1 | ||
|
||
|
||
async def _wait_forever(): | ||
await asyncio.Event().wait() | ||
|
||
|
||
async def _wait(counter: Counter, secs: Optional[int] = None): | ||
try: | ||
if secs is None: | ||
# wait forever | ||
await _wait_forever() | ||
else: | ||
for _ in range(secs): | ||
await asyncio.sleep(1) | ||
except asyncio.CancelledError: | ||
await counter.inc_cancelled() | ||
raise | ||
else: | ||
await counter.inc_done() | ||
|
||
|
||
def chat_completion_wait_forever(counter: Counter): | ||
|
||
async def _chat_completion(*args, **kwargs): | ||
await _wait(counter) | ||
|
||
return _chat_completion | ||
|
||
|
||
def chat_completion_gather(counter: Counter): | ||
|
||
async def _chat_completion(*args, **kwargs): | ||
tasks = (asyncio.create_task(_wait(counter)) for _ in range(10)) | ||
await asyncio.gather(*tasks) | ||
|
||
return _chat_completion | ||
|
||
|
||
def chat_completion_create_task(counter: Counter): | ||
|
||
async def _chat_completion(*args, **kwargs): | ||
for _ in range(10): | ||
asyncio.create_task(_wait(counter, 3)) | ||
await _wait_forever() | ||
|
||
return _chat_completion | ||
|
||
|
||
@pytest.mark.parametrize("with_heartbeat", [True, False]) | ||
@pytest.mark.parametrize( | ||
"chat_completion, expected_cancelled, expected_done", | ||
[ | ||
(chat_completion_wait_forever, 1, 0), | ||
(chat_completion_gather, 10, 0), | ||
(chat_completion_create_task, 0, 10), | ||
], | ||
) | ||
async def test_cancellation( | ||
with_heartbeat: bool, chat_completion, expected_cancelled, expected_done | ||
): | ||
|
||
request = ChatCompletionRequest( | ||
original_request=DUMMY_FASTAPI_REQUEST, | ||
messages=[], | ||
api_key_secret=SecretStr("api-key"), | ||
deployment_id="test-app", | ||
headers={}, | ||
) | ||
|
||
response = ChatCompletionResponse(request) | ||
|
||
counter = Counter() | ||
chat_completion = chat_completion(counter) | ||
|
||
async def _exhaust_stream(stream): | ||
async for _ in stream: | ||
pass | ||
|
||
try: | ||
stream = response._generate_stream(chat_completion) | ||
if with_heartbeat: | ||
stream = add_heartbeat( | ||
stream, | ||
heartbeat_interval=0.2, | ||
heartbeat_object=": heartbeat\n\n", | ||
) | ||
|
||
await asyncio.wait_for(_exhaust_stream(stream), timeout=2) | ||
except asyncio.TimeoutError: | ||
pass | ||
else: | ||
assert False, "Stream should have timed out" | ||
|
||
await asyncio.sleep(2) | ||
|
||
assert ( | ||
counter.cancelled == expected_cancelled | ||
and counter.done == expected_done | ||
), "Stream should have been cancelled" |
Oops, something went wrong.