Skip to content

Commit

Permalink
Retry handler refactoring + adapting tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mephenor committed Dec 11, 2024
1 parent ffd2c20 commit d19e3d6
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 108 deletions.
13 changes: 1 addition & 12 deletions src/ghga_connector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from ghga_connector.config import CONFIG
from ghga_connector.core import (
AbstractMessageDisplay,
HttpxClientConfigurator,
MessageColors,
WorkPackageAccessor,
async_client,
Expand Down Expand Up @@ -112,7 +111,7 @@ def exception_hook(
)

if value.args:
message = value.args[0]
message += f"\n{value.args[0]}"

message_display.failure(message)

Expand Down Expand Up @@ -215,11 +214,6 @@ async def async_upload(
):
"""Upload a file asynchronously"""
message_display = init_message_display(debug=debug)
HttpxClientConfigurator.configure(
exponential_backoff_max=CONFIG.exponential_backoff_max,
max_retries=CONFIG.max_retries,
retry_status_codes=CONFIG.retry_status_codes,
)
async with async_client() as client:
parameters = await retrieve_upload_parameters(client)
await upload_file(
Expand Down Expand Up @@ -293,11 +287,6 @@ async def async_download(
)

message_display = init_message_display(debug=debug)
HttpxClientConfigurator.configure(
exponential_backoff_max=CONFIG.exponential_backoff_max,
max_retries=CONFIG.max_retries,
retry_status_codes=CONFIG.retry_status_codes,
)
message_display.display("\nFetching work package token...")
work_package_information = get_work_package_information(
my_private_key=my_private_key, message_display=message_display
Expand Down
2 changes: 1 addition & 1 deletion src/ghga_connector/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
It should not contain any service API-related code.
"""

from .client import HttpxClientConfigurator, async_client, httpx_client # noqa: F401
from .client import async_client, httpx_client, retry_handler # noqa: F401
from .file_operations import ( # noqa: F401
calc_part_ranges,
get_segments,
Expand Down
40 changes: 25 additions & 15 deletions src/ghga_connector/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Handling session initialization for httpx"""

from contextlib import asynccontextmanager, contextmanager
from functools import cached_property

import httpx
from tenacity import (
Expand All @@ -25,23 +26,17 @@
wait_exponential_jitter,
)

from ghga_connector.config import CONFIG
from ghga_connector.constants import TIMEOUT


class HttpxClientConfigurator:
"""Helper class to make max_retries user configurable"""

retry_handler: AsyncRetrying

@classmethod
def configure(
cls,
exponential_backoff_max: int,
max_retries: int,
retry_status_codes: list[int],
):
@cached_property
def retry_handler(self):
"""Configure client retry handler with exponential backoff"""
cls.retry_handler = AsyncRetrying(
return AsyncRetrying(
reraise=True,
retry=(
retry_if_exception_type(
Expand All @@ -52,23 +47,38 @@ def configure(
)
)
| retry_if_result(
lambda response: response.status_code in retry_status_codes
lambda response: response.status_code in CONFIG.retry_status_codes
)
),
stop=stop_after_attempt(max_retries),
wait=wait_exponential_jitter(max=exponential_backoff_max),
stop=stop_after_attempt(CONFIG.max_retries),
wait=wait_exponential_jitter(max=CONFIG.exponential_backoff_max),
)


retry_handler = HttpxClientConfigurator().retry_handler


@contextmanager
def httpx_client():
"""Yields a context manager httpx client and closes it afterward"""
with httpx.Client(timeout=TIMEOUT) as client:
with httpx.Client(
timeout=TIMEOUT,
limits=httpx.Limits(
max_connections=CONFIG.max_concurrent_downloads,
max_keepalive_connections=CONFIG.max_concurrent_downloads,
),
) as client:
yield client


@asynccontextmanager
async def async_client():
"""Yields a context manager async httpx client and closes it afterward"""
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
async with httpx.AsyncClient(
timeout=TIMEOUT,
limits=httpx.Limits(
max_connections=CONFIG.max_concurrent_downloads,
max_keepalive_connections=CONFIG.max_concurrent_downloads,
),
) as client:
yield client
5 changes: 2 additions & 3 deletions src/ghga_connector/core/downloading/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@

from ghga_connector.core import (
AbstractMessageDisplay,
HttpxClientConfigurator,
PartRange,
ResponseExceptionTranslator,
WorkPackageAccessor,
calc_part_ranges,
exceptions,
retry_handler,
)

from .abstract_downloader import DownloaderBase
Expand Down Expand Up @@ -80,7 +80,6 @@ def __init__( # noqa: PLR0913
self._work_package_accessor = work_package_accessor
self._queue: Queue[Union[tuple[int, bytes], BaseException]] = Queue()
self._semaphore = Semaphore(value=max_concurrent_downloads)
self._retry_handler = HttpxClientConfigurator.retry_handler

async def download_file(self, *, output_path: Path, part_size: int):
"""Download file to the specified location and manage lower level details."""
Expand Down Expand Up @@ -234,7 +233,7 @@ async def download_content_range(
headers = httpx.Headers({"Range": f"bytes={start}-{end}"})

try:
response: httpx.Response = await self._retry_handler(
response: httpx.Response = await retry_handler(
fn=self._client.get, url=url, headers=headers
)
except RetryError as retry_error:
Expand Down
5 changes: 2 additions & 3 deletions src/ghga_connector/core/work_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ghga_service_commons.utils.crypt import decrypt
from tenacity import RetryError

from . import HttpxClientConfigurator, exceptions
from . import exceptions, retry_handler


class WorkPackageAccessor:
Expand All @@ -45,14 +45,13 @@ def __init__( # noqa: PLR0913
self.package_id = package_id
self.my_private_key = my_private_key
self.my_public_key = my_public_key
self.retry_handler = HttpxClientConfigurator.retry_handler

async def _call_url(
self, *, fn: Callable, headers: httpx.Headers, url: str
) -> httpx.Response:
"""Call url with provided headers and client method passed as callable."""
try:
response: httpx.Response = await self.retry_handler(
response: httpx.Response = await retry_handler(
fn=fn,
headers=headers,
url=url,
Expand Down
30 changes: 0 additions & 30 deletions tests/conftest.py

This file was deleted.

20 changes: 9 additions & 11 deletions tests/integration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@

unintercepted_hosts: list[str] = []

pytestmark = [
pytest.mark.asyncio,
pytest.mark.httpx_mock(
assert_all_responses_were_requested=False,
can_send_already_matched_responses=True,
should_mock=lambda request: request.url.host not in unintercepted_hosts,
),
]


def wkvs_method_mock(value: str):
"""Dummy to patch WVKS method"""
Expand All @@ -78,13 +87,6 @@ async def inner(self):
return inner


@pytest.fixture
def non_mocked_hosts() -> list:
"""Hosts that shall not be mocked by httpx."""
# Let requests go out to localstack/S3.
return unintercepted_hosts


@pytest.mark.parametrize(
"file_size, part_size",
[
Expand All @@ -105,7 +107,6 @@ def non_mocked_hosts() -> list:
(20 * 1024 * 1024, DEFAULT_PART_SIZE),
],
)
@pytest.mark.asyncio
async def test_multipart_download(
httpx_mock: HTTPXMock, # noqa: F811
file_size: int,
Expand Down Expand Up @@ -199,7 +200,6 @@ async def test_multipart_download(
),
],
)
@pytest.mark.asyncio
async def test_download(
httpx_mock: HTTPXMock, # noqa: F811
bad_url: bool,
Expand Down Expand Up @@ -334,7 +334,6 @@ async def test_download(
(False, "encrypted_file", exceptions.FileAlreadyEncryptedError),
],
)
@pytest.mark.asyncio
async def test_upload(
httpx_mock: HTTPXMock, # noqa: F811
bad_url: bool,
Expand Down Expand Up @@ -432,7 +431,6 @@ async def test_upload(
(20 * 1024 * 1024, 16),
],
)
@pytest.mark.asyncio
async def test_multipart_upload(
httpx_mock: HTTPXMock, # noqa: F811
file_size: int,
Expand Down
51 changes: 18 additions & 33 deletions tests/integration/test_file_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
async_client,
calc_part_ranges,
)
from ghga_connector.core.downloading import URLResponse
from ghga_connector.core.downloading.downloader import Downloader, TaskHandler
from ghga_connector.core.exceptions import DownloadError
from tests.fixtures.s3 import ( # noqa: F401
Expand Down Expand Up @@ -57,24 +56,14 @@ async def test_download_content_range(
end: int,
file_size: int,
s3_fixture: S3Fixture, # noqa: F811
monkeypatch,
):
"""Test the `download_content_range` function."""
# prepare state and the expected result:
big_object = await get_big_s3_object(s3_fixture, object_size=file_size)

async def download_url(self):
"""Drop in for monkeypatching"""
download_url = await s3_fixture.storage.get_object_download_url(
object_id=big_object.object_id, bucket_id=big_object.bucket_id
)
return URLResponse(download_url=download_url, file_size=0)

monkeypatch.setattr(
"ghga_connector.core.downloading.downloader.Downloader.get_download_url",
download_url,
)
expected_bytes = big_object.content[start : end + 1]
download_url = await s3_fixture.storage.get_object_download_url(
object_id=big_object.object_id, bucket_id=big_object.bucket_id
)

message_display = CLIMessageDisplay()
# download content range with dedicated function:
Expand All @@ -89,7 +78,7 @@ async def download_url(self):
work_package_accessor=dummy_accessor,
message_display=message_display,
)
await downloader.download_content_range(start=start, end=end)
await downloader.download_content_range(url=download_url, start=start, end=end)

result = await downloader._queue.get()
assert not isinstance(result, BaseException)
Expand All @@ -108,27 +97,17 @@ async def download_url(self):
async def test_download_file_parts(
part_size: int,
s3_fixture: S3Fixture, # noqa: F811
monkeypatch,
tmp_path,
):
"""Test the `download_file_parts` function."""
# prepare state and the expected result:
big_object = await get_big_s3_object(s3_fixture)
total_file_size = len(big_object.content)
expected_bytes = big_object.content

async def download_url(self):
"""Drop in for monkeypatching"""
download_url = await s3_fixture.storage.get_object_download_url(
object_id=big_object.object_id, bucket_id=big_object.bucket_id
)
return URLResponse(download_url=download_url, file_size=0)

monkeypatch.setattr(
"ghga_connector.core.downloading.downloader.Downloader.get_download_url",
download_url,
)
part_ranges = calc_part_ranges(part_size=part_size, total_file_size=total_file_size)
download_url = await s3_fixture.storage.get_object_download_url(
object_id=big_object.object_id, bucket_id=big_object.bucket_id
)

async with async_client() as client:
# no work package accessor calls in download_file_parts, just mock for correct type
Expand All @@ -146,7 +125,7 @@ async def download_url(self):

for part_range in part_ranges:
await task_handler.schedule(
downloader.download_to_queue(part_range=part_range)
downloader.download_to_queue(url=download_url, part_range=part_range)
)

file_path = tmp_path / "test.file"
Expand Down Expand Up @@ -176,10 +155,12 @@ async def download_url(self):
)

await task_handler.schedule(
downloader.download_to_queue(part_range=PartRange(-10000, -1))
downloader.download_to_queue(
url=download_url, part_range=PartRange(-10000, -1)
)
)
await task_handler.schedule(
downloader.download_to_queue(part_range=next(part_ranges))
downloader.download_to_queue(url=download_url, part_range=next(part_ranges))
)

file_path = tmp_path / "test2.file"
Expand Down Expand Up @@ -209,11 +190,15 @@ async def download_url(self):
for idx, part_range in enumerate(part_ranges):
if idx == len(part_ranges) - 1: # type: ignore
await task_handler.schedule(
downloader.download_to_queue(part_range=PartRange(-10000, -1))
downloader.download_to_queue(
url=download_url, part_range=PartRange(-10000, -1)
)
)
else:
await task_handler.schedule(
downloader.download_to_queue(part_range=part_range)
downloader.download_to_queue(
url=download_url, part_range=part_range
)
)

file_path = tmp_path / "test3.file"
Expand Down
Loading

0 comments on commit d19e3d6

Please sign in to comment.