Skip to content

Commit

Permalink
[dagster-airlift] test refactors and fixes (#24146)
Browse files Browse the repository at this point in the history
Refactors tests to use the new framework.
  • Loading branch information
dpeng817 authored Sep 3, 2024
1 parent 553d74a commit 793f354
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 320 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,14 @@ def config(self) -> Dict[str, Any]:
def start_date(self) -> float:
return AirflowInstance.timestamp_from_airflow_date(self.metadata["start_date"])

@property
def start_datetime(self) -> datetime.datetime:
return datetime.datetime.strptime(self.metadata["start_date"], "%Y-%m-%dT%H:%M:%S+00:00")

@property
def end_date(self) -> float:
return AirflowInstance.timestamp_from_airflow_date(self.metadata["end_date"])

@property
def end_datetime(self) -> datetime.datetime:
return datetime.datetime.strptime(self.metadata["end_date"], "%Y-%m-%dT%H:%M:%S+00:00")
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
DummyAuthBackend as DummyAuthBackend,
make_dag_info as make_dag_info,
make_dag_run as make_dag_run,
make_instance as make_instance,
make_task_info as make_task_info,
make_task_instance as make_task_instance,
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from datetime import datetime
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple

import requests
Expand Down Expand Up @@ -47,6 +47,9 @@ def __init__(
name="test_instance",
)

def list_dags(self) -> List[DagInfo]:
return list(self._dag_infos_by_dag_id.values())

def get_dag_runs(self, dag_id: str, start_date: datetime, end_date: datetime) -> List[DagRun]:
if dag_id not in self._dag_runs_by_dag_id:
raise ValueError(f"Dag run not found for dag_id {dag_id}")
Expand Down Expand Up @@ -108,13 +111,19 @@ def make_task_info(dag_id: str, task_id: str) -> TaskInfo:
)


def make_task_instance(dag_id: str, task_id: str, run_id: str) -> TaskInstance:
def make_task_instance(
dag_id: str, task_id: str, run_id: str, start_date: datetime, end_date: datetime
) -> TaskInstance:
return TaskInstance(
webserver_url="http://dummy.domain",
dag_id=dag_id,
task_id=task_id,
run_id=run_id,
metadata={},
metadata={
"state": "success",
"start_date": AirflowInstance.airflow_date_from_datetime(start_date),
"end_date": AirflowInstance.airflow_date_from_datetime(end_date),
},
)


Expand All @@ -127,5 +136,50 @@ def make_dag_run(dag_id: str, run_id: str, start_date: datetime, end_date: datet
"state": "success",
"start_date": AirflowInstance.airflow_date_from_datetime(start_date),
"end_date": AirflowInstance.airflow_date_from_datetime(end_date),
"run_type": "manual",
"note": "dummy note",
"conf": {},
},
)


def make_instance(
dag_and_task_structure: Dict[str, List[str]],
dag_runs: List[DagRun] = [],
) -> AirflowInstanceFake:
"""Constructs DagInfo, TaskInfo, and TaskInstance objects from provided data.
Args:
dag_and_task_structure: A dictionary mapping dag_id to a list of task_ids.
dag_runs: A list of DagRun objects to include in the instance. A TaskInstance object will be
created for each task_id in the dag, for each DagRun in dag_runs pertaining to a particular dag.
"""
dag_infos = []
task_infos = []
for dag_id, task_ids in dag_and_task_structure.items():
dag_info = make_dag_info(dag_id=dag_id, file_token=dag_id)
dag_infos.append(dag_info)
task_infos.extend([make_task_info(dag_id=dag_id, task_id=task_id) for task_id in task_ids])
task_instances = []
for dag_run in dag_runs:
task_instances.extend(
[
make_task_instance(
dag_id=dag_run.dag_id,
task_id=task_id,
run_id=dag_run.run_id,
start_date=dag_run.start_datetime,
end_date=dag_run.end_datetime
- timedelta(
seconds=1
), # Ensure that the task ends before the full "dag" completes.
)
for task_id in dag_and_task_structure[dag_run.dag_id]
]
)
return AirflowInstanceFake(
dag_infos=dag_infos,
task_infos=task_infos,
task_instances=task_instances,
dag_runs=dag_runs,
)
Original file line number Diff line number Diff line change
@@ -1,210 +1,75 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Sequence
from typing import Dict, List, Sequence, Tuple, Union

