diff --git a/python_modules/dagster/dagster/_core/definitions/op_invocation.py b/python_modules/dagster/dagster/_core/definitions/op_invocation.py index 115ccfcf55ae0..7a3254413eb69 100644 --- a/python_modules/dagster/dagster/_core/definitions/op_invocation.py +++ b/python_modules/dagster/dagster/_core/definitions/op_invocation.py @@ -32,7 +32,7 @@ from .result import MaterializeResult if TYPE_CHECKING: - from ..execution.context.invocation import RunlessOpExecutionContext + from ..execution.context.invocation import DirectOpExecutionContext from .assets import AssetsDefinition from .composition import PendingNodeInvocation from .decorators.op_decorator import DecoratedOpFunction @@ -109,7 +109,7 @@ def direct_invocation_result( ) -> Any: from dagster._config.pythonic_config import Config from dagster._core.execution.context.invocation import ( - RunlessOpExecutionContext, + DirectOpExecutionContext, build_op_context, ) @@ -149,12 +149,12 @@ def direct_invocation_result( " no context was provided when invoking." ) if len(args) > 0: - if args[0] is not None and not isinstance(args[0], RunlessOpExecutionContext): + if args[0] is not None and not isinstance(args[0], DirectOpExecutionContext): raise DagsterInvalidInvocationError( f"Decorated function '{compute_fn.name}' has context argument, " "but no context was provided when invoking." ) - context = cast(RunlessOpExecutionContext, args[0]) + context = cast(DirectOpExecutionContext, args[0]) # update args to omit context args = args[1:] else: # context argument is provided under kwargs @@ -165,14 +165,14 @@ def direct_invocation_result( f"'{context_param_name}', but no value for '{context_param_name}' was " f"found when invoking. Provided kwargs: {kwargs}" ) - context = cast(RunlessOpExecutionContext, kwargs[context_param_name]) + context = cast(DirectOpExecutionContext, kwargs[context_param_name]) # update kwargs to remove context kwargs = { kwarg: val for kwarg, val in kwargs.items() if not kwarg == context_param_name } # allow passing context, even if the function doesn't have an arg for it - elif len(args) > 0 and isinstance(args[0], RunlessOpExecutionContext): - context = cast(RunlessOpExecutionContext, args[0]) + elif len(args) > 0 and isinstance(args[0], DirectOpExecutionContext): + context = cast(DirectOpExecutionContext, args[0]) args = args[1:] resource_arg_mapping = {arg.name: arg.name for arg in compute_fn.get_resource_args()} @@ -230,7 +230,7 @@ def direct_invocation_result( def _resolve_inputs( - op_def: "OpDefinition", args, kwargs, context: "RunlessOpExecutionContext" + op_def: "OpDefinition", args, kwargs, context: "DirectOpExecutionContext" ) -> Mapping[str, Any]: from dagster._core.execution.plan.execute_step import do_type_check @@ -333,7 +333,7 @@ def _resolve_inputs( return input_dict -def _key_for_result(result: MaterializeResult, context: "RunlessOpExecutionContext") -> AssetKey: +def _key_for_result(result: MaterializeResult, context: "DirectOpExecutionContext") -> AssetKey: if not context.bound_properties.assets_def: raise DagsterInvariantViolationError( f"Op {context.bound_properties.alias} does not have an assets definition." @@ -352,7 +352,7 @@ def _key_for_result(result: MaterializeResult, context: "RunlessOpExecutionConte def _output_name_for_result_obj( event: MaterializeResult, - context: "RunlessOpExecutionContext", + context: "DirectOpExecutionContext", ): if not context.bound_properties.assets_def: raise DagsterInvariantViolationError( @@ -365,7 +365,7 @@ def _output_name_for_result_obj( def _handle_gen_event( event: T, op_def: "OpDefinition", - context: "RunlessOpExecutionContext", + context: "DirectOpExecutionContext", output_defs: Mapping[str, OutputDefinition], outputs_seen: Set[str], ) -> T: @@ -399,7 +399,7 @@ def _handle_gen_event( def _type_check_output_wrapper( - op_def: "OpDefinition", result: Any, context: "RunlessOpExecutionContext" + op_def: "OpDefinition", result: Any, context: "DirectOpExecutionContext" ) -> Any: """Type checks and returns the result of a op. @@ -493,7 +493,7 @@ def type_check_gen(gen): def _type_check_function_output( - op_def: "OpDefinition", result: T, context: "RunlessOpExecutionContext" + op_def: "OpDefinition", result: T, context: "DirectOpExecutionContext" ) -> T: from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator @@ -512,14 +512,14 @@ def _type_check_function_output( def _type_check_output( output_def: "OutputDefinition", output: Union[Output, DynamicOutput], - context: "RunlessOpExecutionContext", + context: "DirectOpExecutionContext", ) -> None: """Validates and performs core type check on a provided output. Args: output_def (OutputDefinition): The output definition to validate against. output (Any): The output to validate. - context (RunlessOpExecutionContext): Context containing resources to be used for type + context (DirectOpExecutionContext): Context containing resources to be used for type check. """ from ..execution.plan.execute_step import do_type_check diff --git a/python_modules/dagster/dagster/_core/execution/context/invocation.py b/python_modules/dagster/dagster/_core/execution/context/invocation.py index 2fe58e73cb7a7..bdd18be343880 100644 --- a/python_modules/dagster/dagster/_core/execution/context/invocation.py +++ b/python_modules/dagster/dagster/_core/execution/context/invocation.py @@ -106,124 +106,24 @@ def __new__( ) -class RunlessExecutionProperties: +class DirectExecutionProperties: """Maintains information about the invocation that is updated during execution time. This information needs to be available to the user once invocation is complete, so that they can assert on events and outputs. It needs to be cleared before the context is used for another invocation. + + This is not implemented as a NamedTuple because the various attributes will be mutated during + execution. """ def __init__(self): - self._events: List[UserEvent] = [] - self._seen_outputs = {} - self._output_metadata = {} - self._requires_typed_event_stream = False - self._typed_event_stream_error_message = None - - @property - def user_events(self): - return self._events - - @property - def seen_outputs(self): - return self._seen_outputs - - @property - def output_metadata(self): - return self._output_metadata - - @property - def requires_typed_event_stream(self) -> bool: - return self._requires_typed_event_stream - - @property - def typed_event_stream_error_message(self) -> Optional[str]: - return self._typed_event_stream_error_message - - def log_event(self, event: UserEvent) -> None: - check.inst_param( - event, - "event", - (AssetMaterialization, AssetObservation, ExpectationResult), - ) - self._events.append(event) - - def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None: - if mapping_key: - if output_name not in self.seen_outputs: - self._seen_outputs[output_name] = set() - cast(Set[str], self._seen_outputs[output_name]).add(mapping_key) - else: - self._seen_outputs[output_name] = "seen" - - def has_seen_output(self, output_name: str, mapping_key: Optional[str] = None) -> bool: - if mapping_key: - return ( - output_name in self.seen_outputs and mapping_key in self.seen_outputs[output_name] - ) - return output_name in self.seen_outputs - - def add_output_metadata( - self, - metadata: Mapping[str, Any], - op_def: OpDefinition, - output_name: Optional[str] = None, - mapping_key: Optional[str] = None, - ) -> None: - metadata = check.mapping_param(metadata, "metadata", key_type=str) - output_name = check.opt_str_param(output_name, "output_name") - mapping_key = check.opt_str_param(mapping_key, "mapping_key") + self.user_events: List[UserEvent] = [] + self.seen_outputs: Dict[str, Union[str, Set[str]]] = {} + self.output_metadata: Dict[str, Dict[str, Union[Any, Mapping[str, Any]]]] = {} + self.requires_typed_event_stream: bool = False + self.typed_event_stream_error_message: Optional[str] = None - if output_name is None and len(op_def.output_defs) == 1: - output_def = op_def.output_defs[0] - output_name = output_def.name - elif output_name is None: - raise DagsterInvariantViolationError( - "Attempted to log metadata without providing output_name, but multiple outputs" - " exist. Please provide an output_name to the invocation of" - " `context.add_output_metadata`." - ) - else: - output_def = op_def.output_def_named(output_name) - if self.has_seen_output(output_name, mapping_key): - output_desc = ( - f"output '{output_def.name}'" - if not mapping_key - else f"output '{output_def.name}' with mapping_key '{mapping_key}'" - ) - raise DagsterInvariantViolationError( - f"In {op_def.node_type_str} '{op_def.name}', attempted to log output" - f" metadata for {output_desc} which has already been yielded. Metadata must be" - " logged before the output is yielded." - ) - if output_def.is_dynamic and not mapping_key: - raise DagsterInvariantViolationError( - f"In {op_def.node_type_str} '{op_def.name}', attempted to log metadata" - f" for dynamic output '{output_def.name}' without providing a mapping key. When" - " logging metadata for a dynamic output, it is necessary to provide a mapping key." - ) - - output_name = output_def.name - if output_name in self.output_metadata: - if not mapping_key or mapping_key in self.output_metadata[output_name]: - raise DagsterInvariantViolationError( - f"In {op_def.node_type_str} '{op_def.name}', attempted to log" - f" metadata for output '{output_name}' more than once." - ) - if mapping_key: - if output_name not in self.output_metadata: - self._output_metadata[output_name] = {} - self._output_metadata[output_name][mapping_key] = metadata - - else: - self._output_metadata[output_name] = metadata - - def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> None: - self._requires_typed_event_stream = True - self._typed_event_stream_error_message = error_message - - -class RunlessOpExecutionContext(OpExecutionContext): +class DirectOpExecutionContext(OpExecutionContext): """The ``context`` object available as the first argument to an op's compute function when being invoked directly. Can also be used as a context manager. """ @@ -292,7 +192,7 @@ def __init__( # my_op(ctx) # ctx._execution_properties.output_metadata # information is retained after invocation # my_op(ctx) # ctx._execution_properties is cleared at the beginning of the next invocation - self._execution_properties = RunlessExecutionProperties() + self._execution_properties = DirectExecutionProperties() def __enter__(self): self._cm_scope_entered = True @@ -318,7 +218,7 @@ def bind( assets_def: Optional[AssetsDefinition], config_from_args: Optional[Mapping[str, Any]], resources_from_args: Optional[Mapping[str, Any]], - ) -> "RunlessOpExecutionContext": + ) -> "DirectOpExecutionContext": from dagster._core.definitions.resource_invocation import resolve_bound_config if self._bound_properties is not None: @@ -327,7 +227,7 @@ def bind( ) # reset execution_properties - self._execution_properties = RunlessExecutionProperties() + self._execution_properties = DirectExecutionProperties() # update the bound context with properties relevant to the execution of the op @@ -413,7 +313,7 @@ def is_bound(self) -> bool: return self._bound_properties is not None @property - def execution_properties(self) -> RunlessExecutionProperties: + def execution_properties(self) -> DirectExecutionProperties: return self._execution_properties @property @@ -637,16 +537,29 @@ def describe_op(self) -> str: def log_event(self, event: UserEvent) -> None: self._check_bound(fn_name="log_event", fn_type="method") - self._execution_properties.log_event(event) + check.inst_param( + event, + "event", + (AssetMaterialization, AssetObservation, ExpectationResult), + ) + self._execution_properties.user_events.append(event) def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None: self._check_bound(fn_name="observe_output", fn_type="method") - self._execution_properties.observe_output(output_name=output_name, mapping_key=mapping_key) + if mapping_key: + if output_name not in self._execution_properties.seen_outputs: + self._execution_properties.seen_outputs[output_name] = set() + cast(Set[str], self._execution_properties.seen_outputs[output_name]).add(mapping_key) + else: + self._execution_properties.seen_outputs[output_name] = "seen" def has_seen_output(self, output_name: str, mapping_key: Optional[str] = None) -> bool: - return self._execution_properties.has_seen_output( - output_name=output_name, mapping_key=mapping_key - ) + if mapping_key: + return ( + output_name in self._execution_properties.seen_outputs + and mapping_key in self._execution_properties.seen_outputs[output_name] + ) + return output_name in self._execution_properties.seen_outputs def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow: self._check_bound(fn_name="asset_partitions_time_window_for_output", fn_type="method") @@ -699,9 +612,57 @@ def add_metadata_two_outputs(context) -> Tuple[str, int]: """ self._check_bound(fn_name="add_output_metadata", fn_type="method") - self._execution_properties.add_output_metadata( - metadata=metadata, op_def=self.op_def, output_name=output_name, mapping_key=mapping_key - ) + metadata = check.mapping_param(metadata, "metadata", key_type=str) + output_name = check.opt_str_param(output_name, "output_name") + mapping_key = check.opt_str_param(mapping_key, "mapping_key") + + if output_name is None and len(self.op_def.output_defs) == 1: + output_def = self.op_def.output_defs[0] + output_name = output_def.name + elif output_name is None: + raise DagsterInvariantViolationError( + "Attempted to log metadata without providing output_name, but multiple outputs" + " exist. Please provide an output_name to the invocation of" + " `context.add_output_metadata`." + ) + else: + output_def = self.op_def.output_def_named(output_name) + + if self.has_seen_output(output_name, mapping_key): + output_desc = ( + f"output '{output_def.name}'" + if not mapping_key + else f"output '{output_def.name}' with mapping_key '{mapping_key}'" + ) + raise DagsterInvariantViolationError( + f"In {self.op_def.node_type_str} '{self.op_def.name}', attempted to log output" + f" metadata for {output_desc} which has already been yielded. Metadata must be" + " logged before the output is yielded." + ) + if output_def.is_dynamic and not mapping_key: + raise DagsterInvariantViolationError( + f"In {self.op_def.node_type_str} '{self.op_def.name}', attempted to log metadata" + f" for dynamic output '{output_def.name}' without providing a mapping key. When" + " logging metadata for a dynamic output, it is necessary to provide a mapping key." + ) + + output_name = output_def.name + if output_name in self._execution_properties.output_metadata: + if ( + not mapping_key + or mapping_key in self._execution_properties.output_metadata[output_name] + ): + raise DagsterInvariantViolationError( + f"In {self.op_def.node_type_str} '{self.op_def.name}', attempted to log" + f" metadata for output '{output_name}' more than once." + ) + if mapping_key: + if output_name not in self._execution_properties.output_metadata: + self._execution_properties.output_metadata[output_name] = {} + self._execution_properties.output_metadata[output_name][mapping_key] = metadata + + else: + self._execution_properties.output_metadata[output_name] = metadata # In bound mode no conversion is done on returned values and missing but expected outputs are not # allowed. @@ -717,7 +678,8 @@ def typed_event_stream_error_message(self) -> Optional[str]: def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> None: self._check_bound(fn_name="set_requires_typed_event_stream", fn_type="method") - self._execution_properties.set_requires_typed_event_stream(error_message=error_message) + self._execution_properties.requires_typed_event_stream = True + self._execution_properties.typed_event_stream_error_message = error_message def _validate_resource_requirements( @@ -740,7 +702,7 @@ def build_op_context( partition_key_range: Optional[PartitionKeyRange] = None, mapping_key: Optional[str] = None, _assets_def: Optional[AssetsDefinition] = None, -) -> RunlessOpExecutionContext: +) -> DirectOpExecutionContext: """Builds op execution context from provided parameters. ``build_op_context`` can be used as either a function or context manager. If there is a @@ -788,7 +750,7 @@ def build_op_context( ) op_config = op_config if op_config else config - return RunlessOpExecutionContext( + return DirectOpExecutionContext( resources_dict=check.opt_mapping_param(resources, "resources", key_type=str), resources_config=check.opt_mapping_param( resources_config, "resources_config", key_type=str diff --git a/python_modules/dagster/dagster/_core/pipes/context.py b/python_modules/dagster/dagster/_core/pipes/context.py index 4bd0ca2a2745a..ad3468bf953f7 100644 --- a/python_modules/dagster/dagster/_core/pipes/context.py +++ b/python_modules/dagster/dagster/_core/pipes/context.py @@ -39,7 +39,7 @@ from dagster._core.errors import DagsterPipesExecutionError from dagster._core.events import EngineEventData from dagster._core.execution.context.compute import OpExecutionContext -from dagster._core.execution.context.invocation import RunlessOpExecutionContext +from dagster._core.execution.context.invocation import DirectOpExecutionContext from dagster._utils.error import ( ExceptionInfo, SerializableErrorInfo, @@ -400,8 +400,8 @@ def build_external_execution_context_data( _convert_time_window(partition_time_window) if partition_time_window else None ), run_id=context.run_id, - job_name=None if isinstance(context, RunlessOpExecutionContext) else context.job_name, - retry_number=0 if isinstance(context, RunlessOpExecutionContext) else context.retry_number, + job_name=None if isinstance(context, DirectOpExecutionContext) else context.job_name, + retry_number=0 if isinstance(context, DirectOpExecutionContext) else context.retry_number, extras=extras or {}, ) diff --git a/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py b/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py index b97e6f69938f3..b41669e98b265 100644 --- a/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py +++ b/python_modules/dagster/dagster_tests/core_tests/test_op_invocation.py @@ -45,7 +45,7 @@ ) from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext from dagster._core.execution.context.invocation import ( - RunlessOpExecutionContext, + DirectOpExecutionContext, build_asset_context, ) from dagster._utils.test import wrap_op_in_graph_and_execute @@ -1388,22 +1388,22 @@ async def main(): asyncio.run(main()) -def assert_context_unbound(context: RunlessOpExecutionContext): +def assert_context_unbound(context: DirectOpExecutionContext): # to assert that the context is correctly unbound after op invocation assert not context.is_bound -def assert_context_bound(context: RunlessOpExecutionContext): +def assert_context_bound(context: DirectOpExecutionContext): # to assert that the context is correctly bound during op invocation assert context.is_bound -def assert_execution_properties_cleared(context: RunlessOpExecutionContext): +def assert_execution_properties_cleared(context: DirectOpExecutionContext): # to assert that the invocation properties are reset at the beginning of op invocation assert len(context.execution_properties.output_metadata.keys()) == 0 -def assert_execution_properties_exist(context: RunlessOpExecutionContext): +def assert_execution_properties_exist(context: DirectOpExecutionContext): # to assert that the invocation properties remain accessible after op invocation assert len(context.execution_properties.output_metadata.keys()) > 0