From 5b863673b4edc8cfd2d93ab76f5a3c2cfaafab3a Mon Sep 17 00:00:00 2001 From: Alex Goodman Date: Tue, 26 Nov 2024 13:56:47 -0500 Subject: [PATCH] limit nvd to default backoff Signed-off-by: Alex Goodman --- src/vunnel/providers/nvd/api.py | 1 - src/vunnel/utils/http.py | 14 +++++++++++++- tests/unit/utils/test_http.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/vunnel/providers/nvd/api.py b/src/vunnel/providers/nvd/api.py index 44a1ed9a..1d7681b9 100644 --- a/src/vunnel/providers/nvd/api.py +++ b/src/vunnel/providers/nvd/api.py @@ -164,7 +164,6 @@ def _request(self, url: str, parameters: dict[str, str], headers: dict[str, str] response = http.get( url, self.logger, - backoff_in_seconds=30, params=payload_str, headers=headers, timeout=self.timeout, diff --git a/src/vunnel/utils/http.py b/src/vunnel/utils/http.py index 9cf8b431..b2ad4f56 100644 --- a/src/vunnel/utils/http.py +++ b/src/vunnel/utils/http.py @@ -20,6 +20,7 @@ def get( # noqa: PLR0913 backoff_in_seconds: int = 3, timeout: int = DEFAULT_TIMEOUT, status_handler: Optional[Callable[[requests.Response], None]] = None, # noqa: UP007 - python 3.9 + max_interval: int = 600, **kwargs: Any, ) -> requests.Response: """ @@ -48,7 +49,7 @@ def get( # noqa: PLR0913 last_exception: Exception | None = None for attempt in range(retries + 1): if last_exception: - sleep_interval = backoff_in_seconds * 2 ** (attempt - 1) + random.uniform(0, 1) # noqa: S311 + sleep_interval = backoff_sleep_interval(backoff_in_seconds, attempt - 1, max_value=max_interval) logger.warning(f"will retry in {int(sleep_interval)} seconds...") time.sleep(sleep_interval) @@ -73,3 +74,14 @@ def get( # noqa: PLR0913 logger.error(f"last retry of GET {url} failed with {last_exception}") raise last_exception raise Exception("unreachable") + + +def backoff_sleep_interval(interval: int, attempt: int, max_value: None | int = None, jitter: bool = True) -> float: + # this is an exponential backoff + val = interval * 2**attempt + if max_value and val > max_value: + val = max_value + if jitter: + val += random.uniform(0, 1) # noqa: S311 + # explanation of S311 disable: rng is not used cryptographically + return val diff --git a/tests/unit/utils/test_http.py b/tests/unit/utils/test_http.py index 51ae27b4..5b89ed5c 100644 --- a/tests/unit/utils/test_http.py +++ b/tests/unit/utils/test_http.py @@ -133,3 +133,32 @@ def test_it_retries_when_status_handler_raises( # custom status handler raised the first time it was called, # so we expect the second mock response to be returned overall assert result == error_response + + +@pytest.mark.parametrize( + "interval, jitter, max_value, expected", + [ + ( + 30, # interval + False, # jitter + None, # max_value + [30, 60, 120, 240, 480, 960, 1920, 3840, 7680, 15360, 30720, 61440, 122880, 245760, 491520], # expected + ), + ( + 3, # interval + False, # jitter + 1000, # max_value + [3, 6, 12, 24, 48, 96, 192, 384, 768, 1000, 1000, 1000, 1000, 1000, 1000], # expected + ), + ], +) +def test_backoff_sleep_interval(interval, jitter, max_value, expected): + actual = [ + http.backoff_sleep_interval(interval, attempt, jitter=jitter, max_value=max_value) for attempt in range(len(expected)) + ] + + if not jitter: + assert actual == expected + else: + for i, (a, e) in enumerate(zip(actual, expected)): + assert a >= e and a <= e + 1, f"Jittered value out of bounds at attempt {i}: {a} (expected ~{e})"