Skip to content

Commit

Permalink
Add factory, cached method and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Dec 18, 2024
1 parent bc6e943 commit aca5ae6
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
except ImportError:
pass

from dagster_airbyte.asset_decorator import airbyte_assets as airbyte_assets
from dagster_airbyte.asset_defs import (
build_airbyte_assets as build_airbyte_assets,
build_airbyte_assets_definitions as build_airbyte_assets_definitions,
load_assets_from_airbyte_instance as load_assets_from_airbyte_instance,
)
from dagster_airbyte.ops import airbyte_sync_op as airbyte_sync_op
Expand All @@ -28,6 +30,7 @@
load_airbyte_cloud_asset_specs as load_airbyte_cloud_asset_specs,
)
from dagster_airbyte.translator import (
AirbyteConnectionTableProps as AirbyteConnectionTableProps,
AirbyteJobStatusType as AirbyteJobStatusType,
AirbyteState as AirbyteState,
DagsterAirbyteTranslator as DagsterAirbyteTranslator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from dagster import AssetsDefinition, multi_asset
from dagster._annotations import experimental

from dagster_airbyte.resources import AirbyteCloudResource
from dagster_airbyte.resources import AirbyteCloudWorkspace
from dagster_airbyte.translator import AirbyteMetadataSet, DagsterAirbyteTranslator


@experimental
def airbyte_assets(
*,
connection_id: str,
workspace: AirbyteCloudResource,
workspace: AirbyteCloudWorkspace,
name: Optional[str] = None,
group_name: Optional[str] = None,
dagster_airbyte_translator: Optional[DagsterAirbyteTranslator] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import yaml
from dagster import (
AssetExecutionContext,
AssetKey,
AssetOut,
AutoMaterializePolicy,
Expand All @@ -33,6 +34,7 @@
SourceAsset,
_check as check,
)
from dagster._annotations import experimental
from dagster._core.definitions import AssetsDefinition, multi_asset
from dagster._core.definitions.cacheable_assets import (
AssetsDefinitionCacheableData,
Expand All @@ -45,7 +47,14 @@
from dagster._core.execution.context.init import build_init_resource_context
from dagster._utils.merger import merge_dicts

from dagster_airbyte.resources import AirbyteCloudResource, AirbyteResource, BaseAirbyteResource
from dagster_airbyte.asset_decorator import airbyte_assets
from dagster_airbyte.resources import (
AirbyteCloudResource,
AirbyteCloudWorkspace,
AirbyteResource,
BaseAirbyteResource,
)
from dagster_airbyte.translator import AirbyteMetadataSet, DagsterAirbyteTranslator
from dagster_airbyte.types import AirbyteTableMetadata
from dagster_airbyte.utils import (
generate_materializations,
Expand Down Expand Up @@ -1032,3 +1041,112 @@ def load_assets_from_airbyte_instance(
connection_to_freshness_policy_fn=connection_to_freshness_policy_fn,
connection_to_auto_materialize_policy_fn=connection_to_auto_materialize_policy_fn,
)


# -----------------------
# Reworked assets factory
# -----------------------


@experimental
def build_airbyte_assets_definitions(
*,
workspace: AirbyteCloudWorkspace,
dagster_airbyte_translator: Optional[DagsterAirbyteTranslator] = None,
) -> Sequence[AssetsDefinition]:
"""The list of AssetsDefinition for all connections in the Airbyte workspace.
Args:
workspace (AirbyteCloudWorkspace): The Airbyte workspace to fetch assets from.
dagster_airbyte_translator (Optional[DagsterAirbyteTranslator], optional): The translator to use
to convert Airbyte content into :py:class:`dagster.AssetSpec`.
Defaults to :py:class:`DagsterAirbyteTranslator`.
Returns:
List[AssetsDefinition]: The list of AssetsDefinition for all connections in the Airbyte workspace.
Examples:
Sync the tables of a Airbyte connection:
.. code-block:: python
from dagster_airbyte import AirbyteCloudWorkspace, build_airbyte_assets_definitions
import dagster as dg
airbyte_workspace = AirbyteCloudWorkspace(
workspace_id=dg.EnvVar("AIRBYTE_CLOUD_WORKSPACE_ID"),
client_id=dg.EnvVar("AIRBYTE_CLOUD_CLIENT_ID"),
client_secret=dg.EnvVar("AIRBYTE_CLOUD_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets_definitions(workspace=workspace)
defs = dg.Definitions(
assets=airbyte_assets,
resources={"airbyte": airbyte_workspace},
)
Sync the tables of a Airbyte connection with a custom translator:
.. code-block:: python
from dagster_airbyte import (
DagsterAirbyteTranslator,
AirbyteConnectionTableProps,
AirbyteCloudWorkspace,
build_airbyte_assets_definitions
)
import dagster as dg
class CustomDagsterAirbyteTranslator(DagsterAirbyteTranslator):
def get_asset_spec(self, props: AirbyteConnectionTableProps) -> dg.AssetSpec:
default_spec = super().get_asset_spec(props)
return default_spec.replace_attributes(
key=asset_spec.key.with_prefix("my_prefix"),
)
airbyte_workspace = AirbyteCloudWorkspace(
workspace_id=dg.EnvVar("AIRBYTE_CLOUD_WORKSPACE_ID"),
client_id=dg.EnvVar("AIRBYTE_CLOUD_CLIENT_ID"),
client_secret=dg.EnvVar("AIRBYTE_CLOUD_CLIENT_SECRET"),
)
airbyte_assets = build_airbyte_assets_definitions(
workspace=workspace,
dagster_airbyte_translator=CustomDagsterAirbyteTranslator()
)
defs = dg.Definitions(
assets=airbyte_assets,
resources={"airbyte": airbyte_workspace},
)
"""
dagster_airbyte_translator = dagster_airbyte_translator or DagsterAirbyteTranslator()

all_asset_specs = workspace.load_asset_specs(
dagster_airbyte_translator=dagster_airbyte_translator
)

connection_ids = {
check.not_none(AirbyteMetadataSet.extract(spec.metadata).connection_id)
for spec in all_asset_specs
}

_asset_fns = []
for connection_id in connection_ids:

@airbyte_assets(
connection_id=connection_id,
workspace=workspace,
name=_clean_name(connection_id),
group_name=_clean_name(connection_id),
dagster_airbyte_translator=dagster_airbyte_translator,
)
def _asset_fn(context: AssetExecutionContext, airbyte: AirbyteCloudWorkspace):
yield from airbyte.sync_and_poll(context=context)

_asset_fns.append(_asset_fn)

return _asset_fns
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from abc import abstractmethod
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, cast

import requests
from dagster import (
AssetExecutionContext,
ConfigurableResource,
Definitions,
Failure,
InitResourceContext,
OpExecutionContext,
_check as check,
get_dagster_logger,
resource,
Expand Down Expand Up @@ -1172,6 +1174,49 @@ def fetch_airbyte_workspace_data(
destinations_by_id=destinations_by_id,
)

@cached_method
def load_asset_specs(
self,
dagster_airbyte_translator: Optional[DagsterAirbyteTranslator] = None,
) -> Sequence[AssetSpec]:
"""Returns a list of AssetSpecs representing the Airbyte content in the workspace.
Args:
dagster_airbyte_translator (Optional[DagsterAirbyteTranslator], optional): The translator to use
to convert Airbyte content into :py:class:`dagster.AssetSpec`.
Defaults to :py:class:`DagsterAirbyteTranslator`.
Returns:
List[AssetSpec]: The set of assets representing the Airbyte content in the workspace.
Examples:
Loading the asset specs for a given Airbyte workspace:
.. code-block:: python
from dagster_airbyte import AirbyteCloudWorkspace
import dagster as dg
airbyte_workspace = AirbyteCloudWorkspace(
workspace_id=dg.EnvVar("AIRBYTE_CLOUD_WORKSPACE_ID"),
client_id=dg.EnvVar("AIRBYTE_CLOUD_CLIENT_ID"),
client_secret=dg.EnvVar("AIRBYTE_CLOUD_CLIENT_SECRET"),
)
airbyte_specs = airbyte_workspace.load_asset_specs()
defs = dg.Definitions(assets=airbyte_specs, resources={"airbyte": airbyte_workspace}
"""
dagster_airbyte_translator = dagster_airbyte_translator or DagsterAirbyteTranslator()

return load_airbyte_cloud_asset_specs(
workspace=self, dagster_airbyte_translator=dagster_airbyte_translator
)

def sync_and_poll(
self, context: Optional[Union[OpExecutionContext, AssetExecutionContext]] = None
):
raise NotImplementedError()


@experimental
def load_airbyte_cloud_asset_specs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
TEST_CLIENT_ID = "some_client_id"
TEST_CLIENT_SECRET = "some_client_secret"

TEST_ANOTHER_WORKSPACE_ID = "some_other_workspace_id"

TEST_ACCESS_TOKEN = "some_access_token"

# Taken from the examples in the Airbyte REST API documentation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import responses
from dagster._config.field_utils import EnvVar
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.tags import has_kind
from dagster._core.test_utils import environ
from dagster_airbyte import AirbyteCloudWorkspace, load_airbyte_cloud_asset_specs
from dagster_airbyte import (
AirbyteCloudWorkspace,
build_airbyte_assets_definitions,
load_airbyte_cloud_asset_specs,
)
from dagster_airbyte.translator import (
AirbyteConnectionTableProps,
AirbyteMetadataSet,
DagsterAirbyteTranslator,
)

from dagster_airbyte_tests.experimental.conftest import (
TEST_ANOTHER_WORKSPACE_ID,
TEST_CLIENT_ID,
TEST_CLIENT_SECRET,
TEST_CONNECTION_ID,
TEST_DESTINATION_TYPE,
TEST_WORKSPACE_ID,
)

Expand Down Expand Up @@ -46,3 +60,105 @@ def test_translator_spec(
# Test the asset key for the connection table
the_asset_key = next(iter(all_assets_keys))
assert the_asset_key.path == ["test_prefix_test_stream"]

first_asset_metadata = next(asset.metadata for asset in all_assets)
assert AirbyteMetadataSet.extract(first_asset_metadata).connection_id == TEST_CONNECTION_ID


def test_cached_load_spec_single_resource(
fetch_workspace_data_api_mocks: responses.RequestsMock,
) -> None:
with environ(
{"AIRBYTE_CLIENT_ID": TEST_CLIENT_ID, "AIRBYTE_CLIENT_SECRET": TEST_CLIENT_SECRET}
):
workspace = AirbyteCloudWorkspace(
workspace_id=TEST_WORKSPACE_ID,
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)

# load asset specs a first time
workspace.load_asset_specs()
assert len(fetch_workspace_data_api_mocks.calls) == 4

# load asset specs a first time, no additional calls are made
workspace.load_asset_specs()
assert len(fetch_workspace_data_api_mocks.calls) == 4


def test_cached_load_spec_multiple_resources(
fetch_workspace_data_api_mocks: responses.RequestsMock,
) -> None:
with environ(
{"AIRBYTE_CLIENT_ID": TEST_CLIENT_ID, "AIRBYTE_CLIENT_SECRET": TEST_CLIENT_SECRET}
):
workspace = AirbyteCloudWorkspace(
workspace_id=TEST_WORKSPACE_ID,
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)

another_workspace = AirbyteCloudWorkspace(
workspace_id=TEST_ANOTHER_WORKSPACE_ID,
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)

# load asset specs with a resource
workspace.load_asset_specs()
assert len(fetch_workspace_data_api_mocks.calls) == 4

# load asset specs with another resource,
# additional calls are made to load its specs
another_workspace.load_asset_specs()
assert len(fetch_workspace_data_api_mocks.calls) == 4 + 4


def test_cached_load_spec_with_asset_factory(
fetch_workspace_data_api_mocks: responses.RequestsMock,
) -> None:
with environ(
{"AIRBYTE_CLIENT_ID": TEST_CLIENT_ID, "AIRBYTE_CLIENT_SECRET": TEST_CLIENT_SECRET}
):
workspace = AirbyteCloudWorkspace(
workspace_id=TEST_WORKSPACE_ID,
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)

# build_airbyte_assets_definitions calls workspace.load_asset_specs to get the connection IDs,
# then workspace.load_asset_specs is called once per connection ID in airbyte_assets,
# but the four calls to the API are only made once.
build_airbyte_assets_definitions(workspace=workspace)
assert len(fetch_workspace_data_api_mocks.calls) == 4


class MyCustomTranslator(DagsterAirbyteTranslator):
def get_asset_spec(self, data: AirbyteConnectionTableProps) -> AssetSpec:
default_spec = super().get_asset_spec(data)
return default_spec.replace_attributes(
key=default_spec.key.with_prefix("test_connection"),
).merge_attributes(metadata={"custom": "metadata"})


def test_translator_custom_metadata(
fetch_workspace_data_api_mocks: responses.RequestsMock,
) -> None:
with environ(
{"AIRBYTE_CLIENT_ID": TEST_CLIENT_ID, "AIRBYTE_CLIENT_SECRET": TEST_CLIENT_SECRET}
):
workspace = AirbyteCloudWorkspace(
workspace_id=TEST_WORKSPACE_ID,
client_id=EnvVar("AIRBYTE_CLIENT_ID"),
client_secret=EnvVar("AIRBYTE_CLIENT_SECRET"),
)
all_asset_specs = workspace.load_asset_specs(
dagster_airbyte_translator=MyCustomTranslator()
)
asset_spec = next(spec for spec in all_asset_specs)

assert "custom" in asset_spec.metadata
assert asset_spec.metadata["custom"] == "metadata"
assert asset_spec.key.path == ["test_connection", "test_prefix_test_stream"]
assert has_kind(asset_spec.tags, "airbyte")
assert has_kind(asset_spec.tags, TEST_DESTINATION_TYPE)
Loading

0 comments on commit aca5ae6

Please sign in to comment.