Skip to content

Commit

Permalink
[dagster-airlift] airflow instance API
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Aug 3, 2024
1 parent 430e24f commit 264168d
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 97 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .airflow_utils import (
AirflowInstance as AirflowInstance,
BasicAuthBackend as BasicAuthBackend,
TaskMapping as TaskMapping,
airflow_task_mappings_from_dbt_project as airflow_task_mappings_from_dbt_project,
assets_defs_from_airflow_instance as assets_defs_from_airflow_instance,
Expand Down
235 changes: 148 additions & 87 deletions examples/experimental/dagster-airlift/dagster_airlift/airflow_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
import json
from abc import ABC
from datetime import timedelta
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Tuple
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence

import requests
from dagster import (
Expand Down Expand Up @@ -43,98 +44,172 @@ class TaskMapping(NamedTuple):
group: Optional[str] = None


def assets_defs_from_airflow_instance(
airflow_webserver_url: str,
auth: Tuple[str, str],
instance_name: str,
task_maps: Sequence[TaskMapping] = [],
) -> List[AssetsDefinition]:
api_url = f"{airflow_webserver_url}/api/v1"
dag_infos: List[DagInfo] = []

# First, we attempt to fetch all the DAGs present in the Airflow instance.
response = requests.get(f"{api_url}/dags", auth=auth)
if response.status_code == 200:
dags = response.json()
for dag in dags["dags"]:
dag_infos.append(
class AirflowAuthBackend(ABC):
def get_session(self) -> requests.Session:
raise NotImplementedError("This method must be implemented by subclasses.")


class BasicAuthBackend(AirflowAuthBackend):
def __init__(self, username: str, password: str):
self.username = username
self.password = password

def get_session(self) -> requests.Session:
session = requests.Session()
session.auth = (self.username, self.password)
return session


class AirflowInstance(NamedTuple):
airflow_webserver_url: str
auth_backend: AirflowAuthBackend
name: str

@property
def api_url(self) -> str:
return f"{self.airflow_webserver_url}/api/v1"

def list_dags(self) -> List[DagInfo]:
response = self.auth_backend.get_session().get(f"{self.api_url}/dags")
if response.status_code == 200:
dags = response.json()
return [
DagInfo(
dag_id=dag["dag_id"],
metadata=dag,
)
for dag in dags["dags"]
]
else:
raise Exception(
f"Failed to fetch DAGs. Status code: {response.status_code}, Message: {response.text}"
)

else:
raise Exception(
f"Failed to fetch DAGs. Status code: {response.status_code}, Message: {response.text}"
def get_task_info(self, dag_id: str, task_id: str) -> Dict[str, Any]:
response = self.auth_backend.get_session().get(
f"{self.api_url}/dags/{dag_id}/tasks/{task_id}"
)
if response.status_code == 200:
return response.json()
else:
raise Exception(
f"Failed to fetch task info for {dag_id}/{task_id}. Status code: {response.status_code}, Message: {response.text}"
)

asset_specs = []
def get_task_url(self, dag_id: str, task_id: str) -> str:
return f"{self.airflow_webserver_url}/dags/{dag_id}/{task_id}"

# All the assets which map to tasks within a given dag should be considered "upstream" of the dag.
dag_id_to_upstream_specs: Dict[str, List[AssetSpec]] = {}
def get_dag_url(self, dag_id: str) -> str:
return f"{self.airflow_webserver_url}/dags/{dag_id}"

for task in task_maps:
response = requests.get(f"{api_url}/dags/{task.dag_id}/tasks/{task.task_id}", auth=auth)
def get_dag_run_url(self, dag_id: str, run_id: str) -> str:
return f"{self.airflow_webserver_url}/dags/{dag_id}/grid?dag_run_id={run_id}&tab=details"

def get_dag_run_asset_key(self, dag_id: str) -> AssetKey:
return AssetKey([self.name, "dag", f"{dag_id}__successful_run"])

def get_dag_source_code(self, file_token: str) -> str:
response = self.auth_backend.get_session().get(f"{self.api_url}/dagSources/{file_token}")
if response.status_code == 200:
task_info = response.json()
joined_metadata = {
**task.metadata,
**{
"Task Info (raw)": JsonMetadataValue(task_info),
"Task ID": task.task_id,
"Dag ID": task.dag_id,
"Link to Task": MarkdownMetadataValue(
f"[View Task]({airflow_webserver_url}/dags/{task.dag_id}/{task.task_id})"
),
},
}
joined_tags = {
**task.tags,
**{"dagster/compute_kind": "airflow"},
}
asset_spec = AssetSpec(
deps=task.deps,
key=task.key,
description=f"A data asset materialized by task {task.task_id} within airflow dag {task.dag_id}.",
metadata=joined_metadata,
tags=joined_tags,
group_name=task.group,
return response.text
else:
raise Exception(
f"Failed to fetch source code for {file_token}. Status code: {response.status_code}, Message: {response.text}"
)
asset_specs.append(asset_spec)
dag_id_to_upstream_specs[task.dag_id] = dag_id_to_upstream_specs.get(
task.dag_id, []
) + [asset_spec]

@staticmethod
def airflow_str_from_datetime(dt: datetime.datetime) -> str:
return dt.strftime("%Y-%m-%dT%H:%M:%S+00:00")

def get_dag_runs(
self, dag_id: str, start_date: datetime.datetime, end_date: datetime.datetime
) -> List[Dict[str, Any]]:
response = self.auth_backend.get_session().get(
f"{self.api_url}/dags/{dag_id}/dagRuns",
params={
"updated_at_gte": self.airflow_str_from_datetime(start_date),
"updated_at_lte": self.airflow_str_from_datetime(end_date),
},
)
if response.status_code == 200:
return response.json()["dag_runs"]
else:
raise Exception(
f"Failed to fetch task info for {task.dag_id}/{task.task_id}. Status code: {response.status_code}, Message: {response.text}"
f"Failed to fetch dag runs for {dag_id}. Status code: {response.status_code}, Message: {response.text}"
)

@staticmethod
def timestamp_from_airflow_date(airflow_date: str) -> float:
try:
return datetime.datetime.strptime(airflow_date, "%Y-%m-%dT%H:%M:%S+00:00").timestamp()
except ValueError:
return datetime.datetime.strptime(
airflow_date, "%Y-%m-%dT%H:%M:%S.%f+00:00"
).timestamp()


def assets_defs_from_airflow_instance(
airflow_instance: AirflowInstance,
task_maps: Sequence[TaskMapping] = [],
) -> List[AssetsDefinition]:
dag_infos = airflow_instance.list_dags()

asset_specs = []

# All the assets which map to tasks within a given dag should be considered "upstream" of the dag.
dag_id_to_upstream_specs: Dict[str, List[AssetSpec]] = {}

for task in task_maps:
task_info = airflow_instance.get_task_info(task.dag_id, task.task_id)
joined_metadata = {
**task.metadata,
**{
"Task Info (raw)": JsonMetadataValue(task_info),
"Task ID": task.task_id,
"Dag ID": task.dag_id,
"Link to Task": MarkdownMetadataValue(
f"[View Task]({airflow_instance.get_task_url(task.dag_id, task.task_id)})"
),
},
}
joined_tags = {
**task.tags,
**{"dagster/compute_kind": "airflow"},
}
asset_spec = AssetSpec(
deps=task.deps,
key=task.key,
description=f"A data asset materialized by task {task.task_id} within airflow dag {task.dag_id}.",
metadata=joined_metadata,
tags=joined_tags,
group_name=task.group,
)
asset_specs.append(asset_spec)
dag_id_to_upstream_specs[task.dag_id] = dag_id_to_upstream_specs.get(task.dag_id, []) + [
asset_spec
]

dag_id_to_asset_key: Dict[str, AssetKey] = {}

for dag_info in dag_infos:
dag_id_to_asset_key[dag_info.dag_id] = AssetKey(
[instance_name, "dag", f"{dag_info.dag_id}__successful_run"]
dag_id_to_asset_key[dag_info.dag_id] = airflow_instance.get_dag_run_asset_key(
dag_info.dag_id
)
metadata = {
"Dag Info (raw)": JsonMetadataValue(dag_info.metadata),
"Dag ID": dag_info.dag_id,
"Link to DAG": MarkdownMetadataValue(
f"[View DAG]({airflow_webserver_url}/dags/{dag_info.dag_id})"
f"[View DAG]({airflow_instance.get_dag_url(dag_info.dag_id)})"
),
}
# Attempt to retrieve source code from the DAG.
file_token = dag_info.metadata["file_token"]
url = f"{api_url}/dagSources/{file_token}"
response = requests.get(url, auth=auth)
if response.status_code == 200:
metadata["Source Code"] = MarkdownMetadataValue(
f"""
metadata["Source Code"] = MarkdownMetadataValue(
f"""
```python
{response.text}
{airflow_instance.get_dag_source_code(dag_info.metadata["file_token"])}
```
"""
)
)
upstream_specs = dag_id_to_upstream_specs.get(dag_info.dag_id, [])
leaf_upstreams = get_leaf_specs(upstream_specs)
asset_specs.append(
Expand All @@ -143,7 +218,7 @@ def assets_defs_from_airflow_instance(
description=f"A materialization corresponds to a successful run of airflow DAG {dag_info.dag_id}.",
metadata=metadata,
tags={"dagster/compute_kind": "airflow"},
group_name=f"{instance_name}__dags",
group_name=f"{airflow_instance.name}__dags",
deps=[AssetDep(asset=spec.key) for spec in leaf_upstreams],
)
)
Expand All @@ -160,12 +235,9 @@ def _the_asset():


def build_airflow_polling_sensor(
airflow_webserver_url: str,
auth: Tuple[str, str],
airflow_instance: AirflowInstance,
airflow_asset_specs: List[AssetSpec],
) -> SensorDefinition:
api_url = f"{airflow_webserver_url}/api/v1"

@sensor(name="airflow_dag_status_sensor")
def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:
"""Sensor to report materialization events for each asset as new runs come in."""
Expand All @@ -183,17 +255,9 @@ def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:
specs_by_dag_id[dag_id] = []
specs_by_dag_id[dag_id].append(spec)
for dag, specs in specs_by_dag_id.items():
response = requests.get(
f"{api_url}/dags/{dag}/dagRuns",
auth=auth,
params={
"updated_at_gte": last_effective_date.strftime("%Y-%m-%dT%H:%M:%S+00:00"),
"updated_at_lte": current_date.strftime("%Y-%m-%dT%H:%M:%S+00:00"),
},
)
# For now, we materialize assets representing tasks only when the whole dag completes.
# With a more robust cursor that can let us know when we've seen a particular task run already, then we can relax this constraint.
for dag_run in response.json()["dag_runs"]:
for dag_run in airflow_instance.get_dag_runs(dag, last_effective_date, current_date):
# If the dag run succeeded, add materializations for all assets referring to dags.
if dag_run["state"] != "success":
raise Exception("Should not have seen a non-successful dag run.")
Expand All @@ -208,15 +272,19 @@ def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:
"Airflow Run ID": dag_run["dag_run_id"],
"Run Metadata (raw)": JsonMetadataValue(dag_run),
"Start Date": TimestampMetadataValue(
timestamp_from_airflow_date(dag_run["start_date"])
airflow_instance.timestamp_from_airflow_date(
dag_run["start_date"]
)
),
"End Date": TimestampMetadataValue(
timestamp_from_airflow_date(dag_run["end_date"])
airflow_instance.timestamp_from_airflow_date(
dag_run["end_date"]
)
),
"Run Type": dag_run["run_type"],
"Airflow Config": JsonMetadataValue(dag_run["conf"]),
"Link to Run": MarkdownMetadataValue(
f"[View Run]({airflow_webserver_url}/dags/{dag_id}/grid?dag_run_id={dag_run['dag_run_id']}&tab=details)"
f"[View Run]({airflow_instance.get_dag_run_url(dag, dag_run['dag_run_id'])})"
),
"Creation Timestamp": TimestampMetadataValue(
get_current_timestamp()
Expand All @@ -234,13 +302,6 @@ def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:
return airflow_dag_sensor


def timestamp_from_airflow_date(airflow_date: str) -> float:
try:
return datetime.datetime.strptime(airflow_date, "%Y-%m-%dT%H:%M:%S+00:00").timestamp()
except ValueError:
return datetime.datetime.strptime(airflow_date, "%Y-%m-%dT%H:%M:%S.%f+00:00").timestamp()


def add_prefix_to_specs(prefix: Sequence[str], specs: Sequence[AssetSpec]) -> Sequence[AssetSpec]:
return [
AssetSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.test_utils import instance_for_test
from dagster_airlift import (
AirflowInstance,
BasicAuthBackend,
TaskMapping,
assets_defs_from_airflow_instance,
build_airflow_polling_sensor,
Expand All @@ -22,10 +24,13 @@ def test_dag_peering(
airflow_instance: None,
) -> None:
"""Test that dags can be correctly peered from airflow, and certain metadata properties are retained."""
assets_defs = assets_defs_from_airflow_instance(
instance = AirflowInstance(
airflow_webserver_url="http://localhost:8080",
auth=("admin", "admin"),
instance_name="airflow_instance",
auth_backend=BasicAuthBackend(username="admin", password="admin"),
name="airflow_instance",
)
assets_defs = assets_defs_from_airflow_instance(
airflow_instance=instance,
task_maps=[
TaskMapping(
dag_id="print_dag",
Expand Down Expand Up @@ -65,8 +70,7 @@ def test_dag_peering(
assert task_spec.metadata["Task ID"] == "print_task"

sensor_def = build_airflow_polling_sensor(
airflow_webserver_url="http://localhost:8080",
auth=("admin", "admin"),
airflow_instance=instance,
airflow_asset_specs=[list(assets_def.specs)[0] for assets_def in assets_defs], # noqa
)

Expand Down
Loading

0 comments on commit 264168d

Please sign in to comment.