Skip to content

Commit

Permalink
[dagster-airlift] assets_from_task (#23813)
Browse files Browse the repository at this point in the history
adds a simple defs builder which ingests a list of specs for a
particular dag/task id, to make the observe step more straightforward.
  • Loading branch information
dpeng817 authored Aug 23, 2024
1 parent 3e179cd commit 72023ef
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
DefsFactory as DefsFactory,
defs_from_factories as defs_from_factories,
)
from .defs_builders import (
combine_defs as combine_defs,
specs_from_task as specs_from_task,
)
from .defs_from_airflow import (
AirflowInstance as AirflowInstance,
build_defs_from_airflow_instance as build_defs_from_airflow_instance,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Sequence, Union

from dagster import AssetsDefinition, AssetSpec
from dagster._core.definitions.asset_key import CoercibleToAssetKey
from dagster._core.definitions.definitions_class import Definitions
from typing_extensions import TypeAlias

from .utils import DAG_ID_TAG, TASK_ID_TAG

CoercibleToAssetSpec: TypeAlias = Union[AssetSpec, CoercibleToAssetKey]


def specs_from_task(
*, task_id: str, dag_id: str, assets: Sequence[CoercibleToAssetSpec]
) -> Sequence[AssetSpec]:
"""Construct a Dagster :py:class:`Definitions` object from a provided set of assets,
with a mapping to which airflow task produces those assets.
"""
return [
asset
if isinstance(asset, AssetSpec)
else AssetSpec(key=asset, tags={DAG_ID_TAG: dag_id, TASK_ID_TAG: task_id})
for asset in assets
]


def combine_defs(*defs: Union[AssetsDefinition, Definitions, AssetSpec]) -> Definitions:
"""Combine provided :py:class:`Definitions` objects and assets into a single object, which contains all constituent definitions."""
assets = []
for _def in defs:
if isinstance(_def, Definitions):
continue
elif isinstance(_def, AssetsDefinition):
assets.append(_def)
elif isinstance(_def, AssetSpec):
assets.append(_def)
else:
raise Exception(f"Unexpected type: {type(_def)}")

return Definitions.merge(
*[the_def for the_def in defs if isinstance(the_def, Definitions)],
Definitions(assets=assets),
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from dagster import AssetKey, AssetSpec, asset, multi_asset
from dagster import AssetKey, AssetSpec, Definitions, asset, multi_asset
from dagster._check.functions import CheckError
from dagster_airlift.core import combine_defs, specs_from_task
from dagster_airlift.core.utils import get_dag_id_from_asset, get_task_id_from_asset


Expand Down Expand Up @@ -145,3 +146,33 @@ def other_dag__other_task():

with pytest.raises(CheckError):
get_task_id_from_asset(other_dag__other_task)


def test_specs_to_tasks() -> None:
"""Tests basic conversion of specs to tasks."""
specs = ["1", AssetSpec(key=AssetKey(["2"]))]
defs = specs_from_task(task_id="task", dag_id="dag", assets=specs)
assert all(isinstance(_def, AssetSpec) for _def in defs)
assert len(list(defs)) == 2
spec = next(iter(defs))
assert spec.tags["airlift/dag_id"] == "dag"


def test_combine_defs() -> None:
"""Tests functionality of combine_defs."""

@asset
def a():
pass

@asset
def b():
pass

@asset
def c():
pass

defs = combine_defs(a, Definitions(assets=[b]), Definitions(assets=[c]))
assert defs.assets
assert len(list(defs.assets)) == 3
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,18 @@
from dbt_example.shared.load_iris import load_csv_to_duckdb


class LoadCSVToDuckDB(BaseOperator):
def __init__(
self,
table_name: str,
csv_path: Path,
duckdb_path: Path,
column_names: List[str],
duckdb_schema: str,
duckdb_database_name: str,
*args,
**kwargs,
):
self._table_name = table_name
class LoadToLakehouseOperator(BaseOperator):
def __init__(self, csv_path: Path, db_path: Path, columns: List[str], *args, **kwargs):
self._csv_path = csv_path
self._duckdb_path = duckdb_path
self._column_names = column_names
self._duckdb_schema = duckdb_schema
self._duckdb_database_name = duckdb_database_name
self._db_path = db_path
self._column_names = columns
super().__init__(*args, **kwargs)

def execute(self, context) -> None:
load_csv_to_duckdb(
table_name=self._table_name,
csv_path=self._csv_path,
duckdb_path=self._duckdb_path,
names=self._column_names,
duckdb_schema=self._duckdb_schema,
duckdb_database_name=self._duckdb_database_name,
db_path=self._db_path,
columns=self._column_names,
)


Expand All @@ -49,21 +33,18 @@ def execute(self, context) -> None:
}

dag = DAG("load_lakehouse", default_args=default_args, schedule_interval=None)
load_iris = LoadCSVToDuckDB(
load_iris = LoadToLakehouseOperator(
task_id="load_iris",
dag=dag,
table_name="iris_lakehouse_table",
csv_path=Path(__file__).parent / "iris.csv",
duckdb_path=Path(os.environ["AIRFLOW_HOME"]) / "jaffle_shop.duckdb",
column_names=[
db_path=Path(os.environ["AIRFLOW_HOME"]) / "jaffle_shop.duckdb",
columns=[
"sepal_length_cm",
"sepal_width_cm",
"petal_length_cm",
"petal_width_cm",
"species",
],
duckdb_schema="iris_dataset",
duckdb_database_name="jaffle_shop",
)
mark_as_dagster_migrating(
global_vars=globals(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,28 @@
from dagster import AssetKey, AssetSpec, Definitions, multi_asset
from dagster_airlift.core import DefsFactory

from dbt_example.shared.load_iris import load_csv_to_duckdb
from dbt_example.shared.load_iris import id_from_path, load_csv_to_duckdb


@dataclass
class CSVToDuckdbDefs(DefsFactory):
table_name: str
duckdb_schema: str
name: str
csv_path: Optional[Path] = None
csv_path: Path
duckdb_path: Optional[Path] = None
column_names: Optional[List[str]] = None
duckdb_database_name: Optional[str] = None
columns: Optional[List[str]] = None

def build_defs(self) -> Definitions:
asset_spec = AssetSpec(key=AssetKey([self.duckdb_schema, self.table_name]))
asset_spec = AssetSpec(key=AssetKey(["lakehouse", id_from_path(self.csv_path)]))

@multi_asset(specs=[asset_spec], name=self.name)
def _multi_asset():
if self.duckdb_path is None:
def _multi_asset() -> None:
if self.duckdb_path is None or self.columns is None:
raise Exception("This asset is not yet executable. Need to provide a duckdb_path.")
else:
load_csv_to_duckdb(
table_name=self.table_name,
csv_path=self.csv_path,
duckdb_path=self.duckdb_path,
names=self.column_names,
duckdb_schema=self.duckdb_schema,
duckdb_database_name=self.duckdb_database_name,
db_path=self.duckdb_path,
columns=self.columns,
)

return Definitions(assets=[_multi_asset])
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dagster_airlift.dbt import DbtProjectDefs
from dagster_dbt import DbtProject

from dbt_example.dagster_defs.csv_to_duckdb_defs import CSVToDuckdbDefs
from dbt_example.dagster_defs.lakehouse import CSVToDuckdbDefs

from .constants import AIRFLOW_BASE_URL, AIRFLOW_INSTANCE_NAME, PASSWORD, USERNAME, dbt_project_path

Expand All @@ -23,18 +23,15 @@
orchestrated_defs=defs_from_factories(
CSVToDuckdbDefs(
name="load_lakehouse__load_iris",
table_name="iris_lakehouse_table",
csv_path=Path("iris.csv"),
duckdb_path=Path(os.environ["AIRFLOW_HOME"]) / "jaffle_shop.duckdb",
column_names=[
columns=[
"sepal_length_cm",
"sepal_width_cm",
"petal_length_cm",
"petal_width_cm",
"species",
],
duckdb_schema="iris_dataset",
duckdb_database_name="jaffle_shop",
),
DbtProjectDefs(
name="dbt_dag__build_dbt_models",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from pathlib import Path

from dagster_airlift.core import AirflowInstance, BasicAuthBackend, build_defs_from_airflow_instance
from dagster_airlift.core.def_factory import defs_from_factories
from dagster_airlift.dbt import DbtProjectDefs

from dbt_example.dagster_defs.csv_to_duckdb_defs import CSVToDuckdbDefs
from dbt_example.dagster_defs.lakehouse import CSVToDuckdbDefs

from .constants import AIRFLOW_BASE_URL, AIRFLOW_INSTANCE_NAME, PASSWORD, USERNAME, dbt_project_path

Expand All @@ -19,8 +21,7 @@
orchestrated_defs=defs_from_factories(
CSVToDuckdbDefs(
name="load_lakehouse__load_iris",
table_name="iris_lakehouse_table",
duckdb_schema="iris_dataset",
csv_path=Path("iris.csv"),
),
DbtProjectDefs(
name="dbt_dag__build_dbt_models",
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
select * from {{ source('iris_dataset', 'iris_lakehouse_table') }} where species = 'Iris-setosa'
select * from {{ source('lakehouse', 'iris_lakehouse_table') }} where species = 'Iris-setosa'
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
version: 2

sources:
- name: iris_dataset
- name: lakehouse
database: jaffle_shop
schema: iris_dataset
schema: lakehouse
tables:
- name: iris_lakehouse_table
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,31 @@
import pandas as pd


def id_from_path(csv_path: Path) -> str:
return csv_path.stem


def load_csv_to_duckdb(
*,
table_name: str,
csv_path: Path,
duckdb_path: Path,
names: List[str],
duckdb_schema: str,
duckdb_database_name: str,
db_path: Path,
columns: List[str],
) -> None:
# Ensure that path exists
if not csv_path.exists():
raise ValueError(f"CSV file not found at {csv_path}")
if not duckdb_path.exists():
raise ValueError(f"DuckDB database not found at {duckdb_path}")
# Duckdb database stored in airflow home
if not db_path.exists():
raise ValueError(f"Database not found at {db_path}")
df = pd.read_csv( # noqa: F841 # used by duckdb
csv_path,
names=names,
names=columns,
)

table_name = id_from_path(csv_path)
db_name = id_from_path(db_path)
# Connect to DuckDB and create a new table
con = duckdb.connect(str(duckdb_path))
con.execute(f"CREATE SCHEMA IF NOT EXISTS {duckdb_schema}").fetchall()
con = duckdb.connect(str(db_path))
con.execute("CREATE SCHEMA IF NOT EXISTS lakehouse").fetchall()
con.execute(
f"CREATE TABLE IF NOT EXISTS {duckdb_database_name}.{duckdb_schema}.{table_name} AS SELECT * FROM df"
f"CREATE TABLE IF NOT EXISTS {db_name}.lakehouse.{table_name} AS SELECT * FROM df"
).fetchall()
con.close()

0 comments on commit 72023ef

Please sign in to comment.