Skip to content

Commit

Permalink
[dagster-airlift] dagster operator (#23386)
Browse files Browse the repository at this point in the history
Dagster operator which is able to remotely invoke dagster via airflow.
It searches for a "node def" with the name that matches the task (and
matches our opinionated file format for blueprints).

Includes unit test which runs against live airflow and dagster, checks
that runs of the respective assets are invoked for each task.

Tests are probably not entirely sufficient. I'd ideally like to run
against a few other cases:
- workspace with multiple defined code locations
- multi asset with multiple keys. Ensure they all get picked up within
same run.

The api surface area here won't actually be exposed to the user. What
I'm imagining is something like this:
```python
... # dag code
create_migrating_dag(
    migrating_dict={...},
    dagster_instance=DagsterInstance(url=...) # or, in cloud case DagsterCloudInstance(url=..., auth_token=...)
)
```
then, under the hood, we swap out operators via the stack, same as we do
with the dag construction.

any real use case will need new graphql endpoints so that we can easily
retrieve asset info per node. The number of steps here feels gratuitous
(although only because the way we're retrieving information is a bit
hacky)
  • Loading branch information
dpeng817 authored and PedramNavid committed Aug 14, 2024
1 parent 6500acc commit c434a56
Show file tree
Hide file tree
Showing 9 changed files with 364 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

import requests
from airflow import DAG
from airflow.operators.python import PythonOperator

from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION


def compute_fn() -> None:
# https://github.com/apache/airflow/discussions/24463
os.environ["NO_PROXY"] = "*"
dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
task_id = os.environ["AIRFLOW_CTX_TASK_ID"]
expected_op_name = f"{dag_id}__{task_id}"
assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys
# create graphql client
dagster_url = os.environ["DAGSTER_URL"]
response = requests.post(f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3)
for asset_node in response.json()["data"]["assetNodes"]:
if asset_node["opName"] == expected_op_name:
repo_location = asset_node["jobs"][0]["repository"]["location"]["name"]
repo_name = asset_node["jobs"][0]["repository"]["name"]
job_name = asset_node["jobs"][0]["name"]
if (repo_location, repo_name, job_name) not in assets_to_trigger:
assets_to_trigger[(repo_location, repo_name, job_name)] = []
assets_to_trigger[(repo_location, repo_name, job_name)].append(
asset_node["assetKey"]["path"]
)
print(f"Found assets to trigger: {assets_to_trigger}") # noqa: T201
triggered_runs = []
for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items():
execution_params = {
"mode": "default",
"executionMetadata": {"tags": []},
"runConfigData": "{}",
"selector": {
"repositoryLocationName": repo_location,
"repositoryName": repo_name,
"pipelineName": job_name,
"assetSelection": [{"path": asset_key} for asset_key in asset_keys],
"assetCheckSelection": [],
},
}
print(f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}") # noqa: T201
response = requests.post(
f"{dagster_url}/graphql",
json={
"query": TRIGGER_ASSETS_MUTATION,
"variables": {"executionParams": execution_params},
},
timeout=3,
)
run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"]
print(f"Launched run {run_id}...") # noqa: T201
triggered_runs.append(run_id)
completed_runs = {} # key is run_id, value is status
while len(completed_runs) < len(triggered_runs):
for run_id in triggered_runs:
if run_id in completed_runs:
continue
response = requests.post(
f"{dagster_url}/graphql",
json={"query": RUNS_QUERY, "variables": {"runId": run_id}},
timeout=3,
)
run_status = response.json()["data"]["runOrError"]["status"]
if run_status in ["SUCCESS", "FAILURE", "CANCELED"]:
print(f"Run {run_id} completed with status {run_status}") # noqa: T201
completed_runs[run_id] = run_status
non_successful_runs = [
run_id for run_id, status in completed_runs.items() if status != "SUCCESS"
]
if non_successful_runs:
raise Exception(f"Runs {non_successful_runs} did not complete successfully.")
print("All runs completed successfully.") # noqa: T201
return None


def build_dagster_task(task_id: str, dag: DAG, **kwargs):
return PythonOperator(task_id=task_id, dag=dag, python_callable=compute_fn, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
ASSET_NODES_QUERY = """
query AssetNodeQuery {
assetNodes {
id
assetKey {
path
}
opName
jobs {
id
name
repository {
id
name
location {
id
name
}
}
}
}
}
"""

TRIGGER_ASSETS_MUTATION = """
mutation LaunchAssetsExecution($executionParams: ExecutionParams!) {
launchPipelineExecution(executionParams: $executionParams) {
... on LaunchRunSuccess {
run {
id
pipelineName
__typename
}
__typename
}
... on PipelineNotFoundError {
message
__typename
}
... on InvalidSubsetError {
message
__typename
}
... on RunConfigValidationInvalid {
errors {
message
__typename
}
__typename
}
...PythonErrorFragment
__typename
}
}
fragment PythonErrorFragment on PythonError {
message
stack
errorChain {
...PythonErrorChain
__typename
}
__typename
}
fragment PythonErrorChain on ErrorChainLink {
isExplicitLink
error {
message
stack
__typename
}
__typename
}
"""

RUNS_QUERY = """
query RunQuery($runId: ID!) {
runOrError(runId: $runId) {
__typename
...PythonErrorFragment
...NotFoundFragment
... on Run {
id
status
__typename
}
}
}
fragment NotFoundFragment on RunNotFoundError {
__typename
message
}
fragment PythonErrorFragment on PythonError {
__typename
message
stack
causes {
message
stack
__typename
}
}
"""
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def airflow_instance_fixture(setup: None) -> Generator[Any, None, None]:
initial_time = get_current_timestamp()

airflow_ready = False
while get_current_timestamp() - initial_time < 30:
while get_current_timestamp() - initial_time < 60:
if airflow_is_ready():
airflow_ready = True
break
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import signal
import subprocess
import time
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Generator

import pytest
import requests
from dagster._core.test_utils import environ
from dagster._time import get_current_timestamp


def assert_link_exists(link_name: str, link_url: Any):
Expand All @@ -19,17 +22,26 @@ def default_dags_dir():
return Path(__file__).parent / "dags"


@pytest.fixture(name="setup")
def setup_fixture(dags_dir: Path) -> Generator[str, None, None]:
@pytest.fixture(name="airflow_home")
def default_airflow_home() -> Generator[str, None, None]:
with TemporaryDirectory() as tmpdir:
# run chmod +x create_airflow_cfg.sh and then run create_airflow_cfg.sh tmpdir
temp_env = {**os.environ.copy(), "AIRFLOW_HOME": tmpdir}
# go up one directory from current
path_to_script = Path(__file__).parent.parent.parent / "airflow_setup.sh"
subprocess.run(["chmod", "+x", path_to_script], check=True, env=temp_env)
subprocess.run([path_to_script, dags_dir], check=True, env=temp_env)
with environ({"AIRFLOW_HOME": tmpdir}):
yield tmpdir
yield tmpdir


@pytest.fixture(name="setup")
def setup_fixture(airflow_home: Path, dags_dir: Path) -> Generator[str, None, None]:
# run chmod +x create_airflow_cfg.sh and then run create_airflow_cfg.sh tmpdir
temp_env = {
**os.environ.copy(),
"AIRFLOW_HOME": str(airflow_home),
"DAGSTER_URL": "http://localhost:3333",
}
# go up one directory from current
path_to_script = Path(__file__).parent.parent.parent / "airflow_setup.sh"
subprocess.run(["chmod", "+x", path_to_script], check=True, env=temp_env)
subprocess.run([path_to_script, dags_dir], check=True, env=temp_env)
with environ({"AIRFLOW_HOME": str(airflow_home), "DAGSTER_URL": "http://localhost:3333"}):
yield str(airflow_home)


@pytest.fixture(name="dbt_project_dir")
Expand All @@ -49,3 +61,46 @@ def dbt_project(dbt_project_dir: Path) -> None:
check=True,
env=os.environ.copy(),
)


def dagster_is_ready() -> bool:
try:
response = requests.get("http://localhost:3333")
return response.status_code == 200
except:
return False


@pytest.fixture(name="dagster_home")
def setup_dagster_home() -> Generator[str, None, None]:
"""Instantiate a temporary directory to serve as the DAGSTER_HOME."""
with TemporaryDirectory() as tmpdir:
yield tmpdir


@pytest.fixture(name="dagster_dev")
def setup_dagster(dagster_home: str, dagster_defs_path: str) -> Generator[Any, None, None]:
"""Stands up a dagster instance using the dagster dev CLI. dagster_defs_path must be provided
by a fixture included in the callsite.
"""
temp_env = {**os.environ.copy(), "DAGSTER_HOME": dagster_home}
process = subprocess.Popen(
["dagster", "dev", "-f", dagster_defs_path, "-p", "3333"],
env=temp_env,
shell=False,
preexec_fn=os.setsid, # noqa
)
# Give dagster a second to stand up
time.sleep(5)

dagster_ready = False
initial_time = get_current_timestamp()
while get_current_timestamp() - initial_time < 60:
if dagster_is_ready():
dagster_ready = True
break
time.sleep(1)

assert dagster_ready, "Dagster did not start within 30 seconds..."
yield process
os.killpg(process.pid, signal.SIGKILL)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging
from datetime import datetime

from airflow import DAG
from dagster_airlift.in_airflow.dagster_operator import build_dagster_task

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
requests_log = logging.getLogger("requests.packages.urllib3")
requests_log.setLevel(logging.INFO)
requests_log.propagate = True


def print_hello():
print("Hello") # noqa: T201


default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2023, 1, 1),
"retries": 1,
}

dag = DAG(
"the_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False
)
migrated_op = build_dagster_task(task_id="some_task", dag=dag)
other_migrated_op = build_dagster_task(task_id="other_task", dag=dag)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dagster import Definitions, asset


@asset
def the_dag__some_task():
return "asset_value"


@asset
def unrelated():
return "unrelated_value"


@asset
def the_dag__other_task():
return "other_task_value"


defs = Definitions(assets=[the_dag__other_task, the_dag__some_task, unrelated])
Loading

0 comments on commit c434a56

Please sign in to comment.