diff --git a/python_modules/libraries/dagster-airbyte/dagster_airbyte/__init__.py b/python_modules/libraries/dagster-airbyte/dagster_airbyte/__init__.py index 2a4161961ea24..980ed7931aa5c 100644 --- a/python_modules/libraries/dagster-airbyte/dagster_airbyte/__init__.py +++ b/python_modules/libraries/dagster-airbyte/dagster_airbyte/__init__.py @@ -21,6 +21,7 @@ from dagster_airbyte.ops import airbyte_sync_op as airbyte_sync_op from dagster_airbyte.resources import ( AirbyteCloudResource as AirbyteCloudResource, + AirbyteCloudWorkspace as AirbyteCloudWorkspace, AirbyteResource as AirbyteResource, AirbyteState as AirbyteState, airbyte_cloud_resource as airbyte_cloud_resource, diff --git a/python_modules/libraries/dagster-airbyte/dagster_airbyte/resources.py b/python_modules/libraries/dagster-airbyte/dagster_airbyte/resources.py index e26531440c93f..62bf44ab2f061 100644 --- a/python_modules/libraries/dagster-airbyte/dagster_airbyte/resources.py +++ b/python_modules/libraries/dagster-airbyte/dagster_airbyte/resources.py @@ -20,6 +20,7 @@ from dagster._annotations import experimental from dagster._config.pythonic_config import infer_schema_from_config_class from dagster._core.definitions.resource_definition import dagster_maintained_resource +from dagster._model import DagsterModel from dagster._utils.cached_method import cached_method from dagster._utils.merger import deep_merge_dicts from pydantic import Field, PrivateAttr @@ -28,6 +29,9 @@ from dagster_airbyte.translator import AirbyteWorkspaceData from dagster_airbyte.types import AirbyteOutput +AIRBYTE_API_BASE = "https://api.airbyte.com" +AIRBYTE_API_VERSION = "v1" + DEFAULT_POLL_INTERVAL_SECONDS = 10 # The access token expire every 3 minutes in Airbyte Cloud. @@ -801,22 +805,31 @@ def airbyte_cloud_resource(context) -> AirbyteCloudResource: @experimental -class AirbyteCloudClient: +class AirbyteCloudClient(DagsterModel): """This class exposes methods on top of the Airbyte APIs for Airbyte Cloud.""" + workspace_id: str = Field(..., description="The Airbyte workspace ID") + client_id: str = Field(..., description="The Airbyte client ID.") + client_secret: str = Field(..., description="The Airbyte client secret.") + request_max_retries: int = Field( + ..., + description=( + "The maximum number of times requests to the Airbyte API should be retried " + "before failing." + ), + ) + request_retry_delay: float = Field( + ..., + description="Time (in seconds) to wait between each request retry.", + ) + request_timeout: int = Field( + ..., + description="Time (in seconds) after which the requests to Airbyte are declared timed out.", + ) + _access_token_value: Optional[str] = PrivateAttr(default=None) _access_token_timestamp: Optional[float] = PrivateAttr(default=None) - def __init__( - self, - workspace_id: str, - client_id: str, - client_secret: str, - ): - self.workspace_id = workspace_id - self.client_id = client_id - self.client_secret = client_secret - @property @cached_method def _log(self) -> logging.Logger: @@ -824,12 +837,109 @@ def _log(self) -> logging.Logger: @property def api_base_url(self) -> str: - raise NotImplementedError() + return f"{AIRBYTE_API_BASE}/{AIRBYTE_API_VERSION}" + + @property + def all_additional_request_params(self) -> Mapping[str, Any]: + return {**self.authorization_request_params, **self.user_agent_request_params} + + @property + def authorization_request_params(self) -> Mapping[str, Any]: + # Make sure the access token is refreshed before using it when calling the API. + if self._needs_refreshed_access_token(): + self._refresh_access_token() + return { + "Authorization": f"Bearer {self._access_token_value}", + } + + @property + def user_agent_request_params(self) -> Mapping[str, Any]: + return { + "User-Agent": "dagster", + } + + def _refresh_access_token(self) -> None: + response = check.not_none( + self._make_request( + method="POST", + endpoint="/applications/token", + base_url=self.api_base_url, + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + }, + # Must not pass the bearer access token when refreshing it. + include_additional_request_params=False, + ) + ) + self._access_token_value = str(response["access_token"]) + self._access_token_timestamp = datetime.now().timestamp() + + def _needs_refreshed_access_token(self) -> bool: + return ( + not self._access_token_value + or not self._access_token_timestamp + or self._access_token_timestamp + <= ( + datetime.now() - timedelta(seconds=AIRBYTE_CLOUD_REFRESH_TIMEDELTA_SECONDS) + ).timestamp() + ) + + def _get_session(self, include_additional_request_params: bool) -> requests.Session: + headers = {"accept": "application/json"} + if include_additional_request_params: + headers = { + **headers, + **self.all_additional_request_params, + } + session = requests.Session() + session.headers.update(headers) + return session def _make_request( - self, method: str, endpoint: str, data: Optional[str] = None + self, + method: str, + endpoint: str, + base_url: str, + data: Optional[Mapping[str, Any]] = None, + include_additional_request_params: bool = True, ) -> Mapping[str, Any]: - raise NotImplementedError() + """Creates and sends a request to the desired Airbyte REST API endpoint. + + Args: + method (str): The http method to use for this request (e.g. "POST", "GET", "PATCH"). + endpoint (str): The Airbyte API endpoint to send this request to. + base_url (str): The base url to the Airbyte API to use. + data (Optional[Dict[str, Any]]): JSON-formatted data string to be included in the request. + include_additional_request_params (bool): Whether to include authorization and user-agent headers + to the request parameters. Defaults to True. + + Returns: + Dict[str, Any]: Parsed json data from the response to this request + """ + url = base_url + endpoint + + num_retries = 0 + while True: + try: + session = self._get_session( + include_additional_request_params=include_additional_request_params + ) + response = session.request( + method=method, url=url, json=data, timeout=self.request_timeout + ) + response.raise_for_status() + return response.json() + except RequestException as e: + self._log.error( + f"Request to Airbyte API failed for url {url} with method {method} : {e}" + ) + if num_retries == self.request_max_retries: + break + num_retries += 1 + time.sleep(self.request_retry_delay) + + raise Failure(f"Max retries ({self.request_max_retries}) exceeded with url: {url}.") def get_connections(self) -> Mapping[str, Any]: """Fetches all connections of an Airbyte workspace from the Airbyte API.""" @@ -849,6 +959,21 @@ class AirbyteCloudWorkspace(ConfigurableResource): workspace_id: str = Field(..., description="The Airbyte Cloud workspace ID") client_id: str = Field(..., description="The Airbyte Cloud client ID.") client_secret: str = Field(..., description="The Airbyte Cloud client secret.") + request_max_retries: int = Field( + default=3, + description=( + "The maximum number of times requests to the Airbyte API should be retried " + "before failing." + ), + ) + request_retry_delay: float = Field( + default=0.25, + description="Time (in seconds) to wait between each request retry.", + ) + request_timeout: int = Field( + default=15, + description="Time (in seconds) after which the requests to Airbyte are declared timed out.", + ) _client: AirbyteCloudClient = PrivateAttr(default=None) @@ -858,6 +983,9 @@ def get_client(self) -> AirbyteCloudClient: workspace_id=self.workspace_id, client_id=self.client_id, client_secret=self.client_secret, + request_max_retries=self.request_max_retries, + request_retry_delay=self.request_retry_delay, + request_timeout=self.request_timeout, ) def fetch_airbyte_workspace_data( diff --git a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/experimental/__init__.py b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/experimental/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/experimental/conftest.py b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/experimental/conftest.py new file mode 100644 index 0000000000000..ad92bc215deee --- /dev/null +++ b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/experimental/conftest.py @@ -0,0 +1,29 @@ +from typing import Iterator + +import pytest +import responses +from dagster_airbyte.resources import AIRBYTE_API_BASE, AIRBYTE_API_VERSION + +TEST_WORKSPACE_ID = "some_workspace_id" +TEST_CLIENT_ID = "some_client_id" +TEST_CLIENT_SECRET = "some_client_secret" + +TEST_ACCESS_TOKEN = "some_access_token" + +# Taken from Airbyte API documentation +# https://reference.airbyte.com/reference/createaccesstoken +SAMPLE_ACCESS_TOKEN = {"access_token": TEST_ACCESS_TOKEN} + + +@pytest.fixture( + name="base_api_mocks", +) +def base_api_mocks_fixture() -> Iterator[responses.RequestsMock]: + with responses.RequestsMock() as response: + response.add( + method=responses.POST, + url=f"{AIRBYTE_API_BASE}/{AIRBYTE_API_VERSION}/applications/token", + json=SAMPLE_ACCESS_TOKEN, + status=201, + ) + yield response diff --git a/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/experimental/test_resources.py b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/experimental/test_resources.py new file mode 100644 index 0000000000000..d9706bc563a0b --- /dev/null +++ b/python_modules/libraries/dagster-airbyte/dagster_airbyte_tests/experimental/test_resources.py @@ -0,0 +1,81 @@ +import json +from datetime import datetime +from unittest import mock + +import responses +from dagster_airbyte import AirbyteCloudWorkspace + +from dagster_airbyte_tests.experimental.conftest import ( + TEST_ACCESS_TOKEN, + TEST_CLIENT_ID, + TEST_CLIENT_SECRET, + TEST_WORKSPACE_ID, +) + + +def test_refresh_access_token(base_api_mocks: responses.RequestsMock) -> None: + """Tests the `AirbyteCloudClient._make_request` method and how the API access token is refreshed. + + Args: + base_api_mocks (responses.RequestsMock): The mock responses for the base API requests, + i.e. generating the access token. + """ + resource = AirbyteCloudWorkspace( + workspace_id=TEST_WORKSPACE_ID, + client_id=TEST_CLIENT_ID, + client_secret=TEST_CLIENT_SECRET, + ) + client = resource.get_client() + + base_api_mocks.add( + method=responses.GET, + url=f"{client.api_base_url}/test", + json={}, + status=200, + ) + + test_time_first_call = datetime(2024, 1, 1, 0, 0, 0) + test_time_before_expiration = datetime(2024, 1, 1, 0, 2, 0) + test_time_after_expiration = datetime(2024, 1, 1, 0, 3, 0) + with mock.patch("dagster_airbyte.resources.datetime", wraps=datetime) as dt: + # Test first call, must get the access token before calling the jobs api + dt.now.return_value = test_time_first_call + client._make_request(method="GET", endpoint="/test", base_url=client.api_base_url) # noqa + + assert len(base_api_mocks.calls) == 2 + access_token_call = base_api_mocks.calls[0] + jobs_api_call = base_api_mocks.calls[1] + + assert "Authorization" not in access_token_call.request.headers + access_token_call_body = json.loads(access_token_call.request.body.decode("utf-8")) + assert access_token_call_body["client_id"] == TEST_CLIENT_ID + assert access_token_call_body["client_secret"] == TEST_CLIENT_SECRET + assert jobs_api_call.request.headers["Authorization"] == f"Bearer {TEST_ACCESS_TOKEN}" + + base_api_mocks.calls.reset() + + # Test second call, occurs before the access token expiration, only the jobs api is called + dt.now.return_value = test_time_before_expiration + client._make_request(method="GET", endpoint="/test", base_url=client.api_base_url) # noqa + + assert len(base_api_mocks.calls) == 1 + jobs_api_call = base_api_mocks.calls[0] + + assert jobs_api_call.request.headers["Authorization"] == f"Bearer {TEST_ACCESS_TOKEN}" + + base_api_mocks.calls.reset() + + # Test third call, occurs after the token expiration, + # must refresh the access token before calling the jobs api + dt.now.return_value = test_time_after_expiration + client._make_request(method="GET", endpoint="/test", base_url=client.api_base_url) # noqa + + assert len(base_api_mocks.calls) == 2 + access_token_call = base_api_mocks.calls[0] + jobs_api_call = base_api_mocks.calls[1] + + assert "Authorization" not in access_token_call.request.headers + access_token_call_body = json.loads(access_token_call.request.body.decode("utf-8")) + assert access_token_call_body["client_id"] == TEST_CLIENT_ID + assert access_token_call_body["client_secret"] == TEST_CLIENT_SECRET + assert jobs_api_call.request.headers["Authorization"] == f"Bearer {TEST_ACCESS_TOKEN}"