From b53da5bcb2ecbd89120182dcdbc9ba4091122da1 Mon Sep 17 00:00:00 2001 From: "Thomas J. Zajac" Date: Thu, 12 Dec 2024 13:44:04 +0000 Subject: [PATCH] Correct task cancellation --- .../core/downloading/downloader.py | 51 ++++++++++++++----- tests/integration/test_file_operations.py | 12 +++-- 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/ghga_connector/core/downloading/downloader.py b/src/ghga_connector/core/downloading/downloader.py index 9bb43bd..1ac1c6f 100644 --- a/src/ghga_connector/core/downloading/downloader.py +++ b/src/ghga_connector/core/downloading/downloader.py @@ -15,12 +15,13 @@ # """Contains a concrete implementation of the abstract downloader""" +import asyncio import base64 -from asyncio import Queue, Semaphore, Task, create_task +from asyncio import PriorityQueue, Queue, Semaphore, Task, create_task from collections.abc import Coroutine from io import BufferedWriter from pathlib import Path -from typing import Any, Union +from typing import Any import httpx from tenacity import RetryError @@ -51,11 +52,37 @@ class TaskHandler: def __init__(self): self._tasks: set[Task] = set() - async def schedule(self, fn: Coroutine[Any, Any, None]): + def schedule(self, fn: Coroutine[Any, Any, None]): """Create a task and register its callback.""" task = create_task(fn) self._tasks.add(task) - task.add_done_callback(self._tasks.discard) + task.add_done_callback(self.finalize) + + def cancel_tasks(self): + """Cancel all running taks.""" + for task in self._tasks: + if not task.done(): + task.cancel() + + def finalize(self, task: Task): + """Deal with potential errors when a task is done. + + This is called as done callback, so there are three possibilites here: + 1. A task encountered an exception: Cancel all remaining tasks and reraise + 2. A task was cancelled: There's nothing to do, we are already propagating + the exception causing the cancellation + 3. A task finished normally: Remove its handle + """ + if not task.cancelled(): + exception = task.exception() + if exception: + self.cancel_tasks() + raise exception + self._tasks.discard(task) + + async def gather(self): + """Await all remaining tasks.""" + await asyncio.gather(*self._tasks) class Downloader(DownloaderBase): @@ -78,7 +105,7 @@ def __init__( # noqa: PLR0913 self._max_wait_time = max_wait_time self._message_display = message_display self._work_package_accessor = work_package_accessor - self._queue: Queue[Union[tuple[int, bytes], BaseException]] = Queue() + self._queue: Queue[tuple[int, bytes]] = PriorityQueue() self._semaphore = Semaphore(value=max_concurrent_downloads) async def download_file(self, *, output_path: Path, part_size: int): @@ -93,7 +120,7 @@ async def download_file(self, *, output_path: Path, part_size: int): # start async part download to intermediate queue for part_range in part_ranges: - await task_handler.schedule( + task_handler.schedule( self.download_to_queue( url=url_response.download_url, part_range=part_range ) @@ -107,6 +134,8 @@ async def download_file(self, *, output_path: Path, part_size: int): exceptions.EnvelopeNotFoundError, exceptions.ExternalApiError, ) as error: + # Cancel running tasks before raising + task_handler.cancel_tasks() raise exceptions.GetEnvelopeError() from error # Write the downloaded parts to a file @@ -123,6 +152,7 @@ async def download_file(self, *, output_path: Path, part_size: int): ), name="Write queue to file", ) + await task_handler.gather() await write_to_file async def fetch_download_url(self) -> URLResponse: @@ -219,8 +249,8 @@ async def download_to_queue(self, *, url: str, part_range: PartRange) -> None: await self.download_content_range( url=url, start=part_range.start, end=part_range.stop ) - except BaseException as exception: - await self._queue.put(exception) + except Exception as exception: + raise exceptions.DownloadError(reason=str(exception)) from exception async def download_content_range( self, @@ -273,10 +303,7 @@ async def drain_queue_to_file( with ProgressBar(file_name=file_name, file_size=file_size) as progress: while downloaded_size < file_size: result = await self._queue.get() - if isinstance(result, BaseException): - raise exceptions.DownloadError(reason=str(result)) - else: - start, part = result + start, part = result file.seek(offset + start) file.write(part) # update tracking information diff --git a/tests/integration/test_file_operations.py b/tests/integration/test_file_operations.py index e572f85..5063e11 100644 --- a/tests/integration/test_file_operations.py +++ b/tests/integration/test_file_operations.py @@ -124,7 +124,7 @@ async def test_download_file_parts( task_handler = TaskHandler() for part_range in part_ranges: - await task_handler.schedule( + task_handler.schedule( downloader.download_to_queue(url=download_url, part_range=part_range) ) @@ -154,12 +154,12 @@ async def test_download_file_parts( part_size=part_size, total_file_size=total_file_size ) - await task_handler.schedule( + task_handler.schedule( downloader.download_to_queue( url=download_url, part_range=PartRange(-10000, -1) ) ) - await task_handler.schedule( + task_handler.schedule( downloader.download_to_queue(url=download_url, part_range=next(part_ranges)) ) @@ -171,6 +171,7 @@ async def test_download_file_parts( ) ) with pytest.raises(DownloadError): + await task_handler.gather() await dl_task # test exception at the end @@ -189,13 +190,13 @@ async def test_download_file_parts( part_ranges = list(part_ranges) # type: ignore for idx, part_range in enumerate(part_ranges): if idx == len(part_ranges) - 1: # type: ignore - await task_handler.schedule( + task_handler.schedule( downloader.download_to_queue( url=download_url, part_range=PartRange(-10000, -1) ) ) else: - await task_handler.schedule( + task_handler.schedule( downloader.download_to_queue( url=download_url, part_range=part_range ) @@ -209,4 +210,5 @@ async def test_download_file_parts( ) ) with pytest.raises(DownloadError): + await task_handler.gather() await dl_task