Skip to content

Commit

Permalink
[1.8][external assets] partitions_def on AssetSpec (#23201)
Browse files Browse the repository at this point in the history
## Summary & Motivation

Adds a `partitions_def` attribute to `AssetSpec`. This is a step towards
deprecating `SourceAsset` in favor of `AssetSpec`.

The biggest question this raises is what happens to `@multi_asset`,
which accepts `AssetSpec`s and also a `partitions_def` argument. This PR
retains the constraint that all the assets within a multi-asset must
have the same `partitions_def`, though I hope to relax it soon after.

Valid parameter combinations for `@multi_asset`:
- `partitions_def` passed to `@multi_asset` is `None`, and
`partitions_def`s passed to individual `AssetSpec`s are all none -> use
`None`.
- `partitions_def` passed to `@multi_asset` is X, and `partitions_def`s
passed to individual `AssetSpec`s are all `None` -> use X.
- `partitions_def` passed to `@multi_asset` is X, and `partitions_def`s
passed to individual `AssetSpec`s are all X -> use X.
- `partitions_def` passed to `@multi_asset` is `None`, and
`partitions_def`s passed to individual `AssetSpec`s are all X -> use X.

Invalid parameter combinations for `@multi_asset`:
- `partitions_def`s passed to individual `AssetSpec`s are are not all
the same.
- In the future, we'll allow this, though require that the
`partitions_def` passed to `@multi_asset` is `None`.
- `partitions_def` passed to `@multi_asset` is X, and `partitions_def`s
passed to individual `AssetSpec`s are all Y.

A big question is whether we should deprecate the `partitions_def`
argument to `@multi_asset` as part of this PR. I suspect we should
eventually do this, but I think the disruption will be more palatable
when users get something in return, like the ability to apply different
`partitions_def`s to different assets within a multi-asset.

## How I Tested These Changes
  • Loading branch information
sryza authored Aug 1, 2024
1 parent 0118892 commit 114f3ac
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def to_spec(self, key: AssetKey, deps: Sequence[AssetDep]) -> AssetSpec:
tags=self.tags,
deps=deps,
auto_materialize_policy=None,
partitions_def=None,
)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dagster._core.definitions.declarative_automation.automation_condition import (
AutomationCondition,
)
from dagster._core.definitions.partition import PartitionsDefinition
from dagster._core.definitions.utils import validate_asset_owner
from dagster._serdes.serdes import whitelist_for_serdes
from dagster._utils.internal_init import IHasInternalInit
Expand Down Expand Up @@ -81,6 +82,7 @@ class AssetSpec(
("automation_condition", PublicAttr[Optional[AutomationCondition]]),
("owners", PublicAttr[Sequence[str]]),
("tags", PublicAttr[Mapping[str, str]]),
("partitions_def", PublicAttr[Optional[PartitionsDefinition]]),
],
),
IHasInternalInit,
Expand Down Expand Up @@ -130,6 +132,7 @@ def __new__(
tags: Optional[Mapping[str, str]] = None,
# TODO: FOU-243
auto_materialize_policy: Optional[AutoMaterializePolicy] = None,
partitions_def: Optional[PartitionsDefinition] = None,
):
from dagster._core.definitions.asset_dep import coerce_to_deps_and_check_duplicates

Expand Down Expand Up @@ -161,6 +164,9 @@ def __new__(
),
owners=owners,
tags=validate_tags_strict(tags) or {},
partitions_def=check.opt_inst_param(
partitions_def, "partitions_def", PartitionsDefinition
),
)

@staticmethod
Expand All @@ -178,6 +184,7 @@ def dagster_internal_init(
owners: Optional[Sequence[str]],
tags: Optional[Mapping[str, str]],
auto_materialize_policy: Optional[AutoMaterializePolicy],
partitions_def: Optional[PartitionsDefinition],
) -> "AssetSpec":
check.invariant(auto_materialize_policy is None)
return AssetSpec(
Expand All @@ -192,6 +199,7 @@ def dagster_internal_init(
automation_condition=automation_condition,
owners=owners,
tags=tags,
partitions_def=partitions_def,
)

@cached_property
Expand Down
36 changes: 35 additions & 1 deletion python_modules/dagster/dagster/_core/definitions/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(
execution_type=execution_type or AssetExecutionType.MATERIALIZATION,
)

self._partitions_def = partitions_def
self._partitions_def = _resolve_partitions_def(specs, partitions_def)

self._resource_defs = wrap_resources_for_execution(
check.opt_mapping_param(resource_defs, "resource_defs")
Expand Down Expand Up @@ -345,6 +345,7 @@ def __init__(
metadata=metadata,
description=description,
skippable=skippable,
partitions_def=self._partitions_def,
)
)

Expand Down Expand Up @@ -1729,6 +1730,7 @@ def _asset_specs_from_attr_key_params(
# NodeDefinition
skippable=False,
auto_materialize_policy=None,
partitions_def=None,
)
)

