Skip to content

Commit

Permalink
Correct task cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
mephenor committed Dec 12, 2024
1 parent 4028d6a commit b53da5b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 17 deletions.
51 changes: 39 additions & 12 deletions src/ghga_connector/core/downloading/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
#
"""Contains a concrete implementation of the abstract downloader"""

import asyncio
import base64
from asyncio import Queue, Semaphore, Task, create_task
from asyncio import PriorityQueue, Queue, Semaphore, Task, create_task
from collections.abc import Coroutine
from io import BufferedWriter
from pathlib import Path
from typing import Any, Union
from typing import Any

import httpx
from tenacity import RetryError
Expand Down Expand Up @@ -51,11 +52,37 @@ class TaskHandler:
def __init__(self):
self._tasks: set[Task] = set()

async def schedule(self, fn: Coroutine[Any, Any, None]):
def schedule(self, fn: Coroutine[Any, Any, None]):
"""Create a task and register its callback."""
task = create_task(fn)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
task.add_done_callback(self.finalize)

def cancel_tasks(self):
"""Cancel all running taks."""
for task in self._tasks:
if not task.done():
task.cancel()

def finalize(self, task: Task):
"""Deal with potential errors when a task is done.
This is called as done callback, so there are three possibilites here:
1. A task encountered an exception: Cancel all remaining tasks and reraise
2. A task was cancelled: There's nothing to do, we are already propagating
the exception causing the cancellation
3. A task finished normally: Remove its handle
"""
if not task.cancelled():
exception = task.exception()
if exception:
self.cancel_tasks()
raise exception
self._tasks.discard(task)

async def gather(self):
"""Await all remaining tasks."""
await asyncio.gather(*self._tasks)


class Downloader(DownloaderBase):
Expand All @@ -78,7 +105,7 @@ def __init__( # noqa: PLR0913
self._max_wait_time = max_wait_time
self._message_display = message_display
self._work_package_accessor = work_package_accessor
self._queue: Queue[Union[tuple[int, bytes], BaseException]] = Queue()
self._queue: Queue[tuple[int, bytes]] = PriorityQueue()
self._semaphore = Semaphore(value=max_concurrent_downloads)

async def download_file(self, *, output_path: Path, part_size: int):
Expand All @@ -93,7 +120,7 @@ async def download_file(self, *, output_path: Path, part_size: int):

# start async part download to intermediate queue
for part_range in part_ranges:
await task_handler.schedule(
task_handler.schedule(
self.download_to_queue(
url=url_response.download_url, part_range=part_range
)
Expand All @@ -107,6 +134,8 @@ async def download_file(self, *, output_path: Path, part_size: int):
exceptions.EnvelopeNotFoundError,
exceptions.ExternalApiError,
) as error:
# Cancel running tasks before raising
task_handler.cancel_tasks()
raise exceptions.GetEnvelopeError() from error

# Write the downloaded parts to a file
Expand All @@ -123,6 +152,7 @@ async def download_file(self, *, output_path: Path, part_size: int):
),
name="Write queue to file",
)
await task_handler.gather()
await write_to_file

async def fetch_download_url(self) -> URLResponse:
Expand Down Expand Up @@ -219,8 +249,8 @@ async def download_to_queue(self, *, url: str, part_range: PartRange) -> None:
await self.download_content_range(
url=url, start=part_range.start, end=part_range.stop
)
except BaseException as exception:
await self._queue.put(exception)
except Exception as exception:
raise exceptions.DownloadError(reason=str(exception)) from exception

async def download_content_range(
self,
Expand Down Expand Up @@ -273,10 +303,7 @@ async def drain_queue_to_file(
with ProgressBar(file_name=file_name, file_size=file_size) as progress:
while downloaded_size < file_size:
result = await self._queue.get()
if isinstance(result, BaseException):
raise exceptions.DownloadError(reason=str(result))
else:
start, part = result
start, part = result
file.seek(offset + start)
file.write(part)
# update tracking information
Expand Down
12 changes: 7 additions & 5 deletions tests/integration/test_file_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def test_download_file_parts(
task_handler = TaskHandler()

for part_range in part_ranges:
await task_handler.schedule(
task_handler.schedule(
downloader.download_to_queue(url=download_url, part_range=part_range)
)

Expand Down Expand Up @@ -154,12 +154,12 @@ async def test_download_file_parts(
part_size=part_size, total_file_size=total_file_size
)

await task_handler.schedule(
task_handler.schedule(
downloader.download_to_queue(
url=download_url, part_range=PartRange(-10000, -1)
)
)
await task_handler.schedule(
task_handler.schedule(
downloader.download_to_queue(url=download_url, part_range=next(part_ranges))
)

Expand All @@ -171,6 +171,7 @@ async def test_download_file_parts(
)
)
with pytest.raises(DownloadError):
await task_handler.gather()
await dl_task

# test exception at the end
Expand All @@ -189,13 +190,13 @@ async def test_download_file_parts(
part_ranges = list(part_ranges) # type: ignore
for idx, part_range in enumerate(part_ranges):
if idx == len(part_ranges) - 1: # type: ignore
await task_handler.schedule(
task_handler.schedule(
downloader.download_to_queue(
url=download_url, part_range=PartRange(-10000, -1)
)
)
else:
await task_handler.schedule(
task_handler.schedule(
downloader.download_to_queue(
url=download_url, part_range=part_range
)
Expand All @@ -209,4 +210,5 @@ async def test_download_file_parts(
)
)
with pytest.raises(DownloadError):
await task_handler.gather()
await dl_task

0 comments on commit b53da5b

Please sign in to comment.