Skip to content

Commit

Permalink
graph_asset_no_defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
johannkm committed Sep 26, 2023
1 parent eb1f7f0 commit 2fec27e
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 41 deletions.
45 changes: 26 additions & 19 deletions python_modules/dagster/dagster/_core/definitions/asset_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,12 @@ def get_attributes_dict(self) -> Dict[str, Any]:


@experimental
def build_blocking_asset_check(
def build_asset_with_blocking_check(
asset_def: "AssetsDefinition",
checks: Sequence[AssetChecksDefinition],
) -> "AssetsDefinition":
from dagster import AssetIn, In, OpExecutionContext, Output, graph_asset, op
from dagster import AssetIn, In, OpExecutionContext, Output, op
from dagster._core.definitions.decorators.asset_decorator import graph_asset_no_defaults
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus

check_specs = []
Expand All @@ -162,9 +163,13 @@ def build_blocking_asset_check(
check.invariant(len(asset_def.op.output_defs) == 1)
asset_out_type = asset_def.op.output_defs[0].dagster_type

@op(ins={"materialization": In(asset_out_type), "check_evaluations": In(Nothing)})
def fan_in_checks_and_materialization(context: OpExecutionContext, materialization):
yield Output(materialization)
@op(ins={"asset_return_value": In(asset_out_type), "check_evaluations": In(Nothing)})
def fan_in_checks_and_asset_return_value(context: OpExecutionContext, asset_return_value: Any):
# we pass the asset_return_value through and store it again so that downstream assets can load it.
# This is a little silly- we only do this because this op has the asset key in its StepOutputProperties
# so the output is written to the right path. We could probably get the asset_def.op to write to the
# asset path (and make sure we don't override it here) to avoid the double write.
yield Output(asset_return_value)

for check_spec in check_specs:
executions = context.instance.event_log_storage.get_asset_check_executions(
Expand All @@ -181,7 +186,21 @@ def fan_in_checks_and_materialization(context: OpExecutionContext, materializati
if execution.status != AssetCheckExecutionRecordStatus.SUCCEEDED:
raise DagsterAssetCheckFailedError()

@graph_asset(
# kwargs are the inputs to the asset_def.op that we are wrapping
def blocking_asset(**kwargs):
asset_return_value = asset_def.op.with_replaced_properties(name="asset_op")(**kwargs)
check_evaluations = [check.node_def(asset_return_value) for check in checks]

return {
"result": fan_in_checks_and_asset_return_value(asset_return_value, check_evaluations),
**{
check_output_name: check_result
for check_output_name, check_result in zip(check_output_names, check_evaluations)
},
}

return graph_asset_no_defaults(
compose_fn=blocking_asset,
name=asset_def.key.path[-1],
key_prefix=asset_def.key.path[:-1] if len(asset_def.key.path) > 1 else None,
group_name=asset_def.group_names_by_key.get(asset_def.key),
Expand All @@ -194,17 +213,5 @@ def fan_in_checks_and_materialization(context: OpExecutionContext, materializati
freshness_policy=asset_def.freshness_policies_by_key.get(asset_def.key),
auto_materialize_policy=asset_def.auto_materialize_policies_by_key.get(asset_def.key),
backfill_policy=asset_def.backfill_policy,
config=None, # gets config from asset_def.op
)
def blocking_asset(**kwargs):
asset_result = asset_def.op.with_replaced_properties(name="asset_op")(**kwargs)
check_evaluations = [check.node_def(asset_result) for check in checks]

return {
"result": fan_in_checks_and_materialization(asset_result, check_evaluations),
**{
check_output_name: check_result
for check_output_name, check_result in zip(check_output_names, check_evaluations)
},
}

return blocking_asset
Original file line number Diff line number Diff line change
Expand Up @@ -1095,34 +1095,94 @@ def slack_files_table():
**check_outs_by_output_name,
}

op_graph = graph(
name=out_asset_key.to_python_identifier(),
return graph_asset_no_defaults(
compose_fn=compose_fn,
name=name,
description=description,
ins=ins,
config=config,
ins={input_name: GraphIn() for _, (input_name, _) in asset_ins.items()},
out=combined_outs_by_output_name,
)(compose_fn)
return AssetsDefinition.from_graph(
op_graph,
keys_by_input_name=keys_by_input_name,
keys_by_output_name={"result": out_asset_key},
partitions_def=partitions_def,
partition_mappings=partition_mappings if partition_mappings else None,
key_prefix=key_prefix,
group_name=group_name,
metadata_by_output_name={"result": metadata} if metadata else None,
freshness_policies_by_output_name=(
{"result": freshness_policy} if freshness_policy else None
),
auto_materialize_policies_by_output_name=(
{"result": auto_materialize_policy} if auto_materialize_policy else None
),
partitions_def=partitions_def,
metadata=metadata,
freshness_policy=freshness_policy,
auto_materialize_policy=auto_materialize_policy,
backfill_policy=backfill_policy,
descriptions_by_output_name={"result": description} if description else None,
resource_defs=resource_defs,
check_specs=check_specs,
)


def graph_asset_no_defaults(
*,
compose_fn: Callable,
name: Optional[str],
description: Optional[str],
ins: Optional[Mapping[str, AssetIn]],
config: Optional[Union[ConfigMapping, Mapping[str, Any]]],
key_prefix: Optional[CoercibleToAssetKeyPrefix],
group_name: Optional[str],
partitions_def: Optional[PartitionsDefinition],
metadata: Optional[MetadataUserInput],
freshness_policy: Optional[FreshnessPolicy],
auto_materialize_policy: Optional[AutoMaterializePolicy],
backfill_policy: Optional[BackfillPolicy],
resource_defs: Optional[Mapping[str, ResourceDefinition]],
check_specs: Optional[Sequence[AssetCheckSpec]],
) -> AssetsDefinition:
key_prefix = [key_prefix] if isinstance(key_prefix, str) else key_prefix
ins = ins or {}
asset_name = name or compose_fn.__name__
asset_ins = build_asset_ins(compose_fn, ins or {}, set())
out_asset_key = AssetKey(list(filter(None, [*(key_prefix or []), asset_name])))

keys_by_input_name = {input_name: asset_key for asset_key, (input_name, _) in asset_ins.items()}
partition_mappings = {
input_name: asset_in.partition_mapping
for input_name, asset_in in ins.items()
if asset_in.partition_mapping
}

check_specs_by_output_name = _validate_and_assign_output_names_to_check_specs(
check_specs, [out_asset_key]
)
check_outs_by_output_name: Mapping[str, GraphOut] = {
output_name: GraphOut() for output_name in check_specs_by_output_name.keys()
}

combined_outs_by_output_name: Mapping = {
"result": GraphOut(),
**check_outs_by_output_name,
}

op_graph = graph(
name=out_asset_key.to_python_identifier(),
description=description,
config=config,
ins={input_name: GraphIn() for _, (input_name, _) in asset_ins.items()},
out=combined_outs_by_output_name,
)(compose_fn)
return AssetsDefinition.from_graph(
op_graph,
keys_by_input_name=keys_by_input_name,
keys_by_output_name={"result": out_asset_key},
partitions_def=partitions_def,
partition_mappings=partition_mappings if partition_mappings else None,
group_name=group_name,
metadata_by_output_name={"result": metadata} if metadata else None,
freshness_policies_by_output_name=(
{"result": freshness_policy} if freshness_policy else None
),
auto_materialize_policies_by_output_name=(
{"result": auto_materialize_policy} if auto_materialize_policy else None
),
backfill_policy=backfill_policy,
descriptions_by_output_name={"result": description} if description else None,
resource_defs=resource_defs,
check_specs=check_specs,
)


def graph_multi_asset(
*,
outs: Mapping[str, AssetOut],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
asset,
asset_check,
)
from dagster._core.definitions.asset_checks import build_blocking_asset_check
from dagster._core.definitions.asset_checks import build_asset_with_blocking_check
from dagster._core.definitions.asset_in import AssetIn


Expand Down Expand Up @@ -45,7 +45,7 @@ def fail_check_if_tagged(context):
)


blocking_asset = build_blocking_asset_check(
blocking_asset = build_asset_with_blocking_check(
asset_def=my_asset, checks=[pass_check, fail_check_if_tagged]
)

Expand Down Expand Up @@ -114,7 +114,7 @@ def fail_check_if_tagged_2(context, my_asset_with_managed_input):
)


blocking_asset_with_managed_input = build_blocking_asset_check(
blocking_asset_with_managed_input = build_asset_with_blocking_check(
asset_def=my_asset_with_managed_input, checks=[fail_check_if_tagged_2]
)

Expand Down

0 comments on commit 2fec27e

Please sign in to comment.