Skip to content

Commit

Permalink
[ext] add asset checks to ext
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 14, 2023
1 parent 39d10f2 commit 736aa11
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 8 deletions.
67 changes: 66 additions & 1 deletion python_modules/dagster-ext/dagster_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
Iterator,
Literal,
Expand All @@ -26,6 +27,7 @@
Type,
TypedDict,
TypeVar,
Union,
cast,
get_args,
)
Expand Down Expand Up @@ -98,6 +100,16 @@ class ExtDataProvenance(TypedDict):
is_user_provided: bool


ExtAssetCheckSeverity = Literal["WARN", "ERROR"]

ExtMetadataRawValue = Union[int, float, str, Mapping[str, Any], Sequence[Any], bool, None]


class ExtMetadataValue(TypedDict):
metadata_type: Optional["ExtMetadataType"]
value: ExtMetadataRawValue


ExtMetadataType = Literal[
"text",
"url",
Expand Down Expand Up @@ -248,6 +260,30 @@ def _assert_opt_param_value(
return value


def _normalize_param_metadata(
metadata: Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]], method: str, param: str
) -> Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]:
_assert_param_type(metadata, dict, method, param)
new_metadata: Dict[str, ExtMetadataValue] = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise DagsterExtError(
f"Invalid type for parameter `{param}` of `{method}`. Expected a dict with string"
f" keys, got a key `{key}` of type `{type(key)}`."
)
elif isinstance(value, dict):
if not {*value.keys()} == {*ExtMetadataValue.__annotations__.keys()}:
raise DagsterExtError(
f"Invalid type for parameter `{param}` of `{method}`. Expected a dict with"
" string keys and values that are either raw metadata values or dictionaries"
f" with schema `{{value: ..., metadata_type: ...}}`. Got a value `{value}`."
)
new_metadata[key] = cast(ExtMetadataValue, value)
else:
new_metadata[key] = {"value": value, "metadata_type": None}
return new_metadata


def _assert_param_json_serializable(value: _T, method: str, param: str) -> _T:
try:
json.dumps(value)
Expand Down Expand Up @@ -701,7 +737,7 @@ def extras(self) -> Mapping[str, Any]:
def report_asset_metadata(
self,
label: str,
value: Any,
value: ExtMetadataRawValue,
metadata_type: Optional[ExtMetadataType] = None,
asset_key: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -729,6 +765,35 @@ def report_asset_data_version(self, data_version: str, asset_key: Optional[str]
"report_asset_data_version", {"asset_key": asset_key, "data_version": data_version}
)

def report_asset_check_result(
self,
check_name: str,
success: bool,
severity: ExtAssetCheckSeverity = "ERROR",
metadata: Optional[Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]] = None,
asset_key: Optional[str] = None,
) -> None:
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_check_result"
)
check_name = _assert_param_type(check_name, str, "report_asset_check_result", "check_name")
success = _assert_param_type(success, bool, "report_asset_check_result", "success")
metadata = (
_normalize_param_metadata(metadata, "report_asset_check_result", "metadata")
if metadata
else None
)
self._write_message(
"report_asset_check",
{
"asset_key": asset_key,
"check_name": check_name,
"success": success,
"metadata": metadata,
"severity": severity,
},
)

def log(self, message: str, level: str = "info") -> None:
message = _assert_param_type(message, str, "log", "asset_key")
level = _assert_param_value(level, ["info", "warning", "error"], "log", "level")
Expand Down
15 changes: 15 additions & 0 deletions python_modules/dagster-ext/dagster_ext_tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,20 @@ def test_single_asset_context():
context.report_asset_metadata("bar", "boo")
context.report_asset_metadata("baz", 2, "int")
context.report_asset_data_version("bar")
context.report_asset_check_result(
"foo_check",
True,
metadata={
"meta_1": 1,
"meta_2": {"value": "foo", "metadata_type": "text"},
},
)

_assert_unknown_asset_key(context, "report_asset_metadata", "bar", "baz", asset_key="fake")
_assert_unknown_asset_key(context, "report_asset_data_version", "bar", asset_key="fake")
_assert_unknown_asset_key(
context, "report_asset_check_result", "foo_check", True, asset_key="fake"
)


