Skip to content

Commit

Permalink
[dagster-airlift] dagster operator
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Aug 9, 2024
1 parent 4fc7a51 commit daefdd8
Show file tree
Hide file tree
Showing 7 changed files with 360 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

import requests
from airflow import DAG
from airflow.operators.python import PythonOperator

from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION


def compute_fn() -> None:
# https://github.com/apache/airflow/discussions/24463
os.environ["NO_PROXY"] = "*"
dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
task_id = os.environ["AIRFLOW_CTX_TASK_ID"]
expected_op_name = f"{dag_id}__{task_id}"
assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys
# create graphql client
dagster_url = os.environ["DAGSTER_URL"]
response = requests.post(f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3)
for asset_node in response.json()["data"]["assetNodes"]:
if asset_node["opName"] == expected_op_name:
repo_location = asset_node["jobs"][0]["repository"]["location"]["name"]
repo_name = asset_node["jobs"][0]["repository"]["name"]
job_name = asset_node["jobs"][0]["name"]
if (repo_location, repo_name, job_name) not in assets_to_trigger:
assets_to_trigger[(repo_location, repo_name, job_name)] = []
assets_to_trigger[(repo_location, repo_name, job_name)].append(
asset_node["assetKey"]["path"]
)
print(f"Found assets to trigger: {assets_to_trigger}") # noqa: T201
triggered_runs = []
for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items():
execution_params = {
"mode": "default",
"executionMetadata": {"tags": []},
"runConfigData": "{}",
"selector": {
"repositoryLocationName": repo_location,
"repositoryName": repo_name,
"pipelineName": job_name,
"assetSelection": [{"path": asset_key} for asset_key in asset_keys],
"assetCheckSelection": [],
},
}
print(f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}") # noqa: T201
response = requests.post(
f"{dagster_url}/graphql",
json={
"query": TRIGGER_ASSETS_MUTATION,
"variables": {"executionParams": execution_params},
},
timeout=3,
)
run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"]
print(f"Launched run {run_id}...") # noqa: T201
triggered_runs.append(run_id)
completed_runs = {} # key is run_id, value is status
while len(completed_runs) < len(triggered_runs):
for run_id in triggered_runs:
if run_id in completed_runs:
continue
response = requests.post(
f"{dagster_url}/graphql",
json={"query": RUNS_QUERY, "variables": {"runId": run_id}},
timeout=3,
)
run_status = response.json()["data"]["runOrError"]["status"]
if run_status in ["SUCCESS", "FAILURE", "CANCELED"]:
print(f"Run {run_id} completed with status {run_status}") # noqa: T201
completed_runs[run_id] = run_status
non_successful_runs = [
run_id for run_id, status in completed_runs.items() if status != "SUCCESS"
]
if non_successful_runs:
raise Exception(f"Runs {non_successful_runs} did not complete successfully.")
print("All runs completed successfully.") # noqa: T201
return None


