From 6c722d425ff4d085cc4a93374a1eaa1aa6682f6f Mon Sep 17 00:00:00 2001 From: Chris DeCarolis Date: Fri, 2 Aug 2024 17:01:21 -0700 Subject: [PATCH] [dagster-airlift] airflow operator switcher --- .../dagster_buildkite/steps/packages.py | 6 ++ .../dagster-airlift/airflow_setup.sh | 3 +- .../core/airflow_cacheable_assets_def.py | 2 +- .../in_airflow/dagster_operator.py | 96 +++++++++++++++++-- .../in_airflow/mark_as_migrating.py | 66 ++++++++++--- .../dagster_airlift/migration_state.py | 9 +- .../dags/switcheroo_dag.py | 68 +++++++++++++ .../airflow_op_switcheroo/dagster_defs.py | 9 ++ .../dags/migrated_dag.py | 17 ++-- .../test_operator_switcheroo.py | 57 +++++++++++ 10 files changed, 300 insertions(+), 33 deletions(-) create mode 100644 examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dags/switcheroo_dag.py create mode 100644 examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dagster_defs.py create mode 100644 examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/test_operator_switcheroo.py diff --git a/.buildkite/dagster-buildkite/dagster_buildkite/steps/packages.py b/.buildkite/dagster-buildkite/dagster_buildkite/steps/packages.py index 77f3ee213fe11..cdb5753daede2 100644 --- a/.buildkite/dagster-buildkite/dagster_buildkite/steps/packages.py +++ b/.buildkite/dagster-buildkite/dagster_buildkite/steps/packages.py @@ -366,6 +366,12 @@ def k8s_extra_cmds(version: str, _) -> List[str]: AvailablePythonVersion.V3_12, ], ), + PackageSpec( + "examples/experimental/dagster-airlift/examples/simple-migration", + unsupported_python_versions=[ + AvailablePythonVersion.V3_12, + ], + ), ] diff --git a/examples/experimental/dagster-airlift/airflow_setup.sh b/examples/experimental/dagster-airlift/airflow_setup.sh index 0312e5ddd5d82..8d6bc4422898c 100755 --- a/examples/experimental/dagster-airlift/airflow_setup.sh +++ b/examples/experimental/dagster-airlift/airflow_setup.sh @@ -14,10 +14,11 @@ if [[ "$DAGS_FOLDER" != /* ]]; then exit 1 fi -# Create the airflow.cfg file in $AIRFLOW_HOME +# Create the airflow.cfg file in $AIRFLOW_HOME. We set a super high import timeout so that we can attach a debugger and mess around with the code. cat < $AIRFLOW_HOME/airflow.cfg [core] dags_folder = $DAGS_FOLDER +dagbag_import_timeout = 30000 load_examples = False [api] auth_backend = airflow.api.auth.backend.basic_auth diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_cacheable_assets_def.py b/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_cacheable_assets_def.py index b8969fcfe5dab..fe3b1da77b671 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_cacheable_assets_def.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_cacheable_assets_def.py @@ -442,5 +442,5 @@ def _get_migration_state_for_task( migration_state: Optional[AirflowMigrationState], dag_id: str, task_id: str ) -> bool: if migration_state: - return migration_state.get_migration_state_for_task(dag_id, task_id) + return migration_state.get_migration_state_for_task(dag_id, task_id) or False return False diff --git a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py index 5f26b2acf214f..2f3fca60a0a13 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py @@ -1,21 +1,30 @@ +import inspect +import logging import os +from typing import Any, Callable, Dict, Set, Tuple import requests -from airflow import DAG +from airflow.models.operator import BaseOperator from airflow.operators.python import PythonOperator from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION +logger = logging.getLogger(__name__) + def compute_fn() -> None: # https://github.com/apache/airflow/discussions/24463 os.environ["NO_PROXY"] = "*" dag_id = os.environ["AIRFLOW_CTX_DAG_ID"] task_id = os.environ["AIRFLOW_CTX_TASK_ID"] + dagster_url = os.environ["DAGSTER_URL"] + return launch_runs_for_task(dag_id, task_id, dagster_url) + + +def launch_runs_for_task(dag_id: str, task_id: str, dagster_url: str) -> None: expected_op_name = f"{dag_id}__{task_id}" assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys # create graphql client - dagster_url = os.environ["DAGSTER_URL"] response = requests.post(f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3) for asset_node in response.json()["data"]["assetNodes"]: if asset_node["opName"] == expected_op_name: @@ -27,7 +36,7 @@ def compute_fn() -> None: assets_to_trigger[(repo_location, repo_name, job_name)].append( asset_node["assetKey"]["path"] ) - print(f"Found assets to trigger: {assets_to_trigger}") # noqa: T201 + logger.debug(f"Found assets to trigger: {assets_to_trigger}") triggered_runs = [] for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items(): execution_params = { @@ -42,7 +51,9 @@ def compute_fn() -> None: "assetCheckSelection": [], }, } - print(f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}") # noqa: T201 + logger.debug( + f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}" + ) response = requests.post( f"{dagster_url}/graphql", json={ @@ -52,7 +63,7 @@ def compute_fn() -> None: timeout=3, ) run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"] - print(f"Launched run {run_id}...") # noqa: T201 + logger.debug(f"Launched run {run_id}...") triggered_runs.append(run_id) completed_runs = {} # key is run_id, value is status while len(completed_runs) < len(triggered_runs): @@ -66,16 +77,83 @@ def compute_fn() -> None: ) run_status = response.json()["data"]["runOrError"]["status"] if run_status in ["SUCCESS", "FAILURE", "CANCELED"]: - print(f"Run {run_id} completed with status {run_status}") # noqa: T201 + logger.debug(f"Run {run_id} completed with status {run_status}") completed_runs[run_id] = run_status non_successful_runs = [ run_id for run_id, status in completed_runs.items() if status != "SUCCESS" ] if non_successful_runs: raise Exception(f"Runs {non_successful_runs} did not complete successfully.") - print("All runs completed successfully.") # noqa: T201 + logger.debug("All runs completed successfully.") return None -def build_dagster_task(task_id: str, dag: DAG, **kwargs): - return PythonOperator(task_id=task_id, dag=dag, python_callable=compute_fn, **kwargs) +class DagsterOperator(PythonOperator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, python_callable=compute_fn) + + +def build_dagster_task(original_task: BaseOperator) -> DagsterOperator: + return instantiate_dagster_operator(original_task) + + +def instantiate_dagster_operator(original_task: BaseOperator) -> DagsterOperator: + """Instantiates a DagsterOperator as a copy of the provided airflow task. + + We attempt to copy as many of the original task's attributes as possible, while respecting + that attributes may change between airflow versions. In order to do this, we inspect the + arguments available to the BaseOperator constructor and copy over any of those arguments that + are available as attributes on the original task. + This approach has limitations: + - If the task attribute is transformed and stored on another property, it will not be copied. + - If the task attribute is transformed in a way that makes it incompatible with the constructor arg + and stored in the same property, that will attempt to be copied and potentiall break initialization. + In the future, if we hit problems with this, we may need to add argument overrides to ensure we either + attempt to include certain additional attributes, or exclude others. If this continues to be a problem + across airflow versions, it may be necessary to revise this approach to one that explicitly maps airflow + version to a set of expected arguments and attributes. + """ + base_operator_args, base_operator_args_with_defaults = get_params(BaseOperator.__init__) + init_kwargs = {} + + ignore_args = ["kwargs", "args", "dag"] + for arg in base_operator_args: + if arg in ignore_args or getattr(original_task, arg, None) is None: + continue + init_kwargs[arg] = getattr(original_task, arg) + for kwarg, default in base_operator_args_with_defaults.items(): + if kwarg in ignore_args or getattr(original_task, kwarg, None) is None: + continue + init_kwargs[kwarg] = getattr(original_task, kwarg, default) + + return DagsterOperator(**init_kwargs) + + +def get_params(func: Callable[..., Any]) -> Tuple[Set[str], Dict[str, Any]]: + """Retrieves the args and kwargs from the signature of a given function or method. + For kwargs, default values are retrieved as well. + + Args: + func (Callable[..., Any]): The function or method to inspect. + + Returns: + Tuple[Set[str], Dict[str, Any]]: + - A set of argument names that do not have default values. + - A dictionary of keyword argument names and their default values. + """ + # Get the function's signature + sig = inspect.signature(func) + + # Initialize sets for args without defaults and kwargs with defaults + args_with_defaults = {} + args = set() + + # Iterate over function parameters + for name, param in sig.parameters.items(): + if param.default is inspect.Parameter.empty and name != "self": # Exclude 'self' + args.add(name) + else: + if name != "self": # Exclude 'self' + args_with_defaults[name] = param.default + + return args, args_with_defaults diff --git a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/mark_as_migrating.py b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/mark_as_migrating.py index b5fbfd1285edc..17304b9fa1782 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/mark_as_migrating.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/mark_as_migrating.py @@ -1,8 +1,11 @@ import json import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from airflow import DAG +from airflow.models import BaseOperator + +from dagster_airlift.in_airflow.dagster_operator import build_dagster_task from ..migration_state import AirflowMigrationState @@ -29,21 +32,60 @@ def mark_as_dagster_migrating( if not logger: logger = logging.getLogger("dagster_airlift") logger.debug(f"Searching for dags migrating to dagster{suffix}...") - num_dags = 0 + migrating_dags: List[DAG] = [] + # Do a pass to collect dags and ensure that migration information is set correctly. for obj in global_vars.values(): if not isinstance(obj, DAG): continue dag: DAG = obj - logger.debug(f"Checking dag with id `{dag.dag_id}` for migration state.") - migration_state_for_dag = migration_state.get_migration_dict_for_dag(dag.dag_id) - if migration_state_for_dag is None: - logger.debug( - f"Dag with id `{dag.dag_id} hasn't been marked with migration state. Skipping..." + if not migration_state.dag_has_migration_state(dag.dag_id): + logger.debug(f"Dag with id `{dag.dag_id}` has no migration state. Skipping...") + continue + logger.debug(f"Dag with id `{dag.dag_id}` has migration state.") + migration_state_for_dag = migration_state.dags[dag.dag_id] + for task_id in migration_state_for_dag.tasks.keys(): + if task_id not in dag.task_dict: + raise Exception( + f"Task with id `{task_id}` not found in dag `{dag.dag_id}`. Found tasks: {list(dag.task_dict.keys())}" + ) + if not isinstance(dag.task_dict[task_id], BaseOperator): + raise Exception( + f"Task with id `{task_id}` in dag `{dag.dag_id}` is not an instance of BaseOperator. This likely means a MappedOperator was attempted, which is not yet supported by airlift." + ) + migrating_dags.append(dag) + + for dag in migrating_dags: + logger.debug(f"Tagging dag {dag.dag_id} as migrating.") + dag.tags.append( + json.dumps( + {"DAGSTER_MIGRATION_STATUS": migration_state.get_migration_dict_for_dag(dag.dag_id)} ) - else: - dag.tags.append(json.dumps({"DAGSTER_MIGRATION_STATUS": migration_state_for_dag})) + ) + migration_state_for_dag = migration_state.dags[dag.dag_id] + migrated_tasks = set() + for task_id, task_state in migration_state_for_dag.tasks.items(): + if not task_state.migrated: + logger.debug( + f"Task {task_id} in dag {dag.dag_id} has `migrated` set to False. Skipping..." + ) + continue + + # At this point, we should be assured that the task exists within the task_dict of the dag, and is a BaseOperator. + original_op: BaseOperator = dag.task_dict[task_id] # type: ignore # we already confirmed this is BaseOperator + del dag.task_dict[task_id] + if original_op.task_group is not None: + del original_op.task_group.children[task_id] logger.debug( - f"Dag with id `{dag.dag_id}` has been marked with migration state. Adding state to tags for dag." + f"Creating new operator for task {original_op.task_id} in dag {original_op.dag_id}" ) - num_dags += 1 - logger.info(f"Marked {num_dags} dags as migrating to dagster{suffix}.") + new_op = build_dagster_task(original_op) + original_op.dag.task_dict[original_op.task_id] = new_op + + new_op.upstream_task_ids = original_op.upstream_task_ids + new_op.downstream_task_ids = original_op.downstream_task_ids + new_op.dag = original_op.dag + original_op.dag = None + migrated_tasks.add(task_id) + logger.debug(f"Migrated tasks {migrated_tasks} in dag {dag.dag_id}.") + logging.debug(f"Migrated {len(migrating_dags)}.") + logging.debug(f"Completed marking dags and tasks as migrating to dagster{suffix}.") diff --git a/examples/experimental/dagster-airlift/dagster_airlift/migration_state.py b/examples/experimental/dagster-airlift/dagster_airlift/migration_state.py index 028074457f794..5d17b99c66547 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/migration_state.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/migration_state.py @@ -15,9 +15,16 @@ class DagMigrationState(NamedTuple): class AirflowMigrationState(NamedTuple): dags: Dict[str, DagMigrationState] - def get_migration_state_for_task(self, dag_id: str, task_id: str) -> bool: + def get_migration_state_for_task(self, dag_id: str, task_id: str) -> Optional[bool]: + if dag_id not in self.dags: + return None + if task_id not in self.dags[dag_id].tasks: + return None return self.dags[dag_id].tasks[task_id].migrated + def dag_has_migration_state(self, dag_id: str) -> bool: + return self.get_migration_dict_for_dag(dag_id) is not None + def get_migration_dict_for_dag(self, dag_id: str) -> Optional[Dict[str, Dict[str, Any]]]: if dag_id not in self.dags: return None diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dags/switcheroo_dag.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dags/switcheroo_dag.py new file mode 100644 index 0000000000000..2954f517b1e3e --- /dev/null +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dags/switcheroo_dag.py @@ -0,0 +1,68 @@ +import logging +import os +from datetime import datetime + +from airflow import DAG +from airflow.operators.python import PythonOperator +from dagster_airlift.in_airflow import mark_as_dagster_migrating +from dagster_airlift.migration_state import ( + AirflowMigrationState, + DagMigrationState, + TaskMigrationState, +) + +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) +requests_log = logging.getLogger("requests.packages.urllib3") +requests_log.setLevel(logging.INFO) +requests_log.propagate = True + + +def write_to_file_in_airflow_home() -> None: + airflow_home = os.environ["AIRFLOW_HOME"] + with open(os.path.join(airflow_home, "airflow_home_file.txt"), "w") as f: + f.write("Hello") + + +def write_to_other_file_in_airflow_home() -> None: + airflow_home = os.environ["AIRFLOW_HOME"] + with open(os.path.join(airflow_home, "other_airflow_home_file.txt"), "w") as f: + f.write("Hello") + + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "start_date": datetime(2023, 1, 1), + "retries": 1, +} + +dag = DAG( + "the_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False +) +op_to_migrate = PythonOperator( + task_id="some_task", python_callable=write_to_file_in_airflow_home, dag=dag +) +op_doesnt_migrate = PythonOperator( + task_id="other_task", python_callable=write_to_other_file_in_airflow_home, dag=dag +) +# Add a dependency between the two tasks +op_doesnt_migrate.set_upstream(op_to_migrate) + +# # set up the debugger +# print("Waiting for debugger to attach...") +# debugpy.listen(("localhost", 7778)) +# debugpy.wait_for_client() +mark_as_dagster_migrating( + global_vars=globals(), + migration_state=AirflowMigrationState( + dags={ + "the_dag": DagMigrationState( + tasks={ + "some_task": TaskMigrationState(migrated=True), + "other_task": TaskMigrationState(migrated=True), + } + ) + } + ), +) diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dagster_defs.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dagster_defs.py new file mode 100644 index 0000000000000..a2def0e909262 --- /dev/null +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/airflow_op_switcheroo/dagster_defs.py @@ -0,0 +1,9 @@ +from dagster import Definitions, asset + + +@asset +def the_dag__some_task(): + return "asset_value" + + +defs = Definitions(assets=[the_dag__some_task]) diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/operator_test_project/dags/migrated_dag.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/operator_test_project/dags/migrated_dag.py index 7b6faf7eef018..03a08757906c1 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/operator_test_project/dags/migrated_dag.py +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/operator_test_project/dags/migrated_dag.py @@ -2,7 +2,7 @@ from datetime import datetime from airflow import DAG -from dagster_airlift.in_airflow.dagster_operator import build_dagster_task +from dagster_airlift.in_airflow.dagster_operator import DagsterOperator logging.basicConfig() logging.getLogger().setLevel(logging.INFO) @@ -11,19 +11,18 @@ requests_log.propagate = True -def print_hello(): - print("Hello") # noqa: T201 - - default_args = { "owner": "airflow", "depends_on_past": False, - "start_date": datetime(2023, 1, 1), "retries": 1, } dag = DAG( - "the_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False + "the_dag", + default_args=default_args, + schedule_interval=None, + is_paused_upon_creation=False, + start_date=datetime(2023, 1, 1), ) -migrated_op = build_dagster_task(task_id="some_task", dag=dag) -other_migrated_op = build_dagster_task(task_id="other_task", dag=dag) +print_task = DagsterOperator(task_id="some_task", dag=dag) +other_task = DagsterOperator(task_id="other_task", dag=dag) diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/test_operator_switcheroo.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/test_operator_switcheroo.py new file mode 100644 index 0000000000000..3643ff59a13c0 --- /dev/null +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/test_operator_switcheroo.py @@ -0,0 +1,57 @@ +import time +from pathlib import Path + +import pytest +import requests +from dagster import AssetKey, DagsterInstance, DagsterRunStatus +from dagster._core.test_utils import environ +from dagster._time import get_current_timestamp + + +@pytest.fixture(name="dags_dir") +def setup_dags_dir() -> Path: + return Path(__file__).parent / "airflow_op_switcheroo" / "dags" + + +@pytest.fixture(name="dagster_defs_path") +def setup_dagster_defs_path() -> str: + return str(Path(__file__).parent / "airflow_op_switcheroo" / "dagster_defs.py") + + +def test_migrated_operator( + airflow_instance: None, dagster_dev: None, dagster_home: str, airflow_home: str +) -> None: + """Tests that dagster migrated operator can correctly map airflow tasks to dagster tasks, and kick off executions.""" + response = requests.post( + "http://localhost:8080/api/v1/dags/the_dag/dagRuns", auth=("admin", "admin"), json={} + ) + assert response.status_code == 200, response.json() + # Wait until the run enters a terminal state + terminal_status = None + start_time = get_current_timestamp() + while get_current_timestamp() - start_time < 30: + response = requests.get( + "http://localhost:8080/api/v1/dags/the_dag/dagRuns", auth=("admin", "admin") + ) + assert response.status_code == 200, response.json() + dag_runs = response.json()["dag_runs"] + if dag_runs[0]["state"] in ["success", "failed"]: + terminal_status = dag_runs[0]["state"] + break + time.sleep(1) + assert terminal_status == "success", ( + "Never reached terminal status" + if terminal_status is None + else f"terminal status was {terminal_status}" + ) + with environ({"DAGSTER_HOME": dagster_home}): + instance = DagsterInstance.get() + runs = instance.get_runs() + # The graphql endpoint kicks off a run for each of the tasks in the dag + assert len(runs) == 1 + some_task_run = [ # noqa + run + for run in runs + if set(list(run.asset_selection)) == {AssetKey(["the_dag__some_task"])} # type: ignore + ][0] + assert some_task_run.status == DagsterRunStatus.SUCCESS