Skip to content

Commit

Permalink
[dagster-airlift] sensor revamp (#24333)
Browse files Browse the repository at this point in the history
- Batch requests to airflow rest API
- Move to an iteration model where we can pause iteration in the middle
of processing (basically after processing step of any dag run) and
restart gracefully
- Handle skipped tasks gracefully
- Stop iteration at 40 seconds to give adequate time for startup and
teardown.

This nearly 10xes speed of sensor iteration, even on local where ipc
cost is low. See
https://gist.github.com/dpeng817/7979124e3286b1473aa57f252817b488
  • Loading branch information
dpeng817 authored Sep 11, 2024
1 parent 3fa4cf6 commit 718c516
Show file tree
Hide file tree
Showing 6 changed files with 472 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from abc import ABC
from functools import cached_property
from typing import Any, Dict, List
from typing import Any, Dict, List, Sequence

import requests
from dagster._core.definitions.asset_key import AssetKey
Expand All @@ -15,6 +15,14 @@
from .utils import convert_to_valid_dagster_name

TERMINAL_STATES = {"success", "failed", "skipped", "up_for_retry", "up_for_reschedule"}
# This limits the number of task ids that we attempt to query from airflow's task instance rest API at a given time.
# Airflow's batch task instance retrieval rest API doesn't have a limit parameter, but we query a single run at a time, meaning we should be getting
# a single task instance per task id.
# Airflow task instance batch API: https://airflow.apache.org/docs/apache-airflow/stable/stable-rest-api-ref.html#operation/get_task_instances_batch
DEFAULT_BATCH_TASK_RETRIEVAL_LIMIT = 100
# This corresponds directly to the page_limit parameter on airflow's batch dag runs rest API.
# Airflow dag run batch API: https://airflow.apache.org/docs/apache-airflow/stable/stable-rest-api-ref.html#operation/get_dag_runs_batch
DEFAULT_BATCH_DAG_RUNS_LIMIT = 100


class AirflowAuthBackend(ABC):
Expand All @@ -26,9 +34,17 @@ def get_webserver_url(self) -> str:


class AirflowInstance:
def __init__(self, auth_backend: AirflowAuthBackend, name: str) -> None:
def __init__(
self,
auth_backend: AirflowAuthBackend,
name: str,
batch_task_instance_limit: int = DEFAULT_BATCH_TASK_RETRIEVAL_LIMIT,
batch_dag_runs_limit: int = DEFAULT_BATCH_DAG_RUNS_LIMIT,
) -> None:
self.auth_backend = auth_backend
self.name = name
self.batch_task_instance_limit = batch_task_instance_limit
self.batch_dag_runs_limit = batch_dag_runs_limit

@property
def normalized_name(self) -> str:
Expand Down Expand Up @@ -74,6 +90,43 @@ def get_migration_state(self) -> AirflowMigrationState:
dag_dict[dag_id] = DagMigrationState.from_dict(migration_dict)
return AirflowMigrationState(dags=dag_dict)

def get_task_instance_batch(
self, dag_id: str, task_ids: Sequence[str], run_id: str, states: Sequence[str]
) -> List["TaskInstance"]:
"""Get all task instances for a given dag_id, task_ids, and run_id."""
task_instances = []
task_id_chunks = [
task_ids[i : i + self.batch_task_instance_limit]
for i in range(0, len(task_ids), self.batch_task_instance_limit)
]
for task_id_chunk in task_id_chunks:
response = self.auth_backend.get_session().post(
f"{self.get_api_url()}/dags/~/dagRuns/~/taskInstances/list",
json={
"dag_ids": [dag_id],
"task_ids": task_id_chunk,
"dag_run_ids": [run_id],
},
)

if response.status_code == 200:
for task_instance_json in response.json()["task_instances"]:
task_id = task_instance_json["task_id"]
task_instance = TaskInstance(
webserver_url=self.auth_backend.get_webserver_url(),
dag_id=dag_id,
task_id=task_id,
run_id=run_id,
metadata=task_instance_json,
)
if task_instance.state in states:
task_instances.append(task_instance)
else:
raise DagsterError(
f"Failed to fetch task instances for {dag_id}/{task_id_chunk}/{run_id}. Status code: {response.status_code}, Message: {response.text}"
)
return task_instances

def get_task_instance(self, dag_id: str, task_id: str, run_id: str) -> "TaskInstance":
response = self.auth_backend.get_session().get(
f"{self.get_api_url()}/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}"
Expand Down Expand Up @@ -149,6 +202,42 @@ def get_dag_runs(
f"Failed to fetch dag runs for {dag_id}. Status code: {response.status_code}, Message: {response.text}"
)

def get_dag_runs_batch(
self,
dag_ids: Sequence[str],
end_date_gte: datetime.datetime,
end_date_lte: datetime.datetime,
offset: int = 0,
) -> List["DagRun"]:
"""Return a batch of dag runs for a list of dag_ids. Ordered by end_date."""
response = self.auth_backend.get_session().post(
f"{self.get_api_url()}/dags/~/dagRuns/list",
json={
"dag_ids": dag_ids,
"end_date_gte": self.airflow_str_from_datetime(end_date_gte),
"end_date_lte": self.airflow_str_from_datetime(end_date_lte),
"order_by": "end_date",
"states": ["success"],
"page_offset": offset,
"page_limit": self.batch_dag_runs_limit,
},
)
if response.status_code == 200:
webserver_url = self.auth_backend.get_webserver_url()
return [
DagRun(
webserver_url=webserver_url,
dag_id=dag_run["dag_id"],
run_id=dag_run["dag_run_id"],
metadata=dag_run,
)
for dag_run in response.json()["dag_runs"]
]
else:
raise DagsterError(
f"Failed to fetch dag runs for {dag_ids}. Status code: {response.status_code}, Message: {response.text}"
)

def trigger_dag(self, dag_id: str) -> str:
response = self.auth_backend.get_session().post(
f"{self.get_api_url()}/dags/{dag_id}/dagRuns",
Expand Down Expand Up @@ -242,6 +331,10 @@ class TaskInstance:
run_id: str
metadata: Dict[str, Any]

@property
def state(self) -> str:
return self.metadata["state"]

@property
def note(self) -> str:
return self.metadata.get("note") or ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from dagster import Definitions

from dagster_airlift.core.sensor import build_airflow_polling_sensor
from dagster_airlift.core.sensor import (
DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
build_airflow_polling_sensor,
)
from dagster_airlift.migration_state import AirflowMigrationState

from .airflow_cacheable_assets_def import DEFAULT_POLL_INTERVAL, AirflowCacheableAssetsDefinition
Expand All @@ -16,6 +19,7 @@ def build_defs_from_airflow_instance(
# This parameter will go away once we can derive the migration state from airflow itself, using our built in utilities.
# Alternatively, we can keep it around to let people override the migration state if they want.
migration_state_override: Optional[AirflowMigrationState] = None,
sensor_minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
) -> Definitions:
"""From a provided airflow instance and a set of airflow-orchestrated dagster definitions, build a set of dagster definitions to peer and observe the airflow instance.
Expand Down Expand Up @@ -45,7 +49,9 @@ def build_defs_from_airflow_instance(
migration_state_override=migration_state_override,
)
# Now, we construct the sensor that will poll airflow for dag runs.
airflow_sensor = build_airflow_polling_sensor(airflow_instance=airflow_instance)
airflow_sensor = build_airflow_polling_sensor(
airflow_instance=airflow_instance, minimum_interval_seconds=sensor_minimum_interval_seconds
)
return Definitions(
assets=[assets_defs],
asset_checks=defs.asset_checks if defs else None,
Expand Down
Loading

0 comments on commit 718c516

Please sign in to comment.