From c76f867069f3c1e79aa586da2bc1ebff63e7af43 Mon Sep 17 00:00:00 2001 From: Adam Mcgrath Date: Fri, 25 Aug 2023 13:50:28 +0100 Subject: [PATCH 1/4] Retry all methods on 429 --- auth0/rest.py | 85 +++---- auth0/test/authentication/test_base.py | 148 ++++++------ auth0/test/management/test_branding.py | 3 +- auth0/test/management/test_rest.py | 315 +++++++++++++------------ auth0/test_async/test_async_auth0.py | 2 +- 5 files changed, 287 insertions(+), 266 deletions(-) diff --git a/auth0/rest.py b/auth0/rest.py index 41282b74..c6eb4e3f 100644 --- a/auth0/rest.py +++ b/auth0/rest.py @@ -1,9 +1,9 @@ from __future__ import annotations import base64 -import json import platform import sys +from json import dumps, loads from random import randint from time import sleep from typing import TYPE_CHECKING, Any, Mapping @@ -95,7 +95,7 @@ def __init__( py_version = platform.python_version() version = sys.modules["auth0"].__version__ - auth0_client = json.dumps( + auth0_client = dumps( { "name": "auth0-python", "version": version, @@ -136,14 +136,20 @@ def MAX_REQUEST_RETRY_DELAY(self) -> int: def MIN_REQUEST_RETRY_DELAY(self) -> int: return 100 - def get( + def _request( self, + method: str, url: str, params: dict[str, Any] | None = None, + data: RequestData | None = None, + json: RequestData | None = None, headers: dict[str, str] | None = None, + files: dict[str, Any] | None = None, ) -> Any: request_headers = self.base_headers.copy() request_headers.update(headers or {}) + if files: + request_headers.pop("Content-Type") # Track the API request attempt number attempt = 0 @@ -151,17 +157,25 @@ def get( # Reset the metrics tracker self._metrics = {"retries": 0, "backoff": []} + kwargs = { + k: v + for k, v in { + "params": params, + "json": json, + "data": data, + "headers": request_headers, + "files": files, + "timeout": self.options.timeout, + }.items() + if v is not None + } + while True: # Increment attempt number attempt += 1 # Issue the request - response = requests.get( - url, - params=params, - headers=request_headers, - timeout=self.options.timeout, - ) + response = requests.request(method, url, **kwargs) # If the response did not have a 429 header, or the attempt number is greater than the configured retries, break if response.status_code != 429 or attempt > self._retries: @@ -177,19 +191,21 @@ def get( # Return the final Response return self._process_response(response) + def get( + self, + url: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: + return self._request("GET", url, params=params, headers=headers) + def post( self, url: str, data: RequestData | None = None, headers: dict[str, str] | None = None, ) -> Any: - request_headers = self.base_headers.copy() - request_headers.update(headers or {}) - - response = requests.post( - url, json=data, headers=request_headers, timeout=self.options.timeout - ) - return self._process_response(response) + return self._request("POST", url, json=data, headers=headers) def file_post( self, @@ -197,29 +213,18 @@ def file_post( data: RequestData | None = None, files: dict[str, Any] | None = None, ) -> Any: - headers = self.base_headers.copy() - headers.pop("Content-Type", None) - - response = requests.post( - url, data=data, files=files, headers=headers, timeout=self.options.timeout + return self._request( + "POST", + url, + data=data, + files=files, ) - return self._process_response(response) def patch(self, url: str, data: RequestData | None = None) -> Any: - headers = self.base_headers.copy() - - response = requests.patch( - url, json=data, headers=headers, timeout=self.options.timeout - ) - return self._process_response(response) + return self._request("PATCH", url, json=data) def put(self, url: str, data: RequestData | None = None) -> Any: - headers = self.base_headers.copy() - - response = requests.put( - url, json=data, headers=headers, timeout=self.options.timeout - ) - return self._process_response(response) + return self._request("PUT", url, json=data) def delete( self, @@ -227,16 +232,12 @@ def delete( params: dict[str, Any] | None = None, data: RequestData | None = None, ) -> Any: - headers = self.base_headers.copy() - - response = requests.delete( + return self._request( + "DELETE", url, - headers=headers, - params=params or {}, + params=params, json=data, - timeout=self.options.timeout, ) - return self._process_response(response) def _calculate_wait(self, attempt: int) -> int: # Retry the request. Apply a exponential backoff for subsequent attempts, using this formula: @@ -317,7 +318,7 @@ def _error_message(self): class JsonResponse(Response): def __init__(self, response: requests.Response | RequestsResponse) -> None: - content = json.loads(response.text) + content = loads(response.text) super().__init__(response.status_code, content, response.headers) def _error_code(self) -> str: diff --git a/auth0/test/authentication/test_base.py b/auth0/test/authentication/test_base.py index 21a42d8f..eed9d040 100644 --- a/auth0/test/authentication/test_base.py +++ b/auth0/test/authentication/test_base.py @@ -42,16 +42,17 @@ def test_telemetry_disabled(self): self.assertEqual(ab.client.base_headers, {"Content-Type": "application/json"}) - @mock.patch("requests.post") - def test_post(self, mock_post): + @mock.patch("requests.request") + def test_post(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False, timeout=(10, 2)) - mock_post.return_value.status_code = 200 - mock_post.return_value.text = '{"x": "y"}' + mock_request.return_value.status_code = 200 + mock_request.return_value.text = '{"x": "y"}' data = ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) - mock_post.assert_called_with( + mock_request.assert_called_with( + "POST", "the-url", json={"a": "b"}, headers={"c": "d", "Content-Type": "application/json"}, @@ -60,37 +61,38 @@ def test_post(self, mock_post): self.assertEqual(data, {"x": "y"}) - @mock.patch("requests.post") - def test_post_with_defaults(self, mock_post): + @mock.patch("requests.request") + def test_post_with_defaults(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False) - mock_post.return_value.status_code = 200 - mock_post.return_value.text = '{"x": "y"}' + mock_request.return_value.status_code = 200 + mock_request.return_value.text = '{"x": "y"}' # Only required params are passed data = ab.post("the-url") - mock_post.assert_called_with( + mock_request.assert_called_with( + "POST", "the-url", - json=None, headers={"Content-Type": "application/json"}, timeout=5.0, ) self.assertEqual(data, {"x": "y"}) - @mock.patch("requests.post") - def test_post_includes_telemetry(self, mock_post): + @mock.patch("requests.request") + def test_post_includes_telemetry(self, mock_request): ab = AuthenticationBase("auth0.com", "cid") - mock_post.return_value.status_code = 200 - mock_post.return_value.text = '{"x": "y"}' + mock_request.return_value.status_code = 200 + mock_request.return_value.text = '{"x": "y"}' data = ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) - self.assertEqual(mock_post.call_count, 1) - call_args, call_kwargs = mock_post.call_args - self.assertEqual(call_args[0], "the-url") + self.assertEqual(mock_request.call_count, 1) + call_args, call_kwargs = mock_request.call_args + self.assertEqual(call_args[0], "POST") + self.assertEqual(call_args[1], "the-url") self.assertEqual(call_kwargs["json"], {"a": "b"}) headers = call_kwargs["headers"] self.assertEqual(headers["c"], "d") @@ -100,13 +102,15 @@ def test_post_includes_telemetry(self, mock_post): self.assertEqual(data, {"x": "y"}) - @mock.patch("requests.post") - def test_post_error(self, mock_post): + @mock.patch("requests.request") + def test_post_error(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False) for error_status in [400, 500, None]: - mock_post.return_value.status_code = error_status - mock_post.return_value.text = '{"error": "e0","error_description": "desc"}' + mock_request.return_value.status_code = error_status + mock_request.return_value.text = ( + '{"error": "e0","error_description": "desc"}' + ) with self.assertRaises(Auth0Error) as context: ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) @@ -115,12 +119,12 @@ def test_post_error(self, mock_post): self.assertEqual(context.exception.error_code, "e0") self.assertEqual(context.exception.message, "desc") - @mock.patch("requests.post") - def test_post_error_mfa_required(self, mock_post): + @mock.patch("requests.request") + def test_post_error_mfa_required(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False) - mock_post.return_value.status_code = 403 - mock_post.return_value.text = '{"error": "mfa_required", "error_description": "Multifactor authentication required", "mfa_token": "Fe26...Ha"}' + mock_request.return_value.status_code = 403 + mock_request.return_value.text = '{"error": "mfa_required", "error_description": "Multifactor authentication required", "mfa_token": "Fe26...Ha"}' with self.assertRaises(Auth0Error) as context: ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) @@ -132,15 +136,15 @@ def test_post_error_mfa_required(self, mock_post): ) self.assertEqual(context.exception.content.get("mfa_token"), "Fe26...Ha") - @mock.patch("requests.post") - def test_post_rate_limit_error(self, mock_post): + @mock.patch("requests.request") + def test_post_rate_limit_error(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False) - mock_post.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 429, "error": "e0", "error_description": "desc"}' ) - mock_post.return_value.status_code = 429 - mock_post.return_value.headers = { + mock_request.return_value.status_code = 429 + mock_request.return_value.headers = { "x-ratelimit-limit": "3", "x-ratelimit-remaining": "6", "x-ratelimit-reset": "9", @@ -155,15 +159,15 @@ def test_post_rate_limit_error(self, mock_post): self.assertIsInstance(context.exception, RateLimitError) self.assertEqual(context.exception.reset_at, 9) - @mock.patch("requests.post") - def test_post_rate_limit_error_without_headers(self, mock_post): + @mock.patch("requests.request") + def test_post_rate_limit_error_without_headers(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False) - mock_post.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 429, "error": "e0", "error_description": "desc"}' ) - mock_post.return_value.status_code = 429 - mock_post.return_value.headers = {} + mock_request.return_value.status_code = 429 + mock_request.return_value.headers = {} with self.assertRaises(Auth0Error) as context: ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) @@ -174,13 +178,15 @@ def test_post_rate_limit_error_without_headers(self, mock_post): self.assertIsInstance(context.exception, RateLimitError) self.assertEqual(context.exception.reset_at, -1) - @mock.patch("requests.post") - def test_post_error_with_code_property(self, mock_post): + @mock.patch("requests.request") + def test_post_error_with_code_property(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False) for error_status in [400, 500, None]: - mock_post.return_value.status_code = error_status - mock_post.return_value.text = '{"code": "e0","error_description": "desc"}' + mock_request.return_value.status_code = error_status + mock_request.return_value.text = ( + '{"code": "e0","error_description": "desc"}' + ) with self.assertRaises(Auth0Error) as context: ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) @@ -189,13 +195,13 @@ def test_post_error_with_code_property(self, mock_post): self.assertEqual(context.exception.error_code, "e0") self.assertEqual(context.exception.message, "desc") - @mock.patch("requests.post") - def test_post_error_with_no_error_code(self, mock_post): + @mock.patch("requests.request") + def test_post_error_with_no_error_code(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False) for error_status in [400, 500, None]: - mock_post.return_value.status_code = error_status - mock_post.return_value.text = '{"error_description": "desc"}' + mock_request.return_value.status_code = error_status + mock_request.return_value.text = '{"error_description": "desc"}' with self.assertRaises(Auth0Error) as context: ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) @@ -204,13 +210,13 @@ def test_post_error_with_no_error_code(self, mock_post): self.assertEqual(context.exception.error_code, "a0.sdk.internal.unknown") self.assertEqual(context.exception.message, "desc") - @mock.patch("requests.post") - def test_post_error_with_text_response(self, mock_post): + @mock.patch("requests.request") + def test_post_error_with_text_response(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False) for error_status in [400, 500, None]: - mock_post.return_value.status_code = error_status - mock_post.return_value.text = "there has been a terrible error" + mock_request.return_value.status_code = error_status + mock_request.return_value.text = "there has been a terrible error" with self.assertRaises(Auth0Error) as context: ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) @@ -221,13 +227,13 @@ def test_post_error_with_text_response(self, mock_post): context.exception.message, "there has been a terrible error" ) - @mock.patch("requests.post") - def test_post_error_with_no_response_text(self, mock_post): + @mock.patch("requests.request") + def test_post_error_with_no_response_text(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False) for error_status in [400, 500, None]: - mock_post.return_value.status_code = error_status - mock_post.return_value.text = None + mock_request.return_value.status_code = error_status + mock_request.return_value.text = None with self.assertRaises(Auth0Error) as context: ab.post("the-url", data={"a": "b"}, headers={"c": "d"}) @@ -236,16 +242,17 @@ def test_post_error_with_no_response_text(self, mock_post): self.assertEqual(context.exception.error_code, "a0.sdk.internal.unknown") self.assertEqual(context.exception.message, "") - @mock.patch("requests.get") - def test_get(self, mock_get): + @mock.patch("requests.request") + def test_get(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False, timeout=(10, 2)) - mock_get.return_value.status_code = 200 - mock_get.return_value.text = '{"x": "y"}' + mock_request.return_value.status_code = 200 + mock_request.return_value.text = '{"x": "y"}' data = ab.get("the-url", params={"a": "b"}, headers={"c": "d"}) - mock_get.assert_called_with( + mock_request.assert_called_with( + "GET", "the-url", params={"a": "b"}, headers={"c": "d", "Content-Type": "application/json"}, @@ -254,37 +261,38 @@ def test_get(self, mock_get): self.assertEqual(data, {"x": "y"}) - @mock.patch("requests.get") - def test_get_with_defaults(self, mock_get): + @mock.patch("requests.request") + def test_get_with_defaults(self, mock_request): ab = AuthenticationBase("auth0.com", "cid", telemetry=False) - mock_get.return_value.status_code = 200 - mock_get.return_value.text = '{"x": "y"}' + mock_request.return_value.status_code = 200 + mock_request.return_value.text = '{"x": "y"}' # Only required params are passed data = ab.get("the-url") - mock_get.assert_called_with( + mock_request.assert_called_with( + "GET", "the-url", - params=None, headers={"Content-Type": "application/json"}, timeout=5.0, ) self.assertEqual(data, {"x": "y"}) - @mock.patch("requests.get") - def test_get_includes_telemetry(self, mock_get): + @mock.patch("requests.request") + def test_get_includes_telemetry(self, mock_request): ab = AuthenticationBase("auth0.com", "cid") - mock_get.return_value.status_code = 200 - mock_get.return_value.text = '{"x": "y"}' + mock_request.return_value.status_code = 200 + mock_request.return_value.text = '{"x": "y"}' data = ab.get("the-url", params={"a": "b"}, headers={"c": "d"}) - self.assertEqual(mock_get.call_count, 1) - call_args, call_kwargs = mock_get.call_args - self.assertEqual(call_args[0], "the-url") + self.assertEqual(mock_request.call_count, 1) + call_args, call_kwargs = mock_request.call_args + self.assertEqual(call_args[0], "GET") + self.assertEqual(call_args[1], "the-url") self.assertEqual(call_kwargs["params"], {"a": "b"}) headers = call_kwargs["headers"] self.assertEqual(headers["c"], "d") diff --git a/auth0/test/management/test_branding.py b/auth0/test/management/test_branding.py index fd2c8584..daf2ac75 100644 --- a/auth0/test/management/test_branding.py +++ b/auth0/test/management/test_branding.py @@ -59,7 +59,7 @@ def test_delete_template_universal_login(self, mock_rc): "https://domain/api/v2/branding/templates/universal-login", ) - @mock.patch("auth0.rest.requests.put") + @mock.patch("auth0.rest.requests.request") def test_update_template_universal_login(self, mock_rc): mock_rc.return_value.status_code = 200 mock_rc.return_value.text = "{}" @@ -68,6 +68,7 @@ def test_update_template_universal_login(self, mock_rc): branding.update_template_universal_login({"a": "b", "c": "d"}) mock_rc.assert_called_with( + "PUT", "https://domain/api/v2/branding/templates/universal-login", json={"template": {"a": "b", "c": "d"}}, headers=mock.ANY, diff --git a/auth0/test/management/test_rest.py b/auth0/test/management/test_rest.py index 125ea92b..3495c8b4 100644 --- a/auth0/test/management/test_rest.py +++ b/auth0/test/management/test_rest.py @@ -135,117 +135,119 @@ def test_default_options_are_used(self): # with self.assertRaises(requests.exceptions.Timeout): # rc.delete("https://google.com") - @mock.patch("requests.get") - def test_get_custom_timeout(self, mock_get): + @mock.patch("requests.request") + def test_get_custom_timeout(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False, timeout=(10, 2)) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_get.return_value.text = '["a", "b"]' - mock_get.return_value.status_code = 200 + mock_request.return_value.text = '["a", "b"]' + mock_request.return_value.status_code = 200 rc.get("the-url") - mock_get.assert_called_with( - "the-url", params=None, headers=headers, timeout=(10, 2) + mock_request.assert_called_with( + "GET", "the-url", headers=headers, timeout=(10, 2) ) - @mock.patch("requests.post") - def test_post_custom_timeout(self, mock_post): + @mock.patch("requests.request") + def test_post_custom_timeout(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False, timeout=(10, 2)) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_post.return_value.text = '["a", "b"]' - mock_post.return_value.status_code = 200 + mock_request.return_value.text = '["a", "b"]' + mock_request.return_value.status_code = 200 rc.post("the-url") - mock_post.assert_called_with( - "the-url", json=None, headers=headers, timeout=(10, 2) + mock_request.assert_called_with( + "POST", "the-url", headers=headers, timeout=(10, 2) ) - @mock.patch("requests.put") - def test_put_custom_timeout(self, mock_put): + @mock.patch("requests.request") + def test_put_custom_timeout(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False, timeout=(10, 2)) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_put.return_value.text = '["a", "b"]' - mock_put.return_value.status_code = 200 + mock_request.return_value.text = '["a", "b"]' + mock_request.return_value.status_code = 200 rc.put("the-url") - mock_put.assert_called_with( - "the-url", json=None, headers=headers, timeout=(10, 2) + mock_request.assert_called_with( + "PUT", "the-url", headers=headers, timeout=(10, 2) ) - @mock.patch("requests.patch") - def test_patch_custom_timeout(self, mock_patch): + @mock.patch("requests.request") + def test_patch_custom_timeout(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False, timeout=(10, 2)) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_patch.return_value.text = '["a", "b"]' - mock_patch.return_value.status_code = 200 + mock_request.return_value.text = '["a", "b"]' + mock_request.return_value.status_code = 200 rc.patch("the-url") - mock_patch.assert_called_with( - "the-url", json=None, headers=headers, timeout=(10, 2) + mock_request.assert_called_with( + "PATCH", "the-url", headers=headers, timeout=(10, 2) ) - @mock.patch("requests.delete") - def test_delete_custom_timeout(self, mock_delete): + @mock.patch("requests.request") + def test_delete_custom_timeout(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False, timeout=(10, 2)) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_delete.return_value.text = '["a", "b"]' - mock_delete.return_value.status_code = 200 + mock_request.return_value.text = '["a", "b"]' + mock_request.return_value.status_code = 200 rc.delete("the-url") - mock_delete.assert_called_with( - "the-url", params={}, json=None, headers=headers, timeout=(10, 2) + mock_request.assert_called_with( + "DELETE", "the-url", headers=headers, timeout=(10, 2) ) - @mock.patch("requests.get") - def test_get(self, mock_get): + @mock.patch("requests.request") + def test_get(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_get.return_value.text = '["a", "b"]' - mock_get.return_value.status_code = 200 + mock_request.return_value.text = '["a", "b"]' + mock_request.return_value.status_code = 200 response = rc.get("the-url") - mock_get.assert_called_with( - "the-url", params=None, headers=headers, timeout=5.0 - ) + mock_request.assert_called_with("GET", "the-url", headers=headers, timeout=5.0) self.assertEqual(response, ["a", "b"]) response = rc.get(url="the/url", params={"A": "param", "B": "param"}) - mock_get.assert_called_with( - "the/url", params={"A": "param", "B": "param"}, headers=headers, timeout=5.0 + mock_request.assert_called_with( + "GET", + "the/url", + params={"A": "param", "B": "param"}, + headers=headers, + timeout=5.0, ) self.assertEqual(response, ["a", "b"]) - mock_get.return_value.text = "" + mock_request.return_value.text = "" response = rc.get("the/url") self.assertEqual(response, "") - @mock.patch("requests.get") - def test_get_errors(self, mock_get): + @mock.patch("requests.request") + def test_get_errors(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) - mock_get.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 999, "errorCode": "code", "message": "message"}' ) - mock_get.return_value.status_code = 999 + mock_request.return_value.status_code = 999 with self.assertRaises(Auth0Error) as context: rc.get("the/url") @@ -254,17 +256,17 @@ def test_get_errors(self, mock_get): self.assertEqual(context.exception.error_code, "code") self.assertEqual(context.exception.message, "message") - @mock.patch("requests.get") - def test_get_rate_limit_error(self, mock_get): + @mock.patch("requests.request") + def test_get_rate_limit_error(self, mock_request): options = RestClientOptions(telemetry=False, retries=0) rc = RestClient(jwt="a-token", options=options) rc._skip_sleep = True - mock_get.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 429, "errorCode": "code", "message": "message"}' ) - mock_get.return_value.status_code = 429 - mock_get.return_value.headers = { + mock_request.return_value.status_code = 429 + mock_request.return_value.headers = { "x-ratelimit-limit": "3", "x-ratelimit-remaining": "6", "x-ratelimit-reset": "9", @@ -281,17 +283,17 @@ def test_get_rate_limit_error(self, mock_get): self.assertEqual(rc._metrics["retries"], 0) - @mock.patch("requests.get") - def test_get_rate_limit_error_without_headers(self, mock_get): + @mock.patch("requests.request") + def test_get_rate_limit_error_without_headers(self, mock_request): options = RestClientOptions(telemetry=False, retries=1) rc = RestClient(jwt="a-token", options=options) - mock_get.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 429, "errorCode": "code", "message": "message"}' ) - mock_get.return_value.status_code = 429 + mock_request.return_value.status_code = 429 - mock_get.return_value.headers = {} + mock_request.return_value.headers = {} with self.assertRaises(Auth0Error) as context: rc.get("the/url") @@ -303,17 +305,17 @@ def test_get_rate_limit_error_without_headers(self, mock_get): self.assertEqual(rc._metrics["retries"], 1) - @mock.patch("requests.get") - def test_get_rate_limit_custom_retries(self, mock_get): + @mock.patch("requests.request") + def test_get_rate_limit_custom_retries(self, mock_request): options = RestClientOptions(telemetry=False, retries=5) rc = RestClient(jwt="a-token", options=options) rc._skip_sleep = True - mock_get.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 429, "errorCode": "code", "message": "message"}' ) - mock_get.return_value.status_code = 429 - mock_get.return_value.headers = { + mock_request.return_value.status_code = 429 + mock_request.return_value.headers = { "x-ratelimit-limit": "3", "x-ratelimit-remaining": "6", "x-ratelimit-reset": "9", @@ -331,17 +333,17 @@ def test_get_rate_limit_custom_retries(self, mock_get): self.assertEqual(rc._metrics["retries"], 5) self.assertEqual(rc._metrics["retries"], len(rc._metrics["backoff"])) - @mock.patch("requests.get") - def test_get_rate_limit_invalid_retries_below_min(self, mock_get): + @mock.patch("requests.request") + def test_get_rate_limit_invalid_retries_below_min(self, mock_request): options = RestClientOptions(telemetry=False, retries=-1) rc = RestClient(jwt="a-token", options=options) rc._skip_sleep = True - mock_get.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 429, "errorCode": "code", "message": "message"}' ) - mock_get.return_value.status_code = 429 - mock_get.return_value.headers = { + mock_request.return_value.status_code = 429 + mock_request.return_value.headers = { "x-ratelimit-limit": "3", "x-ratelimit-remaining": "6", "x-ratelimit-reset": "9", @@ -358,17 +360,17 @@ def test_get_rate_limit_invalid_retries_below_min(self, mock_get): self.assertEqual(rc._metrics["retries"], 0) - @mock.patch("requests.get") - def test_get_rate_limit_invalid_retries_above_max(self, mock_get): + @mock.patch("requests.request") + def test_get_rate_limit_invalid_retries_above_max(self, mock_request): options = RestClientOptions(telemetry=False, retries=11) rc = RestClient(jwt="a-token", options=options) rc._skip_sleep = True - mock_get.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 429, "errorCode": "code", "message": "message"}' ) - mock_get.return_value.status_code = 429 - mock_get.return_value.headers = { + mock_request.return_value.status_code = 429 + mock_request.return_value.headers = { "x-ratelimit-limit": "3", "x-ratelimit-remaining": "6", "x-ratelimit-reset": "9", @@ -385,17 +387,17 @@ def test_get_rate_limit_invalid_retries_above_max(self, mock_get): self.assertEqual(rc._metrics["retries"], rc.MAX_REQUEST_RETRIES()) - @mock.patch("requests.get") - def test_get_rate_limit_retries_use_exponential_backoff(self, mock_get): + @mock.patch("requests.request") + def test_get_rate_limit_retries_use_exponential_backoff(self, mock_request): options = RestClientOptions(telemetry=False, retries=10) rc = RestClient(jwt="a-token", options=options) rc._skip_sleep = True - mock_get.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 429, "errorCode": "code", "message": "message"}' ) - mock_get.return_value.status_code = 429 - mock_get.return_value.headers = { + mock_request.return_value.status_code = 429 + mock_request.return_value.headers = { "x-ratelimit-limit": "3", "x-ratelimit-remaining": "6", "x-ratelimit-reset": "9", @@ -473,32 +475,34 @@ def test_get_rate_limit_retries_use_exponential_backoff(self, mock_get): # Ensure total delay sum is never more than 10s. self.assertLessEqual(finalBackoff, 10000) - @mock.patch("requests.post") - def test_post(self, mock_post): + @mock.patch("requests.request") + def test_post(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_post.return_value.text = '{"a": "b"}' + mock_request.return_value.text = '{"a": "b"}' data = {"some": "data"} - mock_post.return_value.status_code = 200 + mock_request.return_value.status_code = 200 response = rc.post("the/url", data=data) - mock_post.assert_called_with("the/url", json=data, headers=headers, timeout=5.0) + mock_request.assert_called_with( + "POST", "the/url", json=data, headers=headers, timeout=5.0 + ) self.assertEqual(response, {"a": "b"}) - @mock.patch("requests.post") - def test_post_errors(self, mock_post): + @mock.patch("requests.request") + def test_post_errors(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) - mock_post.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 999, "errorCode": "code", "message": "message"}' ) - mock_post.return_value.status_code = 999 + mock_request.return_value.status_code = 999 with self.assertRaises(Auth0Error) as context: rc.post("the-url") @@ -507,14 +511,14 @@ def test_post_errors(self, mock_post): self.assertEqual(context.exception.error_code, "code") self.assertEqual(context.exception.message, "message") - @mock.patch("requests.post") - def test_post_errors_with_no_message_property(self, mock_post): + @mock.patch("requests.request") + def test_post_errors_with_no_message_property(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) - mock_post.return_value.text = json.dumps( + mock_request.return_value.text = json.dumps( {"statusCode": 999, "errorCode": "code", "error": "error"} ) - mock_post.return_value.status_code = 999 + mock_request.return_value.status_code = 999 with self.assertRaises(Auth0Error) as context: rc.post("the-url") @@ -523,14 +527,14 @@ def test_post_errors_with_no_message_property(self, mock_post): self.assertEqual(context.exception.error_code, "code") self.assertEqual(context.exception.message, "error") - @mock.patch("requests.post") - def test_post_errors_with_no_message_or_error_property(self, mock_post): + @mock.patch("requests.request") + def test_post_errors_with_no_message_or_error_property(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) - mock_post.return_value.text = json.dumps( + mock_request.return_value.text = json.dumps( {"statusCode": 999, "errorCode": "code"} ) - mock_post.return_value.status_code = 999 + mock_request.return_value.status_code = 999 with self.assertRaises(Auth0Error) as context: rc.post("the-url") @@ -539,11 +543,11 @@ def test_post_errors_with_no_message_or_error_property(self, mock_post): self.assertEqual(context.exception.error_code, "code") self.assertEqual(context.exception.message, "") - @mock.patch("requests.post") - def test_post_errors_with_message_and_error_property(self, mock_post): + @mock.patch("requests.request") + def test_post_errors_with_message_and_error_property(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) - mock_post.return_value.text = json.dumps( + mock_request.return_value.text = json.dumps( { "statusCode": 999, "errorCode": "code", @@ -551,7 +555,7 @@ def test_post_errors_with_message_and_error_property(self, mock_post): "message": "message", } ) - mock_post.return_value.status_code = 999 + mock_request.return_value.status_code = 999 with self.assertRaises(Auth0Error) as context: rc.post("the-url") @@ -560,13 +564,13 @@ def test_post_errors_with_message_and_error_property(self, mock_post): self.assertEqual(context.exception.error_code, "code") self.assertEqual(context.exception.message, "message") - @mock.patch("requests.post") - def test_post_error_with_code_property(self, mock_post): + @mock.patch("requests.request") + def test_post_error_with_code_property(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) for error_status in [400, 500, None]: - mock_post.return_value.status_code = error_status - mock_post.return_value.text = '{"errorCode": "e0","message": "desc"}' + mock_request.return_value.status_code = error_status + mock_request.return_value.text = '{"errorCode": "e0","message": "desc"}' with self.assertRaises(Auth0Error) as context: rc.post("the-url") @@ -575,13 +579,13 @@ def test_post_error_with_code_property(self, mock_post): self.assertEqual(context.exception.error_code, "e0") self.assertEqual(context.exception.message, "desc") - @mock.patch("requests.post") - def test_post_error_with_no_error_code(self, mock_post): + @mock.patch("requests.request") + def test_post_error_with_no_error_code(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) for error_status in [400, 500, None]: - mock_post.return_value.status_code = error_status - mock_post.return_value.text = '{"message": "desc"}' + mock_request.return_value.status_code = error_status + mock_request.return_value.text = '{"message": "desc"}' with self.assertRaises(Auth0Error) as context: rc.post("the-url") @@ -590,13 +594,13 @@ def test_post_error_with_no_error_code(self, mock_post): self.assertEqual(context.exception.error_code, "a0.sdk.internal.unknown") self.assertEqual(context.exception.message, "desc") - @mock.patch("requests.post") - def test_post_error_with_text_response(self, mock_post): + @mock.patch("requests.request") + def test_post_error_with_text_response(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) for error_status in [400, 500, None]: - mock_post.return_value.status_code = error_status - mock_post.return_value.text = "there has been a terrible error" + mock_request.return_value.status_code = error_status + mock_request.return_value.text = "there has been a terrible error" with self.assertRaises(Auth0Error) as context: rc.post("the-url") @@ -607,13 +611,13 @@ def test_post_error_with_text_response(self, mock_post): context.exception.message, "there has been a terrible error" ) - @mock.patch("requests.post") - def test_post_error_with_no_response_text(self, mock_post): + @mock.patch("requests.request") + def test_post_error_with_no_response_text(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) for error_status in [400, 500, None]: - mock_post.return_value.status_code = error_status - mock_post.return_value.text = None + mock_request.return_value.status_code = error_status + mock_request.return_value.text = None with self.assertRaises(Auth0Error) as context: rc.post("the-url") @@ -622,48 +626,50 @@ def test_post_error_with_no_response_text(self, mock_post): self.assertEqual(context.exception.error_code, "a0.sdk.internal.unknown") self.assertEqual(context.exception.message, "") - @mock.patch("requests.post") - def test_file_post_content_type_is_none(self, mock_post): + @mock.patch("requests.request") + def test_file_post_content_type_is_none(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) headers = {"Authorization": "Bearer a-token"} - mock_post.return_value.status_code = 200 - mock_post.return_value.text = "Success" + mock_request.return_value.status_code = 200 + mock_request.return_value.text = "Success" data = {"some": "data"} files = [mock.Mock()] rc.file_post("the-url", data=data, files=files) - mock_post.assert_called_once_with( - "the-url", data=data, files=files, headers=headers, timeout=5.0 + mock_request.assert_called_once_with( + "POST", "the-url", data=data, files=files, headers=headers, timeout=5.0 ) - @mock.patch("requests.put") - def test_put(self, mock_put): + @mock.patch("requests.request") + def test_put(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_put.return_value.text = '["a", "b"]' - mock_put.return_value.status_code = 200 + mock_request.return_value.text = '["a", "b"]' + mock_request.return_value.status_code = 200 data = {"some": "data"} response = rc.put(url="the-url", data=data) - mock_put.assert_called_with("the-url", json=data, headers=headers, timeout=5.0) + mock_request.assert_called_with( + "PUT", "the-url", json=data, headers=headers, timeout=5.0 + ) self.assertEqual(response, ["a", "b"]) - @mock.patch("requests.put") - def test_put_errors(self, mock_put): + @mock.patch("requests.request") + def test_put_errors(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) - mock_put.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 999, "errorCode": "code", "message": "message"}' ) - mock_put.return_value.status_code = 999 + mock_request.return_value.status_code = 999 with self.assertRaises(Auth0Error) as context: rc.put(url="the/url") @@ -672,34 +678,34 @@ def test_put_errors(self, mock_put): self.assertEqual(context.exception.error_code, "code") self.assertEqual(context.exception.message, "message") - @mock.patch("requests.patch") - def test_patch(self, mock_patch): + @mock.patch("requests.request") + def test_patch(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_patch.return_value.text = '["a", "b"]' - mock_patch.return_value.status_code = 200 + mock_request.return_value.text = '["a", "b"]' + mock_request.return_value.status_code = 200 data = {"some": "data"} response = rc.patch(url="the-url", data=data) - mock_patch.assert_called_with( - "the-url", json=data, headers=headers, timeout=5.0 + mock_request.assert_called_with( + "PATCH", "the-url", json=data, headers=headers, timeout=5.0 ) self.assertEqual(response, ["a", "b"]) - @mock.patch("requests.patch") - def test_patch_errors(self, mock_patch): + @mock.patch("requests.request") + def test_patch_errors(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) - mock_patch.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 999, "errorCode": "code", "message": "message"}' ) - mock_patch.return_value.status_code = 999 + mock_request.return_value.status_code = 999 with self.assertRaises(Auth0Error) as context: rc.patch(url="the/url") @@ -708,53 +714,58 @@ def test_patch_errors(self, mock_patch): self.assertEqual(context.exception.error_code, "code") self.assertEqual(context.exception.message, "message") - @mock.patch("requests.delete") - def test_delete(self, mock_delete): + @mock.patch("requests.request") + def test_delete(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_delete.return_value.text = '["a", "b"]' - mock_delete.return_value.status_code = 200 + mock_request.return_value.text = '["a", "b"]' + mock_request.return_value.status_code = 200 response = rc.delete(url="the-url/ID") - mock_delete.assert_called_with( - "the-url/ID", headers=headers, params={}, json=None, timeout=5.0 + mock_request.assert_called_with( + "DELETE", "the-url/ID", headers=headers, timeout=5.0 ) self.assertEqual(response, ["a", "b"]) - @mock.patch("requests.delete") - def test_delete_with_body_and_params(self, mock_delete): + @mock.patch("requests.request") + def test_delete_with_body_and_params(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) headers = { "Authorization": "Bearer a-token", "Content-Type": "application/json", } - mock_delete.return_value.text = '["a", "b"]' - mock_delete.return_value.status_code = 200 + mock_request.return_value.text = '["a", "b"]' + mock_request.return_value.status_code = 200 data = {"some": "data"} params = {"A": "param", "B": "param"} response = rc.delete(url="the-url/ID", params=params, data=data) - mock_delete.assert_called_with( - "the-url/ID", headers=headers, params=params, json=data, timeout=5.0 + mock_request.assert_called_with( + "DELETE", + "the-url/ID", + headers=headers, + params=params, + json=data, + timeout=5.0, ) self.assertEqual(response, ["a", "b"]) - @mock.patch("requests.delete") - def test_delete_errors(self, mock_delete): + @mock.patch("requests.request") + def test_delete_errors(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) - mock_delete.return_value.text = ( + mock_request.return_value.text = ( '{"statusCode": 999, "errorCode": "code", "message": "message"}' ) - mock_delete.return_value.status_code = 999 + mock_request.return_value.status_code = 999 with self.assertRaises(Auth0Error) as context: rc.delete(url="the-url") diff --git a/auth0/test_async/test_async_auth0.py b/auth0/test_async/test_async_auth0.py index 46a6a765..c92af99a 100644 --- a/auth0/test_async/test_async_auth0.py +++ b/auth0/test_async/test_async_auth0.py @@ -28,7 +28,7 @@ class TestAuth0(unittest.TestCase): async def test_get(self, mocked): callback, mock = get_callback() - await mocked.get(clients, callback=callback) + mocked.get(clients, callback=callback) auth0 = Auth0(domain="example.com", token="jwt") From 4b9d3e15c9920fa1077f986b4069dc07af8227f5 Mon Sep 17 00:00:00 2001 From: Adam Mcgrath Date: Fri, 25 Aug 2023 14:22:36 +0100 Subject: [PATCH 2/4] Retry all methods for async --- auth0/rest.py | 38 +++++++++++++----------------- auth0/rest_async.py | 56 ++++++++++++++++++++++++--------------------- 2 files changed, 46 insertions(+), 48 deletions(-) diff --git a/auth0/rest.py b/auth0/rest.py index c6eb4e3f..0b91323d 100644 --- a/auth0/rest.py +++ b/auth0/rest.py @@ -146,11 +146,6 @@ def _request( headers: dict[str, str] | None = None, files: dict[str, Any] | None = None, ) -> Any: - request_headers = self.base_headers.copy() - request_headers.update(headers or {}) - if files: - request_headers.pop("Content-Type") - # Track the API request attempt number attempt = 0 @@ -163,7 +158,7 @@ def _request( "params": params, "json": json, "data": data, - "headers": request_headers, + "headers": headers, "files": files, "timeout": self.options.timeout, }.items() @@ -197,7 +192,9 @@ def get( params: dict[str, Any] | None = None, headers: dict[str, str] | None = None, ) -> Any: - return self._request("GET", url, params=params, headers=headers) + request_headers = self.base_headers.copy() + request_headers.update(headers or {}) + return self._request("GET", url, params=params, headers=request_headers) def post( self, @@ -205,7 +202,9 @@ def post( data: RequestData | None = None, headers: dict[str, str] | None = None, ) -> Any: - return self._request("POST", url, json=data, headers=headers) + request_headers = self.base_headers.copy() + request_headers.update(headers or {}) + return self._request("POST", url, json=data, headers=request_headers) def file_post( self, @@ -213,18 +212,17 @@ def file_post( data: RequestData | None = None, files: dict[str, Any] | None = None, ) -> Any: - return self._request( - "POST", - url, - data=data, - files=files, - ) + headers = self.base_headers.copy() + headers.pop("Content-Type", None) + return self._request("POST", url, data=data, files=files, headers=headers) def patch(self, url: str, data: RequestData | None = None) -> Any: - return self._request("PATCH", url, json=data) + headers = self.base_headers.copy() + return self._request("PATCH", url, json=data, headers=headers) def put(self, url: str, data: RequestData | None = None) -> Any: - return self._request("PUT", url, json=data) + headers = self.base_headers.copy() + return self._request("PUT", url, json=data, headers=headers) def delete( self, @@ -232,12 +230,8 @@ def delete( params: dict[str, Any] | None = None, data: RequestData | None = None, ) -> Any: - return self._request( - "DELETE", - url, - params=params, - json=data, - ) + headers = self.base_headers.copy() + return self._request("DELETE", url, params=params, json=data, headers=headers) def _calculate_wait(self, attempt: int) -> int: # Retry the request. Apply a exponential backoff for subsequent attempts, using this formula: diff --git a/auth0/rest_async.py b/auth0/rest_async.py index 5ac4e6bf..0581b812 100644 --- a/auth0/rest_async.py +++ b/auth0/rest_async.py @@ -52,43 +52,23 @@ def set_session(self, session: aiohttp.ClientSession) -> None: """ self._session = session - async def _request(self, *args: Any, **kwargs: Any) -> Any: - kwargs["headers"] = kwargs.get("headers", self.base_headers) - kwargs["timeout"] = self.timeout - if self._session is not None: - # Request with re-usable session - async with self._session.request(*args, **kwargs) as response: - return await self._process_response(response) - else: - # Request without re-usable session - async with aiohttp.ClientSession() as session: - async with session.request(*args, **kwargs) as response: - return await self._process_response(response) - - async def get( - self, - url: str, - params: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, + async def _request_with_session( + self, session: aiohttp.ClientSession, *args: Any, **kwargs: Any ) -> Any: - request_headers = self.base_headers.copy() - request_headers.update(headers or {}) # Track the API request attempt number attempt = 0 # Reset the metrics tracker self._metrics = {"retries": 0, "backoff": []} - params = _clean_params(params) while True: # Increment attempt number attempt += 1 try: - response = await self._request( - "get", url, params=params, headers=request_headers - ) - return response + async with session.request(*args, **kwargs) as response: + return await self._process_response(response) + except RateLimitError as e: # If the attempt number is greater than the configured retries, raise RateLimitError if attempt > self._retries: @@ -101,6 +81,30 @@ async def get( # sleep() functions in seconds, so convert the milliseconds formula above accordingly await asyncio.sleep(wait / 1000) + async def _request(self, *args: Any, **kwargs: Any) -> Any: + kwargs["headers"] = kwargs.get("headers", self.base_headers) + kwargs["timeout"] = self.timeout + if self._session is not None: + # Request with re-usable session + return self._request_with_session(self.session, *args, **kwargs) + else: + # Request without re-usable session + async with aiohttp.ClientSession() as session: + return self._request_with_session(session, *args, **kwargs) + + async def get( + self, + url: str, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: + request_headers = self.base_headers.copy() + request_headers.update(headers or {}) + + return await self._request( + "get", url, params=_clean_params(params), headers=request_headers + ) + async def post( self, url: str, @@ -118,7 +122,7 @@ async def file_post( files: dict[str, Any], ) -> Any: headers = self.base_headers.copy() - headers.pop("Content-Type", None) + headers.pop("Content-Type") return await self._request("post", url, data={**data, **files}, headers=headers) async def patch(self, url: str, data: RequestData | None = None) -> Any: From 68d71c0843397d102b9c421bbb58c12aee6ac32d Mon Sep 17 00:00:00 2001 From: Adam Mcgrath Date: Fri, 25 Aug 2023 14:47:03 +0100 Subject: [PATCH 3/4] Add tests --- auth0/test/management/test_rest.py | 20 ++++++++++++++++++-- auth0/test_async/test_asyncify.py | 14 ++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/auth0/test/management/test_rest.py b/auth0/test/management/test_rest.py index 3495c8b4..7113c446 100644 --- a/auth0/test/management/test_rest.py +++ b/auth0/test/management/test_rest.py @@ -4,8 +4,6 @@ import unittest from unittest import mock -import requests - from auth0.rest import RestClient, RestClientOptions from ...exceptions import Auth0Error, RateLimitError @@ -475,6 +473,24 @@ def test_get_rate_limit_retries_use_exponential_backoff(self, mock_request): # Ensure total delay sum is never more than 10s. self.assertLessEqual(finalBackoff, 10000) + @mock.patch("requests.request") + def test_post_rate_limit_retries(self, mock_request): + options = RestClientOptions(telemetry=False, retries=10) + rc = RestClient(jwt="a-token", options=options) + rc._skip_sleep = True + + mock_request.return_value.text = ( + '{"statusCode": 429, "errorCode": "code", "message": "message"}' + ) + mock_request.return_value.status_code = 429 + + with self.assertRaises(Auth0Error) as context: + rc.post("the/url") + + self.assertEqual(context.exception.status_code, 429) + + self.assertEqual(len(rc._metrics["backoff"]), 10) + @mock.patch("requests.request") def test_post(self, mock_request): rc = RestClient(jwt="a-token", telemetry=False) diff --git a/auth0/test_async/test_asyncify.py b/auth0/test_async/test_asyncify.py index c133e3c9..2c0317e6 100644 --- a/auth0/test_async/test_asyncify.py +++ b/auth0/test_async/test_asyncify.py @@ -233,6 +233,20 @@ async def test_rate_limit(self, mocked): (a, b, c) = rest_client._metrics["backoff"] self.assertTrue(100 <= a < b < c <= 1000) + @pytest.mark.asyncio + @aioresponses() + async def test_rate_limit_post(self, mocked): + callback, mock = get_callback(status=429) + await mocked.post(clients, callback=callback) + await mocked.post(clients, callback=callback) + await mocked.post(clients, callback=callback) + await mocked.post(clients, payload=payload) + c = asyncify(Clients)(domain="example.com", token="jwt") + rest_client = c._async_client.client + rest_client._skip_sleep = True + self.assertEqual(await c.all_async(), payload) + self.assertEqual(3, mock.call_count) + @pytest.mark.asyncio @aioresponses() async def test_timeout(self, mocked): From 0e946e2e557c45210222407b9140dff029fcc8e0 Mon Sep 17 00:00:00 2001 From: Adam Mcgrath Date: Fri, 25 Aug 2023 15:30:15 +0100 Subject: [PATCH 4/4] revert await --- auth0/test_async/test_async_auth0.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auth0/test_async/test_async_auth0.py b/auth0/test_async/test_async_auth0.py index c92af99a..46a6a765 100644 --- a/auth0/test_async/test_async_auth0.py +++ b/auth0/test_async/test_async_auth0.py @@ -28,7 +28,7 @@ class TestAuth0(unittest.TestCase): async def test_get(self, mocked): callback, mock = get_callback() - mocked.get(clients, callback=callback) + await mocked.get(clients, callback=callback) auth0 = Auth0(domain="example.com", token="jwt")