diff --git a/examples/starlift-demo/dbt_example/dagster_defs/utils.py b/examples/starlift-demo/dbt_example/dagster_defs/utils.py index ea4fa0740c2e4..1087a7d454573 100644 --- a/examples/starlift-demo/dbt_example/dagster_defs/utils.py +++ b/examples/starlift-demo/dbt_example/dagster_defs/utils.py @@ -2,9 +2,7 @@ from dagster import AssetsDefinition, AssetSpec, AutomationCondition, Definitions, Nothing from dagster._core.definitions.asset_key import AssetKey -from dagster._core.definitions.decorators.decorator_assets_definition_builder import ( - stringify_asset_key_to_input_name, -) +from dagster._core.definitions.assets import stringify_asset_key_to_input_name from dagster._core.definitions.input import In diff --git a/python_modules/dagster/dagster/_core/definitions/assets.py b/python_modules/dagster/dagster/_core/definitions/assets.py index d616fb0361f8d..064507025685d 100644 --- a/python_modules/dagster/dagster/_core/definitions/assets.py +++ b/python_modules/dagster/dagster/_core/definitions/assets.py @@ -74,7 +74,7 @@ ) from dagster._core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError from dagster._utils import IHasInternalInit -from dagster._utils.merger import merge_dicts +from dagster._utils.merger import merge_dicts, reverse_dict from dagster._utils.security import non_secure_md5_hash_str from dagster._utils.tags import normalize_tags from dagster._utils.warnings import ExperimentalWarning, disable_dagster_warnings @@ -85,6 +85,10 @@ ASSET_SUBSET_INPUT_PREFIX = "__subset_input__" +def stringify_asset_key_to_input_name(asset_key: AssetKey) -> str: + return "_".join(asset_key.path).replace("-", "_") + + class AssetsDefinition(ResourceAddable, IHasInternalInit): """Defines a set of assets that are produced by the same op or graph. @@ -932,6 +936,10 @@ def node_keys_by_input_name(self) -> Mapping[str, AssetKey]: """AssetKey for each input on the underlying NodeDefinition.""" return self._computation.keys_by_input_name if self._computation else {} + @property + def input_names_by_node_key(self) -> Mapping[AssetKey, str]: + return {key: input_name for input_name, key in self.node_keys_by_input_name.items()} + @property def node_check_specs_by_output_name(self) -> Mapping[str, AssetCheckSpec]: """AssetCheckSpec for each output on the underlying NodeDefinition.""" @@ -1293,20 +1301,10 @@ def map_asset_specs(self, fn: Callable[[AssetSpec], AssetSpec]) -> "AssetsDefini f"Asset key {spec.key.to_user_string()} was changed to " f"{mapped_spec.key.to_user_string()}. Mapping function must not change keys." ) - if ( - # check reference equality first for performance - mapped_spec.deps is not spec.deps and mapped_spec.deps != spec.deps - ): - raise DagsterInvalidDefinitionError( - f"Asset deps {spec.deps} were changed to {mapped_spec.deps}. Mapping function " - "must not change deps." - ) mapped_specs.append(mapped_spec) - return self.__class__.dagster_internal_init( - **{**self.get_attributes_dict(), "specs": mapped_specs} - ) + return replace_specs_on_asset(self, mapped_specs) def subset_for( self, @@ -1897,3 +1895,64 @@ def unique_id_from_asset_and_check_keys(entity_keys: Iterable["EntityKey"]) -> s """ sorted_key_strs = sorted(str(key) for key in entity_keys) return non_secure_md5_hash_str(json.dumps(sorted_key_strs).encode("utf-8"))[:8] + + +def replace_specs_on_asset( + assets_def: AssetsDefinition, replaced_specs: Sequence[AssetSpec] +) -> "AssetsDefinition": + from dagster._builtins import Nothing + from dagster._core.definitions.input import In + + new_deps = set().union(*(spec.deps for spec in replaced_specs)) + previous_deps = set().union(*(spec.deps for spec in assets_def.specs)) + added_deps = new_deps - previous_deps + removed_deps = previous_deps - new_deps + remaining_original_deps = previous_deps - removed_deps + original_key_to_input_mapping = reverse_dict(assets_def.node_keys_by_input_name) + + # If there are no changes to the dependency structure, we don't need to make any changes to the underlying node. + if not assets_def.is_executable or (not added_deps and not removed_deps): + return assets_def.__class__.dagster_internal_init( + **{**assets_def.get_attributes_dict(), "specs": replaced_specs} + ) + + # Otherwise, there are changes to the dependency structure. We need to update the node_def. + # Graph-backed assets do not currently support non-argument-based deps. Every argument to a graph-backed asset + # must map to an an input on an internal asset node in the graph structure. + # IMPROVEME BUILD-529 + check.invariant( + isinstance(assets_def.node_def, OpDefinition), + "Can only add additional deps to an op-backed asset.", + ) + # for each deleted dep, we need to make sure it is not an argument-based dep. Argument-based deps cannot be removed. + for dep in removed_deps: + input_name = original_key_to_input_mapping[dep.asset_key] + input_def = assets_def.node_def.input_def_named(input_name) + check.invariant( + input_def.dagster_type.is_nothing, + f"Attempted to remove argument-backed dependency {dep.asset_key} (mapped to argument {input_name}) from the asset. Only non-argument dependencies can be changed or removed using map_asset_specs.", + ) + + remaining_original_deps_by_key = {dep.asset_key: dep for dep in remaining_original_deps} + remaining_ins = { + input_name: the_in + for input_name, the_in in assets_def.node_def.input_dict.items() + if assets_def.node_keys_by_input_name[input_name] in remaining_original_deps_by_key + } + all_ins = merge_dicts( + remaining_ins, + { + stringify_asset_key_to_input_name(dep.asset_key): In(dagster_type=Nothing) + for dep in new_deps + }, + ) + + return assets_def.__class__.dagster_internal_init( + **{ + **assets_def.get_attributes_dict(), + "node_def": assets_def.op.with_replaced_properties( + name=assets_def.op.name, ins=all_ins + ), + "specs": replaced_specs, + } + ) diff --git a/python_modules/dagster/dagster/_core/definitions/decorators/decorator_assets_definition_builder.py b/python_modules/dagster/dagster/_core/definitions/decorators/decorator_assets_definition_builder.py index db4b32a374dbe..0eab760bc93e7 100644 --- a/python_modules/dagster/dagster/_core/definitions/decorators/decorator_assets_definition_builder.py +++ b/python_modules/dagster/dagster/_core/definitions/decorators/decorator_assets_definition_builder.py @@ -34,6 +34,7 @@ ASSET_SUBSET_INPUT_PREFIX, AssetsDefinition, get_partition_mappings_from_deps, + stringify_asset_key_to_input_name, ) from dagster._core.definitions.backfill_policy import BackfillPolicy from dagster._core.definitions.decorators.op_decorator import _Op @@ -55,10 +56,6 @@ ) -def stringify_asset_key_to_input_name(asset_key: AssetKey) -> str: - return "_".join(asset_key.path).replace("-", "_") - - def get_function_params_without_context_or_config_or_resources( fn: Callable[..., Any], ) -> List[Parameter]: diff --git a/python_modules/dagster/dagster/_core/definitions/op_definition.py b/python_modules/dagster/dagster/_core/definitions/op_definition.py index 58bdfcb9b1c2f..c70a520409654 100644 --- a/python_modules/dagster/dagster/_core/definitions/op_definition.py +++ b/python_modules/dagster/dagster/_core/definitions/op_definition.py @@ -367,12 +367,14 @@ def with_replaced_properties( ) -> "OpDefinition": return OpDefinition.dagster_internal_init( name=name, - ins=ins - or {input_def.name: In.from_definition(input_def) for input_def in self.input_defs}, - outs=outs - or { + ins={input_def.name: In.from_definition(input_def) for input_def in self.input_defs} + if ins is None + else ins, + outs={ output_def.name: Out.from_definition(output_def) for output_def in self.output_defs - }, + } + if outs is None + else outs, compute_fn=self.compute_fn, config_schema=config_schema or self.config_schema, description=description or self.description, diff --git a/python_modules/dagster/dagster/_utils/merger.py b/python_modules/dagster/dagster/_utils/merger.py index dc0e6401cde86..9995315a294a1 100644 --- a/python_modules/dagster/dagster/_utils/merger.py +++ b/python_modules/dagster/dagster/_utils/merger.py @@ -59,3 +59,13 @@ def merge_dicts(*args: Mapping[Any, Any]) -> Dict[Any, Any]: for arg in args: result.update(arg) return result + + +def reverse_dict(d: Mapping[V, K]) -> Dict[K, V]: + """Returns a new dictionary with the keys and values of the input dictionary swapped. + + If the input dictionary has duplicate values, the returned dictionary will have the value from + the last key that maps to it. + """ + check.dict_param(d, "d") + return {v: k for k, v in d.items()} diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_asset_spec.py b/python_modules/dagster/dagster_tests/definitions_tests/test_asset_spec.py index ddbe373d04b22..a47062ec3a623 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_asset_spec.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_asset_spec.py @@ -2,7 +2,14 @@ import dagster as dg import pytest -from dagster import AssetSpec, AutoMaterializePolicy, AutomationCondition +from dagster import ( + AssetSpec, + AutoMaterializePolicy, + AutomationCondition, + IdentityPartitionMapping, + LastPartitionMapping, +) +from dagster._check import CheckError from dagster._core.definitions.asset_dep import AssetDep from dagster._core.definitions.asset_key import AssetKey from dagster._core.definitions.assets import AssetsDefinition @@ -230,3 +237,123 @@ def my_other_multi_asset(): assert all( spec.owners == ["ben@dagsterlabs.com"] for asset in mapped_assets for spec in asset.specs ) + + +def test_map_asset_specs_additional_deps() -> None: + @dg.multi_asset(specs=[AssetSpec(key="a")]) + def my_asset(): + pass + + @dg.multi_asset(specs=[AssetSpec(key="c", deps=["a"])]) + def my_other_asset(): + pass + + assets = [my_asset, my_other_asset] + + mapped_assets = dg.map_asset_specs( + lambda spec: spec.merge_attributes(deps=["b"]) if spec.key == my_other_asset.key else spec, + assets, + ) + + c_asset = next(iter(asset for asset in mapped_assets if asset.key == my_other_asset.key)) + assert set(next(iter(c_asset.specs)).deps) == {AssetDep("a"), AssetDep("b")} + + +def test_map_asset_specs_multiple_deps_same_key() -> None: + @dg.multi_asset(specs=[AssetSpec(key="a", deps=[AssetDep("b")])]) + def my_asset(): + pass + + # This works because the dep is coerced to an identical object. + + dg.map_asset_specs(lambda spec: spec.merge_attributes(deps=[AssetKey("b")]), [my_asset]) + + # This doesn't work because we change the object. + with pytest.raises(DagsterInvariantViolationError): + dg.map_asset_specs( + lambda spec: spec.merge_attributes( + deps=[AssetDep(AssetKey("b"), partition_mapping=LastPartitionMapping())] + ), + [my_asset], + ) + + +def test_map_asset_specs_nonarg_dep_removal() -> None: + @dg.multi_asset(specs=[AssetSpec(key="a", deps=[AssetDep("b")])]) + def my_asset(): + pass + + new_asset = next( + iter(dg.map_asset_specs(lambda spec: spec.replace_attributes(deps=[]), [my_asset])) + ) + new_spec = next(iter(new_asset.specs)) + assert new_spec.deps == [] + # Ensure that dep removal propogated to the underlying op + assert new_asset.keys_by_input_name == {} + assert len(new_asset.op.input_defs) == 0 + + +def test_map_asset_specs_arg_dep_removal() -> None: + @dg.asset(key="a") + def my_asset(b): + pass + + with pytest.raises(CheckError): + dg.map_asset_specs(lambda spec: spec.replace_attributes(deps=[]), [my_asset]) + + +def test_map_additional_deps_partition_mapping() -> None: + @dg.multi_asset( + specs=[AssetSpec(key="a", deps=[AssetDep("b", partition_mapping=LastPartitionMapping())])] + ) + def my_asset(): + pass + + a_asset = next( + iter( + dg.map_asset_specs( + lambda spec: spec.merge_attributes( + deps=[AssetDep("c", partition_mapping=IdentityPartitionMapping())] + ), + [my_asset], + ) + ) + ) + a_spec = next(iter(a_asset.specs)) + b_dep = next(iter(dep for dep in a_spec.deps if dep.asset_key == AssetKey("b"))) + assert b_dep.partition_mapping == LastPartitionMapping() + c_dep = next(iter(dep for dep in a_spec.deps if dep.asset_key == AssetKey("c"))) + assert c_dep.partition_mapping == IdentityPartitionMapping() + assert a_asset.get_partition_mapping(AssetKey("c")) == IdentityPartitionMapping() + assert a_asset.get_partition_mapping(AssetKey("b")) == LastPartitionMapping() + + +def test_add_specs_non_executable_asset() -> None: + assets_def = ( + dg.Definitions(assets=[AssetSpec(key="foo")]) + .get_repository_def() + .assets_defs_by_key[AssetKey("foo")] + ) + foo_spec = next( + iter( + next( + iter( + dg.map_asset_specs(lambda spec: spec.merge_attributes(deps=["a"]), [assets_def]) + ) + ).specs + ) + ) + assert foo_spec.deps == [AssetDep("a")] + + +def test_graph_backed_asset_additional_deps() -> None: + @dg.op + def foo_op(): + pass + + @dg.graph_asset() + def foo(): + return foo_op() + + with pytest.raises(CheckError): + dg.map_asset_specs(lambda spec: spec.merge_attributes(deps=["baz"]), [foo])