Skip to content

Commit

Permalink
[module-loaders] Further simplify asset checks path (#26527)
Browse files Browse the repository at this point in the history
## Summary & Motivation
We can further simplify the load_asset_from_x code path by making the
asset checks loader return AssetsDefinitions, and changing the
assertions to just make sure that the returned assets only contain
checks.

## How I Tested These Changes
Altered existing tests
  • Loading branch information
dpeng817 committed Dec 19, 2024
1 parent bc06a99 commit 26d808d
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 27 deletions.
46 changes: 37 additions & 9 deletions python_modules/dagster/dagster/_core/definitions/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from dagster._utils.warnings import ExperimentalWarning, disable_dagster_warnings

if TYPE_CHECKING:
from dagster._core.definitions.asset_checks import AssetChecksDefinition
from dagster._core.definitions.graph_definition import GraphDefinition

ASSET_SUBSET_INPUT_PREFIX = "__subset_input__"
Expand Down Expand Up @@ -1180,6 +1181,30 @@ def get_op_def_for_asset_key(self, key: AssetKey) -> Optional[OpDefinition]:
output_name = self.get_output_name_for_asset_key(key)
return self.node_def.resolve_output_to_origin_op_def(output_name)

def coerce_to_checks_def(self) -> "AssetChecksDefinition":
from dagster._core.definitions.asset_checks import (
AssetChecksDefinition,
has_only_asset_checks,
)

if not has_only_asset_checks(self):
raise DagsterInvalidDefinitionError(
"Cannot coerce an AssetsDefinition to an AssetChecksDefinition if it contains "
"non-check assets."
)
if len(self.check_keys) == 0:
raise DagsterInvalidDefinitionError(
"Cannot coerce an AssetsDefinition to an AssetChecksDefinition if it contains no "
"checks."
)
return AssetChecksDefinition.create(
keys_by_input_name=self.keys_by_input_name,
node_def=self.op,
check_specs_by_output_name=self.check_specs_by_output_name,
resource_defs=self.resource_defs,
can_subset=self.can_subset,
)

def with_attributes(
self,
*,
Expand Down Expand Up @@ -1903,15 +1928,18 @@ def replace_specs_on_asset(
from dagster._builtins import Nothing
from dagster._core.definitions.input import In

new_deps = set().union(*(spec.deps for spec in replaced_specs))
previous_deps = set().union(*(spec.deps for spec in assets_def.specs))
added_deps = new_deps - previous_deps
removed_deps = previous_deps - new_deps
remaining_original_deps = previous_deps - removed_deps
new_deps_by_key = {dep.asset_key: dep for spec in replaced_specs for dep in spec.deps}
previous_deps_by_key = {dep.asset_key: dep for spec in assets_def.specs for dep in spec.deps}
added_dep_keys = set(new_deps_by_key.keys()) - set(previous_deps_by_key.keys())
removed_dep_keys = set(previous_deps_by_key.keys()) - set(new_deps_by_key.keys())
remaining_original_deps_by_key = {
key: previous_deps_by_key[key]
for key in set(previous_deps_by_key.keys()) - removed_dep_keys
}
original_key_to_input_mapping = reverse_dict(assets_def.node_keys_by_input_name)

# If there are no changes to the dependency structure, we don't need to make any changes to the underlying node.
if not assets_def.is_executable or (not added_deps and not removed_deps):
if not assets_def.is_executable or (not added_dep_keys and not removed_dep_keys):
return assets_def.__class__.dagster_internal_init(
**{**assets_def.get_attributes_dict(), "specs": replaced_specs}
)
Expand All @@ -1925,15 +1953,15 @@ def replace_specs_on_asset(
"Can only add additional deps to an op-backed asset.",
)
# for each deleted dep, we need to make sure it is not an argument-based dep. Argument-based deps cannot be removed.
for dep in removed_deps:
for dep_key in removed_dep_keys:
dep = previous_deps_by_key[dep_key]
input_name = original_key_to_input_mapping[dep.asset_key]
input_def = assets_def.node_def.input_def_named(input_name)
check.invariant(
input_def.dagster_type.is_nothing,
f"Attempted to remove argument-backed dependency {dep.asset_key} (mapped to argument {input_name}) from the asset. Only non-argument dependencies can be changed or removed using map_asset_specs.",
)

remaining_original_deps_by_key = {dep.asset_key: dep for dep in remaining_original_deps}
remaining_ins = {
input_name: the_in
for input_name, the_in in assets_def.node_def.input_dict.items()
Expand All @@ -1943,7 +1971,7 @@ def replace_specs_on_asset(
remaining_ins,
{
stringify_asset_key_to_input_name(dep.asset_key): In(dagster_type=Nothing)
for dep in new_deps
for dep in new_deps_by_key.values()
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,15 +446,11 @@ def replace_keys_in_asset(
key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values()
},
)
if isinstance(asset, AssetsDefinition) and has_only_asset_checks(asset):
updated_object = AssetChecksDefinition.create(
keys_by_input_name=updated_object.keys_by_input_name,
node_def=updated_object.op,
check_specs_by_output_name=updated_object.check_specs_by_output_name,
resource_defs=updated_object.resource_defs,
can_subset=updated_object.can_subset,
)
return updated_object
return (
updated_object.coerce_to_checks_def()
if has_only_asset_checks(updated_object)
else updated_object
)


class ResolvedAssetObjectList:
Expand Down Expand Up @@ -564,15 +560,11 @@ def with_attributes(
).with_attributes(
backfill_policy=backfill_policy, freshness_policy=freshness_policy
)
if isinstance(asset, AssetChecksDefinition):
new_asset = AssetChecksDefinition.create(
keys_by_input_name=new_asset.keys_by_input_name,
node_def=new_asset.op,
check_specs_by_output_name=new_asset.check_specs_by_output_name,
resource_defs=new_asset.resource_defs,
can_subset=new_asset.can_subset,
)
return_list.append(new_asset)
return_list.append(
new_asset.coerce_to_checks_def()
if has_only_asset_checks(new_asset)
else new_asset
)
elif isinstance(asset, SourceAsset):
return_list.append(
asset.with_attributes(group_name=group_name if group_name else asset.group_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,44 @@ def foo():

with pytest.raises(CheckError):
dg.map_asset_specs(lambda spec: spec.merge_attributes(deps=["baz"]), [foo])


def test_static_partition_mapping_dep() -> None:
@dg.asset(partitions_def=dg.StaticPartitionsDefinition(["1", "2"]))
def b():
pass

@dg.multi_asset(
specs=[
AssetSpec(
key="a",
partitions_def=dg.StaticPartitionsDefinition(["1", "2"]),
deps=[
AssetDep("b", partition_mapping=dg.StaticPartitionMapping({"1": "1", "2": "2"}))
],
)
]
)
def my_asset():
pass

a_asset = next(
iter(
dg.map_asset_specs(
lambda spec: spec.merge_attributes(
deps=[
AssetDep(
"c", partition_mapping=dg.StaticPartitionMapping({"1": "1", "2": "2"})
)
]
),
[my_asset],
)
)
)

a_spec = next(iter(a_asset.specs))
b_dep = next(iter(dep for dep in a_spec.deps if dep.asset_key == AssetKey("b")))
c_dep = next(iter(dep for dep in a_spec.deps if dep.asset_key == AssetKey("c")))
assert b_dep.partition_mapping == dg.StaticPartitionMapping({"1": "1", "2": "2"})
assert c_dep.partition_mapping == dg.StaticPartitionMapping({"1": "1", "2": "2"})

0 comments on commit 26d808d

Please sign in to comment.