Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dagster-airlift] dagster operator #23386

Merged
merged 2 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mark_as_migrating import mark_as_dagster_migrating as mark_as_dagster_migrating
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

import requests
from airflow import DAG
from airflow.operators.python import PythonOperator

from .gql_queries import ASSET_NODES_QUERY, RUNS_QUERY, TRIGGER_ASSETS_MUTATION


def compute_fn() -> None:
# https://github.com/apache/airflow/discussions/24463
os.environ["NO_PROXY"] = "*"
dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
task_id = os.environ["AIRFLOW_CTX_TASK_ID"]
expected_op_name = f"{dag_id}__{task_id}"
assets_to_trigger = {} # key is (repo_location, repo_name, job_name), value is list of asset keys
# create graphql client
dagster_url = os.environ["DAGSTER_URL"]
response = requests.post(f"{dagster_url}/graphql", json={"query": ASSET_NODES_QUERY}, timeout=3)
for asset_node in response.json()["data"]["assetNodes"]:
if asset_node["opName"] == expected_op_name:
repo_location = asset_node["jobs"][0]["repository"]["location"]["name"]
repo_name = asset_node["jobs"][0]["repository"]["name"]
job_name = asset_node["jobs"][0]["name"]
if (repo_location, repo_name, job_name) not in assets_to_trigger:
assets_to_trigger[(repo_location, repo_name, job_name)] = []
assets_to_trigger[(repo_location, repo_name, job_name)].append(
asset_node["assetKey"]["path"]
)
print(f"Found assets to trigger: {assets_to_trigger}") # noqa: T201
triggered_runs = []
for (repo_location, repo_name, job_name), asset_keys in assets_to_trigger.items():
execution_params = {
"mode": "default",
"executionMetadata": {"tags": []},
"runConfigData": "{}",
"selector": {
"repositoryLocationName": repo_location,
"repositoryName": repo_name,
"pipelineName": job_name,
"assetSelection": [{"path": asset_key} for asset_key in asset_keys],
"assetCheckSelection": [],
},
}
print(f"Triggering run for {repo_location}/{repo_name}/{job_name} with assets {asset_keys}") # noqa: T201
response = requests.post(
f"{dagster_url}/graphql",
json={
"query": TRIGGER_ASSETS_MUTATION,
"variables": {"executionParams": execution_params},
},
timeout=3,
)
run_id = response.json()["data"]["launchPipelineExecution"]["run"]["id"]
print(f"Launched run {run_id}...") # noqa: T201
triggered_runs.append(run_id)
completed_runs = {} # key is run_id, value is status
while len(completed_runs) < len(triggered_runs):
for run_id in triggered_runs:
if run_id in completed_runs:
continue
response = requests.post(
f"{dagster_url}/graphql",
json={"query": RUNS_QUERY, "variables": {"runId": run_id}},
timeout=3,
)
run_status = response.json()["data"]["runOrError"]["status"]
if run_status in ["SUCCESS", "FAILURE", "CANCELED"]:
print(f"Run {run_id} completed with status {run_status}") # noqa: T201
completed_runs[run_id] = run_status
non_successful_runs = [
run_id for run_id, status in completed_runs.items() if status != "SUCCESS"
]
if non_successful_runs:
raise Exception(f"Runs {non_successful_runs} did not complete successfully.")
print("All runs completed successfully.") # noqa: T201
return None


