Skip to content

Commit

Permalink
[3/n][dagster-airbyte] Implement base request method in AirbyteCloudC…
Browse files Browse the repository at this point in the history
…lient (#26241)

## Summary & Motivation

This PR implements the base `_make_request` method for
AirbyteCloudClient and the methods that are required to handle the
creation of the access token.

About the access token, the logic was taken and reworked from this
[previous
implementation](#23451).
TL;DR, Airbyte now requires a client ID and secret to generate an access
token - this access token expires every 3 minutes. See more
[here](https://reference.airbyte.com/reference/portalairbytecom-deprecation).

## How I Tested These Changes

Additional tests
  • Loading branch information
maximearmstrong authored Dec 5, 2024
1 parent 0acea44 commit e24374c
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dagster_airbyte.ops import airbyte_sync_op as airbyte_sync_op
from dagster_airbyte.resources import (
AirbyteCloudResource as AirbyteCloudResource,
AirbyteCloudWorkspace as AirbyteCloudWorkspace,
AirbyteResource as AirbyteResource,
AirbyteState as AirbyteState,
airbyte_cloud_resource as airbyte_cloud_resource,
Expand Down
156 changes: 142 additions & 14 deletions python_modules/libraries/dagster-airbyte/dagster_airbyte/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dagster._annotations import experimental
from dagster._config.pythonic_config import infer_schema_from_config_class
from dagster._core.definitions.resource_definition import dagster_maintained_resource
from dagster._model import DagsterModel
from dagster._utils.cached_method import cached_method
from dagster._utils.merger import deep_merge_dicts
from pydantic import Field, PrivateAttr
Expand All @@ -28,6 +29,9 @@
from dagster_airbyte.translator import AirbyteWorkspaceData
from dagster_airbyte.types import AirbyteOutput

AIRBYTE_API_BASE = "https://api.airbyte.com"
AIRBYTE_API_VERSION = "v1"

DEFAULT_POLL_INTERVAL_SECONDS = 10

# The access token expire every 3 minutes in Airbyte Cloud.
Expand Down Expand Up @@ -801,35 +805,141 @@ def airbyte_cloud_resource(context) -> AirbyteCloudResource:


@experimental
class AirbyteCloudClient:
class AirbyteCloudClient(DagsterModel):
"""This class exposes methods on top of the Airbyte APIs for Airbyte Cloud."""

workspace_id: str = Field(..., description="The Airbyte workspace ID")
client_id: str = Field(..., description="The Airbyte client ID.")
client_secret: str = Field(..., description="The Airbyte client secret.")
request_max_retries: int = Field(
...,
description=(
"The maximum number of times requests to the Airbyte API should be retried "
"before failing."
),
)
request_retry_delay: float = Field(
...,
description="Time (in seconds) to wait between each request retry.",
)
request_timeout: int = Field(
...,
description="Time (in seconds) after which the requests to Airbyte are declared timed out.",
)

_access_token_value: Optional[str] = PrivateAttr(default=None)
_access_token_timestamp: Optional[float] = PrivateAttr(default=None)

def __init__(
self,
workspace_id: str,
client_id: str,
client_secret: str,
):
self.workspace_id = workspace_id
self.client_id = client_id
self.client_secret = client_secret

@property
@cached_method
def _log(self) -> logging.Logger:
return get_dagster_logger()

@property
def api_base_url(self) -> str:
raise NotImplementedError()
return f"{AIRBYTE_API_BASE}/{AIRBYTE_API_VERSION}"

@property
def all_additional_request_params(self) -> Mapping[str, Any]:
return {**self.authorization_request_params, **self.user_agent_request_params}

@property
def authorization_request_params(self) -> Mapping[str, Any]:
# Make sure the access token is refreshed before using it when calling the API.
if self._needs_refreshed_access_token():
self._refresh_access_token()
return {
"Authorization": f"Bearer {self._access_token_value}",
}

@property
def user_agent_request_params(self) -> Mapping[str, Any]:
return {
"User-Agent": "dagster",
}

def _refresh_access_token(self) -> None:
response = check.not_none(
self._make_request(
method="POST",
endpoint="/applications/token",
base_url=self.api_base_url,
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
},
# Must not pass the bearer access token when refreshing it.
include_additional_request_params=False,
)
)
self._access_token_value = str(response["access_token"])
self._access_token_timestamp = datetime.now().timestamp()

def _needs_refreshed_access_token(self) -> bool:
return (
not self._access_token_value
or not self._access_token_timestamp
or self._access_token_timestamp
<= (
datetime.now() - timedelta(seconds=AIRBYTE_CLOUD_REFRESH_TIMEDELTA_SECONDS)
).timestamp()
)

def _get_session(self, include_additional_request_params: bool) -> requests.Session:
headers = {"accept": "application/json"}
if include_additional_request_params:
headers = {
**headers,
**self.all_additional_request_params,
}
session = requests.Session()
session.headers.update(headers)
return session

def _make_request(
self, method: str, endpoint: str, data: Optional[str] = None
self,
method: str,
endpoint: str,
base_url: str,
data: Optional[Mapping[str, Any]] = None,
include_additional_request_params: bool = True,
) -> Mapping[str, Any]:
raise NotImplementedError()
"""Creates and sends a request to the desired Airbyte REST API endpoint.
Args:
method (str): The http method to use for this request (e.g. "POST", "GET", "PATCH").
endpoint (str): The Airbyte API endpoint to send this request to.
base_url (str): The base url to the Airbyte API to use.
data (Optional[Dict[str, Any]]): JSON-formatted data string to be included in the request.
include_additional_request_params (bool): Whether to include authorization and user-agent headers
to the request parameters. Defaults to True.
Returns:
Dict[str, Any]: Parsed json data from the response to this request
"""
url = base_url + endpoint

num_retries = 0
while True:
try:
session = self._get_session(
include_additional_request_params=include_additional_request_params
)
response = session.request(
method=method, url=url, json=data, timeout=self.request_timeout
)
response.raise_for_status()
return response.json()
except RequestException as e:
self._log.error(
f"Request to Airbyte API failed for url {url} with method {method} : {e}"
)
if num_retries == self.request_max_retries:
break
num_retries += 1
time.sleep(self.request_retry_delay)

raise Failure(f"Max retries ({self.request_max_retries}) exceeded with url: {url}.")

def get_connections(self) -> Mapping[str, Any]:
"""Fetches all connections of an Airbyte workspace from the Airbyte API."""
Expand All @@ -849,6 +959,21 @@ class AirbyteCloudWorkspace(ConfigurableResource):
workspace_id: str = Field(..., description="The Airbyte Cloud workspace ID")
client_id: str = Field(..., description="The Airbyte Cloud client ID.")
client_secret: str = Field(..., description="The Airbyte Cloud client secret.")
request_max_retries: int = Field(
default=3,
description=(
"The maximum number of times requests to the Airbyte API should be retried "
"before failing."
),
)
request_retry_delay: float = Field(
default=0.25,
description="Time (in seconds) to wait between each request retry.",
)
request_timeout: int = Field(
default=15,
description="Time (in seconds) after which the requests to Airbyte are declared timed out.",
)

_client: AirbyteCloudClient = PrivateAttr(default=None)

Expand All @@ -858,6 +983,9 @@ def get_client(self) -> AirbyteCloudClient:
workspace_id=self.workspace_id,
client_id=self.client_id,
client_secret=self.client_secret,
request_max_retries=self.request_max_retries,
request_retry_delay=self.request_retry_delay,
request_timeout=self.request_timeout,
)

def fetch_airbyte_workspace_data(
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Iterator

import pytest
import responses
from dagster_airbyte.resources import AIRBYTE_API_BASE, AIRBYTE_API_VERSION

TEST_WORKSPACE_ID = "some_workspace_id"
TEST_CLIENT_ID = "some_client_id"
TEST_CLIENT_SECRET = "some_client_secret"

TEST_ACCESS_TOKEN = "some_access_token"

# Taken from Airbyte API documentation
# https://reference.airbyte.com/reference/createaccesstoken
SAMPLE_ACCESS_TOKEN = {"access_token": TEST_ACCESS_TOKEN}


@pytest.fixture(
name="base_api_mocks",
)
def base_api_mocks_fixture() -> Iterator[responses.RequestsMock]:
with responses.RequestsMock() as response:
response.add(
method=responses.POST,
url=f"{AIRBYTE_API_BASE}/{AIRBYTE_API_VERSION}/applications/token",
json=SAMPLE_ACCESS_TOKEN,
status=201,
)
yield response
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
from datetime import datetime
from unittest import mock

import responses
from dagster_airbyte import AirbyteCloudWorkspace

from dagster_airbyte_tests.experimental.conftest import (
TEST_ACCESS_TOKEN,
TEST_CLIENT_ID,
TEST_CLIENT_SECRET,
TEST_WORKSPACE_ID,
)


def test_refresh_access_token(base_api_mocks: responses.RequestsMock) -> None:
"""Tests the `AirbyteCloudClient._make_request` method and how the API access token is refreshed.
Args:
base_api_mocks (responses.RequestsMock): The mock responses for the base API requests,
i.e. generating the access token.
"""
resource = AirbyteCloudWorkspace(
workspace_id=TEST_WORKSPACE_ID,
client_id=TEST_CLIENT_ID,
client_secret=TEST_CLIENT_SECRET,
)
client = resource.get_client()

base_api_mocks.add(
method=responses.GET,
url=f"{client.api_base_url}/test",
json={},
status=200,
)

test_time_first_call = datetime(2024, 1, 1, 0, 0, 0)
test_time_before_expiration = datetime(2024, 1, 1, 0, 2, 0)
test_time_after_expiration = datetime(2024, 1, 1, 0, 3, 0)
with mock.patch("dagster_airbyte.resources.datetime", wraps=datetime) as dt:
# Test first call, must get the access token before calling the jobs api
dt.now.return_value = test_time_first_call
client._make_request(method="GET", endpoint="/test", base_url=client.api_base_url) # noqa

assert len(base_api_mocks.calls) == 2
access_token_call = base_api_mocks.calls[0]
jobs_api_call = base_api_mocks.calls[1]

assert "Authorization" not in access_token_call.request.headers
access_token_call_body = json.loads(access_token_call.request.body.decode("utf-8"))
assert access_token_call_body["client_id"] == TEST_CLIENT_ID
assert access_token_call_body["client_secret"] == TEST_CLIENT_SECRET
assert jobs_api_call.request.headers["Authorization"] == f"Bearer {TEST_ACCESS_TOKEN}"

base_api_mocks.calls.reset()

# Test second call, occurs before the access token expiration, only the jobs api is called
dt.now.return_value = test_time_before_expiration
client._make_request(method="GET", endpoint="/test", base_url=client.api_base_url) # noqa

assert len(base_api_mocks.calls) == 1
jobs_api_call = base_api_mocks.calls[0]

assert jobs_api_call.request.headers["Authorization"] == f"Bearer {TEST_ACCESS_TOKEN}"

base_api_mocks.calls.reset()

# Test third call, occurs after the token expiration,
# must refresh the access token before calling the jobs api
dt.now.return_value = test_time_after_expiration
client._make_request(method="GET", endpoint="/test", base_url=client.api_base_url) # noqa

assert len(base_api_mocks.calls) == 2
access_token_call = base_api_mocks.calls[0]
jobs_api_call = base_api_mocks.calls[1]

assert "Authorization" not in access_token_call.request.headers
access_token_call_body = json.loads(access_token_call.request.body.decode("utf-8"))
assert access_token_call_body["client_id"] == TEST_CLIENT_ID
assert access_token_call_body["client_secret"] == TEST_CLIENT_SECRET
assert jobs_api_call.request.headers["Authorization"] == f"Bearer {TEST_ACCESS_TOKEN}"

0 comments on commit e24374c

Please sign in to comment.