Skip to content

Commit

Permalink
Correctly exit progress bar on error
Browse files Browse the repository at this point in the history
  • Loading branch information
mephenor committed Dec 12, 2024
1 parent b53da5b commit 4d7b0fd
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 21 deletions.
8 changes: 7 additions & 1 deletion src/ghga_connector/core/downloading/abstract_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any

from ghga_connector.core import PartRange
from ghga_connector.core.downloading.progress_bar import ProgressBar

from .structs import URLResponse

Expand Down Expand Up @@ -69,7 +70,12 @@ async def download_content_range(

@abstractmethod
async def drain_queue_to_file(
self, *, file_name: str, file: BufferedWriter, file_size: int, offset: int
self,
*,
file: BufferedWriter,
file_size: int,
offset: int,
progress_bar: ProgressBar,
) -> None:
"""Write downloaded file bytes from queue.
This should be started as asyncio.Task and awaited after the download_to_queue
Expand Down
37 changes: 23 additions & 14 deletions src/ghga_connector/core/downloading/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,21 @@ async def download_file(self, *, output_path: Path, part_size: int):
raise exceptions.GetEnvelopeError() from error

# Write the downloaded parts to a file
with output_path.open("wb") as file:
with (
output_path.open("wb") as file,
ProgressBar(
file_name=file.name, file_size=url_response.file_size
) as progress_bar,
):
# put envelope in file
file.write(envelope)
# start download task
write_to_file = Task(
self.drain_queue_to_file(
file_name=file.name,
file=file,
file_size=url_response.file_size,
offset=len(envelope),
progress_bar=progress_bar,
),
name="Write queue to file",
)
Expand Down Expand Up @@ -291,7 +296,12 @@ async def download_content_range(
raise exceptions.BadResponseCodeError(url=url, response_code=status_code)

async def drain_queue_to_file(
self, *, file_name: str, file: BufferedWriter, file_size: int, offset: int
self,
*,
file: BufferedWriter,
file_size: int,
offset: int,
progress_bar: ProgressBar,
) -> None:
"""Write downloaded file bytes from queue.
This should be started as asyncio.Task and awaited after the download_to_queue
Expand All @@ -300,14 +310,13 @@ async def drain_queue_to_file(
# track and display actually written bytes
downloaded_size = 0

with ProgressBar(file_name=file_name, file_size=file_size) as progress:
while downloaded_size < file_size:
result = await self._queue.get()
start, part = result
file.seek(offset + start)
file.write(part)
# update tracking information
chunk_size = len(part)
downloaded_size += chunk_size
self._queue.task_done()
progress.advance(chunk_size)
while downloaded_size < file_size:
result = await self._queue.get()
start, part = result
file.seek(offset + start)
file.write(part)
# update tracking information
chunk_size = len(part)
downloaded_size += chunk_size
self._queue.task_done()
progress_bar.advance(chunk_size)
32 changes: 26 additions & 6 deletions tests/integration/test_file_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
calc_part_ranges,
)
from ghga_connector.core.downloading.downloader import Downloader, TaskHandler
from ghga_connector.core.downloading.progress_bar import ProgressBar
from ghga_connector.core.exceptions import DownloadError
from tests.fixtures.s3 import ( # noqa: F401
S3Fixture,
Expand Down Expand Up @@ -129,12 +130,19 @@ async def test_download_file_parts(
)

file_path = tmp_path / "test.file"
with file_path.open("wb") as file:
with (
file_path.open("wb") as file,
ProgressBar(file_name=file.name, file_size=total_file_size) as progress_bar,
):
dl_task = create_task(
downloader.drain_queue_to_file(
file_name=file.name, file=file, file_size=total_file_size, offset=0
file=file,
file_size=total_file_size,
offset=0,
progress_bar=progress_bar,
)
)
await task_handler.gather()
await dl_task

num_bytes_obtained = file_path.stat().st_size
Expand Down Expand Up @@ -164,10 +172,16 @@ async def test_download_file_parts(
)

file_path = tmp_path / "test2.file"
with file_path.open("wb") as file:
with (
file_path.open("wb") as file,
ProgressBar(file_name=file.name, file_size=total_file_size) as progress_bar,
):
dl_task = create_task(
downloader.drain_queue_to_file(
file_name=file.name, file=file, file_size=total_file_size, offset=0
file=file,
file_size=total_file_size,
offset=0,
progress_bar=progress_bar,
)
)
with pytest.raises(DownloadError):
Expand Down Expand Up @@ -203,10 +217,16 @@ async def test_download_file_parts(
)

file_path = tmp_path / "test3.file"
with file_path.open("wb") as file:
with (
file_path.open("wb") as file,
ProgressBar(file_name=file.name, file_size=total_file_size) as progress_bar,
):
dl_task = create_task(
downloader.drain_queue_to_file(
file_name=file.name, file=file, file_size=total_file_size, offset=0
file=file,
file_size=total_file_size,
offset=0,
progress_bar=progress_bar,
)
)
with pytest.raises(DownloadError):
Expand Down

0 comments on commit 4d7b0fd

Please sign in to comment.