From e12b6b53252209413196192fa3fc450e0f7917f1 Mon Sep 17 00:00:00 2001 From: Chris DeCarolis Date: Fri, 2 Aug 2024 12:48:29 -0700 Subject: [PATCH] [dagster-airlift] dagster operator --- .../dagster_airlift/__init__.py | 5 +- .../dagster_airlift/test/shared_fixtures.py | 2 +- .../dagster_airlift/within_airflow.py | 208 ++++++++++++++++++ .../af_migrated_operator/dags/migrated_dag.py | 29 +++ .../af_migrated_operator/dagster_defs.py | 19 ++ .../test_migrated_operator.py | 122 ++++++++++ 6 files changed, 383 insertions(+), 2 deletions(-) create mode 100644 examples/experimental/dagster-airlift/dagster_airlift_tests/af_migrated_operator/dags/migrated_dag.py create mode 100644 examples/experimental/dagster-airlift/dagster_airlift_tests/af_migrated_operator/dagster_defs.py create mode 100644 examples/experimental/dagster-airlift/dagster_airlift_tests/test_migrated_operator.py diff --git a/examples/experimental/dagster-airlift/dagster_airlift/__init__.py b/examples/experimental/dagster-airlift/dagster_airlift/__init__.py index f475349e868b9..ee1f2569d4715 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/__init__.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/__init__.py @@ -8,4 +8,7 @@ PythonDefs as PythonDefs, load_defs_from_yaml as load_defs_from_yaml, ) -from .within_airflow import mark_as_dagster_migrating as mark_as_dagster_migrating +from .within_airflow import ( + build_dagster_migrated_operator as build_dagster_migrated_operator, + mark_as_dagster_migrating as mark_as_dagster_migrating, +) diff --git a/examples/experimental/dagster-airlift/dagster_airlift/test/shared_fixtures.py b/examples/experimental/dagster-airlift/dagster_airlift/test/shared_fixtures.py index 476eb0ca22b3c..dd297226f75f3 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/test/shared_fixtures.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/test/shared_fixtures.py @@ -31,7 +31,7 @@ def airflow_instance_fixture(setup: None) -> Generator[Any, None, None]: initial_time = get_current_timestamp() airflow_ready = False - while get_current_timestamp() - initial_time < 30: + while get_current_timestamp() - initial_time < 60: if airflow_is_ready(): airflow_ready = True break diff --git a/examples/experimental/dagster-airlift/dagster_airlift/within_airflow.py b/examples/experimental/dagster-airlift/dagster_airlift/within_airflow.py index 5daaa2f9d287b..a07f70b6d7f5d 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/within_airflow.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/within_airflow.py @@ -1,7 +1,10 @@ import json +import os import sys +import requests from airflow import DAG +from airflow.operators.python import PythonOperator def mark_as_dagster_migrating( @@ -59,3 +62,208 @@ def mark_as_dagster_migrating( ) globals_to_update[var] = new_dag global_vars.update(globals_to_update) + + +ASSET_NODES_QUERY = """ +query AssetNodeQuery { + assetNodes { + id + assetKey { + path + } + opName + jobs { + id + name + repository { + id + name + location { + id + name + } + } + } + } +} +""" + +TRIGGER_ASSETS_MUTATION = """ +mutation LaunchAssetsExecution($executionParams: ExecutionParams!) { + launchPipelineExecution(executionParams: $executionParams) { + ... on LaunchRunSuccess { + run { + id + pipelineName + __typename + } + __typename + } + ... on PipelineNotFoundError { + message + __typename + } + ... on InvalidSubsetError { + message + __typename + } + ... on RunConfigValidationInvalid { + errors { + message + __typename + } + __typename + } + ...PythonErrorFragment + __typename + } +} + +fragment PythonErrorFragment on PythonError { + message + stack + errorChain { + ...PythonErrorChain + __typename + } + __typename +} + +fragment PythonErrorChain on ErrorChainLink { + isExplicitLink + error { + message + stack + __typename + } + __typename +} +""" + +# request format +# { +# "executionParams": { +# "mode": "default", +# "executionMetadata": { +# "tags": [] +# }, +# "runConfigData": "{}", +# "selector": { +# "repositoryLocationName": "toys", +# "repositoryName": "__repository__", +# "pipelineName": "__ASSET_JOB_0", +# "assetSelection": [ +# { +# "path": [ +# "bigquery", +# "raw_customers" +# ] +# } +# ], +# "assetCheckSelection": [] +# } +# } +# } + +RUNS_QUERY = """ +query RunQuery($runId: ID!) { + runOrError(runId: $runId) { + __typename + ...PythonErrorFragment + ...NotFoundFragment + ... on Run { + id + status + __typename + } + } +} +fragment NotFoundFragment on RunNotFoundError { + __typename + message +} +fragment PythonErrorFragment on PythonError { + __typename + message + stack + causes { + message + stack + __typename + } +} +""" + + +def compute_fn() -> None: + # https://github.com/apache/airflow/discussions/24463 + os.environ["NO_PROXY"] = "*" + dag_id = os.environ["AIRFLOW_CTX_DAG_ID"] + task_id = os.environ["AIRFLOW_CTX_TASK_ID"] + expected_op_name = f"{dag_id}__{task_id}" + assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys + # create graphql client + dagster_url = os.environ["DAGSTER_URL"] + response = requests.post(f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3) + for asset_node in response.json()["data"]["assetNodes"]: + if asset_node["opName"] == expected_op_name: + repo_location = asset_node["jobs"][0]["repository"]["location"]["name"] + repo_name = asset_node["jobs"][0]["repository"]["name"] + job_name = asset_node["jobs"][0]["name"] + if (repo_location, repo_name, job_name) not in assets_to_trigger: + assets_to_trigger[(repo_location, repo_name, job_name)] = [] + assets_to_trigger[(repo_location, repo_name, job_name)].append( + asset_node["assetKey"]["path"] + ) + print(f"Found assets to trigger: {assets_to_trigger}") # noqa: T201 + triggered_runs = [] + for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items(): + execution_params = { + "mode": "default", + "executionMetadata": {"tags": []}, + "runConfigData": "{}", + "selector": { + "repositoryLocationName": repo_location, + "repositoryName": repo_name, + "pipelineName": job_name, + "assetSelection": [{"path": asset_key} for asset_key in asset_keys], + "assetCheckSelection": [], + }, + } + print(f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}") # noqa: T201 + response = requests.post( + f"{dagster_url}/graphql", + json={ + "query": TRIGGER_ASSETS_MUTATION, + "variables": {"executionParams": execution_params}, + }, + timeout=3, + ) + run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"] + print(f"Launched run {run_id}...") # noqa: T201 + triggered_runs.append(run_id) + completed_runs = {} # key is run_id, value is status + while len(completed_runs) < len(triggered_runs): + for run_id in triggered_runs: + if run_id in completed_runs: + continue + response = requests.post( + f"{dagster_url}/graphql", + json={"query": RUNS_QUERY, "variables": {"runId": run_id}}, + timeout=3, + ) + run_status = response.json()["data"]["runOrError"]["status"] + if run_status in ["SUCCESS", "FAILURE", "CANCELED"]: + print(f"Run {run_id} completed with status {run_status}") # noqa: T201 + completed_runs[run_id] = run_status + non_successful_runs = [ + run_id for run_id, status in completed_runs.items() if status != "SUCCESS" + ] + if non_successful_runs: + raise Exception(f"Runs {non_successful_runs} did not complete successfully.") + print("All runs completed successfully.") # noqa: T201 + return None + + +def build_dagster_migrated_operator(task_id: str, dag: DAG, **kwargs): + return PythonOperator(task_id=task_id, dag=dag, python_callable=compute_fn, **kwargs) diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/af_migrated_operator/dags/migrated_dag.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/af_migrated_operator/dags/migrated_dag.py new file mode 100644 index 0000000000000..1ca397c999ad9 --- /dev/null +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/af_migrated_operator/dags/migrated_dag.py @@ -0,0 +1,29 @@ +import logging +from datetime import datetime + +from airflow import DAG +from dagster_airlift import build_dagster_migrated_operator + +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) +requests_log = logging.getLogger("requests.packages.urllib3") +requests_log.setLevel(logging.INFO) +requests_log.propagate = True + + +def print_hello(): + print("Hello") # noqa: T201 + + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "start_date": datetime(2023, 1, 1), + "retries": 1, +} + +dag = DAG( + "the_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False +) +migrated_op = build_dagster_migrated_operator(task_id="some_task", dag=dag) +other_migrated_op = build_dagster_migrated_operator(task_id="other_task", dag=dag) diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/af_migrated_operator/dagster_defs.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/af_migrated_operator/dagster_defs.py new file mode 100644 index 0000000000000..91b9287203081 --- /dev/null +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/af_migrated_operator/dagster_defs.py @@ -0,0 +1,19 @@ +from dagster import Definitions, asset + + +@asset +def the_dag__some_task(): + return "asset_value" + + +@asset +def unrelated(): + return "unrelated_value" + + +@asset +def the_dag__other_task(): + return "other_task_value" + + +defs = Definitions(assets=[the_dag__other_task, the_dag__some_task, unrelated]) diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/test_migrated_operator.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/test_migrated_operator.py new file mode 100644 index 0000000000000..9aebe559c5306 --- /dev/null +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/test_migrated_operator.py @@ -0,0 +1,122 @@ +import os +import signal +import subprocess +import time +from tempfile import TemporaryDirectory +from typing import Any, Generator + +import pytest +import requests +from dagster import AssetKey, DagsterInstance, DagsterRunStatus +from dagster._core.test_utils import environ +from dagster._time import get_current_timestamp + + +@pytest.fixture(name="airflow_home") +def setup_airflow_home() -> Generator[str, None, None]: + with TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture(name="setup") +def setup_fixture(airflow_home: str) -> Generator[str, None, None]: + # run chmod +x create_airflow_cfg.sh and then run create_airflow_cfg.sh tmpdir + temp_env = { + **os.environ.copy(), + "AIRFLOW_HOME": airflow_home, + "DAGSTER_URL": "http://localhost:3333", + } + # go up one directory from current + path_to_script = os.path.join(os.path.dirname(__file__), "..", "airflow_setup.sh") + path_to_dags = os.path.join(os.path.dirname(__file__), "af_migrated_operator", "dags") + subprocess.run(["chmod", "+x", path_to_script], check=True, env=temp_env) + subprocess.run([path_to_script, path_to_dags], check=True, env=temp_env) + with environ({"AIRFLOW_HOME": airflow_home, "DAGSTER_URL": "http://localhost:3333"}): + yield airflow_home + + +def dagster_is_ready() -> bool: + try: + response = requests.get("http://localhost:3333") + return response.status_code == 200 + except: + return False + + +@pytest.fixture(name="dagster_home") +def setup_dagster_home() -> Generator[str, None, None]: + with TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture(name="dagster_dev") +def setup_dagster(dagster_home: str) -> Generator[Any, None, None]: + temp_env = {**os.environ.copy(), "DAGSTER_HOME": dagster_home} + path_to_defs = os.path.join( + os.path.dirname(__file__), "af_migrated_operator", "dagster_defs.py" + ) + process = subprocess.Popen( + ["dagster", "dev", "-f", path_to_defs, "-p", "3333"], + env=temp_env, + shell=False, + preexec_fn=os.setsid, # noqa + ) + # Give dagster a second to stand up + time.sleep(5) + + dagster_ready = False + initial_time = get_current_timestamp() + while get_current_timestamp() - initial_time < 60: + if dagster_is_ready(): + dagster_ready = True + break + time.sleep(1) + + assert dagster_ready, "Dagster did not start within 30 seconds..." + yield process + os.killpg(process.pid, signal.SIGKILL) + + +def test_migrated_operator( + airflow_instance: None, dagster_dev: None, dagster_home: str, airflow_home: str +) -> None: + """Tests that dagster migrated operator can correctly map airflow tasks to dagster tasks, and kick off executions.""" + response = requests.post( + "http://localhost:8080/api/v1/dags/the_dag/dagRuns", auth=("admin", "admin"), json={} + ) + assert response.status_code == 200, response.json() + # Wait until the run enters a terminal state + terminal_status = None + start_time = get_current_timestamp() + while get_current_timestamp() - start_time < 30: + response = requests.get( + "http://localhost:8080/api/v1/dags/the_dag/dagRuns", auth=("admin", "admin") + ) + assert response.status_code == 200, response.json() + dag_runs = response.json()["dag_runs"] + if dag_runs[0]["state"] in ["success", "failed"]: + terminal_status = dag_runs[0]["state"] + break + time.sleep(1) + assert terminal_status == "success", ( + "Never reached terminal status" + if terminal_status is None + else f"terminal status was {terminal_status}" + ) + with environ({"DAGSTER_HOME": dagster_home}): + instance = DagsterInstance.get() + runs = instance.get_runs() + # The graphql endpoint kicks off a run for each of the tasks in the dag + assert len(runs) == 2 + some_task_run = [ # noqa + run + for run in runs + if set(list(run.asset_selection)) == {AssetKey(["the_dag__other_task"])} # type: ignore + ][0] + other_task_run = [ # noqa + run + for run in runs + if set(list(run.asset_selection)) == {AssetKey(["the_dag__some_task"])} # type: ignore + ][0] + assert some_task_run.status == DagsterRunStatus.SUCCESS + assert other_task_run.status == DagsterRunStatus.SUCCESS