Skip to content

Commit

Permalink
[dagster-airlift] mark dags as migrating (#23370)
Browse files Browse the repository at this point in the history
Function that allows users to mark airflow dags as "migrating", and
injects a tag into the dag with information about the migration.
When placing the tag in the dag, there are two options:
1. Construct a new dag using essentially a shallow copy of the old dag,
and then inject this into global scope.
2. Inject a tag into the existing dag object you find in global scope
using the mutability of airflow dag's data structures.

I prefer (2) in this approach, because the surface area is way lower
than (1), and it should be relatively resistant to changes in airflow's
API other than this one tiny surface area ([which hasn't changed since
1.10, when tags were first
introduced)](https://airflow.apache.org/docs/apache-airflow/1.10.10/_modules/airflow/models/dag.html#DAG).
Unless we can figure out a reliable way to create an arbitrary copy
constructor for dags across any airflow version, we're likely to run
into brittleness with trying to reconstruct a dag from the pure object,
I think. In general, taking advantage of airflow's mutability when we
can seems like a good approach to injecting migration state.

Another point of discussion; what to do when there are no dags in scope
which the state migration object has reference to. For now I throw an
exception, but wondering if this is too harsh.

Finally, there's the question of how these tags show up in airflow's UI.
It's pretty ugly to see this json blob appear in the airflow UI after
setting the tag, but there doesn't seem to be any other data structures
we can use for this (except maybe params? But I feel more hesitant
hooking into that since it's significantly more complex implementation
wise). So might be the best we can do for now.
  • Loading branch information
dpeng817 authored Aug 11, 2024
1 parent 7348ca3 commit 6e511c6
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mark_as_migrating import mark_as_dagster_migrating as mark_as_dagster_migrating
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
import logging
from typing import Any, Dict, Optional

from airflow import DAG

from ..migration_state import AirflowMigrationState


def mark_as_dagster_migrating(
*,
global_vars: Dict[str, Any],
migration_state: AirflowMigrationState,
logger: Optional[logging.Logger] = None,
) -> None:
"""Alters all airflow dags in the current context to be marked as migrating to dagster.
Uses a migration dictionary to determine the status of the migration for each task within each dag.
Should only ever be the last line in a dag file.
Args:
global_vars (Dict[str, Any]): The global variables in the current context. In most cases, retrieved with `globals()` (no import required).
This is equivalent to what airflow already does to introspect the dags which exist in a given module context:
https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/dags.html#loading-dags
migration_state (AirflowMigrationState): The migration state for the dags.
logger (Optional[logging.Logger]): The logger to use. Defaults to logging.getLogger("dagster_airlift").
"""
caller_module = global_vars.get("__module__")
suffix = f" in module `{caller_module}`" if caller_module else ""
if not logger:
logger = logging.getLogger("dagster_airlift")
logger.debug(f"Searching for dags migrating to dagster{suffix}...")
num_dags = 0
for obj in global_vars.values():
if not isinstance(obj, DAG):
continue
dag: DAG = obj
logger.debug(f"Checking dag with id `{dag.dag_id}` for migration state.")
migration_state_for_dag = migration_state.get_migration_dict_for_dag(dag.dag_id)
if migration_state_for_dag is None:
logger.debug(
f"Dag with id `{dag.dag_id} hasn't been marked with migration state. Skipping..."
)
else:
dag.tags.append(json.dumps({"DAGSTER_MIGRATION_STATUS": migration_state_for_dag}))
logger.debug(
f"Dag with id `{dag.dag_id}` has been marked with migration state. Adding state to tags for dag."
)
num_dags += 1
logger.info(f"Marked {num_dags} dags as migrating to dagster{suffix}.")
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, NamedTuple
from typing import Any, Dict, NamedTuple, Optional

import yaml

Expand All @@ -18,11 +18,38 @@ class AirflowMigrationState(NamedTuple):
def get_migration_state_for_task(self, dag_id: str, task_id: str) -> bool:
return self.dags[dag_id].tasks[task_id].migrated

def get_migration_dict_for_dag(self, dag_id: str) -> Optional[Dict[str, Dict[str, Any]]]:
if dag_id not in self.dags:
return None
return {
"tasks": {
task_id: {"migrated": task_state.migrated}
for task_id, task_state in self.dags[dag_id].tasks.items()
}
}


class MigrationStateParsingError(Exception):
pass


def load_dag_migration_state_from_dict(dag_dict: Dict[str, Dict[str, Any]]) -> DagMigrationState:
if "tasks" not in dag_dict:
raise Exception("Expected a 'tasks' key in the yaml")
task_migration_states = {}
for task_id, task_dict in dag_dict["tasks"].items():
if not isinstance(task_dict, dict):
raise Exception("Expected a dictionary for each task")
if "migrated" not in task_dict:
raise Exception("Expected a 'migrated' key in the task dictionary")
if set(task_dict.keys()) != {"migrated"}:
raise Exception("Expected only a 'migrated' key in the task dictionary")
if task_dict["migrated"] not in [True, False]:
raise Exception("Expected 'migrated' key to be a boolean")
task_migration_states[task_id] = TaskMigrationState(migrated=task_dict["migrated"])
return DagMigrationState(tasks=task_migration_states)


def load_migration_state_from_yaml(migration_yaml_path: Path) -> AirflowMigrationState:
# Expect migration_yaml_path to be a directory, where each file represents a dag, and each
# file in the subdir represents a task. The dictionary each task should consist of a single bit:
Expand All @@ -37,20 +64,7 @@ def load_migration_state_from_yaml(migration_yaml_path: Path) -> AirflowMigratio
yaml_dict = yaml.safe_load(dag_file.read_text())
if not isinstance(yaml_dict, dict):
raise Exception("Expected a dictionary")
if "tasks" not in yaml_dict:
raise Exception("Expected a 'tasks' key in the yaml")
task_migration_states = {}
for task_id, task_dict in yaml_dict["tasks"].items():
if not isinstance(task_dict, dict):
raise Exception("Expected a dictionary for each task")
if "migrated" not in task_dict:
raise Exception("Expected a 'migrated' key in the task dictionary")
if set(task_dict.keys()) != {"migrated"}:
raise Exception("Expected only a 'migrated' key in the task dictionary")
if task_dict["migrated"] not in [True, False]:
raise Exception("Expected 'migrated' key to be a boolean")
task_migration_states[task_id] = TaskMigrationState(migrated=task_dict["migrated"])
dag_migration_states[dag_id] = DagMigrationState(tasks=task_migration_states)
dag_migration_states[dag_id] = load_dag_migration_state_from_dict(yaml_dict)
except Exception as e:
raise MigrationStateParsingError("Error parsing migration yaml") from e
return AirflowMigrationState(dags=dag_migration_states)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from datetime import datetime
from pathlib import Path

from airflow import DAG
from airflow.operators.python import PythonOperator
from dagster_airlift.in_airflow import mark_as_dagster_migrating
from dagster_airlift.migration_state import load_migration_state_from_yaml


def print_hello():
print("Hello") # noqa: T201


default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2023, 1, 1),
"retries": 1,
}

