Skip to content

Commit

Permalink
Add get_connection_details
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Dec 3, 2024
1 parent 604587b commit af60e39
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@
from dagster_airbyte.translator import AirbyteWorkspaceData
from dagster_airbyte.types import AirbyteOutput

AIRBYTE_API_BASE = "https://api.airbyte.com"
AIRBYTE_API_VERSION = "v1"
AIRBYTE_REST_API_BASE = "https://api.airbyte.com"
AIRBYTE_REST_API_VERSION = "v1"

AIRBYTE_SERVER_API_BASE = "https://cloud.airbyte.com/api"
AIRBYTE_SERVER_API_VERSION = "v1"

DEFAULT_POLL_INTERVAL_SECONDS = 10

Expand Down Expand Up @@ -836,8 +839,12 @@ def _log(self) -> logging.Logger:
return get_dagster_logger()

@property
def api_base_url(self) -> str:
return f"{AIRBYTE_API_BASE}/{AIRBYTE_API_VERSION}"
def rest_api_base_url(self) -> str:
return f"{AIRBYTE_REST_API_BASE}/{AIRBYTE_REST_API_VERSION}"

@property
def server_api_base_url(self) -> str:
return f"{AIRBYTE_SERVER_API_BASE}/{AIRBYTE_SERVER_API_VERSION}"

