Skip to content

Commit

Permalink
deprecate asset_partition_*_for_output on AssetExecutionContext
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Jan 29, 2024
1 parent 8891731 commit d7edead
Show file tree
Hide file tree
Showing 20 changed files with 194 additions and 174 deletions.
64 changes: 37 additions & 27 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,11 @@ def _copy_docs_from_op_execution_context(obj):
"dagster_run": "run",
"run_config": "run.run_config",
"run_tags": "run.tags",
"asset_partition_key_for_output": "partition_key",
"asset_partitions_time_window_for_output": "partition_time_window",
"asset_partition_key_range_for_output": "partition_key_range",
"asset_partitions_def_for_output": "assets_def.partitions_def",
"asset_partition_keys_for_output": "partition_keys",
}

ALTERNATE_EXPRESSIONS = {
Expand Down Expand Up @@ -1496,6 +1501,38 @@ def has_tag(self, key: str) -> bool:
def get_tag(self, key: str) -> Optional[str]:
return self.op_execution_context.get_tag(key)

@deprecated(**_get_deprecation_kwargs("asset_partition_key_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_output(self, output_name: str = "result") -> str:
return self.op_execution_context.asset_partition_key_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_output(output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_output(
self, output_name: str = "result"
) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_output(output_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name)

########## pass-through to op context

#### op related
Expand Down Expand Up @@ -1652,23 +1689,6 @@ def partition_key_range(self) -> PartitionKeyRange:
def partition_time_window(self) -> TimeWindow:
return self.op_execution_context.partition_time_window

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_output(self, output_name: str = "result") -> str:
return self.op_execution_context.asset_partition_key_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_output(output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_output(
self, output_name: str = "result"
) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_output(output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange:
Expand All @@ -1679,21 +1699,11 @@ def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRa
def asset_partition_key_for_input(self, input_name: str) -> str:
return self.op_execution_context.asset_partition_key_for_input(input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def load_input(self, context):
assert context.asset_partitions_def == upstream_partitions_def

@asset(partitions_def=upstream_partitions_def)
def upstream_asset(context):
assert context.asset_partition_key_for_output() == "2"
def upstream_asset(context: AssetExecutionContext):
assert context.partition_key == "2"

@asset(
partitions_def=downstream_partitions_def,
ins={"upstream_asset": AssetIn(partition_mapping=TrailingWindowPartitionMapping())},
)
def downstream_asset(context, upstream_asset):
assert context.asset_partition_key_for_output() == "2"
def downstream_asset(context: AssetExecutionContext, upstream_asset):
assert context.partition_key == "2"
assert upstream_asset is None
assert context.asset_partitions_def_for_input("upstream_asset") == upstream_partitions_def

Expand Down Expand Up @@ -341,9 +341,8 @@ def test_partition_keys_in_range():
]

@asset(partitions_def=DailyPartitionsDefinition(start_date="2022-09-11"))
def upstream(context):
assert context.asset_partition_keys_for_output("result") == ["2022-09-11"]
assert context.asset_partition_keys_for_output() == ["2022-09-11"]
def upstream(context: AssetExecutionContext):
assert context.partition_keys == ["2022-09-11"]

@asset(partitions_def=WeeklyPartitionsDefinition(start_date="2022-09-11"))
def downstream(context, upstream):
Expand Down Expand Up @@ -383,8 +382,8 @@ def test_dependency_resolution_partition_mapping():
partitions_def=DailyPartitionsDefinition(start_date="2020-01-01"),
key_prefix=["staging"],
)
def upstream(context):
partition_date_str = context.asset_partition_key_for_output()
def upstream(context: AssetExecutionContext):
partition_date_str = context.partition_key
return partition_date_str

@asset(
Expand Down Expand Up @@ -441,11 +440,8 @@ def upstream(context):
return 1

@asset(partitions_def=composite)
def downstream(context, upstream):
assert (
context.asset_partition_keys_for_input("upstream")
== context.asset_partition_keys_for_output()
)
def downstream(context: AssetExecutionContext, upstream):
assert context.asset_partition_keys_for_input("upstream") == context.partition_keys
return 1

asset_graph = AssetGraph.from_assets([upstream, downstream])
Expand All @@ -471,16 +467,14 @@ def test_multipartitions_def_partition_mapping_infer_single_dim_to_multi():
)

@asset(partitions_def=abc_def)
def upstream(context):
assert context.asset_partition_keys_for_output("result") == ["a"]
def upstream(context: AssetExecutionContext):
assert context.partition_keys == ["a"]
return 1

@asset(partitions_def=composite)
def downstream(context, upstream):
def downstream(context: AssetExecutionContext, upstream):
assert context.asset_partition_keys_for_input("upstream") == ["a"]
assert context.asset_partition_keys_for_output("result") == [
MultiPartitionKey({"abc": "a", "123": "1"})
]
assert context.partition_keys == [MultiPartitionKey({"abc": "a", "123": "1"})]
return 1

asset_graph = AssetGraph.from_assets([upstream, downstream])
Expand Down Expand Up @@ -533,9 +527,9 @@ def upstream(context):
return 1

@asset(partitions_def=abc_def)
def downstream(context, upstream):
def downstream(context: AssetExecutionContext, upstream):
assert set(context.asset_partition_keys_for_input("upstream")) == a_multipartition_keys
assert context.asset_partition_keys_for_output("result") == ["a"]
assert context.partition_keys == ["a"]
return 1

asset_graph = AssetGraph.from_assets([upstream, downstream])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from dagster import (
AssetExecutionContext,
AssetKey,
AssetOut,
AssetsDefinition,
Expand Down Expand Up @@ -278,8 +279,8 @@ def the_asset(context):

def test_materialize_partition_key():
@asset(partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"))
def the_asset(context):
assert context.asset_partition_key_for_output() == "2022-02-02"
def the_asset(context: AssetExecutionContext):
assert context.partition_key == "2022-02-02"

with instance_for_test() as instance:
result = materialize([the_asset], partition_key="2022-02-02", instance=instance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
with_resources,
)
from dagster._core.errors import DagsterInvalidInvocationError
from dagster._core.execution.context.compute import AssetExecutionContext


def test_basic_materialize_to_memory():
Expand Down Expand Up @@ -243,8 +244,8 @@ def multi_asset_with_internal_deps(thing):

def test_materialize_to_memory_partition_key():
@asset(partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"))
def the_asset(context):
assert context.asset_partition_key_for_output() == "2022-02-02"
def the_asset(context: AssetExecutionContext):
assert context.partition_key == "2022-02-02"

result = materialize_to_memory([the_asset], partition_key="2022-02-02")
assert result.success
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pendulum
import pytest
from dagster import (
AssetExecutionContext,
AssetMaterialization,
AssetOut,
AssetsDefinition,
Expand Down Expand Up @@ -155,8 +156,8 @@ def load_input(self, context):
assert False, "shouldn't get here"

@asset(partitions_def=partitions_def)
def my_asset(context):
assert context.asset_partitions_def_for_output() == partitions_def
def my_asset(context: AssetExecutionContext):
assert context.assets_def.partitions_def == partitions_def

my_job = build_assets_job(
"my_job",
Expand Down Expand Up @@ -212,24 +213,24 @@ def test_access_partition_keys_from_context_direct_invocation():
partitions_def = StaticPartitionsDefinition(["a"])

@asset(partitions_def=partitions_def)
def partitioned_asset(context):
assert context.asset_partition_key_for_output() == "a"
def partitioned_asset(context: AssetExecutionContext):
assert context.partition_key == "a"

context = build_asset_context(partition_key="a")

# check unbound context
assert context.asset_partition_key_for_output() == "a"
assert context.partition_key == "a"

# check bound context
partitioned_asset(context)

# check failure for non-partitioned asset
@asset
def non_partitioned_asset(context):
def non_partitioned_asset(context: AssetExecutionContext):
with pytest.raises(
CheckError, match="Tried to access partition_key for a non-partitioned run"
):
context.asset_partition_key_for_output()
_ = context.partition_key

context = build_asset_context()
non_partitioned_asset(context)
Expand Down Expand Up @@ -257,8 +258,8 @@ def load_input(self, context):
assert context.asset_partition_key_range == PartitionKeyRange("a", "c")

@asset(partitions_def=upstream_partitions_def)
def upstream_asset(context):
assert context.asset_partition_key_for_output() == "b"
def upstream_asset(context: AssetExecutionContext):
assert context.partition_key == "b"

@asset
def downstream_asset(upstream_asset):
Expand Down Expand Up @@ -606,7 +607,6 @@ def test_partition_range_single_run():
@asset(partitions_def=partitions_def)
def upstream_asset(context) -> None:
key_range = PartitionKeyRange(start="2020-01-01", end="2020-01-03")
assert context.asset_partition_key_range_for_output() == key_range
assert context.partition_key_range == key_range
assert context.partition_time_window == TimeWindow(
partitions_def.time_window_for_partition_key(key_range.start).start,
Expand All @@ -615,11 +615,11 @@ def upstream_asset(context) -> None:
assert context.partition_keys == partitions_def.get_partition_keys_in_range(key_range)

@asset(partitions_def=partitions_def, deps=["upstream_asset"])
def downstream_asset(context) -> None:
def downstream_asset(context: AssetExecutionContext) -> None:
assert context.asset_partition_key_range_for_input("upstream_asset") == PartitionKeyRange(
start="2020-01-01", end="2020-01-03"
)
assert context.asset_partition_key_range_for_output() == PartitionKeyRange(
assert context.partition_key_range == PartitionKeyRange(
start="2020-01-01", end="2020-01-03"
)

Expand Down Expand Up @@ -653,17 +653,15 @@ def test_multipartition_range_single_run():
)

@asset(partitions_def=partitions_def)
def multipartitioned_asset(context) -> None:
key_range = context.asset_partition_key_range_for_output()
def multipartitioned_asset(context: AssetExecutionContext) -> None:
key_range = context.partition_key_range

assert isinstance(key_range.start, MultiPartitionKey)
assert isinstance(key_range.end, MultiPartitionKey)
assert key_range.start == MultiPartitionKey({"date": "2020-01-01", "abc": "a"})
assert key_range.end == MultiPartitionKey({"date": "2020-01-03", "abc": "a"})

assert all(
isinstance(key, MultiPartitionKey) for key in context.asset_partition_keys_for_output()
)
assert all(isinstance(key, MultiPartitionKey) for key in context.partition_keys)

the_job = define_asset_job("job").resolve(
asset_graph=AssetGraph.from_assets([multipartitioned_asset])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,11 @@ def test_deprecation_warnings():
"asset_key_for_input",
"asset_key_for_output",
"asset_partition_key_for_input",
"asset_partition_key_for_output",
"asset_partition_key_range",
"asset_partition_key_range_for_input",
"asset_partition_key_range_for_output",
"asset_partition_keys_for_input",
"asset_partition_keys_for_output",
"asset_partitions_def_for_input",
"asset_partitions_def_for_output",
"asset_partitions_time_window_for_input",
"asset_partitions_time_window_for_output",
"assets_def",
"get_output_metadata",
"has_asset_checks_def",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1259,8 +1259,8 @@ def test_partitions_time_window_asset_invocation():
@asset(
partitions_def=partitions_def,
)
def partitioned_asset(context):
start, end = context.asset_partitions_time_window_for_output()
def partitioned_asset(context: AssetExecutionContext):
start, end = context.partition_time_window
assert start == pendulum.instance(datetime(2023, 2, 2), tz=partitions_def.timezone)
assert end == pendulum.instance(datetime(2023, 2, 3), tz=partitions_def.timezone)

Expand All @@ -1279,7 +1279,7 @@ def test_multipartitioned_time_window_asset_invocation():
)

@asset(partitions_def=partitions_def)
def my_asset(context):
def my_asset(context: AssetExecutionContext):
time_window = TimeWindow(
start=pendulum.instance(
datetime(year=2020, month=1, day=1),
Expand All @@ -1290,7 +1290,7 @@ def my_asset(context):
tz=get_time_partitions_def(partitions_def).timezone,
),
)
assert context.asset_partitions_time_window_for_output() == time_window
assert context.partition_time_window == time_window
return 1

context = build_asset_context(
Expand All @@ -1306,12 +1306,12 @@ def my_asset(context):
)

@asset(partitions_def=partitions_def)
def static_multipartitioned_asset(context):
def static_multipartitioned_asset(context: AssetExecutionContext):
with pytest.raises(
DagsterInvariantViolationError,
match="with a single time dimension",
):
context.asset_partitions_time_window_for_output()
_ = context.partition_time_window

context = build_asset_context(
partition_key="a|a",
Expand Down
Loading

0 comments on commit d7edead

Please sign in to comment.