Skip to content

Commit

Permalink
[ext] remove report_asset_metadata, report_asset_data_version (#16683)
Browse files Browse the repository at this point in the history
## Summary & Motivation

Remove `report_asset_metadata` and `report_asset_data_version` in favor
of a single `report_asset_materialization` method, which requires
reporting all of the data associated with a materialization at once.
Calling `report_asset_materialization` twice for the same asset is an
error.

## How I Tested These Changes

Updated unit tests.
  • Loading branch information
smackesey authored Sep 21, 2023
1 parent 7aa1b1e commit c402db0
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def query(self, query_str: str) -> None:

client = SomeSqlClient()
client.query(sql)
context.report_asset_metadata("sql", sql)
context.report_asset_materialization(metadata={"sql": sql})
context.log(f"Ran {sql}")
90 changes: 66 additions & 24 deletions python_modules/dagster-ext/dagster_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
Iterator,
Literal,
Mapping,
Optional,
Sequence,
Set,
TextIO,
Type,
TypedDict,
TypeVar,
Union,
cast,
get_args,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -98,7 +100,19 @@ class ExtDataProvenance(TypedDict):
is_user_provided: bool


ExtMetadataRawValue = Union[int, float, str, Mapping[str, Any], Sequence[Any], bool, None]


class ExtMetadataValue(TypedDict):
type: "ExtMetadataType"
raw_value: ExtMetadataRawValue


# Infer the type from the raw value on the orchestration end
EXT_METADATA_TYPE_INFER = "__infer__"

ExtMetadataType = Literal[
"__infer__",
"text",
"url",
"path",
Expand Down Expand Up @@ -148,7 +162,10 @@ def _assert_single_asset(data: ExtContextData, key: str) -> None:


def _resolve_optionally_passed_asset_key(
data: ExtContextData, asset_key: Optional[str], method: str
data: ExtContextData,
asset_key: Optional[str],
method: str,
already_materialized_assets: Set[str],
) -> str:
asset_keys = _assert_defined_asset_property(data["asset_keys"], method)
asset_key = _assert_opt_param_type(asset_key, str, method, "asset_key")
Expand All @@ -163,6 +180,11 @@ def _resolve_optionally_passed_asset_key(
" targets multiple assets."
)
asset_key = asset_keys[0]
if asset_key in already_materialized_assets:
raise DagsterExtError(
f"Calling `{method}` with asset key `{asset_key}` is undefined. Asset has already been"
" materialized, so no additional data can be reported for it."
)
return asset_key


Expand Down Expand Up @@ -259,6 +281,33 @@ def _assert_param_json_serializable(value: _T, method: str, param: str) -> _T:
return value


_METADATA_VALUE_KEYS = frozenset(ExtMetadataValue.__annotations__.keys())


def _normalize_param_metadata(
metadata: Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]], method: str, param: str
) -> Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]:
_assert_param_type(metadata, dict, method, param)
new_metadata: Dict[str, ExtMetadataValue] = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise DagsterExtError(
f"Invalid type for parameter `{param}` of `{method}`. Expected a dict with string"
f" keys, got a key `{key}` of type `{type(key)}`."
)
elif isinstance(value, dict):
if not {*value.keys()} == _METADATA_VALUE_KEYS:
raise DagsterExtError(
f"Invalid type for parameter `{param}` of `{method}`. Expected a dict with"
" string keys and values that are either raw metadata values or dictionaries"
f" with schema `{{raw_value: ..., type: ...}}`. Got a value `{value}`."
)
new_metadata[key] = cast(ExtMetadataValue, value)
else:
new_metadata[key] = {"raw_value": value, "type": EXT_METADATA_TYPE_INFER}
return new_metadata


def _param_from_env_var(key: str) -> Any:
raw_value = os.environ.get(_param_name_to_env_var(key))
return decode_env_var(raw_value) if raw_value is not None else None
Expand Down Expand Up @@ -625,6 +674,7 @@ def __init__(
) -> None:
self._data = data
self._message_channel = message_channel
self._materialized_assets: set[str] = set()

