Skip to content

Commit

Permalink
[dagster-airlift] airflow operator switcher
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Aug 9, 2024
1 parent e7baac3 commit 485bdd5
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 20 deletions.
3 changes: 2 additions & 1 deletion examples/experimental/dagster-airlift/airflow_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ if [[ "$DAGS_FOLDER" != /* ]]; then
exit 1
fi

# Create the airflow.cfg file in $AIRFLOW_HOME
# Create the airflow.cfg file in $AIRFLOW_HOME. We set a super high import timeout so that we can attach a debugger and mess around with the code.
cat <<EOL > $AIRFLOW_HOME/airflow.cfg
[core]
dags_folder = $DAGS_FOLDER
dagbag_import_timeout = 30000
load_examples = False
[api]
auth_backend = airflow.api.auth.backend.basic_auth
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import requests
from airflow import DAG
from airflow.operators.python import PythonOperator

from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION
Expand Down Expand Up @@ -77,5 +76,5 @@ def compute_fn() -> None:
return None


def build_dagster_task(task_id: str, dag: DAG, **kwargs):
return PythonOperator(task_id=task_id, dag=dag, python_callable=compute_fn, **kwargs)
def build_dagster_task(task_id: str, **kwargs):
return PythonOperator(task_id=task_id, python_callable=compute_fn, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import sys

from airflow import DAG
from airflow.models import BaseOperator

from dagster_airlift.in_airflow.dagster_operator import build_dagster_task

from ..migration_state import AirflowMigrationState

Expand All @@ -15,20 +18,55 @@ def mark_as_dagster_migrating(
"""
# get global context from above frame
global_vars = sys._getframe(1).f_globals # noqa: SLF001
dag_vars_to_mark = set()
task_vars_to_migrate = set()
all_dags_by_id = {}
for var, obj in global_vars.items():
if isinstance(obj, DAG):
if migration_state.dag_is_being_migrated(obj.dag_id):
dag_vars_to_mark.add(var)
all_dags_by_id[obj.dag_id] = obj
if isinstance(obj, BaseOperator) and migration_state.get_migration_state_for_task(
dag_id=obj.dag_id, task_id=obj.task_id
):
task_vars_to_migrate.add(var)

for var in dag_vars_to_mark:
dag: DAG = global_vars[var]
print(f"Tagging dag {dag.dag_id} as migrating.") # noqa: T201
dag.tags.append(
json.dumps(
{"DAGSTER_MIGRATION_STATUS": migration_state.get_migration_dict_for_dag(dag.dag_id)}
)
)

any_dags_marked = False
for obj in global_vars.values():
if not isinstance(obj, DAG):
continue
dag: DAG = obj
migration_status = migration_state.get_migration_dict_for_dag(dag.dag_id)
# If there are migrated tasks, then add a tag to the dag to indicate that it is migrating.
if migration_status is not None:
any_dags_marked = True
dag.tags.append(json.dumps({"DAGSTER_MIGRATION_STATUS": migration_status}))

if not any_dags_marked:
# Should we warn here?
raise Exception(
"No dags were marked as migrating. This is likely an error in the migration state file."
for var in task_vars_to_migrate:
original_op: BaseOperator = global_vars[var]
# Need to figure out how to make this constructor resistant to changes in airflow version.
print(f"Creating new operator for task {original_op.task_id} in dag {original_op.dag_id}") # noqa: T201
# First, flush the existing operator from the dag.
new_op = build_dagster_task(
task_id=original_op.task_id,
owner=original_op.owner,
email=original_op.email,
email_on_retry=original_op.email_on_retry,
email_on_failure=original_op.email_on_failure,
retries=original_op.retries,
retry_delay=original_op.retry_delay,
retry_exponential_backoff=original_op.retry_exponential_backoff,
max_retry_delay=original_op.max_retry_delay,
start_date=original_op.start_date,
end_date=original_op.end_date,
depends_on_past=original_op.depends_on_past,
wait_for_downstream=original_op.wait_for_downstream,
params=original_op.params,
doc_md="This task has been migrated to dagster.",
)
original_op.dag.task_dict[original_op.task_id] = new_op

new_op.upstream_task_ids = original_op.upstream_task_ids
new_op.downstream_task_ids = original_op.downstream_task_ids
new_op.dag = original_op.dag
original_op.dag = None
print(f"Switching global state var to dagster operator for {var}.") # noqa: T201
global_vars[var] = new_op
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@ class DagMigrationState(NamedTuple):
class AirflowMigrationState(NamedTuple):
dags: Dict[str, DagMigrationState]

def get_migration_state_for_task(self, dag_id: str, task_id: str) -> bool:
def get_migration_state_for_task(self, dag_id: str, task_id: str) -> Optional[bool]:
if dag_id not in self.dags:
return None
if task_id not in self.dags[dag_id].tasks:
return None
return self.dags[dag_id].tasks[task_id].migrated

def dag_is_being_migrated(self, dag_id: str) -> bool:
return self.get_migration_dict_for_dag(dag_id) is not None

def get_migration_dict_for_dag(self, dag_id: str) -> Optional[Dict[str, Dict[str, Any]]]:
if dag_id not in self.dags:
return None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
import os
from datetime import datetime

from airflow import DAG
from airflow.operators.python import PythonOperator
from dagster_airlift.in_airflow import mark_as_dagster_migrating
from dagster_airlift.migration_state import (
AirflowMigrationState,
DagMigrationState,
TaskMigrationState,
)

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
requests_log = logging.getLogger("requests.packages.urllib3")
requests_log.setLevel(logging.INFO)
requests_log.propagate = True


def write_to_file_in_airflow_home() -> None:
airflow_home = os.environ["AIRFLOW_HOME"]
with open(os.path.join(airflow_home, "airflow_home_file.txt"), "w") as f:
f.write("Hello")


def write_to_other_file_in_airflow_home() -> None:
airflow_home = os.environ["AIRFLOW_HOME"]
with open(os.path.join(airflow_home, "other_airflow_home_file.txt"), "w") as f:
f.write("Hello")


default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2023, 1, 1),
"retries": 1,
}

dag = DAG(
"the_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False
)
op_to_migrate = PythonOperator(
task_id="some_task", python_callable=write_to_file_in_airflow_home, dag=dag
)
op_doesnt_migrate = PythonOperator(
task_id="other_task", python_callable=write_to_other_file_in_airflow_home, dag=dag
)
# Add a dependency between the two tasks
op_doesnt_migrate.set_upstream(op_to_migrate)

# # set up the debugger
# print("Waiting for debugger to attach...")
# debugpy.listen(("localhost", 7778))
# debugpy.wait_for_client()
mark_as_dagster_migrating(
migration_state=AirflowMigrationState(
dags={
"the_dag": DagMigrationState(
tasks={
"some_task": TaskMigrationState(migrated=True),
"other_task": TaskMigrationState(migrated=True),
}
)
}
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dagster import Definitions, asset


@asset
def the_dag__some_task():
return "asset_value"


defs = Definitions(assets=[the_dag__some_task])
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import time
from pathlib import Path

import pytest
import requests
from dagster import AssetKey, DagsterInstance, DagsterRunStatus
from dagster._core.test_utils import environ
from dagster._time import get_current_timestamp


@pytest.fixture(name="dags_dir")
def setup_dags_dir() -> Path:
return Path(__file__).parent / "airflow_op_switcheroo" / "dags"


@pytest.fixture(name="dagster_defs_path")
def setup_dagster_defs_path() -> str:
return str(Path(__file__).parent / "airflow_op_switcheroo" / "dagster_defs.py")


def test_migrated_operator(
airflow_instance: None, dagster_dev: None, dagster_home: str, airflow_home: str
) -> None:
"""Tests that dagster migrated operator can correctly map airflow tasks to dagster tasks, and kick off executions."""
response = requests.post(
"http://localhost:8080/api/v1/dags/the_dag/dagRuns", auth=("admin", "admin"), json={}
)
assert response.status_code == 200, response.json()
# Wait until the run enters a terminal state
terminal_status = None
start_time = get_current_timestamp()
while get_current_timestamp() - start_time < 30:
response = requests.get(
"http://localhost:8080/api/v1/dags/the_dag/dagRuns", auth=("admin", "admin")
)
assert response.status_code == 200, response.json()
dag_runs = response.json()["dag_runs"]
if dag_runs[0]["state"] in ["success", "failed"]:
terminal_status = dag_runs[0]["state"]
break
time.sleep(1)
assert terminal_status == "success", (
"Never reached terminal status"
if terminal_status is None
else f"terminal status was {terminal_status}"
)
with environ({"DAGSTER_HOME": dagster_home}):
instance = DagsterInstance.get()
runs = instance.get_runs()
# The graphql endpoint kicks off a run for each of the tasks in the dag
assert len(runs) == 1
some_task_run = [ # noqa
run
for run in runs
if set(list(run.asset_selection)) == {AssetKey(["the_dag__some_task"])} # type: ignore
][0]
assert some_task_run.status == DagsterRunStatus.SUCCESS
1 change: 1 addition & 0 deletions examples/experimental/dagster-airlift/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ deps =
-e ../../../python_modules/dagster-test
-e ../../../python_modules/dagster-pipes
-e ../../../python_modules/dagster-webserver
-e ../../../python_modules/dagster-graphql
-e ../../../python_modules/libraries/dagster-dbt
-e .[core,mwaa,dbt,test,in-airflow]
dbt-duckdb
Expand Down

0 comments on commit 485bdd5

Please sign in to comment.