Skip to content

Commit

Permalink
[lazy-defs] Restrict reconstruction metadata to string values (#24747)
Browse files Browse the repository at this point in the history
## Summary & Motivation

Make reconstruction metadata only accept string values. Add error
message about need to pre-serialize when string is not passed.

## How I Tested These Changes

New unit tests.

## Changelog

NOCHANGELOG

- [ ] `NEW` _(added new feature or capability)_
- [ ] `BUGFIX` _(fixed a bug)_
- [ ] `DOCS` _(added or updated documentation)_
  • Loading branch information
smackesey authored Sep 27, 2024
1 parent 41fbdbd commit 7998da4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -723,20 +723,30 @@ def get_all_asset_specs(self) -> Sequence[AssetSpec]:
return [asset_node.to_asset_spec() for asset_node in asset_graph.asset_nodes]

@experimental
def with_reconstruction_metadata(self, state_metadata: Mapping[str, Any]) -> Self:
"""Add metadata to the Definitions object. This is typically used to cache data
def with_reconstruction_metadata(self, reconstruction_metadata: Mapping[str, str]) -> Self:
"""Add reconstruction metadata to the Definitions object. This is typically used to cache data
loaded from some external API that is computed during initialization of a code server.
The cached data is then made available on the DefinitionsLoadContext during
reconstruction of the same code location context (such as a run worker), allowing use of the
cached data to avoid additional external API queries. Values must be JSON-serializable.
cached data to avoid additional external API queries. Values are expected to be serialized
in advance and must be strings.
"""
state_metadata = {
k: CodeLocationReconstructionMetadataValue(v) for k, v in state_metadata.items()
check.mapping_param(reconstruction_metadata, "reconstruction_metadata", key_type=str)
for k, v in reconstruction_metadata.items():
if not isinstance(v, str):
raise DagsterInvariantViolationError(
f"Reconstruction metadata values must be strings. State-representing values are"
f" expected to be serialized before being passed as reconstruction metadata."
f" Got for key {k}:\n\n{v}"
)
normalized_metadata = {
k: CodeLocationReconstructionMetadataValue(v)
for k, v in reconstruction_metadata.items()
}
return copy(
self,
metadata={
**(self.metadata or {}),
**state_metadata,
**normalized_metadata,
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
)
from dagster._core.errors import DagsterInvalidMetadata
from dagster._serdes import whitelist_for_serdes
from dagster._serdes.errors import SerializationError
from dagster._serdes.serdes import PackableValue, serialize_value
from dagster._serdes.serdes import PackableValue

T_Packable = TypeVar("T_Packable", bound=PackableValue, default=PackableValue, covariant=True)
from dagster._serdes import pack_value
Expand Down Expand Up @@ -492,13 +491,13 @@ def null() -> "NullMetadataValue":

# not public because rest of code location metadata API is not public
@staticmethod
def code_location_reconstruction(data: Any) -> "CodeLocationReconstructionMetadataValue":
def code_location_reconstruction(data: str) -> "CodeLocationReconstructionMetadataValue":
"""Static constructor for a metadata value wrapping arbitrary code location data useful during reconstruction as
:py:class:`CodeLocationReconstructionMetadataValue`. Can be used as the value type for the `metadata`
parameter for supported events.
Args:
data (Any): The code location state for a metadata entry.
data (str): The serialized code location state for a metadata entry.
"""
return CodeLocationReconstructionMetadataValue(data)

Expand Down Expand Up @@ -1034,26 +1033,24 @@ def value(self) -> None:

@whitelist_for_serdes
class CodeLocationReconstructionMetadataValue(
NamedTuple("_CodeLocationReconstructionMetadataValue", [("data", PublicAttr[Any])]),
MetadataValue[Any],
NamedTuple("_CodeLocationReconstructionMetadataValue", [("data", PublicAttr[str])]),
MetadataValue[str],
):
"""Representation of some state data used to define the Definitions in a code location.
"""Representation of some state data used to define the Definitions in a code location. Users
are expected to serialize data before passing it to this class.
Args:
data (Any): Arbitrary JSON-serializable data used to define the Definitions in a
data (str): A string representing data used to define the Definitions in a
code location.
"""

def __new__(cls, data: Any):
try:
serialize_value(data)
except SerializationError:
raise DagsterInvalidMetadata("Value is not JSON-serializable.")

return super(CodeLocationReconstructionMetadataValue, cls).__new__(cls, data)
def __new__(cls, data: str):
return super(CodeLocationReconstructionMetadataValue, cls).__new__(
cls, check.str_param(data, "data")
)

@public
@property
def value(self) -> Any:
"""None: The wrapped code location state data."""
def value(self) -> str:
"""str: The wrapped code location state data."""
return self.data
12 changes: 5 additions & 7 deletions python_modules/dagster/dagster_tests/core_tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
UrlMetadataValue,
op,
)
from dagster._check.functions import CheckError
from dagster._core.definitions.metadata.metadata_value import (
CodeLocationReconstructionMetadataValue,
)
from dagster._core.errors import DagsterInvalidMetadata
from dagster._serdes.serdes import deserialize_value, serialize_value


