Skip to content

Commit

Permalink
[dagster-airbyte] Support client_id and client_secret in AirbyteC…
Browse files Browse the repository at this point in the history
…loudResource (#23451)

## Summary & Motivation

This PR updates the `AirbyteCloudResource` to support `client_id` and
`client_secret` for authentication. The `api_key` can no longer be used
for authentication because Airbyte is [deprecating
portal.airbyte.com](https://reference.airbyte.com/reference/portalairbytecom-deprecation).

Tests and docs are updated to reflect the changes.

Two main questions for reviewers:
- this PR implements a pattern where the access token is refreshed
before making a call to the API, if the token was fetched more than 2.5
minutes ago. Should we avoid this pattern and let users manage the
resource lifecycle?
- [According to
Airbyte](https://reference.airbyte.com/reference/portalairbytecom-deprecation),
the access token expires after 3 minutes.
- The access token is initially fetched in `setup_for_execution`, then
refreshed if needed. I'm concerned that for jobs including other assets,
it might take more than 3 minutes before the Airbyte assets are
materialized.
- [portal.airbyte.com will be deprecated next
week](https://reference.airbyte.com/reference/portalairbytecom-deprecation),
on August 15th, so I removed the previous `api_key` attribute without
deprecation warning. Are we comfortable doing so? Considering that this
approach will fail next week.

## How I Tested These Changes

BK with updated tests

Fully tested on a live cloud instance with this code: 

```python
from dagster import Definitions, EnvVar
from dagster_airbyte import AirbyteCloudResource, build_airbyte_assets

airbyte_instance = AirbyteCloudResource(
    client_id=EnvVar("AIRBYTE_CLIENT_ID"),
    client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)

airbyte_assets = build_airbyte_assets(
    # Test connection - Sample Data (Faker) to Google Sheets 
    connection_id="0bb7a00c-0b85-4fac-b8ff-67dc380f1c29",
    destination_tables=["products", "purchases", "users"],
)

defs = Definitions(assets=airbyte_assets, resources={"airbyte": airbyte_instance})
```

Job successful:
<img width="1037" alt="Screenshot 2024-08-13 at 4 30 12 PM"
src="https://github.com/user-attachments/assets/46c91c95-b597-48de-bcbd-8bbb38f5f6d4">

Asset graph:
<img width="1037" alt="Screenshot 2024-08-13 at 4 30 38 PM"
src="https://github.com/user-attachments/assets/eae8ab12-0da9-4956-8ae6-10932cd0c199">
  • Loading branch information
maximearmstrong authored and clairelin135 committed Aug 13, 2024
1 parent 50ac858 commit 49e1373
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 31 deletions.
19 changes: 12 additions & 7 deletions docs/content/integrations/airbyte-cloud.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ To get started, you will need to install the `dagster` and `dagster-airbyte` Pyt
pip install dagster dagster-airbyte
```

You'll also need to have an Airbyte Cloud account, and have created an Airbyte API Key. For more information, see the [Airbyte API docs](https://reference.airbyte.com/reference/start).
You'll also need to have an Airbyte Cloud account, and have created an Airbyte client ID and client secret. For more information, see the [Airbyte API docs](https://reference.airbyte.com/reference/getting-started) and [Airbyte authentication guide](https://reference.airbyte.com/reference/authentication).

---

Expand All @@ -58,11 +58,12 @@ from dagster import EnvVar
from dagster_airbyte import AirbyteCloudResource

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
```

Here, the API key is provided using an <PyObject object="EnvVar" />. For more information on setting environment variables in a production setting, see [Using environment variables and secrets](/guides/dagster/using-environment-variables-and-secrets).
Here, the client ID and client secret are provided using an <PyObject object="EnvVar" />. For more information on setting environment variables in a production setting, see [Using environment variables and secrets](/guides/dagster/using-environment-variables-and-secrets).

---

Expand Down Expand Up @@ -104,7 +105,8 @@ from dagster_airbyte import build_airbyte_assets, AirbyteCloudResource
from dagster import Definitions, EnvVar

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -153,7 +155,8 @@ from dagster_snowflake_pandas import SnowflakePandasIOManager
import pandas as pd

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -207,7 +210,8 @@ from dagster_airbyte import (
from dagster_snowflake import SnowflakeResource

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -261,7 +265,8 @@ from dagster import (
)

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def scope_define_cloud_instance() -> None:
from dagster_airbyte import AirbyteCloudResource

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
# end_define_cloud_instance

Expand Down Expand Up @@ -126,7 +127,8 @@ def scope_airbyte_cloud_manual_config():
from dagster import Definitions, EnvVar

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -257,7 +259,8 @@ def scope_add_downstream_assets_cloud():
import pandas as pd

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -310,7 +313,8 @@ def scope_add_downstream_assets_cloud_with_deps():
from dagster_snowflake import SnowflakeResource

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down Expand Up @@ -400,7 +404,8 @@ def scope_schedule_assets_cloud():
)

airbyte_instance = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
connection_id="43908042-8399-4a58-82f1-71a45099fff7",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import time
from abc import abstractmethod
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Any, Dict, List, Mapping, Optional, cast

import requests
from dagster import (
ConfigurableResource,
Failure,
InitResourceContext,
_check as check,
get_dagster_logger,
resource,
Expand All @@ -19,13 +21,17 @@
from dagster._core.definitions.resource_definition import dagster_maintained_resource
from dagster._utils.cached_method import cached_method
from dagster._utils.merger import deep_merge_dicts
from pydantic import Field
from pydantic import Field, PrivateAttr
from requests.exceptions import RequestException

from dagster_airbyte.types import AirbyteOutput

DEFAULT_POLL_INTERVAL_SECONDS = 10

# The access token expire every 3 minutes in Airbyte Cloud.
# Refresh is needed after 2.5 minutes to avoid the "token expired" error message.
AIRBYTE_CLOUD_REFRESH_TIMEDELTA_SECONDS = 150


class AirbyteState:
RUNNING = "running"
Expand Down Expand Up @@ -94,7 +100,11 @@ def all_additional_request_params(self) -> Mapping[str, Any]:
raise NotImplementedError()

def make_request(
self, endpoint: str, data: Optional[Mapping[str, object]] = None, method: str = "POST"
self,
endpoint: str,
data: Optional[Mapping[str, object]] = None,
method: str = "POST",
include_additional_request_params: bool = True,
) -> Optional[Mapping[str, object]]:
"""Creates and sends a request to the desired Airbyte REST API endpoint.
Expand All @@ -120,10 +130,11 @@ def make_request(
if data:
request_args["json"] = data

request_args = deep_merge_dicts(
request_args,
self.all_additional_request_params,
)
if include_additional_request_params:
request_args = deep_merge_dicts(
request_args,
self.all_additional_request_params,
)

response = requests.request(
**request_args,
Expand Down Expand Up @@ -244,7 +255,7 @@ def sync_and_poll(


class AirbyteCloudResource(BaseAirbyteResource):
"""This resource allows users to programatically interface with the Airbyte Cloud API to launch
"""This resource allows users to programmatically interface with the Airbyte Cloud API to launch
syncs and monitor their progress.
**Examples:**
Expand All @@ -255,7 +266,8 @@ class AirbyteCloudResource(BaseAirbyteResource):
from dagster_airbyte import AirbyteResource
my_airbyte_resource = AirbyteCloudResource(
api_key=EnvVar("AIRBYTE_API_KEY"),
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets(
Expand All @@ -269,15 +281,48 @@ class AirbyteCloudResource(BaseAirbyteResource):
)
"""

api_key: str = Field(..., description="The Airbyte Cloud API key.")
client_id: str = Field(..., description="The Airbyte Cloud client ID.")
client_secret: str = Field(..., description="The Airbyte Cloud client secret.")

_access_token_value: Optional[str] = PrivateAttr(default=None)
_access_token_timestamp: Optional[float] = PrivateAttr(default=None)

def setup_for_execution(self, context: InitResourceContext) -> None:
# Refresh access token when the resource is initialized
self._refresh_access_token()

@property
def api_base_url(self) -> str:
return "https://api.airbyte.com/v1"

@property
def all_additional_request_params(self) -> Mapping[str, Any]:
return {"headers": {"Authorization": f"Bearer {self.api_key}", "User-Agent": "dagster"}}
# Make sure the access token is refreshed before using it when calling the API.
if self._needs_refreshed_access_token():
self._refresh_access_token()
return {
"headers": {
"Authorization": f"Bearer {self._access_token_value}",
"User-Agent": "dagster",
}
}

def make_request(
self,
endpoint: str,
data: Optional[Mapping[str, object]] = None,
method: str = "POST",
include_additional_request_params: bool = True,
) -> Optional[Mapping[str, object]]:
# Make sure the access token is refreshed before using it when calling the API.
if include_additional_request_params and self._needs_refreshed_access_token():
self._refresh_access_token()
return super().make_request(
endpoint=endpoint,
data=data,
method=method,
include_additional_request_params=include_additional_request_params,
)

def start_sync(self, connection_id: str) -> Mapping[str, object]:
job_sync = check.not_none(
Expand Down Expand Up @@ -306,6 +351,31 @@ def _should_forward_logs(self) -> bool:
# Airbyte Cloud does not support streaming logs yet
return False

def _refresh_access_token(self) -> None:
response = check.not_none(
self.make_request(
endpoint="/applications/token",
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
},
# Must not pass the bearer access token when refreshing it.
include_additional_request_params=False,
)
)
self._access_token_value = str(response["access_token"])
self._access_token_timestamp = datetime.now().timestamp()

def _needs_refreshed_access_token(self) -> bool:
return (
not self._access_token_value
or not self._access_token_timestamp
or self._access_token_timestamp
<= datetime.timestamp(
datetime.now() - timedelta(seconds=AIRBYTE_CLOUD_REFRESH_TIMEDELTA_SECONDS)
)
)


class AirbyteResource(BaseAirbyteResource):
"""This resource allows users to programatically interface with the Airbyte REST API to launch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def test_assets_with_normalization(


def test_assets_cloud() -> None:
ab_resource = AirbyteCloudResource(api_key="some_key", poll_interval=0)
ab_resource = AirbyteCloudResource(
client_id="some_client_id", client_secret="some_client_secret", poll_interval=0
)
ab_url = ab_resource.api_base_url

ab_assets = build_airbyte_assets(
Expand All @@ -220,6 +222,11 @@ def test_assets_cloud() -> None:
)

with responses.RequestsMock() as rsps:
rsps.add(
rsps.POST,
f"{ab_url}/applications/token",
json={"access_token": "some_access_token"},
)
rsps.add(
rsps.POST,
f"{ab_url}/jobs",
Expand Down
Loading

0 comments on commit 49e1373

Please sign in to comment.