@property
def all_additional_request_params(self) -> Mapping[str, Any]:
Expand Down Expand Up @@ -867,7 +874,7 @@ def _refresh_access_token(self) -> None:
self._make_request(
method="POST",
endpoint="applications/token",
base_url=self.api_base_url,
base_url=self.rest_api_base_url,
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
Expand Down Expand Up @@ -895,6 +902,7 @@ def _make_request(
endpoint: str,
base_url: str,
data: Optional[Mapping[str, Any]] = None,
params: Optional[Mapping[str, Any]] = None,
include_additional_request_params: bool = True,
) -> Mapping[str, Any]:
"""Creates and sends a request to the desired Airbyte REST API endpoint.
Expand All @@ -904,6 +912,7 @@ def _make_request(
endpoint (str): The Airbyte API endpoint to send this request to.
base_url (str): The base url to the Airbyte API to use.
data (Optional[Dict[str, Any]]): JSON-formatted data string to be included in the request.
params (Optional[Dict[str, Any]]): JSON-formatted query params to be included in the request.
include_additional_request_params (bool): Whether to include authorization and user-agent headers
to the request parameters. Defaults to True.
Expand All @@ -925,6 +934,9 @@ def _make_request(
if data:
request_args["json"] = data

if params:
request_args["params"] = params

if include_additional_request_params:
request_args = deep_merge_dicts(
request_args,
Expand All @@ -946,20 +958,31 @@ def _make_request(
raise Failure(f"Max retries ({self.request_max_retries}) exceeded with url: {url}.")

def get_connections(self) -> Mapping[str, Any]:
"""Fetches all connections of an Airbyte workspace from the Airbyte API."""
"""Fetches all connections of an Airbyte workspace from the Airbyte REST API."""
return self._make_request(
method="GET",
endpoint="connections",
base_url=self.api_base_url,
data={"workspaceIds": [self.workspace_id]},
base_url=self.rest_api_base_url,
params={"workspaceIds": self.workspace_id},
)

def get_connection_details(self, connection_id) -> Mapping[str, Any]:
"""Fetches all connections of an Airbyte workspace from the Airbyte Server API."""
# Using the Server API to get the connection details, including streams and their configs.
# https://airbyte-public-api-docs.s3.us-east-2.amazonaws.com/rapidoc-api-docs.html#post-/v1/connections/get
return self._make_request(
method="POST",
endpoint="connections/get",
base_url=self.server_api_base_url,
data={"connectionId": connection_id},
)

def get_destination_details(self, destination_id: str) -> Mapping[str, Any]:
"""Fetches details about a given destination from the Airbyte API."""
"""Fetches details about a given destination from the Airbyte REST API."""
return self._make_request(
method="GET",
endpoint=f"destinations/{destination_id}",
base_url=self.api_base_url,
base_url=self.rest_api_base_url,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,34 @@

import pytest
import responses
from dagster_airbyte.resources import AIRBYTE_API_BASE, AIRBYTE_API_VERSION
from dagster_airbyte.resources import AIRBYTE_REST_API_BASE, AIRBYTE_REST_API_VERSION, AIRBYTE_SERVER_API_BASE, AIRBYTE_SERVER_API_VERSION

TEST_WORKSPACE_ID = "some_workspace_id"
TEST_CLIENT_ID = "some_client_id"
TEST_CLIENT_SECRET = "some_client_secret"


# Taken from Airbyte API documentation
TEST_DESTINATION_ID = "18dccc91-0ab1-4f72-9ed7-0b8fc27c5826"
TEST_CONNECTION_ID = "9924bcd0-99be-453d-ba47-c2c9766f7da5"


# Taken from Airbyte REST API documentation
# https://reference.airbyte.com/reference/createaccesstoken
SAMPLE_ACCESS_TOKEN = {"access_token": "some_access_token"}


# Taken from Airbyte API documentation
# Taken from Airbyte REST API documentation
# https://reference.airbyte.com/reference/listconnections
SAMPLE_CONNECTIONS = {
"next": "https://api.airbyte.com/v1/connections?limit=5&offset=10",
"previous": "https://api.airbyte.com/v1/connections?limit=5&offset=0",
"data": [
{
"connectionId": "9924bcd0-99be-453d-ba47-c2c9766f7da5",
"connectionId": TEST_CONNECTION_ID,
"workspaceId": "744cc0ed-7f05-4949-9e60-2a814f90c035",
"name": "Postgres To Snowflake",
"sourceId": "0c31738c-0b2d-4887-b506-e2cd1c39cc35",
"destinationId": "18dccc91-0ab1-4f72-9ed7-0b8fc27c5826",
"destinationId": TEST_DESTINATION_ID,
"status": "active",
"schedule": {
"schedule_type": "cron",
Expand All @@ -35,10 +39,123 @@
}


# Taken from Airbyte Server API documentation
# https://airbyte-public-api-docs.s3.us-east-2.amazonaws.com/rapidoc-api-docs.html#post-/v1/connections/get
SAMPLE_CONNECTION_DETAILS = {
"connectionId": TEST_CONNECTION_ID,
"name": "string",
"namespaceDefinition": "source",
"namespaceFormat": "${SOURCE_NAMESPACE}",
"prefix": "string",
"sourceId": "0c31738c-0b2d-4887-b506-e2cd1c39cc35",
"destinationId": TEST_DESTINATION_ID,
"operationIds": [
"1938d12e-b540-4000-8c46-1be33f00ab01"
],
"syncCatalog": {
"streams": [
{
"stream": {
"name": "string",
"jsonSchema": {},
"supportedSyncModes": [
"full_refresh"
],
"sourceDefinedCursor": False,
"defaultCursorField": [
"string"
],
"sourceDefinedPrimaryKey": [
[
"string"
]
],
"namespace": "string",
"isResumable": False
},
"config": {
"syncMode": "full_refresh",
"cursorField": [
"string"
],
"destinationSyncMode": "append",
"primaryKey": [
[
"string"
]
],
"aliasName": "string",
"selected": False,
"suggested": False,
"fieldSelectionEnabled": False,
"selectedFields": [
{
"fieldPath": [
"string"
]
}
],
"hashedFields": [
{
"fieldPath": [
"string"
]
}
],
"mappers": [
{
"id": "1938d12e-b540-4000-8ff0-46231e18f301",
"type": "hashing",
"mapperConfiguration": {}
}
],
"minimumGenerationId": 0,
"generationId": 0,
"syncId": 0
}
}
]
},
"schedule": {
"units": 0,
"timeUnit": "minutes"
},
"scheduleType": "manual",
"scheduleData": {
"basicSchedule": {
"timeUnit": "minutes",
"units": 0
},
"cron": {
"cronExpression": "string",
"cronTimeZone": "string"
}
},
"status": "active",
"resourceRequirements": {
"cpu_request": "string",
"cpu_limit": "string",
"memory_request": "string",
"memory_limit": "string",
"ephemeral_storage_request": "string",
"ephemeral_storage_limit": "string"
},
"sourceCatalogId": "1938d12e-b540-4000-85a4-7ecc2445a901",
"geography": "auto",
"breakingChange": False,
"notifySchemaChanges": False,
"notifySchemaChangesByEmail": False,
"nonBreakingChangesPreference": "ignore",
"created_at": 0,
"backfillPreference": "enabled",
"workspaceId": "744cc0ed-7f05-4949-9e60-2a814f90c035"
}


# Taken from Airbyte API documentation
# https://reference.airbyte.com/reference/getdestination
SAMPLE_DESTINATION_DETAILS = {
"destinationId": "18dccc91-0ab1-4f72-9ed7-0b8fc27c5826",
"destinationId": TEST_DESTINATION_ID,
"name": "My Destination",
"sourceType": "postgres",
"workspaceId": "744cc0ed-7f05-4949-9e60-2a814f90c035",
Expand All @@ -51,19 +168,14 @@
}


@pytest.fixture(name="destination_id")
def destination_id_fixture() -> str:
return "18dccc91-0ab1-4f72-9ed7-0b8fc27c5826"


@pytest.fixture(
name="base_api_mocks",
)
def base_api_mocks_fixture() -> Iterator[responses.RequestsMock]:
with responses.RequestsMock() as response:
response.add(
method=responses.POST,
url=f"{AIRBYTE_API_BASE}/{AIRBYTE_API_VERSION}/applications/token",
url=f"{AIRBYTE_REST_API_BASE}/{AIRBYTE_REST_API_VERSION}/applications/token",
json=SAMPLE_ACCESS_TOKEN,
status=201,
)
Expand All @@ -74,18 +186,23 @@ def base_api_mocks_fixture() -> Iterator[responses.RequestsMock]:
name="fetch_workspace_data_api_mocks",
)
def fetch_workspace_data_api_mocks_fixture(
destination_id: str,
base_api_mocks: responses.RequestsMock,
) -> Iterator[responses.RequestsMock]:
base_api_mocks.add(
method=responses.GET,
url=f"{AIRBYTE_API_BASE}/{AIRBYTE_API_VERSION}/connections",
url=f"{AIRBYTE_REST_API_BASE}/{AIRBYTE_REST_API_VERSION}/connections",
json=SAMPLE_CONNECTIONS,
status=200,
)
base_api_mocks.add(
method=responses.POST,
url=f"{AIRBYTE_SERVER_API_BASE}/{AIRBYTE_SERVER_API_VERSION}/connections/get",
json=SAMPLE_CONNECTION_DETAILS,
status=200,
)
base_api_mocks.add(
method=responses.GET,
url=f"{AIRBYTE_API_BASE}/{AIRBYTE_API_VERSION}/destinations/{destination_id}",
url=f"{AIRBYTE_REST_API_BASE}/{AIRBYTE_REST_API_VERSION}/destinations/{TEST_DESTINATION_ID}",
json=SAMPLE_DESTINATION_DETAILS,
status=200,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
TEST_CLIENT_ID,
TEST_CLIENT_SECRET,
TEST_WORKSPACE_ID,
TEST_DESTINATION_ID,
TEST_CONNECTION_ID,
)


Expand All @@ -22,7 +24,7 @@ def test_refresh_access_token(base_api_mocks: responses.RequestsMock) -> None:

base_api_mocks.add(
method=responses.GET,
url=f"{client.api_base_url}/test",
url=f"{client.rest_api_base_url}/test",
json={},
status=200,
)
Expand All @@ -33,7 +35,7 @@ def test_refresh_access_token(base_api_mocks: responses.RequestsMock) -> None:
with mock.patch("dagster_airbyte.resources.datetime", wraps=datetime) as dt:
# Test first call, must get the access token before calling the jobs api
dt.now.return_value = test_time_first_call
client._make_request(method="GET", endpoint="test", base_url=client.api_base_url) # noqa
client._make_request(method="GET", endpoint="test", base_url=client.rest_api_base_url) # noqa

assert len(base_api_mocks.calls) == 2
access_token_call = base_api_mocks.calls[0]
Expand All @@ -49,7 +51,7 @@ def test_refresh_access_token(base_api_mocks: responses.RequestsMock) -> None:

# Test second call, occurs before the access token expiration, only the jobs api is called
dt.now.return_value = test_time_before_expiration
client._make_request(method="GET", endpoint="test", base_url=client.api_base_url) # noqa
client._make_request(method="GET", endpoint="test", base_url=client.rest_api_base_url) # noqa

assert len(base_api_mocks.calls) == 1
jobs_api_call = base_api_mocks.calls[0]
Expand All @@ -61,7 +63,7 @@ def test_refresh_access_token(base_api_mocks: responses.RequestsMock) -> None:
# Test third call, occurs after the token expiration,
# must refresh the access token before calling the jobs api
dt.now.return_value = test_time_after_expiration
client._make_request(method="GET", endpoint="test", base_url=client.api_base_url) # noqa
client._make_request(method="GET", endpoint="test", base_url=client.rest_api_base_url) # noqa

assert len(base_api_mocks.calls) == 2
access_token_call = base_api_mocks.calls[0]
Expand All @@ -75,7 +77,6 @@ def test_refresh_access_token(base_api_mocks: responses.RequestsMock) -> None:


def test_basic_resource_request(
destination_id: str,
fetch_workspace_data_api_mocks: responses.RequestsMock,
) -> None:
resource = AirbyteCloudWorkspace(
Expand All @@ -87,11 +88,14 @@ def test_basic_resource_request(

# fetch workspace data calls
client.get_connections()
client.get_destination_details(destination_id=destination_id)
client.get_connection_details(connection_id=TEST_CONNECTION_ID)
client.get_destination_details(destination_id=TEST_DESTINATION_ID)

assert len(fetch_workspace_data_api_mocks.calls) == 3
assert len(fetch_workspace_data_api_mocks.calls) == 4
# The first call is to create the access token
assert "Authorization" not in fetch_workspace_data_api_mocks.calls[0].request.headers
# The two next calls are actual API calls
assert "connections" in fetch_workspace_data_api_mocks.calls[1].request.url
assert f"destinations/{destination_id}" in fetch_workspace_data_api_mocks.calls[2].request.url
assert "connections/get" in fetch_workspace_data_api_mocks.calls[2].request.url
assert TEST_CONNECTION_ID in fetch_workspace_data_api_mocks.calls[2].request.body.decode()
assert f"destinations/{TEST_DESTINATION_ID}" in fetch_workspace_data_api_mocks.calls[3].request.url

0 comments on commit af60e39

Please sign in to comment.