Skip to content

Commit

Permalink
Retry all methods on 429 (#518)
Browse files Browse the repository at this point in the history
### Changes

Currently the SDK only retries on 429s in a GET, add retry to the other
methods

### References

fixes #513
  • Loading branch information
adamjmcgrath authored Oct 24, 2023
2 parents 730b9f5 + 62ffd9a commit 82b3e1c
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 293 deletions.
79 changes: 37 additions & 42 deletions auth0/rest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import base64
import json
import platform
import sys
from json import dumps, loads
from random import randint
from time import sleep
from typing import TYPE_CHECKING, Any, Mapping
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
py_version = platform.python_version()
version = sys.modules["auth0"].__version__

auth0_client = json.dumps(
auth0_client = dumps(
{
"name": "auth0-python",
"version": version,
Expand Down Expand Up @@ -136,32 +136,41 @@ def MAX_REQUEST_RETRY_DELAY(self) -> int:
def MIN_REQUEST_RETRY_DELAY(self) -> int:
return 100

def get(
def _request(
self,
method: str,
url: str,
params: dict[str, Any] | None = None,
data: RequestData | None = None,
json: RequestData | None = None,
headers: dict[str, str] | None = None,
files: dict[str, Any] | None = None,
) -> Any:
request_headers = self.base_headers.copy()
request_headers.update(headers or {})

# Track the API request attempt number
attempt = 0

# Reset the metrics tracker
self._metrics = {"retries": 0, "backoff": []}

kwargs = {
k: v
for k, v in {
"params": params,
"json": json,
"data": data,
"headers": headers,
"files": files,
"timeout": self.options.timeout,
}.items()
if v is not None
}

while True:
# Increment attempt number
attempt += 1

# Issue the request
response = requests.get(
url,
params=params,
headers=request_headers,
timeout=self.options.timeout,
)
response = requests.request(method, url, **kwargs)

# If the response did not have a 429 header, or the attempt number is greater than the configured retries, break
if response.status_code != 429 or attempt > self._retries:
Expand All @@ -177,6 +186,16 @@ def get(
# Return the final Response
return self._process_response(response)

def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
request_headers = self.base_headers.copy()
request_headers.update(headers or {})
return self._request("GET", url, params=params, headers=request_headers)

def post(
self,
url: str,
Expand All @@ -185,11 +204,7 @@ def post(
) -> Any:
request_headers = self.base_headers.copy()
request_headers.update(headers or {})

response = requests.post(
url, json=data, headers=request_headers, timeout=self.options.timeout
)
return self._process_response(response)
return self._request("POST", url, json=data, headers=request_headers)

def file_post(
self,
Expand All @@ -199,27 +214,15 @@ def file_post(
) -> Any:
headers = self.base_headers.copy()
headers.pop("Content-Type", None)

response = requests.post(
url, data=data, files=files, headers=headers, timeout=self.options.timeout
)
return self._process_response(response)
return self._request("POST", url, data=data, files=files, headers=headers)

def patch(self, url: str, data: RequestData | None = None) -> Any:
headers = self.base_headers.copy()

response = requests.patch(
url, json=data, headers=headers, timeout=self.options.timeout
)
return self._process_response(response)
return self._request("PATCH", url, json=data, headers=headers)

def put(self, url: str, data: RequestData | None = None) -> Any:
headers = self.base_headers.copy()

response = requests.put(
url, json=data, headers=headers, timeout=self.options.timeout
)
return self._process_response(response)
return self._request("PUT", url, json=data, headers=headers)

def delete(
self,
Expand All @@ -228,15 +231,7 @@ def delete(
data: RequestData | None = None,
) -> Any:
headers = self.base_headers.copy()

response = requests.delete(
url,
headers=headers,
params=params or {},
json=data,
timeout=self.options.timeout,
)
return self._process_response(response)
return self._request("DELETE", url, params=params, json=data, headers=headers)

def _calculate_wait(self, attempt: int) -> int:
# Retry the request. Apply a exponential backoff for subsequent attempts, using this formula:
Expand Down Expand Up @@ -317,7 +312,7 @@ def _error_message(self):

class JsonResponse(Response):
def __init__(self, response: requests.Response | RequestsResponse) -> None:
content = json.loads(response.text)
content = loads(response.text)
super().__init__(response.status_code, content, response.headers)

def _error_code(self) -> str:
Expand Down
56 changes: 30 additions & 26 deletions auth0/rest_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,43 +52,23 @@ def set_session(self, session: aiohttp.ClientSession) -> None:
"""
self._session = session

async def _request(self, *args: Any, **kwargs: Any) -> Any:
kwargs["headers"] = kwargs.get("headers", self.base_headers)
kwargs["timeout"] = self.timeout
if self._session is not None:
# Request with re-usable session
async with self._session.request(*args, **kwargs) as response:
return await self._process_response(response)
else:
# Request without re-usable session
async with aiohttp.ClientSession() as session:
async with session.request(*args, **kwargs) as response:
return await self._process_response(response)

async def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
async def _request_with_session(
self, session: aiohttp.ClientSession, *args: Any, **kwargs: Any
) -> Any:
request_headers = self.base_headers.copy()
request_headers.update(headers or {})
# Track the API request attempt number
attempt = 0

# Reset the metrics tracker
self._metrics = {"retries": 0, "backoff": []}

params = _clean_params(params)
while True:
# Increment attempt number
attempt += 1

try:
response = await self._request(
"get", url, params=params, headers=request_headers
)
return response
async with session.request(*args, **kwargs) as response:
return await self._process_response(response)

except RateLimitError as e:
# If the attempt number is greater than the configured retries, raise RateLimitError
if attempt > self._retries:
Expand All @@ -101,6 +81,30 @@ async def get(
# sleep() functions in seconds, so convert the milliseconds formula above accordingly
await asyncio.sleep(wait / 1000)

async def _request(self, *args: Any, **kwargs: Any) -> Any:
kwargs["headers"] = kwargs.get("headers", self.base_headers)
kwargs["timeout"] = self.timeout
if self._session is not None:
# Request with re-usable session
return self._request_with_session(self.session, *args, **kwargs)
else:
# Request without re-usable session
async with aiohttp.ClientSession() as session:
return self._request_with_session(session, *args, **kwargs)

async def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
request_headers = self.base_headers.copy()
request_headers.update(headers or {})

return await self._request(
"get", url, params=_clean_params(params), headers=request_headers
)

async def post(
self,
url: str,
Expand All @@ -118,7 +122,7 @@ async def file_post(
files: dict[str, Any],
) -> Any:
headers = self.base_headers.copy()
headers.pop("Content-Type", None)
headers.pop("Content-Type")
return await self._request("post", url, data={**data, **files}, headers=headers)

async def patch(self, url: str, data: RequestData | None = None) -> Any:
Expand Down
Loading

0 comments on commit 82b3e1c

Please sign in to comment.