Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ext] Use MaterializeResult for ext_protocol #16624

Merged
merged 2 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def number_y(
context: AssetExecutionContext,
ext_k8s_pod: ExtK8sPod,
):
ext_k8s_pod.run(
yield from ext_k8s_pod.run(
context=context,
namespace=namespace,
image=docker_image,
Expand Down Expand Up @@ -138,7 +138,7 @@ def number_y(
],
)

ext_k8s_pod.run(
yield from ext_k8s_pod.run(
context=context,
namespace=namespace,
extras={
Expand Down Expand Up @@ -197,6 +197,7 @@ def number_y_job(context: AssetExecutionContext):
k8s_job_name=job_name,
)
reader.consume_pod_logs(core_api, job_name, namespace)
yield from ext_context.get_results()

result = materialize(
[number_y_job],
Expand Down
15 changes: 7 additions & 8 deletions python_modules/dagster-ext/dagster_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Mapping,
Optional,
Sequence,
Set,
TextIO,
Type,
TypedDict,
Expand Down Expand Up @@ -165,7 +164,6 @@ def _resolve_optionally_passed_asset_key(
data: ExtContextData,
asset_key: Optional[str],
method: str,
already_materialized_assets: Set[str],
) -> str:
asset_keys = _assert_defined_asset_property(data["asset_keys"], method)
asset_key = _assert_opt_param_type(asset_key, str, method, "asset_key")
Expand All @@ -180,11 +178,6 @@ def _resolve_optionally_passed_asset_key(
" targets multiple assets."
)
asset_key = asset_keys[0]
if asset_key in already_materialized_assets:
raise DagsterExtError(
f"Calling `{method}` with asset key `{asset_key}` is undefined. Asset has already been"
" materialized, so no additional data can be reported for it."
)
return asset_key


Expand Down Expand Up @@ -784,8 +777,14 @@ def report_asset_materialization(
asset_key: Optional[str] = None,
):
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_materialization", self._materialized_assets
self._data, asset_key, "report_asset_materialization"
)
if asset_key in self._materialized_assets:
raise DagsterExtError(
f"Calling `report_asset_materialization` with asset key `{asset_key}` is undefined."
" Asset has already been materialized, so no additional data can be reported"
" for it."
)
metadata = (
_normalize_param_metadata(metadata, "report_asset_materialization", "metadata")
if metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

import boto3
import pytest
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
DATA_VERSION_TAG,
)
from dagster._core.definitions.decorators.asset_decorator import asset
from dagster._core.definitions.decorators.asset_decorator import asset, multi_asset
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.materialize import materialize
from dagster._core.definitions.metadata import (
Expand All @@ -30,8 +31,8 @@
TextMetadataValue,
UrlMetadataValue,
)
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.errors import DagsterExternalExecutionError, DagsterInvariantViolationError
from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext
from dagster._core.execution.context.invocation import build_asset_context
from dagster._core.ext.subprocess import (
ExtSubprocess,
Expand Down Expand Up @@ -150,7 +151,7 @@ def test_ext_subprocess(
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
ext.run(
yield from ext.run(
cmd,
context=context,
extras=extras,
Expand All @@ -163,11 +164,7 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
resource = ExtSubprocess(context_injector=context_injector, message_reader=message_reader)

with instance_for_test() as instance:
materialize(
[foo],
instance=instance,
resources={"ext": resource},
)
materialize([foo], instance=instance, resources={"ext": resource})
mat = instance.get_latest_materialization_event(foo.key)
assert mat and mat.asset_materialization
assert isinstance(mat.asset_materialization.metadata["bar"], MarkdownMetadataValue)
Expand All @@ -180,6 +177,35 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)


def test_ext_multi_asset():
def script_fn():
from dagster_ext import init_dagster_ext

context = init_dagster_ext()
context.report_asset_materialization(
{"foo_meta": "ok"}, data_version="alpha", asset_key="foo"
)
context.report_asset_materialization(data_version="alpha", asset_key="bar")

@multi_asset(specs=[AssetSpec("foo"), AssetSpec("bar")])
def foo_bar(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
yield from ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize([foo_bar], instance=instance, resources={"ext": ExtSubprocess()})
foo_mat = instance.get_latest_materialization_event(AssetKey(["foo"]))
assert foo_mat and foo_mat.asset_materialization
assert foo_mat.asset_materialization.metadata["foo_meta"].value == "ok"
assert foo_mat.asset_materialization.tags
assert foo_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
bar_mat = instance.get_latest_materialization_event(AssetKey(["foo"]))
assert bar_mat and bar_mat.asset_materialization
assert bar_mat.asset_materialization.tags
assert bar_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"


def test_ext_typed_metadata():
def script_fn():
from dagster_ext import init_dagster_ext
Expand Down Expand Up @@ -207,7 +233,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
ext.run(cmd, context=context)
yield from ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize(
Expand Down Expand Up @@ -254,7 +280,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
ext.run(cmd, context=context)
yield from ext.run(cmd, context=context)

with pytest.raises(DagsterExternalExecutionError):
materialize([foo], resources={"ext": ExtSubprocess()})
Expand Down Expand Up @@ -321,6 +347,7 @@ def subproc_run(context: AssetExecutionContext):
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
yield from ext_context.get_results()

with instance_for_test() as instance:
materialize(
Expand All @@ -333,3 +360,28 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags
assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]


def test_ext_no_client_no_yield():
def script_fn():
pass

@asset
def foo(context: OpExecutionContext):
with temp_script(script_fn) as external_script:
with ext_protocol(
context,
ExtTempFileContextInjector(),
ExtTempFileMessageReader(),
) as ext_context:
cmd = [_PYTHON_EXECUTABLE, external_script]
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)

with pytest.raises(
DagsterInvariantViolationError,
match=(
r"did not yield or return expected outputs.*Did you forget to `yield from"
r" ext_context.get_results\(\)`?"
),
):
materialize([foo])
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/ext/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dagster._core.execution.context.compute import OpExecutionContext

if TYPE_CHECKING:
from .context import ExtMessageHandler
from .context import ExtMessageHandler, ExtResult


class ExtClient(ABC):
Expand All @@ -21,7 +21,7 @@ def run(
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
) -> None: ...
) -> Iterator["ExtResult"]: ...


class ExtContextInjector(ABC):
Expand Down
59 changes: 49 additions & 10 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Mapping, Optional
from queue import Queue
from typing import Any, Dict, Iterator, Mapping, Optional, Set

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -10,24 +12,54 @@
ExtExtras,
ExtMessage,
ExtMetadataType,
ExtMetadataValue,
ExtParams,
ExtTimeWindow,
encode_env_var,
)
from typing_extensions import TypeAlias

import dagster._check as check
from dagster._core.definitions.data_version import DataProvenance, DataVersion
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value
from dagster._core.definitions.partition_key_range import PartitionKeyRange
from dagster._core.definitions.result import MaterializeResult
from dagster._core.definitions.time_window_partitions import TimeWindow
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.invocation import BoundOpExecutionContext
from dagster._core.ext.client import ExtMessageReader

ExtResult: TypeAlias = MaterializeResult


class ExtMessageHandler:
def __init__(self, context: OpExecutionContext) -> None:
self._context = context
# Queue is thread-safe
self._result_queue: Queue[ExtResult] = Queue()
# Only read by the main thread after all messages are handled, so no need for a lock
self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys)
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@contextmanager
def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]:
with message_reader.read_messages(self) as params:
yield params
for key in self._unmaterialized_assets:
self._result_queue.put(MaterializeResult(asset_key=key))

def clear_result_queue(self) -> Iterator[ExtResult]:
while not self._result_queue.empty():
yield self._result_queue.get()

def _resolve_metadata(
self, metadata: Mapping[str, ExtMetadataValue]
) -> Mapping[str, MetadataValue]:
return {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}

def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) -> MetadataValue:
if metadata_type == EXT_METADATA_TYPE_INFER:
Expand Down Expand Up @@ -69,20 +101,24 @@ def handle_message(self, message: ExtMessage) -> None:
self._handle_log(**message["params"]) # type: ignore

def _handle_report_asset_materialization(
self, asset_key: str, metadata: Optional[Mapping[str, Any]], data_version: Optional[str]
self,
asset_key: str,
metadata: Optional[Mapping[str, ExtMetadataValue]],
data_version: Optional[str],
) -> None:
check.str_param(asset_key, "asset_key")
check.opt_str_param(data_version, "data_version")
metadata = check.opt_mapping_param(metadata, "metadata", key_type=str)
resolved_asset_key = AssetKey.from_user_string(asset_key)
resolved_metadata = {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}
if data_version is not None:
self._context.set_data_version(resolved_asset_key, DataVersion(data_version))
if resolved_metadata:
output_name = self._context.output_for_asset_key(resolved_asset_key)
self._context.add_output_metadata(resolved_metadata, output_name)
resolved_metadata = self._resolve_metadata(metadata)
resolved_data_version = None if data_version is None else DataVersion(data_version)
result = MaterializeResult(
asset_key=resolved_asset_key,
metadata=resolved_metadata,
data_version=resolved_data_version,
)
self._result_queue.put(result)
self._unmaterialized_assets.remove(resolved_asset_key)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
Expand Down Expand Up @@ -114,6 +150,9 @@ def get_external_process_env_vars(self):
),
}

