Skip to content

Commit

Permalink
Convert the TokenRateLimit to a TokenManager and centralize some func…
Browse files Browse the repository at this point in the history
…tionality
  • Loading branch information
TrishGillett committed Jul 30, 2024
1 parent ef782ce commit 15891a4
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 54 deletions.
114 changes: 60 additions & 54 deletions tap_github/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,27 @@
from datetime import datetime
from os import environ
from random import choice, shuffle
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set

import jwt
import requests
from singer_sdk.authenticators import APIAuthenticatorBase
from singer_sdk.streams import RESTStream


class TokenRateLimit:
"""A class to store token rate limiting information."""
class TokenManager:
"""A class to store a token's attributes and state."""

DEFAULT_RATE_LIMIT = 5000
# The DEFAULT_RATE_LIMIT_BUFFER buffer serves two purposes:
# - keep some leeway and rotate tokens before erroring out on rate limit.
# - not consume all available calls when we rare using an org or user token.
DEFAULT_RATE_LIMIT_BUFFER = 1000

def __init__(self, token: str, rate_limit_buffer: Optional[int] = None):
"""Init TokenRateLimit info."""
def __init__(self, token: str, rate_limit_buffer: Optional[int] = None, logger: Optional[Any] = None):
"""Init TokenManager info."""
self.token = token
self.logger = logger
self.rate_limit = self.DEFAULT_RATE_LIMIT
self.rate_limit_remaining = self.DEFAULT_RATE_LIMIT
self.rate_limit_reset: Optional[int] = None
Expand All @@ -41,7 +42,28 @@ def update_rate_limit(self, response_headers: Any) -> None:
self.rate_limit_reset = int(response_headers["X-RateLimit-Reset"])
self.rate_limit_used = int(response_headers["X-RateLimit-Used"])

def is_valid(self) -> bool:
def is_valid_token(self) -> bool:
"""Try making a request with the current token. If the request succeeds return True, else False."""
try:
response = requests.get(
url="https://api.github.com/rate_limit",
headers={
"Authorization": f"token {self.token}",
},
)
response.raise_for_status()
return True
except requests.exceptions.HTTPError:
msg = (
f"A token was dismissed. "
f"{response.status_code} Client Error: "
f"{str(response.content)} (Reason: {response.reason})"
)
if self.logger is not None:
self.logger.warning(msg)
return False

