Skip to content

Commit

Permalink
Improve download experience (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
mephenor authored Dec 16, 2024
1 parent 94b9452 commit dc4040e
Show file tree
Hide file tree
Showing 17 changed files with 1,264 additions and 1,289 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ repos:
- id: no-commit-to-branch
args: [--branch, dev, --branch, int, --branch, main]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.3
rev: v0.8.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.13.0
hooks:
- id: mypy
args: [--no-warn-unused-ignores]
1,698 changes: 820 additions & 878 deletions lock/requirements-dev.txt

Large diffs are not rendered by default.

362 changes: 186 additions & 176 deletions lock/requirements.txt

Large diffs are not rendered by default.

33 changes: 18 additions & 15 deletions src/ghga_connector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from ghga_connector.config import CONFIG
from ghga_connector.core import (
AbstractMessageDisplay,
HttpxClientConfigurator,
MessageColors,
WorkPackageAccessor,
async_client,
Expand All @@ -42,7 +41,7 @@
from ghga_connector.core.downloading.batch_processing import FileStager
from ghga_connector.core.main import (
decrypt_file,
download_files,
download_file,
get_wps_token,
upload_file,
)
Expand Down Expand Up @@ -112,7 +111,7 @@ def exception_hook(
)

if value.args:
message = value.args[0]
message += f"\n{value.args[0]}"

message_display.failure(message)

Expand Down Expand Up @@ -215,11 +214,6 @@ async def async_upload(
):
"""Upload a file asynchronously"""
message_display = init_message_display(debug=debug)
HttpxClientConfigurator.configure(
exponential_backoff_max=CONFIG.exponential_backoff_max,
max_retries=CONFIG.max_retries,
retry_status_codes=CONFIG.retry_status_codes,
)
async with async_client() as client:
parameters = await retrieve_upload_parameters(client)
await upload_file(
Expand Down Expand Up @@ -260,18 +254,30 @@ def download(
debug: bool = typer.Option(
False, help="Set this option in order to view traceback for errors."
),
overwrite: bool = typer.Option(
False,
help="Set to true to overwrite already existing files in the output directory.",
),
):
"""Wrapper for the async download function"""
asyncio.run(
async_download(output_dir, my_public_key_path, my_private_key_path, debug)
async_download(
output_dir=output_dir,
my_public_key_path=my_public_key_path,
my_private_key_path=my_private_key_path,
debug=debug,
overwrite=overwrite,
)
)


async def async_download(
*,
output_dir: Path,
my_public_key_path: Path,
my_private_key_path: Path,
debug: bool = False,
overwrite: bool = False,
):
"""Download files asynchronously"""
if not my_public_key_path.is_file():
Expand All @@ -286,11 +292,6 @@ async def async_download(
)

message_display = init_message_display(debug=debug)
HttpxClientConfigurator.configure(
exponential_backoff_max=CONFIG.exponential_backoff_max,
max_retries=CONFIG.max_retries,
retry_status_codes=CONFIG.retry_status_codes,
)
message_display.display("\nFetching work package token...")
work_package_information = get_work_package_information(
my_private_key=my_private_key, message_display=message_display
Expand All @@ -305,6 +306,7 @@ async def async_download(
work_package_information=work_package_information,
)

message_display.display("Preparing files for download...")
stager = FileStager(
wanted_file_ids=list(parameters.file_ids_with_extension),
dcs_api_url=parameters.dcs_api_url,
Expand All @@ -318,7 +320,7 @@ async def async_download(
staged_files = await stager.get_staged_files()
for file_id in staged_files:
message_display.display(f"Downloading file with id '{file_id}'...")
await download_files(
await download_file(
api_url=parameters.dcs_api_url,
client=client,
file_id=file_id,
Expand All @@ -329,6 +331,7 @@ async def async_download(
part_size=CONFIG.part_size,
message_display=message_display,
work_package_accessor=parameters.work_package_accessor,
overwrite=overwrite,
)
staged_files.clear()

Expand Down
2 changes: 1 addition & 1 deletion src/ghga_connector/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
It should not contain any service API-related code.
"""

from .client import HttpxClientConfigurator, async_client, httpx_client # noqa: F401
from .client import async_client, httpx_client, retry_handler # noqa: F401
from .file_operations import ( # noqa: F401
calc_part_ranges,
get_segments,
Expand Down
40 changes: 25 additions & 15 deletions src/ghga_connector/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Handling session initialization for httpx"""

from contextlib import asynccontextmanager, contextmanager
from functools import cached_property

import httpx
from tenacity import (
Expand All @@ -25,23 +26,17 @@
wait_exponential_jitter,
)

from ghga_connector.config import CONFIG
from ghga_connector.constants import TIMEOUT


class HttpxClientConfigurator:
"""Helper class to make max_retries user configurable"""

retry_handler: AsyncRetrying

@classmethod
def configure(
cls,
exponential_backoff_max: int,
max_retries: int,
retry_status_codes: list[int],
):
@cached_property
def retry_handler(self):
"""Configure client retry handler with exponential backoff"""
cls.retry_handler = AsyncRetrying(
return AsyncRetrying(
reraise=True,
retry=(
retry_if_exception_type(
Expand All @@ -52,23 +47,38 @@ def configure(
)
)
| retry_if_result(
lambda response: response.status_code in retry_status_codes
lambda response: response.status_code in CONFIG.retry_status_codes
)
),
stop=stop_after_attempt(max_retries),
wait=wait_exponential_jitter(max=exponential_backoff_max),
stop=stop_after_attempt(CONFIG.max_retries),
wait=wait_exponential_jitter(max=CONFIG.exponential_backoff_max),
)


retry_handler = HttpxClientConfigurator().retry_handler


@contextmanager
def httpx_client():
"""Yields a context manager httpx client and closes it afterward"""
with httpx.Client(timeout=TIMEOUT) as client:
with httpx.Client(
timeout=TIMEOUT,
limits=httpx.Limits(
max_connections=CONFIG.max_concurrent_downloads,
max_keepalive_connections=CONFIG.max_concurrent_downloads,
),
) as client:
yield client


@asynccontextmanager
async def async_client():
"""Yields a context manager async httpx client and closes it afterward"""
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
async with httpx.AsyncClient(
timeout=TIMEOUT,
limits=httpx.Limits(
max_connections=CONFIG.max_concurrent_downloads,
max_keepalive_connections=CONFIG.max_concurrent_downloads,
),
) as client:
yield client
17 changes: 10 additions & 7 deletions src/ghga_connector/core/downloading/abstract_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any

from ghga_connector.core import PartRange
from ghga_connector.core.downloading.progress_bar import ProgressBar

from .structs import URLResponse

Expand All @@ -34,17 +35,13 @@ def download_file(self, *, output_path: Path, part_size: int):
"""Download file to the specified location and manage lower level details."""

@abstractmethod
def await_download_url(self) -> Coroutine[URLResponse, Any, Any]:
def fetch_download_url(self) -> Coroutine[URLResponse, Any, Any]:
"""Wait until download URL can be generated.
Returns a URLResponse containing two elements:
1. the download url
2. the file size in bytes
"""

@abstractmethod
def get_download_url(self) -> Coroutine[URLResponse, Any, Any]:
"""Fetch a presigned URL from which file data can be downloaded."""

@abstractmethod
def get_file_header_envelope(self) -> Coroutine[bytes, Any, Any]:
"""
Expand All @@ -54,7 +51,7 @@ def get_file_header_envelope(self) -> Coroutine[bytes, Any, Any]:
"""

@abstractmethod
async def download_to_queue(self, *, part_range: PartRange) -> None:
async def download_to_queue(self, *, url: str, part_range: PartRange) -> None:
"""
Start downloading file parts in parallel into a queue.
This should be wrapped into asyncio.task and is guarded by a semaphore to limit
Expand All @@ -65,14 +62,20 @@ async def download_to_queue(self, *, part_range: PartRange) -> None:
async def download_content_range(
self,
*,
url: str,
start: int,
end: int,
) -> None:
"""Download a specific range of a file's content using a presigned url."""

@abstractmethod
async def drain_queue_to_file(
self, *, file_name: str, file: BufferedWriter, file_size: int, offset: int
self,
*,
file: BufferedWriter,
file_size: int,
offset: int,
progress_bar: ProgressBar,
) -> None:
"""Write downloaded file bytes from queue.
This should be started as asyncio.Task and awaited after the download_to_queue
Expand Down
25 changes: 16 additions & 9 deletions src/ghga_connector/core/downloading/batch_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from time import sleep, time
from time import perf_counter, sleep

import httpx

Expand Down Expand Up @@ -85,7 +85,7 @@ def get_input(self, *, message: str) -> str:

def handle_response(self, *, response: str):
"""Handle response from get_input."""
if not response.lower() == "yes":
if not (response.lower() == "yes" or response.lower() == "y"):
raise exceptions.AbortBatchProcessError()


Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__( # noqa: PLR0913
self.work_package_accessor = work_package_accessor
self.max_wait_time = config.max_wait_time
self.client = client
self.time_started = now = time()
self.started_waiting = now = perf_counter()

# Successfully staged files with their download URLs and sizes
# in the beginning, consider all files as staged with a retry time of 0
Expand All @@ -180,10 +180,14 @@ async def get_staged_files(self) -> dict[str, URLResponse]:
These values contain the download URLs and file sizes.
The dict should cleared after these files have been downloaded.
"""
self.message_display.display("Updating list of staged files...")
staging_items = list(self.unstaged_retry_times.items())
for file_id, retry_time in staging_items:
if time() >= retry_time:
if perf_counter() >= retry_time:
await self._check_file(file_id=file_id)
if len(self.staged_urls.items()) > 0:
self.started_waiting = perf_counter() # reset wait timer
break
if not self.staged_urls and not self._handle_failures():
sleep(1)
self._check_timeout()
Expand Down Expand Up @@ -217,8 +221,10 @@ async def _check_file(self, file_id: str) -> None:
if isinstance(response, URLResponse):
del self.unstaged_retry_times[file_id]
self.staged_urls[file_id] = response
self.message_display.display(f"File {file_id} is ready for download.")
elif isinstance(response, RetryResponse):
self.unstaged_retry_times[file_id] = time() + response.retry_after
self.unstaged_retry_times[file_id] = perf_counter() + response.retry_after
self.message_display.display(f"File {file_id} is (still) being staged.")
else:
self.missing_files.append(file_id)

Expand All @@ -227,7 +233,7 @@ def _check_timeout(self):
In that cases, a MaxWaitTimeExceededError is raised.
"""
if time() - self.time_started >= self.max_wait_time:
if perf_counter() - self.started_waiting >= self.max_wait_time:
raise exceptions.MaxWaitTimeExceededError(max_wait_time=self.max_wait_time)

def _handle_failures(self) -> bool:
Expand All @@ -238,8 +244,8 @@ def _handle_failures(self) -> bool:
"""
if not self.missing_files or self.ignore_failed:
return False
failed = ", ".join(self.missing_files)
message = f"No download exists for the following file IDs: {failed}"
missing = ", ".join(self.missing_files)
message = f"No download exists for the following file IDs: {missing}"
self.message_display.failure(message)
if self.finished:
return False
Expand All @@ -250,5 +256,6 @@ def _handle_failures(self) -> bool:
response = self.io_handler.get_input(message=unknown_ids_present)
self.io_handler.handle_response(response=response)
self.message_display.display("Downloading remaining files")
self.time_started = time() # reset the timer
self.started_waiting = perf_counter() # reset the timer
self.missing_files = [] # reset list of missing files
return True
Loading

0 comments on commit dc4040e

Please sign in to comment.