Skip to content

Commit

Permalink
[14/n][dagster-airbyte] Implement AirbyteCloudWorkspace.sync_and_poll
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Dec 18, 2024
1 parent 71c99c0 commit 7d77d65
Show file tree
Hide file tree
Showing 10 changed files with 345 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from dagster_airbyte.resources import AirbyteCloudWorkspace
from dagster_airbyte.translator import AirbyteMetadataSet, DagsterAirbyteTranslator
from dagster_airbyte.utils import DAGSTER_AIRBYTE_TRANSLATOR_METADATA_KEY


@experimental
Expand Down Expand Up @@ -101,14 +102,18 @@ def airbyte_connection_assets(context: dg.AssetExecutionContext, airbyte: Airbyt
resources={"airbyte": airbyte_workspace},
)
"""
dagster_airbyte_translator = dagster_airbyte_translator or DagsterAirbyteTranslator()

return multi_asset(
name=name,
group_name=group_name,
can_subset=False,
can_subset=True,
specs=[
spec
spec.merge_attributes(
metadata={DAGSTER_AIRBYTE_TRANSLATOR_METADATA_KEY: dagster_airbyte_translator}
)
for spec in workspace.load_asset_specs(
dagster_airbyte_translator=dagster_airbyte_translator or DagsterAirbyteTranslator()
dagster_airbyte_translator=dagster_airbyte_translator
)
if AirbyteMetadataSet.extract(spec.metadata).connection_id == connection_id
],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import hashlib
import inspect
import os
import re
from abc import abstractmethod
from functools import partial
from itertools import chain
Expand Down Expand Up @@ -57,6 +56,7 @@
from dagster_airbyte.translator import AirbyteMetadataSet, DagsterAirbyteTranslator
from dagster_airbyte.types import AirbyteTableMetadata
from dagster_airbyte.utils import (
clean_name,
generate_materializations,
generate_table_schema,
is_basic_normalization_operation,
Expand Down Expand Up @@ -471,11 +471,6 @@ def _get_normalization_tables_for_schema(
return out


def _clean_name(name: str) -> str:
"""Cleans an input to be a valid Dagster asset name."""
return re.sub(r"[^a-z0-9]+", "_", name.lower())


class AirbyteConnectionMetadata(
NamedTuple(
"_AirbyteConnectionMetadata",
Expand Down Expand Up @@ -917,7 +912,7 @@ def load_assets_from_airbyte_instance(
workspace_id: Optional[str] = None,
key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
create_assets_for_normalization_tables: bool = True,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]] = _clean_name,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]] = clean_name,
connection_meta_to_group_fn: Optional[
Callable[[AirbyteConnectionMetadata], Optional[str]]
] = None,
Expand Down Expand Up @@ -1022,7 +1017,7 @@ def load_assets_from_airbyte_instance(
check.invariant(
not connection_meta_to_group_fn
or not connection_to_group_fn
or connection_to_group_fn == _clean_name,
or connection_to_group_fn == clean_name,
"Cannot specify both connection_meta_to_group_fn and connection_to_group_fn",
)

Expand Down Expand Up @@ -1140,8 +1135,8 @@ def get_asset_spec(self, props: AirbyteConnectionTableProps) -> dg.AssetSpec:
@airbyte_assets(
connection_id=connection_id,
workspace=workspace,
name=_clean_name(connection_id),
group_name=_clean_name(connection_id),
name=clean_name(connection_id),
group_name=clean_name(connection_id),
dagster_airbyte_translator=dagster_airbyte_translator,
)
def _asset_fn(context: AssetExecutionContext, airbyte: AirbyteCloudWorkspace):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from dagster_airbyte.asset_defs import (
AirbyteConnectionMetadata,
AirbyteInstanceCacheableAssetsDefinition,
_clean_name,
)
from dagster_airbyte.managed.types import (
MANAGED_ELEMENTS_DEPRECATION_MSG,
Expand All @@ -50,7 +49,7 @@
InitializedAirbyteSource,
)
from dagster_airbyte.resources import AirbyteResource
from dagster_airbyte.utils import is_basic_normalization_operation
from dagster_airbyte.utils import is_basic_normalization_operation, clean_name


def gen_configured_stream_json(
Expand Down Expand Up @@ -746,7 +745,7 @@ def load_assets_from_connections(
connections: Iterable[AirbyteConnection],
key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
create_assets_for_normalization_tables: bool = True,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]] = _clean_name,
connection_to_group_fn: Optional[Callable[[str], Optional[str]]] = clean_name,
connection_meta_to_group_fn: Optional[
Callable[[AirbyteConnectionMetadata], Optional[str]]
] = None,
Expand Down Expand Up @@ -821,7 +820,7 @@ def load_assets_from_connections(
check.invariant(
not connection_meta_to_group_fn
or not connection_to_group_fn
or connection_to_group_fn == _clean_name,
or connection_to_group_fn == clean_name,
"Cannot specify both connection_meta_to_group_fn and connection_to_group_fn",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
import requests
from dagster import (
AssetExecutionContext,
AssetMaterialization,
ConfigurableResource,
Definitions,
Failure,
InitResourceContext,
MaterializeResult,
OpExecutionContext,
_check as check,
get_dagster_logger,
Expand All @@ -34,13 +36,19 @@

from dagster_airbyte.translator import (
AirbyteConnection,
AirbyteConnectionTableProps,
AirbyteDestination,
AirbyteJob,
AirbyteJobStatusType,
AirbyteMetadataSet,
AirbyteWorkspaceData,
DagsterAirbyteTranslator,
)
from dagster_airbyte.types import AirbyteOutput
from dagster_airbyte.utils import (
get_airbyte_connection_table_name,
get_translator_from_airbyte_assets,
)

AIRBYTE_REST_API_BASE = "https://api.airbyte.com"
AIRBYTE_REST_API_VERSION = "v1"
Expand Down Expand Up @@ -1211,10 +1219,91 @@ def load_asset_specs(
workspace=self, dagster_airbyte_translator=dagster_airbyte_translator
)

def _generate_materialization(
self,
airbyte_output: AirbyteOutput,
dagster_airbyte_translator: DagsterAirbyteTranslator,
):
connection = AirbyteConnection.from_connection_details(
connection_details=airbyte_output.connection_details
)

for stream in connection.streams.values():
if stream.selected:
connection_table_name = get_airbyte_connection_table_name(
stream_prefix=connection.stream_prefix,
stream_name=stream.name,
)
stream_asset_spec = dagster_airbyte_translator.get_asset_spec(
props=AirbyteConnectionTableProps(
table_name=connection_table_name,
stream_prefix=connection.stream_prefix,
stream_name=stream.name,
json_schema=stream.json_schema,
connection_id=connection.id,
connection_name=connection.name,
destination_type=None,
database=None,
schema=None,
)
)

yield AssetMaterialization(
asset_key=stream_asset_spec.key,
description=(
f"Table generated via Airbyte Cloud sync "
f"for connection {connection.name}: {connection_table_name}"
),
metadata=stream_asset_spec.metadata,
)

def sync_and_poll(
self, context: Optional[Union[OpExecutionContext, AssetExecutionContext]] = None
self, context: Union[OpExecutionContext, AssetExecutionContext]
):
raise NotImplementedError()
"""Executes a sync and poll process to materialize Airbyte Cloud assets.
Args:
context (Union[OpExecutionContext, AssetExecutionContext]): The execution context
from within `@airbyte_assets`. If an AssetExecutionContext is passed,
its underlying OpExecutionContext will be used.
Returns:
Iterator[Union[AssetMaterialization, MaterializeResult]]: An iterator of MaterializeResult
or AssetMaterialization.
"""
assets_def = context.assets_def
dagster_airbyte_translator = get_translator_from_airbyte_assets(assets_def)
connection_id = next(
check.not_none(AirbyteMetadataSet.extract(spec.metadata).connection_id)
for spec in assets_def.specs
)

client = self.get_client()
airbyte_output = client.sync_and_poll(
connection_id=connection_id,
)

materialized_asset_keys = set()
for materialization in self._generate_materialization(
airbyte_output=airbyte_output, dagster_airbyte_translator=dagster_airbyte_translator
):
# Scan through all tables actually created, if it was expected then emit a MaterializeResult.
# Otherwise, emit a runtime AssetMaterialization.
if materialization.asset_key in context.selected_asset_keys:
yield MaterializeResult(
asset_key=materialization.asset_key, metadata=materialization.metadata
)
materialized_asset_keys.add(materialization.asset_key)
else:
context.log.warning(
f"An unexpected asset was materialized: {materialization.asset_key}. "
f"Yielding a materialization event."
)
yield materialization

unmaterialized_asset_keys = context.selected_asset_keys - materialized_asset_keys
if unmaterialized_asset_keys:
context.log.warning(f"Assets were not materialized: {unmaterialized_asset_keys}")


@experimental
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AirbyteConnectionTableProps:
json_schema: Mapping[str, Any]
connection_id: str
connection_name: str
destination_type: str
destination_type: Optional[str]
database: Optional[str]
schema: Optional[str]

Expand Down Expand Up @@ -231,5 +231,5 @@ def get_asset_spec(self, props: AirbyteConnectionTableProps) -> AssetSpec:
return AssetSpec(
key=AssetKey(props.table_name),
metadata=metadata,
kinds={"airbyte", props.destination_type},
kinds={"airbyte", *({props.destination_type} if props.destination_type else set())},
)
35 changes: 33 additions & 2 deletions python_modules/libraries/dagster-airbyte/dagster_airbyte/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
from typing import Any, Iterator, Mapping, Optional, Sequence
import re
from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional, Sequence

from dagster import AssetMaterialization, MetadataValue
from dagster import (
AssetMaterialization,
AssetsDefinition,
DagsterInvariantViolationError,
MetadataValue,
)
from dagster._core.definitions.metadata.table import TableColumn, TableSchema

from dagster_airbyte.types import AirbyteOutput

if TYPE_CHECKING:
from dagster_airbyte import DagsterAirbyteTranslator

DAGSTER_AIRBYTE_TRANSLATOR_METADATA_KEY = "dagster-airbyte/dagster_airbyte_translator"


def clean_name(name: str) -> str:
"""Cleans an input to be a valid Dagster asset name."""
return re.sub(r"[^a-z0-9]+", "_", name.lower())


def get_airbyte_connection_table_name(stream_prefix: Optional[str], stream_name: str) -> str:
return f"{stream_prefix if stream_prefix else ''}{stream_name}"
Expand Down Expand Up @@ -78,3 +94,18 @@ def generate_materializations(
all_stream_stats.get(stream_name, {}),
asset_key_prefix=asset_key_prefix,
)


def get_translator_from_airbyte_assets(
airbyte_assets: AssetsDefinition,
) -> "DagsterAirbyteTranslator":
metadata_by_key = airbyte_assets.metadata_by_key or {}
first_asset_key = next(iter(airbyte_assets.metadata_by_key.keys()))
first_metadata = metadata_by_key.get(first_asset_key, {})
dagster_airbyte_translator = first_metadata.get(DAGSTER_AIRBYTE_TRANSLATOR_METADATA_KEY)
if dagster_airbyte_translator is None:
raise DagsterInvariantViolationError(
f"Expected to find airbyte translator metadata on asset {first_asset_key.to_user_string()},"
" but did not. Did you pass in assets that weren't generated by @airbyte_assets?"
)
return dagster_airbyte_translator
Loading

0 comments on commit 7d77d65

Please sign in to comment.