def _write_message(self, method: str, params: Optional[Mapping[str, Any]] = None) -> None:
message = ExtMessage(method=method, params=params)
Expand Down Expand Up @@ -727,36 +777,28 @@ def extras(self) -> Mapping[str, Any]:

# ##### WRITE

def report_asset_metadata(
def report_asset_materialization(
self,
label: str,
value: Any,
metadata_type: Optional[ExtMetadataType] = None,
metadata: Optional[Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]] = None,
data_version: Optional[str] = None,
asset_key: Optional[str] = None,
) -> None:
):
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_metadata"
self._data, asset_key, "report_asset_materialization", self._materialized_assets
)
label = _assert_param_type(label, str, "report_asset_metadata", "label")
value = _assert_param_json_serializable(value, "report_asset_metadata", "value")
metadata_type = _assert_opt_param_value(
metadata_type, get_args(ExtMetadataType), "report_asset_metadata", "type"
)
self._write_message(
"report_asset_metadata",
{"asset_key": asset_key, "label": label, "value": value, "type": metadata_type},
)

def report_asset_data_version(self, data_version: str, asset_key: Optional[str] = None) -> None:
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_data_version"
metadata = (
_normalize_param_metadata(metadata, "report_asset_materialization", "metadata")
if metadata
else None
)
data_version = _assert_param_type(
data_version, str, "report_asset_data_version", "data_version"
data_version = _assert_opt_param_type(
data_version, str, "report_asset_materialization", "data_version"
)
self._write_message(
"report_asset_data_version", {"asset_key": asset_key, "data_version": data_version}
"report_asset_materialization",
{"asset_key": asset_key, "data_version": data_version, "metadata": metadata},
)
self._materialized_assets.add(asset_key)

def log(self, message: str, level: str = "info") -> None:
message = _assert_param_type(message, str, "log", "asset_key")
Expand Down
26 changes: 17 additions & 9 deletions python_modules/dagster-ext/dagster_ext_tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,15 @@ def test_single_asset_context():
assert context.code_version_by_asset_key == {"foo": "beta"}
assert context.provenance == foo_provenance
assert context.provenance_by_asset_key == {"foo": foo_provenance}
context.report_asset_metadata("bar", "boo")
context.report_asset_metadata("baz", 2, "int")
context.report_asset_data_version("bar")
context.report_asset_materialization(
metadata={
"bar": "boo",
"baz": {"raw_value": 2, "type": "int"},
},
data_version="bar",
)

_assert_unknown_asset_key(context, "report_asset_metadata", "bar", "baz", asset_key="fake")
_assert_unknown_asset_key(context, "report_asset_data_version", "bar", asset_key="fake")
_assert_unknown_asset_key(context, "report_asset_materialization", asset_key="fake")


def test_multi_asset_context():
Expand All @@ -110,10 +113,8 @@ def test_multi_asset_context():
_assert_undefined(context, "provenance")
assert context.provenance_by_asset_key == {"foo": foo_provenance, "bar": bar_provenance}

_assert_undefined_asset_key(context, "report_asset_metadata", "bar", "baz")
_assert_unknown_asset_key(context, "report_asset_metadata", "bar", "baz", asset_key="fake")
_assert_undefined_asset_key(context, "report_asset_data_version", "bar")
_assert_unknown_asset_key(context, "report_asset_data_version", "bar", asset_key="fake")
_assert_undefined_asset_key(context, "report_asset_materialization", "bar")
_assert_unknown_asset_key(context, "report_asset_materialization", "bar", asset_key="fake")


def test_no_partition_context():
Expand Down Expand Up @@ -162,3 +163,10 @@ def test_extras_context():
assert context.get_extra("foo") == "bar"
with pytest.raises(DagsterExtError, match="Extra `bar` is undefined"):
context.get_extra("bar")


def test_report_twice_materialized():
context = _make_external_execution_context(asset_keys=["foo"])
with pytest.raises(DagsterExtError, match="already been materialized"):
context.report_asset_materialization(asset_key="foo")
context.report_asset_materialization(asset_key="foo")
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ def script_fn():
context = ExtContext.get()
context.log("hello world")
time.sleep(0.1) # sleep to make sure that we encompass multiple intervals for blob store IO
context.report_asset_metadata("bar", context.get_extra("bar"), metadata_type="md")
context.report_asset_data_version("alpha")
context.report_asset_materialization(
metadata={"bar": {"raw_value": context.get_extra("bar"), "type": "md"}},
data_version="alpha",
)

with temp_script(script_fn) as script_path:
yield script_path
Expand Down Expand Up @@ -183,19 +185,23 @@ def script_fn():
from dagster_ext import init_dagster_ext

context = init_dagster_ext()
context.report_asset_metadata("untyped_meta", "bar")
context.report_asset_metadata("text_meta", "bar", metadata_type="text")
context.report_asset_metadata("url_meta", "http://bar.com", metadata_type="url")
context.report_asset_metadata("path_meta", "/bar", metadata_type="path")
context.report_asset_metadata("notebook_meta", "/bar.ipynb", metadata_type="notebook")
context.report_asset_metadata("json_meta", ["bar"], metadata_type="json")
context.report_asset_metadata("md_meta", "bar", metadata_type="md")
context.report_asset_metadata("float_meta", 1.0, metadata_type="float")
context.report_asset_metadata("int_meta", 1, metadata_type="int")
context.report_asset_metadata("bool_meta", True, metadata_type="bool")
context.report_asset_metadata("dagster_run_meta", "foo", metadata_type="dagster_run")
context.report_asset_metadata("asset_meta", "bar/baz", metadata_type="asset")
context.report_asset_metadata("null_meta", None, metadata_type="null")
context.report_asset_materialization(
metadata={
"infer_meta": "bar",
"text_meta": {"raw_value": "bar", "type": "text"},
"url_meta": {"raw_value": "http://bar.com", "type": "url"},
"path_meta": {"raw_value": "/bar", "type": "path"},
"notebook_meta": {"raw_value": "/bar.ipynb", "type": "notebook"},
"json_meta": {"raw_value": ["bar"], "type": "json"},
"md_meta": {"raw_value": "bar", "type": "md"},
"float_meta": {"raw_value": 1.0, "type": "float"},
"int_meta": {"raw_value": 1, "type": "int"},
"bool_meta": {"raw_value": True, "type": "bool"},
"dagster_run_meta": {"raw_value": "foo", "type": "dagster_run"},
"asset_meta": {"raw_value": "bar/baz", "type": "asset"},
"null_meta": {"raw_value": None, "type": "null"},
}
)

@asset
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
Expand All @@ -212,8 +218,8 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
mat = instance.get_latest_materialization_event(foo.key)
assert mat and mat.asset_materialization
metadata = mat.asset_materialization.metadata
assert isinstance(metadata["untyped_meta"], TextMetadataValue)
assert metadata["untyped_meta"].value == "bar"
# assert isinstance(metadata["infer_meta"], TextMetadataValue)
# assert metadata["infer_meta"].value == "bar"
assert isinstance(metadata["text_meta"], TextMetadataValue)
assert metadata["text_meta"].value == "bar"
assert isinstance(metadata["url_meta"], UrlMetadataValue)
Expand Down Expand Up @@ -286,8 +292,10 @@ def script_fn():
init_dagster_ext()
context = ExtContext.get()
context.log("hello world")
context.report_asset_metadata("bar", context.get_extra("bar"))
context.report_asset_data_version("alpha")
context.report_asset_materialization(
metadata={"bar": context.get_extra("bar")},
data_version="alpha",
)

with temp_script(script_fn) as script_path:
cmd = ["python", script_path]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
store_asset_value("number_sum", storage_root, value)

context.log(f"{context.asset_key}: {number_x} + {number_y} = {value}")
context.report_asset_data_version(compute_data_version(value))
context.report_asset_materialization(data_version=compute_data_version(value))
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
store_asset_value("number_x", storage_root, value)

context.log(f"{context.asset_key}: {2} * {multiplier} = {value}")
context.report_asset_data_version(compute_data_version(value))
context.report_asset_materialization(data_version=compute_data_version(value))
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@
store_asset_value("number_y", storage_root, value)

context.log(f"{context.asset_key}: {value} read from $NUMBER_Y environment variable.")
context.report_asset_metadata("is_even", value % 2 == 0)
context.report_asset_data_version(compute_data_version(value))
context.report_asset_materialization(
metadata={"is_even": value % 2 == 0},
data_version=compute_data_version(value),
)
54 changes: 25 additions & 29 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass
from typing import Any, Mapping, Optional, get_args
from typing import Any, Mapping, Optional

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
EXT_METADATA_TYPE_INFER,
IS_DAGSTER_EXT_PROCESS_ENV_VAR,
ExtContextData,
ExtDataProvenance,
Expand All @@ -28,30 +29,8 @@ class ExtMessageHandler:
def __init__(self, context: OpExecutionContext) -> None:
self._context = context

# Type ignores because we currently validate in individual handlers
def handle_message(self, message: ExtMessage) -> None:
if message["method"] == "report_asset_metadata":
self._handle_report_asset_metadata(**message["params"]) # type: ignore
elif message["method"] == "report_asset_data_version":
self._handle_report_asset_data_version(**message["params"]) # type: ignore
elif message["method"] == "log":
self._handle_log(**message["params"]) # type: ignore

def _handle_report_asset_metadata(
self, asset_key: str, label: str, value: Any, type: ExtMetadataType # noqa: A002
) -> None:
check.str_param(asset_key, "asset_key")
check.str_param(label, "label")
check.opt_literal_param(type, "type", get_args(ExtMetadataType))
key = AssetKey.from_user_string(asset_key)
output_name = self._context.output_for_asset_key(key)
metadata_value = self._resolve_metadata_value(value, type)
self._context.add_output_metadata({label: metadata_value}, output_name)

def _resolve_metadata_value(
self, value: Any, metadata_type: Optional[ExtMetadataType]
) -> MetadataValue:
if metadata_type is None:
def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) -> MetadataValue:
if metadata_type == EXT_METADATA_TYPE_INFER:
return normalize_metadata_value(value)
elif metadata_type == "text":
return MetadataValue.text(value)
Expand Down Expand Up @@ -82,11 +61,28 @@ def _resolve_metadata_value(
else:
check.failed(f"Unexpected metadata type {metadata_type}")

def _handle_report_asset_data_version(self, asset_key: str, data_version: str) -> None:
# Type ignores because we currently validate in individual handlers
def handle_message(self, message: ExtMessage) -> None:
if message["method"] == "report_asset_materialization":
self._handle_report_asset_materialization(**message["params"]) # type: ignore
elif message["method"] == "log":
self._handle_log(**message["params"]) # type: ignore

def _handle_report_asset_materialization(
self, asset_key: str, metadata: Optional[Mapping[str, Any]], data_version: Optional[str]
) -> None:
check.str_param(asset_key, "asset_key")
check.str_param(data_version, "data_version")
key = AssetKey.from_user_string(asset_key)
self._context.set_data_version(key, DataVersion(data_version))
check.opt_str_param(data_version, "data_version")
metadata = check.opt_mapping_param(metadata, "metadata", key_type=str)
resolved_asset_key = AssetKey.from_user_string(asset_key)
resolved_metadata = {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}
if data_version is not None:
self._context.set_data_version(resolved_asset_key, DataVersion(data_version))
if resolved_metadata:
output_name = self._context.output_for_asset_key(resolved_asset_key)
self._context.add_output_metadata(resolved_metadata, output_name)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
Expand Down
Loading

0 comments on commit c402db0

Please sign in to comment.