Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dagster-airlift] AirflowInstance #23343

Merged
merged 1 commit into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .airflow_utils import (
AirflowInstance as AirflowInstance,
BasicAuthBackend as BasicAuthBackend,
TaskMapping as TaskMapping,
airflow_task_mappings_from_dbt_project as airflow_task_mappings_from_dbt_project,
assets_defs_from_airflow_instance as assets_defs_from_airflow_instance,
Expand Down
235 changes: 148 additions & 87 deletions examples/experimental/dagster-airlift/dagster_airlift/airflow_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
import json
from abc import ABC
from datetime import timedelta
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Tuple
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence

import requests
from dagster import (
Expand Down Expand Up @@ -43,98 +44,172 @@ class TaskMapping(NamedTuple):
group: Optional[str] = None


def assets_defs_from_airflow_instance(
airflow_webserver_url: str,
auth: Tuple[str, str],
instance_name: str,
task_maps: Sequence[TaskMapping] = [],
) -> List[AssetsDefinition]:
api_url = f"{airflow_webserver_url}/api/v1"
dag_infos: List[DagInfo] = []

# First, we attempt to fetch all the DAGs present in the Airflow instance.
response = requests.get(f"{api_url}/dags", auth=auth)
if response.status_code == 200:
dags = response.json()
for dag in dags["dags"]:
dag_infos.append(
class AirflowAuthBackend(ABC):
def get_session(self) -> requests.Session:
raise NotImplementedError("This method must be implemented by subclasses.")


class BasicAuthBackend(AirflowAuthBackend):
def __init__(self, username: str, password: str):
self.username = username
self.password = password

def get_session(self) -> requests.Session:
session = requests.Session()
session.auth = (self.username, self.password)
return session


class AirflowInstance(NamedTuple):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a little goofy to have this be a NamedTuple but nbd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's goofy about it being a namedtuple? I guess just vs a regular class you mean?

Honestly it's just learned behavior from the codebase at this point

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally NamedTuple are "dumb" value objects that represent a value/record at a point in time. This is a more complex object that will likely accumulate state, require caching, that sort of thing.

airflow_webserver_url: str
auth_backend: AirflowAuthBackend
name: str

@property
def api_url(self) -> str:
return f"{self.airflow_webserver_url}/api/v1"

def list_dags(self) -> List[DagInfo]:
response = self.auth_backend.get_session().get(f"{self.api_url}/dags")
if response.status_code == 200:
dags = response.json()
return [
DagInfo(
dag_id=dag["dag_id"],
metadata=dag,
)
for dag in dags["dags"]
]
else:
raise Exception(
f"Failed to fetch DAGs. Status code: {response.status_code}, Message: {response.text}"
)

else:
raise Exception(
f"Failed to fetch DAGs. Status code: {response.status_code}, Message: {response.text}"
def get_task_info(self, dag_id: str, task_id: str) -> Dict[str, Any]:
response = self.auth_backend.get_session().get(
f"{self.api_url}/dags/{dag_id}/tasks/{task_id}"
)
if response.status_code == 200:
return response.json()
else:
raise Exception(
f"Failed to fetch task info for {dag_id}/{task_id}. Status code: {response.status_code}, Message: {response.text}"
)

asset_specs = []
def get_task_url(self, dag_id: str, task_id: str) -> str:
return f"{self.airflow_webserver_url}/dags/{dag_id}/{task_id}"

# All the assets which map to tasks within a given dag should be considered "upstream" of the dag.
dag_id_to_upstream_specs: Dict[str, List[AssetSpec]] = {}
def get_dag_url(self, dag_id: str) -> str:
return f"{self.airflow_webserver_url}/dags/{dag_id}"

for task in task_maps:
response = requests.get(f"{api_url}/dags/{task.dag_id}/tasks/{task.task_id}", auth=auth)
def get_dag_run_url(self, dag_id: str, run_id: str) -> str:
return f"{self.airflow_webserver_url}/dags/{dag_id}/grid?dag_run_id={run_id}&tab=details"

def get_dag_run_asset_key(self, dag_id: str) -> AssetKey:
return AssetKey([self.name, "dag", f"{dag_id}__successful_run"])

def get_dag_source_code(self, file_token: str) -> str:
response = self.auth_backend.get_session().get(f"{self.api_url}/dagSources/{file_token}")
if response.status_code == 200:
task_info = response.json()
joined_metadata = {
**task.metadata,
**{
"Task Info (raw)": JsonMetadataValue(task_info),
"Task ID": task.task_id,
"Dag ID": task.dag_id,
"Link to Task": MarkdownMetadataValue(
f"[View Task]({airflow_webserver_url}/dags/{task.dag_id}/{task.task_id})"
),
},
}
joined_tags = {
**task.tags,
**{"dagster/compute_kind": "airflow"},
}
asset_spec = AssetSpec(
deps=task.deps,
key=task.key,
description=f"A data asset materialized by task {task.task_id} within airflow dag {task.dag_id}.",
metadata=joined_metadata,
tags=joined_tags,
group_name=task.group,
return response.text
else:
raise Exception(
f"Failed to fetch source code for {file_token}. Status code: {response.status_code}, Message: {response.text}"
)
asset_specs.append(asset_spec)
dag_id_to_upstream_specs[task.dag_id] = dag_id_to_upstream_specs.get(
task.dag_id, []
) + [asset_spec]

@staticmethod
def airflow_str_from_datetime(dt: datetime.datetime) -> str:
return dt.strftime("%Y-%m-%dT%H:%M:%S+00:00")

def get_dag_runs(
self, dag_id: str, start_date: datetime.datetime, end_date: datetime.datetime
) -> List[Dict[str, Any]]:
response = self.auth_backend.get_session().get(
f"{self.api_url}/dags/{dag_id}/dagRuns",
params={
"updated_at_gte": self.airflow_str_from_datetime(start_date),
"updated_at_lte": self.airflow_str_from_datetime(end_date),
},
)
if response.status_code == 200:
return response.json()["dag_runs"]
else:
raise Exception(
f"Failed to fetch task info for {task.dag_id}/{task.task_id}. Status code: {response.status_code}, Message: {response.text}"
f"Failed to fetch dag runs for {dag_id}. Status code: {response.status_code}, Message: {response.text}"
)

@staticmethod
def timestamp_from_airflow_date(airflow_date: str) -> float:
try:
return datetime.datetime.strptime(airflow_date, "%Y-%m-%dT%H:%M:%S+00:00").timestamp()
except ValueError:
return datetime.datetime.strptime(
airflow_date, "%Y-%m-%dT%H:%M:%S.%f+00:00"
).timestamp()


def assets_defs_from_airflow_instance(
airflow_instance: AirflowInstance,
task_maps: Sequence[TaskMapping] = [],
) -> List[AssetsDefinition]:
dag_infos = airflow_instance.list_dags()

asset_specs = []

# All the assets which map to tasks within a given dag should be considered "upstream" of the dag.
dag_id_to_upstream_specs: Dict[str, List[AssetSpec]] = {}

for task in task_maps:
task_info = airflow_instance.get_task_info(task.dag_id, task.task_id)
joined_metadata = {
**task.metadata,
**{
"Task Info (raw)": JsonMetadataValue(task_info),
"Task ID": task.task_id,
"Dag ID": task.dag_id,
"Link to Task": MarkdownMetadataValue(
f"[View Task]({airflow_instance.get_task_url(task.dag_id, task.task_id)})"
),
},
}
joined_tags = {
**task.tags,
**{"dagster/compute_kind": "airflow"},
}
asset_spec = AssetSpec(
deps=task.deps,
key=task.key,
description=f"A data asset materialized by task {task.task_id} within airflow dag {task.dag_id}.",
metadata=joined_metadata,
tags=joined_tags,
group_name=task.group,
)
asset_specs.append(asset_spec)
dag_id_to_upstream_specs[task.dag_id] = dag_id_to_upstream_specs.get(task.dag_id, []) + [
asset_spec
]

dag_id_to_asset_key: Dict[str, AssetKey] = {}

for dag_info in dag_infos:
dag_id_to_asset_key[dag_info.dag_id] = AssetKey(
[instance_name, "dag", f"{dag_info.dag_id}__successful_run"]
dag_id_to_asset_key[dag_info.dag_id] = airflow_instance.get_dag_run_asset_key(
dag_info.dag_id
)
metadata = {
"Dag Info (raw)": JsonMetadataValue(dag_info.metadata),
"Dag ID": dag_info.dag_id,
"Link to DAG": MarkdownMetadataValue(
f"[View DAG]({airflow_webserver_url}/dags/{dag_info.dag_id})"
f"[View DAG]({airflow_instance.get_dag_url(dag_info.dag_id)})"
),
}
# Attempt to retrieve source code from the DAG.
file_token = dag_info.metadata["file_token"]
url = f"{api_url}/dagSources/{file_token}"
response = requests.get(url, auth=auth)
if response.status_code == 200:
metadata["Source Code"] = MarkdownMetadataValue(
f"""
metadata["Source Code"] = MarkdownMetadataValue(
f"""
```python
{response.text}
{airflow_instance.get_dag_source_code(dag_info.metadata["file_token"])}
```
"""
)
)
upstream_specs = dag_id_to_upstream_specs.get(dag_info.dag_id, [])
leaf_upstreams = get_leaf_specs(upstream_specs)
asset_specs.append(
Expand All @@ -143,7 +218,7 @@ def assets_defs_from_airflow_instance(
description=f"A materialization corresponds to a successful run of airflow DAG {dag_info.dag_id}.",
metadata=metadata,
tags={"dagster/compute_kind": "airflow"},
group_name=f"{instance_name}__dags",
group_name=f"{airflow_instance.name}__dags",
deps=[AssetDep(asset=spec.key) for spec in leaf_upstreams],
)
)
Expand All @@ -160,12 +235,9 @@ def _the_asset():


def build_airflow_polling_sensor(
airflow_webserver_url: str,
auth: Tuple[str, str],
airflow_instance: AirflowInstance,
airflow_asset_specs: List[AssetSpec],
) -> SensorDefinition:
api_url = f"{airflow_webserver_url}/api/v1"

@sensor(name="airflow_dag_status_sensor")
def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:
"""Sensor to report materialization events for each asset as new runs come in."""
Expand All @@ -183,17 +255,9 @@ def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:
specs_by_dag_id[dag_id] = []
specs_by_dag_id[dag_id].append(spec)
for dag, specs in specs_by_dag_id.items():
response = requests.get(
f"{api_url}/dags/{dag}/dagRuns",
auth=auth,
params={
"updated_at_gte": last_effective_date.strftime("%Y-%m-%dT%H:%M:%S+00:00"),
"updated_at_lte": current_date.strftime("%Y-%m-%dT%H:%M:%S+00:00"),
},
)
# For now, we materialize assets representing tasks only when the whole dag completes.
# With a more robust cursor that can let us know when we've seen a particular task run already, then we can relax this constraint.
for dag_run in response.json()["dag_runs"]:
for dag_run in airflow_instance.get_dag_runs(dag, last_effective_date, current_date):
# If the dag run succeeded, add materializations for all assets referring to dags.
if dag_run["state"] != "success":
raise Exception("Should not have seen a non-successful dag run.")
Expand All @@ -208,15 +272,19 @@ def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:
"Airflow Run ID": dag_run["dag_run_id"],
"Run Metadata (raw)": JsonMetadataValue(dag_run),
"Start Date": TimestampMetadataValue(
timestamp_from_airflow_date(dag_run["start_date"])
airflow_instance.timestamp_from_airflow_date(
dag_run["start_date"]
)
),
"End Date": TimestampMetadataValue(
timestamp_from_airflow_date(dag_run["end_date"])
airflow_instance.timestamp_from_airflow_date(
dag_run["end_date"]
)
),
"Run Type": dag_run["run_type"],
"Airflow Config": JsonMetadataValue(dag_run["conf"]),
"Link to Run": MarkdownMetadataValue(
f"[View Run]({airflow_webserver_url}/dags/{dag_id}/grid?dag_run_id={dag_run['dag_run_id']}&tab=details)"
f"[View Run]({airflow_instance.get_dag_run_url(dag, dag_run['dag_run_id'])})"
),
"Creation Timestamp": TimestampMetadataValue(
get_current_timestamp()
Expand All @@ -234,13 +302,6 @@ def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult:
return airflow_dag_sensor


def timestamp_from_airflow_date(airflow_date: str) -> float:
try:
return datetime.datetime.strptime(airflow_date, "%Y-%m-%dT%H:%M:%S+00:00").timestamp()
except ValueError:
return datetime.datetime.strptime(airflow_date, "%Y-%m-%dT%H:%M:%S.%f+00:00").timestamp()


def add_prefix_to_specs(prefix: Sequence[str], specs: Sequence[AssetSpec]) -> Sequence[AssetSpec]:
return [
AssetSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.test_utils import instance_for_test
from dagster_airlift import (
AirflowInstance,
BasicAuthBackend,
TaskMapping,
assets_defs_from_airflow_instance,
build_airflow_polling_sensor,
Expand All @@ -22,10 +24,13 @@ def test_dag_peering(
airflow_instance: None,
) -> None:
"""Test that dags can be correctly peered from airflow, and certain metadata properties are retained."""
assets_defs = assets_defs_from_airflow_instance(
instance = AirflowInstance(
airflow_webserver_url="http://localhost:8080",
auth=("admin", "admin"),
instance_name="airflow_instance",
auth_backend=BasicAuthBackend(username="admin", password="admin"),
name="airflow_instance",
)
assets_defs = assets_defs_from_airflow_instance(
airflow_instance=instance,
task_maps=[
TaskMapping(
dag_id="print_dag",
Expand Down Expand Up @@ -65,8 +70,7 @@ def test_dag_peering(
assert task_spec.metadata["Task ID"] == "print_task"

sensor_def = build_airflow_polling_sensor(
airflow_webserver_url="http://localhost:8080",
auth=("admin", "admin"),
airflow_instance=instance,
airflow_asset_specs=[list(assets_def.specs)[0] for assets_def in assets_defs], # noqa
)

Expand Down
Loading