marked_dag = DAG(
"marked_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False
)
print_op = PythonOperator(task_id="print_task", python_callable=print_hello, dag=marked_dag)
downstream_print_op = PythonOperator(
task_id="downstream_print_task", python_callable=print_hello, dag=marked_dag
)


path_to_migration_state = Path(__file__).parent.parent / "migration_state"
mark_as_dagster_migrating(
migration_state=load_migration_state_from_yaml(path_to_migration_state), global_vars=globals()
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tasks:
print_task:
migrated: False
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from datetime import datetime
from pathlib import Path

from airflow import DAG
from airflow.operators.python import PythonOperator
from dagster_airlift.in_airflow import mark_as_dagster_migrating
from dagster_airlift.migration_state import load_migration_state_from_yaml


def print_hello():
print("Hello") # noqa: T201


default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2023, 1, 1),
"retries": 1,
}

marked_dag = DAG(
"marked_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False
)
print_op = PythonOperator(task_id="print_task", python_callable=print_hello, dag=marked_dag)
downstream_print_op = PythonOperator(
task_id="downstream_print_task", python_callable=print_hello, dag=marked_dag
)

# There is no entry for marked_dag in the migration state directory. There shouldn't be an exception, the dag just shouldn't be marked.
path_to_migration_state = Path(__file__).parent.parent / "migration_state"
mark_as_dagster_migrating(
migration_state=load_migration_state_from_yaml(path_to_migration_state), global_vars=globals()
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tasks:
print_task:
migrated: False
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import json
import os
import subprocess
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generator

import pytest
import requests
from dagster._core.test_utils import environ


@pytest.fixture(name="setup")
def setup_fixture() -> Generator[str, None, None]:
with TemporaryDirectory() as tmpdir:
# run chmod +x create_airflow_cfg.sh and then run create_airflow_cfg.sh tmpdir
temp_env = {**os.environ.copy(), "AIRFLOW_HOME": tmpdir}
# go up one directory from current
path_to_script = Path(__file__).parent.parent.parent / "airflow_setup.sh"
path_to_dags = Path(__file__).parent / "correctly_marked_dag" / "dags"
subprocess.run(["chmod", "+x", path_to_script], check=True, env=temp_env)
subprocess.run([path_to_script, path_to_dags], check=True, env=temp_env)
with environ({"AIRFLOW_HOME": tmpdir}):
yield tmpdir


def test_migrating_dag(airflow_instance: None) -> None:
"""Tests that a correctly marked dag is marked as migrating via a tag on the dag object."""
response = requests.get("http://localhost:8080/api/v1/dags/marked_dag", auth=("admin", "admin"))
assert response.status_code == 200
tags = response.json()["tags"]
assert len(tags) == 1
assert json.loads(tags[0]["name"]) == {
"DAGSTER_MIGRATION_STATUS": {
"tasks": {
"print_task": {
"migrated": False,
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
import subprocess
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generator

import pytest
import requests
from dagster._core.test_utils import environ


@pytest.fixture(name="setup")
def setup_fixture() -> Generator[str, None, None]:
with TemporaryDirectory() as tmpdir:
# run chmod +x create_airflow_cfg.sh and then run create_airflow_cfg.sh tmpdir
temp_env = {**os.environ.copy(), "AIRFLOW_HOME": tmpdir}
# go up one directory from current
path_to_script = Path(__file__).parent.parent.parent / "airflow_setup.sh"
path_to_dags = Path(__file__).parent / "incorrectly_marked_dag" / "dags"
subprocess.run(["chmod", "+x", path_to_script], check=True, env=temp_env)
subprocess.run([path_to_script, path_to_dags], check=True, env=temp_env)
with environ({"AIRFLOW_HOME": tmpdir}):
yield tmpdir


def test_migrating_dag(airflow_instance: None) -> None:
"""Tests that an incorrectly marked dag throws an exception, and is not loaded."""
response = requests.get("http://localhost:8080/api/v1/dags/marked_dag", auth=("admin", "admin"))
assert response.status_code == 200
tags = response.json()["tags"]
assert len(tags) == 0

0 comments on commit 6e511c6

Please sign in to comment.