Skip to content

Commit

Permalink
[branching io manager][fix] Fix loading partitioned parent assets in …
Browse files Browse the repository at this point in the history
…BranchingIOManager (#18491)

## Summary

Tweak of fix in #17303 which uses the upstream asset partition rather
than the current asset partition when using the branching IO manager.
Uses `context.asset_partition_key` rather than the upstream output's
partition key, which can sometimes be `None` when the upstream output is
not available.

## Test Plan

Adds a new unit test with a non-standard partition mapping, which
previously failed.

---------

Co-authored-by: Ben <ben@Bens-MacBook-Pro.local>
  • Loading branch information
benpankow and Ben authored Dec 12, 2023
1 parent 7322b74 commit 262853f
Showing 4 changed files with 338 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -69,18 +69,38 @@ def load_input(self, context: InputContext) -> Any:
return self.branch_io_manager.load_input(context)
else:
# we are dealing with an asset input
event_log_entry = latest_materialization_log_entry(
instance=context.instance,
asset_key=context.asset_key,
partition_key=context.partition_key if context.has_partition_key else None,
)
if (
event_log_entry
# figure out which partition keys are loaded, if any
partition_keys = []
if context.has_asset_partitions:
partition_keys = context.asset_partition_keys

# we'll fetch materializations with key=None if we aren't loading
# a partitioned asset, this will return us the latest materialization
# of an unpartitioned asset
if len(partition_keys) == 0:
partition_keys = [None]

# grab the latest materialization for each partition that we
# need to load, OR just the latest materialization if not partitioned
event_log_entries = [
latest_materialization_log_entry(
instance=context.instance,
asset_key=context.asset_key,
partition_key=partition_key,
)
for partition_key in partition_keys
]

# if all partitions are available in the branch, we can load from the branch
# otherwise we need to load from the parent
if all(
event_log_entry is not None
and event_log_entry.asset_materialization
and get_text_metadata_value(
event_log_entry.asset_materialization, self.branch_metadata_key
)
== self.branch_name
for event_log_entry in event_log_entries
):
context.log.info(
f'Branching Manager: Loading "{context.asset_key.to_user_string()}" from'
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math
import time
from typing import Optional
from typing import Any, Optional, cast

from dagster import Definitions, asset
from dagster import Definitions, In, asset, job, op
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.events import AssetKey, AssetMaterialization
from dagster._core.definitions.metadata import TextMetadataValue
@@ -246,3 +246,38 @@ def test_basic_workflow():
assert dev_t1_runner.load_asset_value("now_time_plus_20_after_plus_N") == (
now_time_prod_value_1 + 17 + 20
)


@op
def now_time_op() -> int:
return int(math.floor(time.time() * 100))


@op(ins={"now_time": In(int)})
def now_time_divide_by_2(now_time: int) -> int:
return now_time // 2


@job
def now_time_job():
now_time_divide_by_2(now_time_op())


def test_job_op_usecase() -> Any:
with DefinitionsRunner.ephemeral(
Definitions(
jobs=[now_time_job],
resources={
"io_manager": BranchingIOManager(
parent_io_manager=AssetBasedInMemoryIOManager(),
branch_io_manager=AssetBasedInMemoryIOManager(),
)
},
),
) as runner:
assert (
cast(DefinitionsRunner, runner)
.defs.get_job_def("now_time_job")
.execute_in_process(instance=runner.instance)
.success
)
Original file line number Diff line number Diff line change
@@ -1,21 +1,54 @@
import math
import time

from dagster import DagsterInstance, Definitions, asset
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.partition import StaticPartitionsDefinition
from typing import Any, Dict, List, cast

from dagster import (
AssetExecutionContext,
AssetIn,
AssetsDefinition,
DagsterInstance,
Definitions,
In,
StaticPartitionMapping,
StaticPartitionsDefinition,
asset,
job,
op,
)
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.storage.branching.branching_io_manager import BranchingIOManager

from .utils import AssetBasedInMemoryIOManager, DefinitionsRunner

partitioning_scheme = StaticPartitionsDefinition(["A", "B", "C"])
secondary_partitioning_scheme = StaticPartitionsDefinition(["1", "2", "3"])
tertiary_partitioning_scheme = StaticPartitionsDefinition(["ab", "bc", "ca"])

primary_secondary_partition_mapping = StaticPartitionMapping({"A": "1", "B": "2", "C": "3"})
primary_tertiary_partition_mapping = StaticPartitionMapping(
{
"A": ["ab", "ca"],
"B": ["bc", "ab"],
"C": ["bc", "ca"],
}
)


@asset(partitions_def=partitioning_scheme)
def now_time():
return int(math.floor(time.time() * 100))


@asset(
partitions_def=secondary_partitioning_scheme,
ins={
"now_time": AssetIn(key="now_time", partition_mapping=primary_secondary_partition_mapping)
},
)
def now_time_times_two(now_time: int) -> int:
return now_time * 2


def get_now_time_plus_N(N: int) -> AssetsDefinition:
@asset(partitions_def=partitioning_scheme)
def now_time_plus_N(now_time: int) -> int:
@@ -217,3 +250,214 @@ def test_basic_partitioning_workflow():
dev_runner_t1.load_asset_value("now_time_plus_N", partition_key="B")
== dev_now_time_B + 15
)


def test_partition_mapping_workflow() -> Any:
prod_io_manager = AssetBasedInMemoryIOManager()
dev_io_manager = AssetBasedInMemoryIOManager()

prod_defs = Definitions(
assets=[now_time, now_time_times_two],
resources={
"io_manager": prod_io_manager,
},
)

dev_defs = Definitions(
assets=[now_time, now_time_times_two],
resources={
"io_manager": BranchingIOManager(
parent_io_manager=prod_io_manager, branch_io_manager=dev_io_manager
)
},
)

with DagsterInstance.ephemeral() as dev_instance, DagsterInstance.ephemeral() as prod_instance:
# Simulate a full prod run. All partitions are full
prod_runner = DefinitionsRunner(prod_defs, prod_instance)
prod_runner.materialize_asset("now_time", partition_key="A")
prod_runner.materialize_asset("now_time", partition_key="B")
prod_runner.materialize_asset("now_time", partition_key="C")

prod_runner.materialize_asset("now_time_times_two", partition_key="1")
prod_runner.materialize_asset("now_time_times_two", partition_key="2")
prod_runner.materialize_asset("now_time_times_two", partition_key="3")

for partition_key in ["A", "B", "C"]:
assert prod_io_manager.has_value("now_time", partition_key)

for partition_key in ["1", "2", "3"]:
assert prod_io_manager.has_value("now_time_times_two", partition_key)

dev_runner = DefinitionsRunner(dev_defs, dev_instance)

dev_runner.materialize_asset("now_time_times_two", partition_key="2")

assert dev_runner.load_asset_value(
"now_time", partition_key="A"
) == prod_runner.load_asset_value("now_time", partition_key="A")

assert not dev_io_manager.has_value("now_time", partition_key="A")

# now_time_plus_N has been remataerialized in the dev branch but still with same logic
assert dev_runner.load_asset_value(
"now_time_times_two", partition_key="2"
) == prod_runner.load_asset_value("now_time_times_two", partition_key="2")

assert dev_io_manager.has_value("now_time_times_two", partition_key="2")


# Asset factory which produces a partitioned asset w/ each partition having a different seeded value
def get_base_values(seed_values: List[int]) -> AssetsDefinition:
assert len(seed_values) == 3
seed_value_dict = {
"A": seed_values[0],
"B": seed_values[1],
"C": seed_values[2],
}

@asset(partitions_def=partitioning_scheme)
def base_values(context: AssetExecutionContext) -> int:
return seed_value_dict[context.partition_key]

return base_values


# Asset with a many-to-one mapping from input partitions to output partitions
# e.g. to materialize partition "ab" in the output asset, we need as input partitions "A" and "B" of
# now_time
@asset(
partitions_def=tertiary_partitioning_scheme,
ins={
"upstream_values": AssetIn(
key="base_values", partition_mapping=primary_tertiary_partition_mapping
)
},
)
def average_upstream(upstream_values: Dict[str, int]) -> int:
return sum(upstream_values.values()) // len(upstream_values)


def test_multi_partition_mapping_workflow() -> Any:
prod_io_manager = AssetBasedInMemoryIOManager()
dev_io_manager = AssetBasedInMemoryIOManager()

prod_defs = Definitions(
assets=[get_base_values([10, 20, 30]), average_upstream],
resources={
"io_manager": prod_io_manager,
},
)

dev_defs = Definitions(
assets=[get_base_values([50, 100, 150]), average_upstream],
resources={
"io_manager": BranchingIOManager(
parent_io_manager=prod_io_manager, branch_io_manager=dev_io_manager
)
},
)

with DagsterInstance.ephemeral() as dev_instance, DagsterInstance.ephemeral() as prod_instance:
# Simulate a full prod run. All partitions are full
prod_runner = DefinitionsRunner(prod_defs, prod_instance)
prod_runner.materialize_asset("base_values", partition_key="A")
prod_runner.materialize_asset("base_values", partition_key="B")
prod_runner.materialize_asset("base_values", partition_key="C")

prod_runner.materialize_asset("average_upstream", partition_key="ab")
prod_runner.materialize_asset("average_upstream", partition_key="bc")
prod_runner.materialize_asset("average_upstream", partition_key="ca")

for partition_key in ["A", "B", "C"]:
assert prod_io_manager.has_value("base_values", partition_key)

for partition_key in ["ab", "bc", "ca"]:
assert prod_io_manager.has_value("average_upstream", partition_key)

# Verify that, since we haven't materialized the upstream asset in dev, we are reading from
# the values generated in prod
assert prod_io_manager.get_value("base_values", partition_key="A") == 10
assert not dev_io_manager.has_value("base_values", partition_key="A")

assert prod_io_manager.get_value("average_upstream", partition_key="ab") == 15
assert not dev_io_manager.has_value("average_upstream", partition_key="ab") == 15

dev_runner = DefinitionsRunner(dev_defs, dev_instance)

# First, we try rematerializing the averages in dev. Since we haven't materialized the
# upstream asset, we should be reading from prod, and the values should be the same
dev_runner.materialize_asset("average_upstream", partition_key="ab")
assert dev_io_manager.has_value("average_upstream", partition_key="ab")
assert dev_io_manager.get_value("average_upstream", partition_key="ab") == 15

# Now, we materialize the upstream asset in dev, but only a single partition
# The branching IO manager logic will only read from the upstream asset in dev if all
# upstream partitions are materialized in dev, so the average will be unchanged
dev_runner.materialize_asset("base_values", partition_key="A")
assert dev_io_manager.has_value("base_values", partition_key="A")
assert dev_io_manager.get_value("base_values", partition_key="A") == 50

dev_runner.materialize_asset("average_upstream", partition_key="ab")
assert dev_io_manager.get_value("average_upstream", partition_key="ab") == 15

# Now, we materialize the upstream asset in dev for the "B" partition. Since we have
# materialized all needed upstream partitions in dev, the branching IO manager logic
# will read from the upstream asset in dev, and the average will be updated
dev_runner.materialize_asset("base_values", partition_key="B")
assert dev_io_manager.has_value("base_values", partition_key="B")
assert dev_io_manager.get_value("base_values", partition_key="B") == 100

dev_runner.materialize_asset("average_upstream", partition_key="ab")
assert dev_io_manager.get_value("average_upstream", partition_key="ab") == 75


@op
def fixed_value_op(context: OpExecutionContext) -> int:
if context.partition_key == "A":
return 10
elif context.partition_key == "B":
return 20
elif context.partition_key == "C":
return 30
else:
raise Exception("Invalid partition key")


@op(ins={"input_value": In(int)})
def divide_input_by_two(input_value: int) -> int:
return input_value // 2


@job(partitions_def=partitioning_scheme)
def my_math_job():
divide_input_by_two(fixed_value_op())


def test_job_op_usecase_partitioned() -> Any:
with DefinitionsRunner.ephemeral(
Definitions(
jobs=[my_math_job],
resources={
"io_manager": BranchingIOManager(
parent_io_manager=AssetBasedInMemoryIOManager(),
branch_io_manager=AssetBasedInMemoryIOManager(),
)
},
),
) as runner:
result = (
cast(DefinitionsRunner, runner)
.defs.get_job_def("my_math_job")
.execute_in_process(instance=runner.instance, partition_key="A")
)
assert result.success
assert result.output_for_node("divide_input_by_two") == 5

result = (
cast(DefinitionsRunner, runner)
.defs.get_job_def("my_math_job")
.execute_in_process(instance=runner.instance, partition_key="B")
)
assert result.success
assert result.output_for_node("divide_input_by_two") == 10
Loading

0 comments on commit 262853f

Please sign in to comment.