Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry all methods on 429 #518

Merged
merged 7 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 37 additions & 42 deletions auth0/rest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import base64
import json
import platform
import sys
from json import dumps, loads
from random import randint
from time import sleep
from typing import TYPE_CHECKING, Any, Mapping
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
py_version = platform.python_version()
version = sys.modules["auth0"].__version__

auth0_client = json.dumps(
auth0_client = dumps(
{
"name": "auth0-python",
"version": version,
Expand Down Expand Up @@ -136,32 +136,41 @@ def MAX_REQUEST_RETRY_DELAY(self) -> int:
def MIN_REQUEST_RETRY_DELAY(self) -> int:
return 100

def get(
def _request(
self,
method: str,
url: str,
params: dict[str, Any] | None = None,
data: RequestData | None = None,
json: RequestData | None = None,
headers: dict[str, str] | None = None,
files: dict[str, Any] | None = None,
) -> Any:
request_headers = self.base_headers.copy()
request_headers.update(headers or {})

# Track the API request attempt number
attempt = 0

# Reset the metrics tracker
self._metrics = {"retries": 0, "backoff": []}

kwargs = {
k: v
for k, v in {
"params": params,
"json": json,
"data": data,
"headers": headers,
"files": files,
"timeout": self.options.timeout,
}.items()
if v is not None
}

while True:
# Increment attempt number
attempt += 1

# Issue the request
response = requests.get(
url,
params=params,
headers=request_headers,
timeout=self.options.timeout,
)
response = requests.request(method, url, **kwargs)

# If the response did not have a 429 header, or the attempt number is greater than the configured retries, break
if response.status_code != 429 or attempt > self._retries:
Expand All @@ -177,6 +186,16 @@ def get(
# Return the final Response
return self._process_response(response)

def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
request_headers = self.base_headers.copy()
request_headers.update(headers or {})
return self._request("GET", url, params=params, headers=request_headers)

def post(
self,
url: str,
Expand All @@ -185,11 +204,7 @@ def post(
) -> Any:
request_headers = self.base_headers.copy()
request_headers.update(headers or {})

response = requests.post(
url, json=data, headers=request_headers, timeout=self.options.timeout
)
return self._process_response(response)
return self._request("POST", url, json=data, headers=request_headers)

def file_post(
self,
Expand All @@ -199,27 +214,15 @@ def file_post(
) -> Any:
headers = self.base_headers.copy()
headers.pop("Content-Type", None)

response = requests.post(
url, data=data, files=files, headers=headers, timeout=self.options.timeout
)
return self._process_response(response)
return self._request("POST", url, data=data, files=files, headers=headers)

def patch(self, url: str, data: RequestData | None = None) -> Any:
headers = self.base_headers.copy()

response = requests.patch(
url, json=data, headers=headers, timeout=self.options.timeout
)
return self._process_response(response)
return self._request("PATCH", url, json=data, headers=headers)

def put(self, url: str, data: RequestData | None = None) -> Any:
headers = self.base_headers.copy()

response = requests.put(
url, json=data, headers=headers, timeout=self.options.timeout
)
return self._process_response(response)
return self._request("PUT", url, json=data, headers=headers)

def delete(
self,
Expand All @@ -228,15 +231,7 @@ def delete(
data: RequestData | None = None,
) -> Any:
headers = self.base_headers.copy()

response = requests.delete(
url,
headers=headers,
params=params or {},
json=data,
timeout=self.options.timeout,
)
return self._process_response(response)
return self._request("DELETE", url, params=params, json=data, headers=headers)

def _calculate_wait(self, attempt: int) -> int:
# Retry the request. Apply a exponential backoff for subsequent attempts, using this formula:
Expand Down Expand Up @@ -317,7 +312,7 @@ def _error_message(self):

class JsonResponse(Response):
def __init__(self, response: requests.Response | RequestsResponse) -> None:
content = json.loads(response.text)
content = loads(response.text)
super().__init__(response.status_code, content, response.headers)

def _error_code(self) -> str:
Expand Down
56 changes: 30 additions & 26 deletions auth0/rest_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,43 +52,23 @@ def set_session(self, session: aiohttp.ClientSession) -> None:
"""
self._session = session

async def _request(self, *args: Any, **kwargs: Any) -> Any:
kwargs["headers"] = kwargs.get("headers", self.base_headers)
kwargs["timeout"] = self.timeout
if self._session is not None:
# Request with re-usable session
async with self._session.request(*args, **kwargs) as response:
return await self._process_response(response)
else:
# Request without re-usable session
async with aiohttp.ClientSession() as session:
async with session.request(*args, **kwargs) as response:
return await self._process_response(response)

async def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
async def _request_with_session(
self, session: aiohttp.ClientSession, *args: Any, **kwargs: Any
) -> Any:
request_headers = self.base_headers.copy()
request_headers.update(headers or {})
# Track the API request attempt number
attempt = 0

# Reset the metrics tracker
self._metrics = {"retries": 0, "backoff": []}

params = _clean_params(params)
while True:
# Increment attempt number
attempt += 1

try:
response = await self._request(
"get", url, params=params, headers=request_headers
)
return response
async with session.request(*args, **kwargs) as response:
return await self._process_response(response)

except RateLimitError as e:
# If the attempt number is greater than the configured retries, raise RateLimitError
if attempt > self._retries:
Expand All @@ -101,6 +81,30 @@ async def get(
# sleep() functions in seconds, so convert the milliseconds formula above accordingly
await asyncio.sleep(wait / 1000)

async def _request(self, *args: Any, **kwargs: Any) -> Any:
kwargs["headers"] = kwargs.get("headers", self.base_headers)
kwargs["timeout"] = self.timeout
if self._session is not None:
# Request with re-usable session
return self._request_with_session(self.session, *args, **kwargs)
else:
# Request without re-usable session
async with aiohttp.ClientSession() as session:
return self._request_with_session(session, *args, **kwargs)

async def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
request_headers = self.base_headers.copy()
request_headers.update(headers or {})

return await self._request(
"get", url, params=_clean_params(params), headers=request_headers
)

async def post(
self,
url: str,
Expand All @@ -118,7 +122,7 @@ async def file_post(
files: dict[str, Any],
) -> Any:
headers = self.base_headers.copy()
headers.pop("Content-Type", None)
headers.pop("Content-Type")
return await self._request("post", url, data={**data, **files}, headers=headers)

async def patch(self, url: str, data: RequestData | None = None) -> Any:
Expand Down
Loading