diff --git a/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py b/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py index b72e2f5e16aa6..450ee9874dd46 100644 --- a/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py +++ b/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py @@ -26,7 +26,7 @@ def number_y( context: AssetExecutionContext, ext_k8s_pod: ExtK8sPod, ): - return ext_k8s_pod.run( + yield from ext_k8s_pod.run( context=context, namespace=namespace, image=docker_image, @@ -138,7 +138,7 @@ def number_y( ], ) - return ext_k8s_pod.run( + yield from ext_k8s_pod.run( context=context, namespace=namespace, extras={ @@ -197,7 +197,7 @@ def number_y_job(context: AssetExecutionContext): k8s_job_name=job_name, ) reader.consume_pod_logs(core_api, job_name, namespace) - return ext_context.get_materialize_results() + yield from ext_context.get_results() result = materialize( [number_y_job], diff --git a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py index 714d0fc7c030b..c2b0b540904cf 100644 --- a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py +++ b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py @@ -9,11 +9,12 @@ import boto3 import pytest +from dagster._core.definitions.asset_spec import AssetSpec from dagster._core.definitions.data_version import ( DATA_VERSION_IS_USER_PROVIDED_TAG, DATA_VERSION_TAG, ) -from dagster._core.definitions.decorators.asset_decorator import asset +from dagster._core.definitions.decorators.asset_decorator import asset, multi_asset from dagster._core.definitions.events import AssetKey from dagster._core.definitions.materialize import materialize from dagster._core.definitions.metadata import ( @@ -30,8 +31,8 @@ TextMetadataValue, UrlMetadataValue, ) -from dagster._core.errors import DagsterExternalExecutionError -from dagster._core.execution.context.compute import AssetExecutionContext +from dagster._core.errors import DagsterExternalExecutionError, DagsterInvariantViolationError +from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext from dagster._core.execution.context.invocation import build_asset_context from dagster._core.ext.subprocess import ( ExtSubprocess, @@ -150,7 +151,7 @@ def test_ext_subprocess( def foo(context: AssetExecutionContext, ext: ExtSubprocess): extras = {"bar": "baz"} cmd = [_PYTHON_EXECUTABLE, external_script] - return ext.run( + yield from ext.run( cmd, context=context, extras=extras, @@ -163,11 +164,7 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess): resource = ExtSubprocess(context_injector=context_injector, message_reader=message_reader) with instance_for_test() as instance: - materialize( - [foo], - instance=instance, - resources={"ext": resource}, - ) + materialize([foo], instance=instance, resources={"ext": resource}) mat = instance.get_latest_materialization_event(foo.key) assert mat and mat.asset_materialization assert isinstance(mat.asset_materialization.metadata["bar"], MarkdownMetadataValue) @@ -180,6 +177,35 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess): assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE) +def test_ext_multi_asset(): + def script_fn(): + from dagster_ext import init_dagster_ext + + context = init_dagster_ext() + context.report_asset_materialization( + {"foo_meta": "ok"}, data_version="alpha", asset_key="foo" + ) + context.report_asset_materialization(data_version="alpha", asset_key="bar") + + @multi_asset(specs=[AssetSpec("foo"), AssetSpec("bar")]) + def foo_bar(context: AssetExecutionContext, ext: ExtSubprocess): + with temp_script(script_fn) as script_path: + cmd = [_PYTHON_EXECUTABLE, script_path] + yield from ext.run(cmd, context=context) + + with instance_for_test() as instance: + materialize([foo_bar], instance=instance, resources={"ext": ExtSubprocess()}) + foo_mat = instance.get_latest_materialization_event(AssetKey(["foo"])) + assert foo_mat and foo_mat.asset_materialization + assert foo_mat.asset_materialization.metadata["foo_meta"].value == "ok" + assert foo_mat.asset_materialization.tags + assert foo_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha" + bar_mat = instance.get_latest_materialization_event(AssetKey(["foo"])) + assert bar_mat and bar_mat.asset_materialization + assert bar_mat.asset_materialization.tags + assert bar_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha" + + def test_ext_typed_metadata(): def script_fn(): from dagster_ext import init_dagster_ext @@ -207,7 +233,7 @@ def script_fn(): def foo(context: AssetExecutionContext, ext: ExtSubprocess): with temp_script(script_fn) as script_path: cmd = [_PYTHON_EXECUTABLE, script_path] - return ext.run(cmd, context=context) + yield from ext.run(cmd, context=context) with instance_for_test() as instance: materialize( @@ -254,7 +280,7 @@ def script_fn(): def foo(context: AssetExecutionContext, ext: ExtSubprocess): with temp_script(script_fn) as script_path: cmd = [_PYTHON_EXECUTABLE, script_path] - ext.run(cmd, context=context) + yield from ext.run(cmd, context=context) with pytest.raises(DagsterExternalExecutionError): materialize([foo], resources={"ext": ExtSubprocess()}) @@ -321,9 +347,7 @@ def subproc_run(context: AssetExecutionContext): extras=extras, ) as ext_context: subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False) - _ext_context = ext_context - mat_results = _ext_context.get_materialize_results() - return mat_results[0] if len(mat_results) == 1 else mat_results + yield from ext_context.get_results() with instance_for_test() as instance: materialize( @@ -338,28 +362,26 @@ def subproc_run(context: AssetExecutionContext): assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG] -def test_ext_no_client_premature_get_results(external_script): - @asset - def subproc_run(context: AssetExecutionContext): - extras = {"bar": "baz"} - cmd = [_PYTHON_EXECUTABLE, external_script] +def test_ext_no_client_no_yield(): + def script_fn(): + pass - with ext_protocol( - context, - ExtTempFileContextInjector(), - ExtTempFileMessageReader(), - extras=extras, - ) as ext_context: - subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False) - return ext_context.get_materialize_results() + @asset + def foo(context: OpExecutionContext): + with temp_script(script_fn) as external_script: + with ext_protocol( + context, + ExtTempFileContextInjector(), + ExtTempFileMessageReader(), + ) as ext_context: + cmd = [_PYTHON_EXECUTABLE, external_script] + subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False) with pytest.raises( - DagsterExternalExecutionError, + DagsterInvariantViolationError, match=( - "`get_materialize_results` must be called after the `ext_protocol` context manager has" - " exited." + r"did not yield or return expected outputs.*Did you forget to `yield from" + r" ext_context.get_results\(\)`?" ), ): - materialize( - [subproc_run], - ) + materialize([foo]) diff --git a/python_modules/dagster/dagster/_core/ext/client.py b/python_modules/dagster/dagster/_core/ext/client.py index ab3c9fdd7e97d..47270207bb62e 100644 --- a/python_modules/dagster/dagster/_core/ext/client.py +++ b/python_modules/dagster/dagster/_core/ext/client.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Iterator, Optional from dagster_ext import ( ExtContextData, @@ -23,7 +23,7 @@ def run( *, context: OpExecutionContext, extras: Optional[ExtExtras] = None, - ) -> Union["MaterializeResult", Tuple["MaterializeResult", ...]]: ... + ) -> Iterator["MaterializeResult"]: ... class ExtContextInjector(ABC): diff --git a/python_modules/dagster/dagster/_core/ext/context.py b/python_modules/dagster/dagster/_core/ext/context.py index 7fb81fca83bba..4b3731a393a40 100644 --- a/python_modules/dagster/dagster/_core/ext/context.py +++ b/python_modules/dagster/dagster/_core/ext/context.py @@ -1,8 +1,7 @@ +from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Mapping, Optional -from typing import Any, Mapping, Optional, get_args +from queue import Queue from typing import Any, Dict, Iterator, Mapping, Optional, Set -from typing import Any, Dict, Mapping, Optional, Tuple, get_args from dagster_ext import ( DAGSTER_EXT_ENV_KEYS, @@ -25,14 +24,31 @@ from dagster._core.definitions.partition_key_range import PartitionKeyRange from dagster._core.definitions.result import MaterializeResult from dagster._core.definitions.time_window_partitions import TimeWindow -from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.execution.context.invocation import BoundOpExecutionContext +from dagster._core.ext.client import ExtMessageReader class ExtMessageHandler: def __init__(self, context: OpExecutionContext) -> None: self._context = context + # Queue is thread-safe + self._result_queue: Queue[MaterializeResult] = Queue() + # Only read by the main thread after all messages are handled, so no need for a lock + self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys) + self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {} + self._data_versions: Dict[AssetKey, DataVersion] = {} + + @contextmanager + def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]: + with message_reader.read_messages(self) as params: + yield params + for key in self._unmaterialized_assets: + self._result_queue.put(MaterializeResult(asset_key=key)) + + def clear_result_queue(self) -> Iterator[MaterializeResult]: + while not self._result_queue.empty(): + yield self._result_queue.get() def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) -> MetadataValue: if metadata_type == EXT_METADATA_TYPE_INFER: @@ -83,11 +99,14 @@ def _handle_report_asset_materialization( resolved_metadata = { k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items() } - if data_version is not None: - self._context.set_data_version(resolved_asset_key, DataVersion(data_version)) - if resolved_metadata: - output_name = self._context.output_for_asset_key(resolved_asset_key) - self._context.add_output_metadata(resolved_metadata, output_name) + resolved_data_version = None if data_version is None else DataVersion(data_version) + result = MaterializeResult( + asset_key=resolved_asset_key, + metadata=resolved_metadata, + data_version=resolved_data_version, + ) + self._result_queue.put(result) + self._unmaterialized_assets.remove(resolved_asset_key) def _handle_log(self, message: str, level: str = "info") -> None: check.str_param(message, "message") @@ -109,7 +128,6 @@ class ExtOrchestrationContext: message_handler: ExtMessageHandler context_injector_params: ExtParams message_reader_params: ExtParams - is_task_finished: bool = False def get_external_process_env_vars(self): return { @@ -120,23 +138,8 @@ def get_external_process_env_vars(self): ), } - def get_materialize_results(self) -> Tuple[MaterializeResult, ...]: - if not self.is_task_finished: - raise DagsterExternalExecutionError( - "`get_materialize_results` must be called after the `ext_protocol` context manager" - " has exited." - ) - return tuple( - self._materialize_result_for_asset(AssetKey.from_user_string(key)) - for key in self.context_data["asset_keys"] or [] - ) - - def _materialize_result_for_asset(self, asset_key: AssetKey): - return MaterializeResult( - asset_key=asset_key, - metadata=self.message_handler.metadata.get(asset_key), - data_version=self.message_handler.data_versions.get(asset_key), - ) + def get_results(self) -> Iterator[MaterializeResult]: + yield from self.message_handler.clear_result_queue() def build_external_execution_context_data( diff --git a/python_modules/dagster/dagster/_core/ext/subprocess.py b/python_modules/dagster/dagster/_core/ext/subprocess.py index 41b6e8d526009..30a4717e5d286 100644 --- a/python_modules/dagster/dagster/_core/ext/subprocess.py +++ b/python_modules/dagster/dagster/_core/ext/subprocess.py @@ -1,5 +1,5 @@ from subprocess import Popen -from typing import Mapping, Optional, Sequence, Tuple, Union +from typing import Iterator, Mapping, Optional, Sequence, Union from dagster_ext import ExtExtras @@ -67,7 +67,7 @@ def run( extras: Optional[ExtExtras] = None, env: Optional[Mapping[str, str]] = None, cwd: Optional[str] = None, - ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: + ) -> Iterator[MaterializeResult]: with ext_protocol( context=context, context_injector=self.context_injector, @@ -83,14 +83,16 @@ def run( **(env or {}), }, ) - process.wait() + while True: + yield from ext_context.get_results() + if process.poll() is not None: + break if process.returncode != 0: raise DagsterExternalExecutionError( f"External execution process failed with code {process.returncode}" ) - mat_results = ext_context.get_materialize_results() - return mat_results[0] if len(mat_results) == 1 else mat_results + yield from ext_context.get_results() ExtSubprocess = ResourceParam[_ExtSubprocess] diff --git a/python_modules/dagster/dagster/_core/ext/utils.py b/python_modules/dagster/dagster/_core/ext/utils.py index 48b0efd8b69a7..81549b4e12e5a 100644 --- a/python_modules/dagster/dagster/_core/ext/utils.py +++ b/python_modules/dagster/dagster/_core/ext/utils.py @@ -185,6 +185,12 @@ def extract_message_or_forward_to_stdout(handler: "ExtMessageHandler", log_line: sys.stdout.writelines((log_line, "\n")) +_FAIL_TO_YIELD_ERROR_MESSAGE = ( + "Did you forget to `yield from ext_context.get_results()`? This should be called once after the" + " `ext_protocol` block has exited to yield any remaining buffered results." +) + + @contextmanager def ext_protocol( context: OpExecutionContext, @@ -195,18 +201,16 @@ def ext_protocol( """Enter the context managed context injector and message reader that power the EXT protocol and receive the environment variables that need to be provided to the external process. """ + # This will trigger an error if expected outputs are not yielded + context.set_require_typed_event_stream(error_message=_FAIL_TO_YIELD_ERROR_MESSAGE) context_data = build_external_execution_context_data(context, extras) message_handler = ExtMessageHandler(context) with context_injector.inject_context( - context_data, - ) as ci_params, message_reader.read_messages( - message_handler, - ) as mr_params: - ext_context = ExtOrchestrationContext( + context_data + ) as ci_params, message_handler.handle_messages(message_reader) as mr_params: + yield ExtOrchestrationContext( context_data=context_data, message_handler=message_handler, context_injector_params=ci_params, message_reader_params=mr_params, ) - yield ext_context - ext_context.is_task_finished = True diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py index b7f522a1e463d..deda20b3c665d 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py @@ -5,7 +5,7 @@ import string import time from contextlib import contextmanager -from typing import Iterator, Mapping, Optional, Tuple, Union +from typing import Iterator, Mapping, Optional import dagster._check as check from dagster._core.definitions.resource_annotation import ResourceParam @@ -67,7 +67,7 @@ def run( context: OpExecutionContext, extras: Optional[ExtExtras] = None, submit_args: Optional[Mapping[str, str]] = None, - ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: + ) -> Iterator[MaterializeResult]: """Run a Databricks job with the EXT protocol. Args: @@ -116,8 +116,9 @@ def run( raise DagsterExternalExecutionError( f"Error running Databricks job: {run.state.state_message}" ) + yield from ext_context.get_results() time.sleep(5) - return ext_context.get_materialize_results() + yield from ext_context.get_results() ExtDatabricks = ResourceParam[_ExtDatabricks] diff --git a/python_modules/libraries/dagster-docker/dagster_docker/ext.py b/python_modules/libraries/dagster-docker/dagster_docker/ext.py index aa379722416a2..b7a88238f3a04 100644 --- a/python_modules/libraries/dagster-docker/dagster_docker/ext.py +++ b/python_modules/libraries/dagster-docker/dagster_docker/ext.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Iterator, Mapping, Optional, Sequence, Union import docker from dagster import ( @@ -95,7 +95,7 @@ def run( registry: Optional[Mapping[str, str]] = None, container_kwargs: Optional[Mapping[str, Any]] = None, extras: Optional[ExtExtras] = None, - ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: + ) -> Iterator[MaterializeResult]: """Create a docker container and run it to completion, enriched with the ext protocol. Args: @@ -163,7 +163,7 @@ def run( raise DagsterExtError(f"Container exited with non-zero status code: {result}") finally: container.stop() - return ext_context.get_materialize_results() + return ext_context.get_results() def _create_container( self, diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py index ff4f99ee048c2..93bdcf42df06b 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py @@ -1,7 +1,7 @@ import random import string from contextlib import contextmanager -from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Iterator, Mapping, Optional, Sequence, Union import kubernetes from dagster import ( @@ -124,7 +124,7 @@ def run( base_pod_meta: Optional[Mapping[str, Any]] = None, base_pod_spec: Optional[Mapping[str, Any]] = None, extras: Optional[ExtExtras] = None, - ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: + ) -> Iterator[MaterializeResult]: """Publish a kubernetes pod and wait for it to complete, enriched with the ext protocol. Args: @@ -197,8 +197,7 @@ def run( ) finally: client.core_api.delete_namespaced_pod(pod_name, namespace) - mats = ext_context.get_materialize_results() - return mats[0] if len(mats) == 1 else mats + return ext_context.get_results() def build_pod_body(