diff --git a/python_modules/dagster/dagster/_core/definitions/asset_out.py b/python_modules/dagster/dagster/_core/definitions/asset_out.py index 3ef432235af57..bc0f2c5fd4e1e 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_out.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_out.py @@ -161,6 +161,7 @@ def to_spec(self, key: AssetKey, deps: Sequence[AssetDep]) -> AssetSpec: tags=self.tags, deps=deps, auto_materialize_policy=None, + partitions_def=None, ) @property diff --git a/python_modules/dagster/dagster/_core/definitions/asset_spec.py b/python_modules/dagster/dagster/_core/definitions/asset_spec.py index dd3509216ee33..af9fe2d0ebeba 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_spec.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_spec.py @@ -7,6 +7,7 @@ from dagster._core.definitions.declarative_automation.automation_condition import ( AutomationCondition, ) +from dagster._core.definitions.partition import PartitionsDefinition from dagster._core.definitions.utils import validate_asset_owner from dagster._serdes.serdes import whitelist_for_serdes from dagster._utils.internal_init import IHasInternalInit @@ -81,6 +82,7 @@ class AssetSpec( ("automation_condition", PublicAttr[Optional[AutomationCondition]]), ("owners", PublicAttr[Sequence[str]]), ("tags", PublicAttr[Mapping[str, str]]), + ("partitions_def", PublicAttr[Optional[PartitionsDefinition]]), ], ), IHasInternalInit, @@ -130,6 +132,7 @@ def __new__( tags: Optional[Mapping[str, str]] = None, # TODO: FOU-243 auto_materialize_policy: Optional[AutoMaterializePolicy] = None, + partitions_def: Optional[PartitionsDefinition] = None, ): from dagster._core.definitions.asset_dep import coerce_to_deps_and_check_duplicates @@ -161,6 +164,9 @@ def __new__( ), owners=owners, tags=validate_tags_strict(tags) or {}, + partitions_def=check.opt_inst_param( + partitions_def, "partitions_def", PartitionsDefinition + ), ) @staticmethod @@ -178,6 +184,7 @@ def dagster_internal_init( owners: Optional[Sequence[str]], tags: Optional[Mapping[str, str]], auto_materialize_policy: Optional[AutoMaterializePolicy], + partitions_def: Optional[PartitionsDefinition], ) -> "AssetSpec": check.invariant(auto_materialize_policy is None) return AssetSpec( @@ -192,6 +199,7 @@ def dagster_internal_init( automation_condition=automation_condition, owners=owners, tags=tags, + partitions_def=partitions_def, ) @cached_property diff --git a/python_modules/dagster/dagster/_core/definitions/assets.py b/python_modules/dagster/dagster/_core/definitions/assets.py index 84220ddd997c6..fc637d33f90df 100644 --- a/python_modules/dagster/dagster/_core/definitions/assets.py +++ b/python_modules/dagster/dagster/_core/definitions/assets.py @@ -227,7 +227,7 @@ def __init__( execution_type=execution_type or AssetExecutionType.MATERIALIZATION, ) - self._partitions_def = partitions_def + self._partitions_def = _resolve_partitions_def(specs, partitions_def) self._resource_defs = wrap_resources_for_execution( check.opt_mapping_param(resource_defs, "resource_defs") @@ -345,6 +345,7 @@ def __init__( metadata=metadata, description=description, skippable=skippable, + partitions_def=self._partitions_def, ) ) @@ -1729,6 +1730,7 @@ def _asset_specs_from_attr_key_params( # NodeDefinition skippable=False, auto_materialize_policy=None, + partitions_def=None, ) ) @@ -1789,6 +1791,38 @@ def get_self_dep_time_window_partition_mapping( return None +def _resolve_partitions_def( + specs: Optional[Sequence[AssetSpec]], partitions_def: Optional[PartitionsDefinition] +) -> Optional[PartitionsDefinition]: + if specs: + asset_keys_by_partitions_def = defaultdict(set) + for spec in specs: + asset_keys_by_partitions_def[spec.partitions_def].add(spec.key) + if len(asset_keys_by_partitions_def) > 1: + partition_1_asset_keys, partition_2_asset_keys, *_ = ( + asset_keys_by_partitions_def.values() + ) + check.failed( + f"All AssetSpecs must have the same partitions_def, but " + f"{next(iter(partition_1_asset_keys)).to_user_string()} and " + f"{next(iter(partition_2_asset_keys)).to_user_string()} have different " + "partitions_defs." + ) + common_partitions_def = next(iter(asset_keys_by_partitions_def.keys())) + if ( + common_partitions_def is not None + and partitions_def is not None + and common_partitions_def != partitions_def + ): + check.failed( + f"AssetSpec for {next(iter(specs)).key.to_user_string()} has partitions_def which is different " + "than the partitions_def provided to AssetsDefinition.", + ) + return partitions_def or common_partitions_def + else: + return partitions_def + + def get_partition_mappings_from_deps( partition_mappings: Dict[AssetKey, PartitionMapping], deps: Iterable[AssetDep], asset_name: str ) -> Mapping[AssetKey, PartitionMapping]: diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitioned_assets.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitioned_assets.py index bf5c2cf0aae37..3844fe5c8c48e 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitioned_assets.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_partitioned_assets.py @@ -7,6 +7,7 @@ AssetMaterialization, AssetOut, AssetsDefinition, + AssetSpec, DagsterInstance, DagsterInvalidDefinitionError, DailyPartitionsDefinition, @@ -762,3 +763,45 @@ def downstream_asset(context, upstream_asset): [downstream_asset, upstream_asset.to_source_asset()], partition_key="2020-01-02-05:00", ) + + +def test_asset_spec_partitions_def(): + partitions_def = DailyPartitionsDefinition(start_date="2020-01-01") + + @multi_asset( + specs=[AssetSpec("asset1", partitions_def=partitions_def)], partitions_def=partitions_def + ) + def assets1(): ... + + assert assets1.partitions_def == partitions_def + assert next(iter(assets1.specs)).partitions_def == partitions_def + + @multi_asset(specs=[AssetSpec("asset1", partitions_def=partitions_def)]) + def assets2(): ... + + assert assets2.partitions_def == partitions_def + assert next(iter(assets2.specs)).partitions_def == partitions_def + + with pytest.raises( + CheckError, + match="AssetSpec for asset1 has partitions_def which is different than the partitions_def provided to AssetsDefinition.", + ): + + @multi_asset( + specs=[AssetSpec("asset1", partitions_def=StaticPartitionsDefinition(["a", "b"]))], + partitions_def=partitions_def, + ) + def assets3(): ... + + with pytest.raises( + CheckError, + match="All AssetSpecs must have the same partitions_def, but asset1 and asset2 have different partitions_defs.", + ): + + @multi_asset( + specs=[ + AssetSpec("asset1", partitions_def=partitions_def), + AssetSpec("asset2", partitions_def=StaticPartitionsDefinition(["a", "b"])), + ], + ) + def assets4(): ... diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_asset_backfill.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_asset_backfill.py index c55b17c1cb48d..913bd2838f586 100644 --- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_asset_backfill.py +++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_asset_backfill.py @@ -813,54 +813,66 @@ def external_asset_graph_from_assets_by_repo_name( ) def test_serialization(static_serialization, time_window_serialization): time_window_partitions = DailyPartitionsDefinition(start_date="2015-05-05") - - @asset(partitions_def=time_window_partitions) - def daily_asset(): - return 1 - keys = ["a", "b", "c", "d", "e", "f"] static_partitions = StaticPartitionsDefinition(keys) - @asset(partitions_def=static_partitions) - def static_asset(): - return 1 + def make_asset_graph1(): + @asset(partitions_def=time_window_partitions) + def daily_asset(): ... - asset_graph = external_asset_graph_from_assets_by_repo_name( - {"repo": [daily_asset, static_asset]} - ) + @asset(partitions_def=static_partitions) + def static_asset(): ... - assert AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph) is True - assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph) is True + return external_asset_graph_from_assets_by_repo_name({"repo": [daily_asset, static_asset]}) - daily_asset._partitions_def = static_partitions # noqa: SLF001 - static_asset._partitions_def = time_window_partitions # noqa: SLF001 + asset_graph1 = make_asset_graph1() + assert AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph1) is True + assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph1) is True - asset_graph = external_asset_graph_from_assets_by_repo_name( - {"repo": [daily_asset, static_asset]} - ) + def make_asset_graph2(): + @asset(partitions_def=static_partitions) + def daily_asset(): ... - assert AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph) is False - assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph) is False + @asset(partitions_def=time_window_partitions) + def static_asset(): ... - static_asset._partitions_def = StaticPartitionsDefinition(keys + ["x"]) # noqa: SLF001 + return external_asset_graph_from_assets_by_repo_name({"repo": [daily_asset, static_asset]}) - asset_graph = external_asset_graph_from_assets_by_repo_name( - {"repo": [daily_asset, static_asset]} + asset_graph2 = make_asset_graph2() + assert ( + AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph2) is False ) + assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph2) is False - assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph) is True + def make_asset_graph3(): + @asset(partitions_def=StaticPartitionsDefinition(keys + ["x"])) + def daily_asset(): ... - @asset(partitions_def=static_partitions) - def daily_asset_renamed(): - return 1 + @asset(partitions_def=static_partitions) + def static_asset(): ... - asset_graph_renamed = external_asset_graph_from_assets_by_repo_name( - {"repo": [daily_asset_renamed, static_asset]} - ) + return external_asset_graph_from_assets_by_repo_name({"repo": [daily_asset, static_asset]}) + + asset_graph3 = make_asset_graph3() + + assert AssetBackfillData.is_valid_serialization(static_serialization, asset_graph3) is True + + def make_asset_graph4(): + @asset(partitions_def=static_partitions) + def daily_asset_renamed(): + return 1 + + @asset(partitions_def=time_window_partitions) + def static_asset(): ... + + return external_asset_graph_from_assets_by_repo_name( + {"repo": [daily_asset_renamed, static_asset]} + ) + + asset_graph4 = make_asset_graph4() assert ( - AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph_renamed) - is False + AssetBackfillData.is_valid_serialization(time_window_serialization, asset_graph4) is False ) diff --git a/python_modules/dagster/dagster_tests/definitions_tests/auto_materialize_tests/scenario_state.py b/python_modules/dagster/dagster_tests/definitions_tests/auto_materialize_tests/scenario_state.py index c8d7c591321b0..11d3acad81b5a 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/auto_materialize_tests/scenario_state.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/auto_materialize_tests/scenario_state.py @@ -4,7 +4,6 @@ import logging import os import sys -from collections import namedtuple from contextlib import contextmanager from dataclasses import dataclass, field from typing import AbstractSet, Iterable, NamedTuple, Optional, Sequence, Union, cast @@ -106,15 +105,6 @@ def _get_code_location_origin_from_repository(repository: RepositoryDefinition, ) -class AssetSpecWithPartitionsDef( - namedtuple( - "AssetSpecWithPartitionsDef", - AssetSpec._fields + ("partitions_def",), - defaults=(None,) * (1 + len(AssetSpec._fields)), - ) -): ... - - class MultiAssetSpec(NamedTuple): specs: Sequence[AssetSpec] partitions_def: Optional[PartitionsDefinition] = None @@ -125,7 +115,7 @@ class MultiAssetSpec(NamedTuple): class ScenarioSpec: """A construct for declaring and modifying a desired Definitions object.""" - asset_specs: Sequence[Union[AssetSpec, AssetSpecWithPartitionsDef, MultiAssetSpec]] + asset_specs: Sequence[Union[AssetSpec, MultiAssetSpec]] current_time: datetime.datetime = field(default_factory=lambda: get_current_datetime()) sensors: Sequence[SensorDefinition] = field(default_factory=list) additional_repo_specs: Sequence["ScenarioSpec"] = field(default_factory=list) @@ -160,20 +150,12 @@ def _multi_asset(context: AssetExecutionContext): ) # create an observable_source_asset or regular asset depending on the execution type if execution_type == AssetExecutionType.OBSERVATION: - if isinstance(spec, AssetSpecWithPartitionsDef): - sd = spec._asdict() - partitions_def = sd.pop("partitions_def") - specs = [AssetSpec(**sd)] - else: - partitions_def = None - specs = [spec] @op def noop(): ... osa = AssetsDefinition( - specs=specs, - partitions_def=partitions_def, + specs=[spec], execution_type=execution_type, keys_by_output_name={"result": spec.key}, node_def=noop, @@ -253,13 +235,7 @@ def with_asset_properties( ) else: if keys is None or spec.key in {AssetKey.from_coercible(key) for key in keys}: - if "partitions_def" in kwargs: - # partitions_def is not a field on AssetSpec, so we need to do this hack - new_asset_specs.append( - AssetSpecWithPartitionsDef(**{**spec._asdict(), **kwargs}) - ) - else: - new_asset_specs.append(spec._replace(**kwargs)) + new_asset_specs.append(spec._replace(**kwargs)) else: new_asset_specs.append(spec) return dataclasses.replace(self, asset_specs=new_asset_specs)