def build_dagster_task(task_id: str, dag: DAG, **kwargs):
return PythonOperator(task_id=task_id, dag=dag, python_callable=compute_fn, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
ASSET_NODES_QUERY = """
query AssetNodeQuery {
assetNodes {
id
assetKey {
path
}
opName
jobs {
id
name
repository {
id
name
location {
id
name
}
}
}
}
}
"""

TRIGGER_ASSETS_MUTATION = """
mutation LaunchAssetsExecution($executionParams: ExecutionParams!) {
launchPipelineExecution(executionParams: $executionParams) {
... on LaunchRunSuccess {
run {
id
pipelineName
__typename
}
__typename
}
... on PipelineNotFoundError {
message
__typename
}
... on InvalidSubsetError {
message
__typename
}
... on RunConfigValidationInvalid {
errors {
message
__typename
}
__typename
}
...PythonErrorFragment
__typename
}
}

fragment PythonErrorFragment on PythonError {
message
stack
errorChain {
...PythonErrorChain
__typename
}
__typename
}

fragment PythonErrorChain on ErrorChainLink {
isExplicitLink
error {
message
stack
__typename
}
__typename
}
"""

RUNS_QUERY = """
query RunQuery($runId: ID!) {
runOrError(runId: $runId) {
__typename
...PythonErrorFragment
...NotFoundFragment
... on Run {
id
status
__typename
}
}
}
fragment NotFoundFragment on RunNotFoundError {
__typename
message
}
fragment PythonErrorFragment on PythonError {
__typename
message
stack
causes {
message
stack
__typename
}
}
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
import logging
from typing import Any, Dict, Optional

from airflow import DAG

from ..migration_state import AirflowMigrationState


def mark_as_dagster_migrating(
*,
global_vars: Dict[str, Any],
migration_state: AirflowMigrationState,
logger: Optional[logging.Logger] = None,
) -> None:
"""Alters all airflow dags in the current context to be marked as migrating to dagster.
Uses a migration dictionary to determine the status of the migration for each task within each dag.
Should only ever be the last line in a dag file.
Args:
global_vars (Dict[str, Any]): The global variables in the current context. In most cases, retrieved with `globals()` (no import required).
This is equivalent to what airflow already does to introspect the dags which exist in a given module context:
https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/dags.html#loading-dags
migration_state (AirflowMigrationState): The migration state for the dags.
logger (Optional[logging.Logger]): The logger to use. Defaults to logging.getLogger("dagster_airlift").
"""
caller_module = global_vars.get("__module__")
suffix = f" in module `{caller_module}`" if caller_module else ""
if not logger:
logger = logging.getLogger("dagster_airlift")
logger.debug(f"Searching for dags migrating to dagster{suffix}...")
num_dags = 0
for obj in global_vars.values():
if not isinstance(obj, DAG):
continue
dag: DAG = obj
logger.debug(f"Checking dag with id `{dag.dag_id}` for migration state.")
migration_state_for_dag = migration_state.get_migration_dict_for_dag(dag.dag_id)
if migration_state_for_dag is None:
logger.debug(
f"Dag with id `{dag.dag_id} hasn't been marked with migration state. Skipping..."
)
else:
dag.tags.append(json.dumps({"DAGSTER_MIGRATION_STATUS": migration_state_for_dag}))
logger.debug(
f"Dag with id `{dag.dag_id}` has been marked with migration state. Adding state to tags for dag."
)
num_dags += 1
logger.info(f"Marked {num_dags} dags as migrating to dagster{suffix}.")
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, NamedTuple
from typing import Any, Dict, NamedTuple, Optional

import yaml

Expand All @@ -18,11 +18,38 @@ class AirflowMigrationState(NamedTuple):
def get_migration_state_for_task(self, dag_id: str, task_id: str) -> bool:
return self.dags[dag_id].tasks[task_id].migrated

def get_migration_dict_for_dag(self, dag_id: str) -> Optional[Dict[str, Dict[str, Any]]]:
if dag_id not in self.dags:
return None
return {
"tasks": {
task_id: {"migrated": task_state.migrated}
for task_id, task_state in self.dags[dag_id].tasks.items()
}
}


class MigrationStateParsingError(Exception):
pass


def load_dag_migration_state_from_dict(dag_dict: Dict[str, Dict[str, Any]]) -> DagMigrationState:
if "tasks" not in dag_dict:
raise Exception("Expected a 'tasks' key in the yaml")
task_migration_states = {}
for task_id, task_dict in dag_dict["tasks"].items():
if not isinstance(task_dict, dict):
raise Exception("Expected a dictionary for each task")
if "migrated" not in task_dict:
raise Exception("Expected a 'migrated' key in the task dictionary")
if set(task_dict.keys()) != {"migrated"}:
raise Exception("Expected only a 'migrated' key in the task dictionary")
if task_dict["migrated"] not in [True, False]:
raise Exception("Expected 'migrated' key to be a boolean")
task_migration_states[task_id] = TaskMigrationState(migrated=task_dict["migrated"])
return DagMigrationState(tasks=task_migration_states)


def load_migration_state_from_yaml(migration_yaml_path: Path) -> AirflowMigrationState:
# Expect migration_yaml_path to be a directory, where each file represents a dag, and each
# file in the subdir represents a task. The dictionary each task should consist of a single bit:
Expand All @@ -37,20 +64,7 @@ def load_migration_state_from_yaml(migration_yaml_path: Path) -> AirflowMigratio
yaml_dict = yaml.safe_load(dag_file.read_text())
if not isinstance(yaml_dict, dict):
raise Exception("Expected a dictionary")
if "tasks" not in yaml_dict:
raise Exception("Expected a 'tasks' key in the yaml")
task_migration_states = {}
for task_id, task_dict in yaml_dict["tasks"].items():
if not isinstance(task_dict, dict):
raise Exception("Expected a dictionary for each task")
if "migrated" not in task_dict:
raise Exception("Expected a 'migrated' key in the task dictionary")
if set(task_dict.keys()) != {"migrated"}:
raise Exception("Expected only a 'migrated' key in the task dictionary")
if task_dict["migrated"] not in [True, False]:
raise Exception("Expected 'migrated' key to be a boolean")
task_migration_states[task_id] = TaskMigrationState(migrated=task_dict["migrated"])
dag_migration_states[dag_id] = DagMigrationState(tasks=task_migration_states)
dag_migration_states[dag_id] = load_dag_migration_state_from_dict(yaml_dict)
except Exception as e:
raise MigrationStateParsingError("Error parsing migration yaml") from e
return AirflowMigrationState(dags=dag_migration_states)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def airflow_instance_fixture(setup: None) -> Generator[Any, None, None]:
initial_time = get_current_timestamp()

airflow_ready = False
while get_current_timestamp() - initial_time < 30:
while get_current_timestamp() - initial_time < 60:
if airflow_is_ready():
airflow_ready = True
break
Expand Down
Loading