def test_multi_asset_context():
Expand Down Expand Up @@ -114,6 +125,10 @@ def test_multi_asset_context():
_assert_unknown_asset_key(context, "report_asset_metadata", "bar", "baz", asset_key="fake")
_assert_undefined_asset_key(context, "report_asset_data_version", "bar")
_assert_unknown_asset_key(context, "report_asset_data_version", "bar", asset_key="fake")
_assert_undefined_asset_key(context, "report_asset_check_result", "foo_check", True)
_assert_unknown_asset_key(
context, "report_asset_check_result", "foo_check", True, asset_key="fake"
)


def test_no_partition_context():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import boto3
import pytest
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
DATA_VERSION_TAG,
Expand Down Expand Up @@ -43,6 +44,7 @@
ext_protocol,
)
from dagster._core.instance_for_test import instance_for_test
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
from dagster_aws.ext import ExtS3MessageReader
from moto.server import ThreadedMotoServer

Expand Down Expand Up @@ -93,6 +95,15 @@ def script_fn():
time.sleep(0.1) # sleep to make sure that we encompass multiple intervals for blob store IO
context.report_asset_metadata("bar", context.get_extra("bar"), metadata_type="md")
context.report_asset_data_version("alpha")
context.report_asset_check_result(
"foo_check",
success=True,
severity="WARN",
metadata={
"meta_1": 1,
"meta_2": {"value": "foo", "metadata_type": "text"},
},
)

with temp_script(script_fn) as script_path:
yield script_path
Expand Down Expand Up @@ -144,7 +155,7 @@ def test_ext_subprocess(
else:
assert False, "Unreachable"

@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))])
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand Down Expand Up @@ -177,6 +188,14 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
captured = capsys.readouterr()
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=foo.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED


def test_ext_typed_metadata():
def script_fn():
Expand Down Expand Up @@ -301,7 +320,7 @@ def script_fn():


