diff --git a/src/ghga_connector/core/downloading/downloader.py b/src/ghga_connector/core/downloading/downloader.py index 4c2bd40..ae246d4 100644 --- a/src/ghga_connector/core/downloading/downloader.py +++ b/src/ghga_connector/core/downloading/downloader.py @@ -157,8 +157,13 @@ async def download_file(self, *, output_path: Path, part_size: int): ), name="Write queue to file", ) - await task_handler.gather() - await write_to_file + try: + await task_handler.gather() + except: + write_to_file.cancel() + raise + else: + await write_to_file async def fetch_download_url(self) -> URLResponse: """Fetch a work order token and retrieve the download url. diff --git a/tests/integration/test_file_operations.py b/tests/integration/test_file_operations.py index b79f071..c89169c 100644 --- a/tests/integration/test_file_operations.py +++ b/tests/integration/test_file_operations.py @@ -111,6 +111,7 @@ async def test_download_file_parts( object_id=big_object.object_id, bucket_id=big_object.bucket_id ) url_response = URLResponse(download_url, total_file_size) + mock_fetch = AsyncMock(return_value=url_response) async with async_client() as client: # no work package accessor calls in download_file_parts, just mock for correct type @@ -124,7 +125,7 @@ async def test_download_file_parts( work_package_accessor=dummy_accessor, message_display=message_display, ) - downloader.fetch_download_url = AsyncMock(return_value=url_response) + downloader.fetch_download_url = mock_fetch # type: ignore task_handler = TaskHandler() for part_range in part_ranges: @@ -158,7 +159,7 @@ async def test_download_file_parts( work_package_accessor=dummy_accessor, message_display=message_display, ) - downloader.fetch_download_url = AsyncMock(return_value=url_response) + downloader.fetch_download_url = mock_fetch # type: ignore task_handler = TaskHandler() part_ranges = calc_part_ranges( part_size=part_size, total_file_size=total_file_size @@ -185,8 +186,13 @@ async def test_download_file_parts( ) ) with pytest.raises(DownloadError): - await task_handler.gather() - await dl_task + try: + await task_handler.gather() + except: + dl_task.cancel() + raise + else: + await dl_task # test exception at the end downloader = Downloader( @@ -197,7 +203,7 @@ async def test_download_file_parts( work_package_accessor=dummy_accessor, message_display=message_display, ) - downloader.fetch_download_url = AsyncMock(return_value=url_response) + downloader.fetch_download_url = mock_fetch # type: ignore task_handler = TaskHandler() part_ranges = calc_part_ranges( part_size=part_size, total_file_size=total_file_size @@ -227,5 +233,10 @@ async def test_download_file_parts( ) ) with pytest.raises(DownloadError): - await task_handler.gather() - await dl_task + try: + await task_handler.gather() + except: + dl_task.cancel() + raise + else: + await dl_task