diff --git a/python_modules/dagster/dagster/_core/definitions/assets.py b/python_modules/dagster/dagster/_core/definitions/assets.py index 064507025685d..e273bdaa86c65 100644 --- a/python_modules/dagster/dagster/_core/definitions/assets.py +++ b/python_modules/dagster/dagster/_core/definitions/assets.py @@ -80,6 +80,7 @@ from dagster._utils.warnings import ExperimentalWarning, disable_dagster_warnings if TYPE_CHECKING: + from dagster._core.definitions.asset_checks import AssetChecksDefinition from dagster._core.definitions.graph_definition import GraphDefinition ASSET_SUBSET_INPUT_PREFIX = "__subset_input__" @@ -1180,6 +1181,30 @@ def get_op_def_for_asset_key(self, key: AssetKey) -> Optional[OpDefinition]: output_name = self.get_output_name_for_asset_key(key) return self.node_def.resolve_output_to_origin_op_def(output_name) + def coerce_to_checks_def(self) -> "AssetChecksDefinition": + from dagster._core.definitions.asset_checks import ( + AssetChecksDefinition, + has_only_asset_checks, + ) + + if not has_only_asset_checks(self): + raise DagsterInvalidDefinitionError( + "Cannot coerce an AssetsDefinition to an AssetChecksDefinition if it contains " + "non-check assets." + ) + if len(self.check_keys) == 0: + raise DagsterInvalidDefinitionError( + "Cannot coerce an AssetsDefinition to an AssetChecksDefinition if it contains no " + "checks." + ) + return AssetChecksDefinition.create( + keys_by_input_name=self.keys_by_input_name, + node_def=self.op, + check_specs_by_output_name=self.check_specs_by_output_name, + resource_defs=self.resource_defs, + can_subset=self.can_subset, + ) + def with_attributes( self, *, @@ -1903,15 +1928,18 @@ def replace_specs_on_asset( from dagster._builtins import Nothing from dagster._core.definitions.input import In - new_deps = set().union(*(spec.deps for spec in replaced_specs)) - previous_deps = set().union(*(spec.deps for spec in assets_def.specs)) - added_deps = new_deps - previous_deps - removed_deps = previous_deps - new_deps - remaining_original_deps = previous_deps - removed_deps + new_deps_by_key = {dep.asset_key: dep for spec in replaced_specs for dep in spec.deps} + previous_deps_by_key = {dep.asset_key: dep for spec in assets_def.specs for dep in spec.deps} + added_dep_keys = set(new_deps_by_key.keys()) - set(previous_deps_by_key.keys()) + removed_dep_keys = set(previous_deps_by_key.keys()) - set(new_deps_by_key.keys()) + remaining_original_deps_by_key = { + key: previous_deps_by_key[key] + for key in set(previous_deps_by_key.keys()) - removed_dep_keys + } original_key_to_input_mapping = reverse_dict(assets_def.node_keys_by_input_name) # If there are no changes to the dependency structure, we don't need to make any changes to the underlying node. - if not assets_def.is_executable or (not added_deps and not removed_deps): + if not assets_def.is_executable or (not added_dep_keys and not removed_dep_keys): return assets_def.__class__.dagster_internal_init( **{**assets_def.get_attributes_dict(), "specs": replaced_specs} ) @@ -1925,7 +1953,8 @@ def replace_specs_on_asset( "Can only add additional deps to an op-backed asset.", ) # for each deleted dep, we need to make sure it is not an argument-based dep. Argument-based deps cannot be removed. - for dep in removed_deps: + for dep_key in removed_dep_keys: + dep = previous_deps_by_key[dep_key] input_name = original_key_to_input_mapping[dep.asset_key] input_def = assets_def.node_def.input_def_named(input_name) check.invariant( @@ -1933,7 +1962,6 @@ def replace_specs_on_asset( f"Attempted to remove argument-backed dependency {dep.asset_key} (mapped to argument {input_name}) from the asset. Only non-argument dependencies can be changed or removed using map_asset_specs.", ) - remaining_original_deps_by_key = {dep.asset_key: dep for dep in remaining_original_deps} remaining_ins = { input_name: the_in for input_name, the_in in assets_def.node_def.input_dict.items() @@ -1943,7 +1971,7 @@ def replace_specs_on_asset( remaining_ins, { stringify_asset_key_to_input_name(dep.asset_key): In(dagster_type=Nothing) - for dep in new_deps + for dep in new_deps_by_key.values() }, ) diff --git a/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py b/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py index 63cd3730bff43..bdec14601d014 100644 --- a/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py +++ b/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py @@ -446,15 +446,11 @@ def replace_keys_in_asset( key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values() }, ) - if isinstance(asset, AssetsDefinition) and has_only_asset_checks(asset): - updated_object = AssetChecksDefinition.create( - keys_by_input_name=updated_object.keys_by_input_name, - node_def=updated_object.op, - check_specs_by_output_name=updated_object.check_specs_by_output_name, - resource_defs=updated_object.resource_defs, - can_subset=updated_object.can_subset, - ) - return updated_object + return ( + updated_object.coerce_to_checks_def() + if has_only_asset_checks(updated_object) + else updated_object + ) class ResolvedAssetObjectList: @@ -564,15 +560,11 @@ def with_attributes( ).with_attributes( backfill_policy=backfill_policy, freshness_policy=freshness_policy ) - if isinstance(asset, AssetChecksDefinition): - new_asset = AssetChecksDefinition.create( - keys_by_input_name=new_asset.keys_by_input_name, - node_def=new_asset.op, - check_specs_by_output_name=new_asset.check_specs_by_output_name, - resource_defs=new_asset.resource_defs, - can_subset=new_asset.can_subset, - ) - return_list.append(new_asset) + return_list.append( + new_asset.coerce_to_checks_def() + if has_only_asset_checks(new_asset) + else new_asset + ) elif isinstance(asset, SourceAsset): return_list.append( asset.with_attributes(group_name=group_name if group_name else asset.group_name) diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_asset_spec.py b/python_modules/dagster/dagster_tests/definitions_tests/test_asset_spec.py index a47062ec3a623..5e8643905acc8 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_asset_spec.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_asset_spec.py @@ -357,3 +357,44 @@ def foo(): with pytest.raises(CheckError): dg.map_asset_specs(lambda spec: spec.merge_attributes(deps=["baz"]), [foo]) + + +def test_static_partition_mapping_dep() -> None: + @dg.asset(partitions_def=dg.StaticPartitionsDefinition(["1", "2"])) + def b(): + pass + + @dg.multi_asset( + specs=[ + AssetSpec( + key="a", + partitions_def=dg.StaticPartitionsDefinition(["1", "2"]), + deps=[ + AssetDep("b", partition_mapping=dg.StaticPartitionMapping({"1": "1", "2": "2"})) + ], + ) + ] + ) + def my_asset(): + pass + + a_asset = next( + iter( + dg.map_asset_specs( + lambda spec: spec.merge_attributes( + deps=[ + AssetDep( + "c", partition_mapping=dg.StaticPartitionMapping({"1": "1", "2": "2"}) + ) + ] + ), + [my_asset], + ) + ) + ) + + a_spec = next(iter(a_asset.specs)) + b_dep = next(iter(dep for dep in a_spec.deps if dep.asset_key == AssetKey("b"))) + c_dep = next(iter(dep for dep in a_spec.deps if dep.asset_key == AssetKey("c"))) + assert b_dep.partition_mapping == dg.StaticPartitionMapping({"1": "1", "2": "2"}) + assert c_dep.partition_mapping == dg.StaticPartitionMapping({"1": "1", "2": "2"})