Skip to content

Commit

Permalink
[dagster-airlift] Introduce dag handle, use in sensor (#25242)
Browse files Browse the repository at this point in the history
## Summary & Motivation
Introduce the notion of a "DagHandle" and use it in the
AirflowDefinitionsData object for checking the sensor
## How I Tested These Changes
Existing tests
## Changelog
NOCHANGELOG
  • Loading branch information
dpeng817 authored Oct 14, 2024
1 parent e2e5c91 commit dc9b6a7
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
from dagster import AssetKey, Definitions
from dagster._record import record

from dagster_airlift.constants import DAG_MAPPING_METADATA_KEY
from dagster_airlift.core.airflow_instance import AirflowInstance
from dagster_airlift.core.serialization.compute import AirliftMetadataMappingInfo
from dagster_airlift.core.serialization.serialized_data import TaskHandle
from dagster_airlift.core.utils import is_mapped_asset_spec, task_handles_for_spec
from dagster_airlift.core.serialization.serialized_data import DagHandle, TaskHandle
from dagster_airlift.core.utils import (
dag_handles_for_spec,
is_dag_mapped_asset_spec,
is_task_mapped_asset_spec,
task_handles_for_spec,
)


@record
Expand All @@ -36,20 +40,21 @@ def dag_ids_with_mapped_asset_keys(self) -> AbstractSet[str]:
def asset_keys_per_task_handle(self) -> Mapping[TaskHandle, AbstractSet[AssetKey]]:
asset_keys_per_handle = defaultdict(set)
for spec in self.mapped_defs.get_all_asset_specs():
if is_mapped_asset_spec(spec):
if is_task_mapped_asset_spec(spec):
task_handles = task_handles_for_spec(spec)
for task_handle in task_handles:
asset_keys_per_handle[task_handle].add(spec.key)
return asset_keys_per_handle

@cached_property
def asset_keys_per_dag(self) -> Mapping[str, AbstractSet[AssetKey]]:
dag_id_to_asset_key = defaultdict(set)
def asset_keys_per_dag_handle(self) -> Mapping[DagHandle, AbstractSet[AssetKey]]:
asset_keys_per_handle = defaultdict(set)
for spec in self.mapped_defs.get_all_asset_specs():
if DAG_MAPPING_METADATA_KEY in spec.metadata:
for mapping in spec.metadata[DAG_MAPPING_METADATA_KEY]:
dag_id_to_asset_key[mapping["dag_id"]].add(spec.key)
return dag_id_to_asset_key
if is_dag_mapped_asset_spec(spec):
dag_handles = dag_handles_for_spec(spec)
for dag_handle in dag_handles:
asset_keys_per_handle[dag_handle].add(spec.key)
return asset_keys_per_handle

def asset_keys_in_task(self, dag_id: str, task_id: str) -> AbstractSet[AssetKey]:
return self.asset_keys_per_task_handle[TaskHandle(dag_id=dag_id, task_id=task_id)]
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dagster_airlift.constants import EFFECTIVE_TIMESTAMP_METADATA_KEY
from dagster_airlift.core.airflow_defs_data import AirflowDefinitionsData
from dagster_airlift.core.airflow_instance import DagRun, TaskInstance
from dagster_airlift.core.serialization.serialized_data import DagHandle

AssetEvent = Union[AssetMaterialization, AssetObservation, AssetCheckEvaluation]
DagsterEventTransformerFn = Callable[
Expand All @@ -38,7 +39,7 @@ def materializations_for_dag_run(
AssetMaterialization(
asset_key=asset_key, description=dag_run.note, metadata=get_dag_run_metadata(dag_run)
)
for asset_key in airflow_data.asset_keys_per_dag[dag_run.dag_id]
for asset_key in airflow_data.asset_keys_per_dag_handle[DagHandle(dag_run.dag_id)]
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
TaskHandle,
TaskInfo,
)
from dagster_airlift.core.utils import is_mapped_asset_spec, spec_iterator, task_handles_for_spec
from dagster_airlift.core.utils import (
is_task_mapped_asset_spec,
spec_iterator,
task_handles_for_spec,
)


@record
Expand All @@ -23,7 +27,7 @@ class AirliftMetadataMappingInfo:

@cached_property
def mapped_asset_specs(self) -> List[AssetSpec]:
return [spec for spec in self.asset_specs if is_mapped_asset_spec(spec)]
return [spec for spec in self.asset_specs if is_task_mapped_asset_spec(spec)]

@cached_property
def dag_ids(self) -> Set[str]:
Expand Down Expand Up @@ -53,7 +57,7 @@ def asset_key_map(self) -> Dict[str, Dict[str, Set[AssetKey]]]:
"""Mapping of dag_id to task_id to set of asset_keys mapped from that task."""
asset_key_map: Dict[str, Dict[str, Set[AssetKey]]] = defaultdict(lambda: defaultdict(set))
for spec in self.asset_specs:
if is_mapped_asset_spec(spec):
if is_task_mapped_asset_spec(spec):
for task_handle in task_handles_for_spec(spec):
asset_key_map[task_handle.dag_id][task_handle.task_id].add(spec.key)
return asset_key_map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ class TaskHandle(NamedTuple):
task_id: str


@whitelist_for_serdes
class DagHandle(NamedTuple):
dag_id: str


###################################################################################################
# Serialized data that scopes to airflow DAGs and tasks.
###################################################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.storage.tags import KIND_PREFIX

from dagster_airlift.constants import AIRFLOW_SOURCE_METADATA_KEY_PREFIX, TASK_MAPPING_METADATA_KEY
from dagster_airlift.constants import (
AIRFLOW_SOURCE_METADATA_KEY_PREFIX,
DAG_MAPPING_METADATA_KEY,
TASK_MAPPING_METADATA_KEY,
)

if TYPE_CHECKING:
from dagster_airlift.core.serialization.serialized_data import TaskHandle
from dagster_airlift.core.serialization.serialized_data import DagHandle, TaskHandle


def convert_to_valid_dagster_name(name: str) -> str:
Expand Down Expand Up @@ -50,17 +54,31 @@ def get_metadata_key(instance_name: str) -> str:
return f"{AIRFLOW_SOURCE_METADATA_KEY_PREFIX}/{instance_name}"


def is_mapped_asset_spec(spec: AssetSpec) -> bool:
def is_task_mapped_asset_spec(spec: AssetSpec) -> bool:
return TASK_MAPPING_METADATA_KEY in spec.metadata


def is_dag_mapped_asset_spec(spec: AssetSpec) -> bool:
return DAG_MAPPING_METADATA_KEY in spec.metadata


def task_handles_for_spec(spec: AssetSpec) -> Set["TaskHandle"]:
from dagster_airlift.core.serialization.serialized_data import TaskHandle

check.param_invariant(is_mapped_asset_spec(spec), "spec", "Must be mappped spec")
check.param_invariant(is_task_mapped_asset_spec(spec), "spec", "Must be mapped spec")
task_handles = []
for task_handle_dict in spec.metadata[TASK_MAPPING_METADATA_KEY]:
task_handles.append(
TaskHandle(dag_id=task_handle_dict["dag_id"], task_id=task_handle_dict["task_id"])
)
return set(task_handles)


def dag_handles_for_spec(spec: AssetSpec) -> Set["DagHandle"]:
from dagster_airlift.core.serialization.serialized_data import DagHandle

check.param_invariant(is_dag_mapped_asset_spec(spec), "spec", "Must be mapped spec")
return {
DagHandle(dag_id=dag_handle_dict["dag_id"])
for dag_handle_dict in spec.metadata[DAG_MAPPING_METADATA_KEY]
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
TaskHandle,
)
from dagster_airlift.core.top_level_dag_def_api import assets_with_task_mappings
from dagster_airlift.core.utils import is_mapped_asset_spec, metadata_for_task_mapping
from dagster_airlift.core.utils import is_task_mapped_asset_spec, metadata_for_task_mapping
from dagster_airlift.test import make_instance
from dagster_airlift.utils import DAGSTER_AIRLIFT_PROXIED_STATE_DIR_ENV_VAR

Expand Down Expand Up @@ -229,7 +229,7 @@ def test_transitive_asset_deps() -> None:

b_asset = repo_def.assets_defs_by_key[b_key]
assert [dep.asset_key for dep in next(iter(b_asset.specs)).deps] == [a_key]
assert not is_mapped_asset_spec(next(iter(b_asset.specs)))
assert not is_task_mapped_asset_spec(next(iter(b_asset.specs)))

c_asset = repo_def.assets_defs_by_key[c_key]
assert [dep.asset_key for dep in next(iter(c_asset.specs)).deps] == [b_key]
Expand Down

0 comments on commit dc9b6a7

Please sign in to comment.