def build_dagster_task(task_id: str, dag: DAG, **kwargs):
return PythonOperator(task_id=task_id, dag=dag, python_callable=compute_fn, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
ASSET_NODES_QUERY = """
query AssetNodeQuery {
assetNodes {
id
assetKey {
path
}
opName
jobs {
id
name
repository {
id
name
location {
id
name
}
}
}
}
}
"""

TRIGGER_ASSETS_MUTATION = """
mutation LaunchAssetsExecution($executionParams: ExecutionParams!) {
launchPipelineExecution(executionParams: $executionParams) {
... on LaunchRunSuccess {
run {
id
pipelineName
__typename
}
__typename
}
... on PipelineNotFoundError {
message
__typename
}
... on InvalidSubsetError {
message
__typename
}
... on RunConfigValidationInvalid {
errors {
message
__typename
}
__typename
}
...PythonErrorFragment
__typename
}
}
fragment PythonErrorFragment on PythonError {
message
stack
errorChain {
...PythonErrorChain
__typename
}
__typename
}
fragment PythonErrorChain on ErrorChainLink {
isExplicitLink
error {
message
stack
__typename
}
__typename
}
"""

RUNS_QUERY = """
query RunQuery($runId: ID!) {
runOrError(runId: $runId) {
__typename
...PythonErrorFragment
...NotFoundFragment
... on Run {
id
status
__typename
}
}
}
fragment NotFoundFragment on RunNotFoundError {
__typename
message
}
fragment PythonErrorFragment on PythonError {
__typename
message
stack
causes {
message
stack
__typename
}
}
"""
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def airflow_instance_fixture(setup: None) -> Generator[Any, None, None]:
initial_time = get_current_timestamp()

airflow_ready = False
while get_current_timestamp() - initial_time < 30:
while get_current_timestamp() - initial_time < 60:
if airflow_is_ready():
airflow_ready = True
break
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import signal
import subprocess
import time
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Generator

import pytest
import requests
from dagster._core.test_utils import environ
from dagster._time import get_current_timestamp


def assert_link_exists(link_name: str, link_url: Any):
Expand All @@ -19,17 +22,26 @@ def default_dags_dir():
return Path(__file__).parent / "dags"


@pytest.fixture(name="setup")
def setup_fixture(dags_dir: Path) -> Generator[str, None, None]:
@pytest.fixture(name="airflow_home")
def default_airflow_home() -> Generator[str, None, None]:
with TemporaryDirectory() as tmpdir:
# run chmod +x create_airflow_cfg.sh and then run create_airflow_cfg.sh tmpdir
temp_env = {**os.environ.copy(), "AIRFLOW_HOME": tmpdir}
# go up one directory from current
path_to_script = Path(__file__).parent.parent.parent / "airflow_setup.sh"
subprocess.run(["chmod", "+x", path_to_script], check=True, env=temp_env)
subprocess.run([path_to_script, dags_dir], check=True, env=temp_env)
with environ({"AIRFLOW_HOME": tmpdir}):
yield tmpdir
yield tmpdir


@pytest.fixture(name="setup")
def setup_fixture(airflow_home: Path, dags_dir: Path) -> Generator[str, None, None]:
# run chmod +x create_airflow_cfg.sh and then run create_airflow_cfg.sh tmpdir
temp_env = {
**os.environ.copy(),
"AIRFLOW_HOME": str(airflow_home),
"DAGSTER_URL": "http://localhost:3333",
}
# go up one directory from current
path_to_script = Path(__file__).parent.parent.parent / "airflow_setup.sh"
subprocess.run(["chmod", "+x", path_to_script], check=True, env=temp_env)
subprocess.run([path_to_script, dags_dir], check=True, env=temp_env)
with environ({"AIRFLOW_HOME": str(airflow_home), "DAGSTER_URL": "http://localhost:3333"}):
yield str(airflow_home)


@pytest.fixture(name="dbt_project_dir")
Expand All @@ -49,3 +61,46 @@ def dbt_project(dbt_project_dir: Path) -> None:
check=True,
env=os.environ.copy(),
)


def dagster_is_ready() -> bool:
try:
response = requests.get("http://localhost:3333")
return response.status_code == 200
except:
return False


@pytest.fixture(name="dagster_home")
def setup_dagster_home() -> Generator[str, None, None]:
"""Instantiate a temporary directory to serve as the DAGSTER_HOME."""
with TemporaryDirectory() as tmpdir:
yield tmpdir


@pytest.fixture(name="dagster_dev")
def setup_dagster(dagster_home: str, dagster_defs_path: str) -> Generator[Any, None, None]:
"""Stands up a dagster instance using the dagster dev CLI. dagster_defs_path must be provided
by a fixture included in the callsite.
"""
temp_env = {**os.environ.copy(), "DAGSTER_HOME": dagster_home}
process = subprocess.Popen(
["dagster", "dev", "-f", dagster_defs_path, "-p", "3333"],
env=temp_env,
shell=False,
preexec_fn=os.setsid, # noqa
)
# Give dagster a second to stand up
time.sleep(5)

dagster_ready = False
initial_time = get_current_timestamp()
while get_current_timestamp() - initial_time < 60:
if dagster_is_ready():
dagster_ready = True
break
time.sleep(1)

assert dagster_ready, "Dagster did not start within 30 seconds..."
yield process
os.killpg(process.pid, signal.SIGKILL)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging
from datetime import datetime

from airflow import DAG
from dagster_airlift.in_airflow.dagster_operator import build_dagster_task

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
requests_log = logging.getLogger("requests.packages.urllib3")
requests_log.setLevel(logging.INFO)
requests_log.propagate = True


def print_hello():
print("Hello") # noqa: T201


default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2023, 1, 1),
"retries": 1,
}

dag = DAG(
"the_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False
)
migrated_op = build_dagster_task(task_id="some_task", dag=dag)
other_migrated_op = build_dagster_task(task_id="other_task", dag=dag)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dagster import Definitions, asset


@asset
def the_dag__some_task():
return "asset_value"


@asset
def unrelated():
return "unrelated_value"


@asset
def the_dag__other_task():
return "other_task_value"


defs = Definitions(assets=[the_dag__other_task, the_dag__some_task, unrelated])
Loading

0 comments on commit daefdd8

Please sign in to comment.