Expand Down Expand Up @@ -1789,6 +1791,38 @@ def get_self_dep_time_window_partition_mapping(
return None


def _resolve_partitions_def(
specs: Optional[Sequence[AssetSpec]], partitions_def: Optional[PartitionsDefinition]
) -> Optional[PartitionsDefinition]:
if specs:
asset_keys_by_partitions_def = defaultdict(set)
for spec in specs:
asset_keys_by_partitions_def[spec.partitions_def].add(spec.key)
if len(asset_keys_by_partitions_def) > 1:
partition_1_asset_keys, partition_2_asset_keys, *_ = (
asset_keys_by_partitions_def.values()
)
check.failed(
f"All AssetSpecs must have the same partitions_def, but "
f"{next(iter(partition_1_asset_keys)).to_user_string()} and "
f"{next(iter(partition_2_asset_keys)).to_user_string()} have different "
"partitions_defs."
)
common_partitions_def = next(iter(asset_keys_by_partitions_def.keys()))
if (
common_partitions_def is not None
and partitions_def is not None
and common_partitions_def != partitions_def
):
check.failed(
f"AssetSpec for {next(iter(specs)).key.to_user_string()} has partitions_def which is different "
"than the partitions_def provided to AssetsDefinition.",
)
return partitions_def or common_partitions_def
else:
return partitions_def


def get_partition_mappings_from_deps(
partition_mappings: Dict[AssetKey, PartitionMapping], deps: Iterable[AssetDep], asset_name: str
) -> Mapping[AssetKey, PartitionMapping]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AssetMaterialization,
AssetOut,
AssetsDefinition,
AssetSpec,
DagsterInstance,
DagsterInvalidDefinitionError,
DailyPartitionsDefinition,
Expand Down Expand Up @@ -762,3 +763,45 @@ def downstream_asset(context, upstream_asset):
[downstream_asset, upstream_asset.to_source_asset()],
partition_key="2020-01-02-05:00",
)


def test_asset_spec_partitions_def():
partitions_def = DailyPartitionsDefinition(start_date="2020-01-01")

@multi_asset(
specs=[AssetSpec("asset1", partitions_def=partitions_def)], partitions_def=partitions_def
)
def assets1(): ...

assert assets1.partitions_def == partitions_def
assert next(iter(assets1.specs)).partitions_def == partitions_def

@multi_asset(specs=[AssetSpec("asset1", partitions_def=partitions_def)])
def assets2(): ...

assert assets2.partitions_def == partitions_def
assert next(iter(assets2.specs)).partitions_def == partitions_def

with pytest.raises(
CheckError,
match="AssetSpec for asset1 has partitions_def which is different than the partitions_def provided to AssetsDefinition.",
):

@multi_asset(
specs=[AssetSpec("asset1", partitions_def=StaticPartitionsDefinition(["a", "b"]))],
partitions_def=partitions_def,
)
def assets3(): ...

with pytest.raises(
CheckError,
match="All AssetSpecs must have the same partitions_def, but asset1 and asset2 have different partitions_defs.",
):

@multi_asset(
specs=[
AssetSpec("asset1", partitions_def=partitions_def),
AssetSpec("asset2", partitions_def=StaticPartitionsDefinition(["a", "b"])),
],
)
def assets4(): ...
Original file line number Diff line number Diff line change
Expand Up @@ -813,54 +813,66 @@ def external_asset_graph_from_assets_by_repo_name(
)
def test_serialization(static_serialization, time_window_serialization):
time_window_partitions = DailyPartitionsDefinition(start_date="2015-05-05")

@asset(partitions_def=time_window_partitions)
def daily_asset():
return 1

keys = ["a", "b", "c", "d", "e", "f"]
static_partitions = StaticPartitionsDefinition(keys)

@asset(partitions_def=static_partitions)
def static_asset():
return 1
def make_asset_graph1():
@asset(partitions_def=time_window_partitions)
def daily_asset(): ...

asset_graph = external_asset_graph_from_assets_by_repo_name(
{"repo": [daily_asset, static_asset]}
)
@asset(partitions_def=static_partitions)
def static_asset(): ...

assert AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph) is True
assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph) is True
return external_asset_graph_from_assets_by_repo_name({"repo": [daily_asset, static_asset]})

