Skip to content

Commit

Permalink
[dagster-airlift] airflow operator switcher
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Aug 12, 2024
1 parent dde6733 commit 6c722d4
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,12 @@ def k8s_extra_cmds(version: str, _) -> List[str]:
AvailablePythonVersion.V3_12,
],
),
PackageSpec(
"examples/experimental/dagster-airlift/examples/simple-migration",
unsupported_python_versions=[
AvailablePythonVersion.V3_12,
],
),
]


Expand Down
3 changes: 2 additions & 1 deletion examples/experimental/dagster-airlift/airflow_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ if [[ "$DAGS_FOLDER" != /* ]]; then
exit 1
fi

# Create the airflow.cfg file in $AIRFLOW_HOME
# Create the airflow.cfg file in $AIRFLOW_HOME. We set a super high import timeout so that we can attach a debugger and mess around with the code.
cat <<EOL > $AIRFLOW_HOME/airflow.cfg
[core]
dags_folder = $DAGS_FOLDER
dagbag_import_timeout = 30000
load_examples = False
[api]
auth_backend = airflow.api.auth.backend.basic_auth
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,5 +442,5 @@ def _get_migration_state_for_task(
migration_state: Optional[AirflowMigrationState], dag_id: str, task_id: str
) -> bool:
if migration_state:
return migration_state.get_migration_state_for_task(dag_id, task_id)
return migration_state.get_migration_state_for_task(dag_id, task_id) or False
return False
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
import inspect
import logging
import os
from typing import Any, Callable, Dict, Set, Tuple

import requests
from airflow import DAG
from airflow.models.operator import BaseOperator
from airflow.operators.python import PythonOperator

from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION

logger = logging.getLogger(__name__)


def compute_fn() -> None:
# https://github.com/apache/airflow/discussions/24463
os.environ["NO_PROXY"] = "*"
dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
task_id = os.environ["AIRFLOW_CTX_TASK_ID"]
dagster_url = os.environ["DAGSTER_URL"]
return launch_runs_for_task(dag_id, task_id, dagster_url)


def launch_runs_for_task(dag_id: str, task_id: str, dagster_url: str) -> None:
expected_op_name = f"{dag_id}__{task_id}"
assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys
# create graphql client
dagster_url = os.environ["DAGSTER_URL"]
response = requests.post(f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3)
for asset_node in response.json()["data"]["assetNodes"]:
if asset_node["opName"] == expected_op_name:
Expand All @@ -27,7 +36,7 @@ def compute_fn() -> None:
assets_to_trigger[(repo_location, repo_name, job_name)].append(
asset_node["assetKey"]["path"]
)
print(f"Found assets to trigger: {assets_to_trigger}") # noqa: T201
logger.debug(f"Found assets to trigger: {assets_to_trigger}")
triggered_runs = []
for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items():
execution_params = {
Expand All @@ -42,7 +51,9 @@ def compute_fn() -> None:
"assetCheckSelection": [],
},
}
print(f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}") # noqa: T201
logger.debug(
f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}"
)
response = requests.post(
f"{dagster_url}/graphql",
json={
Expand All @@ -52,7 +63,7 @@ def compute_fn() -> None:
timeout=3,
)
run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"]
print(f"Launched run {run_id}...") # noqa: T201
logger.debug(f"Launched run {run_id}...")
triggered_runs.append(run_id)
completed_runs = {} # key is run_id, value is status
while len(completed_runs) < len(triggered_runs):
Expand All @@ -66,16 +77,83 @@ def compute_fn() -> None:
)
run_status = response.json()["data"]["runOrError"]["status"]
if run_status in ["SUCCESS", "FAILURE", "CANCELED"]:
print(f"Run {run_id} completed with status {run_status}") # noqa: T201
logger.debug(f"Run {run_id} completed with status {run_status}")
completed_runs[run_id] = run_status
non_successful_runs = [
run_id for run_id, status in completed_runs.items() if status != "SUCCESS"
]
if non_successful_runs:
raise Exception(f"Runs {non_successful_runs} did not complete successfully.")
print("All runs completed successfully.") # noqa: T201
logger.debug("All runs completed successfully.")
return None


def build_dagster_task(task_id: str, dag: DAG, **kwargs):
return PythonOperator(task_id=task_id, dag=dag, python_callable=compute_fn, **kwargs)
class DagsterOperator(PythonOperator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, python_callable=compute_fn)


def build_dagster_task(original_task: BaseOperator) -> DagsterOperator:
return instantiate_dagster_operator(original_task)


def instantiate_dagster_operator(original_task: BaseOperator) -> DagsterOperator:
"""Instantiates a DagsterOperator as a copy of the provided airflow task.
We attempt to copy as many of the original task's attributes as possible, while respecting
that attributes may change between airflow versions. In order to do this, we inspect the
arguments available to the BaseOperator constructor and copy over any of those arguments that
are available as attributes on the original task.
This approach has limitations:
- If the task attribute is transformed and stored on another property, it will not be copied.
- If the task attribute is transformed in a way that makes it incompatible with the constructor arg
and stored in the same property, that will attempt to be copied and potentiall break initialization.
In the future, if we hit problems with this, we may need to add argument overrides to ensure we either
attempt to include certain additional attributes, or exclude others. If this continues to be a problem
across airflow versions, it may be necessary to revise this approach to one that explicitly maps airflow
version to a set of expected arguments and attributes.
"""
base_operator_args, base_operator_args_with_defaults = get_params(BaseOperator.__init__)
init_kwargs = {}

ignore_args = ["kwargs", "args", "dag"]
for arg in base_operator_args:
if arg in ignore_args or getattr(original_task, arg, None) is None:
continue
init_kwargs[arg] = getattr(original_task, arg)
for kwarg, default in base_operator_args_with_defaults.items():
if kwarg in ignore_args or getattr(original_task, kwarg, None) is None:
continue
init_kwargs[kwarg] = getattr(original_task, kwarg, default)

return DagsterOperator(**init_kwargs)


def get_params(func: Callable[..., Any]) -> Tuple[Set[str], Dict[str, Any]]:
"""Retrieves the args and kwargs from the signature of a given function or method.
For kwargs, default values are retrieved as well.
Args:
func (Callable[..., Any]): The function or method to inspect.
Returns:
Tuple[Set[str], Dict[str, Any]]:
- A set of argument names that do not have default values.
- A dictionary of keyword argument names and their default values.
"""
# Get the function's signature
sig = inspect.signature(func)

# Initialize sets for args without defaults and kwargs with defaults
args_with_defaults = {}
args = set()

# Iterate over function parameters
for name, param in sig.parameters.items():
if param.default is inspect.Parameter.empty and name != "self": # Exclude 'self'
args.add(name)
else:
if name != "self": # Exclude 'self'
args_with_defaults[name] = param.default

return args, args_with_defaults
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from airflow import DAG
from airflow.models import BaseOperator

from dagster_airlift.in_airflow.dagster_operator import build_dagster_task

from ..migration_state import AirflowMigrationState

Expand All @@ -29,21 +32,60 @@ def mark_as_dagster_migrating(
if not logger:
logger = logging.getLogger("dagster_airlift")
logger.debug(f"Searching for dags migrating to dagster{suffix}...")
num_dags = 0
migrating_dags: List[DAG] = []
# Do a pass to collect dags and ensure that migration information is set correctly.
for obj in global_vars.values():
if not isinstance(obj, DAG):
continue
dag: DAG = obj
logger.debug(f"Checking dag with id `{dag.dag_id}` for migration state.")
migration_state_for_dag = migration_state.get_migration_dict_for_dag(dag.dag_id)
if migration_state_for_dag is None:
logger.debug(
f"Dag with id `{dag.dag_id} hasn't been marked with migration state. Skipping..."
if not migration_state.dag_has_migration_state(dag.dag_id):
logger.debug(f"Dag with id `{dag.dag_id}` has no migration state. Skipping...")
continue
logger.debug(f"Dag with id `{dag.dag_id}` has migration state.")
migration_state_for_dag = migration_state.dags[dag.dag_id]
for task_id in migration_state_for_dag.tasks.keys():
if task_id not in dag.task_dict:
raise Exception(
f"Task with id `{task_id}` not found in dag `{dag.dag_id}`. Found tasks: {list(dag.task_dict.keys())}"
)
if not isinstance(dag.task_dict[task_id], BaseOperator):
raise Exception(
f"Task with id `{task_id}` in dag `{dag.dag_id}` is not an instance of BaseOperator. This likely means a MappedOperator was attempted, which is not yet supported by airlift."
)
migrating_dags.append(dag)

for dag in migrating_dags:
logger.debug(f"Tagging dag {dag.dag_id} as migrating.")
dag.tags.append(
json.dumps(
{"DAGSTER_MIGRATION_STATUS": migration_state.get_migration_dict_for_dag(dag.dag_id)}
)
else:
dag.tags.append(json.dumps({"DAGSTER_MIGRATION_STATUS": migration_state_for_dag}))
)
migration_state_for_dag = migration_state.dags[dag.dag_id]
migrated_tasks = set()
for task_id, task_state in migration_state_for_dag.tasks.items():
if not task_state.migrated:
logger.debug(
f"Task {task_id} in dag {dag.dag_id} has `migrated` set to False. Skipping..."
)
continue

# At this point, we should be assured that the task exists within the task_dict of the dag, and is a BaseOperator.
original_op: BaseOperator = dag.task_dict[task_id] # type: ignore # we already confirmed this is BaseOperator
del dag.task_dict[task_id]
if original_op.task_group is not None:
del original_op.task_group.children[task_id]
logger.debug(
f"Dag with id `{dag.dag_id}` has been marked with migration state. Adding state to tags for dag."
f"Creating new operator for task {original_op.task_id} in dag {original_op.dag_id}"
)
num_dags += 1
logger.info(f"Marked {num_dags} dags as migrating to dagster{suffix}.")
new_op = build_dagster_task(original_op)
original_op.dag.task_dict[original_op.task_id] = new_op

new_op.upstream_task_ids = original_op.upstream_task_ids
new_op.downstream_task_ids = original_op.downstream_task_ids
new_op.dag = original_op.dag
original_op.dag = None
migrated_tasks.add(task_id)
logger.debug(f"Migrated tasks {migrated_tasks} in dag {dag.dag_id}.")
logging.debug(f"Migrated {len(migrating_dags)}.")
logging.debug(f"Completed marking dags and tasks as migrating to dagster{suffix}.")
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@ class DagMigrationState(NamedTuple):
class AirflowMigrationState(NamedTuple):
dags: Dict[str, DagMigrationState]

def get_migration_state_for_task(self, dag_id: str, task_id: str) -> bool:
def get_migration_state_for_task(self, dag_id: str, task_id: str) -> Optional[bool]:
if dag_id not in self.dags:
return None
if task_id not in self.dags[dag_id].tasks:
return None
return self.dags[dag_id].tasks[task_id].migrated

def dag_has_migration_state(self, dag_id: str) -> bool:
return self.get_migration_dict_for_dag(dag_id) is not None

def get_migration_dict_for_dag(self, dag_id: str) -> Optional[Dict[str, Dict[str, Any]]]:
if dag_id not in self.dags:
return None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import os
from datetime import datetime

from airflow import DAG
from airflow.operators.python import PythonOperator
from dagster_airlift.in_airflow import mark_as_dagster_migrating
from dagster_airlift.migration_state import (
AirflowMigrationState,
DagMigrationState,
TaskMigrationState,
)

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
requests_log = logging.getLogger("requests.packages.urllib3")
requests_log.setLevel(logging.INFO)
requests_log.propagate = True


def write_to_file_in_airflow_home() -> None:
airflow_home = os.environ["AIRFLOW_HOME"]
with open(os.path.join(airflow_home, "airflow_home_file.txt"), "w") as f:
f.write("Hello")


def write_to_other_file_in_airflow_home() -> None:
airflow_home = os.environ["AIRFLOW_HOME"]
with open(os.path.join(airflow_home, "other_airflow_home_file.txt"), "w") as f:
f.write("Hello")


default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2023, 1, 1),
"retries": 1,
}

dag = DAG(
"the_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False
)
op_to_migrate = PythonOperator(
task_id="some_task", python_callable=write_to_file_in_airflow_home, dag=dag
)
op_doesnt_migrate = PythonOperator(
task_id="other_task", python_callable=write_to_other_file_in_airflow_home, dag=dag
)
# Add a dependency between the two tasks
op_doesnt_migrate.set_upstream(op_to_migrate)

# # set up the debugger
# print("Waiting for debugger to attach...")
# debugpy.listen(("localhost", 7778))
# debugpy.wait_for_client()
mark_as_dagster_migrating(
global_vars=globals(),
migration_state=AirflowMigrationState(
dags={
"the_dag": DagMigrationState(
tasks={
"some_task": TaskMigrationState(migrated=True),
"other_task": TaskMigrationState(migrated=True),
}
)
}
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dagster import Definitions, asset


@asset
def the_dag__some_task():
return "asset_value"


defs = Definitions(assets=[the_dag__some_task])
Loading

0 comments on commit 6c722d4

Please sign in to comment.