Skip to content

Commit

Permalink
[dagster-airlift] dagster operator
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Aug 5, 2024
1 parent f212aea commit 40cea78
Show file tree
Hide file tree
Showing 6 changed files with 383 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@
PythonDefs as PythonDefs,
load_defs_from_yaml as load_defs_from_yaml,
)
from .within_airflow import mark_as_dagster_migrating as mark_as_dagster_migrating
from .within_airflow import (
build_dagster_migrated_operator as build_dagster_migrated_operator,
mark_as_dagster_migrating as mark_as_dagster_migrating,
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def airflow_instance_fixture(setup: None) -> Generator[Any, None, None]:
initial_time = get_current_timestamp()

airflow_ready = False
while get_current_timestamp() - initial_time < 30:
while get_current_timestamp() - initial_time < 60:
if airflow_is_ready():
airflow_ready = True
break
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import os
import sys

import requests
from airflow import DAG
from airflow.operators.python import PythonOperator


def mark_as_dagster_migrating(
Expand Down Expand Up @@ -59,3 +62,208 @@ def mark_as_dagster_migrating(
)
globals_to_update[var] = new_dag
global_vars.update(globals_to_update)


ASSET_NODES_QUERY = """
query AssetNodeQuery {
assetNodes {
id
assetKey {
path
}
opName
jobs {
id
name
repository {
id
name
location {
id
name
}
}
}
}
}
"""

TRIGGER_ASSETS_MUTATION = """
mutation LaunchAssetsExecution($executionParams: ExecutionParams!) {
launchPipelineExecution(executionParams: $executionParams) {
... on LaunchRunSuccess {
run {
id
pipelineName
__typename
}
__typename
}
... on PipelineNotFoundError {
message
__typename
}
... on InvalidSubsetError {
message
__typename
}
... on RunConfigValidationInvalid {
errors {
message
__typename
}
__typename
}
...PythonErrorFragment
__typename
}
}
fragment PythonErrorFragment on PythonError {
message
stack
errorChain {
...PythonErrorChain
__typename
}
__typename
}
fragment PythonErrorChain on ErrorChainLink {
isExplicitLink
error {
message
stack
__typename
}
__typename
}
"""

# request format
# {
# "executionParams": {
# "mode": "default",
# "executionMetadata": {
# "tags": []
# },
# "runConfigData": "{}",
# "selector": {
# "repositoryLocationName": "toys",
# "repositoryName": "__repository__",
# "pipelineName": "__ASSET_JOB_0",
# "assetSelection": [
# {
# "path": [
# "bigquery",
# "raw_customers"
# ]
# }
# ],
# "assetCheckSelection": []
# }
# }
# }

RUNS_QUERY = """
query RunQuery($runId: ID!) {
runOrError(runId: $runId) {
__typename
...PythonErrorFragment
...NotFoundFragment
... on Run {
id
status
__typename
}
}
}
fragment NotFoundFragment on RunNotFoundError {
__typename
message
}
fragment PythonErrorFragment on PythonError {
__typename
message
stack
causes {
message
stack
__typename
}
}
"""


def compute_fn() -> None:
# https://github.com/apache/airflow/discussions/24463
os.environ["NO_PROXY"] = "*"
dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
task_id = os.environ["AIRFLOW_CTX_TASK_ID"]
expected_op_name = f"{dag_id}__{task_id}"
assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys
# create graphql client
dagster_url = os.environ["DAGSTER_URL"]
response = requests.post(f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3)
for asset_node in response.json()["data"]["assetNodes"]:
if asset_node["opName"] == expected_op_name:
repo_location = asset_node["jobs"][0]["repository"]["location"]["name"]
repo_name = asset_node["jobs"][0]["repository"]["name"]
job_name = asset_node["jobs"][0]["name"]
if (repo_location, repo_name, job_name) not in assets_to_trigger:
assets_to_trigger[(repo_location, repo_name, job_name)] = []
assets_to_trigger[(repo_location, repo_name, job_name)].append(
asset_node["assetKey"]["path"]
)
print(f"Found assets to trigger: {assets_to_trigger}") # noqa: T201
triggered_runs = []
for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items():
execution_params = {
"mode": "default",
"executionMetadata": {"tags": []},
"runConfigData": "{}",
"selector": {
"repositoryLocationName": repo_location,
"repositoryName": repo_name,
"pipelineName": job_name,
"assetSelection": [{"path": asset_key} for asset_key in asset_keys],
"assetCheckSelection": [],
},
}
print(f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}") # noqa: T201
response = requests.post(
f"{dagster_url}/graphql",
json={
"query": TRIGGER_ASSETS_MUTATION,
"variables": {"executionParams": execution_params},
},
timeout=3,
)
run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"]
print(f"Launched run {run_id}...") # noqa: T201
triggered_runs.append(run_id)
completed_runs = {} # key is run_id, value is status
while len(completed_runs) < len(triggered_runs):
for run_id in triggered_runs:
if run_id in completed_runs:
continue
response = requests.post(
f"{dagster_url}/graphql",
json={"query": RUNS_QUERY, "variables": {"runId": run_id}},
timeout=3,
)
run_status = response.json()["data"]["runOrError"]["status"]
if run_status in ["SUCCESS", "FAILURE", "CANCELED"]:
print(f"Run {run_id} completed with status {run_status}") # noqa: T201
completed_runs[run_id] = run_status
non_successful_runs = [
run_id for run_id, status in completed_runs.items() if status != "SUCCESS"
]
if non_successful_runs:
raise Exception(f"Runs {non_successful_runs} did not complete successfully.")
print("All runs completed successfully.") # noqa: T201
return None


def build_dagster_migrated_operator(task_id: str, dag: DAG, **kwargs):
return PythonOperator(task_id=task_id, dag=dag, python_callable=compute_fn, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging
from datetime import datetime

from airflow import DAG
from dagster_airlift import build_dagster_migrated_operator

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
requests_log = logging.getLogger("requests.packages.urllib3")
requests_log.setLevel(logging.INFO)
requests_log.propagate = True


def print_hello():
print("Hello") # noqa: T201


default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2023, 1, 1),
"retries": 1,
}

dag = DAG(
"the_dag", default_args=default_args, schedule_interval=None, is_paused_upon_creation=False
)
migrated_op = build_dagster_migrated_operator(task_id="some_task", dag=dag)
other_migrated_op = build_dagster_migrated_operator(task_id="other_task", dag=dag)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dagster import Definitions, asset


@asset
def the_dag__some_task():
return "asset_value"


@asset
def unrelated():
return "unrelated_value"


@asset
def the_dag__other_task():
return "other_task_value"


defs = Definitions(assets=[the_dag__other_task, the_dag__some_task, unrelated])
Loading

0 comments on commit 40cea78

Please sign in to comment.