diff --git a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/proxied_state.py b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/proxied_state.py index 7a6f2afc06e4c..0f17c5d39d243 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/proxied_state.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/proxied_state.py @@ -23,20 +23,29 @@ def to_dict(self) -> Dict[str, Any]: class DagProxiedState(NamedTuple): + proxied: Optional[bool] tasks: Dict[str, TaskProxiedState] @staticmethod - def from_dict(dag_dict: Dict[str, Sequence[Dict[str, Any]]]) -> "DagProxiedState": - if "tasks" not in dag_dict: + def from_dict(dag_dict: Dict[str, Any]) -> "DagProxiedState": + if "tasks" not in dag_dict and "proxied" not in dag_dict: raise Exception( - f"Expected a 'tasks' key in the dag dictionary. Instead; got: {dag_dict}" + f"Expected a 'tasks' or 'proxied' top-level key in the dag dictionary. Instead; got: {dag_dict}" + ) + if "tasks" in dag_dict and "proxied" in dag_dict: + raise Exception( + f"Expected only one of 'tasks' or 'proxied' top-level keys in the dag dictionary. Instead; got: {dag_dict}" ) - task_list = dag_dict["tasks"] task_proxied_states = {} - for task_dict in task_list: - task_state = TaskProxiedState.from_dict(task_dict) - task_proxied_states[task_state.task_id] = task_state - return DagProxiedState(tasks=task_proxied_states) + if "tasks" in dag_dict: + task_list: Sequence[Dict[str, Any]] = dag_dict["tasks"] + for task_dict in task_list: + task_state = TaskProxiedState.from_dict(task_dict) + task_proxied_states[task_state.task_id] = task_state + dag_proxied_state: Optional[bool] = dag_dict.get("proxied") + if dag_proxied_state not in [True, False, None]: + raise Exception("Expected 'proxied' key to be a boolean or None") + return DagProxiedState(tasks=task_proxied_states, proxied=dag_proxied_state) def to_dict(self) -> Dict[str, Sequence[Dict[str, Any]]]: return {"tasks": [task_state.to_dict() for task_state in self.tasks.values()]} diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/test_migration_state.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/test_proxied_state.py similarity index 76% rename from examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/test_migration_state.py rename to examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/test_proxied_state.py index c61ac63002d17..568b34ab6e954 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/test_migration_state.py +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/test_proxied_state.py @@ -24,13 +24,15 @@ def test_proxied_state() -> None: "first_task": TaskProxiedState(task_id="first_task", proxied=True), "second_task": TaskProxiedState(task_id="second_task", proxied=False), "third_task": TaskProxiedState(task_id="third_task", proxied=True), - } + }, + proxied=None, ), "second": DagProxiedState( tasks={ "some_task": TaskProxiedState("some_task", proxied=True), "other_task": TaskProxiedState("other_task", proxied=False), - } + }, + proxied=None, ), } ) @@ -58,3 +60,23 @@ def test_proxied_state_from_yaml() -> None: assert dag_proxied_state.is_task_proxied("load_raw_customers") is False assert dag_proxied_state.is_task_proxied("build_dbt_models") is False assert dag_proxied_state.is_task_proxied("export_customers") is True + + +def test_dag_level_proxied_state_from_yaml() -> None: + proxied_state_dict = yaml.safe_load(""" +proxied: True +""") + dag_proxied_state = DagProxiedState.from_dict(proxied_state_dict) + assert dag_proxied_state.proxied is True + + proxied_state_dict = yaml.safe_load(""" +proxied: False +""") + dag_proxied_state = DagProxiedState.from_dict(proxied_state_dict) + assert dag_proxied_state.proxied is False + + proxied_state_dict = yaml.safe_load(""" +proxied: Fish +""") + with pytest.raises(Exception, match="Expected 'proxied' key to be a boolean or None"): + DagProxiedState.from_dict(proxied_state_dict)