From eb1f7f04383ed2420990f3bed1bbe5e367022d8d Mon Sep 17 00:00:00 2001 From: Johann Miller Date: Wed, 20 Sep 2023 18:11:02 -0400 Subject: [PATCH] fix inputs --- .../dagster/_core/definitions/asset_checks.py | 14 ++++- .../test_blocking_asset_checks.py | 54 +++++++++++-------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/asset_checks.py b/python_modules/dagster/dagster/_core/definitions/asset_checks.py index 176560332d1f1..d0bd4ae9da4aa 100644 --- a/python_modules/dagster/dagster/_core/definitions/asset_checks.py +++ b/python_modules/dagster/dagster/_core/definitions/asset_checks.py @@ -159,7 +159,10 @@ def build_blocking_asset_check( check_output_names = [c.get_python_identifier() for c in check_specs] - @op(ins={"materialization": In(Any), "check_evaluations": In(Nothing)}) + check.invariant(len(asset_def.op.output_defs) == 1) + asset_out_type = asset_def.op.output_defs[0].dagster_type + + @op(ins={"materialization": In(asset_out_type), "check_evaluations": In(Nothing)}) def fan_in_checks_and_materialization(context: OpExecutionContext, materialization): yield Output(materialization) @@ -181,9 +184,16 @@ def fan_in_checks_and_materialization(context: OpExecutionContext, materializati @graph_asset( name=asset_def.key.path[-1], key_prefix=asset_def.key.path[:-1] if len(asset_def.key.path) > 1 else None, + group_name=asset_def.group_names_by_key.get(asset_def.key), + partitions_def=asset_def.partitions_def, check_specs=check_specs, description=asset_def.descriptions_by_key.get(asset_def.key), - ins={name: AssetIn(key) for name, key in asset_def.keys_by_input_name.items()} + ins={name: AssetIn(key) for name, key in asset_def.keys_by_input_name.items()}, + resource_defs=asset_def.resource_defs, + metadata=asset_def.metadata_by_key.get(asset_def.key), + freshness_policy=asset_def.freshness_policies_by_key.get(asset_def.key), + auto_materialize_policy=asset_def.auto_materialize_policies_by_key.get(asset_def.key), + backfill_policy=asset_def.backfill_policy, ) def blocking_asset(**kwargs): asset_result = asset_def.op.with_replaced_properties(name="asset_op")(**kwargs) diff --git a/python_modules/dagster/dagster_tests/definitions_tests/test_blocking_asset_checks.py b/python_modules/dagster/dagster_tests/definitions_tests/test_blocking_asset_checks.py index 2b9fedca1846c..0ef0e27ad1b91 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/test_blocking_asset_checks.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/test_blocking_asset_checks.py @@ -22,9 +22,11 @@ def execute_assets_and_checks( job_def = defs.get_implicit_global_asset_job_def() return job_def.execute_in_process(raise_on_error=raise_on_error, instance=instance, tags=tags) + @asset def upstream_asset(): - pass + return "foo" + @asset(deps=[upstream_asset]) def my_asset(): @@ -52,6 +54,7 @@ def fail_check_if_tagged(context): def downstream_asset(): pass + def test_check_pass(): result = execute_assets_and_checks( assets=[upstream_asset, blocking_asset, downstream_asset], raise_on_error=False @@ -76,7 +79,9 @@ def test_check_pass(): def test_check_fail_and_block(): result = execute_assets_and_checks( - assets=[upstream_asset, blocking_asset, downstream_asset], raise_on_error=False, tags={"fail_check": "true"} + assets=[upstream_asset, blocking_asset, downstream_asset], + raise_on_error=False, + tags={"fail_check": "true"}, ) assert not result.success @@ -95,17 +100,17 @@ def test_check_fail_and_block(): assert materialization_events[1].asset_key == AssetKey(["my_asset"]) - @asset def my_asset_with_managed_input(upstream_asset): - pass - + assert upstream_asset == "foo" + return "bar" @asset_check(asset="my_asset_with_managed_input") -def fail_check_if_tagged_2(context): +def fail_check_if_tagged_2(context, my_asset_with_managed_input): + assert my_asset_with_managed_input == "bar" return AssetCheckResult( - success=not context.has_tag("fail_check"), check_name="fail_check_if_tagged" + success=not context.has_tag("fail_check"), check_name="fail_check_if_tagged_2" ) @@ -113,50 +118,53 @@ def fail_check_if_tagged_2(context): asset_def=my_asset_with_managed_input, checks=[fail_check_if_tagged_2] ) + @asset(ins={"input_asset": AssetIn(blocking_asset_with_managed_input.key)}) def downstream_asset_2(input_asset): - pass + assert input_asset == "bar" + def test_check_pass_with_inputs(): result = execute_assets_and_checks( - assets=[upstream_asset, blocking_asset_with_managed_input, downstream_asset_2], raise_on_error=False + assets=[upstream_asset, blocking_asset_with_managed_input, downstream_asset_2], + raise_on_error=False, ) assert result.success check_evals = result.get_asset_check_evaluations() - assert len(check_evals) == 2 + assert len(check_evals) == 1 check_evals_by_name = {check_eval.check_name: check_eval for check_eval in check_evals} - assert check_evals_by_name["pass_check"].success - assert check_evals_by_name["pass_check"].asset_key == AssetKey(["my_asset_with_managed_input"]) - assert check_evals_by_name["fail_check_if_tagged"].success - assert check_evals_by_name["fail_check_if_tagged"].asset_key == AssetKey(["my_asset_with_managed_input"]) + assert check_evals_by_name["fail_check_if_tagged_2"].success + assert check_evals_by_name["fail_check_if_tagged_2"].asset_key == AssetKey( + ["my_asset_with_managed_input"] + ) # downstream asset materializes materialization_events = result.get_asset_materialization_events() assert len(materialization_events) == 3 assert materialization_events[0].asset_key == AssetKey(["upstream_asset"]) assert materialization_events[1].asset_key == AssetKey(["my_asset_with_managed_input"]) - assert materialization_events[2].asset_key == AssetKey(["downstream_asset"]) + assert materialization_events[2].asset_key == AssetKey(["downstream_asset_2"]) def test_check_fail_and_block_with_inputs(): result = execute_assets_and_checks( - assets=[upstream_asset, blocking_asset_with_managed_input, downstream_asset_2], raise_on_error=False, tags={"fail_check": "true"} + assets=[upstream_asset, blocking_asset_with_managed_input, downstream_asset_2], + raise_on_error=False, + tags={"fail_check": "true"}, ) assert not result.success check_evals = result.get_asset_check_evaluations() - assert len(check_evals) == 2 + assert len(check_evals) == 1 check_evals_by_name = {check_eval.check_name: check_eval for check_eval in check_evals} - assert check_evals_by_name["pass_check"].success - assert check_evals_by_name["pass_check"].asset_key == AssetKey(["my_asset_with_managed_input"]) - assert not check_evals_by_name["fail_check_if_tagged"].success - assert check_evals_by_name["fail_check_if_tagged"].asset_key == AssetKey(["my_asset_with_managed_input"]) + assert not check_evals_by_name["fail_check_if_tagged_2"].success + assert check_evals_by_name["fail_check_if_tagged_2"].asset_key == AssetKey( + ["my_asset_with_managed_input"] + ) # downstream asset should not have been materialized materialization_events = result.get_asset_materialization_events() assert len(materialization_events) == 2 assert materialization_events[0].asset_key == AssetKey(["upstream_asset"]) assert materialization_events[1].asset_key == AssetKey(["my_asset_with_managed_input"]) - -