From 3d5141f49164a63e666e77d68ab2c93a1c27ead4 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Fri, 14 Jun 2024 12:47:41 -0700 Subject: [PATCH] `AssetGraphComputation` inside `AssetsDefinition` (#22554) ## Summary & Motivation Sprouting off of Nick's comment [here](https://github.com/dagster-io/dagster/pull/22165#discussion_r1638984982), this moves all the computation-related properties of `AssetsDefinition` into their own class. While in this PR it's restricted to the internals of `AssetsDefinition`, I think that this object could eventually take on a wider role. It basically corresponds to the nodes on the right in this diagram I've been bandying about: image ## How I Tested These Changes --- .../dagster/_core/definitions/assets.py | 215 ++++++++++-------- .../asset_defs_tests/test_assets.py | 2 +- 2 files changed, 126 insertions(+), 91 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/assets.py b/python_modules/dagster/dagster/_core/definitions/assets.py index 37a5b0e8efcf9..f4efbfc909887 100644 --- a/python_modules/dagster/dagster/_core/definitions/assets.py +++ b/python_modules/dagster/dagster/_core/definitions/assets.py @@ -59,6 +59,7 @@ DagsterInvalidInvocationError, DagsterInvariantViolationError, ) +from dagster._model import dagster_model from dagster._utils import IHasInternalInit from dagster._utils.merger import merge_dicts from dagster._utils.security import non_secure_md5_hash_str @@ -86,6 +87,28 @@ ASSET_SUBSET_INPUT_PREFIX = "__subset_input__" +@dagster_model +class AssetGraphComputation: + """A computation whose purpose is to materialize assets, observe assets, and/or evaluate asset + checks. + + Binds a NodeDefinition to the asset keys and asset check keys that it interacts with. + """ + + node_def: NodeDefinition + keys_by_input_name: Mapping[str, AssetKey] + keys_by_output_name: Mapping[str, AssetKey] + backfill_policy: Optional[BackfillPolicy] + can_subset: bool + is_subset: bool + selected_asset_keys: AbstractSet[AssetKey] + selected_asset_check_keys: AbstractSet[AssetCheckKey] + + @cached_property + def output_names_by_key(self): + return {key: name for name, key in self.keys_by_output_name.items()} + + class AssetsDefinition(ResourceAddable, RequiresResources, IHasInternalInit): """Defines a set of assets that are produced by the same op or graph. @@ -106,21 +129,14 @@ class AssetsDefinition(ResourceAddable, RequiresResources, IHasInternalInit): "owners_by_key", } - _node_def: Optional[NodeDefinition] - _keys_by_input_name: Mapping[str, AssetKey] - _keys_by_output_name: Mapping[str, AssetKey] _partitions_def: Optional[PartitionsDefinition] # partition mappings are also tracked inside the AssetSpecs, but this enables faster access by # upstream asset key _partition_mappings: Mapping[AssetKey, PartitionMapping] _resource_defs: Mapping[str, ResourceDefinition] - _selected_asset_keys: AbstractSet[AssetKey] - _can_subset: bool - _backfill_policy: Optional[BackfillPolicy] - _selected_asset_check_keys: AbstractSet[AssetCheckKey] - _is_subset: bool _specs_by_key: Mapping[AssetKey, AssetSpec] + _computation: Optional[AssetGraphComputation] @experimental_param(param="specs") def __init__( @@ -181,19 +197,45 @@ def __init__( check.invariant( not can_subset, "node_def is None, so backfill_policy must not be provided" ) + self._computation = None + else: + selected_asset_keys, selected_asset_check_keys = _resolve_selections( + all_asset_keys={spec.key for spec in specs} + if specs + else set(check.not_none(keys_by_output_name).values()), + all_check_keys={spec.key for spec in (check_specs_by_output_name or {}).values()}, + selected_asset_keys=selected_asset_keys, + selected_asset_check_keys=selected_asset_check_keys, + ) - self._node_def = node_def - self._keys_by_input_name = check.opt_mapping_param( - keys_by_input_name, - "keys_by_input_name", - key_type=str, - value_type=AssetKey, - ) - self._keys_by_output_name = check.opt_mapping_param( - keys_by_output_name, - "keys_by_output_name", + self._computation = AssetGraphComputation( + node_def=node_def, + keys_by_input_name=check.opt_mapping_param( + keys_by_input_name, + "keys_by_input_name", + key_type=str, + value_type=AssetKey, + ), + keys_by_output_name=check.opt_mapping_param( + keys_by_output_name, + "keys_by_output_name", + key_type=str, + value_type=AssetKey, + ), + can_subset=can_subset, + backfill_policy=check.opt_inst_param( + backfill_policy, "backfill_policy", BackfillPolicy + ), + is_subset=check.bool_param(is_subset, "is_subset"), + selected_asset_keys=selected_asset_keys, + selected_asset_check_keys=selected_asset_check_keys, + ) + + check_specs_by_output_name = check.opt_mapping_param( + check_specs_by_output_name, + "check_specs_by_output_name", key_type=str, - value_type=AssetKey, + value_type=AssetCheckSpec, ) self._check_specs_by_output_name = check.opt_mapping_param( @@ -209,12 +251,6 @@ def __init__( check.opt_mapping_param(resource_defs, "resource_defs") ) - self._can_subset = can_subset - - self._backfill_policy = check.opt_inst_param( - backfill_policy, "backfill_policy", BackfillPolicy - ) - if self._partitions_def is None: # check if backfill policy is BackfillPolicyType.SINGLE_RUN if asset is not partitioned check.param_invariant( @@ -227,8 +263,6 @@ def __init__( "Non partitioned asset can only have single run backfill policy", ) - self._is_subset = check.bool_param(is_subset, "is_subset") - if specs is not None: check.invariant(group_names_by_key is None) check.invariant(metadata_by_key is None) @@ -242,7 +276,10 @@ def __init__( resolved_specs = specs else: - all_asset_keys = set(self._keys_by_output_name.values()) + computation_not_none = check.not_none( + self._computation, "If specs are not provided, a node_def must be provided" + ) + all_asset_keys = set(computation_not_none.keys_by_output_name.values()) if asset_deps: check.invariant( @@ -256,7 +293,7 @@ def __init__( if partition_mappings: _validate_partition_mappings( partition_mappings=partition_mappings, - input_asset_keys=set(self._keys_by_input_name.values()), + input_asset_keys=set(computation_not_none.keys_by_input_name.values()), all_asset_keys=all_asset_keys, ) @@ -264,7 +301,7 @@ def __init__( resolved_specs = _asset_specs_from_attr_key_params( all_asset_keys=all_asset_keys, - keys_by_input_name=self._keys_by_input_name, + keys_by_input_name=computation_not_none.keys_by_input_name, deps_by_asset_key=asset_deps, partition_mappings=partition_mappings, tags_by_key=tags_by_key, @@ -278,7 +315,6 @@ def __init__( ) normalized_specs: List[AssetSpec] = [] - output_names_by_key = {key: name for name, key in self._keys_by_output_name.items()} for spec in resolved_specs: if spec.owners: @@ -287,11 +323,11 @@ def __init__( group_name = normalize_group_name(spec.group_name) - if node_def is not None: - output_def, _ = node_def.resolve_output_to_origin( - output_names_by_key[spec.key], None + if self._computation is not None: + output_def, _ = self._computation.node_def.resolve_output_to_origin( + self._computation.output_names_by_key[spec.key], None ) - node_def_description = node_def.description + node_def_description = self._computation.node_def.description output_def_metadata = output_def.metadata output_def_description = output_def.description output_def_code_version = output_def.code_version @@ -336,32 +372,25 @@ def __init__( [dep for spec in normalized_specs for dep in spec.deps], node_def.name if node_def else "external assets", ) - self._selected_asset_keys, self._selected_asset_check_keys = _resolve_selections( - all_asset_keys=self._specs_by_key.keys(), - all_check_keys={spec.key for spec in (check_specs_by_output_name or {}).values()}, - selected_asset_keys=selected_asset_keys, - selected_asset_check_keys=selected_asset_check_keys, - ) self._check_specs_by_key = { - spec.key: spec - for spec in self._check_specs_by_output_name.values() - if spec.key in self._selected_asset_check_keys + spec.key: spec for spec in self._check_specs_by_output_name.values() } - _validate_self_deps( - input_keys=[ - key - # filter out the special inputs which are used for cases when a multi-asset is - # subsetted, as these are not the same as self-dependencies and are never loaded - # in the same step that their corresponding output is produced - for input_name, key in self._keys_by_input_name.items() - if not input_name.startswith(ASSET_SUBSET_INPUT_PREFIX) - ], - output_keys=self._selected_asset_keys, - partition_mappings=self._partition_mappings, - partitions_def=self._partitions_def, - ) + if self._computation: + _validate_self_deps( + input_keys=[ + key + # filter out the special inputs which are used for cases when a multi-asset is + # subsetted, as these are not the same as self-dependencies and are never loaded + # in the same step that their corresponding output is produced + for input_name, key in self._computation.keys_by_input_name.items() + if not input_name.startswith(ASSET_SUBSET_INPUT_PREFIX) + ], + output_keys=self._computation.selected_asset_keys, + partition_mappings=self._partition_mappings, + partitions_def=self._partitions_def, + ) def dagster_internal_init( *, @@ -400,8 +429,10 @@ def __call__(self, *args: object, **kwargs: object) -> object: from .graph_definition import GraphDefinition # defer to GraphDefinition.__call__ for graph backed assets, or if invoked in composition - if isinstance(self._node_def, GraphDefinition) or is_in_composition(): - return check.not_none(self._node_def)(*args, **kwargs) + if ( + self._computation and isinstance(self._computation.node_def, GraphDefinition) + ) or is_in_composition(): + return self.node_def(*args, **kwargs) # invoke against self to allow assets def information to be used return direct_invocation_result(self, *args, **kwargs) @@ -769,7 +800,7 @@ def can_subset(self) -> bool: asset keys in a given computation (as opposed to being required to materialize all asset keys). """ - return self._can_subset + return self._computation.can_subset if self._computation else False @property def specs(self) -> Iterable[AssetSpec]: @@ -807,11 +838,12 @@ def op(self) -> OpDefinition: """OpDefinition: Returns the OpDefinition that is used to materialize the assets in this AssetsDefinition. """ + node_def = self.node_def check.invariant( - isinstance(self._node_def, OpDefinition), + isinstance(node_def, OpDefinition), "The NodeDefinition for this AssetsDefinition is not of type OpDefinition.", ) - return cast(OpDefinition, self._node_def) + return cast(OpDefinition, node_def) @public @property @@ -819,7 +851,7 @@ def node_def(self) -> NodeDefinition: """NodeDefinition: Returns the OpDefinition or GraphDefinition that is used to materialize the assets in this AssetsDefinition. """ - return check.not_none(self._node_def, "This AssetsDefinition has no node_def") + return check.not_none(self._computation, "This AssetsDefinition has no node_def").node_def @public @property @@ -849,7 +881,7 @@ def key(self) -> AssetKey: check.invariant( len(self.keys) == 1, "Tried to retrieve asset key from an assets definition with multiple asset keys: " - + ", ".join([str(ak.to_string()) for ak in self._keys_by_output_name.values()]), + + ", ".join([str(ak.to_string()) for ak in self.keys]), ) return next(iter(self.keys)) @@ -866,7 +898,10 @@ def resource_defs(self) -> Mapping[str, ResourceDefinition]: @property def keys(self) -> AbstractSet[AssetKey]: """AbstractSet[AssetKey]: The asset keys associated with this AssetsDefinition.""" - return self._selected_asset_keys + if self._computation: + return self._computation.selected_asset_keys + else: + return self._specs_by_key.keys() @property def has_keys(self) -> bool: @@ -883,19 +918,17 @@ def dependency_keys(self) -> Iterable[AssetKey]: AssetsDefinition. """ # the input asset keys that are directly upstream of a selected asset key - upstream_keys = {dep_key for key in self.keys for dep_key in self.asset_deps[key]} - input_keys = set(self._keys_by_input_name.values()) - return upstream_keys.intersection(input_keys) + return {dep_key for key in self.keys for dep_key in self.asset_deps[key]} @property def node_keys_by_output_name(self) -> Mapping[str, AssetKey]: """AssetKey for each output on the underlying NodeDefinition.""" - return self._keys_by_output_name + return self._computation.keys_by_output_name if self._computation else {} @property def node_keys_by_input_name(self) -> Mapping[str, AssetKey]: """AssetKey for each input on the underlying NodeDefinition.""" - return self._keys_by_input_name + return self._computation.keys_by_input_name if self._computation else {} @property def node_check_specs_by_output_name(self) -> Mapping[str, AssetCheckSpec]: @@ -907,7 +940,7 @@ def check_specs_by_output_name(self) -> Mapping[str, AssetCheckSpec]: return { name: spec for name, spec in self._check_specs_by_output_name.items() - if spec.key in self._selected_asset_check_keys + if self._computation is None or spec.key in self._computation.selected_asset_check_keys } def get_spec_for_check_key(self, asset_check_key: AssetCheckKey) -> AssetCheckSpec: @@ -990,7 +1023,7 @@ def _get_external_asset_metadata_value(self, metadata_key: str) -> object: @property def backfill_policy(self) -> Optional[BackfillPolicy]: - return self._backfill_policy + return self._computation.backfill_policy if self._computation else None @public @property @@ -1048,7 +1081,11 @@ def check_keys(self) -> AbstractSet[AssetCheckKey]: AbstractSet[Tuple[AssetKey, str]]: The selected asset checks. An asset check is identified by the asset key and the name of the check. """ - return self._selected_asset_check_keys + if self._computation: + return self._computation.selected_asset_check_keys + else: + check.invariant(not self._check_specs_by_output_name) + return set() @property def check_key(self) -> AssetCheckKey: @@ -1062,7 +1099,7 @@ def check_key(self) -> AssetCheckKey: @property def execution_type(self) -> AssetExecutionType: - if self._node_def is None: + if self._computation is None: return AssetExecutionType.UNEXECUTABLE value = self._get_external_asset_metadata_value(SYSTEM_METADATA_KEY_ASSET_EXECUTION_TYPE) @@ -1123,7 +1160,7 @@ def get_op_def_for_asset_key(self, key: AssetKey) -> Optional[OpDefinition]: """If this is an op-backed asset, returns the op def. If it's a graph-backed asset, returns the op def within the graph that produces the given asset key. """ - if self._node_def is None: + if self._computation is None: return None output_name = self.get_output_name_for_asset_key(key) @@ -1224,15 +1261,13 @@ def update_replace_dict_and_conflicts( replaced_attributes = dict( keys_by_input_name={ input_name: input_asset_key_replacements.get(key, key) - for input_name, key in self._keys_by_input_name.items() + for input_name, key in self.node_keys_by_input_name.items() }, keys_by_output_name={ output_name: output_asset_key_replacements.get(key, key) - for output_name, key in self._keys_by_output_name.items() - }, - selected_asset_keys={ - output_asset_key_replacements.get(key, key) for key in self._selected_asset_keys + for output_name, key in self.node_keys_by_output_name.items() }, + selected_asset_keys={output_asset_key_replacements.get(key, key) for key in self.keys}, backfill_policy=backfill_policy if backfill_policy else self.backfill_policy, is_subset=self.is_subset, check_specs_by_output_name=check_specs_by_output_name, @@ -1360,7 +1395,7 @@ def subset_for( @property def is_subset(self) -> bool: - return self._is_subset + return self._computation.is_subset if self._computation else False @public def to_source_assets(self) -> Sequence[SourceAsset]: @@ -1416,7 +1451,7 @@ def _output_to_source_asset(self, output_name: str) -> SourceAsset: output_def = self.node_def.resolve_output_to_origin( output_name, NodeHandle(self.node_def.name, parent=None) )[0] - key = self._keys_by_output_name[output_name] + key = self.node_keys_by_output_name[output_name] spec = self.specs_by_key[key] return SourceAsset( @@ -1431,7 +1466,7 @@ def _output_to_source_asset(self, output_name: str) -> SourceAsset: ) def get_io_manager_key_for_asset_key(self, key: AssetKey) -> str: - if self._node_def is None: + if self._computation is None: return self._specs_by_key[key].metadata.get( SYSTEM_METADATA_KEY_IO_MANAGER_KEY, DEFAULT_IO_MANAGER_KEY ) @@ -1488,16 +1523,16 @@ def with_resources(self, resource_defs: Mapping[str, ResourceDefinition]) -> "As def get_attributes_dict(self) -> Dict[str, Any]: return dict( - keys_by_input_name=self._keys_by_input_name, - keys_by_output_name=self._keys_by_output_name, - node_def=self._node_def, + keys_by_input_name=self.node_keys_by_input_name, + keys_by_output_name=self.node_keys_by_output_name, + node_def=self._computation.node_def if self._computation else None, partitions_def=self._partitions_def, - selected_asset_keys=self._selected_asset_keys, - can_subset=self._can_subset, + selected_asset_keys=self.keys, + can_subset=self.can_subset, resource_defs=self._resource_defs, - backfill_policy=self._backfill_policy, + backfill_policy=self.backfill_policy, check_specs_by_output_name=self._check_specs_by_output_name, - selected_asset_check_keys=self._selected_asset_check_keys, + selected_asset_check_keys=self.check_keys, specs=self.specs, is_subset=self.is_subset, ) diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets.py index 14311047e928b..2c54e1dbedd9b 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets.py @@ -2179,7 +2179,7 @@ def op1(): def test_construct_assets_definition_no_args() -> None: - with pytest.raises(CheckError, match="Must provide node_def if not providing specs"): + with pytest.raises(CheckError, match="If specs are not provided, a node_def must be provided"): AssetsDefinition()