Skip to content

Commit

Permalink
Merge branch 'master' into docs/revamp
Browse files Browse the repository at this point in the history
  • Loading branch information
PedramNavid committed Aug 9, 2024
2 parents 805b36a + 319959f commit 36da245
Show file tree
Hide file tree
Showing 22 changed files with 521 additions and 224 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ..migration_state import load_migration_state_from_yaml as load_migration_state_from_yaml
from .basic_auth import BasicAuthBackend as BasicAuthBackend
from .defs_from_airflow import (
AirflowInstance as AirflowInstance,
build_defs_from_airflow_instance as build_defs_from_airflow_instance,
)
from .migration_state import load_migration_state_from_yaml as load_migration_state_from_yaml
from .multi_asset import PythonDefs as PythonDefs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
unpack_value,
)

from dagster_airlift.core.migration_state import AirflowMigrationState
from dagster_airlift.migration_state import AirflowMigrationState

from .airflow_instance import AirflowInstance, DagInfo, TaskInfo
from .utils import get_dag_id_from_asset, get_task_id_from_asset
Expand Down Expand Up @@ -149,16 +149,14 @@ def compute_cacheable_data(self) -> Sequence[AssetsDefinitionCacheableData]:
dag_specs_per_key: Dict[AssetKey, CacheableAssetSpec] = {}
for dag in self.airflow_instance.list_dags():
source_code = self.airflow_instance.get_dag_source_code(dag.metadata["file_token"])
dag_specs_per_key[self.airflow_instance.get_dag_run_asset_key(dag.dag_id)] = (
get_cached_spec_for_dag(
airflow_instance=self.airflow_instance,
task_asset_keys_in_dag=cacheable_task_data.all_asset_keys_per_dag_id.get(
dag.dag_id, set()
),
downstreams_asset_dependency_graph=cacheable_task_data.downstreams_asset_dependency_graph,
dag_info=dag,
source_code=source_code,
)
dag_specs_per_key[dag.dag_asset_key] = get_cached_spec_for_dag(
airflow_instance=self.airflow_instance,
task_asset_keys_in_dag=cacheable_task_data.all_asset_keys_per_dag_id.get(
dag.dag_id, set()
),
downstreams_asset_dependency_graph=cacheable_task_data.downstreams_asset_dependency_graph,
dag_info=dag,
source_code=source_code,
)
return [
AssetsDefinitionCacheableData(
Expand Down Expand Up @@ -213,9 +211,7 @@ def get_cached_spec_for_dag(
metadata = {
"Dag Info (raw)": JsonMetadataValue(dag_info.metadata),
"Dag ID": dag_info.dag_id,
"Link to DAG": MarkdownMetadataValue(
f"[View DAG]({airflow_instance.get_dag_url(dag_info.dag_id)})"
),
"Link to DAG": MarkdownMetadataValue(f"[View DAG]({dag_info.url})"),
}
# Attempt to retrieve source code from the DAG.
metadata["Source Code"] = MarkdownMetadataValue(
Expand All @@ -227,7 +223,7 @@ def get_cached_spec_for_dag(
)

return CacheableAssetSpec(
asset_key=airflow_instance.get_dag_run_asset_key(dag_info.dag_id),
asset_key=dag_info.dag_asset_key,
description=f"A materialization corresponds to a successful run of airflow DAG {dag_info.dag_id}.",
metadata=metadata,
tags={"dagster/compute_kind": "airflow", DAG_ID_TAG: dag_info.dag_id},
Expand Down Expand Up @@ -277,9 +273,7 @@ def construct_cacheable_assets_and_infer_dependencies(
"Task Info (raw)": JsonMetadataValue(task_info.metadata),
# In this case,
"Dag ID": task_info.dag_id,
"Link to DAG": MarkdownMetadataValue(
f"[View DAG]({airflow_instance.get_dag_url(task_info.dag_id)})"
),
"Link to DAG": MarkdownMetadataValue(f"[View DAG]({task_info.dag_url})"),
}
migration_state_for_task = _get_migration_state_for_task(
migration_state, task_info.dag_id, task_info.task_id
Expand Down Expand Up @@ -408,24 +402,40 @@ def get_task_info_for_asset(
return airflow_instance.get_task_info(dag_id, task_id)


def list_intersection(list1, list2):
return list(set(list1) & set(list2))


def get_leaf_assets_for_dag(
asset_keys_in_dag: Set[AssetKey],
downstreams_asset_dependency_graph: Dict[AssetKey, Set[AssetKey]],
) -> List[AssetKey]:
# An asset is a "leaf" for the dag if it has no dependencies _within_ the dag. It may have
# An asset is a "leaf" for the dag if it has no transitive dependencies _within_ the dag. It may have
# dependencies _outside_ the dag.
return [
asset_key
for asset_key in asset_keys_in_dag
if list_intersection(
downstreams_asset_dependency_graph.get(asset_key, []), asset_keys_in_dag
leaf_assets = []
cache = {}
for asset_key in asset_keys_in_dag:
if (
get_transitive_dependencies_for_asset(
asset_key, downstreams_asset_dependency_graph, cache
).intersection(asset_keys_in_dag)
== set()
):
leaf_assets.append(asset_key)
return leaf_assets


def get_transitive_dependencies_for_asset(
asset_key: AssetKey,
downstreams_asset_dependency_graph: Dict[AssetKey, Set[AssetKey]],
cache: Dict[AssetKey, Set[AssetKey]],
) -> Set[AssetKey]:
if asset_key in cache:
return cache[asset_key]
transitive_deps = set()
for dep in downstreams_asset_dependency_graph[asset_key]:
transitive_deps.add(dep)
transitive_deps.update(
get_transitive_dependencies_for_asset(dep, downstreams_asset_dependency_graph, cache)
)
== set()
]
cache[asset_key] = transitive_deps
return transitive_deps


def _get_migration_state_for_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from typing import Any, Dict, List

import requests
from dagster import AssetKey
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.errors import DagsterError
from dagster._record import record
from dagster._time import get_current_datetime

TERMINAL_STATES = {"success", "failed", "skipped", "up_for_retry", "up_for_reschedule"}


class AirflowAuthBackend(ABC):
Expand All @@ -32,8 +35,10 @@ def list_dags(self) -> List["DagInfo"]:
response = self.auth_backend.get_session().get(f"{self.get_api_url()}/dags")
if response.status_code == 200:
dags = response.json()
webserver_url = self.auth_backend.get_webserver_url()
return [
DagInfo(
webserver_url=webserver_url,
dag_id=dag["dag_id"],
metadata=dag,
)
Expand All @@ -44,12 +49,30 @@ def list_dags(self) -> List["DagInfo"]:
f"Failed to fetch DAGs. Status code: {response.status_code}, Message: {response.text}"
)

def get_task_instance(self, dag_id: str, task_id: str, run_id: str) -> "TaskInstance":
response = self.auth_backend.get_session().get(
f"{self.get_api_url()}/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}"
)
if response.status_code == 200:
return TaskInstance(
webserver_url=self.auth_backend.get_webserver_url(),
dag_id=dag_id,
task_id=task_id,
run_id=run_id,
metadata=response.json(),
)
else:
raise DagsterError(
f"Failed to fetch task instance for {dag_id}/{task_id}/{run_id}. Status code: {response.status_code}, Message: {response.text}"
)

def get_task_info(self, dag_id: str, task_id: str) -> "TaskInfo":
response = self.auth_backend.get_session().get(
f"{self.get_api_url()}/dags/{dag_id}/tasks/{task_id}"
)
if response.status_code == 200:
return TaskInfo(
webserver_url=self.auth_backend.get_webserver_url(),
dag_id=dag_id,
task_id=task_id,
metadata=response.json(),
Expand All @@ -59,23 +82,6 @@ def get_task_info(self, dag_id: str, task_id: str) -> "TaskInfo":
f"Failed to fetch task info for {dag_id}/{task_id}. Status code: {response.status_code}, Message: {response.text}"
)

def get_dag_url(self, dag_id: str) -> str:
return f"{self.auth_backend.get_webserver_url()}/dags/{dag_id}"

def get_dag_run_url(self, dag_id: str, run_id: str) -> str:
return f"{self.auth_backend.get_webserver_url()}/dags/{dag_id}/grid?dag_run_id={run_id}&tab=details"

def get_task_instance_url(self, dag_id: str, task_id: str, run_id: str) -> str:
# http://localhost:8080/dags/print_dag/grid?dag_run_id=manual__2024-08-08T17%3A21%3A22.427241%2B00%3A00&task_id=print_task
return f"{self.auth_backend.get_webserver_url()}/dags/{dag_id}/grid?dag_run_id={run_id}&task_id={task_id}"

def get_task_instance_log_url(self, dag_id: str, task_id: str, run_id: str) -> str:
# http://localhost:8080/dags/print_dag/grid?dag_run_id=manual__2024-08-08T17%3A21%3A22.427241%2B00%3A00&task_id=print_task&tab=logs
return f"{self.get_task_instance_url(dag_id, task_id, run_id)}&tab=logs"

def get_dag_run_asset_key(self, dag_id: str) -> AssetKey:
return AssetKey([self.normalized_name, "dag", dag_id])

def get_dag_source_code(self, file_token: str) -> str:
response = self.auth_backend.get_session().get(
f"{self.get_api_url()}/dagSources/{file_token}"
Expand All @@ -93,7 +99,7 @@ def airflow_str_from_datetime(dt: datetime.datetime) -> str:

def get_dag_runs(
self, dag_id: str, start_date: datetime.datetime, end_date: datetime.datetime
) -> List[Dict[str, Any]]:
) -> List["DagRun"]:
response = self.auth_backend.get_session().get(
f"{self.get_api_url()}/dags/{dag_id}/dagRuns",
params={
Expand All @@ -103,12 +109,55 @@ def get_dag_runs(
},
)
if response.status_code == 200:
return response.json()["dag_runs"]
webserver_url = self.auth_backend.get_webserver_url()
return [
DagRun(
webserver_url=webserver_url,
dag_id=dag_id,
run_id=dag_run["dag_run_id"],
metadata=dag_run,
)
for dag_run in response.json()["dag_runs"]
]
else:
raise DagsterError(
f"Failed to fetch dag runs for {dag_id}. Status code: {response.status_code}, Message: {response.text}"
)

def trigger_dag(self, dag_id: str) -> str:
response = self.auth_backend.get_session().post(
f"{self.get_api_url()}/dags/{dag_id}/dagRuns",
json={},
)
if response.status_code != 200:
raise DagsterError(
f"Failed to launch run for {dag_id}. Status code: {response.status_code}, Message: {response.text}"
)
return response.json()["dag_run_id"]

def get_dag_run(self, dag_id: str, run_id: str) -> "DagRun":
response = self.auth_backend.get_session().get(
f"{self.get_api_url()}/dags/{dag_id}/dagRuns/{run_id}"
)
if response.status_code != 200:
raise DagsterError(
f"Failed to fetch dag run for {dag_id}/{run_id}. Status code: {response.status_code}, Message: {response.text}"
)
return DagRun(
webserver_url=self.auth_backend.get_webserver_url(),
dag_id=dag_id,
run_id=run_id,
metadata=response.json(),
)

def wait_for_run_completion(self, dag_id: str, run_id: str, timeout: int = 30) -> None:
start_time = get_current_datetime()
while get_current_datetime() - start_time < datetime.timedelta(seconds=timeout):
dag_run = self.get_dag_run(dag_id, run_id)
if dag_run.finished:
return
raise DagsterError(f"Timed out waiting for airflow run {run_id} to finish.")

@staticmethod
def timestamp_from_airflow_date(airflow_date: str) -> float:
try:
Expand All @@ -121,12 +170,96 @@ def timestamp_from_airflow_date(airflow_date: str) -> float:

@record
class DagInfo:
webserver_url: str
dag_id: str
metadata: Dict[str, Any]

@property
def url(self) -> str:
return f"{self.webserver_url}/dags/{self.dag_id}"

@property
def dag_asset_key(self) -> AssetKey:
# Conventional asset key representing a successful run of an airfow dag.
return AssetKey(["airflow_instance", "dag", self.dag_id])


@record
class TaskInfo:
webserver_url: str
dag_id: str
task_id: str
metadata: Dict[str, Any]

@property
def dag_url(self) -> str:
return f"{self.webserver_url}/dags/{self.dag_id}"


@record
class TaskInstance:
webserver_url: str
dag_id: str
task_id: str
run_id: str
metadata: Dict[str, Any]

@property
def note(self) -> str:
return self.metadata.get("note") or ""

@property
def details_url(self) -> str:
return f"{self.webserver_url}/dags/{self.dag_id}/grid?dag_run_id={self.run_id}&task_id={self.task_id}"

@property
def log_url(self) -> str:
return f"{self.details_url}&tab=logs"

@property
def start_date(self) -> float:
return AirflowInstance.timestamp_from_airflow_date(self.metadata["start_date"])

@property
def end_date(self) -> float:
return AirflowInstance.timestamp_from_airflow_date(self.metadata["end_date"])


@record
class DagRun:
webserver_url: str
dag_id: str
run_id: str
metadata: Dict[str, Any]

@property
def note(self) -> str:
return self.metadata.get("note") or ""

@property
def url(self) -> str:
return f"{self.webserver_url}/dags/{self.dag_id}/grid?dag_run_id={self.run_id}&tab=details"

@property
def success(self) -> bool:
return self.metadata["state"] == "success"

@property
def finished(self) -> bool:
return self.metadata["state"] in TERMINAL_STATES

@property
def run_type(self) -> str:
return self.metadata["run_type"]

@property
def config(self) -> Dict[str, Any]:
return self.metadata["conf"]

@property
def start_date(self) -> float:
return AirflowInstance.timestamp_from_airflow_date(self.metadata["start_date"])

@property
def end_date(self) -> float:
return AirflowInstance.timestamp_from_airflow_date(self.metadata["end_date"])
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from dagster_airlift.core.sensor import build_airflow_polling_sensor

from ..migration_state import AirflowMigrationState
from .airflow_cacheable_assets_def import DEFAULT_POLL_INTERVAL, AirflowCacheableAssetsDefinition
from .airflow_instance import AirflowInstance
from .migration_state import AirflowMigrationState


def build_defs_from_airflow_instance(
Expand Down
Loading

0 comments on commit 36da245

Please sign in to comment.