Skip to content

Commit

Permalink
[dagster-airlift][rfc] genericize dagster operator (#24187)
Browse files Browse the repository at this point in the history
Genericizes session auth for dagster operator.

Conflicted on whether we should fully genericize the
`launch_runs_for_task` method for now. I think it's fine to have a
default implementation personally for now and see how needs evolve,
since it can be overridden.
  • Loading branch information
dpeng817 authored Sep 3, 2024
1 parent 8721e8c commit 3b7dad1
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 96 deletions.
Original file line number Diff line number Diff line change
@@ -1,110 +1,147 @@
import inspect
import logging
import os
from typing import Any, Callable, Dict, Set, Tuple
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Set, Tuple, Type

import requests
from airflow.models.operator import BaseOperator
from airflow.operators.python import PythonOperator
from airflow.utils.context import Context

from dagster_airlift.core.utils import DAG_ID_TAG, TASK_ID_TAG

from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION
from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION, VERIFICATION_QUERY

logger = logging.getLogger(__name__)


def compute_fn() -> None:
# https://github.com/apache/airflow/discussions/24463
os.environ["NO_PROXY"] = "*"
dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
task_id = os.environ["AIRFLOW_CTX_TASK_ID"]
dagster_url = os.environ["DAGSTER_URL"]
return launch_runs_for_task(dag_id, task_id, dagster_url)


def launch_runs_for_task(dag_id: str, task_id: str, dagster_url: str) -> None:
expected_op_name = f"{dag_id}__{task_id}"

assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys
# create graphql client
response = requests.post(f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3)
for asset_node in response.json()["data"]["assetNodes"]:
tags = {tag["key"]: tag["value"] for tag in asset_node["tags"]}
# match assets based on conventional dag_id__task_id naming or based on explicit tags
if asset_node["opName"] == expected_op_name or (
tags.get(DAG_ID_TAG) == dag_id and tags.get(TASK_ID_TAG) == task_id
):
repo_location = asset_node["jobs"][0]["repository"]["location"]["name"]
repo_name = asset_node["jobs"][0]["repository"]["name"]
job_name = asset_node["jobs"][0]["name"]
if (repo_location, repo_name, job_name) not in assets_to_trigger:
assets_to_trigger[(repo_location, repo_name, job_name)] = []
assets_to_trigger[(repo_location, repo_name, job_name)].append(
asset_node["assetKey"]["path"]
)
logger.debug(f"Found assets to trigger: {assets_to_trigger}")
triggered_runs = []
for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items():
execution_params = {
"mode": "default",
"executionMetadata": {"tags": []},
"runConfigData": "{}",
"selector": {
"repositoryLocationName": repo_location,
"repositoryName": repo_name,
"pipelineName": job_name,
"assetSelection": [{"path": asset_key} for asset_key in asset_keys],
"assetCheckSelection": [],
},
}
logger.debug(
f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}"
class BaseProxyToDagsterOperator(BaseOperator, ABC):
"""Interface for a DagsterOperator.
This interface is used to create a custom operator that will be used to replace the original airflow operator when a task is marked as migrated.
"""

@abstractmethod
def get_dagster_session(self, context: Context) -> requests.Session:
"""Returns a requests session that can be used to make requests to the Dagster API."""

def _get_validated_session(self, context: Context) -> requests.Session:
session = self.get_dagster_session(context)
dagster_url = self.get_dagster_url(context)
response = session.post(
f"{dagster_url}/graphql", json={"query": VERIFICATION_QUERY}, timeout=3
)
response = requests.post(
f"{dagster_url}/graphql",
json={
"query": TRIGGER_ASSETS_MUTATION,
"variables": {"executionParams": execution_params},
},
timeout=3,
if response.status_code != 200:
raise Exception(
f"Failed to connect to Dagster at {dagster_url}. Response: {response.text}"
)
return session

@abstractmethod
def get_dagster_url(self, context: Context) -> str:
"""Returns the URL for the Dagster instance."""

def launch_runs_for_task(self, context: Context, dag_id: str, task_id: str) -> None:
"""Launches runs for the given task in Dagster."""
expected_op_name = f"{dag_id}__{task_id}"
session = self._get_validated_session(context)

dagster_url = self.get_dagster_url(context)
assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys
# create graphql client
response = session.post(
f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3
)
run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"]
logger.debug(f"Launched run {run_id}...")
triggered_runs.append(run_id)
completed_runs = {} # key is run_id, value is status
while len(completed_runs) < len(triggered_runs):
for run_id in triggered_runs:
if run_id in completed_runs:
continue
response = requests.post(
for asset_node in response.json()["data"]["assetNodes"]:
tags = {tag["key"]: tag["value"] for tag in asset_node["tags"]}
# match assets based on conventional dag_id__task_id naming or based on explicit tags
if asset_node["opName"] == expected_op_name or (
tags.get(DAG_ID_TAG) == dag_id and tags.get(TASK_ID_TAG) == task_id
):
repo_location = asset_node["jobs"][0]["repository"]["location"]["name"]
repo_name = asset_node["jobs"][0]["repository"]["name"]
job_name = asset_node["jobs"][0]["name"]
if (repo_location, repo_name, job_name) not in assets_to_trigger:
assets_to_trigger[(repo_location, repo_name, job_name)] = []
assets_to_trigger[(repo_location, repo_name, job_name)].append(
asset_node["assetKey"]["path"]
)
logger.debug(f"Found assets to trigger: {assets_to_trigger}")
triggered_runs = []
for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items():
execution_params = {
"mode": "default",
"executionMetadata": {"tags": []},
"runConfigData": "{}",
"selector": {
"repositoryLocationName": repo_location,
"repositoryName": repo_name,
"pipelineName": job_name,
"assetSelection": [{"path": asset_key} for asset_key in asset_keys],
"assetCheckSelection": [],
},
}
logger.debug(
f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}"
)
response = session.post(
f"{dagster_url}/graphql",
json={"query": RUNS_QUERY, "variables": {"runId": run_id}},
json={
"query": TRIGGER_ASSETS_MUTATION,
"variables": {"executionParams": execution_params},
},
timeout=3,
)
run_status = response.json()["data"]["runOrError"]["status"]
if run_status in ["SUCCESS", "FAILURE", "CANCELED"]:
logger.debug(f"Run {run_id} completed with status {run_status}")
completed_runs[run_id] = run_status
non_successful_runs = [
run_id for run_id, status in completed_runs.items() if status != "SUCCESS"
]
if non_successful_runs:
raise Exception(f"Runs {non_successful_runs} did not complete successfully.")
logger.debug("All runs completed successfully.")
return None


class DagsterOperator(PythonOperator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, python_callable=compute_fn)


def build_dagster_task(original_task: BaseOperator) -> DagsterOperator:
return instantiate_dagster_operator(original_task)


def instantiate_dagster_operator(original_task: BaseOperator) -> DagsterOperator:
run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"]
logger.debug(f"Launched run {run_id}...")
triggered_runs.append(run_id)
completed_runs = {} # key is run_id, value is status
while len(completed_runs) < len(triggered_runs):
for run_id in triggered_runs:
if run_id in completed_runs:
continue
response = session.post(
f"{dagster_url}/graphql",
json={"query": RUNS_QUERY, "variables": {"runId": run_id}},
timeout=3,
)
run_status = response.json()["data"]["runOrError"]["status"]
if run_status in ["SUCCESS", "FAILURE", "CANCELED"]:
logger.debug(f"Run {run_id} completed with status {run_status}")
completed_runs[run_id] = run_status
non_successful_runs = [
run_id for run_id, status in completed_runs.items() if status != "SUCCESS"
]
if non_successful_runs:
raise Exception(f"Runs {non_successful_runs} did not complete successfully.")
logger.debug("All runs completed successfully.")
return None

def execute(self, context: Context) -> Any:
# https://github.com/apache/airflow/discussions/24463
os.environ["NO_PROXY"] = "*"
dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
task_id = os.environ["AIRFLOW_CTX_TASK_ID"]
return self.launch_runs_for_task(context, dag_id, task_id)


class DefaultProxyToDagsterOperator(BaseProxyToDagsterOperator):
def get_dagster_session(self, context: Context) -> requests.Session:
return requests.Session()

def get_dagster_url(self, context: Context) -> str:
return os.environ["DAGSTER_URL"]


def build_dagster_task(
original_task: BaseOperator, dagster_operator_klass: Type[BaseProxyToDagsterOperator]
) -> BaseProxyToDagsterOperator:
return instantiate_dagster_operator(original_task, dagster_operator_klass)


def instantiate_dagster_operator(
original_task: BaseOperator, dagster_operator_klass: Type[BaseProxyToDagsterOperator]
) -> BaseProxyToDagsterOperator:
"""Instantiates a DagsterOperator as a copy of the provided airflow task.
We attempt to copy as many of the original task's attributes as possible, while respecting
Expand Down Expand Up @@ -133,7 +170,7 @@ def instantiate_dagster_operator(original_task: BaseOperator) -> DagsterOperator
continue
init_kwargs[kwarg] = getattr(original_task, kwarg, default)

return DagsterOperator(**init_kwargs)
return dagster_operator_klass(**init_kwargs)


def get_params(func: Callable[..., Any]) -> Tuple[Set[str], Dict[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
VERIFICATION_QUERY = """
query VerificationQuery {
version
}
"""

ASSET_NODES_QUERY = """
query AssetNodeQuery {
assetNodes {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import json
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

from airflow import DAG
from airflow.models import BaseOperator

from dagster_airlift.in_airflow.dagster_operator import build_dagster_task
from dagster_airlift.in_airflow.dagster_operator import (
BaseProxyToDagsterOperator,
DefaultProxyToDagsterOperator,
build_dagster_task,
)

from ..migration_state import AirflowMigrationState

Expand All @@ -15,6 +19,7 @@ def mark_as_dagster_migrating(
global_vars: Dict[str, Any],
migration_state: AirflowMigrationState,
logger: Optional[logging.Logger] = None,
dagster_operator_klass: Type[BaseProxyToDagsterOperator] = DefaultProxyToDagsterOperator,
) -> None:
"""Alters all airflow dags in the current context to be marked as migrating to dagster.
Uses a migration dictionary to determine the status of the migration for each task within each dag.
Expand Down Expand Up @@ -78,7 +83,7 @@ def mark_as_dagster_migrating(
logger.debug(
f"Creating new operator for task {original_op.task_id} in dag {original_op.dag_id}"
)
new_op = build_dagster_task(original_op)
new_op = build_dagster_task(original_op, dagster_operator_klass)
original_op.dag.task_dict[original_op.task_id] = new_op

new_op.upstream_task_ids = original_op.upstream_task_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime

from airflow import DAG
from dagster_airlift.in_airflow.dagster_operator import DagsterOperator
from dagster_airlift.in_airflow.dagster_operator import DefaultProxyToDagsterOperator

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
Expand All @@ -24,5 +24,5 @@
is_paused_upon_creation=False,
start_date=datetime(2023, 1, 1),
)
print_task = DagsterOperator(task_id="some_task", dag=dag)
other_task = DagsterOperator(task_id="other_task", dag=dag)
print_task = DefaultProxyToDagsterOperator(task_id="some_task", dag=dag)
other_task = DefaultProxyToDagsterOperator(task_id="other_task", dag=dag)

0 comments on commit 3b7dad1

Please sign in to comment.