def get_results(self) -> Iterator[ExtResult]:
yield from self.message_handler.clear_result_queue()


def build_external_execution_context_data(
context: OpExecutionContext,
Expand Down
9 changes: 6 additions & 3 deletions python_modules/dagster/dagster/_core/ext/subprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from subprocess import Popen
from typing import Mapping, Optional, Sequence, Union
from typing import Iterator, Mapping, Optional, Sequence, Union

from dagster_ext import ExtExtras

Expand All @@ -12,6 +12,7 @@
ExtContextInjector,
ExtMessageReader,
)
from dagster._core.ext.context import ExtResult
from dagster._core.ext.utils import (
ExtTempFileContextInjector,
ExtTempFileMessageReader,
Expand Down Expand Up @@ -66,7 +67,7 @@ def run(
extras: Optional[ExtExtras] = None,
env: Optional[Mapping[str, str]] = None,
cwd: Optional[str] = None,
) -> None:
) -> Iterator[ExtResult]:
with ext_protocol(
context=context,
context_injector=self.context_injector,
Expand All @@ -82,12 +83,14 @@ def run(
**(env or {}),
},
)
process.wait()
while process.poll() is None:
yield from ext_context.get_results()

if process.returncode != 0:
raise DagsterExternalExecutionError(
f"External execution process failed with code {process.returncode}"
)
yield from ext_context.get_results()


ExtSubprocess = ResourceParam[_ExtSubprocess]
Loading