From 718c516e9d991ba1d0d4a18ae821d81920ee4811 Mon Sep 17 00:00:00 2001 From: Christopher DeCarolis Date: Wed, 11 Sep 2024 12:51:21 -0700 Subject: [PATCH] [dagster-airlift] sensor revamp (#24333) - Batch requests to airflow rest API - Move to an iteration model where we can pause iteration in the middle of processing (basically after processing step of any dag run) and restart gracefully - Handle skipped tasks gracefully - Stop iteration at 40 seconds to give adequate time for startup and teardown. This nearly 10xes speed of sensor iteration, even on local where ipc cost is low. See https://gist.github.com/dpeng817/7979124e3286b1473aa57f252817b488 --- .../dagster_airlift/core/airflow_instance.py | 97 ++++++- .../dagster_airlift/core/defs_from_airflow.py | 10 +- .../dagster_airlift/core/sensor.py | 242 ++++++++++++------ .../test/airflow_test_instance.py | 35 ++- .../unit_tests/conftest.py | 24 +- .../unit_tests/core_tests/test_sensor.py | 155 ++++++++++- 6 files changed, 472 insertions(+), 91 deletions(-) diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_instance.py b/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_instance.py index eaceda7b59849..8e550b83f4442 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_instance.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/airflow_instance.py @@ -2,7 +2,7 @@ import json from abc import ABC from functools import cached_property -from typing import Any, Dict, List +from typing import Any, Dict, List, Sequence import requests from dagster._core.definitions.asset_key import AssetKey @@ -15,6 +15,14 @@ from .utils import convert_to_valid_dagster_name TERMINAL_STATES = {"success", "failed", "skipped", "up_for_retry", "up_for_reschedule"} +# This limits the number of task ids that we attempt to query from airflow's task instance rest API at a given time. +# Airflow's batch task instance retrieval rest API doesn't have a limit parameter, but we query a single run at a time, meaning we should be getting +# a single task instance per task id. +# Airflow task instance batch API: https://airflow.apache.org/docs/apache-airflow/stable/stable-rest-api-ref.html#operation/get_task_instances_batch +DEFAULT_BATCH_TASK_RETRIEVAL_LIMIT = 100 +# This corresponds directly to the page_limit parameter on airflow's batch dag runs rest API. +# Airflow dag run batch API: https://airflow.apache.org/docs/apache-airflow/stable/stable-rest-api-ref.html#operation/get_dag_runs_batch +DEFAULT_BATCH_DAG_RUNS_LIMIT = 100 class AirflowAuthBackend(ABC): @@ -26,9 +34,17 @@ def get_webserver_url(self) -> str: class AirflowInstance: - def __init__(self, auth_backend: AirflowAuthBackend, name: str) -> None: + def __init__( + self, + auth_backend: AirflowAuthBackend, + name: str, + batch_task_instance_limit: int = DEFAULT_BATCH_TASK_RETRIEVAL_LIMIT, + batch_dag_runs_limit: int = DEFAULT_BATCH_DAG_RUNS_LIMIT, + ) -> None: self.auth_backend = auth_backend self.name = name + self.batch_task_instance_limit = batch_task_instance_limit + self.batch_dag_runs_limit = batch_dag_runs_limit @property def normalized_name(self) -> str: @@ -74,6 +90,43 @@ def get_migration_state(self) -> AirflowMigrationState: dag_dict[dag_id] = DagMigrationState.from_dict(migration_dict) return AirflowMigrationState(dags=dag_dict) + def get_task_instance_batch( + self, dag_id: str, task_ids: Sequence[str], run_id: str, states: Sequence[str] + ) -> List["TaskInstance"]: + """Get all task instances for a given dag_id, task_ids, and run_id.""" + task_instances = [] + task_id_chunks = [ + task_ids[i : i + self.batch_task_instance_limit] + for i in range(0, len(task_ids), self.batch_task_instance_limit) + ] + for task_id_chunk in task_id_chunks: + response = self.auth_backend.get_session().post( + f"{self.get_api_url()}/dags/~/dagRuns/~/taskInstances/list", + json={ + "dag_ids": [dag_id], + "task_ids": task_id_chunk, + "dag_run_ids": [run_id], + }, + ) + + if response.status_code == 200: + for task_instance_json in response.json()["task_instances"]: + task_id = task_instance_json["task_id"] + task_instance = TaskInstance( + webserver_url=self.auth_backend.get_webserver_url(), + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + metadata=task_instance_json, + ) + if task_instance.state in states: + task_instances.append(task_instance) + else: + raise DagsterError( + f"Failed to fetch task instances for {dag_id}/{task_id_chunk}/{run_id}. Status code: {response.status_code}, Message: {response.text}" + ) + return task_instances + def get_task_instance(self, dag_id: str, task_id: str, run_id: str) -> "TaskInstance": response = self.auth_backend.get_session().get( f"{self.get_api_url()}/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}" @@ -149,6 +202,42 @@ def get_dag_runs( f"Failed to fetch dag runs for {dag_id}. Status code: {response.status_code}, Message: {response.text}" ) + def get_dag_runs_batch( + self, + dag_ids: Sequence[str], + end_date_gte: datetime.datetime, + end_date_lte: datetime.datetime, + offset: int = 0, + ) -> List["DagRun"]: + """Return a batch of dag runs for a list of dag_ids. Ordered by end_date.""" + response = self.auth_backend.get_session().post( + f"{self.get_api_url()}/dags/~/dagRuns/list", + json={ + "dag_ids": dag_ids, + "end_date_gte": self.airflow_str_from_datetime(end_date_gte), + "end_date_lte": self.airflow_str_from_datetime(end_date_lte), + "order_by": "end_date", + "states": ["success"], + "page_offset": offset, + "page_limit": self.batch_dag_runs_limit, + }, + ) + if response.status_code == 200: + webserver_url = self.auth_backend.get_webserver_url() + return [ + DagRun( + webserver_url=webserver_url, + dag_id=dag_run["dag_id"], + run_id=dag_run["dag_run_id"], + metadata=dag_run, + ) + for dag_run in response.json()["dag_runs"] + ] + else: + raise DagsterError( + f"Failed to fetch dag runs for {dag_ids}. Status code: {response.status_code}, Message: {response.text}" + ) + def trigger_dag(self, dag_id: str) -> str: response = self.auth_backend.get_session().post( f"{self.get_api_url()}/dags/{dag_id}/dagRuns", @@ -242,6 +331,10 @@ class TaskInstance: run_id: str metadata: Dict[str, Any] + @property + def state(self) -> str: + return self.metadata["state"] + @property def note(self) -> str: return self.metadata.get("note") or "" diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/defs_from_airflow.py b/examples/experimental/dagster-airlift/dagster_airlift/core/defs_from_airflow.py index a8f5d53ce50a8..1abfdb5f15f12 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/defs_from_airflow.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/defs_from_airflow.py @@ -2,7 +2,10 @@ from dagster import Definitions -from dagster_airlift.core.sensor import build_airflow_polling_sensor +from dagster_airlift.core.sensor import ( + DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS, + build_airflow_polling_sensor, +) from dagster_airlift.migration_state import AirflowMigrationState from .airflow_cacheable_assets_def import DEFAULT_POLL_INTERVAL, AirflowCacheableAssetsDefinition @@ -16,6 +19,7 @@ def build_defs_from_airflow_instance( # This parameter will go away once we can derive the migration state from airflow itself, using our built in utilities. # Alternatively, we can keep it around to let people override the migration state if they want. migration_state_override: Optional[AirflowMigrationState] = None, + sensor_minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS, ) -> Definitions: """From a provided airflow instance and a set of airflow-orchestrated dagster definitions, build a set of dagster definitions to peer and observe the airflow instance. @@ -45,7 +49,9 @@ def build_defs_from_airflow_instance( migration_state_override=migration_state_override, ) # Now, we construct the sensor that will poll airflow for dag runs. - airflow_sensor = build_airflow_polling_sensor(airflow_instance=airflow_instance) + airflow_sensor = build_airflow_polling_sensor( + airflow_instance=airflow_instance, minimum_interval_seconds=sensor_minimum_interval_seconds + ) return Definitions( assets=[assets_defs], asset_checks=defs.asset_checks if defs else None, diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/sensor.py b/examples/experimental/dagster-airlift/dagster_airlift/core/sensor.py index bae6e3df5c8f4..d3e1801e69acb 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/sensor.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/sensor.py @@ -1,6 +1,6 @@ from collections import defaultdict from datetime import timedelta -from typing import Dict, List, Sequence, Set, Tuple +from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple from dagster import ( AssetCheckKey, @@ -22,20 +22,38 @@ RepositoryDefinition, ) from dagster._core.utils import toposort_flatten +from dagster._grpc.client import DEFAULT_SENSOR_GRPC_TIMEOUT from dagster._record import record +from dagster._serdes import deserialize_value, serialize_value +from dagster._serdes.serdes import whitelist_for_serdes from dagster._time import datetime_from_timestamp, get_current_datetime, get_current_timestamp from dagster_airlift.constants import MIGRATED_TAG -from dagster_airlift.core.airflow_instance import AirflowInstance, TaskInstance +from dagster_airlift.core.airflow_instance import AirflowInstance from dagster_airlift.core.utils import get_dag_id_from_asset, get_task_id_from_asset +MAIN_LOOP_TIMEOUT_SECONDS = DEFAULT_SENSOR_GRPC_TIMEOUT - 20 +DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS = 1 +START_LOOKBACK_SECONDS = 60 # Lookback one minute in time for the initial setting of the cursor. + + +@whitelist_for_serdes +@record +class AirflowPollingSensorCursor: + """A cursor that stores the last effective timestamp and the last polled dag id.""" + + end_date_gte: Optional[float] = None + end_date_lte: Optional[float] = None + dag_query_offset: Optional[int] = None + def build_airflow_polling_sensor( airflow_instance: AirflowInstance, + minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS, ) -> SensorDefinition: @sensor( name="airflow_dag_status_sensor", - minimum_interval_seconds=1, + minimum_interval_seconds=minimum_interval_seconds, default_status=DefaultSensorStatus.RUNNING, # This sensor will only ever execute asset checks and not asset materializations. asset_selection=AssetSelection.all_asset_checks(), @@ -43,83 +61,66 @@ def build_airflow_polling_sensor( def airflow_dag_sensor(context: SensorEvaluationContext) -> SensorResult: """Sensor to report materialization events for each asset as new runs come in.""" repository_def = check.not_none(context.repository_def) - last_effective_date = ( - datetime_from_timestamp(float(context.cursor)) - if context.cursor - else get_current_datetime() - timedelta(days=1) - ) + try: + cursor = ( + deserialize_value(context.cursor, AirflowPollingSensorCursor) + if context.cursor + else AirflowPollingSensorCursor() + ) + except Exception as e: + context.log.info(f"Failed to interpret cursor. Starting from scratch. Error: {e}") + cursor = AirflowPollingSensorCursor() current_date = get_current_datetime() - materializations_to_report: List[Tuple[float, AssetMaterialization]] = [] toposorted_keys = toposorted_asset_keys(repository_def) - asset_check_keys_to_request = set() unmigrated_info = get_unmigrated_info(repository_def) - for dag_id, peered_dag_asset_info in unmigrated_info.asset_info_by_dag_id.items(): - dag_key = peered_dag_asset_info.dag_asset_key - task_keys = peered_dag_asset_info.task_asset_keys - # For now, we materialize assets representing tasks only when the whole dag completes. - # With a more robust cursor that can let us know when we've seen a particular task run already, then we can relax this constraint. - for dag_run in airflow_instance.get_dag_runs(dag_id, last_effective_date, current_date): - if not dag_run.success: - raise Exception("Should only see successful dag runs at this point.") - - metadata = { - "Airflow Run ID": dag_run.run_id, - "Run Metadata (raw)": JsonMetadataValue(dag_run.metadata), - "Run Type": dag_run.run_type, - "Airflow Config": JsonMetadataValue(dag_run.config), - "Creation Timestamp": TimestampMetadataValue(get_current_timestamp()), - } - # Add dag materialization - dag_metadata = { - **metadata, - "Run Details": MarkdownMetadataValue(f"[View Run]({dag_run.url})"), - "Start Date": TimestampMetadataValue(dag_run.start_date), - "End Date": TimestampMetadataValue(dag_run.end_date), - } - materializations_to_report.append( - ( - dag_run.end_date, - AssetMaterialization( - asset_key=dag_key, - description=dag_run.note, - metadata=dag_metadata, - ), - ) - ) - asset_check_keys_to_request.update(unmigrated_info.checks_per_key[dag_key]) - task_runs = {} - for task_id, asset_key in task_keys: - task_run: TaskInstance = task_runs.get( - task_id, airflow_instance.get_task_instance(dag_id, task_id, dag_run.run_id) - ) - task_runs[task_id] = task_run - task_metadata = { - **metadata, - "Run Details": MarkdownMetadataValue(f"[View Run]({task_run.details_url})"), - "Task Logs": MarkdownMetadataValue(f"[View Logs]({task_run.log_url})"), - "Start Date": TimestampMetadataValue(task_run.start_date), - "End Date": TimestampMetadataValue(task_run.end_date), - } - materializations_to_report.append( - ( - task_run.end_date, - AssetMaterialization( - asset_key=asset_key, - description=task_run.note, - metadata=task_metadata, - ), - ) - ) - asset_check_keys_to_request.update(unmigrated_info.checks_per_key[asset_key]) + current_dag_offset = cursor.dag_query_offset or 0 + end_date_gte = ( + cursor.end_date_gte + or (current_date - timedelta(seconds=START_LOOKBACK_SECONDS)).timestamp() + ) + end_date_lte = cursor.end_date_lte or current_date.timestamp() + sensor_iter = materializations_and_requests_from_batch_iter( + end_date_gte=end_date_gte, + end_date_lte=end_date_lte, + offset=current_dag_offset, + airflow_instance=airflow_instance, + unmigrated_info=unmigrated_info, + ) + all_materializations: List[Tuple[float, AssetMaterialization]] = [] + all_check_keys: Set[AssetCheckKey] = set() + latest_offset = current_dag_offset + while get_current_datetime() - current_date < timedelta(seconds=MAIN_LOOP_TIMEOUT_SECONDS): + batch_result = next(sensor_iter, None) + if batch_result is None: + break + all_materializations.extend(batch_result.materializations_and_timestamps) + + for asset_key in batch_result.all_asset_keys_materialized: + all_check_keys.update(unmigrated_info.checks_per_key[asset_key]) + latest_offset = batch_result.idx + # Sort materializations by end date and toposort order sorted_mats = sorted( - materializations_to_report, key=lambda x: (x[0], toposorted_keys.index(x[1].asset_key)) + all_materializations, key=lambda x: (x[0], toposorted_keys.index(x[1].asset_key)) ) - context.update_cursor(str(current_date.timestamp())) + if batch_result is not None: + new_cursor = AirflowPollingSensorCursor( + end_date_gte=end_date_gte, + end_date_lte=end_date_lte, + dag_query_offset=latest_offset + 1, + ) + else: + # We have completed iteration for this range + new_cursor = AirflowPollingSensorCursor( + end_date_gte=end_date_lte, + end_date_lte=None, + dag_query_offset=0, + ) + context.update_cursor(serialize_value(new_cursor)) return SensorResult( asset_events=[sorted_mat[1] for sorted_mat in sorted_mats], - run_requests=[RunRequest(asset_check_keys=list(asset_check_keys_to_request))] - if asset_check_keys_to_request + run_requests=[RunRequest(asset_check_keys=list(all_check_keys))] + if all_check_keys else None, ) @@ -131,12 +132,23 @@ class PeeredDagAssetInfo: dag_asset_key: AssetKey task_asset_keys: Set[Tuple[str, AssetKey]] + @property + def task_ids(self) -> Sequence[str]: + return [task_id for task_id, _ in self.task_asset_keys] + + def asset_keys_for_task(self, task_id: str) -> Sequence[AssetKey]: + return [asset_key for task_id_, asset_key in self.task_asset_keys if task_id_ == task_id] + @record class UnmigratedInfo: asset_info_by_dag_id: Dict[str, PeeredDagAssetInfo] checks_per_key: Dict[AssetKey, Set[AssetCheckKey]] + @property + def dag_ids(self) -> Sequence[str]: + return list(self.asset_info_by_dag_id.keys()) + def get_unmigrated_info( repository_def: RepositoryDefinition, @@ -190,3 +202,89 @@ def toposorted_asset_keys( asset_dep_graph[spec.key].update(dep.asset_key for dep in spec.deps) return toposort_flatten(asset_dep_graph) + + +@record +class BatchResult: + idx: int + materializations_and_timestamps: List[Tuple[float, AssetMaterialization]] + all_asset_keys_materialized: Set[AssetKey] + + +def materializations_and_requests_from_batch_iter( + end_date_gte: float, + end_date_lte: float, + offset: int, + airflow_instance: AirflowInstance, + unmigrated_info: UnmigratedInfo, +) -> Iterator[Optional[BatchResult]]: + runs = airflow_instance.get_dag_runs_batch( + dag_ids=unmigrated_info.dag_ids, + end_date_gte=datetime_from_timestamp(end_date_gte), + end_date_lte=datetime_from_timestamp(end_date_lte), + offset=offset, + ) + for i, dag_run in enumerate(runs): + peered_dag_asset_info = unmigrated_info.asset_info_by_dag_id[dag_run.dag_id] + materializations_for_run = [] + all_asset_keys_materialized = set() + metadata = { + "Airflow Run ID": dag_run.run_id, + "Run Metadata (raw)": JsonMetadataValue(dag_run.metadata), + "Run Type": dag_run.run_type, + "Airflow Config": JsonMetadataValue(dag_run.config), + "Creation Timestamp": TimestampMetadataValue(get_current_timestamp()), + } + # Add dag materialization + dag_metadata = { + **metadata, + "Run Details": MarkdownMetadataValue(f"[View Run]({dag_run.url})"), + "Start Date": TimestampMetadataValue(dag_run.start_date), + "End Date": TimestampMetadataValue(dag_run.end_date), + } + materializations_for_run.append( + ( + dag_run.end_date, + AssetMaterialization( + asset_key=unmigrated_info.asset_info_by_dag_id[dag_run.dag_id].dag_asset_key, + description=dag_run.note, + metadata=dag_metadata, + ), + ) + ) + all_asset_keys_materialized.add(peered_dag_asset_info.dag_asset_key) + for task_run in airflow_instance.get_task_instance_batch( + run_id=dag_run.run_id, + dag_id=dag_run.dag_id, + task_ids=peered_dag_asset_info.task_ids, + states=["success"], + ): + asset_keys = peered_dag_asset_info.asset_keys_for_task(task_run.task_id) + task_metadata = { + **metadata, + "Run Details": MarkdownMetadataValue(f"[View Run]({task_run.details_url})"), + "Task Logs": MarkdownMetadataValue(f"[View Logs]({task_run.log_url})"), + "Start Date": TimestampMetadataValue(task_run.start_date), + "End Date": TimestampMetadataValue(task_run.end_date), + } + for asset_key in asset_keys: + materializations_for_run.append( + ( + task_run.end_date, + AssetMaterialization( + asset_key=asset_key, + description=task_run.note, + metadata=task_metadata, + ), + ) + ) + all_asset_keys_materialized.add(asset_key) + yield ( + BatchResult( + idx=i + offset, + materializations_and_timestamps=materializations_for_run, + all_asset_keys_materialized=all_asset_keys_materialized, + ) + if materializations_for_run + else None + ) diff --git a/examples/experimental/dagster-airlift/dagster_airlift/test/airflow_test_instance.py b/examples/experimental/dagster-airlift/dagster_airlift/test/airflow_test_instance.py index a9bbb8030bbb5..5a302b605e2e6 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/test/airflow_test_instance.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/test/airflow_test_instance.py @@ -1,6 +1,6 @@ from collections import defaultdict from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import requests @@ -65,6 +65,39 @@ def get_dag_runs(self, dag_id: str, start_date: datetime, end_date: datetime) -> and start_date.timestamp() <= run.end_date <= end_date.timestamp() ] + def get_dag_runs_batch( + self, + dag_ids: Sequence[str], + end_date_gte: datetime, + end_date_lte: datetime, + offset: int = 0, + ) -> List[DagRun]: + runs = [ + (run.end_date, run) + for runs in self._dag_runs_by_dag_id.values() + for run in runs + if end_date_gte.timestamp() <= run.end_date <= end_date_lte.timestamp() + and run.dag_id in dag_ids + ] + sorted_by_end_date = [run for _, run in sorted(runs, key=lambda x: x[0])] + return sorted_by_end_date[offset:] + + def get_task_instance_batch( + self, dag_id: str, task_ids: Sequence[str], run_id: str, states: Sequence[str] + ) -> List[TaskInstance]: + task_instances = [] + for task_id in set(task_ids): + if (dag_id, task_id) not in self._task_instances_by_dag_and_task_id: + continue + task_instances.extend( + [ + task_instance + for task_instance in self._task_instances_by_dag_and_task_id[(dag_id, task_id)] + if task_instance.run_id == run_id and task_instance.state in states + ] + ) + return task_instances + def get_task_instance(self, dag_id: str, task_id: str, run_id: str) -> TaskInstance: if (dag_id, task_id) not in self._task_instances_by_dag_and_task_id: raise ValueError(f"Task instance not found for dag_id {dag_id} and task_id {task_id}") diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/conftest.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/conftest.py index a2427ad35e732..9dd77cd788ce7 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/conftest.py +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/conftest.py @@ -7,6 +7,7 @@ AssetObservation, AssetSpec, Definitions, + SensorEvaluationContext, SensorResult, build_sensor_context, ) @@ -27,8 +28,11 @@ def strip_to_first_of_month(dt: datetime) -> datetime: def fully_loaded_repo_from_airflow_asset_graph( assets_per_task: Dict[str, Dict[str, List[Tuple[str, List[str]]]]], additional_defs: Definitions = Definitions(), + create_runs: bool = True, ) -> RepositoryDefinition: - defs = build_definitions_airflow_asset_graph(assets_per_task, additional_defs=additional_defs) + defs = build_definitions_airflow_asset_graph( + assets_per_task, additional_defs=additional_defs, create_runs=create_runs + ) repo_def = defs.get_repository_def() repo_def.load_all_definitions() return repo_def @@ -37,6 +41,7 @@ def fully_loaded_repo_from_airflow_asset_graph( def build_definitions_airflow_asset_graph( assets_per_task: Dict[str, Dict[str, List[Tuple[str, List[str]]]]], additional_defs: Definitions = Definitions(), + create_runs: bool = True, ) -> Definitions: specs = [] dag_and_task_structure = defaultdict(list) @@ -51,9 +56,8 @@ def build_definitions_airflow_asset_graph( metadata={"airlift/dag_id": dag_id, "airlift/task_id": task_id}, ) ) - instance = make_instance( - dag_and_task_structure=dag_and_task_structure, - dag_runs=[ + runs = ( + [ make_dag_run( dag_id=dag_id, run_id=f"run-{dag_id}", @@ -61,7 +65,13 @@ def build_definitions_airflow_asset_graph( end_date=get_current_datetime(), ) for dag_id in dag_and_task_structure.keys() - ], + ] + if create_runs + else [] + ) + instance = make_instance( + dag_and_task_structure=dag_and_task_structure, + dag_runs=runs, ) defs = Definitions.merge( additional_defs, @@ -73,7 +83,7 @@ def build_definitions_airflow_asset_graph( def build_and_invoke_sensor( assets_per_task: Dict[str, Dict[str, List[Tuple[str, List[str]]]]], additional_defs: Definitions = Definitions(), -) -> SensorResult: +) -> Tuple[SensorResult, SensorEvaluationContext]: repo_def = fully_loaded_repo_from_airflow_asset_graph( assets_per_task, additional_defs=additional_defs ) @@ -81,7 +91,7 @@ def build_and_invoke_sensor( context = build_sensor_context(repository_def=repo_def) result = sensor(context) assert isinstance(result, SensorResult) - return result + return result, context def assert_expected_key_order( diff --git a/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/core_tests/test_sensor.py b/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/core_tests/test_sensor.py index 4c91fd0076a57..624512a5b443b 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/core_tests/test_sensor.py +++ b/examples/experimental/dagster-airlift/dagster_airlift_tests/unit_tests/core_tests/test_sensor.py @@ -1,12 +1,24 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone -from dagster import AssetCheckKey, AssetKey, AssetSpec, Definitions, asset_check +import mock +from dagster import ( + AssetCheckKey, + AssetKey, + AssetSpec, + Definitions, + SensorResult, + asset_check, + build_sensor_context, +) from dagster._core.definitions.events import AssetMaterialization from dagster._core.test_utils import freeze_time +from dagster._serdes import deserialize_value +from dagster_airlift.core.sensor import AirflowPollingSensorCursor from dagster_airlift_tests.unit_tests.conftest import ( assert_expected_key_order, build_and_invoke_sensor, + fully_loaded_repo_from_airflow_asset_graph, ) @@ -15,7 +27,7 @@ def test_dag_and_task_metadata() -> None: freeze_datetime = datetime(2021, 1, 1) with freeze_time(freeze_datetime): - result = build_and_invoke_sensor( + result, _ = build_and_invoke_sensor( assets_per_task={ "dag": {"task": [("a", [])]}, }, @@ -91,7 +103,7 @@ def test_interleaved_exeutions() -> None: # c -> d where c and d are each in their own airflow tasks, in a different dag. freeze_datetime = datetime(2021, 1, 1) with freeze_time(freeze_datetime): - result = build_and_invoke_sensor( + result, context = build_and_invoke_sensor( assets_per_task={ "dag1": {"task1": [("a", [])], "task2": [("b", ["a"])]}, "dag2": {"task1": [("c", [])], "task2": [("d", ["c"])]}, @@ -109,6 +121,11 @@ def test_interleaved_exeutions() -> None: # dag1 and dag2 should be after all task-mapped assets assert mats_order.index("airflow_instance/dag/dag1") >= 4 assert mats_order.index("airflow_instance/dag/dag2") >= 4 + assert context.cursor + cursor = deserialize_value(context.cursor, AirflowPollingSensorCursor) + assert cursor.end_date_gte == freeze_datetime.timestamp() + assert cursor.end_date_lte is None + assert cursor.dag_query_offset == 0 def test_dependencies_within_tasks() -> None: @@ -125,7 +142,7 @@ def test_dependencies_within_tasks() -> None: # e f freeze_datetime = datetime(2021, 1, 1) with freeze_time(freeze_datetime): - result = build_and_invoke_sensor( + result, context = build_and_invoke_sensor( assets_per_task={ "dag": { "task1": [("a", []), ("b", ["a"]), ("c", ["a"])], @@ -137,6 +154,11 @@ def test_dependencies_within_tasks() -> None: assert_expected_key_order( result.asset_events, ["a", "b", "c", "d", "e", "f", "airflow_instance/dag/dag"] ) + assert context.cursor + cursor = deserialize_value(context.cursor, AirflowPollingSensorCursor) + assert cursor.end_date_gte == freeze_datetime.timestamp() + assert cursor.end_date_lte is None + assert cursor.dag_query_offset == 0 def test_outside_of_dag_dependency() -> None: @@ -144,7 +166,7 @@ def test_outside_of_dag_dependency() -> None: # a -> b -> c where a and c are in the same task, and b is not in any dag. freeze_datetime = datetime(2021, 1, 1) with freeze_time(freeze_datetime): - result = build_and_invoke_sensor( + result, context = build_and_invoke_sensor( assets_per_task={ "dag": {"task": [("a", []), ("c", ["b"])]}, }, @@ -153,6 +175,11 @@ def test_outside_of_dag_dependency() -> None: assert len(result.asset_events) == 3 assert all(isinstance(event, AssetMaterialization) for event in result.asset_events) assert_expected_key_order(result.asset_events, ["a", "c", "airflow_instance/dag/dag"]) + assert context.cursor + cursor = deserialize_value(context.cursor, AirflowPollingSensorCursor) + assert cursor.end_date_gte == freeze_datetime.timestamp() + assert cursor.end_date_lte is None + assert cursor.dag_query_offset == 0 def test_request_asset_checks() -> None: @@ -172,7 +199,7 @@ def check_unrelated_asset(): pass with freeze_time(freeze_datetime): - result = build_and_invoke_sensor( + result, context = build_and_invoke_sensor( assets_per_task={ "dag": {"task": [("a", []), ("b", ["a"])]}, }, @@ -193,3 +220,117 @@ def check_unrelated_asset(): name="check_dag_asset", asset_key=AssetKey(["airflow_instance", "dag", "dag"]) ), } + assert context.cursor + cursor = deserialize_value(context.cursor, AirflowPollingSensorCursor) + assert cursor.end_date_gte == freeze_datetime.timestamp() + assert cursor.end_date_lte is None + assert cursor.dag_query_offset == 0 + + +_CALLCOUNT = [0] + + +def _mock_get_current_datetime() -> datetime: + if _CALLCOUNT[0] < 2: + _CALLCOUNT[0] += 1 + return datetime(2021, 2, 1, tzinfo=timezone.utc) + next_time = datetime(2021, 2, 1, tzinfo=timezone.utc) + timedelta(seconds=46 * _CALLCOUNT[0]) + _CALLCOUNT[0] += 1 + return next_time + + +def test_cursor() -> None: + """Test expected cursor behavior for sensor.""" + asset_and_dag_structure = { + "dag1": {"task1": [("a", [])]}, + "dag2": {"task1": [("b", [])]}, + } + + with freeze_time(datetime(2021, 1, 1, tzinfo=timezone.utc)): + # First, run through a full successful iteration of the sensor. Expect time to move forward, and polled dag id to be None, since we completed iteration of all dags. + # Then, run through a partial iteration of the sensor. We mock get_current_datetime to return a time after timeout passes iteration start after the first call, meaning we should pause iteration. + repo_def = fully_loaded_repo_from_airflow_asset_graph(asset_and_dag_structure) + sensor = next(iter(repo_def.sensor_defs)) + context = build_sensor_context(repository_def=repo_def) + result = sensor(context) + assert isinstance(result, SensorResult) + assert context.cursor + new_cursor = deserialize_value(context.cursor, AirflowPollingSensorCursor) + assert new_cursor.end_date_gte == datetime(2021, 1, 1, tzinfo=timezone.utc).timestamp() + assert new_cursor.end_date_lte is None + assert new_cursor.dag_query_offset == 0 + + with mock.patch( + "dagster._time._mockable_get_current_datetime", wraps=_mock_get_current_datetime + ): + result = sensor(context) + assert isinstance(result, SensorResult) + new_cursor = deserialize_value(context.cursor, AirflowPollingSensorCursor) + # We didn't advance to the next effective timestamp, since we didn't complete iteration + assert new_cursor.end_date_gte == datetime(2021, 1, 1, tzinfo=timezone.utc).timestamp() + # We have not yet moved forward + assert new_cursor.end_date_lte == datetime(2021, 2, 1, tzinfo=timezone.utc).timestamp() + assert new_cursor.dag_query_offset == 1 + + _CALLCOUNT[0] = 0 + # We weren't able to complete iteration, so we should pause iteration again + result = sensor(context) + assert isinstance(result, SensorResult) + new_cursor = deserialize_value(context.cursor, AirflowPollingSensorCursor) + assert new_cursor.end_date_gte == datetime(2021, 1, 1, tzinfo=timezone.utc).timestamp() + assert new_cursor.end_date_lte == datetime(2021, 2, 1, tzinfo=timezone.utc).timestamp() + assert new_cursor.dag_query_offset == 2 + + _CALLCOUNT[0] = 0 + # Now it should finish iteration. + result = sensor(context) + assert isinstance(result, SensorResult) + new_cursor = deserialize_value(context.cursor, AirflowPollingSensorCursor) + assert new_cursor.end_date_gte == datetime(2021, 2, 1, tzinfo=timezone.utc).timestamp() + assert new_cursor.end_date_lte is None + assert new_cursor.dag_query_offset == 0 + + +def test_legacy_cursor() -> None: + """Test the case where a legacy/uninterpretable cursor is provided to the sensor execution.""" + freeze_datetime = datetime(2021, 1, 1, tzinfo=timezone.utc) + with freeze_time(freeze_datetime): + repo_def = fully_loaded_repo_from_airflow_asset_graph( + { + "dag": {"task": [("a", [])]}, + } + ) + sensor = next(iter(repo_def.sensor_defs)) + context = build_sensor_context( + repository_def=repo_def, cursor=str(freeze_datetime.timestamp()) + ) + result = sensor(context) + assert isinstance(result, SensorResult) + assert context.cursor + new_cursor = deserialize_value(context.cursor, AirflowPollingSensorCursor) + assert new_cursor.end_date_gte == datetime(2021, 1, 1, tzinfo=timezone.utc).timestamp() + assert new_cursor.end_date_lte is None + assert new_cursor.dag_query_offset == 0 + + +def test_no_runs() -> None: + """Test the case with no runs.""" + freeze_datetime = datetime(2021, 1, 1, tzinfo=timezone.utc) + with freeze_time(freeze_datetime): + repo_def = fully_loaded_repo_from_airflow_asset_graph( + { + "dag": {"task": [("a", [])]}, + }, + create_runs=False, + ) + sensor = next(iter(repo_def.sensor_defs)) + context = build_sensor_context(repository_def=repo_def) + result = sensor(context) + assert isinstance(result, SensorResult) + assert context.cursor + new_cursor = deserialize_value(context.cursor, AirflowPollingSensorCursor) + assert new_cursor.end_date_gte == datetime(2021, 1, 1, tzinfo=timezone.utc).timestamp() + assert new_cursor.end_date_lte is None + assert new_cursor.dag_query_offset == 0 + assert not result.asset_events + assert not result.run_requests