Skip to content

Commit

Permalink
fix: handling client disconnect gracefully by propagating cancellatio…
Browse files Browse the repository at this point in the history
…n exception up to the chat completion handler (#185)
  • Loading branch information
adubovik authored Nov 25, 2024
1 parent 424b740 commit 45681f3
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 77 deletions.
66 changes: 31 additions & 35 deletions aidial_sdk/chat_completion/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
from aidial_sdk.chat_completion.request import Request
from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.exceptions import RequestValidationError, RuntimeServerError
from aidial_sdk.utils._cancel_scope import CancelScope
from aidial_sdk.utils.errors import RUNTIME_ERROR_MESSAGE, runtime_error
from aidial_sdk.utils.logging import log_error, log_exception
from aidial_sdk.utils.merge_chunks import merge
from aidial_sdk.utils.streaming import ResponseStream

_Producer = Callable[[Request, "Response"], Coroutine[Any, Any, Any]]


class Response:
request: Request
Expand Down Expand Up @@ -68,49 +71,38 @@ def n(self) -> int:
def stream(self) -> int:
return self.request.stream

async def _generate_stream(
self,
producer: Callable[[Request, "Response"], Coroutine[Any, Any, Any]],
) -> ResponseStream:
async def _run_producer(self, producer: _Producer):
try:
await producer(self.request, self)
except Exception as e:
if isinstance(e, DIALException):
dial_exception = e
else:
log_exception(RUNTIME_ERROR_MESSAGE)
dial_exception = RuntimeServerError(RUNTIME_ERROR_MESSAGE)

self._queue.put_nowait(ExceptionChunk(dial_exception))
else:
self._queue.put_nowait(EndChunk())

async def _generate_stream(self, producer: _Producer) -> ResponseStream:
async with CancelScope() as cs:
cs.create_task(self._run_producer(producer))

async for chunk in self._generate_chunk_stream():
yield chunk

def _create_chunk(chunk):
async def _generate_chunk_stream(self) -> ResponseStream:
def _create_chunk(chunk: BaseChunk):
return BaseChunkWithDefaults(
chunk=chunk, defaults=self._default_chunk
)

user_task = asyncio.create_task(producer(self.request, self))
user_task_is_done = False

# A list of chunks whose emitting is delayed up until the very last moment
delayed_chunks: List[BaseChunk] = []

while True:
get_chunk_task = asyncio.create_task(self._queue.get())
done = (
await asyncio.wait(
[get_chunk_task, user_task],
return_when=asyncio.FIRST_COMPLETED,
)
)[0]

if user_task in done and not user_task_is_done:
user_task_is_done = True
try:
user_task.result()
except Exception as e:
if isinstance(e, DIALException):
dial_exception = e
else:
log_exception(RUNTIME_ERROR_MESSAGE)
dial_exception = RuntimeServerError(
RUNTIME_ERROR_MESSAGE
)

self._queue.put_nowait(ExceptionChunk(dial_exception))
else:
self._queue.put_nowait(EndChunk())

chunk = await get_chunk_task
chunk = await self._queue.get()
self._queue.task_done()

if isinstance(chunk, BaseChunk):
Expand All @@ -122,7 +114,11 @@ def _create_chunk(chunk):

is_top_level_chunk = isinstance(
chunk,
(UsageChunk, UsagePerModelChunk, DiscardedMessagesChunk),
(
UsageChunk,
UsagePerModelChunk,
DiscardedMessagesChunk,
),
)

if is_last_end_choice_chunk or is_top_level_chunk:
Expand Down
72 changes: 72 additions & 0 deletions aidial_sdk/utils/_cancel_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import asyncio
from asyncio import exceptions
from typing import Optional, Set


class CancelScope:
"""
Async context manager that enforces cancellation of all tasks created within its scope when either:
1. the parent task has been cancelled or has thrown an exception or
2. any of the tasks created within the scope has thrown an exception.
"""

def __init__(self):
self._tasks: Set[asyncio.Task] = set()
self._on_completed_fut: Optional[asyncio.Future] = None
self._cancelling: bool = False

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):

cancelled_error = (
exc if isinstance(exc, exceptions.CancelledError) else None
)

# If the parent task has thrown an exception, cancel all the tasks
if exc_type is not None:
self._cancel_tasks()

while self._tasks:
if self._on_completed_fut is None:
self._on_completed_fut = asyncio.Future()

# If the parent task was cancelled, cancel all the tasks
try:
await self._on_completed_fut
except exceptions.CancelledError as ex:
cancelled_error = ex
self._cancel_tasks()

self._on_completed_fut = None

if cancelled_error:
raise cancelled_error

def create_task(self, coro):
task = asyncio.create_task(coro)
task.add_done_callback(self._on_task_done)
self._tasks.add(task)
return task

def _cancel_tasks(self):
if not self._cancelling:
self._cancelling = True
for t in self._tasks:
if not t.done():
t.cancel()

def _on_task_done(self, task):
self._tasks.discard(task)

if (
self._on_completed_fut is not None
and not self._on_completed_fut.done()
and not self._tasks
):
self._on_completed_fut.set_result(True)