def has_calls_remaining(self) -> bool:
"""Check if token is valid.
Returns:
Expand Down Expand Up @@ -113,25 +135,33 @@ def generate_app_access_token(
class GitHubTokenAuthenticator(APIAuthenticatorBase):
"""Base class for offloading API auth."""

def prepare_tokens(self) -> Dict[str, TokenRateLimit]:
def prepare_tokens(self) -> List[TokenManager]:
# Save GitHub tokens
available_tokens: List[str] = []
rate_limit_buffer = self._config.get("rate_limit_buffer", None)

personal_tokens: Set[str] = set()
if "auth_token" in self._config:
available_tokens = available_tokens + [self._config["auth_token"]]
personal_tokens.add(self._config["auth_token"])
if "additional_auth_tokens" in self._config:
available_tokens = available_tokens + self._config["additional_auth_tokens"]
personal_tokens = personal_tokens.union(self._config["additional_auth_tokens"])
else:
# Accept multiple tokens using environment variables GITHUB_TOKEN*
env_tokens = [
env_tokens = {
value
for key, value in environ.items()
if key.startswith("GITHUB_TOKEN")
]
}
if len(env_tokens) > 0:
self.logger.info(
f"Found {len(env_tokens)} 'GITHUB_TOKEN' environment variables for authentication."
)
available_tokens = env_tokens
personal_tokens = env_tokens

token_managers: List[TokenManager] = []
for token in personal_tokens:
token_manager = TokenManager(token, rate_limit_buffer=rate_limit_buffer, logger=self.logger)
if token_manager.is_valid_token():
token_managers.append(token_manager)

# Parse App level private key and generate a token
if "GITHUB_APP_PRIVATE_KEY" in environ.keys():
Expand All @@ -152,39 +182,15 @@ def prepare_tokens(self) -> Dict[str, TokenRateLimit]:
app_token = generate_app_access_token(
github_app_id, github_private_key, github_installation_id or None
)
available_tokens = available_tokens + [app_token]

# Get rate_limit_buffer
rate_limit_buffer = self._config.get("rate_limit_buffer", None)

# Dedup tokens and test them
filtered_tokens = []
for token in list(set(available_tokens)):
try:
response = requests.get(
url="https://api.github.com/rate_limit",
headers={
"Authorization": f"token {token}",
},
)
response.raise_for_status()
filtered_tokens.append(token)
except requests.exceptions.HTTPError:
msg = (
f"A token was dismissed. "
f"{response.status_code} Client Error: "
f"{str(response.content)} (Reason: {response.reason})"
)
self.logger.warning(msg)
token_manager = TokenManager(app_token, rate_limit_buffer=rate_limit_buffer, logger=self.logger)
if token_manager.is_valid_token():
token_managers.append(token_manager)

self.logger.info(f"Tap will run with {len(filtered_tokens)} auth tokens")
self.logger.info(f"Tap will run with {len(token_managers)} auth tokens")

# Create a dict of TokenRateLimit
# TODO - separate app_token and add logic to refresh the token
# using generate_app_access_token.
return {
token: TokenRateLimit(token, rate_limit_buffer) for token in filtered_tokens
}
# Create a dict of TokenManager
# TODO - separate app_token and add logic to refresh the token using generate_app_access_token.
return token_managers

def __init__(self, stream: RESTStream) -> None:
"""Init authenticator.
Expand All @@ -196,18 +202,18 @@ def __init__(self, stream: RESTStream) -> None:
self.logger: logging.Logger = stream.logger
self.tap_name: str = stream.tap_name
self._config: Dict[str, Any] = dict(stream.config)
self.tokens_map = self.prepare_tokens()
self.active_token: Optional[TokenRateLimit] = (
choice(list(self.tokens_map.values())) if len(self.tokens_map) else None
self.token_managers = self.prepare_tokens()
self.active_token: Optional[TokenManager] = (
choice(self.token_managers) if len(self.token_managers) else None
)

def get_next_auth_token(self) -> None:
tokens_list = list(self.tokens_map.items())
token_managers = self.token_managers
current_token = self.active_token.token if self.active_token else ""
shuffle(tokens_list)
for _, token_rate_limit in tokens_list:
if token_rate_limit.is_valid() and current_token != token_rate_limit.token:
self.active_token = token_rate_limit
shuffle(token_managers)
for token_manager in token_managers:
if token_manager.has_calls_remaining() and current_token != token_manager.token:
self.active_token = token_manager
self.logger.info(f"Switching to fresh auth token")
return

Expand All @@ -219,7 +225,7 @@ def update_rate_limit(
self, response_headers: requests.models.CaseInsensitiveDict
) -> None:
# If no token or only one token is available, return early.
if len(self.tokens_map) <= 1 or self.active_token is None:
if len(self.token_managers) <= 1 or self.active_token is None:
return

self.active_token.update_rate_limit(response_headers)
Expand All @@ -236,7 +242,7 @@ def auth_headers(self) -> Dict[str, str]:
result = super().auth_headers
if self.active_token:
# Make sure that our token is still valid or update it.
if not self.active_token.is_valid():
if not self.active_token.has_calls_remaining():
self.get_next_auth_token()
result["Authorization"] = f"token {self.active_token.token}"
else:
Expand Down
112 changes: 112 additions & 0 deletions tap_github/tests/test_authenticator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from datetime import datetime, timedelta
import pytest
import requests
from unittest.mock import patch, MagicMock
from tap_github.authenticator import TokenManager


class TestTokenManager():

def test_default_rate_limits(self):
token_manager = TokenManager("mytoken", rate_limit_buffer=700)

assert token_manager.rate_limit == 5000
assert token_manager.rate_limit_remaining == 5000
assert token_manager.rate_limit_reset is None
assert token_manager.rate_limit_used == 0
assert token_manager.rate_limit_buffer == 700

token_manager_2 = TokenManager("mytoken")
assert token_manager_2.rate_limit_buffer == 1000

def test_update_rate_limit(self):
mock_response_headers = {
"X-RateLimit-Limit": "5000",
"X-RateLimit-Remaining": "4999",
"X-RateLimit-Reset": "1372700873",
"X-RateLimit-Used": "1"
}

token_manager = TokenManager("mytoken")
token_manager.update_rate_limit(mock_response_headers)

assert token_manager.rate_limit == 5000
assert token_manager.rate_limit_remaining == 4999
assert token_manager.rate_limit_reset == 1372700873
assert token_manager.rate_limit_used == 1

def test_is_valid_token_successful(self):
with patch('requests.get') as mock_get:
mock_response = mock_get.return_value
mock_response.raise_for_status.return_value = None

token_manager = TokenManager("validtoken")

assert token_manager.is_valid_token()
mock_get.assert_called_once_with(
url="https://api.github.com/rate_limit",
headers={"Authorization": "token validtoken"}
)

def test_is_valid_token_failure(self):
with patch('requests.get') as mock_get:
# Setup for a failed request
mock_response = mock_get.return_value
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError()
mock_response.status_code = 401
mock_response.content = b'Unauthorized Access'
mock_response.reason = 'Unauthorized'

token_manager = TokenManager("invalidtoken")
token_manager.logger = MagicMock()

assert not token_manager.is_valid_token()
token_manager.logger.warning.assert_called_once()
assert "401" in token_manager.logger.warning.call_args[0][0]

def test_has_calls_remaining_succeeds_if_token_never_used(self):
token_manager = TokenManager("mytoken")
assert token_manager.has_calls_remaining()

def test_has_calls_remaining_succeeds_if_lots_remaining(self):
mock_response_headers = {
"X-RateLimit-Limit": "5000",
"X-RateLimit-Remaining": "4999",
"X-RateLimit-Reset": "1372700873",
"X-RateLimit-Used": "1"
}

token_manager = TokenManager("mytoken")
token_manager.update_rate_limit(mock_response_headers)

assert token_manager.has_calls_remaining()

def test_has_calls_remaining_succeeds_if_reset_time_reached(self):
mock_response_headers = {
"X-RateLimit-Limit": "5000",
"X-RateLimit-Remaining": "1",
"X-RateLimit-Reset": "1372700873",
"X-RateLimit-Used": "4999"
}

token_manager = TokenManager("mytoken", rate_limit_buffer=1000)
token_manager.update_rate_limit(mock_response_headers)

assert token_manager.has_calls_remaining()

def test_has_calls_remaining_fails_if_few_calls_remaining_and_reset_time_not_reached(self):
mock_response_headers = {
"X-RateLimit-Limit": "5000",
"X-RateLimit-Remaining": "1",
"X-RateLimit-Reset": str(int((datetime.now() + timedelta(days=100)).timestamp())),
"X-RateLimit-Used": "4999"
}

token_manager = TokenManager("mytoken", rate_limit_buffer=1000)
token_manager.update_rate_limit(mock_response_headers)

assert not token_manager.has_calls_remaining()




0 comments on commit 15891a4

Please sign in to comment.