Skip to content

Commit

Permalink
AIP-82 Handle trigger serialization (apache#45562)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored Jan 31, 2025
1 parent b625c70 commit 53e1723
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 35 deletions.
39 changes: 20 additions & 19 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import json
import logging
import traceback
from typing import TYPE_CHECKING, Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from sqlalchemy import and_, delete, exists, func, insert, select, tuple_
from sqlalchemy.exc import OperationalError
Expand All @@ -53,8 +53,7 @@
from airflow.models.errors import ParseImportError
from airflow.models.trigger import Trigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger
from airflow.serialization.serialized_objects import BaseSerialization, SerializedAssetWatcher
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.timezone import utcnow
Expand All @@ -68,7 +67,6 @@

from airflow.models.dagwarning import DagWarning
from airflow.serialization.serialized_objects import MaybeSerializedDAG
from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -747,16 +745,23 @@ def add_asset_trigger_references(
# Update references from assets being used
refs_to_add: dict[tuple[str, str], set[int]] = {}
refs_to_remove: dict[tuple[str, str], set[int]] = {}
triggers: dict[int, BaseTrigger] = {}
triggers: dict[int, dict] = {}

# Optimization: if no asset collected, skip fetching active assets
active_assets = _find_active_assets(self.assets.keys(), session=session) if self.assets else {}

for name_uri, asset in self.assets.items():
# If the asset belong to a DAG not active or paused, consider there is no watcher associated to it
asset_watchers = asset.watchers if name_uri in active_assets else []
trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = {
self._get_base_trigger_hash(trigger): trigger for trigger in asset_watchers
asset_watchers: list[SerializedAssetWatcher] = (
[cast(SerializedAssetWatcher, watcher) for watcher in asset.watchers]
if name_uri in active_assets
else []
)
trigger_hash_to_trigger_dict: dict[int, dict] = {
self._get_trigger_hash(
watcher.trigger["classpath"], watcher.trigger["kwargs"]
): watcher.trigger
for watcher in asset_watchers
}
triggers.update(trigger_hash_to_trigger_dict)
trigger_hash_from_asset: set[int] = set(trigger_hash_to_trigger_dict.keys())
Expand All @@ -783,7 +788,10 @@ def add_asset_trigger_references(
}

all_trigger_keys: set[tuple[str, str]] = {
self._encrypt_trigger_kwargs(triggers[trigger_hash])
(
triggers[trigger_hash]["classpath"],
Trigger.encrypt_kwargs(triggers[trigger_hash]["kwargs"]),
)
for trigger_hashes in refs_to_add.values()
for trigger_hash in trigger_hashes
}
Expand All @@ -800,7 +808,9 @@ def add_asset_trigger_references(
new_trigger_models = [
trigger
for trigger in [
Trigger.from_object(triggers[trigger_hash])
Trigger(
classpath=triggers[trigger_hash]["classpath"], kwargs=triggers[trigger_hash]["kwargs"]
)
for trigger_hash in all_trigger_hashes
if trigger_hash not in orm_triggers
]
Expand Down Expand Up @@ -836,11 +846,6 @@ def add_asset_trigger_references(
if (asset_model.name, asset_model.uri) not in self.assets:
asset_model.triggers = []

@staticmethod
def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
classpath, kwargs = trigger.serialize()
return classpath, Trigger.encrypt_kwargs(kwargs)

@staticmethod
def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
"""
Expand All @@ -852,7 +857,3 @@ def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
This is not true for event driven scheduling.
"""
return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))

def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int:
classpath, kwargs = trigger.serialize()
return self._get_trigger_hash(classpath, kwargs)
7 changes: 3 additions & 4 deletions airflow/example_dags/example_asset_with_watchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
from __future__ import annotations

import os
import tempfile

from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.standard.triggers.file import FileTrigger
from airflow.sdk.definitions.asset import Asset
from airflow.sdk import Asset, AssetWatcher

file_path = tempfile.NamedTemporaryFile().name
file_path = "/tmp/test"

with DAG(
dag_id="example_create_file",
Expand All @@ -44,7 +43,7 @@ def create_file():
chain(create_file())

trigger = FileTrigger(filepath=file_path, poke_interval=10)
asset = Asset("example_asset", watchers=[trigger])
asset = Asset("example_asset", watchers=[AssetWatcher(name="test_file_watcher", trigger=trigger)])

with DAG(
dag_id="example_asset_with_watchers",
Expand Down
12 changes: 12 additions & 0 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@
{"type": "null"},
{ "$ref": "#/definitions/dict" }
]
},
"watchers": {
"type": "array",
"items": { "$ref": "#/definitions/trigger" }
}
},
"required": [ "uri", "extra" ]
Expand Down Expand Up @@ -126,6 +130,14 @@
],
"additionalProperties": false
},
"trigger": {
"type": "object",
"properties": {
"classpath": { "type": "string" },
"kwargs": { "$ref": "#/definitions/dict" }
},
"required": [ "classpath", "kwargs" ]
},
"dict": {
"description": "A python dictionary containing values of any type",
"type": "object"
Expand Down
47 changes: 44 additions & 3 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
AssetAny,
AssetRef,
AssetUniqueKey,
AssetWatcher,
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
Expand Down Expand Up @@ -251,13 +252,34 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]:
:meta private:
"""
if isinstance(var, Asset):
return {

def _encode_watcher(watcher: AssetWatcher):
return {
"name": watcher.name,
"trigger": _encode_trigger(watcher.trigger),
}

def _encode_trigger(trigger: BaseTrigger | dict):
if isinstance(trigger, dict):
return trigger
classpath, kwargs = trigger.serialize()
return {
"classpath": classpath,
"kwargs": kwargs,
}

asset = {
"__type": DAT.ASSET,
"name": var.name,
"uri": var.uri,
"group": var.group,
"extra": var.extra,
}

if len(var.watchers) > 0:
asset["watchers"] = [_encode_watcher(watcher) for watcher in var.watchers]

return asset
if isinstance(var, AssetAlias):
return {"__type": DAT.ASSET_ALIAS, "name": var.name, "group": var.group}
if isinstance(var, AssetAll):
Expand All @@ -283,7 +305,7 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:
"""
dat = var["__type"]
if dat == DAT.ASSET:
return Asset(name=var["name"], uri=var["uri"], group=var["group"], extra=var["extra"])
return decode_asset(var)
if dat == DAT.ASSET_ALL:
return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
if dat == DAT.ASSET_ANY:
Expand All @@ -295,6 +317,19 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:
raise ValueError(f"deserialization not implemented for DAT {dat!r}")


def decode_asset(var: dict[str, Any]):
watchers = var.get("watchers", [])
return Asset(
name=var["name"],
uri=var["uri"],
group=var["group"],
extra=var["extra"],
watchers=[
SerializedAssetWatcher(name=watcher["name"], trigger=watcher["trigger"]) for watcher in watchers
],
)


def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
key = var.key
return {
Expand Down Expand Up @@ -874,7 +909,7 @@ def deserialize(cls, encoded_var: Any) -> Any:
elif type_ == DAT.XCOM_REF:
return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
elif type_ == DAT.ASSET:
return Asset(**var)
return decode_asset(var)
elif type_ == DAT.ASSET_ALIAS:
return AssetAlias(**var)
elif type_ == DAT.ASSET_ANY:
Expand Down Expand Up @@ -1810,6 +1845,12 @@ def set_ref(task: Operator) -> Operator:
return group


class SerializedAssetWatcher(AssetWatcher):
"""JSON serializable representation of an asset watcher."""

trigger: dict


def _has_kubernetes() -> bool:
global HAS_KUBERNETES
if "HAS_KUBERNETES" in globals():
Expand Down
7 changes: 6 additions & 1 deletion task_sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from typing import TYPE_CHECKING

__all__ = [
"__version__",
"Asset",
"AssetWatcher",
"BaseOperator",
"Connection",
"DAG",
Expand All @@ -27,7 +30,6 @@
"MappedOperator",
"TaskGroup",
"XComArg",
"__version__",
"dag",
"get_current_context",
"get_parsing_context",
Expand All @@ -36,6 +38,7 @@
__version__ = "1.0.0.alpha1"

if TYPE_CHECKING:
from airflow.sdk.definitions.asset import Asset, AssetWatcher
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.context import get_current_context, get_parsing_context
Expand All @@ -60,6 +63,8 @@
"dag": ".definitions.dag",
"get_current_context": ".definitions.context",
"get_parsing_context": ".definitions.context",
"Asset": ".definitions.asset",
"AssetWatcher": ".definitions.asset",
}


Expand Down
25 changes: 20 additions & 5 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from sqlalchemy.orm import Session

from airflow.models.asset import AssetModel
from airflow.serialization.serialized_objects import SerializedAssetWatcher
from airflow.triggers.base import BaseTrigger

AttrsInstance = attrs.AttrsInstance
Expand All @@ -54,6 +55,7 @@
"AssetNameRef",
"AssetRef",
"AssetUriRef",
"AssetWatcher",
]


Expand Down Expand Up @@ -252,6 +254,19 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
raise NotImplementedError


@attrs.define(frozen=True)
class AssetWatcher:
"""A representation of an asset watcher. The name uniquely identifies the watch."""

name: str
# This attribute serves double purpose.
# For a "normal" asset instance loaded from DAG, this holds the trigger used to monitor an external
# resource. In that case, ``AssetWatcher`` is used directly by users.
# For an asset recreated from a serialized DAG, this holds the serialized data of the trigger. In that
# case, `SerializedAssetWatcher` is used. We need to keep the two types to make mypy happy.
trigger: BaseTrigger | dict


@attrs.define(init=False, unsafe_hash=False)
class Asset(os.PathLike, BaseAsset):
"""A representation of data asset dependencies between workflows."""
Expand All @@ -271,7 +286,7 @@ class Asset(os.PathLike, BaseAsset):
factory=dict,
converter=_set_extra_default,
)
watchers: list[BaseTrigger] = attrs.field(
watchers: list[AssetWatcher | SerializedAssetWatcher] = attrs.field(
factory=list,
)

Expand All @@ -286,7 +301,7 @@ def __init__(
*,
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] = ...,
watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
) -> None:
"""Canonical; both name and uri are provided."""

