Skip to content

Commit

Permalink
fix inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
johannkm committed Sep 26, 2023
1 parent 342c406 commit eb1f7f0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
14 changes: 12 additions & 2 deletions python_modules/dagster/dagster/_core/definitions/asset_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def build_blocking_asset_check(

check_output_names = [c.get_python_identifier() for c in check_specs]

@op(ins={"materialization": In(Any), "check_evaluations": In(Nothing)})
check.invariant(len(asset_def.op.output_defs) == 1)
asset_out_type = asset_def.op.output_defs[0].dagster_type

@op(ins={"materialization": In(asset_out_type), "check_evaluations": In(Nothing)})
def fan_in_checks_and_materialization(context: OpExecutionContext, materialization):
yield Output(materialization)

Expand All @@ -181,9 +184,16 @@ def fan_in_checks_and_materialization(context: OpExecutionContext, materializati
@graph_asset(
name=asset_def.key.path[-1],
key_prefix=asset_def.key.path[:-1] if len(asset_def.key.path) > 1 else None,
group_name=asset_def.group_names_by_key.get(asset_def.key),
partitions_def=asset_def.partitions_def,
check_specs=check_specs,
description=asset_def.descriptions_by_key.get(asset_def.key),
ins={name: AssetIn(key) for name, key in asset_def.keys_by_input_name.items()}
ins={name: AssetIn(key) for name, key in asset_def.keys_by_input_name.items()},
resource_defs=asset_def.resource_defs,
metadata=asset_def.metadata_by_key.get(asset_def.key),
freshness_policy=asset_def.freshness_policies_by_key.get(asset_def.key),
auto_materialize_policy=asset_def.auto_materialize_policies_by_key.get(asset_def.key),
backfill_policy=asset_def.backfill_policy,
)
def blocking_asset(**kwargs):
asset_result = asset_def.op.with_replaced_properties(name="asset_op")(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ def execute_assets_and_checks(
job_def = defs.get_implicit_global_asset_job_def()
return job_def.execute_in_process(raise_on_error=raise_on_error, instance=instance, tags=tags)


@asset
def upstream_asset():
pass
return "foo"


@asset(deps=[upstream_asset])
def my_asset():
Expand Down Expand Up @@ -52,6 +54,7 @@ def fail_check_if_tagged(context):
def downstream_asset():
pass


def test_check_pass():
result = execute_assets_and_checks(
assets=[upstream_asset, blocking_asset, downstream_asset], raise_on_error=False
Expand All @@ -76,7 +79,9 @@ def test_check_pass():

def test_check_fail_and_block():
result = execute_assets_and_checks(
assets=[upstream_asset, blocking_asset, downstream_asset], raise_on_error=False, tags={"fail_check": "true"}
assets=[upstream_asset, blocking_asset, downstream_asset],
raise_on_error=False,
tags={"fail_check": "true"},
)
assert not result.success

Expand All @@ -95,68 +100,71 @@ def test_check_fail_and_block():
assert materialization_events[1].asset_key == AssetKey(["my_asset"])



@asset
def my_asset_with_managed_input(upstream_asset):
pass

assert upstream_asset == "foo"
return "bar"


@asset_check(asset="my_asset_with_managed_input")
def fail_check_if_tagged_2(context):
def fail_check_if_tagged_2(context, my_asset_with_managed_input):
assert my_asset_with_managed_input == "bar"
return AssetCheckResult(
success=not context.has_tag("fail_check"), check_name="fail_check_if_tagged"
success=not context.has_tag("fail_check"), check_name="fail_check_if_tagged_2"
)


blocking_asset_with_managed_input = build_blocking_asset_check(
asset_def=my_asset_with_managed_input, checks=[fail_check_if_tagged_2]
)


@asset(ins={"input_asset": AssetIn(blocking_asset_with_managed_input.key)})
def downstream_asset_2(input_asset):
pass
assert input_asset == "bar"


def test_check_pass_with_inputs():
result = execute_assets_and_checks(
assets=[upstream_asset, blocking_asset_with_managed_input, downstream_asset_2], raise_on_error=False
assets=[upstream_asset, blocking_asset_with_managed_input, downstream_asset_2],
raise_on_error=False,
)
assert result.success

check_evals = result.get_asset_check_evaluations()
assert len(check_evals) == 2
assert len(check_evals) == 1
check_evals_by_name = {check_eval.check_name: check_eval for check_eval in check_evals}
assert check_evals_by_name["pass_check"].success
assert check_evals_by_name["pass_check"].asset_key == AssetKey(["my_asset_with_managed_input"])
assert check_evals_by_name["fail_check_if_tagged"].success
assert check_evals_by_name["fail_check_if_tagged"].asset_key == AssetKey(["my_asset_with_managed_input"])
assert check_evals_by_name["fail_check_if_tagged_2"].success
assert check_evals_by_name["fail_check_if_tagged_2"].asset_key == AssetKey(
["my_asset_with_managed_input"]
)

# downstream asset materializes
materialization_events = result.get_asset_materialization_events()
assert len(materialization_events) == 3
assert materialization_events[0].asset_key == AssetKey(["upstream_asset"])
assert materialization_events[1].asset_key == AssetKey(["my_asset_with_managed_input"])
assert materialization_events[2].asset_key == AssetKey(["downstream_asset"])
assert materialization_events[2].asset_key == AssetKey(["downstream_asset_2"])


def test_check_fail_and_block_with_inputs():
result = execute_assets_and_checks(
assets=[upstream_asset, blocking_asset_with_managed_input, downstream_asset_2], raise_on_error=False, tags={"fail_check": "true"}
assets=[upstream_asset, blocking_asset_with_managed_input, downstream_asset_2],
raise_on_error=False,
tags={"fail_check": "true"},
)
assert not result.success

check_evals = result.get_asset_check_evaluations()
assert len(check_evals) == 2
assert len(check_evals) == 1
check_evals_by_name = {check_eval.check_name: check_eval for check_eval in check_evals}
assert check_evals_by_name["pass_check"].success
assert check_evals_by_name["pass_check"].asset_key == AssetKey(["my_asset_with_managed_input"])
assert not check_evals_by_name["fail_check_if_tagged"].success
assert check_evals_by_name["fail_check_if_tagged"].asset_key == AssetKey(["my_asset_with_managed_input"])
assert not check_evals_by_name["fail_check_if_tagged_2"].success
assert check_evals_by_name["fail_check_if_tagged_2"].asset_key == AssetKey(
["my_asset_with_managed_input"]
)

# downstream asset should not have been materialized
materialization_events = result.get_asset_materialization_events()
assert len(materialization_events) == 2
assert materialization_events[0].asset_key == AssetKey(["upstream_asset"])
assert materialization_events[1].asset_key == AssetKey(["my_asset_with_managed_input"])


0 comments on commit eb1f7f0

Please sign in to comment.