Skip to content

Commit

Permalink
[dagster-airlift] remove naming convention (#24417)
Browse files Browse the repository at this point in the history
## Summary & Motivation
Remove convention-based approaches for setting dag id and task ID.

## Changelog
`NOCHANGELOG`
  • Loading branch information
dpeng817 authored Sep 12, 2024
1 parent 431964d commit af4fab9
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ def build_defs_from_airflow_instance(
Each airflow dag will be represented as a dagster asset, and each dag run will be represented as an asset materialization. A :py:class:`dagster.SensorDefinition` provided in the returned Definitions will be used to poll for dag runs.
The provided orchestrated defs are expected to contain fully qualified :py:class:`dagster.AssetsDefinition` objects, each of which should be mapped to a task and dag in the provided airflow instance. Using the airflow instance,
dagster will provide dependency information between the assets representing tasks, and the dags that they contain. The included :py:class:`dagster.SensorDefinition` will poll for dag runs and materialize runs including each task as an asset for that task.
There are two ways that the mapping can be done on a provided definition:
1. By using the `airlift/dag_id` and `airlift/task_id` op tags on the underlying :py:class:`dagster.NodeDefinition` for the asset.
2. By using an opinionated naming format on the :py:class:`dagster.NodeDefinition` for the asset. The naming format is `dag_id__task_id`.
Args:
airflow_instance (AirflowInstance): The airflow instance to peer with.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,11 @@ def convert_to_valid_dagster_name(name: str) -> str:


def get_task_id_from_asset(asset: Union[AssetsDefinition, AssetSpec]) -> Optional[str]:
return _get_prop_from_asset(asset, TASK_ID_METADATA_KEY, 1)
return prop_from_metadata(asset, TASK_ID_METADATA_KEY)


def get_dag_id_from_asset(asset: Union[AssetsDefinition, AssetSpec]) -> Optional[str]:
return _get_prop_from_asset(asset, DAG_ID_METADATA_KEY, 0)


def _get_prop_from_asset(
asset: Union[AssetSpec, AssetsDefinition], prop_metadata_key: str, position: int
) -> Optional[str]:
prop_from_asset_tags = prop_from_metadata(asset, prop_metadata_key)
if isinstance(asset, AssetSpec) or not asset.is_executable:
return prop_from_asset_tags
prop_from_op_tags = None
if asset.node_def.tags and prop_metadata_key in asset.node_def.tags:
prop_from_op_tags = asset.node_def.tags[prop_metadata_key]
prop_from_name = None
if len(asset.node_def.name.split("__")) == 2:
prop_from_name = asset.node_def.name.split("__")[position]
if prop_from_asset_tags and prop_from_op_tags:
check.invariant(
prop_from_asset_tags == prop_from_op_tags,
f"ID mismatch between asset tags and op tags: {prop_from_asset_tags} != {prop_from_op_tags}",
)
if prop_from_asset_tags and prop_from_name:
check.invariant(
prop_from_asset_tags == prop_from_name,
f"ID mismatch between tags and name: {prop_from_asset_tags} != {prop_from_name}",
)
if prop_from_op_tags and prop_from_name:
check.invariant(
prop_from_op_tags == prop_from_name,
f"ID mismatch between op tags and name: {prop_from_op_tags} != {prop_from_name}",
)
return prop_from_asset_tags or prop_from_op_tags or prop_from_name
return prop_from_metadata(asset, DAG_ID_METADATA_KEY)


def prop_from_metadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def get_dagster_url(self, context: Context) -> str:

def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> None:
"""Launches runs for the given task in Dagster."""
expected_op_name = f"{dag_id}__{task_id}"
session = self._get_validated_session(context)

dagster_url = self.get_dagster_url(context)
Expand All @@ -64,8 +63,7 @@ def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> N
for entry in asset_node["metadataEntries"]
if entry["__typename"] == "TextMetadataEntry"
}
# match assets based on conventional dag_id__task_id naming or based on explicit tags
if asset_node["opName"] == expected_op_name or (
if (
text_metadata_entries.get(DAG_ID_METADATA_KEY) == dag_id
and text_metadata_entries.get(TASK_ID_METADATA_KEY) == task_id
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
}
__typename
}
opName
jobs {
id
name
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dagster import Definitions, asset
from dagster_airlift.constants import DAG_ID_METADATA_KEY, TASK_ID_METADATA_KEY


@asset
@asset(metadata={DAG_ID_METADATA_KEY: "the_dag", TASK_ID_METADATA_KEY: "some_task"})
def the_dag__some_task():
return "asset_value"

Expand All @@ -11,7 +12,7 @@ def unrelated():
return "unrelated_value"


@asset
@asset(metadata={DAG_ID_METADATA_KEY: "the_dag", TASK_ID_METADATA_KEY: "other_task"})
def the_dag__other_task():
return "other_task_value"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ def some_schedule():
pass


@asset
def dag__task():
pass


@asset
def a():
pass
Expand All @@ -63,11 +58,6 @@ def a_check():
pass


@asset_check(asset=dag__task)
def other_check():
pass


@job
def the_job():
pass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from dagster import AssetKey, AssetSpec, asset, multi_asset
from dagster._check.functions import CheckError
from dagster_airlift.core import specs_from_task
from dagster_airlift.core.utils import get_dag_id_from_asset, get_task_id_from_asset


Expand All @@ -16,8 +15,8 @@ def no_op():
assert get_task_id_from_asset(no_op) is None


def test_retrieve_by_asset_tag() -> None:
"""Test that we can retrieve the dag and task id from the asset tags. Test that error edge cases are properly handled."""
def test_retrieve_by_asset_metadata() -> None:
"""Test that we can retrieve the dag and task id from the asset metadata. Test that error edge cases are properly handled."""

# 1. Single spec retrieval
@asset(metadata={"airlift/dag_id": "print_dag", "airlift/task_id": "print_task"})
Expand Down Expand Up @@ -85,74 +84,3 @@ def multi_spec_task_mismatch():

with pytest.raises(CheckError):
get_task_id_from_asset(multi_spec_task_mismatch)


def test_retrieve_by_op_tag() -> None:
"""Test that we can retrieve the dag and task id from the op tags."""

@asset(op_tags={"airlift/dag_id": "print_dag", "airlift/task_id": "print_task"})
def the_asset():
pass

assert get_dag_id_from_asset(the_asset) == "print_dag"
assert get_task_id_from_asset(the_asset) == "print_task"


def test_retrieve_by_name() -> None:
"""Test that we can retrieve the dag and task id from the name."""

@asset
def print_dag__print_task():
pass

assert get_dag_id_from_asset(print_dag__print_task) == "print_dag"
assert get_task_id_from_asset(print_dag__print_task) == "print_task"


def test_op_asset_tag_mismatch() -> None:
@asset(
metadata={"airlift/dag_id": "print_dag", "airlift/task_id": "print_task"},
op_tags={"airlift/dag_id": "other_dag", "airlift/task_id": "other_task"},
)
def mismatched():
pass

with pytest.raises(CheckError):
get_dag_id_from_asset(mismatched)

with pytest.raises(CheckError):
get_task_id_from_asset(mismatched)


def test_op_asset_name_mismatch() -> None:
@asset(metadata={"airlift/dag_id": "print_dag", "airlift/task_id": "print_task"})
def other_dag__other_task():
pass

with pytest.raises(CheckError):
get_dag_id_from_asset(other_dag__other_task)

with pytest.raises(CheckError):
get_task_id_from_asset(other_dag__other_task)


def test_op_tag_name_mismatch() -> None:
@asset(op_tags={"airlift/dag_id": "print_dag", "airlift/task_id": "print_task"})
def other_dag__other_task():
pass

with pytest.raises(CheckError):
get_dag_id_from_asset(other_dag__other_task)

with pytest.raises(CheckError):
get_task_id_from_asset(other_dag__other_task)


def test_specs_to_tasks() -> None:
"""Tests basic conversion of specs to tasks."""
specs = ["1", AssetSpec(key=AssetKey(["2"]))]
defs = specs_from_task(task_id="task", dag_id="dag", assets=specs)
assert all(isinstance(_def, AssetSpec) for _def in defs)
assert len(list(defs)) == 2
spec = next(iter(defs))
assert spec.metadata["airlift/dag_id"] == "dag"

0 comments on commit af4fab9

Please sign in to comment.