import requests
from dagster import (
AssetsDefinition,
AssetSpec,
SensorResult,
asset,
build_sensor_context,
multi_asset,
repository,
)
from dagster import AssetObservation, AssetSpec, Definitions, SensorResult, build_sensor_context
from dagster._core.definitions.asset_check_evaluation import AssetCheckEvaluation
from dagster._core.definitions.events import AssetMaterialization
from dagster._core.definitions.repository_definition.repository_definition import (
RepositoryDefinition,
)
from dagster._time import get_current_datetime
from dagster_airlift.core import AirflowInstance
from dagster_airlift.core.airflow_instance import DagRun, TaskInfo, TaskInstance
from dagster_airlift.core.basic_auth import AirflowAuthBackend
from dagster_airlift.core.sensor import build_airflow_polling_sensor
from dagster_airlift.core import build_defs_from_airflow_instance
from dagster_airlift.test import make_dag_run, make_instance


def strip_to_first_of_month(dt: datetime) -> datetime:
return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0)


class DummyAuthBackend(AirflowAuthBackend):
def get_session(self) -> requests.Session:
raise NotImplementedError("This shouldn't be called from this mock context.")

def get_webserver_url(self) -> str:
return "http://dummy.domain"


class DummyInstance(AirflowInstance):
"""A dummy instance that returns a single dag run and task instance for each call.
Designed in such a way that timestamps mirror the task_id, so that we can easily test ordering.
If you want some task to complete after a different task, you can simply set the task_id to a higher number.
The dag id should be a number higher than any task id it contains, so that it will complete after all constituent tasks.
This instance is designed to be used with "frozen" time, so that a baseline can be established for testing.
"""

def __init__(self) -> None:
super().__init__(
auth_backend=DummyAuthBackend(),
name="dummy_instance",
)

def get_dag_runs(self, dag_id: str, start_date: datetime, end_date: datetime) -> List[DagRun]:
"""Return a single dag run that started and finished within the given range."""
cur_date = strip_to_first_of_month(get_current_datetime())
return [
make_dag_run(cur_date, cur_date + timedelta(days=int(dag_id) + 1), dag_id),
]

def get_task_instance(self, dag_id: str, task_id: str, run_id: str) -> TaskInstance:
"""Return a task instance that started and finished within the given range. Expects that time has been frozen."""
cur_date = strip_to_first_of_month(get_current_datetime())
return make_task_instance(
dag_id,
task_id,
cur_date + timedelta(days=int(task_id)),
cur_date + timedelta(days=int(task_id) + 1),
)

def get_task_info(self, dag_id, task_id) -> TaskInfo:
return TaskInfo(
webserver_url="http://localhost:8080", dag_id=dag_id, task_id=task_id, metadata={}
)

def get_dag_source_code(self, file_token: str) -> str:
return "source code"