# If any of the tasks was cancelled, cancel all the tasks
if task.exception() is not None:
self._cancel_tasks()
56 changes: 28 additions & 28 deletions aidial_sdk/utils/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from aidial_sdk.chat_completion.chunks import BaseChunkWithDefaults
from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.utils._cancel_scope import CancelScope
from aidial_sdk.utils.logging import log_debug
from aidial_sdk.utils.merge_chunks import cleanup_indices, merge

Expand Down Expand Up @@ -136,31 +137,30 @@ async def add_heartbeat(
heartbeat_object: Optional[_HeartbeatObject] = None,
heartbeat_callback: Optional[_HeartbeatCallback] = None,
) -> AsyncGenerator[_T, None]:
chunk_task: Optional[asyncio.Task[_T]] = None

while True:
if chunk_task is None:
chunk_task = asyncio.create_task(stream.__anext__())

done = (
await asyncio.wait(
[chunk_task],
timeout=heartbeat_interval,
return_when=asyncio.FIRST_COMPLETED,
)
)[0]

if chunk_task in done:
try:
chunk, chunk_task = chunk_task.result(), None
yield chunk
except StopAsyncIteration:
break
except Exception as e:
raise e
else:
if heartbeat_object is not None:
yield await _eval_heartbeat_object(heartbeat_object)

if heartbeat_callback is not None:
await _call_heartbeat_callback(heartbeat_callback)
async with CancelScope() as cs:
chunk_task: Optional[asyncio.Task[_T]] = None

while True:
if chunk_task is None:
chunk_task = cs.create_task(stream.__anext__())

done = (
await asyncio.wait(
[chunk_task],
timeout=heartbeat_interval,
return_when=asyncio.FIRST_COMPLETED,
)
)[0]

if chunk_task in done:
try:
chunk, chunk_task = chunk_task.result(), None
yield chunk
except StopAsyncIteration:
break
else:
if heartbeat_object is not None:
yield await _eval_heartbeat_object(heartbeat_object)

if heartbeat_callback is not None:
await _call_heartbeat_callback(heartbeat_callback)
9 changes: 0 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,6 @@ exclude = [

[tool.black]
line-length = 80
exclude = '''
/(
\.git
| \.venv
| \.nox
| \.pytest_cache
| \.__pycache__
)/
'''

[tool.isort]
line_length = 80
Expand Down
128 changes: 128 additions & 0 deletions tests/test_cancellation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import asyncio
from typing import Optional

import pytest

from aidial_sdk.chat_completion.request import Request as ChatCompletionRequest
from aidial_sdk.chat_completion.response import (
Response as ChatCompletionResponse,
)
from aidial_sdk.pydantic_v1 import SecretStr
from aidial_sdk.utils.streaming import add_heartbeat
from tests.utils.constants import DUMMY_FASTAPI_REQUEST


class Counter:
done: int = 0
cancelled: int = 0
_lock: asyncio.Lock

def __init__(self) -> None:
self._lock = asyncio.Lock()

async def inc_done(self):
async with self._lock:
self.done += 1

async def inc_cancelled(self):
async with self._lock:
self.cancelled += 1


async def _wait_forever():
await asyncio.Event().wait()


async def _wait(counter: Counter, secs: Optional[int] = None):
try:
if secs is None:
# wait forever
await _wait_forever()
else:
for _ in range(secs):
await asyncio.sleep(1)
except asyncio.CancelledError:
await counter.inc_cancelled()
raise
else:
await counter.inc_done()


def chat_completion_wait_forever(counter: Counter):

async def _chat_completion(*args, **kwargs):
await _wait(counter)

return _chat_completion


def chat_completion_gather(counter: Counter):

async def _chat_completion(*args, **kwargs):
tasks = (asyncio.create_task(_wait(counter)) for _ in range(10))
await asyncio.gather(*tasks)

return _chat_completion


def chat_completion_create_task(counter: Counter):

async def _chat_completion(*args, **kwargs):
for _ in range(10):
asyncio.create_task(_wait(counter, 3))
await _wait_forever()

return _chat_completion


@pytest.mark.parametrize("with_heartbeat", [True, False])
@pytest.mark.parametrize(
"chat_completion, expected_cancelled, expected_done",
[
(chat_completion_wait_forever, 1, 0),
(chat_completion_gather, 10, 0),
(chat_completion_create_task, 0, 10),
],
)
async def test_cancellation(
with_heartbeat: bool, chat_completion, expected_cancelled, expected_done
):

request = ChatCompletionRequest(
original_request=DUMMY_FASTAPI_REQUEST,
messages=[],
api_key_secret=SecretStr("api-key"),
deployment_id="test-app",
headers={},
)

response = ChatCompletionResponse(request)

counter = Counter()
chat_completion = chat_completion(counter)

async def _exhaust_stream(stream):
async for _ in stream:
pass

try:
stream = response._generate_stream(chat_completion)
if with_heartbeat:
stream = add_heartbeat(
stream,
heartbeat_interval=0.2,
heartbeat_object=": heartbeat\n\n",
)

await asyncio.wait_for(_exhaust_stream(stream), timeout=2)
except asyncio.TimeoutError:
pass
else:
assert False, "Stream should have timed out"

await asyncio.sleep(2)

assert (
counter.cancelled == expected_cancelled
and counter.done == expected_done
), "Stream should have been cancelled"
Loading

0 comments on commit 45681f3

Please sign in to comment.