def test_ext_no_client(external_script):
@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["subproc_run"]))])
def subproc_run(context: AssetExecutionContext):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand All @@ -325,3 +344,11 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags
assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=subproc_run.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,15 @@ def asset_checks_def_for_node(
def asset_checks_defs(self) -> Iterable[AssetChecksDefinition]:
return self.asset_checks_defs_by_node_handle.values()

def get_asset_check_for_output_name(self, output_name: str) -> Optional[AssetCheckHandle]:
for (
asset_check_handle,
node_output_handle,
) in self.node_output_handles_by_asset_check_handle.items():
if node_output_handle.output_name == output_name:
return asset_check_handle
return None

def get_output_name_for_asset_check(self, asset_check_handle: AssetCheckHandle) -> str:
"""Output name in the leaf op."""
return self.node_output_handles_by_asset_check_handle[asset_check_handle].output_name
Expand Down
46 changes: 46 additions & 0 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import dagster._check as check
from dagster._annotations import deprecated, experimental, public
from dagster._core.definitions.asset_check_result import AssetCheckResult
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.data_version import (
Expand Down Expand Up @@ -464,6 +465,51 @@ def get_output_metadata(
output_name=output_name, mapping_key=mapping_key
)

@public
@experimental
def add_asset_check_result(self, asset_check_result: AssetCheckResult) -> None:
"""Add an asset check result for an asset being materialized in the current step.
Args:
asset_check_result (AssetCheckResult): The asset check result to add.
**Examples:**
.. code-block:: python
from dagster import op, AssetKey, AssetCheckSeverity
@asset
def foo_asset(context):
...
context.add_asset_check_result(
AssetCheckResult(
asset_key=AssetKey("my_asset"),
check_name="my_check",
success=True,
severity=AssetCheckSeverity.WARNING,
metadata={"foo": "bar"}
)
)
...
"""
check.inst_param(asset_check_result, "asset_check_result", AssetCheckResult)
self._step_execution_context.add_result_object(asset_check_result)

def has_asset_check_result_for_output(self, output_name: str) -> bool:
handle = self.job_def.asset_layer.get_asset_check_for_output_name(output_name)
if handle is None:
return False
result_objects = self.get_step_execution_context().result_objects
for obj in result_objects:
if (
isinstance(obj, AssetCheckResult)
and obj.asset_key == handle.asset_key
and obj.check_name == handle.name
):
return True
return False

def get_step_execution_context(self) -> StepExecutionContext:
"""Allows advanced users (e.g. framework authors) to punch through to the underlying
step execution context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

import dagster._check as check
from dagster._core.definitions.asset_check_result import AssetCheckResult
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.composition import PendingNodeInvocation
from dagster._core.definitions.decorators.op_decorator import DecoratedOpFunction
Expand Down Expand Up @@ -48,6 +49,7 @@
DagsterInvariantViolationError,
)
from dagster._core.execution.build_resources import build_resources, wrap_resources_for_execution
from dagster._core.execution.plan.compute import OpOutputUnion
from dagster._core.instance import DagsterInstance
from dagster._core.log_manager import DagsterLogManager
from dagster._core.storage.dagster_run import DagsterRun
Expand Down Expand Up @@ -116,6 +118,7 @@ def __init__(
self._partition_key_range = partition_key_range
self._user_events: List[UserEvent] = []
self._output_metadata: Dict[str, Any] = {}
self._result_objects: List[OpOutputUnion] = []

self._assets_def = check.opt_inst_param(assets_def, "assets_def", AssetsDefinition)

Expand Down Expand Up @@ -324,6 +327,7 @@ def bind(
),
user_events=self._user_events,
output_metadata=self._output_metadata,
result_objects=self._result_objects,
mapping_key=self._mapping_key,
partition_key=self._partition_key,
partition_key_range=self._partition_key_range,
Expand Down Expand Up @@ -406,6 +410,7 @@ class BoundOpExecutionContext(OpExecutionContext):
_user_events: List[UserEvent]
_seen_outputs: Dict[str, Union[str, Set[str]]]
_output_metadata: Dict[str, Any]
_result_objects: List[OpOutputUnion]
_mapping_key: Optional[str]
_partition_key: Optional[str]
_partition_key_range: Optional[PartitionKeyRange]
Expand All @@ -425,6 +430,7 @@ def __init__(
alias: Optional[str],
user_events: List[UserEvent],
output_metadata: Dict[str, Any],
result_objects: List[OpOutputUnion],
mapping_key: Optional[str],
partition_key: Optional[str],
partition_key_range: Optional[PartitionKeyRange],
Expand All @@ -443,6 +449,7 @@ def __init__(
self._user_events = user_events
self._seen_outputs = {}
self._output_metadata = output_metadata
self._result_objects = result_objects
self._mapping_key = mapping_key
self._partition_key = partition_key
self._partition_key_range = partition_key_range
Expand Down Expand Up @@ -714,6 +721,15 @@ def add_metadata_two_outputs(context) -> Tuple[str, int]:
else:
self._output_metadata[output_name] = metadata

def add_asset_check_result(self, asset_check_result: AssetCheckResult) -> None:
raise DagsterInvariantViolationError(
"Add `add_asset_check_result` is not supported during op invocation"
)

# We can always return false here since `add_asset_check_result` can't be used with invocation.
def has_asset_check_result_for_output(self, output_name: str) -> bool:
return False


def build_op_context(
resources: Optional[Mapping[str, Any]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from dagster._core.definitions.dependency import NodeHandle
from dagster._core.definitions.resource_definition import Resources
from dagster._core.event_api import EventLogRecord
from dagster._core.execution.plan.compute import OpOutputUnion
from dagster._core.execution.plan.plan import ExecutionPlan
from dagster._core.execution.plan.state import KnownExecutionState
from dagster._core.instance import DagsterInstance
Expand Down Expand Up @@ -553,6 +554,7 @@ def __init__(
self._step_output_capture = {}

self._output_metadata: Dict[str, Any] = {}
self._result_objects: List["OpOutputUnion"] = []
self._seen_outputs: Dict[str, Union[str, Set[str]]] = {}

self._input_asset_version_info: Dict[AssetKey, Optional["InputAssetVersionInfo"]] = {}
Expand Down Expand Up @@ -790,6 +792,13 @@ def get_output_metadata(
return metadata.get(mapping_key)
return metadata

def add_result_object(self, obj: "OpOutputUnion") -> None:
self._result_objects.append(obj)

@property
def result_objects(self) -> Sequence["OpOutputUnion"]:
return self._result_objects

def _get_source_run_id_from_logs(self, step_output_handle: StepOutputHandle) -> Optional[str]:
# walk through event logs to find the right run_id based on the run lineage

Expand Down
Loading

0 comments on commit 736aa11

Please sign in to comment.