def make_dag_run(dag_start: datetime, dag_end: datetime, dag_id: str) -> DagRun:
return DagRun(
metadata={
"run_type": "manual",
"conf": {},
"start_date": dag_start.strftime("%Y-%m-%dT%H:%M:%S+00:00"),
"end_date": dag_end.strftime("%Y-%m-%dT%H:%M:%S+00:00"),
"state": "success",
},
dag_id=dag_id,
run_id="run",
webserver_url="http://localhost:8080",
def build_defs_from_airflow_asset_graph(
assets_per_task: Dict[str, Dict[str, List[Tuple[str, List[str]]]]],
additional_defs: Definitions = Definitions(),
) -> RepositoryDefinition:
specs = []
dag_and_task_structure = defaultdict(list)
for dag_id, task_structure in assets_per_task.items():
for task_id, asset_structure in task_structure.items():
dag_and_task_structure[dag_id].append(task_id)
for asset_key, deps in asset_structure:
specs.append(
AssetSpec(
asset_key,
deps=deps,
tags={"airlift/dag_id": dag_id, "airlift/task_id": task_id},
)
)
instance = make_instance(
dag_and_task_structure=dag_and_task_structure,
dag_runs=[
make_dag_run(
dag_id=dag_id,
run_id=f"run-{dag_id}",
start_date=get_current_datetime() - timedelta(minutes=10),
end_date=get_current_datetime(),
)
for dag_id in dag_and_task_structure.keys()
],
)


def make_task_instance(
dag_id: str, task_id: str, task_start: datetime, task_end: datetime
) -> TaskInstance:
return TaskInstance(
metadata={
"note": "note",
"start_date": task_start.strftime("%Y-%m-%dT%H:%M:%S+00:00"),
"end_date": task_end.strftime("%Y-%m-%dT%H:%M:%S+00:00"),
"state": "success",
},
dag_id=dag_id,
task_id=task_id,
webserver_url="http://localhost:8080",
run_id="run",
defs = Definitions.merge(
additional_defs,
Definitions(assets=specs),
)


def build_task_asset(
deps_graph: Dict[str, List[str]],
task_id: str,
dag_id: str,
) -> AssetsDefinition:
asset_specs = [AssetSpec(key=key, deps=deps) for key, deps in deps_graph.items()]

@multi_asset(specs=asset_specs, op_tags={"airlift/task_id": task_id, "airlift/dag_id": dag_id})
def asset_fn():
pass

return asset_fn


def build_dag_asset(
dag_id: str,
) -> AssetsDefinition:
@asset(op_tags={"airlift/dag_id": dag_id}, key=dag_id)
def asset_fn():
pass

return asset_fn


def make_test_instance(
get_task_instance_override=None, get_dag_runs_override=None, list_dags_override=None
) -> DummyInstance:
klass_to_instantiate = DummyInstance
if get_task_instance_override:

class TaskInstanceOverride(klass_to_instantiate):
def get_task_instance(self, dag_id: str, task_id: str, run_id: str) -> TaskInstance:
return get_task_instance_override(self, dag_id, task_id, run_id)

klass_to_instantiate = TaskInstanceOverride

if get_dag_runs_override:

class DagRunsOverride(klass_to_instantiate): # type: ignore
def get_dag_runs(
self, dag_id: str, start_date: datetime, end_date: datetime
) -> List[DagRun]:
return get_dag_runs_override(self, dag_id, start_date, end_date)

klass_to_instantiate = DagRunsOverride

if list_dags_override:

class ListDagsOverride(klass_to_instantiate): # type: ignore
def list_dags(self):
return list_dags_override(self)

klass_to_instantiate = ListDagsOverride

return klass_to_instantiate()


def repo_from_defs(assets_defs: List[AssetsDefinition]) -> RepositoryDefinition:
@repository
def repo():
return assets_defs

return repo
repo_def = build_defs_from_airflow_instance(instance, defs=defs).get_repository_def()
repo_def.load_all_definitions()
return repo_def


def build_and_invoke_sensor(
instance: AirflowInstance,
defs: List[AssetsDefinition],
assets_per_task: Dict[str, Dict[str, List[Tuple[str, List[str]]]]],
additional_defs: Definitions = Definitions(),
) -> SensorResult:
sensor = build_airflow_polling_sensor(instance)
context = build_sensor_context(repository_def=repo_from_defs(defs))
repo_def = build_defs_from_airflow_asset_graph(assets_per_task, additional_defs=additional_defs)
sensor = next(iter(repo_def.sensor_defs))
context = build_sensor_context(repository_def=repo_def)
result = sensor(context)
assert isinstance(result, SensorResult)
return result


def build_dag_assets(
tasks_to_asset_deps_graph: Dict[str, Dict[str, List[str]]],
dag_id: Optional[str] = None,
) -> List[AssetsDefinition]:
resolved_dag_id = dag_id or str(
max(int(task_id) for task_id in tasks_to_asset_deps_graph.keys()) + 1
)
assets = []
for task_id, deps_graph in tasks_to_asset_deps_graph.items():
assets.append(build_task_asset(deps_graph, task_id, resolved_dag_id))
assets.append(build_dag_asset(resolved_dag_id))
return assets


def assert_expected_key_order(
mats: Sequence[AssetMaterialization], expected_key_order: Sequence[str]
mats: Sequence[Union[AssetMaterialization, AssetObservation, AssetCheckEvaluation]],
expected_key_order: Sequence[str],
) -> None:
assert all(isinstance(mat, AssetMaterialization) for mat in mats)
assert [mat.asset_key.to_user_string() for mat in mats] == expected_key_order


def make_asset(key, deps):
@asset(key=key, deps=deps)
def the_asset():
pass

return the_asset
Loading

0 comments on commit 793f354

Please sign in to comment.