From 45402e4405518ad9914e3db0ba3ff53d32146791 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Thu, 4 Jul 2024 17:50:08 -0600 Subject: [PATCH 1/6] Lint with Ruff --- .pre-commit-config.yaml | 23 +-- pyproject.toml | 30 +++- tap_github/authenticator.py | 83 +++++----- tap_github/client.py | 162 +++++++++++-------- tap_github/organization_streams.py | 45 +++--- tap_github/repository_streams.py | 246 ++++++++++++++--------------- tap_github/scraping.py | 73 +++++---- tap_github/streams.py | 13 +- tap_github/tap.py | 25 +-- tap_github/tests/__init__.py | 1 - tap_github/tests/fixtures.py | 55 +++---- tap_github/tests/test_core.py | 26 ++- tap_github/tests/test_tap.py | 80 ++++++---- tap_github/user_streams.py | 88 +++++------ tap_github/utils/filter_stdout.py | 19 +-- tox.ini | 9 -- 16 files changed, 512 insertions(+), 466 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0546aac2..fa3ac809 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-json exclude: "\\.vscode/.*.json" @@ -14,24 +14,15 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace -- repo: https://github.com/asottile/pyupgrade - rev: v3.15.1 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.0 hooks: - - id: pyupgrade - args: [--py37-plus] - -- repo: https://github.com/psf/black - rev: 24.2.0 - hooks: - - id: black - -- repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort + - id: ruff + args: [ --fix ] + - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 + rev: v1.10.1 hooks: - id: mypy pass_filenames: true diff --git a/pyproject.toml b/pyproject.toml index fdeb2b41..67f087b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,13 +50,33 @@ isort = ">=5.10.1" module = ["backoff"] ignore_missing_imports = true -[tool.black] +[tool.ruff] line-length = 88 +target-version = "py38" -[tool.isort] -profile = "black" -multi_line_output = 3 # Vertical Hanging Indent -src_paths = "tap_github" +[tool.ruff.lint] +ignore = [] +select = [ + "A", + "B", + "C901", + "E", + "F", + "FA", + "I", + "N", + "PERF", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "SIM", + "UP", +] + +[tool.ruff.lint.mccabe] +max-complexity = 10 [build-system] requires = ["poetry-core>=1.0.8"] diff --git a/tap_github/authenticator.py b/tap_github/authenticator.py index 7c9528d6..780750eb 100644 --- a/tap_github/authenticator.py +++ b/tap_github/authenticator.py @@ -1,11 +1,14 @@ """Classes to assist in authenticating to the GitHub API.""" +from __future__ import annotations + +import http import logging import time from datetime import datetime from os import environ from random import choice, shuffle -from typing import Any, Dict, List, Optional +from typing import Any import jwt import requests @@ -22,12 +25,12 @@ class TokenRateLimit: # - not consume all available calls when we rare using an org or user token. DEFAULT_RATE_LIMIT_BUFFER = 1000 - def __init__(self, token: str, rate_limit_buffer: Optional[int] = None): + def __init__(self, token: str, rate_limit_buffer: int | None = None): """Init TokenRateLimit info.""" self.token = token self.rate_limit = self.DEFAULT_RATE_LIMIT self.rate_limit_remaining = self.DEFAULT_RATE_LIMIT - self.rate_limit_reset: Optional[int] = None + self.rate_limit_reset: int | None = None self.rate_limit_used = 0 self.rate_limit_buffer = ( rate_limit_buffer @@ -44,17 +47,17 @@ def update_rate_limit(self, response_headers: Any) -> None: def is_valid(self) -> bool: """Check if token is valid. - Returns: + Returns + ------- True if the token is valid and has enough api calls remaining. + """ if self.rate_limit_reset is None: return True - if ( - self.rate_limit_used > (self.rate_limit - self.rate_limit_buffer) - and self.rate_limit_reset > datetime.now().timestamp() - ): - return False - return True + return ( + self.rate_limit_used <= self.rate_limit - self.rate_limit_buffer + or self.rate_limit_reset <= datetime.now().timestamp() + ) def generate_jwt_token( @@ -81,7 +84,7 @@ def generate_jwt_token( def generate_app_access_token( github_app_id: str, github_private_key: str, - github_installation_id: Optional[str] = None, + github_installation_id: str | None = None, ) -> str: jwt_token = generate_jwt_token(github_app_id, github_private_key) @@ -89,7 +92,8 @@ def generate_app_access_token( if github_installation_id is None: list_installations_resp = requests.get( - url="https://api.github.com/app/installations", headers=headers + url="https://api.github.com/app/installations", + headers=headers, ) list_installations_resp.raise_for_status() list_installations = list_installations_resp.json() @@ -99,12 +103,10 @@ def generate_app_access_token( github_installation_id = choice(list_installations)["id"] - url = "https://api.github.com/app/installations/{}/access_tokens".format( - github_installation_id - ) + url = f"https://api.github.com/app/installations/{github_installation_id}/access_tokens" resp = requests.post(url, headers=headers) - if resp.status_code != 201: + if resp.status_code != http.HTTPStatus.CREATED: resp.raise_for_status() return resp.json()["token"] @@ -113,13 +115,13 @@ def generate_app_access_token( class GitHubTokenAuthenticator(APIAuthenticatorBase): """Base class for offloading API auth.""" - def prepare_tokens(self) -> Dict[str, TokenRateLimit]: + def prepare_tokens(self) -> dict[str, TokenRateLimit]: # Save GitHub tokens - available_tokens: List[str] = [] + available_tokens: list[str] = [] if "auth_token" in self._config: - available_tokens = available_tokens + [self._config["auth_token"]] + available_tokens += [self._config["auth_token"]] if "additional_auth_tokens" in self._config: - available_tokens = available_tokens + self._config["additional_auth_tokens"] + available_tokens += self._config["additional_auth_tokens"] else: # Accept multiple tokens using environment variables GITHUB_TOKEN* env_tokens = [ @@ -127,16 +129,16 @@ def prepare_tokens(self) -> Dict[str, TokenRateLimit]: for key, value in environ.items() if key.startswith("GITHUB_TOKEN") ] - if len(env_tokens) > 0: + if env_tokens: self.logger.info( - f"Found {len(env_tokens)} 'GITHUB_TOKEN' environment variables for authentication." + f"Found {len(env_tokens)} 'GITHUB_TOKEN' environment variables for authentication.", # noqa: E501 ) available_tokens = env_tokens # Parse App level private key and generate a token - if "GITHUB_APP_PRIVATE_KEY" in environ.keys(): + if "GITHUB_APP_PRIVATE_KEY" in environ: # To simplify settings, we use a single env-key formatted as follows: - # "{app_id};;{-----BEGIN RSA PRIVATE KEY-----\n_YOUR_PRIVATE_KEY_\n-----END RSA PRIVATE KEY-----}" + # "{app_id};;{-----BEGIN RSA PRIVATE KEY-----\n_YOUR_PRIVATE_KEY_\n-----END RSA PRIVATE KEY-----}" # noqa: E501 parts = environ["GITHUB_APP_PRIVATE_KEY"].split(";;") github_app_id = parts[0] github_private_key = (parts[1:2] or [""])[0].replace("\\n", "\n") @@ -144,13 +146,15 @@ def prepare_tokens(self) -> Dict[str, TokenRateLimit]: if not (github_private_key): self.logger.warning( - "GITHUB_APP_PRIVATE_KEY could not be parsed. The expected format is " - '":app_id:;;-----BEGIN RSA PRIVATE KEY-----\n_YOUR_P_KEY_\n-----END RSA PRIVATE KEY-----"' + "GITHUB_APP_PRIVATE_KEY could not be parsed. The expected format is " # noqa: E501 + '":app_id:;;-----BEGIN RSA PRIVATE KEY-----\n_YOUR_P_KEY_\n-----END RSA PRIVATE KEY-----"', # noqa: E501 ) else: app_token = generate_app_access_token( - github_app_id, github_private_key, github_installation_id or None + github_app_id, + github_private_key, + github_installation_id or None, ) available_tokens = available_tokens + [app_token] @@ -169,11 +173,11 @@ def prepare_tokens(self) -> Dict[str, TokenRateLimit]: ) response.raise_for_status() filtered_tokens.append(token) - except requests.exceptions.HTTPError: + except requests.exceptions.HTTPError: # noqa: PERF203 msg = ( f"A token was dismissed. " f"{response.status_code} Client Error: " - f"{str(response.content)} (Reason: {response.reason})" + f"{response.content!s} (Reason: {response.reason})" ) self.logger.warning(msg) @@ -190,14 +194,16 @@ def __init__(self, stream: RESTStream) -> None: """Init authenticator. Args: + ---- stream: A stream for a RESTful endpoint. + """ super().__init__(stream=stream) self.logger: logging.Logger = stream.logger self.tap_name: str = stream.tap_name - self._config: Dict[str, Any] = dict(stream.config) + self._config: dict[str, Any] = dict(stream.config) self.tokens_map = self.prepare_tokens() - self.active_token: Optional[TokenRateLimit] = ( + self.active_token: TokenRateLimit | None = ( choice(list(self.tokens_map.values())) if len(self.tokens_map) else None ) @@ -208,15 +214,16 @@ def get_next_auth_token(self) -> None: for _, token_rate_limit in tokens_list: if token_rate_limit.is_valid() and current_token != token_rate_limit.token: self.active_token = token_rate_limit - self.logger.info(f"Switching to fresh auth token") + self.logger.info("Switching to fresh auth token") return raise RuntimeError( - "All GitHub tokens have hit their rate limit. Stopping here." + "All GitHub tokens have hit their rate limit. Stopping here.", ) def update_rate_limit( - self, response_headers: requests.models.CaseInsensitiveDict + self, + response_headers: requests.models.CaseInsensitiveDict, ) -> None: # If no token or only one token is available, return early. if len(self.tokens_map) <= 1 or self.active_token is None: @@ -225,13 +232,15 @@ def update_rate_limit( self.active_token.update_rate_limit(response_headers) @property - def auth_headers(self) -> Dict[str, str]: + def auth_headers(self) -> dict[str, str]: """Return a dictionary of auth headers to be applied. These will be merged with any `http_headers` specified in the stream. - Returns: + Returns + ------- HTTP headers for authentication. + """ result = super().auth_headers if self.active_token: @@ -242,6 +251,6 @@ def auth_headers(self) -> Dict[str, str]: else: self.logger.info( "No auth token detected. " - "For higher rate limits, please specify `auth_token` in config." + "For higher rate limits, please specify `auth_token` in config.", ) return result diff --git a/tap_github/client.py b/tap_github/client.py index 8b6f4591..61067cbe 100644 --- a/tap_github/client.py +++ b/tap_github/client.py @@ -1,17 +1,16 @@ """REST client handling, including GitHubStream base class.""" -import collections +from __future__ import annotations + import email.utils +import http import inspect import random -import re import time from types import FrameType -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast +from typing import TYPE_CHECKING, Any, Iterable, cast from urllib.parse import parse_qs, urlparse -import requests -from backoff.types import Details from dateutil.parser import parse from nested_lookup import nested_lookup from singer_sdk.exceptions import FatalAPIError, RetriableAPIError @@ -20,6 +19,10 @@ from tap_github.authenticator import GitHubTokenAuthenticator +if TYPE_CHECKING: + import requests + from backoff.types import Details + EMPTY_REPO_ERROR_STATUS = 409 @@ -27,20 +30,20 @@ class GitHubRestStream(RESTStream): """GitHub Rest stream class.""" MAX_PER_PAGE = 100 # GitHub's limit is 100. - MAX_RESULTS_LIMIT: Optional[int] = None + MAX_RESULTS_LIMIT: int | None = None DEFAULT_API_BASE_URL = "https://api.github.com" LOG_REQUEST_METRIC_URLS = True # GitHub is missing the "since" parameter on a few endpoints - # set this parameter to True if your stream needs to navigate data in descending order - # and try to exit early on its own. + # set this parameter to True if your stream needs to navigate data in descending + # order and try to exit early on its own. # This only has effect on streams whose `replication_key` is `updated_at`. use_fake_since_parameter = False - _authenticator: Optional[GitHubTokenAuthenticator] = None + _authenticator: GitHubTokenAuthenticator | None = None @property - def authenticator(self) -> GitHubTokenAuthenticator: + def authenticator(self) -> GitHubTokenAuthenticator: # noqa: D102 if self._authenticator is None: self._authenticator = GitHubTokenAuthenticator(stream=self) return self._authenticator @@ -50,19 +53,22 @@ def url_base(self) -> str: return self.config.get("api_url_base", self.DEFAULT_API_BASE_URL) primary_keys = ["id"] - replication_key: Optional[str] = None - tolerated_http_errors: List[int] = [] + replication_key: str | None = None + tolerated_http_errors: list[int] = [] @property - def http_headers(self) -> Dict[str, str]: + def http_headers(self) -> dict[str, str]: """Return the http headers needed.""" - headers = {"Accept": "application/vnd.github.v3+json"} - headers["User-Agent"] = cast(str, self.config.get("user_agent", "tap-github")) - return headers + return { + "Accept": "application/vnd.github.v3+json", + "User-Agent": cast(str, self.config.get("user_agent", "tap-github")), + } - def get_next_page_token( - self, response: requests.Response, previous_token: Optional[Any] - ) -> Optional[Any]: + def get_next_page_token( # noqa: PLR0911 + self, + response: requests.Response, + previous_token: Any | None, # noqa: ANN401 + ) -> Any | None: # noqa: ANN401 """Return a token for identifying next page or None if no more pages.""" if ( previous_token @@ -74,23 +80,22 @@ def get_next_page_token( return None # Leverage header links returned by the GitHub API. - if "next" not in response.links.keys(): + if "next" not in response.links: return None resp_json = response.json() - if isinstance(resp_json, list): - results = resp_json - else: - results = resp_json.get("items") + results = resp_json if isinstance(resp_json, list) else resp_json.get("items") - # Exit early if the response has no items. ? Maybe duplicative the "next" link check. + # Exit early if the response has no items. + # Maybe duplicate the "next" link check. if not results: return None - # Unfortunately endpoints such as /starred, /stargazers, /events and /pulls do not support - # the "since" parameter out of the box. So we use a workaround here to exit early. - # For such streams, we sort by descending dates (most recent first), and paginate - # "back in time" until we reach records before our "fake_since" parameter. + # Unfortunately endpoints such as /starred, /stargazers, /events and /pulls do + # not support the "since" parameter out of the box. So we use a workaround here + # to exit early. For such streams, we sort by descending dates + # (most recent first), and paginate "back in time" until we reach records before + # our "fake_since" parameter. if self.replication_key and self.use_fake_since_parameter: request_parameters = parse_qs(str(urlparse(response.request.url).query)) # parse_qs interprets "+" as a space, revert this to keep an aware datetime @@ -109,7 +114,8 @@ def get_next_page_token( else None ) - # commit_timestamp is a constructed key which does not exist in the raw response + # commit_timestamp is a constructed key which does not exist in the raw + # response replication_date = ( results[-1][self.replication_key] if self.replication_key != "commit_timestamp" @@ -135,8 +141,10 @@ def get_next_page_token( return (previous_token or 1) + 1 def get_url_params( - self, context: Optional[Dict], next_page_token: Optional[Any] - ) -> Dict[str, Any]: + self, + context: dict | None, + next_page_token: Any | None, # noqa: ANN401 + ) -> dict[str, Any]: """Return a dictionary of values to be used in URL parameterization.""" params: dict = {"per_page": self.MAX_PER_PAGE} if next_page_token: @@ -146,23 +154,25 @@ def get_url_params( params["sort"] = "updated" params["direction"] = "desc" if self.use_fake_since_parameter else "asc" - # Unfortunately the /starred, /stargazers (starred_at) and /events (created_at) endpoints do not support - # the "since" parameter out of the box. But we use a workaround in 'get_next_page_token'. + # Unfortunately the /starred, /stargazers (starred_at) and /events (created_at) + # endpoints do not support the "since" parameter out of the box. But we use a + # workaround in 'get_next_page_token'. elif self.replication_key in ["starred_at", "created_at"]: params["sort"] = "created" params["direction"] = "desc" - # Warning: /commits endpoint accept "since" but results are ordered by descending commit_timestamp + # Warning: /commits endpoint accept "since" but results are ordered by + # descending commit_timestamp elif self.replication_key == "commit_timestamp": params["direction"] = "desc" elif self.replication_key: self.logger.warning( - f"The replication key '{self.replication_key}' is not fully supported by this client yet." + f"The replication key '{self.replication_key}' is not fully supported by this client yet.", # noqa: E501 ) since = self.get_starting_timestamp(context) - since_key = "since" if not self.use_fake_since_parameter else "fake_since" + since_key = "fake_since" if self.use_fake_since_parameter else "since" if self.replication_key and since: params[since_key] = since # Leverage conditional requests to save API quotas @@ -179,14 +189,17 @@ def validate_response(self, response: requests.Response) -> None: method should raise an :class:`singer_sdk.exceptions.RetriableAPIError`. Args: + ---- response: A `requests.Response`_ object. Raises: + ------ FatalAPIError: If the request is not retriable. RetriableAPIError: If the request is retriable. .. _requests.Response: https://docs.python-requests.org/en/latest/api/#requests.Response + """ full_path = urlparse(response.url).path if response.status_code in ( @@ -199,14 +212,18 @@ def validate_response(self, response: requests.Response) -> None: self.logger.info(msg) return - if 400 <= response.status_code < 500: + if ( + http.HTTPStatus.BAD_REQUEST + <= response.status_code + < http.HTTPStatus.INTERNAL_SERVER_ERROR + ): # noqa: E501 msg = ( f"{response.status_code} Client Error: " - f"{str(response.content)} (Reason: {response.reason}) for path: {full_path}" + f"{response.content!s} (Reason: {response.reason}) for path: {full_path}" # noqa: E501 ) # Retry on rate limiting if ( - response.status_code == 403 + response.status_code == http.HTTPStatus.FORBIDDEN and "rate limit exceeded" in str(response.content).lower() ): # Update token @@ -216,7 +233,7 @@ def validate_response(self, response: requests.Response) -> None: # Retry on secondary rate limit if ( - response.status_code == 403 + response.status_code == http.HTTPStatus.FORBIDDEN and "secondary rate limit" in str(response.content).lower() ): # Wait about a minute and retry @@ -225,31 +242,34 @@ def validate_response(self, response: requests.Response) -> None: # The GitHub API randomly returns 401 Unauthorized errors, so we try again. if ( - response.status_code == 401 + response.status_code == http.HTTPStatus.UNAUTHORIZED # if the token is invalid, we are also told about it - and not "bad credentials" in str(response.content).lower() + and "bad credentials" not in str(response.content).lower() ): raise RetriableAPIError(msg, response) # all other errors are fatal # Note: The API returns a 404 "Not Found" if trying to read a repo # for which the token is not allowed access. - raise FatalAPIError(msg) + raise FatalAPIError(msg) # noqa: PLR2004 - elif 500 <= response.status_code < 600: + if http.HTTPStatus.INTERNAL_SERVER_ERROR <= response.status_code < 600: # noqa: PLR2004 msg = ( f"{response.status_code} Server Error: " - f"{str(response.content)} (Reason: {response.reason}) for path: {full_path}" + f"{response.content!s} (Reason: {response.reason}) for path: {full_path}" # noqa: E501 ) raise RetriableAPIError(msg, response) def parse_response(self, response: requests.Response) -> Iterable[dict]: """Parse the response and return an iterator of result rows.""" - # TODO - Split into handle_reponse and parse_response. + # TODO - Split into handle_response and parse_response. if response.status_code in ( - self.tolerated_http_errors + [EMPTY_REPO_ERROR_STATUS] + { + *self.tolerated_http_errors, + EMPTY_REPO_ERROR_STATUS, + } ): - return [] + return [] # noqa: B901 # Update token rate limit info and loop through tokens if needed. self.authenticator.update_rate_limit(response.headers) @@ -265,7 +285,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: yield from results - def post_process(self, row: dict, context: Optional[Dict[str, str]] = None) -> dict: + def post_process(self, row: dict, context: dict[str, str] | None = None) -> dict: """Add `repo_id` by default to all streams.""" if context is not None and "repo_id" in context: row["repo_id"] = context["repo_id"] @@ -283,7 +303,7 @@ def backoff_handler(self, details: Details) -> None: ).f_locals["e"] if ( exc.response is not None - and exc.response.status_code == 403 + and exc.response.status_code == http.HTTPStatus.FORBIDDEN and "rate limit exceeded" in str(exc.response.content) ): # we hit a rate limit, rotate token @@ -295,8 +315,8 @@ def calculate_sync_cost( self, request: requests.PreparedRequest, response: requests.Response, - context: Optional[dict], - ) -> Dict[str, int]: + context: dict | None, + ) -> dict[str, int]: """Return the cost of the last REST API call.""" return {"rest": 1, "graphql": 0, "search": 0} @@ -315,31 +335,34 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: """Parse the response and return an iterator of result rows. Args: + ---- response: A raw `requests.Response`_ object. Yields: + ------ One item for every item found in the response. .. _requests.Response: https://docs.python-requests.org/en/latest/api/#requests.Response + """ resp_json = response.json() yield from extract_jsonpath(self.query_jsonpath, input=resp_json) def get_next_page_token( - self, response: requests.Response, previous_token: Optional[Any] - ) -> Optional[Any]: - """ - Return a dict of cursors for identifying next page or None if no more pages. + self, + response: requests.Response, + previous_token: Any | None, + ) -> Any | None: + """Return a dict of cursors for identifying next page or None if no more pages. Note - pagination requires the Graphql query to have nextPageCursor_X parameters - with the assosciated hasNextPage_X, startCursor_X and endCursor_X. + with the associated hasNextPage_X, startCursor_X and endCursor_X. X should be an integer between 0 and 9, increasing with query depth. Warning - we recommend to avoid using deep (nested) pagination. """ - resp_json = response.json() # Find if results contains "hasNextPage_X" flags and if any are True. @@ -352,7 +375,7 @@ def get_next_page_token( with_keys=True, ) - has_next_page_indices: List[int] = [] + has_next_page_indices: list[int] = [] # Iterate over all the items and filter items with hasNextPage = True. for key, value in next_page_results.items(): # Check if key is even then add pair to new dictionary @@ -361,7 +384,7 @@ def get_next_page_token( has_next_page_indices.append(pagination_index) # Check if any "hasNextPage" is True. Otherwise, exit early. - if not len(has_next_page_indices) > 0: + if not has_next_page_indices: return None # Get deepest pagination item @@ -369,7 +392,7 @@ def get_next_page_token( # We leverage previous_token to remember the pagination cursors # for indices below max_pagination_index. - next_page_cursors: Dict[str, str] = dict() + next_page_cursors: dict[str, str] = {} for key, value in (previous_token or {}).items(): # Only keep pagination info for indices below max_pagination_index. pagination_index = int(str(key).split("_")[1]) @@ -391,10 +414,12 @@ def get_next_page_token( return next_page_cursors def get_url_params( - self, context: Optional[Dict], next_page_token: Optional[Any] - ) -> Dict[str, Any]: + self, + context: dict | None, + next_page_token: Any | None, + ) -> dict[str, Any]: """Return a dictionary of values to be used in URL parameterization.""" - params = context.copy() if context else dict() + params = context.copy() if context else {} params["per_page"] = self.MAX_PER_PAGE if next_page_token: params.update(next_page_token) @@ -409,8 +434,8 @@ def calculate_sync_cost( self, request: requests.PreparedRequest, response: requests.Response, - context: Optional[dict], - ) -> Dict[str, int]: + context: dict | None, + ) -> dict[str, int]: """Return the cost of the last graphql API call.""" costgen = extract_jsonpath("$.data.rateLimit.cost", input=response.json()) # calculate_sync_cost is called before the main response parsing. @@ -431,11 +456,14 @@ def validate_response(self, response: requests.Response) -> None: at the very minimum. Args: + ---- response: A `requests.Response`_ object. Raises: + ------ FatalAPIError: If the request is not retriable. RetriableAPIError: If the request is retriable. + """ super().validate_response(response) rj = response.json() diff --git a/tap_github/organization_streams.py b/tap_github/organization_streams.py index b4222172..512fbc1d 100644 --- a/tap_github/organization_streams.py +++ b/tap_github/organization_streams.py @@ -1,6 +1,8 @@ """User Stream types classes for tap-github.""" -from typing import Any, Dict, Iterable, List, Optional +from __future__ import annotations + +from typing import Any, Iterable from singer_sdk import typing as th # JSON Schema typing helpers @@ -16,17 +18,16 @@ class OrganizationStream(GitHubRestStream): path = "/orgs/{org}" @property - def partitions(self) -> Optional[List[Dict]]: + def partitions(self) -> list[dict] | None: return [{"org": org} for org in self.config["organizations"]] - def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: + def get_child_context(self, record: dict, context: dict | None) -> dict: return { "org": record["login"], } - def get_records(self, context: Optional[Dict]) -> Iterable[Dict[str, Any]]: - """ - Override the parent method to allow skipping API calls + def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]: + """Override the parent method to allow skipping API calls if the stream is deselected and skip_parent_streams is True in config. This allows running the tap with fewer API calls and preserving quota when only syncing a child stream. Without this, @@ -63,9 +64,7 @@ def get_records(self, context: Optional[Dict]) -> Iterable[Dict[str, Any]]: class TeamsStream(GitHubRestStream): - """ - API Reference: https://docs.github.com/en/rest/reference/teams#list-teams - """ + """API Reference: https://docs.github.com/en/rest/reference/teams#list-teams""" name = "teams" primary_keys = ["id"] @@ -74,14 +73,16 @@ class TeamsStream(GitHubRestStream): parent_stream_type = OrganizationStream state_partitioning_keys = ["org"] - def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: + def get_child_context(self, record: dict, context: dict | None) -> dict: new_context = {"team_slug": record["slug"]} - if context: - return { + return ( + { **context, **new_context, } - return new_context + if context + else new_context + ) schema = th.PropertiesList( # Parent Keys @@ -118,9 +119,7 @@ def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: class TeamMembersStream(GitHubRestStream): - """ - API Reference: https://docs.github.com/en/rest/reference/teams#list-team-members - """ + """API Reference: https://docs.github.com/en/rest/reference/teams#list-team-members""" name = "team_members" primary_keys = ["id", "team_slug"] @@ -129,14 +128,16 @@ class TeamMembersStream(GitHubRestStream): parent_stream_type = TeamsStream state_partitioning_keys = ["team_slug", "org"] - def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: + def get_child_context(self, record: dict, context: dict | None) -> dict: new_context = {"username": record["login"]} - if context: - return { + return ( + { **context, **new_context, } - return new_context + if context + else new_context + ) schema = th.PropertiesList( # Parent keys @@ -156,9 +157,7 @@ def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: class TeamRolesStream(GitHubRestStream): - """ - API Reference: https://docs.github.com/en/rest/reference/teams#get-team-membership-for-a-user - """ + """API Reference: https://docs.github.com/en/rest/reference/teams#get-team-membership-for-a-user""" name = "team_roles" path = "/orgs/{org}/teams/{team_slug}/memberships/{username}" diff --git a/tap_github/repository_streams.py b/tap_github/repository_streams.py index de8371e8..bbe2de68 100644 --- a/tap_github/repository_streams.py +++ b/tap_github/repository_streams.py @@ -1,6 +1,9 @@ """Repository Stream types classes for tap-github.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from __future__ import annotations + +import http +from typing import Any, Iterable from urllib.parse import parse_qs, urlparse import requests @@ -29,8 +32,10 @@ class RepositoryStream(GitHubRestStream): replication_key = "updated_at" def get_url_params( - self, context: Optional[Dict], next_page_token: Optional[Any] - ) -> Dict[str, Any]: + self, + context: dict | None, + next_page_token: Any | None, + ) -> dict[str, Any]: """Return a dictionary of values to be used in URL parameterization.""" assert context is not None, f"Context cannot be empty for '{self.name}' stream." params = super().get_url_params(context, next_page_token) @@ -43,7 +48,6 @@ def get_url_params( @property def path(self) -> str: # type: ignore """Return the API endpoint path. Path options are mutually exclusive.""" - if "searches" in self.config: # Search API max: 1,000 total. self.MAX_RESULTS_LIMIT = 1000 @@ -56,17 +60,14 @@ def path(self) -> str: # type: ignore @property def records_jsonpath(self) -> str: # type: ignore - if "searches" in self.config: - return "$.items[*]" - else: - return "$[*]" + return "$.items[*]" if "searches" in self.config else "$[*]" - def get_repo_ids(self, repo_list: List[Tuple[str]]) -> List[Dict[str, str]]: + def get_repo_ids(self, repo_list: list[tuple[str]]) -> list[dict[str, str]]: # noqa: C901 """Enrich the list of repos with their numeric ID from github. This helps maintain a stable id for context and bookmarks. It uses the github graphql api to fetch the databaseId. - It also removes non-existant repos and corrects casing to ensure + It also removes non-existent repos and corrects casing to ensure data is correct downstream. """ @@ -88,7 +89,7 @@ def query(self) -> str: for i, repo in enumerate(self.repo_list): chunks.append( f'repo{i}: repository(name: "{repo[1]}", owner: "{repo[0]}") ' - "{ nameWithOwner databaseId }" + "{ nameWithOwner databaseId }", ) return "query {" + " ".join(chunks) + " rateLimit { cost } }" @@ -116,7 +117,7 @@ def validate_response(self, response: requests.Response) -> None: # Also remove repos which do not exist to avoid crashing further down # the line. for record in temp_stream.request_records({}): - for item in record.keys(): + for item in record: if item == "rateLimit": continue try: @@ -129,7 +130,7 @@ def validate_response(self, response: requests.Response) -> None: repo_full_name = "/".join(repo_list[int(item[4:])]) self.logger.info( f"Repository not found: {repo_full_name} \t" - "Removing it from list" + "Removing it from list", ) continue # check if repo has moved or been renamed @@ -137,17 +138,17 @@ def validate_response(self, response: requests.Response) -> None: # the repo name has changed, log some details, and move on. self.logger.info( f"Repository name changed: {repo_full_name} \t" - f"New name: {name_with_owner}" + f"New name: {name_with_owner}", ) repos_with_ids.append( - {"org": org, "repo": repo, "repo_id": record[item]["databaseId"]} + {"org": org, "repo": repo, "repo_id": record[item]["databaseId"]}, ) self.logger.info(f"Running the tap on {len(repos_with_ids)} repositories") return repos_with_ids @property - def partitions(self) -> Optional[List[Dict[str, str]]]: + def partitions(self) -> list[dict[str, str]] | None: """Return a list of partitions. This is called before syncing records, we use it to fetch some additional @@ -161,9 +162,9 @@ def partitions(self) -> Optional[List[Dict[str, str]]]: if "repositories" in self.config: split_repo_names = list( - map(lambda s: s.split("/"), self.config["repositories"]) + map(lambda s: s.split("/"), self.config["repositories"]), ) - augmented_repo_list = list() + augmented_repo_list = [] # chunk requests to the graphql endpoint to avoid timeouts and other # obscure errors that the api doesn't say much about. The actual limit # seems closer to 1000, use half that to stay safe. @@ -172,10 +173,10 @@ def partitions(self) -> Optional[List[Dict[str, str]]]: self.logger.info(f"Filtering repository list of {list_length} repositories") for ndx in range(0, list_length, chunk_size): augmented_repo_list += self.get_repo_ids( - split_repo_names[ndx : ndx + chunk_size] + split_repo_names[ndx : ndx + chunk_size], ) self.logger.info( - f"Running the tap on {len(augmented_repo_list)} repositories" + f"Running the tap on {len(augmented_repo_list)} repositories", ) return augmented_repo_list @@ -183,7 +184,7 @@ def partitions(self) -> Optional[List[Dict[str, str]]]: return [{"org": org} for org in self.config["organizations"]] return None - def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: + def get_child_context(self, record: dict, context: dict | None) -> dict: """Return a child context object from the record and optional provided context. By default, will return context if provided and otherwise the record dict. @@ -196,9 +197,8 @@ def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: "repo_id": record["id"], } - def get_records(self, context: Optional[Dict]) -> Iterable[Dict[str, Any]]: - """ - Override the parent method to allow skipping API calls + def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]: + """Override the parent method to allow skipping API calls if the stream is deselected and skip_parent_streams is True in config. This allows running the tap with fewer API calls and preserving quota when only syncing a child stream. Without this, @@ -285,10 +285,9 @@ def get_records(self, context: Optional[Dict]) -> Iterable[Dict[str, Any]]: class ReadmeStream(GitHubRestStream): - """ - A stream dedicated to fetching the object version of a README.md. + """A stream dedicated to fetching the object version of a README.md. - Inclduding its content, base64 encoded of the readme in GitHub flavored Markdown. + Including its content, base64 encoded of the readme in GitHub flavored Markdown. For html, see ReadmeHtmlStream. """ @@ -329,8 +328,7 @@ class ReadmeStream(GitHubRestStream): class ReadmeHtmlStream(GitHubRestStream): - """ - A stream dedicated to fetching the HTML version of README.md. + """A stream dedicated to fetching the HTML version of README.md. For the object details, such as path and size, see ReadmeStream. """ @@ -456,8 +454,7 @@ class CommunityProfileStream(GitHubRestStream): class EventsStream(GitHubRestStream): - """ - Defines 'Events' stream. + """Defines 'Events' stream. Issue events are fetched from the repository level (as opposed to per issue) to optimize for API quota usage. """ @@ -472,7 +469,7 @@ class EventsStream(GitHubRestStream): # GitHub is missing the "since" parameter on this endpoint. use_fake_since_parameter = True - def get_records(self, context: Optional[Dict] = None) -> Iterable[Dict[str, Any]]: + def get_records(self, context: dict | None = None) -> Iterable[dict[str, Any]]: """Return a generator of row-type dictionary objects. Each row emitted should be a dictionary of property names to their values. """ @@ -482,12 +479,12 @@ def get_records(self, context: Optional[Dict] = None) -> Iterable[Dict[str, Any] return super().get_records(context) - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: + def post_process(self, row: dict, context: dict | None = None) -> dict: row = super().post_process(row, context) - # TODO - We should think about the best approach to handle this. An alternative would be to - # do a 'dumb' tap that just keeps the same schemas as GitHub without renaming these - # objects to "target_". They are worth keeping, however, as they can be different from - # the parent stream, e.g. for fork/parent PR events. + # TODO - We should think about the best approach to handle this. An alternative + # would be to do a 'dumb' tap that just keeps the same schemas as GitHub without + # renaming these objects to "target_". They are worth keeping, however, as they + # can be different from the parent stream, e.g. for fork/parent PR events. row["target_repo"] = row.pop("repo", None) row["target_org"] = row.pop("org", None) return row @@ -709,7 +706,7 @@ class ReleasesStream(GitHubRestStream): th.Property("created_at", th.DateTimeType), th.Property("updated_at", th.DateTimeType), th.Property("uploader", user_object), - ) + ), ), ), ).to_dict() @@ -724,7 +721,7 @@ class LanguagesStream(GitHubRestStream): state_partitioning_keys = ["repo", "org"] def parse_response(self, response: requests.Response) -> Iterable[dict]: - """Parse the language response and reformat to return as an iterator of [{language_name: Python, bytes: 23}].""" + """Parse the language response and reformat to return as an iterator of [{language_name: Python, bytes: 23}].""" # noqa: E501 if response.status_code in self.tolerated_http_errors: return [] @@ -782,7 +779,7 @@ class CollaboratorsStream(GitHubRestStream): class AssigneesStream(GitHubRestStream): - """Defines 'Assignees' stream which returns possible assignees for issues/prs following GitHub's API convention.""" + """Defines 'Assignees' stream which returns possible assignees for issues/prs following GitHub's API convention.""" # noqa: E501 name = "assignees" path = "/repos/{org}/{repo}/assignees" @@ -810,7 +807,7 @@ class AssigneesStream(GitHubRestStream): class IssuesStream(GitHubRestStream): - """Defines 'Issues' stream which returns Issues and PRs following GitHub's API convention.""" + """Defines 'Issues' stream which returns Issues and PRs following GitHub's API convention.""" # noqa: E501 name = "issues" path = "/repos/{org}/{repo}/issues" @@ -821,13 +818,16 @@ class IssuesStream(GitHubRestStream): state_partitioning_keys = ["repo", "org"] def get_url_params( - self, context: Optional[Dict], next_page_token: Optional[Any] - ) -> Dict[str, Any]: + self, + context: dict | None, + next_page_token: Any | None, + ) -> dict[str, Any]: """Return a dictionary of values to be used in URL parameterization.""" assert context is not None, f"Context cannot be empty for '{self.name}' stream." params = super().get_url_params(context, next_page_token) # Fetch all issues and PRs, regardless of state (OPEN, CLOSED, MERGED). - # To exclude PRs from the issues stream, you can use the Stream Maps in the config. + # To exclude PRs from the issues stream, you can use the Stream Maps in the + # config. # { # // .. # "stream_maps": { @@ -850,7 +850,7 @@ def http_headers(self) -> dict: headers["Accept"] = "application/vnd.github.squirrel-girl-preview" return headers - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: + def post_process(self, row: dict, context: dict | None = None) -> dict: row = super().post_process(row, context) row["type"] = "pull_request" if "pull_request" in row else "issue" if row["body"] is not None: @@ -912,8 +912,7 @@ def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: class IssueCommentsStream(GitHubRestStream): - """ - Defines 'IssueComments' stream. + """Defines 'IssueComments' stream. Issue comments are fetched from the repository level (as opposed to per issue) to optimize for API quota usage. """ @@ -933,7 +932,7 @@ class IssueCommentsStream(GitHubRestStream): # But it is too expensive on large repos and results in a lot of server errors. use_fake_since_parameter = True - def get_records(self, context: Optional[Dict] = None) -> Iterable[Dict[str, Any]]: + def get_records(self, context: dict | None = None) -> Iterable[dict[str, Any]]: """Return a generator of row-type dictionary objects. Each row emitted should be a dictionary of property names to their values. @@ -944,7 +943,7 @@ def get_records(self, context: Optional[Dict] = None) -> Iterable[Dict[str, Any] return super().get_records(context) - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: + def post_process(self, row: dict, context: dict | None = None) -> dict: row = super().post_process(row, context) row["issue_number"] = int(row["issue_url"].split("/")[-1]) if row["body"] is not None: @@ -976,8 +975,7 @@ def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: class IssueEventsStream(GitHubRestStream): - """ - Defines 'IssueEvents' stream. + """Defines 'IssueEvents' stream. Issue events are fetched from the repository level (as opposed to per issue) to optimize for API quota usage. """ @@ -992,7 +990,7 @@ class IssueEventsStream(GitHubRestStream): # GitHub is missing the "since" parameter on this endpoint. use_fake_since_parameter = True - def get_records(self, context: Optional[Dict] = None) -> Iterable[Dict[str, Any]]: + def get_records(self, context: dict | None = None) -> Iterable[dict[str, Any]]: """Return a generator of row-type dictionary objects. Each row emitted should be a dictionary of property names to their values. @@ -1003,14 +1001,14 @@ def get_records(self, context: Optional[Dict] = None) -> Iterable[Dict[str, Any] return super().get_records(context) - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: + def post_process(self, row: dict, context: dict | None = None) -> dict: row = super().post_process(row, context) - if "issue" in row.keys(): + if "issue" in row: row["issue_number"] = int(row["issue"].pop("number")) row["issue_url"] = row["issue"].pop("url") else: self.logger.debug( - f"No issue assosciated with event {row['id']} - {row['event']}." + f"No issue associated with event {row['id']} - {row['event']}.", ) return row @@ -1032,8 +1030,7 @@ def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: class CommitsStream(GitHubRestStream): - """ - Defines the 'Commits' stream. + """Defines the 'Commits' stream. The stream is fetched per repository to optimize for API quota usage. """ @@ -1045,9 +1042,8 @@ class CommitsStream(GitHubRestStream): state_partitioning_keys = ["repo", "org"] ignore_parent_replication_key = True - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: - """ - Add a timestamp top-level field to be used as state replication key. + def post_process(self, row: dict, context: dict | None = None) -> dict: + """Add a timestamp top-level field to be used as state replication key. It's not clear from github's API docs which time (author or committer) is used to compare to the `since` argument that the endpoint supports. """ @@ -1179,8 +1175,10 @@ class PullRequestsStream(GitHubRestStream): use_fake_since_parameter = True def get_url_params( - self, context: Optional[Dict], next_page_token: Optional[Any] - ) -> Dict[str, Any]: + self, + context: dict | None, + next_page_token: Any | None, + ) -> dict[str, Any]: """Return a dictionary of values to be used in URL parameterization.""" assert context is not None, f"Context cannot be empty for '{self.name}' stream." params = super().get_url_params(context, next_page_token) @@ -1199,7 +1197,7 @@ def http_headers(self) -> dict: headers["Accept"] = "application/vnd.github.squirrel-girl-preview" return headers - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: + def post_process(self, row: dict, context: dict | None = None) -> dict: row = super().post_process(row, context) if row["body"] is not None: # some pr bodies include control characters such as \x00 @@ -1216,7 +1214,7 @@ def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: row["reactions"]["minus_one"] = row["reactions"].pop("-1", None) return row - def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: + def get_child_context(self, record: dict, context: dict | None) -> dict: if context: return { "org": context["org"], @@ -1394,8 +1392,9 @@ class PullRequestCommits(GitHubRestStream): "parents", th.ArrayType( th.ObjectType( - th.Property("url", th.StringType), th.Property("sha", th.StringType) - ) + th.Property("url", th.StringType), + th.Property("sha", th.StringType), + ), ), ), th.Property("files", th.ArrayType(files_object)), @@ -1409,7 +1408,7 @@ class PullRequestCommits(GitHubRestStream): ), ).to_dict() - def post_process(self, row: dict, context: Optional[Dict[str, str]] = None) -> dict: + def post_process(self, row: dict, context: dict[str, str] | None = None) -> dict: row = super().post_process(row, context) if context is not None and "pull_number" in context: row["pull_number"] = context["pull_number"] @@ -1443,7 +1442,8 @@ class ReviewsStream(GitHubRestStream): th.ObjectType( th.Property("html", th.ObjectType(th.Property("href", th.StringType))), th.Property( - "pull_request", th.ObjectType(th.Property("href", th.StringType)) + "pull_request", + th.ObjectType(th.Property("href", th.StringType)), ), ), ), @@ -1491,7 +1491,8 @@ class ReviewCommentsStream(GitHubRestStream): th.Property("self", th.ObjectType(th.Property("href", th.StringType))), th.Property("html", th.ObjectType(th.Property("href", th.StringType))), th.Property( - "pull_request", th.ObjectType(th.Property("href", th.StringType)) + "pull_request", + th.ObjectType(th.Property("href", th.StringType)), ), ), ), @@ -1536,17 +1537,17 @@ class ContributorsStream(GitHubRestStream): def parse_response(self, response: requests.Response) -> Iterable[dict]: # TODO: update this and validate_response when # https://github.com/meltano/sdk/pull/1754 is merged - if response.status_code != 200: + if response.status_code != http.HTTPStatus.OK: return [] yield from super().parse_response(response) def validate_response(self, response: requests.Response) -> None: """Allow some specific errors.""" - if response.status_code == 403: + if response.status_code == http.HTTPStatus.FORBIDDEN: contents = response.json() if ( contents["message"] - == "The history or contributor list is too large to list contributors for this repository via the API." + == "The history or contributor list is too large to list contributors for this repository via the API." # noqa: E501 ): self.logger.info( "Skipping repo '%s'. The list of contributors is too large.", @@ -1568,8 +1569,10 @@ class AnonymousContributorsStream(GitHubRestStream): tolerated_http_errors = [204] def get_url_params( - self, context: Optional[Dict], next_page_token: Optional[Any] - ) -> Dict[str, Any]: + self, + context: dict | None, + next_page_token: Any | None, + ) -> dict[str, Any]: """Return a dictionary of values to be used in URL parameterization.""" assert context is not None, f"Context cannot be empty for '{self.name}' stream." params = super().get_url_params(context, next_page_token) @@ -1596,7 +1599,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: class StargazersStream(GitHubRestStream): - """Defines 'Stargazers' stream. Warning: this stream does NOT track star deletions.""" + """Defines 'Stargazers' stream. Warning: this stream does NOT track star deletions.""" # noqa: E501 name = "stargazers_rest" path = "/repos/{org}/{repo}/stargazers" @@ -1611,7 +1614,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO - remove warning with next release. self.logger.warning( - "The stream 'stargazers_rest' is deprecated. Please use the Graphql version instead: 'stargazers'." + "The stream 'stargazers_rest' is deprecated. Please use the Graphql version instead: 'stargazers'.", # noqa: E501 ) @property @@ -1625,10 +1628,8 @@ def http_headers(self) -> dict: headers["Accept"] = "application/vnd.github.v3.star+json" return headers - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: - """ - Add a user_id top-level field to be used as state replication key. - """ + def post_process(self, row: dict, context: dict | None = None) -> dict: + """Add a user_id top-level field to be used as state replication key.""" row = super().post_process(row, context) row["user_id"] = row["user"]["id"] return row @@ -1646,7 +1647,7 @@ def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: class StargazersGraphqlStream(GitHubGraphqlStream): - """Defines 'UserContributedToStream' stream. Warning: this stream 'only' gets the first 100 projects (by stars).""" + """Defines 'UserContributedToStream' stream. Warning: this stream 'only' gets the first 100 projects (by stars).""" # noqa: E501 name = "stargazers" query_jsonpath = "$.data.repository.stargazers.edges.[*]" @@ -1662,23 +1663,21 @@ def __init__(self, *args, **kwargs): # TODO - remove warning with next release. self.logger.warning( "The stream 'stargazers' might conflict with previous implementation. " - "Looking for the older version? Use 'stargazers_rest'." + "Looking for the older version? Use 'stargazers_rest'.", ) - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: - """ - Add a user_id top-level field to be used as state replication key. - """ + def post_process(self, row: dict, context: dict | None = None) -> dict: + """Add a user_id top-level field to be used as state replication key.""" row = super().post_process(row, context) row["user_id"] = row["user"]["id"] return row def get_next_page_token( - self, response: requests.Response, previous_token: Optional[Any] - ) -> Optional[Any]: - """ - Exit early if a since parameter is provided. - """ + self, + response: requests.Response, + previous_token: Any | None, + ) -> Any | None: + """Exit early if a since parameter is provided.""" request_parameters = parse_qs(str(urlparse(response.request.url).query)) # parse_qs interprets "+" as a space, revert this to keep an aware datetime @@ -1691,12 +1690,12 @@ def get_next_page_token( except IndexError: since = "" - # If since parameter is present, try to exit early by looking at the last "starred_at". + # If since parameter is present, try to exit early by looking at the last "starred_at". # noqa: E501 # Noting that we are traversing in DESCENDING order by STARRED_AT. if since: results = list(extract_jsonpath(self.query_jsonpath, input=response.json())) # If no results, return None to exit early. - if len(results) == 0: + if not results: return None last = results[-1] if parse(last["starred_at"]) < parse(since): @@ -1706,7 +1705,7 @@ def get_next_page_token( @property def query(self) -> str: """Return dynamic GraphQL query.""" - # Graphql id is equivalent to REST node_id. To keep the tap consistent, we rename "id" to "node_id". + # Graphql id is equivalent to REST node_id. To keep the tap consistent, we rename "id" to "node_id". # noqa: E501 return """ query repositoryStargazers($repo: String! $org: String! $nextPageCursor_0: String) { repository(name: $repo owner: $org) { @@ -1734,7 +1733,7 @@ def query(self) -> str: cost } } - """ + """ # noqa: E501 schema = th.PropertiesList( # Parent Keys @@ -1749,8 +1748,7 @@ def query(self) -> str: class StatsContributorsStream(GitHubRestStream): - """ - Defines 'StatsContributors' stream. Fetching contributors activity. + """Defines 'StatsContributors' stream. Fetching contributors activity. https://docs.github.com/en/rest/reference/metrics#get-all-contributor-commit-activity """ @@ -1760,12 +1758,12 @@ class StatsContributorsStream(GitHubRestStream): parent_stream_type = RepositoryStream ignore_parent_replication_key = True state_partitioning_keys = ["repo", "org"] - # Note - these queries are expensive and the API might return an HTTP 202 if the response + # Note - these queries are expensive and the API might return an HTTP 202 if the response # noqa: E501 # has not been cached recently. https://docs.github.com/en/rest/reference/metrics#a-word-about-caching tolerated_http_errors = [202, 204] def parse_response(self, response: requests.Response) -> Iterable[dict]: - """Parse the response and return an iterator of flattened contributor activity.""" + """Parse the response and return an iterator of flattened contributor activity.""" # noqa: E501 replacement_keys = { "a": "additions", "c": "commits", @@ -1777,7 +1775,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: weekly_data = contributor_activity["weeks"] for week in weekly_data: # no need to save weeks with no contributions or author. - # if a user has deleted their account, GitHub may surprisingly return author: None. + # if a user has deleted their account, GitHub may surprisingly return author: None. # noqa: E501 author = contributor_activity["author"] if (sum(week[key] for key in ["a", "c", "d"]) == 0) or (author is None): continue @@ -1818,7 +1816,7 @@ class ProjectsStream(GitHubRestStream): parent_stream_type = RepositoryStream state_partitioning_keys = ["repo", "org"] - def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: + def get_child_context(self, record: dict, context: dict | None) -> dict: return { "project_id": record["id"], "repo_id": context["repo_id"] if context else None, @@ -1857,7 +1855,7 @@ class ProjectColumnsStream(GitHubRestStream): parent_stream_type = ProjectsStream state_partitioning_keys = ["project_id", "repo", "org"] - def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: + def get_child_context(self, record: dict, context: dict | None) -> dict: return { "column_id": record["id"], "repo_id": context["repo_id"] if context else None, @@ -1987,7 +1985,7 @@ class WorkflowRunsStream(GitHubRestStream): th.ObjectType( th.Property("id", th.IntegerType), th.Property("number", th.IntegerType), - ) + ), ), ), th.Property("created_at", th.DateTimeType), @@ -2005,7 +2003,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: """Parse the response and return an iterator of result rows.""" yield from extract_jsonpath(self.records_jsonpath, input=response.json()) - def get_child_context(self, record: dict, context: Optional[dict]) -> dict: + def get_child_context(self, record: dict, context: dict | None) -> dict: """Return a child context object from the record and optional provided context. By default, will return context if provided and otherwise the record dict. Developers may override this behavior to send specific information to child @@ -2060,7 +2058,7 @@ class WorkflowRunJobsStream(GitHubRestStream): th.Property("number", th.IntegerType), th.Property("started_at", th.DateTimeType), th.Property("completed_at", th.DateTimeType), - ) + ), ), ), th.Property("check_run_url", th.StringType), @@ -2080,8 +2078,10 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: yield from extract_jsonpath(self.records_jsonpath, input=response.json()) def get_url_params( - self, context: Optional[dict], next_page_token: Optional[Any] - ) -> Dict[str, Any]: + self, + context: dict | None, + next_page_token: Any | None, + ) -> dict[str, Any]: params = super().get_url_params(context, next_page_token) params["filter"] = "all" return params @@ -2106,14 +2106,15 @@ class ExtraMetricsStream(GitHubRestStream): @property def url_base(self) -> str: return self.config.get("api_url_base", self.DEFAULT_API_BASE_URL).replace( - "api.", "" + "api.", + "", ) def parse_response(self, response: requests.Response) -> Iterable[dict]: """Parse the repository main page to extract extra metrics.""" yield from scrape_metrics(response, self.logger) - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: + def post_process(self, row: dict, context: dict | None = None) -> dict: row = super().post_process(row, context) if context is not None: row["repo"] = context["repo"] @@ -2158,17 +2159,18 @@ class DependentsStream(GitHubRestStream): @property def url_base(self) -> str: return self.config.get("api_url_base", self.DEFAULT_API_BASE_URL).replace( - "api.", "" + "api.", + "", ) def parse_response(self, response: requests.Response) -> Iterable[dict]: - """Get the response for the first page and scrape results, potentially iterating through pages.""" + """Get the response for the first page and scrape results, potentially iterating through pages.""" # noqa: E501 yield from scrape_dependents(response, self.logger) - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: + def post_process(self, row: dict, context: dict | None = None) -> dict: new_row = {"dependent": row} new_row = super().post_process(new_row, context) - # we extract dependent_name_with_owner to be able to use it safely as a primary key, + # we extract dependent_name_with_owner to be able to use it safely as a primary key, # noqa: E501 # regardless of the target used. new_row["dependent_name_with_owner"] = row["name_with_owner"] return new_row @@ -2225,10 +2227,8 @@ def http_headers(self) -> dict: headers["Accept"] = "application/vnd.github.hawkgirl-preview+json" return headers - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: - """ - Add a dependency_repo_id top-level field to be used as primary key. - """ + def post_process(self, row: dict, context: dict | None = None) -> dict: + """Add a dependency_repo_id top-level field to be used as primary key.""" row = super().post_process(row, context) row["dependency_repo_id"] = ( row["dependency"]["id"] if row["dependency"] else None @@ -2241,8 +2241,8 @@ def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: @property def query(self) -> str: """Return dynamic GraphQL query.""" - # Graphql id is equivalent to REST node_id. To keep the tap consistent, we rename "id" to "node_id". - # Due to GrapQl nested-pagination limitations, we loop through the top level dependencyGraphManifests one by one. + # Graphql id is equivalent to REST node_id. To keep the tap consistent, we rename "id" to "node_id". # noqa: E501 + # Due to GraphQl nested-pagination limitations, we loop through the top level dependencyGraphManifests one by one. # noqa: E501 return """ query repositoryDependencies($repo: String! $org: String! $nextPageCursor_0: String $nextPageCursor_1: String) { repository(name: $repo owner: $org) { @@ -2287,7 +2287,7 @@ def query(self) -> str: } } - """ + """ # noqa: E501 schema = th.PropertiesList( # Parent Keys @@ -2323,7 +2323,7 @@ class TrafficRestStream(GitHubRestStream): """Base class for Traffic Streams""" def parse_response(self, response: requests.Response) -> Iterable[dict]: - if response.status_code != 200: + if response.status_code != http.HTTPStatus.OK: return [] """Parse the response and return an iterator of result rows.""" @@ -2333,8 +2333,8 @@ def validate_response(self, response: requests.Response) -> None: """Allow some specific errors. Do not raise exceptions if the error says "Must have push access to repository" as we actually expect these in this stream when we don't have write permissions into it. - """ - if response.status_code == 403: + """ # noqa: E501 + if response.status_code == http.HTTPStatus.FORBIDDEN: contents = response.json() if contents["message"] == "Resource not accessible by integration": self.logger.info("Permissions missing to sync stream '%s'", self.name) diff --git a/tap_github/scraping.py b/tap_github/scraping.py index b3cb0d46..966a6264 100644 --- a/tap_github/scraping.py +++ b/tap_github/scraping.py @@ -3,11 +3,13 @@ Inspired by https://github.com/dogsheep/github-to-sqlite/pull/70 """ +from __future__ import annotations + import logging import re import time from datetime import datetime, timezone -from typing import Any, Dict, Iterable, Optional, Union, cast +from typing import Any, Iterable, cast from urllib.parse import urlparse import requests @@ -18,8 +20,9 @@ def scrape_dependents( - response: requests.Response, logger: Optional[logging.Logger] = None -) -> Iterable[Dict[str, Any]]: + response: requests.Response, + logger: logging.Logger | None = None, +) -> Iterable[dict[str, Any]]: from bs4 import BeautifulSoup logger = logger or logging.getLogger("scraping") @@ -30,8 +33,7 @@ def scrape_dependents( options = soup.find_all("a", class_="select-menu-item") links = [] if len(options) > 0: - for link in options: - links.append(link["href"]) + links.extend(link["href"] for link in options) else: links.append(response.url) @@ -41,7 +43,7 @@ def scrape_dependents( yield from _scrape_dependents(f"https://{base_url}/{link}", logger) -def _scrape_dependents(url: str, logger: logging.Logger) -> Iterable[Dict[str, Any]]: +def _scrape_dependents(url: str, logger: logging.Logger) -> Iterable[dict[str, Any]]: # Optional dependency: from bs4 import BeautifulSoup @@ -53,7 +55,7 @@ def _scrape_dependents(url: str, logger: logging.Logger) -> Iterable[Dict[str, A soup = BeautifulSoup(response.content, "html.parser") repo_names = [ - (a["href"] if not isinstance(a["href"], list) else a["href"][0]).lstrip("/") + (a["href"][0] if isinstance(a["href"], list) else a["href"]).lstrip("/") for a in soup.select("a[data-hovercard-type=repository]") ] stars = [ @@ -67,7 +69,7 @@ def _scrape_dependents(url: str, logger: logging.Logger) -> Iterable[Dict[str, A if not len(repo_names) == len(stars) == len(forks): raise IndexError( - "Could not find star and fork info. Maybe the GitHub page format has changed?" + "Could not find star and fork info. Maybe the GitHub page format has changed?", # noqa: E501 ) repos = [ @@ -82,21 +84,21 @@ def _scrape_dependents(url: str, logger: logging.Logger) -> Iterable[Dict[str, A # next page? try: next_link: Tag = soup.select(".paginate-container")[0].find_all( - "a", text="Next" + "a", + text="Next", )[0] except IndexError: break if next_link is not None: href = next_link["href"] - url = str(href if not isinstance(href, list) else href[0]) + url = str(href[0] if isinstance(href, list) else href) time.sleep(1) else: url = "" -def parse_counter(tag: Union[Tag, NavigableString, None]) -> int: - """ - Extract a count of [issues|PR|contributors...] from an HTML tag. +def parse_counter(tag: Tag | NavigableString | None) -> int: + """Extract a count of [issues|PR|contributors...] from an HTML tag. For very high numbers, we only get an approximate value as github does not provide the actual number. """ @@ -111,15 +113,16 @@ def parse_counter(tag: Union[Tag, NavigableString, None]) -> int: else: title_string = cast(str, title[0]) return int(title_string.strip().replace(",", "").replace("+", "")) - except (KeyError, ValueError): - raise IndexError( - f"Could not parse counter {tag}. Maybe the GitHub page format has changed?" - ) + except (KeyError, ValueError) as e: + raise IndexError( # noqa: B904 + f"Could not parse counter {tag}. Maybe the GitHub page format has changed?", + ) from e def scrape_metrics( - response: requests.Response, logger: Optional[logging.Logger] = None -) -> Iterable[Dict[str, Any]]: + response: requests.Response, + logger: logging.Logger | None = None, +) -> Iterable[dict[str, Any]]: from bs4 import BeautifulSoup logger = logger or logging.getLogger("scraping") @@ -129,30 +132,36 @@ def scrape_metrics( try: issues = parse_counter(soup.find("span", id="issues-repo-tab-count")) prs = parse_counter(soup.find("span", id="pull-requests-repo-tab-count")) - except IndexError: + except IndexError as e: # These two items should exist. We raise an error if we could not find them. - raise IndexError( - "Could not find issues or prs info. Maybe the GitHub page format has changed?" - ) + raise IndexError( # noqa: B904 + "Could not find issues or prs info. Maybe the GitHub page format has changed?", # noqa: E501 + ) from e dependents_node = soup.find(string=used_by_regex) # verify that we didn't hit some random text in the page. # sometimes the dependents section isn't shown on the page either dependents_node_parent = getattr(dependents_node, "parent", None) dependents: int = 0 - if dependents_node_parent is not None and "href" in dependents_node_parent: - if dependents_node_parent["href"].endswith("/network/dependents"): - dependents = parse_counter(getattr(dependents_node, "next_element", None)) + if ( + dependents_node_parent is not None + and "href" in dependents_node_parent + and dependents_node_parent["href"].endswith("/network/dependents") + ): + dependents = parse_counter(getattr(dependents_node, "next_element", None)) # likewise, handle edge cases with contributors contributors_node = soup.find(string=contributors_regex) contributors_node_parent = getattr(contributors_node, "parent", None) contributors: int = 0 - if contributors_node_parent is not None and "href" in contributors_node_parent: - if contributors_node_parent["href"].endswith("/graphs/contributors"): - contributors = parse_counter( - getattr(contributors_node, "next_element", None), - ) + if ( + contributors_node_parent is not None + and "href" in contributors_node_parent + and contributors_node_parent["href"].endswith("/graphs/contributors") + ): + contributors = parse_counter( + getattr(contributors_node, "next_element", None), + ) fetched_at = datetime.now(tz=timezone.utc) @@ -163,7 +172,7 @@ def scrape_metrics( "dependents": dependents, "contributors": contributors, "fetched_at": fetched_at, - } + }, ] logger.debug(metrics) diff --git a/tap_github/streams.py b/tap_github/streams.py index e1b05e58..bd29f949 100644 --- a/tap_github/streams.py +++ b/tap_github/streams.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from enum import Enum -from typing import List, Set, Type from singer_sdk.streams.core import Stream @@ -53,14 +54,12 @@ class Streams(Enum): - """ - Represents all streams our tap supports, and which queries (by username, by organization, etc.) you can use. - """ + """Represents all streams our tap supports, and which queries (by username, by organization, etc.) you can use.""" # noqa: E501 - valid_queries: Set[str] - streams: List[Type[Stream]] + valid_queries: set[str] + streams: list[type[Stream]] - def __init__(self, valid_queries: Set[str], streams: List[Type[Stream]]): + def __init__(self, valid_queries: set[str], streams: list[type[Stream]]): self.valid_queries = valid_queries self.streams = streams diff --git a/tap_github/tap.py b/tap_github/tap.py index 5dc028c5..4142748d 100644 --- a/tap_github/tap.py +++ b/tap_github/tap.py @@ -1,8 +1,9 @@ """GitHub tap class.""" +from __future__ import annotations + import logging import os -from typing import List from singer_sdk import Stream, Tap from singer_sdk import typing as th # JSON schema typing helpers @@ -17,14 +18,15 @@ class TapGitHub(Tap): name = "tap-github" @classproperty - def logger(cls) -> logging.Logger: + def logger(cls) -> logging.Logger: # noqa: N805 """Get logger. - Returns: + Returns + ------- Logger with local LOGLEVEL. LOGLEVEL from env takes priority. - """ - LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper() + """ + LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper() # noqa: N806 assert ( LOGLEVEL in logging._levelToName.values() ), f"Invalid LOGLEVEL configuration: {LOGLEVEL}" @@ -44,12 +46,12 @@ def logger(cls) -> logging.Logger: th.Property( "additional_auth_tokens", th.ArrayType(th.StringType), - description="List of GitHub tokens to authenticate with. Streams will loop through them when hitting rate limits.", + description="List of GitHub tokens to authenticate with. Streams will loop through them when hitting rate limits.", # noqa: E501 ), th.Property( "rate_limit_buffer", th.IntegerType, - description="Add a buffer to avoid consuming all query points for the token at hand. Defaults to 1000.", + description="Add a buffer to avoid consuming all query points for the token at hand. Defaults to 1000.", # noqa: E501 ), th.Property( "searches", @@ -57,7 +59,7 @@ def logger(cls) -> logging.Logger: th.ObjectType( th.Property("name", th.StringType, required=True), th.Property("query", th.StringType, required=True), - ) + ), ), ), th.Property("organizations", th.ArrayType(th.StringType)), @@ -77,9 +79,8 @@ def logger(cls) -> logging.Logger: ), ).to_dict() - def discover_streams(self) -> List[Stream]: + def discover_streams(self) -> list[Stream]: """Return a list of discovered streams for each query.""" - # If the config is empty, assume we are running --help or --capabilities. if ( self.config @@ -87,12 +88,12 @@ def discover_streams(self) -> List[Stream]: ): raise ValueError( "This tap requires one and only one of the following path options: " - f"{Streams.all_valid_queries()}." + f"{Streams.all_valid_queries()}.", ) streams = [] for stream_type in Streams: if (not self.config) or len( - stream_type.valid_queries.intersection(self.config) + stream_type.valid_queries.intersection(self.config), ) > 0: streams += [ StreamClass(tap=self) for StreamClass in stream_type.streams diff --git a/tap_github/tests/__init__.py b/tap_github/tests/__init__.py index cf07a069..76c0d4ff 100644 --- a/tap_github/tests/__init__.py +++ b/tap_github/tests/__init__.py @@ -1,6 +1,5 @@ """Test suite for tap-github.""" -import requests import requests_cache # Setup caching for all api calls done through `requests` in order to limit diff --git a/tap_github/tests/fixtures.py b/tap_github/tests/fixtures.py index 169a6d47..2f3eafbb 100644 --- a/tap_github/tests/fixtures.py +++ b/tap_github/tests/fixtures.py @@ -11,7 +11,7 @@ sys.stdout = FilterStdOutput(sys.stdout, r'{"type": ') # type: ignore -@pytest.fixture +@pytest.fixture() def search_config(): return { "metrics_log_level": "warning", @@ -20,15 +20,14 @@ def search_config(): { "name": "tap_something", "query": "tap-+language:Python", - } + }, ], } -@pytest.fixture +@pytest.fixture() def repo_list_config(request): - """ - Get a default list of repos or pass your own by decorating your test with + """Get a default list of repos or pass your own by decorating your test with @pytest.mark.repo_list(['org1/repo1', 'org2/repo2']) """ marker = request.node.get_closest_marker("repo_list") @@ -45,17 +44,13 @@ def repo_list_config(request): } -@pytest.fixture +@pytest.fixture() def username_list_config(request): - """ - Get a default list of usernames or pass your own by decorating your test with + """Get a default list of usernames or pass your own by decorating your test with @pytest.mark.username_list(['ericboucher', 'aaronsteers']) """ marker = request.node.get_closest_marker("username_list") - if marker is None: - username_list = ["ericboucher", "aaronsteers"] - else: - username_list = marker.args[0] + username_list = ["ericboucher", "aaronsteers"] if marker is None else marker.args[0] return { "metrics_log_level": "warning", @@ -65,18 +60,13 @@ def username_list_config(request): } -@pytest.fixture +@pytest.fixture() def user_id_list_config(request): - """ - Get a default list of usernames or pass your own by decorating your test with + """Get a default list of usernames or pass your own by decorating your test with @pytest.mark.user_id_list(['ericboucher', 'aaronsteers']) """ marker = request.node.get_closest_marker("user_id_list") - if marker is None: - user_id_list = [1, 2] - else: - user_id_list = marker.args[0] - + user_id_list = [1, 2] if marker is None else marker.args[0] return { "metrics_log_level": "warning", "start_date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d"), @@ -85,10 +75,9 @@ def user_id_list_config(request): } -@pytest.fixture +@pytest.fixture() def organization_list_config(request): - """ - Get a default list of organizations or pass your own by decorating your test with + """Get a default list of organizations or pass your own by decorating your test with @pytest.mark.organization_list(['MeltanoLabs', 'oviohub']) """ marker = request.node.get_closest_marker("organization_list") @@ -103,26 +92,30 @@ def organization_list_config(request): } -def alternative_sync_chidren(self, child_context: dict, no_sync: bool = True) -> None: - """ - Override for Stream._sync_children. +def alternative_sync_children(self, child_context: dict, no_sync: bool = True) -> None: + """Override for Stream._sync_children. Enabling us to use an ORG_LEVEL_TOKEN for the collaborators stream. """ for child_stream in self.child_streams: # Use org:write access level credentials for collaborators stream - if child_stream.name in ["collaborators"]: - ORG_LEVEL_TOKEN = os.environ.get("ORG_LEVEL_TOKEN") + if child_stream.name == "collaborators": + """ + The `ORG_LEVEL_TOKEN` variable is used to store an organization-level GitHub API token. This token is used when syncing the "collaborators" stream, as it requires a higher level of access than the standard user token. + + If the `ORG_LEVEL_TOKEN` is not found in the environment, a warning is logged and the collaborators stream sync is skipped. + """ # noqa: E501 + ORG_LEVEL_TOKEN = os.environ.get("ORG_LEVEL_TOKEN") # noqa: N806 # TODO - Fix collaborators tests, likely by mocking API responses directly. # Currently we have to bypass them as they are failing frequently. if not ORG_LEVEL_TOKEN or no_sync: logging.warning( - 'No "ORG_LEVEL_TOKEN" found. Skipping collaborators stream sync.' + 'No "ORG_LEVEL_TOKEN" found. Skipping collaborators stream sync.', ) continue - SAVED_GTHUB_TOKEN = os.environ.get("GITHUB_TOKEN") + SAVED_GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN") # noqa: N806 os.environ["GITHUB_TOKEN"] = ORG_LEVEL_TOKEN child_stream.sync(context=child_context) - os.environ["GITHUB_TOKEN"] = SAVED_GTHUB_TOKEN or "" + os.environ["GITHUB_TOKEN"] = SAVED_GITHUB_TOKEN or "" continue # default behavior: diff --git a/tap_github/tests/test_core.py b/tap_github/tests/test_core.py index 66500e11..7697a724 100644 --- a/tap_github/tests/test_core.py +++ b/tap_github/tests/test_core.py @@ -11,11 +11,7 @@ from tap_github.utils.filter_stdout import nostdout from .fixtures import ( - alternative_sync_chidren, - organization_list_config, - repo_list_config, - search_config, - username_list_config, + alternative_sync_children, ) @@ -24,22 +20,22 @@ def test_standard_tap_tests_for_search_mode(search_config): """Run standard tap tests from the SDK.""" tests = get_standard_tap_tests(TapGitHub, config=search_config) with patch( - "singer_sdk.streams.core.Stream._sync_children", alternative_sync_chidren - ): - with nostdout(): - for test in tests: - test() + "singer_sdk.streams.core.Stream._sync_children", + alternative_sync_children, + ), nostdout(): + for test in tests: + test() def test_standard_tap_tests_for_repo_list_mode(repo_list_config): """Run standard tap tests from the SDK.""" tests = get_standard_tap_tests(TapGitHub, config=repo_list_config) with patch( - "singer_sdk.streams.core.Stream._sync_children", alternative_sync_chidren - ): - with nostdout(): - for test in tests: - test() + "singer_sdk.streams.core.Stream._sync_children", + alternative_sync_children, + ), nostdout(): + for test in tests: + test() def test_standard_tap_tests_for_username_list_mode(username_list_config): diff --git a/tap_github/tests/test_tap.py b/tap_github/tests/test_tap.py index f8fcb096..d5301cbb 100644 --- a/tap_github/tests/test_tap.py +++ b/tap_github/tests/test_tap.py @@ -1,7 +1,7 @@ +from __future__ import annotations + import json -import os import re -from typing import Optional from unittest.mock import patch import pytest @@ -13,7 +13,7 @@ from tap_github.scraping import parse_counter from tap_github.tap import TapGitHub -from .fixtures import alternative_sync_chidren, repo_list_config, username_list_config +from .fixtures import alternative_sync_children repo_list_2 = [ "MeltanoLabs/tap-github", @@ -58,10 +58,12 @@ def test_validate_repo_list_config(repo_list_config): def run_tap_with_config( - capsys, config_obj: dict, skip_stream: Optional[str], single_stream: Optional[str] + caps, + config_obj: dict, + skip_stream: str | None, + single_stream: str | None, ) -> str: - """ - Run the tap with the given config and capture stdout, optionally + """Run the tap with the given config and capture stdout, optionally skipping a stream (this is meant to be the top level stream), or running a single one. """ @@ -79,27 +81,31 @@ def run_tap_with_config( cat_helpers.deselect_all_streams(catalog) cat_helpers.set_catalog_stream_selected(catalog, "repositories", selected=True) cat_helpers.set_catalog_stream_selected( - catalog, stream_name=single_stream, selected=True + catalog, + stream_name=single_stream, + selected=True, ) # discard previous output to stdout (potentially from other tests) - capsys.readouterr() + caps.readouterr() with patch( - "singer_sdk.streams.core.Stream._sync_children", alternative_sync_chidren + "singer_sdk.streams.core.Stream._sync_children", + alternative_sync_children, ): tap2 = TapGitHub(config=config_obj, catalog=catalog.to_dict()) tap2.sync_all() - captured = capsys.readouterr() + captured = caps.readouterr() return captured.out @pytest.mark.parametrize("skip_parent_streams", [False, True]) @pytest.mark.repo_list(repo_list_2) def test_get_a_repository_in_repo_list_mode( - capsys, repo_list_config, skip_parent_streams + capsys, + repo_list_config, + skip_parent_streams, ): - """ - Discover the catalog, and request 2 repository records. + """Discover the catalog, and request 2 repository records. The test is parametrized to run twice, with and without syncing the top level `repositories` stream. """ @@ -113,7 +119,7 @@ def test_get_a_repository_in_repo_list_mode( # Verify we got the right number of records # one per repo in the list only if we sync the "repositories" stream, 0 if not assert captured_out.count('{"type": "RECORD", "stream": "repositories"') == len( - repo_list_2_ids * (not skip_parent_streams) + repo_list_2_ids * (not skip_parent_streams), ) # check that the tap corrects invalid case in config input assert '"repo": "Tap-GitLab"' not in captured_out @@ -122,19 +128,22 @@ def test_get_a_repository_in_repo_list_mode( @pytest.mark.repo_list(["MeltanoLabs/tap-github"]) def test_last_state_message_is_valid(capsys, repo_list_config): - """ - Validate that the last state message is not a temporary one and contains the + """Validate that the last state message is not a temporary one and contains the expected values for a stream with overridden state partitioning keys. Run this on a single repo to avoid having to filter messages too much. """ repo_list_config["skip_parent_streams"] = True captured_out = run_tap_with_config( - capsys, repo_list_config, "repositories", single_stream=None + capsys, + repo_list_config, + "repositories", + single_stream=None, ) # capture the messages we're interested in state_messages = re.findall(r'{"type": "STATE", "value":.*}', captured_out) issue_comments_records = re.findall( - r'{"type": "RECORD", "stream": "issue_comments",.*}', captured_out + r'{"type": "RECORD", "stream": "issue_comments",.*}', + captured_out, ) assert state_messages is not None last_state_msg = state_messages[-1] @@ -146,13 +155,13 @@ def test_last_state_message_is_valid(capsys, repo_list_config): last_state_updated_at = isoparse( last_state["value"]["bookmarks"]["issue_comments"]["partitions"][0][ "replication_key_value" - ] + ], ) latest_updated_at = max( map( lambda record: isoparse(json.loads(record)["record"]["updated_at"]), issue_comments_records, - ) + ), ) assert last_state_updated_at == latest_updated_at @@ -162,11 +171,11 @@ def test_last_state_message_is_valid(capsys, repo_list_config): @pytest.mark.parametrize("skip_parent_streams", [False, True]) @pytest.mark.username_list(["EricBoucher", "aaRONsTeeRS"]) def test_get_a_user_in_user_usernames_mode( - capsys, username_list_config, skip_parent_streams + capsys, + username_list_config, + skip_parent_streams, ): - """ - Discover the catalog, and request 2 repository records - """ + """Discover the catalog, and request 2 repository records""" username_list_config["skip_parent_streams"] = skip_parent_streams captured_out = run_tap_with_config( capsys, @@ -177,11 +186,11 @@ def test_get_a_user_in_user_usernames_mode( # Verify we got the right number of records: # one per user in the list if we sync the root stream, 0 otherwise assert captured_out.count('{"type": "RECORD", "stream": "users"') == len( - username_list_config["user_usernames"] * (not skip_parent_streams) + username_list_config["user_usernames"] * (not skip_parent_streams), ) # these 2 are inequalities as number will keep changing :) - assert captured_out.count('{"type": "RECORD", "stream": "starred"') > 150 - assert captured_out.count('{"type": "RECORD", "stream": "user_contributed_to"') > 25 + assert captured_out.count('{"type": "RECORD", "stream": "starred"') > 150 # noqa: PLR2004 + assert captured_out.count('{"type": "RECORD", "stream": "user_contributed_to"') > 25 # noqa: PLR2004 assert '{"username": "aaronsteers"' in captured_out assert '{"username": "aaRONsTeeRS"' not in captured_out assert '{"username": "EricBoucher"' not in captured_out @@ -189,19 +198,20 @@ def test_get_a_user_in_user_usernames_mode( @pytest.mark.repo_list(["torvalds/linux"]) def test_large_list_of_contributors(capsys, repo_list_config): - """ - Check that the github error message for very large lists of contributors + """Check that the github error message for very large lists of contributors is handled properly (does not return any records). """ captured_out = run_tap_with_config( - capsys, repo_list_config, skip_stream=None, single_stream="contributors" + capsys, + repo_list_config, + skip_stream=None, + single_stream="contributors", ) assert captured_out.count('{"type": "RECORD", "stream": "contributors"') == 0 def test_web_tag_parse_counter(): - """ - Check that the parser runs ok on various forms of counters. + """Check that the parser runs ok on various forms of counters. Used in extra_metrics stream. """ # regular int @@ -209,18 +219,18 @@ def test_web_tag_parse_counter(): '57', "html.parser", ).span - assert parse_counter(tag) == 57 + assert parse_counter(tag) == 57 # noqa: PLR2004 # 2k tag = BeautifulSoup( '2k', "html.parser", ).span - assert parse_counter(tag) == 2028 + assert parse_counter(tag) == 2028 # noqa: PLR2004 # 5k+. The real number is not available in the page, use this approx value tag = BeautifulSoup( '5k+', "html.parser", ).span - assert parse_counter(tag) == 5_000 + assert parse_counter(tag) == 5_000 # noqa: PLR2004 diff --git a/tap_github/user_streams.py b/tap_github/user_streams.py index 0cd1fb40..a9c75630 100644 --- a/tap_github/user_streams.py +++ b/tap_github/user_streams.py @@ -1,7 +1,9 @@ """User Stream types classes for tap-github.""" +from __future__ import annotations + import re -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Iterable from singer_sdk import typing as th # JSON Schema typing helpers from singer_sdk.exceptions import FatalAPIError @@ -21,45 +23,46 @@ def path(self) -> str: # type: ignore """Return the API endpoint path.""" if "user_usernames" in self.config: return "/users/{username}" - elif "user_ids" in self.config: + if "user_ids" in self.config: return "/user/{id}" @property - def partitions(self) -> Optional[List[Dict]]: + def partitions(self) -> list[dict] | None: """Return a list of partitions.""" - if "user_usernames" in self.config: - input_user_list = self.config["user_usernames"] - - augmented_user_list = list() - # chunk requests to the graphql endpoint to avoid timeouts and other - # obscure errors that the api doesn't say much about. The actual limit - # seems closer to 1000, use half that to stay safe. - chunk_size = 500 - list_length = len(input_user_list) - self.logger.info(f"Filtering user list of {list_length} users") - for ndx in range(0, list_length, chunk_size): - augmented_user_list += self.get_user_ids( - input_user_list[ndx : ndx + chunk_size] - ) - self.logger.info(f"Running the tap on {len(augmented_user_list)} users") - return augmented_user_list - - elif "user_ids" in self.config: - return [{"id": id} for id in self.config["user_ids"]] - return None - - def get_child_context(self, record: Dict, context: Optional[Dict]) -> dict: + if "user_usernames" not in self.config: + return ( + [{"id": uid} for uid in self.config["user_ids"]] + if "user_ids" in self.config + else None + ) + input_user_list = self.config["user_usernames"] + + augmented_user_list = [] + # chunk requests to the graphql endpoint to avoid timeouts and other + # obscure errors that the api doesn't say much about. The actual limit + # seems closer to 1000, use half that to stay safe. + chunk_size = 500 + list_length = len(input_user_list) + self.logger.info(f"Filtering user list of {list_length} users") + for ndx in range(0, list_length, chunk_size): + augmented_user_list += self.get_user_ids( + input_user_list[ndx : ndx + chunk_size], + ) + self.logger.info(f"Running the tap on {len(augmented_user_list)} users") + return augmented_user_list + + def get_child_context(self, record: dict, context: dict | None) -> dict: return { "username": record["login"], "user_id": record["id"], } - def get_user_ids(self, user_list: List[str]) -> List[Dict[str, str]]: - """Enrich the list of userse with their numeric ID from github. + def get_user_ids(self, user_list: list[str]) -> list[dict[str, str]]: + """Enrich the list of users with their numeric ID from github. This helps maintain a stable id for context and bookmarks. It uses the github graphql api to fetch the databaseId. - It also removes non-existant repos and corrects casing to ensure + It also removes non-existent repos and corrects casing to ensure data is correct downstream. """ @@ -84,7 +87,7 @@ def query(self) -> str: # and the /user endpoint works for all types. chunks.append( f'user{i}: repositoryOwner(login: "{user}") ' - "{ login avatarUrl}" + "{ login avatarUrl}", ) return "query {" + " ".join(chunks) + " rateLimit { cost } }" @@ -94,8 +97,8 @@ def query(self) -> str: users_with_ids: list = list() temp_stream = TempStream(self._tap, list(user_list)) - databaseIdPattern: re.Pattern = re.compile( - r"https://avatars.githubusercontent.com/u/(\d+)?.*" + database_id_pattern: re.Pattern = re.compile( + r"https://avatars.githubusercontent.com/u/(\d+)?.*", ) # replace manually provided org/repo values by the ones obtained # from github api. This guarantees that case is correct in the output data. @@ -103,7 +106,7 @@ def query(self) -> str: # Also remove repos which do not exist to avoid crashing further down # the line. for record in temp_stream.request_records({}): - for item in record.keys(): + for item in record: if item == "rateLimit": continue try: @@ -114,15 +117,15 @@ def query(self) -> str: invalid_username = user_list[int(item[4:])] self.logger.info( f"Username not found: {invalid_username} \t" - "Removing it from list" + "Removing it from list", ) continue # the databaseId (in graphql language) is not available on # repositoryOwner, so we parse the avatarUrl to get it :/ - m = databaseIdPattern.match(record[item]["avatarUrl"]) + m = database_id_pattern.match(record[item]["avatarUrl"]) if m is not None: - dbId = m.group(1) - users_with_ids.append({"username": username, "user_id": dbId}) + db_id = m.group(1) + users_with_ids.append({"username": username, "user_id": db_id}) else: # If we get here, github's API is not returning what # we expected, so it's most likely a breaking change on @@ -132,9 +135,8 @@ def query(self) -> str: self.logger.info(f"Running the tap on {len(users_with_ids)} users") return users_with_ids - def get_records(self, context: Optional[Dict]) -> Iterable[Dict[str, Any]]: - """ - Override the parent method to allow skipping API calls + def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]: + """Override the parent method to allow skipping API calls if the stream is deselected and skip_parent_streams is True in config. This allows running the tap with fewer API calls and preserving quota when only syncing a child stream. Without this, @@ -218,10 +220,8 @@ def http_headers(self) -> dict: headers["Accept"] = "application/vnd.github.v3.star+json" return headers - def post_process(self, row: dict, context: Optional[Dict] = None) -> dict: - """ - Add a repo_id top-level field to be used as state replication key. - """ + def post_process(self, row: dict, context: dict | None = None) -> dict: + """Add a repo_id top-level field to be used as state replication key.""" row["repo_id"] = row["repo"]["id"] if context is not None: row["user_id"] = context["user_id"] @@ -316,7 +316,7 @@ def query(self) -> str: cost } } - """ + """ # noqa: E501 schema = th.PropertiesList( th.Property("node_id", th.StringType), diff --git a/tap_github/utils/filter_stdout.py b/tap_github/utils/filter_stdout.py index bfddbe9b..34930ccd 100644 --- a/tap_github/utils/filter_stdout.py +++ b/tap_github/utils/filter_stdout.py @@ -1,14 +1,16 @@ +from __future__ import annotations + import contextlib import io import re import sys -from typing import Pattern, TextIO, Union +from typing import Pattern, TextIO class FilterStdOutput: - """Filter out stdout/sterr given a regex pattern.""" + """Filter out stdout/stderr given a regex pattern.""" - def __init__(self, stream: TextIO, re_pattern: Union[str, Pattern]): + def __init__(self, stream: TextIO, re_pattern: str | Pattern): # noqa: FA100 self.stream = stream self.pattern = ( re.compile(re_pattern) if isinstance(re_pattern, str) else re_pattern @@ -21,13 +23,12 @@ def __getattr__(self, attr_name: str): def write(self, data): if data == "\n" and self.triggered: self.triggered = False + elif self.pattern.search(data) is None: + self.stream.write(data) + self.stream.flush() else: - if self.pattern.search(data) is None: - self.stream.write(data) - self.stream.flush() - else: - # caught bad pattern - self.triggered = True + # caught bad pattern + self.triggered = True def flush(self): self.stream.flush() diff --git a/tox.ini b/tox.ini index 6b48fe08..89f63a94 100644 --- a/tox.ini +++ b/tox.ini @@ -8,12 +8,3 @@ commands = poetry install -v poetry run pytest -[flake8] -; ignore = E226,E302,E41 -ignore = W503 -max-line-length = 88 -exclude = cookiecutter -max-complexity = 10 - -[pydocstyle] -ignore = D105,D203,D213 From dbf96f3e265d8ac9430cda7d41e721e25bf6ea5d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Jul 2024 23:50:56 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tap_github/tests/fixtures.py | 2 +- tox.ini | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tap_github/tests/fixtures.py b/tap_github/tests/fixtures.py index 2f3eafbb..ec6cd1c1 100644 --- a/tap_github/tests/fixtures.py +++ b/tap_github/tests/fixtures.py @@ -101,7 +101,7 @@ def alternative_sync_children(self, child_context: dict, no_sync: bool = True) - if child_stream.name == "collaborators": """ The `ORG_LEVEL_TOKEN` variable is used to store an organization-level GitHub API token. This token is used when syncing the "collaborators" stream, as it requires a higher level of access than the standard user token. - + If the `ORG_LEVEL_TOKEN` is not found in the environment, a warning is logged and the collaborators stream sync is skipped. """ # noqa: E501 ORG_LEVEL_TOKEN = os.environ.get("ORG_LEVEL_TOKEN") # noqa: N806 diff --git a/tox.ini b/tox.ini index 89f63a94..a60284ba 100644 --- a/tox.ini +++ b/tox.ini @@ -7,4 +7,3 @@ whitelist_externals = poetry commands = poetry install -v poetry run pytest - From f76adc6533f3116920251cc4d1903aeb134f5ce3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Thu, 4 Jul 2024 17:54:12 -0600 Subject: [PATCH 3/6] Fix mypy issues --- tap_github/client.py | 2 +- tap_github/repository_streams.py | 8 ++++---- tap_github/scraping.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tap_github/client.py b/tap_github/client.py index 61067cbe..465808f3 100644 --- a/tap_github/client.py +++ b/tap_github/client.py @@ -269,7 +269,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: EMPTY_REPO_ERROR_STATUS, } ): - return [] # noqa: B901 + return [] # type: ignore[return-value] # noqa: B901 # Update token rate limit info and loop through tokens if needed. self.authenticator.update_rate_limit(response.headers) diff --git a/tap_github/repository_streams.py b/tap_github/repository_streams.py index bbe2de68..623a2049 100644 --- a/tap_github/repository_streams.py +++ b/tap_github/repository_streams.py @@ -354,7 +354,7 @@ def http_headers(self) -> dict: def parse_response(self, response: requests.Response) -> Iterable[dict]: """Parse the README to yield the html response instead of an object.""" if response.status_code in self.tolerated_http_errors: - return [] + return [] # type: ignore[return-value] yield {"raw_html": response.text} @@ -723,7 +723,7 @@ class LanguagesStream(GitHubRestStream): def parse_response(self, response: requests.Response) -> Iterable[dict]: """Parse the language response and reformat to return as an iterator of [{language_name: Python, bytes: 23}].""" # noqa: E501 if response.status_code in self.tolerated_http_errors: - return [] + return [] # type: ignore[return-value] languages_json = response.json() for key, value in languages_json.items(): @@ -1538,7 +1538,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]: # TODO: update this and validate_response when # https://github.com/meltano/sdk/pull/1754 is merged if response.status_code != http.HTTPStatus.OK: - return [] + return [] # type: ignore[return-value] yield from super().parse_response(response) def validate_response(self, response: requests.Response) -> None: @@ -2324,7 +2324,7 @@ class TrafficRestStream(GitHubRestStream): def parse_response(self, response: requests.Response) -> Iterable[dict]: if response.status_code != http.HTTPStatus.OK: - return [] + return [] # type: ignore[return-value] """Parse the response and return an iterator of result rows.""" yield from extract_jsonpath(self.records_jsonpath, input=response.json()) diff --git a/tap_github/scraping.py b/tap_github/scraping.py index 966a6264..772f35e7 100644 --- a/tap_github/scraping.py +++ b/tap_github/scraping.py @@ -31,7 +31,7 @@ def scrape_dependents( # Navigate through Package toggle if present base_url = urlparse(response.url).hostname or "github.com" options = soup.find_all("a", class_="select-menu-item") - links = [] + links: list[str] = [] if len(options) > 0: links.extend(link["href"] for link in options) else: From a09c24030f4de9b5219c797ecab84ab8cb1b5d17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Thu, 4 Jul 2024 17:57:48 -0600 Subject: [PATCH 4/6] Fix fixtures --- tap_github/tests/conftest.py | 40 +++++++++++++++++++++++++++++++++++ tap_github/tests/fixtures.py | 35 ------------------------------ tap_github/tests/test_core.py | 4 +--- 3 files changed, 41 insertions(+), 38 deletions(-) create mode 100644 tap_github/tests/conftest.py diff --git a/tap_github/tests/conftest.py b/tap_github/tests/conftest.py new file mode 100644 index 00000000..c2b58ec1 --- /dev/null +++ b/tap_github/tests/conftest.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import datetime + +import pytest + + +@pytest.fixture() +def repo_list_config(request): + """Get a default list of repos or pass your own by decorating your test with + @pytest.mark.repo_list(['org1/repo1', 'org2/repo2']) + """ + marker = request.node.get_closest_marker("repo_list") + if marker is None: + repo_list = ["MeltanoLabs/tap-github", "mapswipe/mapswipe"] + else: + repo_list = marker.args[0] + + return { + "metrics_log_level": "warning", + "start_date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d"), + "repositories": repo_list, + "rate_limit_buffer": 100, + } + + +@pytest.fixture() +def username_list_config(request): + """Get a default list of usernames or pass your own by decorating your test with + @pytest.mark.username_list(['ericboucher', 'aaronsteers']) + """ + marker = request.node.get_closest_marker("username_list") + username_list = ["ericboucher", "aaronsteers"] if marker is None else marker.args[0] + + return { + "metrics_log_level": "warning", + "start_date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d"), + "user_usernames": username_list, + "rate_limit_buffer": 100, + } diff --git a/tap_github/tests/fixtures.py b/tap_github/tests/fixtures.py index ec6cd1c1..5d60c94f 100644 --- a/tap_github/tests/fixtures.py +++ b/tap_github/tests/fixtures.py @@ -25,41 +25,6 @@ def search_config(): } -@pytest.fixture() -def repo_list_config(request): - """Get a default list of repos or pass your own by decorating your test with - @pytest.mark.repo_list(['org1/repo1', 'org2/repo2']) - """ - marker = request.node.get_closest_marker("repo_list") - if marker is None: - repo_list = ["MeltanoLabs/tap-github", "mapswipe/mapswipe"] - else: - repo_list = marker.args[0] - - return { - "metrics_log_level": "warning", - "start_date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d"), - "repositories": repo_list, - "rate_limit_buffer": 100, - } - - -@pytest.fixture() -def username_list_config(request): - """Get a default list of usernames or pass your own by decorating your test with - @pytest.mark.username_list(['ericboucher', 'aaronsteers']) - """ - marker = request.node.get_closest_marker("username_list") - username_list = ["ericboucher", "aaronsteers"] if marker is None else marker.args[0] - - return { - "metrics_log_level": "warning", - "start_date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d"), - "user_usernames": username_list, - "rate_limit_buffer": 100, - } - - @pytest.fixture() def user_id_list_config(request): """Get a default list of usernames or pass your own by decorating your test with diff --git a/tap_github/tests/test_core.py b/tap_github/tests/test_core.py index 7697a724..00d1be6d 100644 --- a/tap_github/tests/test_core.py +++ b/tap_github/tests/test_core.py @@ -10,9 +10,7 @@ from tap_github.tap import TapGitHub from tap_github.utils.filter_stdout import nostdout -from .fixtures import ( - alternative_sync_children, -) +from .fixtures import alternative_sync_children # Run standard built-in tap tests from the SDK: From 1d433a0ca96c2ed914865b3418e7fb8550f35f83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Thu, 4 Jul 2024 18:01:19 -0600 Subject: [PATCH 5/6] Fix more fixtures --- tap_github/tests/conftest.py | 31 +++++++++++++++++++++++++++++++ tap_github/tests/fixtures.py | 31 ------------------------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/tap_github/tests/conftest.py b/tap_github/tests/conftest.py index c2b58ec1..6419a334 100644 --- a/tap_github/tests/conftest.py +++ b/tap_github/tests/conftest.py @@ -5,6 +5,37 @@ import pytest +@pytest.fixture() +def organization_list_config(request): + """Get a default list of organizations or pass your own by decorating your test with + @pytest.mark.organization_list(['MeltanoLabs', 'oviohub']) + """ + marker = request.node.get_closest_marker("organization_list") + + organization_list = ["MeltanoLabs"] if marker is None else marker.args[0] + + return { + "metrics_log_level": "warning", + "start_date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d"), + "organizations": organization_list, + "rate_limit_buffer": 100, + } + + +@pytest.fixture() +def search_config(): + return { + "metrics_log_level": "warning", + "start_date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d"), + "searches": [ + { + "name": "tap_something", + "query": "tap-+language:Python", + }, + ], + } + + @pytest.fixture() def repo_list_config(request): """Get a default list of repos or pass your own by decorating your test with diff --git a/tap_github/tests/fixtures.py b/tap_github/tests/fixtures.py index 5d60c94f..aedefa75 100644 --- a/tap_github/tests/fixtures.py +++ b/tap_github/tests/fixtures.py @@ -11,20 +11,6 @@ sys.stdout = FilterStdOutput(sys.stdout, r'{"type": ') # type: ignore -@pytest.fixture() -def search_config(): - return { - "metrics_log_level": "warning", - "start_date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d"), - "searches": [ - { - "name": "tap_something", - "query": "tap-+language:Python", - }, - ], - } - - @pytest.fixture() def user_id_list_config(request): """Get a default list of usernames or pass your own by decorating your test with @@ -40,23 +26,6 @@ def user_id_list_config(request): } -@pytest.fixture() -def organization_list_config(request): - """Get a default list of organizations or pass your own by decorating your test with - @pytest.mark.organization_list(['MeltanoLabs', 'oviohub']) - """ - marker = request.node.get_closest_marker("organization_list") - - organization_list = ["MeltanoLabs"] if marker is None else marker.args[0] - - return { - "metrics_log_level": "warning", - "start_date": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d"), - "organizations": organization_list, - "rate_limit_buffer": 100, - } - - def alternative_sync_children(self, child_context: dict, no_sync: bool = True) -> None: """Override for Stream._sync_children. Enabling us to use an ORG_LEVEL_TOKEN for the collaborators stream. From 835041176e5d791dff12d9a044fe0ff956cbe8e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez-Mondrag=C3=B3n?= Date: Thu, 4 Jul 2024 18:05:41 -0600 Subject: [PATCH 6/6] Fix fixture name --- tap_github/tests/test_tap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tap_github/tests/test_tap.py b/tap_github/tests/test_tap.py index d5301cbb..0c9f17bd 100644 --- a/tap_github/tests/test_tap.py +++ b/tap_github/tests/test_tap.py @@ -58,7 +58,7 @@ def test_validate_repo_list_config(repo_list_config): def run_tap_with_config( - caps, + capsys: pytest.LogCaptureFixture, config_obj: dict, skip_stream: str | None, single_stream: str | None, @@ -87,14 +87,14 @@ def run_tap_with_config( ) # discard previous output to stdout (potentially from other tests) - caps.readouterr() + capsys.readouterr() with patch( "singer_sdk.streams.core.Stream._sync_children", alternative_sync_children, ): tap2 = TapGitHub(config=config_obj, catalog=catalog.to_dict()) tap2.sync_all() - captured = caps.readouterr() + captured = capsys.readouterr() return captured.out