diff --git a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py index e6a75d3ab64e3..658f4d5c241f1 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/dagster_operator.py @@ -1,110 +1,147 @@ import inspect import logging import os -from typing import Any, Callable, Dict, Set, Tuple +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Set, Tuple, Type import requests from airflow.models.operator import BaseOperator -from airflow.operators.python import PythonOperator +from airflow.utils.context import Context from dagster_airlift.core.utils import DAG_ID_TAG, TASK_ID_TAG -from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION +from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION, VERIFICATION_QUERY logger = logging.getLogger(__name__) -def compute_fn() -> None: - # https://github.com/apache/airflow/discussions/24463 - os.environ["NO_PROXY"] = "*" - dag_id = os.environ["AIRFLOW_CTX_DAG_ID"] - task_id = os.environ["AIRFLOW_CTX_TASK_ID"] - dagster_url = os.environ["DAGSTER_URL"] - return launch_runs_for_task(dag_id, task_id, dagster_url) - - -def launch_runs_for_task(dag_id: str, task_id: str, dagster_url: str) -> None: - expected_op_name = f"{dag_id}__{task_id}" - - assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys - # create graphql client - response = requests.post(f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3) - for asset_node in response.json()["data"]["assetNodes"]: - tags = {tag["key"]: tag["value"] for tag in asset_node["tags"]} - # match assets based on conventional dag_id__task_id naming or based on explicit tags - if asset_node["opName"] == expected_op_name or ( - tags.get(DAG_ID_TAG) == dag_id and tags.get(TASK_ID_TAG) == task_id - ): - repo_location = asset_node["jobs"][0]["repository"]["location"]["name"] - repo_name = asset_node["jobs"][0]["repository"]["name"] - job_name = asset_node["jobs"][0]["name"] - if (repo_location, repo_name, job_name) not in assets_to_trigger: - assets_to_trigger[(repo_location, repo_name, job_name)] = [] - assets_to_trigger[(repo_location, repo_name, job_name)].append( - asset_node["assetKey"]["path"] - ) - logger.debug(f"Found assets to trigger: {assets_to_trigger}") - triggered_runs = [] - for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items(): - execution_params = { - "mode": "default", - "executionMetadata": {"tags": []}, - "runConfigData": "{}", - "selector": { - "repositoryLocationName": repo_location, - "repositoryName": repo_name, - "pipelineName": job_name, - "assetSelection": [{"path": asset_key} for asset_key in asset_keys], - "assetCheckSelection": [], - }, - } - logger.debug( - f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}" +class BaseProxyToDagsterOperator(BaseOperator, ABC): + """Interface for a DagsterOperator. + + This interface is used to create a custom operator that will be used to replace the original airflow operator when a task is marked as migrated. + """ + + @abstractmethod + def get_dagster_session(self, context: Context) -> requests.Session: + """Returns a requests session that can be used to make requests to the Dagster API.""" + + def _get_validated_session(self, context: Context) -> requests.Session: + session = self.get_dagster_session(context) + dagster_url = self.get_dagster_url(context) + response = session.post( + f"{dagster_url}/graphql", json={"query": VERIFICATION_QUERY}, timeout=3 ) - response = requests.post( - f"{dagster_url}/graphql", - json={ - "query": TRIGGER_ASSETS_MUTATION, - "variables": {"executionParams": execution_params}, - }, - timeout=3, + if response.status_code != 200: + raise Exception( + f"Failed to connect to Dagster at {dagster_url}. Response: {response.text}" + ) + return session + + @abstractmethod + def get_dagster_url(self, context: Context) -> str: + """Returns the URL for the Dagster instance.""" + + def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> None: + """Launches runs for the given task in Dagster.""" + expected_op_name = f"{dag_id}__{task_id}" + session = self._get_validated_session(context) + + dagster_url = self.get_dagster_url(context) + assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys + # create graphql client + response = session.post( + f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3 ) - run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"] - logger.debug(f"Launched run {run_id}...") - triggered_runs.append(run_id) - completed_runs = {} # key is run_id, value is status - while len(completed_runs) < len(triggered_runs): - for run_id in triggered_runs: - if run_id in completed_runs: - continue - response = requests.post( + for asset_node in response.json()["data"]["assetNodes"]: + tags = {tag["key"]: tag["value"] for tag in asset_node["tags"]} + # match assets based on conventional dag_id__task_id naming or based on explicit tags + if asset_node["opName"] == expected_op_name or ( + tags.get(DAG_ID_TAG) == dag_id and tags.get(TASK_ID_TAG) == task_id + ): + repo_location = asset_node["jobs"][0]["repository"]["location"]["name"] + repo_name = asset_node["jobs"][0]["repository"]["name"] + job_name = asset_node["jobs"][0]["name"] + if (repo_location, repo_name, job_name) not in assets_to_trigger: + assets_to_trigger[(repo_location, repo_name, job_name)] = [] + assets_to_trigger[(repo_location, repo_name, job_name)].append( + asset_node["assetKey"]["path"] + ) + logger.debug(f"Found assets to trigger: {assets_to_trigger}") + triggered_runs = [] + for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items(): + execution_params = { + "mode": "default", + "executionMetadata": {"tags": []}, + "runConfigData": "{}", + "selector": { + "repositoryLocationName": repo_location, + "repositoryName": repo_name, + "pipelineName": job_name, + "assetSelection": [{"path": asset_key} for asset_key in asset_keys], + "assetCheckSelection": [], + }, + } + logger.debug( + f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}" + ) + response = session.post( f"{dagster_url}/graphql", - json={"query": RUNS_QUERY, "variables": {"runId": run_id}}, + json={ + "query": TRIGGER_ASSETS_MUTATION, + "variables": {"executionParams": execution_params}, + }, timeout=3, ) - run_status = response.json()["data"]["runOrError"]["status"] - if run_status in ["SUCCESS", "FAILURE", "CANCELED"]: - logger.debug(f"Run {run_id} completed with status {run_status}") - completed_runs[run_id] = run_status - non_successful_runs = [ - run_id for run_id, status in completed_runs.items() if status != "SUCCESS" - ] - if non_successful_runs: - raise Exception(f"Runs {non_successful_runs} did not complete successfully.") - logger.debug("All runs completed successfully.") - return None - - -class DagsterOperator(PythonOperator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, python_callable=compute_fn) - - -def build_dagster_task(original_task: BaseOperator) -> DagsterOperator: - return instantiate_dagster_operator(original_task) - - -def instantiate_dagster_operator(original_task: BaseOperator) -> DagsterOperator: + run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"] + logger.debug(f"Launched run {run_id}...") + triggered_runs.append(run_id) + completed_runs = {} # key is run_id, value is status + while len(completed_runs) < len(triggered_runs): + for run_id in triggered_runs: + if run_id in completed_runs: + continue + response = session.post( + f"{dagster_url}/graphql", + json={"query": RUNS_QUERY, "variables": {"runId": run_id}}, + timeout=3, + ) + run_status = response.json()["data"]["runOrError"]["status"] + if run_status in ["SUCCESS", "FAILURE", "CANCELED"]: + logger.debug(f"Run {run_id} completed with status {run_status}") + completed_runs[run_id] = run_status + non_successful_runs = [ + run_id for run_id, status in completed_runs.items() if status != "SUCCESS" + ] + if non_successful_runs: + raise Exception(f"Runs {non_successful_runs} did not complete successfully.") + logger.debug("All runs completed successfully.") + return None + + def execute(self, context: Context) -> Any: + # https://github.com/apache/airflow/discussions/24463 + os.environ["NO_PROXY"] = "*" + dag_id = os.environ["AIRFLOW_CTX_DAG_ID"] + task_id = os.environ["AIRFLOW_CTX_TASK_ID"] + return self.launch_runs_for_task(context, dag_id, task_id) + + +class DefaultProxyToDagsterOperator(BaseProxyToDagsterOperator): + def get_dagster_session(self, context: Context) -> requests.Session: + return requests.Session() + + def get_dagster_url(self, context: Context) -> str: + return os.environ["DAGSTER_URL"] + + +def build_dagster_task( + original_task: BaseOperator, dagster_operator_klass: Type[BaseProxyToDagsterOperator] +) -> BaseProxyToDagsterOperator: + return instantiate_dagster_operator(original_task, dagster_operator_klass) + + +def instantiate_dagster_operator( + original_task: BaseOperator, dagster_operator_klass: Type[BaseProxyToDagsterOperator] +) -> BaseProxyToDagsterOperator: """Instantiates a DagsterOperator as a copy of the provided airflow task. We attempt to copy as many of the original task's attributes as possible, while respecting @@ -133,7 +170,7 @@ def instantiate_dagster_operator(original_task: BaseOperator) -> DagsterOperator continue init_kwargs[kwarg] = getattr(original_task, kwarg, default) - return DagsterOperator(**init_kwargs) + return dagster_operator_klass(**init_kwargs) def get_params(func: Callable[..., Any]) -> Tuple[Set[str], Dict[str, Any]]: diff --git a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/gql_queries.py b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/gql_queries.py index f4de7bfab76a6..ea98638466928 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/gql_queries.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/gql_queries.py @@ -1,3 +1,9 @@ +VERIFICATION_QUERY = """ +query VerificationQuery { + version +} +""" + ASSET_NODES_QUERY = """ query AssetNodeQuery { assetNodes { diff --git a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/mark_as_migrating.py b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/mark_as_migrating.py index 17304b9fa1782..27e4b61185223 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/mark_as_migrating.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/in_airflow/mark_as_migrating.py @@ -1,11 +1,15 @@ import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from airflow import DAG from airflow.models import BaseOperator -from dagster_airlift.in_airflow.dagster_operator import build_dagster_task +from dagster_airlift.in_airflow.dagster_operator import ( + BaseProxyToDagsterOperator, + DefaultProxyToDagsterOperator, + build_dagster_task, +) from ..migration_state import AirflowMigrationState @@ -15,6 +19,7 @@ def mark_as_dagster_migrating( global_vars: Dict[str, Any], migration_state: AirflowMigrationState, logger: Optional[logging.Logger] = None, + dagster_operator_klass: Type[BaseProxyToDagsterOperator] = DefaultProxyToDagsterOperator, ) -> None: """Alters all airflow dags in the current context to be marked as migrating to dagster. Uses a migration dictionary to determine the status of the migration for each task within each dag. @@ -78,7 +83,7 @@ def mark_as_dagster_migrating( logger.debug( f"Creating new operator for task {original_op.task_id} in dag {original_op.dag_id}" ) - new_op = build_dagster_task(original_op) + new_op = build_dagster_task(original_op, dagster_operator_klass) original_op.dag.task_dict[original_op.task_id] = new_op new_op.upstream_task_ids = original_op.upstream_task_ids diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/operator_test_project/dags/migrated_dag.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/operator_test_project/dags/migrated_dag.py index 03a08757906c1..411e7dbe6044c 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/operator_test_project/dags/migrated_dag.py +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/integration_tests/operator_test_project/dags/migrated_dag.py @@ -2,7 +2,7 @@ from datetime import datetime from airflow import DAG -from dagster_airlift.in_airflow.dagster_operator import DagsterOperator +from dagster_airlift.in_airflow.dagster_operator import DefaultProxyToDagsterOperator logging.basicConfig() logging.getLogger().setLevel(logging.INFO) @@ -24,5 +24,5 @@ is_paused_upon_creation=False, start_date=datetime(2023, 1, 1), ) -print_task = DagsterOperator(task_id="some_task", dag=dag) -other_task = DagsterOperator(task_id="other_task", dag=dag) +print_task = DefaultProxyToDagsterOperator(task_id="some_task", dag=dag) +other_task = DefaultProxyToDagsterOperator(task_id="other_task", dag=dag)