Skip to content

Commit

Permalink
[external-assets] Update callsites for partition property asset graph…
Browse files Browse the repository at this point in the history
… accessors (#20330)

## Summary & Motivation

Internal companion PR: dagster-io/internal#8615

Second round of method deletions/callsite updates to use new node-based
`AssetGraph` APIs. Removed the following methods from `AssetGraph`:

- `get_partitions_def`
- `is_partitioned`
- `have_same_partitioning`
- `have_same_or_no_partitioning`
- `get_child_nodes` (`get_children`, which previously returned keys, now
fulfills this function)
- `get_parent_nodes` (`get_parents`, which previously returned keys, now
fulfills this function)
- `get_execution_set_asset_keys`
- `has_self_dependency`

## How I Tested These Changes

Existing test suite.
  • Loading branch information
smackesey authored Mar 8, 2024
1 parent e3381e2 commit 01715ac
Show file tree
Hide file tree
Showing 24 changed files with 216 additions and 246 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_asset_backfill_preview(
asset_partitions = []

for asset_key in asset_backfill_data.get_targeted_asset_keys_topological_order(asset_graph):
if asset_graph.get_partitions_def(asset_key):
if asset_graph.get(asset_key).partitions_def:
partitions_subset = asset_backfill_data.target_subset.partitions_subsets_by_asset_key[
asset_key
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,12 +603,12 @@ def test_launch_asset_backfill_with_upstream_anchor_asset():
asset_graph = repo.asset_graph
assert target_subset == AssetGraphSubset(
partitions_subsets_by_asset_key={
AssetKey("hourly"): asset_graph.get_partitions_def(
AssetKey("hourly"): asset_graph.get(
AssetKey("hourly")
).subset_with_partition_keys(hourly_partitions),
AssetKey("daily"): asset_graph.get_partitions_def(
).partitions_def.subset_with_partition_keys(hourly_partitions),
AssetKey("daily"): asset_graph.get(
AssetKey("daily")
).subset_with_partition_keys(["2020-01-02", "2020-01-03"]),
).partitions_def.subset_with_partition_keys(["2020-01-02", "2020-01-03"]),
},
)

Expand Down Expand Up @@ -668,15 +668,15 @@ def test_launch_asset_backfill_with_two_anchor_assets():
asset_graph = repo.asset_graph
assert target_subset == AssetGraphSubset(
partitions_subsets_by_asset_key={
AssetKey("hourly1"): asset_graph.get_partitions_def(
AssetKey("hourly1"): asset_graph.get(
AssetKey("hourly1")
).subset_with_partition_keys(hourly_partitions),
AssetKey("hourly2"): asset_graph.get_partitions_def(
).partitions_def.subset_with_partition_keys(hourly_partitions),
AssetKey("hourly2"): asset_graph.get(
AssetKey("hourly2")
).subset_with_partition_keys(hourly_partitions),
AssetKey("daily"): asset_graph.get_partitions_def(
).partitions_def.subset_with_partition_keys(hourly_partitions),
AssetKey("daily"): asset_graph.get(
AssetKey("daily")
).subset_with_partition_keys(["2020-01-02", "2020-01-03"]),
).partitions_def.subset_with_partition_keys(["2020-01-02", "2020-01-03"]),
},
)

Expand Down Expand Up @@ -724,13 +724,13 @@ def test_launch_asset_backfill_with_upstream_anchor_asset_and_non_partitioned_as
non_partitioned_asset_keys={AssetKey("non_partitioned")},
partitions_subsets_by_asset_key={
AssetKey("hourly"): (
asset_graph.get_partitions_def(AssetKey("hourly"))
.empty_subset()
asset_graph.get(AssetKey("hourly"))
.partitions_def.empty_subset()
.with_partition_keys(hourly_partitions)
),
AssetKey("daily"): (
asset_graph.get_partitions_def(AssetKey("daily"))
.empty_subset()
asset_graph.get(AssetKey("daily"))
.partitions_def.empty_subset()
.with_partition_keys(["2020-01-02", "2020-01-03"])
),
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _mock_asset_backfill_runs(
status: DagsterRunStatus,
partition_key: Optional[str],
):
partitions_def = asset_graph.get_partitions_def(asset_key)
partitions_def = asset_graph.get(asset_key).partitions_def

@asset(
partitions_def=partitions_def,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def create(
evaluation_state_by_key: Mapping[AssetKey, "AssetConditionEvaluationState"],
expected_data_time_mapping: Mapping[AssetKey, Optional[datetime.datetime]],
) -> "AssetConditionEvaluationContext":
partitions_def = instance_queryer.asset_graph.get_partitions_def(asset_key)
partitions_def = instance_queryer.asset_graph.get(asset_key).partitions_def

return AssetConditionEvaluationContext(
asset_key=asset_key,
Expand Down Expand Up @@ -136,7 +136,7 @@ def asset_graph(self) -> BaseAssetGraph:

@property
def partitions_def(self) -> Optional[PartitionsDefinition]:
return self.asset_graph.get_partitions_def(self.asset_key)
return self.asset_graph.get(self.asset_key).partitions_def

@property
def evaluation_time(self) -> datetime.datetime:
Expand Down Expand Up @@ -190,7 +190,7 @@ def parent_will_update_subset(self) -> ValidAssetSubset:
can be materialized in the same run as this asset.
"""
subset = self.empty_subset()
for parent_key in self.asset_graph.get_parents(self.asset_key):
for parent_key in self.asset_graph.get(self.asset_key).parent_keys:
if not self.materializable_in_same_run(self.asset_key, parent_key):
continue
parent_info = self.evaluation_state_by_key.get(parent_key)
Expand Down Expand Up @@ -302,17 +302,19 @@ def materializable_in_same_run(self, child_key: AssetKey, parent_key: AssetKey)
"""Returns whether a child asset can be materialized in the same run as a parent asset."""
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph

child_node = self.asset_graph.get(child_key)
parent_node = self.asset_graph.get(parent_key)
return (
# both assets must be materializable
child_key in self.asset_graph.materializable_asset_keys
and parent_key in self.asset_graph.materializable_asset_keys
child_node.is_materializable
and parent_node.is_materializable
# the parent must have the same partitioning
and self.asset_graph.have_same_partitioning(child_key, parent_key)
and child_node.partitions_def == parent_node.partitions_def
# the parent must have a simple partition mapping to the child
and (
not self.asset_graph.is_partitioned(parent_key)
not parent_node.is_partitioned
or isinstance(
self.asset_graph.get_partition_mapping(child_key, parent_key),
self.asset_graph.get_partition_mapping(child_node.key, parent_node.key),
(TimeWindowPartitionMapping, IdentityPartitionMapping),
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_implicit_auto_materialize_policy(
"""For backcompat with pre-auto materialize policy graphs, assume a default scope of 1 day."""
auto_materialize_policy = asset_graph.get(asset_key).auto_materialize_policy
if auto_materialize_policy is None:
time_partitions_def = get_time_partitions_def(asset_graph.get_partitions_def(asset_key))
time_partitions_def = get_time_partitions_def(asset_graph.get(asset_key).partitions_def)
if time_partitions_def is None:
max_materializations_per_minute = None
elif time_partitions_def.schedule_type == ScheduleType.HOURLY:
Expand Down Expand Up @@ -141,7 +141,7 @@ def auto_materialize_asset_keys_and_parents(self) -> AbstractSet[AssetKey]:
return {
parent
for asset_key in self.auto_materialize_asset_keys
for parent in self.asset_graph.get_parents(asset_key)
for parent in self.asset_graph.get(asset_key).parent_keys
} | self.auto_materialize_asset_keys

@property
Expand Down Expand Up @@ -267,9 +267,9 @@ def get_asset_condition_evaluations(

# if we need to materialize any partitions of a non-subsettable multi-asset, we need to
# materialize all of them
execution_unit_keys = self.asset_graph.get_execution_set_asset_keys(asset_key)
if len(execution_unit_keys) > 1 and num_requested > 0:
for neighbor_key in execution_unit_keys:
execution_set_keys = self.asset_graph.get(asset_key).execution_set_asset_keys
if len(execution_set_keys) > 1 and num_requested > 0:
for neighbor_key in execution_set_keys:
expected_data_time_mapping[neighbor_key] = expected_data_time

# make sure that the true_subset of the neighbor is accurate -- when it was
Expand Down Expand Up @@ -353,7 +353,7 @@ def build_run_requests(

for asset_partition in asset_partitions:
assets_to_reconcile_by_partitions_def_partition_key[
asset_graph.get_partitions_def(asset_partition.asset_key), asset_partition.partition_key
asset_graph.get(asset_partition.asset_key).partitions_def, asset_partition.partition_key
].add(asset_partition.asset_key)

run_requests = []
Expand Down Expand Up @@ -414,7 +414,7 @@ def build_run_requests_with_backfill_policies(
# here we are grouping assets by their partitions def and partition keys selected.
for asset_key, partition_keys in asset_partition_keys.items():
assets_to_reconcile_by_partitions_def_partition_keys[
asset_graph.get_partitions_def(asset_key),
asset_graph.get(asset_key).partitions_def,
frozenset(partition_keys) if partition_keys else None,
].add(asset_key)

Expand Down Expand Up @@ -581,7 +581,7 @@ def get_auto_observe_run_requests(
for repository_asset_keys in asset_graph.split_asset_keys_by_repository(assets_to_auto_observe):
asset_keys_by_partitions_def = defaultdict(list)
for asset_key in repository_asset_keys:
partitions_def = asset_graph.get_partitions_def(asset_key)
partitions_def = asset_graph.get(asset_key).partitions_def
asset_keys_by_partitions_def[partitions_def].append(asset_key)
partitions_def_and_asset_key_groups.extend(asset_keys_by_partitions_def.values())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def backcompat_deserialize_asset_daemon_cursor_str(
partition_subsets_by_asset_key = {}
for key_str, serialized_str in data.get("handled_root_partitions_by_asset_key", {}).items():
asset_key = AssetKey.from_user_string(key_str)
partitions_def = asset_graph.get_partitions_def(asset_key) if asset_graph else None
partitions_def = asset_graph.get(asset_key).partitions_def if asset_graph else None
if not partitions_def:
continue
try:
Expand All @@ -221,7 +221,7 @@ def backcompat_deserialize_asset_daemon_cursor_str(
latest_evaluation_by_asset_key = {}
for key_str, serialized_evaluation in serialized_latest_evaluation_by_asset_key.items():
key = AssetKey.from_user_string(key_str)
partitions_def = asset_graph.get_partitions_def(key) if asset_graph else None
partitions_def = asset_graph.get(key).partitions_def if asset_graph else None

evaluation = deserialize_auto_materialize_asset_evaluation_to_asset_condition_evaluation_with_run_ids(
serialized_evaluation, partitions_def
Expand All @@ -239,7 +239,7 @@ def backcompat_deserialize_asset_daemon_cursor_str(
latest_evaluation_result = latest_evaluation_by_asset_key.get(asset_key)
# create a placeholder evaluation result if we don't have one
if not latest_evaluation_result:
partitions_def = asset_graph.get_partitions_def(asset_key) if asset_graph else None
partitions_def = asset_graph.get(asset_key).partitions_def if asset_graph else None
latest_evaluation_result = AssetConditionEvaluation(
condition_snapshot=AssetConditionSnapshot("", "", ""),
true_subset=AssetSubset.empty(asset_key, partitions_def),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,25 @@ def _compare_base_and_branch_assets(self, asset_key: "AssetKey") -> Sequence[Cha
):
changes.append(ChangeReason.CODE_VERSION)

if self.branch_asset_graph.get_parents(asset_key) != self.base_asset_graph.get_parents(
asset_key
if (
self.branch_asset_graph.get(asset_key).parent_keys
!= self.base_asset_graph.get(asset_key).parent_keys
):
changes.append(ChangeReason.INPUTS)
else:
# if the set of inputs is different, then we don't need to check if the partition mappings
# for inputs have changed since ChangeReason.INPUTS is already in the list of changes
for upstream_asset in self.branch_asset_graph.get_parents(asset_key):
for upstream_asset in self.branch_asset_graph.get(asset_key).parent_keys:
if self.branch_asset_graph.get_partition_mapping(
asset_key, upstream_asset
) != self.base_asset_graph.get_partition_mapping(asset_key, upstream_asset):
changes.append(ChangeReason.INPUTS)
break

if self.branch_asset_graph.get_partitions_def(
asset_key
) != self.base_asset_graph.get_partitions_def(asset_key):
if (
self.branch_asset_graph.get(asset_key).partitions_def
!= self.base_asset_graph.get(asset_key).partitions_def
):
changes.append(ChangeReason.PARTITIONS_DEFINITION)

return changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_asset_subset(self, asset_key: AssetKey, asset_graph: BaseAssetGraph) ->
"""Returns an AssetSubset representing the subset of a specific asset that this
AssetGraphSubset contains.
"""
partitions_def = asset_graph.get_partitions_def(asset_key)
partitions_def = asset_graph.get(asset_key).partitions_def
if partitions_def is None:
return AssetSubset(
asset_key=asset_key, value=asset_key in self.non_partitioned_asset_keys
Expand All @@ -100,7 +100,7 @@ def get_partitions_subset(
self, asset_key: AssetKey, asset_graph: Optional[BaseAssetGraph] = None
) -> PartitionsSubset:
if asset_graph:
partitions_def = asset_graph.get_partitions_def(asset_key)
partitions_def = asset_graph.get(asset_key).partitions_def
if partitions_def is None:
check.failed("Can only call get_partitions_subset on a partitioned asset")

Expand Down Expand Up @@ -148,15 +148,15 @@ def to_storage_dict(
},
"serializable_partitions_def_ids_by_asset_key": {
key.to_user_string(): check.not_none(
asset_graph.get_partitions_def(key)
asset_graph.get(key).partitions_def
).get_serializable_unique_identifier(
dynamic_partitions_store=dynamic_partitions_store
)
for key, _ in self.partitions_subsets_by_asset_key.items()
},
"partitions_def_class_names_by_asset_key": {
key.to_user_string(): check.not_none(
asset_graph.get_partitions_def(key)
asset_graph.get(key).partitions_def
).__class__.__name__
for key, _ in self.partitions_subsets_by_asset_key.items()
},
Expand Down Expand Up @@ -255,7 +255,7 @@ def from_asset_partition_set(
return AssetGraphSubset(
partitions_subsets_by_asset_key={
asset_key: (
cast(PartitionsDefinition, asset_graph.get_partitions_def(asset_key))
cast(PartitionsDefinition, asset_graph.get(asset_key).partitions_def)
.empty_subset()
.with_partition_keys(partition_keys)
)
Expand All @@ -278,7 +278,7 @@ def can_deserialize(

for key, value in serialized_dict["partitions_subsets_by_asset_key"].items():
asset_key = AssetKey.from_user_string(key)
partitions_def = asset_graph.get_partitions_def(asset_key)
partitions_def = asset_graph.get(asset_key).partitions_def

if partitions_def is None:
# Asset had a partitions definition at storage time, but no longer does
Expand Down Expand Up @@ -320,7 +320,7 @@ def from_storage_dict(
)
continue

partitions_def = asset_graph.get_partitions_def(asset_key)
partitions_def = asset_graph.get(asset_key).partitions_def

if partitions_def is None:
if not allow_partial:
Expand Down Expand Up @@ -382,7 +382,7 @@ def from_asset_keys(
non_partitioned_asset_keys: Set[AssetKey] = set()

for asset_key in asset_keys:
partitions_def = asset_graph.get_partitions_def(asset_key)
partitions_def = asset_graph.get(asset_key).partitions_def
if partitions_def:
partitions_subsets_by_asset_key[asset_key] = (
partitions_def.empty_subset().with_partition_keys(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def resolve_inner(self, asset_graph: BaseAssetGraph) -> AbstractSet[AssetKey]:
selection = self.child.resolve_inner(asset_graph)
output = set(selection)
for asset_key in selection:
output.update(asset_graph.get_execution_set_asset_keys(asset_key))
output.update(asset_graph.get(asset_key).execution_set_asset_keys)
return output

def to_serializable_asset_selection(self, asset_graph: BaseAssetGraph) -> "AssetSelection":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def evaluate_for_asset(
else:
# At least one upstream partition in each upstream asset must be updated in order
# for the candidate to be updated
parent_asset_keys = context.asset_graph.get_parents(context.asset_key)
parent_asset_keys = context.asset_graph.get(context.asset_key).parent_keys
updated_parent_keys = {ap.asset_key for ap in updated_parent_partitions}
non_updated_parent_keys = parent_asset_keys - updated_parent_keys

Expand Down Expand Up @@ -964,11 +964,11 @@ def get_parent_subsets_updated_since_cron_by_key(
partitioned parents, as their partitions encode the time windows they have processed.
"""
updated_subsets_by_key = {}
for parent_asset_key in context.asset_graph.get_parents(context.asset_key):
for parent_asset_key in context.asset_graph.get(context.asset_key).parent_keys:
# no need to incrementally calculate updated time-window partitions definitions, as
# their partitions encode the time windows they have processed.
if isinstance(
context.asset_graph.get_partitions_def(parent_asset_key),
context.asset_graph.get(parent_asset_key).partitions_def,
TimeWindowPartitionsDefinition,
):
continue
Expand All @@ -988,7 +988,7 @@ def parent_updated_since_cron(
"""Returns if, for a given child asset partition, the given parent asset been updated with
information from the required time window.
"""
parent_partitions_def = context.asset_graph.get_partitions_def(parent_asset_key)
parent_partitions_def = context.asset_graph.get(parent_asset_key).partitions_def

if isinstance(parent_partitions_def, TimeWindowPartitionsDefinition):
# for time window partitions definitions, we simply assert that all time partitions that
Expand Down Expand Up @@ -1069,7 +1069,7 @@ def evaluate_for_asset(
candidate,
updated_subsets_by_key.get(parent_asset_key, context.empty_subset()),
)
for parent_asset_key in context.asset_graph.get_parents(candidate.asset_key)
for parent_asset_key in context.asset_graph.get(candidate.asset_key).parent_keys
)
},
)
Expand Down
Loading

0 comments on commit 01715ac

Please sign in to comment.