diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_defs_data.py b/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_defs_data.py index 36b3e4163cc1a..ebbff08f075a3 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_defs_data.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_defs_data.py @@ -5,11 +5,15 @@ from dagster import AssetKey, Definitions from dagster._record import record -from dagster_airlift.constants import DAG_MAPPING_METADATA_KEY from dagster_airlift.core.airflow_instance import AirflowInstance from dagster_airlift.core.serialization.compute import AirliftMetadataMappingInfo -from dagster_airlift.core.serialization.serialized_data import TaskHandle -from dagster_airlift.core.utils import is_mapped_asset_spec, task_handles_for_spec +from dagster_airlift.core.serialization.serialized_data import DagHandle, TaskHandle +from dagster_airlift.core.utils import ( + dag_handles_for_spec, + is_dag_mapped_asset_spec, + is_task_mapped_asset_spec, + task_handles_for_spec, +) @record @@ -36,20 +40,21 @@ def dag_ids_with_mapped_asset_keys(self) -> AbstractSet[str]: def asset_keys_per_task_handle(self) -> Mapping[TaskHandle, AbstractSet[AssetKey]]: asset_keys_per_handle = defaultdict(set) for spec in self.mapped_defs.get_all_asset_specs(): - if is_mapped_asset_spec(spec): + if is_task_mapped_asset_spec(spec): task_handles = task_handles_for_spec(spec) for task_handle in task_handles: asset_keys_per_handle[task_handle].add(spec.key) return asset_keys_per_handle @cached_property - def asset_keys_per_dag(self) -> Mapping[str, AbstractSet[AssetKey]]: - dag_id_to_asset_key = defaultdict(set) + def asset_keys_per_dag_handle(self) -> Mapping[DagHandle, AbstractSet[AssetKey]]: + asset_keys_per_handle = defaultdict(set) for spec in self.mapped_defs.get_all_asset_specs(): - if DAG_MAPPING_METADATA_KEY in spec.metadata: - for mapping in spec.metadata[DAG_MAPPING_METADATA_KEY]: - dag_id_to_asset_key[mapping["dag_id"]].add(spec.key) - return dag_id_to_asset_key + if is_dag_mapped_asset_spec(spec): + dag_handles = dag_handles_for_spec(spec) + for dag_handle in dag_handles: + asset_keys_per_handle[dag_handle].add(spec.key) + return asset_keys_per_handle def asset_keys_in_task(self, dag_id: str, task_id: str) -> AbstractSet[AssetKey]: return self.asset_keys_per_task_handle[TaskHandle(dag_id=dag_id, task_id=task_id)] diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/sensor/event_translation.py b/examples/experimental/dagster-airlift/dagster_airlift/core/sensor/event_translation.py index 8bf534ccf589f..2a0244e2c8586 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/sensor/event_translation.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/sensor/event_translation.py @@ -16,6 +16,7 @@ from dagster_airlift.constants import EFFECTIVE_TIMESTAMP_METADATA_KEY from dagster_airlift.core.airflow_defs_data import AirflowDefinitionsData from dagster_airlift.core.airflow_instance import DagRun, TaskInstance +from dagster_airlift.core.serialization.serialized_data import DagHandle AssetEvent = Union[AssetMaterialization, AssetObservation, AssetCheckEvaluation] DagsterEventTransformerFn = Callable[ @@ -38,7 +39,7 @@ def materializations_for_dag_run( AssetMaterialization( asset_key=asset_key, description=dag_run.note, metadata=get_dag_run_metadata(dag_run) ) - for asset_key in airflow_data.asset_keys_per_dag[dag_run.dag_id] + for asset_key in airflow_data.asset_keys_per_dag_handle[DagHandle(dag_run.dag_id)] ] diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/serialization/compute.py b/examples/experimental/dagster-airlift/dagster_airlift/core/serialization/compute.py index dcf5208211a53..7aa1f8ef5a038 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/serialization/compute.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/serialization/compute.py @@ -14,7 +14,11 @@ TaskHandle, TaskInfo, ) -from dagster_airlift.core.utils import is_mapped_asset_spec, spec_iterator, task_handles_for_spec +from dagster_airlift.core.utils import ( + is_task_mapped_asset_spec, + spec_iterator, + task_handles_for_spec, +) @record @@ -23,7 +27,7 @@ class AirliftMetadataMappingInfo: @cached_property def mapped_asset_specs(self) -> List[AssetSpec]: - return [spec for spec in self.asset_specs if is_mapped_asset_spec(spec)] + return [spec for spec in self.asset_specs if is_task_mapped_asset_spec(spec)] @cached_property def dag_ids(self) -> Set[str]: @@ -53,7 +57,7 @@ def asset_key_map(self) -> Dict[str, Dict[str, Set[AssetKey]]]: """Mapping of dag_id to task_id to set of asset_keys mapped from that task.""" asset_key_map: Dict[str, Dict[str, Set[AssetKey]]] = defaultdict(lambda: defaultdict(set)) for spec in self.asset_specs: - if is_mapped_asset_spec(spec): + if is_task_mapped_asset_spec(spec): for task_handle in task_handles_for_spec(spec): asset_key_map[task_handle.dag_id][task_handle.task_id].add(spec.key) return asset_key_map diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/serialization/serialized_data.py b/examples/experimental/dagster-airlift/dagster_airlift/core/serialization/serialized_data.py index a6a4879baec0c..d3a9c137c1511 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/serialization/serialized_data.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/serialization/serialized_data.py @@ -48,6 +48,11 @@ class TaskHandle(NamedTuple): task_id: str +@whitelist_for_serdes +class DagHandle(NamedTuple): + dag_id: str + + ################################################################################################### # Serialized data that scopes to airflow DAGs and tasks. ################################################################################################### diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/utils.py b/examples/experimental/dagster-airlift/dagster_airlift/core/utils.py index 4c22f6b899ae3..e7f77be5ad071 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/utils.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/utils.py @@ -11,10 +11,14 @@ from dagster._core.errors import DagsterInvariantViolationError from dagster._core.storage.tags import KIND_PREFIX -from dagster_airlift.constants import AIRFLOW_SOURCE_METADATA_KEY_PREFIX, TASK_MAPPING_METADATA_KEY +from dagster_airlift.constants import ( + AIRFLOW_SOURCE_METADATA_KEY_PREFIX, + DAG_MAPPING_METADATA_KEY, + TASK_MAPPING_METADATA_KEY, +) if TYPE_CHECKING: - from dagster_airlift.core.serialization.serialized_data import TaskHandle + from dagster_airlift.core.serialization.serialized_data import DagHandle, TaskHandle def convert_to_valid_dagster_name(name: str) -> str: @@ -50,17 +54,31 @@ def get_metadata_key(instance_name: str) -> str: return f"{AIRFLOW_SOURCE_METADATA_KEY_PREFIX}/{instance_name}" -def is_mapped_asset_spec(spec: AssetSpec) -> bool: +def is_task_mapped_asset_spec(spec: AssetSpec) -> bool: return TASK_MAPPING_METADATA_KEY in spec.metadata +def is_dag_mapped_asset_spec(spec: AssetSpec) -> bool: + return DAG_MAPPING_METADATA_KEY in spec.metadata + + def task_handles_for_spec(spec: AssetSpec) -> Set["TaskHandle"]: from dagster_airlift.core.serialization.serialized_data import TaskHandle - check.param_invariant(is_mapped_asset_spec(spec), "spec", "Must be mappped spec") + check.param_invariant(is_task_mapped_asset_spec(spec), "spec", "Must be mapped spec") task_handles = [] for task_handle_dict in spec.metadata[TASK_MAPPING_METADATA_KEY]: task_handles.append( TaskHandle(dag_id=task_handle_dict["dag_id"], task_id=task_handle_dict["task_id"]) ) return set(task_handles) + + +def dag_handles_for_spec(spec: AssetSpec) -> Set["DagHandle"]: + from dagster_airlift.core.serialization.serialized_data import DagHandle + + check.param_invariant(is_dag_mapped_asset_spec(spec), "spec", "Must be mapped spec") + return { + DagHandle(dag_id=dag_handle_dict["dag_id"]) + for dag_handle_dict in spec.metadata[DAG_MAPPING_METADATA_KEY] + } diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/core_tests/test_load_defs.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/core_tests/test_load_defs.py index 4a58c6d600315..dfa2e2aa13b25 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/core_tests/test_load_defs.py +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/core_tests/test_load_defs.py @@ -44,7 +44,7 @@ TaskHandle, ) from dagster_airlift.core.top_level_dag_def_api import assets_with_task_mappings -from dagster_airlift.core.utils import is_mapped_asset_spec, metadata_for_task_mapping +from dagster_airlift.core.utils import is_task_mapped_asset_spec, metadata_for_task_mapping from dagster_airlift.test import make_instance from dagster_airlift.utils import DAGSTER_AIRLIFT_PROXIED_STATE_DIR_ENV_VAR @@ -229,7 +229,7 @@ def test_transitive_asset_deps() -> None: b_asset = repo_def.assets_defs_by_key[b_key] assert [dep.asset_key for dep in next(iter(b_asset.specs)).deps] == [a_key] - assert not is_mapped_asset_spec(next(iter(b_asset.specs))) + assert not is_task_mapped_asset_spec(next(iter(b_asset.specs))) c_asset = repo_def.assets_defs_by_key[c_key] assert [dep.asset_key for dep in next(iter(c_asset.specs)).deps] == [b_key]