Skip to content

Commit

Permalink
Consolidate cacheable methods
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Dec 18, 2024
1 parent 3005364 commit 1c471fe
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 69 deletions.
15 changes: 14 additions & 1 deletion python_modules/dagster/dagster/_core/definitions/asset_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,9 @@ def replace_attributes(
skippable=skippable if skippable is not ... else self.skippable,
group_name=group_name if group_name is not ... else self.group_name,
code_version=code_version if code_version is not ... else self.code_version,
freshness_policy=self.freshness_policy,
freshness_policy=freshness_policy
if freshness_policy is not ...
else self.freshness_policy,
automation_condition=automation_condition
if automation_condition is not ...
else self.automation_condition,
Expand All @@ -340,6 +342,17 @@ def replace_attributes(
partitions_def=partitions_def if partitions_def is not ... else self.partitions_def,
)

def with_attributes(
self,
group_name: Optional[str] = ...,
automation_condition: Optional[AutomationCondition] = ...,
**kwargs,
) -> "AssetSpec":
"""Returns a new AssetSpec with the specified attributes replaced."""
return self.replace_attributes(
group_name=group_name, automation_condition=automation_condition
)

@public
def merge_attributes(
self,
Expand Down
7 changes: 7 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,14 @@ def with_attributes(
Union[AutomationCondition, Mapping[AssetKey, AutomationCondition]]
] = None,
backfill_policy: Optional[BackfillPolicy] = None,
group_name: Optional[str] = None,
) -> "AssetsDefinition":
check.invariant(
not (group_name and group_names_by_key),
"Cannot use both group_name, which specifies the group for every contained asset, and group_names_by_key, which specifies group on a per-asset basis.",
)
if group_name:
group_names_by_key = {key: group_name for key in self.keys}
conflicts_by_attr_name: Dict[str, Set[AssetKey]] = defaultdict(set)
replaced_specs = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.backfill_policy import BackfillPolicy
from dagster._core.definitions.declarative_automation.automation_condition import (
AutomationCondition,
)
from dagster._core.definitions.events import AssetKey, CoercibleToAssetKeyPrefix
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.metadata import RawMetadataMapping
Expand Down Expand Up @@ -159,13 +162,21 @@ def with_attributes(
freshness_policy: Optional[
Union[FreshnessPolicy, Mapping[AssetKey, FreshnessPolicy]]
] = None,
group_name: Optional[str] = None,
backfill_policy: Optional[BackfillPolicy] = None,
automation_condition: Optional[AutomationCondition] = None,
) -> "CacheableAssetsDefinition":
return PrefixOrGroupWrappedCacheableAssetsDefinition(
self,
output_asset_key_replacements=output_asset_key_replacements,
input_asset_key_replacements=input_asset_key_replacements,
group_names_by_key=group_names_by_key,
freshness_policy=freshness_policy,
group_name_for_all_assets=group_name,
auto_materialize_policy=automation_condition.as_auto_materialize_policy()
if automation_condition
else None,
backfill_policy=backfill_policy,
)

def with_prefix_for_all(self, prefix: CoercibleToAssetKeyPrefix) -> "CacheableAssetsDefinition":
Expand All @@ -178,25 +189,6 @@ def with_prefix_for_all(self, prefix: CoercibleToAssetKeyPrefix) -> "CacheableAs
prefix = check.is_list(prefix, of_type=str)
return PrefixOrGroupWrappedCacheableAssetsDefinition(self, prefix_for_all_assets=prefix)

def with_attributes_for_all(
self,
group_name: Optional[str],
freshness_policy: Optional[FreshnessPolicy],
auto_materialize_policy: Optional[AutoMaterializePolicy],
backfill_policy: Optional[BackfillPolicy],
) -> "CacheableAssetsDefinition":
"""Utility method which allows setting attributes for all assets in this
CacheableAssetsDefinition, since the keys may not be known at the time of
construction.
"""
return PrefixOrGroupWrappedCacheableAssetsDefinition(
self,
group_name_for_all_assets=group_name,
freshness_policy=freshness_policy,
auto_materialize_policy=auto_materialize_policy,
backfill_policy=backfill_policy,
)


class WrappedCacheableAssetsDefinition(CacheableAssetsDefinition):
"""Wraps an instance of CacheableAssetsDefinition, applying transformed_assets_def to the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.source_asset import SourceAsset
from dagster._core.definitions.utils import DEFAULT_GROUP_NAME, resolve_automation_condition
from dagster._core.definitions.utils import resolve_automation_condition
from dagster._core.errors import DagsterInvalidDefinitionError


Expand Down Expand Up @@ -154,28 +154,6 @@ def to_post_load(self) -> "ResolvedAssetObjectList":
return ResolvedAssetObjectList(self.deduped_objects)


def _spec_mapper_disallow_group_override(
group_name: Optional[str],
automation_condition: Optional[AutomationCondition],
) -> Callable[[AssetSpec], AssetSpec]:
def _inner(spec: AssetSpec) -> AssetSpec:
if (
group_name is not None
and spec.group_name is not None
and group_name != spec.group_name
and spec.group_name != DEFAULT_GROUP_NAME
):
raise DagsterInvalidDefinitionError(
f"Asset spec {spec.key.to_user_string()} has group name {spec.group_name}, which conflicts with the group name {group_name} provided in load_assets_from_modules."
)
return spec.replace_attributes(
group_name=group_name if group_name else ...,
automation_condition=automation_condition if automation_condition else ...,
)

return _inner


def key_iterator(
asset: Union[AssetsDefinition, SourceAsset, AssetSpec], included_targeted_keys: bool = False
) -> Iterator[AssetKey]:
Expand Down Expand Up @@ -568,30 +546,12 @@ def with_attributes(
)
return_list = []
for asset in assets_list.loaded_objects:
if isinstance(asset, AssetsDefinition):
new_asset = asset.map_asset_specs(
_spec_mapper_disallow_group_override(group_name, automation_condition)
).with_attributes(
backfill_policy=backfill_policy, freshness_policy=freshness_policy
)
return_list.append(new_asset)
elif isinstance(asset, SourceAsset):
return_list.append(
asset.with_attributes(group_name=group_name if group_name else asset.group_name)
)
elif isinstance(asset, AssetSpec):
return_list.append(
_spec_mapper_disallow_group_override(group_name, automation_condition)(asset)
)
else:
return_list.append(
asset.with_attributes_for_all(
group_name,
freshness_policy=freshness_policy,
auto_materialize_policy=automation_condition.as_auto_materialize_policy()
if automation_condition
else None,
backfill_policy=backfill_policy,
)
return_list.append(
asset.with_attributes(
group_name=group_name,
freshness_policy=freshness_policy,
automation_condition=automation_condition,
backfill_policy=backfill_policy,
)
)
return ResolvedAssetObjectList(return_list)
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def with_resources(self, resource_defs) -> "SourceAsset":
)

def with_attributes(
self, group_name: Optional[str] = None, key: Optional[AssetKey] = None
self, group_name: Optional[str] = None, key: Optional[AssetKey] = None, **_kwargs
) -> "SourceAsset":
if group_name is not None and self.group_name != DEFAULT_GROUP_NAME:
raise DagsterInvalidDefinitionError(
Expand Down

0 comments on commit 1c471fe

Please sign in to comment.