Skip to content

Commit

Permalink
[dagster-airlift] airflow operator switcher
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Aug 9, 2024
1 parent 9cc01d3 commit 87da942
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,12 @@ def k8s_extra_cmds(version: str, _) -> List[str]:
AvailablePythonVersion.V3_12,
],
),
PackageSpec(
"examples/experimental/dagster-airlift/examples/simple-migration",
unsupported_python_versions=[
AvailablePythonVersion.V3_12,
],
),
]


Expand Down
3 changes: 2 additions & 1 deletion examples/experimental/dagster-airlift/airflow_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ if [[ "$DAGS_FOLDER" != /* ]]; then
exit 1
fi

# Create the airflow.cfg file in $AIRFLOW_HOME
# Create the airflow.cfg file in $AIRFLOW_HOME. We set a super high import timeout so that we can attach a debugger and mess around with the code.
cat <<EOL > $AIRFLOW_HOME/airflow.cfg
[core]
dags_folder = $DAGS_FOLDER
dagbag_import_timeout = 30000
load_examples = False
[api]
auth_backend = airflow.api.auth.backend.basic_auth
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import requests
from airflow import DAG
from airflow.operators.python import PythonOperator

from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION
Expand Down Expand Up @@ -77,5 +76,5 @@ def compute_fn() -> None:
return None


def build_dagster_task(task_id: str, dag: DAG, **kwargs):
return PythonOperator(task_id=task_id, dag=dag, python_callable=compute_fn, **kwargs)
def build_dagster_task(task_id: str, **kwargs):
return PythonOperator(task_id=task_id, python_callable=compute_fn, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import Any, Dict

from airflow import DAG
from airflow.models import BaseOperator

from dagster_airlift.in_airflow.dagster_operator import build_dagster_task

from ..migration_state import AirflowMigrationState

Expand All @@ -25,21 +28,66 @@ def mark_as_dagster_migrating(
caller_module = global_vars.get("__module__")
suffix = f" in module `{caller_module}`" if caller_module else ""
logger.debug(f"Searching for dags migrating to dagster{suffix}...")
num_dags = 0
for obj in global_vars.values():
if not isinstance(obj, DAG):
continue
dag: DAG = obj
logger.debug(f"Checking dag with id `{dag.dag_id}` for migration state.")
migration_state_for_dag = migration_state.get_migration_dict_for_dag(dag.dag_id)
if migration_state_for_dag is None:
logger.debug(
f"Dag with id `{dag.dag_id} hasn't been marked with migration state. Skipping..."
)
else:
dag.tags.append(json.dumps({"DAGSTER_MIGRATION_STATUS": migration_state_for_dag}))
logger.debug(
f"Dag with id `{dag.dag_id}` has been marked with migration state. Adding state to tags for dag."
dag_vars_to_mark = set()
task_vars_to_migrate = set()
all_dags_by_id = {}
for var, obj in global_vars.items():
if isinstance(obj, DAG):
dag: DAG = obj
if migration_state.dag_is_being_migrated(obj.dag_id):
logger.debug(f"Dag with id `{dag.dag_id}` has migration state.")
dag_vars_to_mark.add(var)
else:
logger.debug(
f"Dag with id `{dag.dag_id} has no associated migration state. Skipping..."
)
all_dags_by_id[obj.dag_id] = obj
if isinstance(obj, BaseOperator) and migration_state.get_migration_state_for_task(
dag_id=obj.dag_id, task_id=obj.task_id
):
task_vars_to_migrate.add(var)

for var in dag_vars_to_mark:
dag: DAG = global_vars[var]
logger.debug(f"Tagging dag {dag.dag_id} as migrating.")
dag.tags.append(
json.dumps(
{"DAGSTER_MIGRATION_STATUS": migration_state.get_migration_dict_for_dag(dag.dag_id)}
)
num_dags += 1
logger.info(f"Marked {num_dags} dags as migrating to dagster{suffix}.")
)
logging.debug(f"Marked {len(dag_vars_to_mark)} dags as migrating to dagster via tag.")

for var in task_vars_to_migrate:
original_op: BaseOperator = global_vars[var]
logger.debug(
f"Creating new operator for task {original_op.task_id} in dag {original_op.dag_id}"
)
new_op = build_dagster_task(
task_id=original_op.task_id,
owner=original_op.owner,
email=original_op.email,
email_on_retry=original_op.email_on_retry,
email_on_failure=original_op.email_on_failure,
retries=original_op.retries,
retry_delay=original_op.retry_delay,
retry_exponential_backoff=original_op.retry_exponential_backoff,
max_retry_delay=original_op.max_retry_delay,
start_date=original_op.start_date,
end_date=original_op.end_date,
depends_on_past=original_op.depends_on_past,
wait_for_downstream=original_op.wait_for_downstream,
params=original_op.params,
doc_md="This task has been migrated to dagster.",
)
original_op.dag.task_dict[original_op.task_id] = new_op

new_op.upstream_task_ids = original_op.upstream_task_ids
new_op.downstream_task_ids = original_op.downstream_task_ids
new_op.dag = original_op.dag
original_op.dag = None
logger.debug(
f"Switching global state var to dagster operator for task {original_op.task_id}."
)
global_vars[var] = new_op
logging.debug(f"Marked {len(task_vars_to_migrate)} tasks as migrating to dagster.")
logging.debug(f"Completed marking dags and tasks as migrating to dagster{suffix}.")
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@ class DagMigrationState(NamedTuple):
class AirflowMigrationState(NamedTuple):
dags: Dict[str, DagMigrationState]

def get_migration_state_for_task(self, dag_id: str, task_id: str) -> bool:
def get_migration_state_for_task(self, dag_id: str, task_id: str) -> Optional[bool]:
if dag_id not in self.dags:
return None
if task_id not in self.dags[dag_id].tasks:
return None
return self.dags[dag_id].tasks[task_id].migrated

def dag_is_being_migrated(self, dag_id: str) -> bool:
return self.get_migration_dict_for_dag(dag_id) is not None

def get_migration_dict_for_dag(self, dag_id: str) -> Optional[Dict[str, Dict[str, Any]]]:
if dag_id not in self.dags:
return None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import os
from datetime import datetime

from airflow import DAG
from airflow.operators.python import PythonOperator
from dagster_airlift.in_airflow import mark_as_dagster_migrating
from dagster_airlift.migration_state import (
AirflowMigrationState,
DagMigrationState,
TaskMigrationState,
)

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
requests_log = logging.getLogger("requests.packages.urllib3")
requests_log.setLevel(logging.INFO)
requests_log.propagate = True


def write_to_file_in_airflow_home() -> None:
airflow_home = os.environ["AIRFLOW_HOME"]
with open(os.path.join(airflow_home, "airflow_home_file.txt"), "w") as f:
f.write("Hello")


def write_to_other_file_in_airflow_home() -> None:
airflow_home = os.environ["AIRFLOW_HOME"]
with open(os.path.join(airflow_home, "other_airflow_home_file.txt"), "w") as f:
f.write("Hello")


default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2023, 1, 1),
"retries": 1,
}

dag = DAG(
"the_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False
)
op_to_migrate = PythonOperator(
task_id="some_task", python_callable=write_to_file_in_airflow_home, dag=dag
)
op_doesnt_migrate = PythonOperator(
task_id="other_task", python_callable=write_to_other_file_in_airflow_home, dag=dag
)
# Add a dependency between the two tasks
op_doesnt_migrate.set_upstream(op_to_migrate)

# # set up the debugger
# print("Waiting for debugger to attach...")
# debugpy.listen(("localhost", 7778))
# debugpy.wait_for_client()
mark_as_dagster_migrating(
global_vars=globals(),
migration_state=AirflowMigrationState(
dags={
"the_dag": DagMigrationState(
tasks={
"some_task": TaskMigrationState(migrated=True),
"other_task": TaskMigrationState(migrated=True),
}
)
}
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dagster import Definitions, asset


@asset
def the_dag__some_task():
return "asset_value"


defs = Definitions(assets=[the_dag__some_task])
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import time
from pathlib import Path

import pytest
import requests
from dagster import AssetKey, DagsterInstance, DagsterRunStatus
from dagster._core.test_utils import environ
from dagster._time import get_current_timestamp


@pytest.fixture(name="dags_dir")
def setup_dags_dir() -> Path:
return Path(__file__).parent / "airflow_op_switcheroo" / "dags"


@pytest.fixture(name="dagster_defs_path")
def setup_dagster_defs_path() -> str:
return str(Path(__file__).parent / "airflow_op_switcheroo" / "dagster_defs.py")


def test_migrated_operator(
airflow_instance: None, dagster_dev: None, dagster_home: str, airflow_home: str
) -> None:
"""Tests that dagster migrated operator can correctly map airflow tasks to dagster tasks, and kick off executions."""
response = requests.post(
"http://localhost:8080/api/v1/dags/the_dag/dagRuns", auth=("admin", "admin"), json={}
)
assert response.status_code == 200, response.json()
# Wait until the run enters a terminal state
terminal_status = None
start_time = get_current_timestamp()
while get_current_timestamp() - start_time < 30:
response = requests.get(
"http://localhost:8080/api/v1/dags/the_dag/dagRuns", auth=("admin", "admin")
)
assert response.status_code == 200, response.json()
dag_runs = response.json()["dag_runs"]
if dag_runs[0]["state"] in ["success", "failed"]:
terminal_status = dag_runs[0]["state"]
break
time.sleep(1)
assert terminal_status == "success", (
"Never reached terminal status"
if terminal_status is None
else f"terminal status was {terminal_status}"
)
with environ({"DAGSTER_HOME": dagster_home}):
instance = DagsterInstance.get()
runs = instance.get_runs()
# The graphql endpoint kicks off a run for each of the tasks in the dag
assert len(runs) == 1
some_task_run = [ # noqa
run
for run in runs
if set(list(run.asset_selection)) == {AssetKey(["the_dag__some_task"])} # type: ignore
][0]
assert some_task_run.status == DagsterRunStatus.SUCCESS

0 comments on commit 87da942

Please sign in to comment.