Skip to content

Commit

Permalink
small updates and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Sep 22, 2023
1 parent bb2dbb4 commit c8e9c94
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1133,9 +1133,10 @@ def self_dependent_asset(context: AssetExecutionContext, self_dependent_asset):
# running a backfill of the 2023-08-21 through 2023-08-25 partitions of this asset will log:
# ["2023-08-20", "2023-08-21", "2023-08-22", "2023-08-23", "2023-08-24"]
"""
asset_key = self.asset_key_for_input(input_name)
return list(
self._step_execution_context.asset_partitions_subset_for_input(
input_name
self._step_execution_context.partitions_subset_for_upstream_asset(
asset_key
).get_partition_keys()
)

Expand Down
22 changes: 5 additions & 17 deletions python_modules/dagster/dagster/_core/execution/context/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def partition_key_range_for_asset(
self, asset: Optional[CoercibleToAssetKey], is_dependency: bool = False
) -> PartitionKeyRange:
if asset is None:
check.failed(f"Tried to access partition key range with invalid asset key: {asset}")
check.failed(f"Tried to access partition key range for invalid asset key: {asset}")
if self._load_partition_info_as_upstream_asset(
current_asset=asset, is_dependency=is_dependency
):
Expand All @@ -1095,23 +1095,11 @@ def partition_key_range_for_asset(
else:
return self.asset_partition_key_range

# def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange:
# subset = self.asset_partitions_subset_for_input(input_name)
# partition_key_ranges = subset.get_partition_key_ranges(
# dynamic_partitions_store=self.instance
# )

# if len(partition_key_ranges) != 1:
# check.failed(
# "Tried to access asset partition key range, but there are "
# f"({len(partition_key_ranges)}) key ranges associated with this input.",
# )

# return partition_key_ranges[0]

def partitions_subset_for_upstream_asset(
self, asset: CoercibleToAssetKey, *, require_valid_partitions: bool = True
self, asset: Optional[CoercibleToAssetKey], *, require_valid_partitions: bool = True
) -> PartitionsSubset:
if asset is None:
check.failed(f"Tried to access partition for invalid asset key: {asset}")
asset_layer = self.job_def.asset_layer
assets_def = asset_layer.assets_def_for_node(self.node_handle)
upstream_asset_key = AssetKey.from_coercible(asset)
Expand Down Expand Up @@ -1156,7 +1144,7 @@ def partitions_subset_for_upstream_asset(

return mapped_partitions_result.partitions_subset

check.failed("The input has no asset partitions")
check.failed(f"The asset {asset} has no partitions")

def asset_partition_key_for_input(self, input_name: str) -> str:
asset_key = self.job_def.asset_layer.asset_key_for_input(self.node_handle, input_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pendulum
import pytest
from dagster import (
AssetExecutionContext,
AssetMaterialization,
AssetOut,
AssetsDefinition,
Expand Down Expand Up @@ -140,7 +141,7 @@ def test_single_partitioned_asset_job():

class MyIOManager(IOManager):
def handle_output(self, context, obj):
assert context.asset_partition_key == "b"
assert context.partition_key == "b"
assert context.asset_partitions_def == partitions_def

def load_input(self, context):
Expand Down Expand Up @@ -204,24 +205,24 @@ def test_access_partition_keys_from_context_direct_invocation():
partitions_def = StaticPartitionsDefinition(["a"])

@asset(partitions_def=partitions_def)
def partitioned_asset(context):
assert context.asset_partition_key_for_output() == "a"
def partitioned_asset(context: AssetExecutionContext):
assert context.partition_key == "a"

context = build_asset_context(partition_key="a")

# check unbound context
assert context.asset_partition_key_for_output() == "a"
assert context.partition_key == "a"

# check bound context
partitioned_asset(context)

# check failure for non-partitioned asset
@asset
def non_partitioned_asset(context):
def non_partitioned_asset(context: AssetExecutionContext):
with pytest.raises(
CheckError, match="Tried to access partition_key for a non-partitioned asset"
):
context.asset_partition_key_for_output()
context.partition_key # noqa: B018

context = build_asset_context()
non_partitioned_asset(context)
Expand Down Expand Up @@ -249,8 +250,8 @@ def load_input(self, context):
assert context.asset_partition_key_range == PartitionKeyRange("a", "c")

@asset(partitions_def=upstream_partitions_def)
def upstream_asset(context):
assert context.asset_partition_key_for_output() == "b"
def upstream_asset(context: AssetExecutionContext):
assert context.partition_key == "b"

@asset
def downstream_asset(upstream_asset):
Expand Down Expand Up @@ -572,8 +573,8 @@ def test_mismatched_job_partitioned_config_with_asset_partitions():
daily_partitions_def = DailyPartitionsDefinition(start_date="2020-01-01")

@asset(config_schema={"day_of_month": int}, partitions_def=daily_partitions_def)
def asset1(context):
assert context.op_config["day_of_month"] == 1
def asset1(context: AssetExecutionContext):
assert context.op_execution_context.op_config["day_of_month"] == 1
assert context.partition_key == "2020-01-01"

@hourly_partitioned_config(start_date="2020-01-01-00:00")
Expand All @@ -596,8 +597,8 @@ def test_partition_range_single_run():
partitions_def = DailyPartitionsDefinition(start_date="2020-01-01")

@asset(partitions_def=partitions_def)
def upstream_asset(context) -> None:
assert context.asset_partition_key_range_for_output() == PartitionKeyRange(
def upstream_asset(context: AssetExecutionContext) -> None:
assert context.partition_key_range == PartitionKeyRange(
start="2020-01-01", end="2020-01-03"
)

Expand Down Expand Up @@ -640,16 +641,17 @@ def test_multipartition_range_single_run():
)

@asset(partitions_def=partitions_def)
def multipartitioned_asset(context) -> None:
key_range = context.asset_partition_key_range_for_output()
def multipartitioned_asset(context: AssetExecutionContext) -> None:
key_range = context.partition_key_range

assert isinstance(key_range.start, MultiPartitionKey)
assert isinstance(key_range.end, MultiPartitionKey)
assert key_range.start == MultiPartitionKey({"date": "2020-01-01", "abc": "a"})
assert key_range.end == MultiPartitionKey({"date": "2020-01-03", "abc": "a"})

assert all(
isinstance(key, MultiPartitionKey) for key in context.asset_partition_keys_for_output()
isinstance(key, MultiPartitionKey)
for key in partitions_def.get_partitions_keys_in_range(context.partition_key_range)
)

the_job = define_asset_job("job").resolve(
Expand Down

0 comments on commit c8e9c94

Please sign in to comment.