Skip to content

Commit

Permalink
[dagster-airlift][dag] Dag-level override proxied state (#25159)
Browse files Browse the repository at this point in the history
## Summary & Motivation
Proxied state scaffolding for dag-level overrides
## How I Tested These Changes
Tests for new behavior
## Changelog
NOCHANGELOG
  • Loading branch information
dpeng817 authored Oct 14, 2024
1 parent 9cdf47a commit 5c014ea
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,29 @@ def to_dict(self) -> Dict[str, Any]:


class DagProxiedState(NamedTuple):
proxied: Optional[bool]
tasks: Dict[str, TaskProxiedState]

@staticmethod
def from_dict(dag_dict: Dict[str, Sequence[Dict[str, Any]]]) -> "DagProxiedState":
if "tasks" not in dag_dict:
def from_dict(dag_dict: Dict[str, Any]) -> "DagProxiedState":
if "tasks" not in dag_dict and "proxied" not in dag_dict:
raise Exception(
f"Expected a 'tasks' key in the dag dictionary. Instead; got: {dag_dict}"
f"Expected a 'tasks' or 'proxied' top-level key in the dag dictionary. Instead; got: {dag_dict}"
)
if "tasks" in dag_dict and "proxied" in dag_dict:
raise Exception(
f"Expected only one of 'tasks' or 'proxied' top-level keys in the dag dictionary. Instead; got: {dag_dict}"
)
task_list = dag_dict["tasks"]
task_proxied_states = {}
for task_dict in task_list:
task_state = TaskProxiedState.from_dict(task_dict)
task_proxied_states[task_state.task_id] = task_state
return DagProxiedState(tasks=task_proxied_states)
if "tasks" in dag_dict:
task_list: Sequence[Dict[str, Any]] = dag_dict["tasks"]
for task_dict in task_list:
task_state = TaskProxiedState.from_dict(task_dict)
task_proxied_states[task_state.task_id] = task_state
dag_proxied_state: Optional[bool] = dag_dict.get("proxied")
if dag_proxied_state not in [True, False, None]:
raise Exception("Expected 'proxied' key to be a boolean or None")
return DagProxiedState(tasks=task_proxied_states, proxied=dag_proxied_state)

def to_dict(self) -> Dict[str, Sequence[Dict[str, Any]]]:
return {"tasks": [task_state.to_dict() for task_state in self.tasks.values()]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ def test_proxied_state() -> None:
"first_task": TaskProxiedState(task_id="first_task", proxied=True),
"second_task": TaskProxiedState(task_id="second_task", proxied=False),
"third_task": TaskProxiedState(task_id="third_task", proxied=True),
}
},
proxied=None,
),
"second": DagProxiedState(
tasks={
"some_task": TaskProxiedState("some_task", proxied=True),
"other_task": TaskProxiedState("other_task", proxied=False),
}
},
proxied=None,
),
}
)
Expand Down Expand Up @@ -58,3 +60,23 @@ def test_proxied_state_from_yaml() -> None:
assert dag_proxied_state.is_task_proxied("load_raw_customers") is False
assert dag_proxied_state.is_task_proxied("build_dbt_models") is False
assert dag_proxied_state.is_task_proxied("export_customers") is True


def test_dag_level_proxied_state_from_yaml() -> None:
proxied_state_dict = yaml.safe_load("""
proxied: True
""")
dag_proxied_state = DagProxiedState.from_dict(proxied_state_dict)
assert dag_proxied_state.proxied is True

proxied_state_dict = yaml.safe_load("""
proxied: False
""")
dag_proxied_state = DagProxiedState.from_dict(proxied_state_dict)
assert dag_proxied_state.proxied is False

proxied_state_dict = yaml.safe_load("""
proxied: Fish
""")
with pytest.raises(Exception, match="Expected 'proxied' key to be a boolean or None"):
DagProxiedState.from_dict(proxied_state_dict)

0 comments on commit 5c014ea

Please sign in to comment.