Expand Down Expand Up @@ -113,13 +113,11 @@ def test_json_metadata_value():


def test_code_location_reconstruction_metadata_value():
assert CodeLocationReconstructionMetadataValue({"a": "b"}).data == {"a": "b"}
assert CodeLocationReconstructionMetadataValue({"a": "b"}).value == {"a": "b"}
assert CodeLocationReconstructionMetadataValue("abc").data == "abc"
assert CodeLocationReconstructionMetadataValue(1).data == 1
assert CodeLocationReconstructionMetadataValue("foo").data == "foo"
assert CodeLocationReconstructionMetadataValue("foo").value == "foo"

with pytest.raises(DagsterInvalidMetadata, match="not JSON-serializable"):
CodeLocationReconstructionMetadataValue(object())
with pytest.raises(CheckError, match="not a str"):
CodeLocationReconstructionMetadataValue({"foo": "bar"})


def test_serdes_json_metadata():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -25,6 +26,7 @@
RepositoryLoadData,
)
from dagster._core.definitions.unresolved_asset_job_definition import define_asset_job
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.execution.api import execute_job
from dagster._core.instance_for_test import instance_for_test
from dagster._record import record
Expand All @@ -50,14 +52,16 @@ def _get_foo_integration_defs(context: DefinitionsLoadContext, workspace_id: str
context.load_type == DefinitionsLoadType.RECONSTRUCTION
and cache_key in context.reconstruction_metadata
):
payload = context.reconstruction_metadata[cache_key]
serialized_payload = context.reconstruction_metadata[cache_key]
payload = json.loads(serialized_payload)
else:
payload = fetch_foo_integration_asset_info(workspace_id)
serialized_payload = json.dumps(payload)
asset_specs = [AssetSpec(item["id"]) for item in payload]
assets = external_assets_from_specs(asset_specs)
return Definitions(
assets=assets,
).with_reconstruction_metadata({cache_key: payload})
).with_reconstruction_metadata({cache_key: serialized_payload})


@lazy_definitions
Expand Down Expand Up @@ -97,7 +101,7 @@ def test_reconstruction_metadata():
cacheable_asset_data={},
reconstruction_metadata={
f"{FOO_INTEGRATION_SOURCE_KEY}/{WORKSPACE_ID}": MetadataValue.code_location_reconstruction(
fetch_foo_integration_asset_info(WORKSPACE_ID)
json.dumps(fetch_foo_integration_asset_info(WORKSPACE_ID))
)
},
)
Expand All @@ -111,6 +115,13 @@ def test_reconstruction_metadata():
mock_fetch.assert_not_called()


def test_invalid_reconstruction_metadata():
with pytest.raises(
DagsterInvariantViolationError, match=r"Reconstruction metadata values must be strings"
):
Definitions().with_reconstruction_metadata({"foo": {"not": "a string"}})


def test_default_global_context():
instance = DefinitionsLoadContext.get()
DefinitionsLoadContext._instance = None # noqa: SLF001
Expand Down

0 comments on commit 7998da4

Please sign in to comment.