daily_asset._partitions_def = static_partitions # noqa: SLF001
static_asset._partitions_def = time_window_partitions # noqa: SLF001
asset_graph1 = make_asset_graph1()
assert AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph1) is True
assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph1) is True

asset_graph = external_asset_graph_from_assets_by_repo_name(
{"repo": [daily_asset, static_asset]}
)
def make_asset_graph2():
@asset(partitions_def=static_partitions)
def daily_asset(): ...

assert AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph) is False
assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph) is False
@asset(partitions_def=time_window_partitions)
def static_asset(): ...

static_asset._partitions_def = StaticPartitionsDefinition(keys + ["x"]) # noqa: SLF001
return external_asset_graph_from_assets_by_repo_name({"repo": [daily_asset, static_asset]})

asset_graph = external_asset_graph_from_assets_by_repo_name(
{"repo": [daily_asset, static_asset]}
asset_graph2 = make_asset_graph2()
assert (
AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph2) is False
)
assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph2) is False

assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph) is True
def make_asset_graph3():
@asset(partitions_def=StaticPartitionsDefinition(keys + ["x"]))
def daily_asset(): ...

@asset(partitions_def=static_partitions)
def daily_asset_renamed():
return 1
@asset(partitions_def=static_partitions)
def static_asset(): ...

asset_graph_renamed = external_asset_graph_from_assets_by_repo_name(
{"repo": [daily_asset_renamed, static_asset]}
)
return external_asset_graph_from_assets_by_repo_name({"repo": [daily_asset, static_asset]})

asset_graph3 = make_asset_graph3()

assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph3) is True

def make_asset_graph4():
@asset(partitions_def=static_partitions)
def daily_asset_renamed():
return 1

@asset(partitions_def=time_window_partitions)
def static_asset(): ...

return external_asset_graph_from_assets_by_repo_name(
{"repo": [daily_asset_renamed, static_asset]}
)

asset_graph4 = make_asset_graph4()

assert (
AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph_renamed)
is False
AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph4) is False
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import os
import sys
from collections import namedtuple
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import AbstractSet, Iterable, NamedTuple, Optional, Sequence, Union, cast
Expand Down Expand Up @@ -106,15 +105,6 @@ def _get_code_location_origin_from_repository(repository: RepositoryDefinition,
)


class AssetSpecWithPartitionsDef(
namedtuple(
"AssetSpecWithPartitionsDef",
AssetSpec._fields + ("partitions_def",),
defaults=(None,) * (1 + len(AssetSpec._fields)),
)
): ...


class MultiAssetSpec(NamedTuple):
specs: Sequence[AssetSpec]
partitions_def: Optional[PartitionsDefinition] = None
Expand All @@ -125,7 +115,7 @@ class MultiAssetSpec(NamedTuple):
class ScenarioSpec:
"""A construct for declaring and modifying a desired Definitions object."""

asset_specs: Sequence[Union[AssetSpec, AssetSpecWithPartitionsDef, MultiAssetSpec]]
asset_specs: Sequence[Union[AssetSpec, MultiAssetSpec]]
current_time: datetime.datetime = field(default_factory=lambda: get_current_datetime())
sensors: Sequence[SensorDefinition] = field(default_factory=list)
additional_repo_specs: Sequence["ScenarioSpec"] = field(default_factory=list)
Expand Down Expand Up @@ -160,20 +150,12 @@ def _multi_asset(context: AssetExecutionContext):
)
# create an observable_source_asset or regular asset depending on the execution type
if execution_type == AssetExecutionType.OBSERVATION:
if isinstance(spec, AssetSpecWithPartitionsDef):
sd = spec._asdict()
partitions_def = sd.pop("partitions_def")
specs = [AssetSpec(**sd)]
else:
partitions_def = None
specs = [spec]

@op
def noop(): ...

osa = AssetsDefinition(
specs=specs,
partitions_def=partitions_def,
specs=[spec],
execution_type=execution_type,
keys_by_output_name={"result": spec.key},
node_def=noop,
Expand Down Expand Up @@ -253,13 +235,7 @@ def with_asset_properties(
)
else:
if keys is None or spec.key in {AssetKey.from_coercible(key) for key in keys}:
if "partitions_def" in kwargs:
# partitions_def is not a field on AssetSpec, so we need to do this hack
new_asset_specs.append(
AssetSpecWithPartitionsDef(**{**spec._asdict(), **kwargs})
)
else:
new_asset_specs.append(spec._replace(**kwargs))
new_asset_specs.append(spec._replace(**kwargs))
else:
new_asset_specs.append(spec)
return dataclasses.replace(self, asset_specs=new_asset_specs)
Expand Down

0 comments on commit 114f3ac

Please sign in to comment.