Expand All @@ -297,7 +312,7 @@ def __init__(
*,
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] = ...,
watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

Expand All @@ -308,7 +323,7 @@ def __init__(
uri: str,
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] = ...,
watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

Expand All @@ -319,7 +334,7 @@ def __init__(
*,
group: str | None = None,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
watchers: list[AssetWatcher | SerializedAssetWatcher] | None = None,
) -> None:
if name is None and uri is None:
raise TypeError("Asset() requires either 'name' or 'uri'")
Expand Down
8 changes: 6 additions & 2 deletions tests/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.asset import Asset, AssetWatcher
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
from airflow.utils import timezone as tz
from airflow.utils.session import create_session
Expand Down Expand Up @@ -131,7 +131,11 @@ def per_test(self) -> Generator:
)
def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_triggers, dag_maker):
trigger = TimeDeltaTrigger(timedelta(seconds=0))
asset = Asset("test_add_asset_trigger_references_asset", watchers=[trigger])
classpath, kwargs = trigger.serialize()
asset = Asset(
"test_add_asset_trigger_references_asset",
watchers=[AssetWatcher(name="test", trigger={"classpath": classpath, "kwargs": kwargs})],
)

with dag_maker(dag_id="test_add_asset_trigger_references_dag", schedule=[asset]) as dag:
EmptyOperator(task_id="mytask")
Expand Down
Loading

0 comments on commit 53e1723

Please sign in to comment.