From 46bb42908b8a0084aeb0b4fa48a54550e57c58b3 Mon Sep 17 00:00:00 2001 From: JamieDeMaria Date: Fri, 17 Nov 2023 16:19:27 -0500 Subject: [PATCH] update interfaces --- .../_core/definitions/op_invocation.py | 45 ++++++++++--------- .../_core/execution/context/invocation.py | 35 ++++++++++++--- 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/op_invocation.py b/python_modules/dagster/dagster/_core/definitions/op_invocation.py index afee2c768b34c..d72b0ce2584b2 100644 --- a/python_modules/dagster/dagster/_core/definitions/op_invocation.py +++ b/python_modules/dagster/dagster/_core/definitions/op_invocation.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from ..execution.context.invocation import ( - DirectInvocationOpExecutionContext, + BaseDirectInvocationContext, ) from .assets import AssetsDefinition from .composition import PendingNodeInvocation @@ -225,7 +225,7 @@ def direct_invocation_result( def _resolve_inputs( - op_def: "OpDefinition", args, kwargs, context: "DirectInvocationOpExecutionContext" + op_def: "OpDefinition", args, kwargs, context: "BaseDirectInvocationContext" ) -> Mapping[str, Any]: from dagster._core.execution.plan.execute_step import do_type_check @@ -263,9 +263,8 @@ def _resolve_inputs( "but no context parameter was defined for the op." ) - node_label = op_def.node_type_str raise DagsterInvalidInvocationError( - f"Too many input arguments were provided for {node_label} '{context.alias}'." + f"Too many input arguments were provided for {context.execution_properties.step_description}'." f" {suggestion}" ) @@ -308,7 +307,7 @@ def _resolve_inputs( input_dict[k] = v # Type check inputs - op_label = context.describe_op() + step_label = context.execution_properties.step_description for input_name, val in input_dict.items(): input_def = input_defs_by_name[input_name] @@ -317,7 +316,7 @@ def _resolve_inputs( if not type_check.success: raise DagsterTypeCheckDidNotPass( description=( - f'Type check failed for {op_label} input "{input_def.name}" - ' + f'Type check failed for {step_label} input "{input_def.name}" - ' f'expected type "{dagster_type.display_name}". ' f"Description: {type_check.description}" ), @@ -328,33 +327,35 @@ def _resolve_inputs( return input_dict -def _key_for_result( - result: MaterializeResult, context: "DirectInvocationOpExecutionContext" -) -> AssetKey: +def _key_for_result(result: MaterializeResult, context: "BaseDirectInvocationContext") -> AssetKey: if result.asset_key: return result.asset_key - if len(context.assets_def.keys) == 1: - return next(iter(context.assets_def.keys)) + if len(context.execution_properties.op_execution_context.assets_def.keys) == 1: + return next(iter(context.execution_properties.op_execution_context.assets_def.keys)) raise DagsterInvariantViolationError( "MaterializeResult did not include asset_key and it can not be inferred. Specify which" - f" asset_key, options are: {context.assets_def.keys}" + f" asset_key, options are: {context.execution_properties.op_execution_context.assets_def.keys}" ) def _output_name_for_result_obj( event: MaterializeResult, - context: "DirectInvocationOpExecutionContext", + context: "BaseDirectInvocationContext", ): asset_key = _key_for_result(event, context) - return context.assets_def.get_output_name_for_asset_key(asset_key) + return ( + context.execution_properties.op_execution_context.assets_def.get_output_name_for_asset_key( + asset_key + ) + ) def _handle_gen_event( event: T, op_def: "OpDefinition", - context: "DirectInvocationOpExecutionContext", + context: "BaseDirectInvocationContext", output_defs: Mapping[str, OutputDefinition], outputs_seen: Set[str], ) -> T: @@ -380,7 +381,7 @@ def _handle_gen_event( output_def, DynamicOutputDefinition ): raise DagsterInvariantViolationError( - f"Invocation of {op_def.node_type_str} '{context.alias}' yielded" + f"Invocation of {context.execution_properties.step_description} yielded" f" an output '{output_def.name}' multiple times." ) outputs_seen.add(output_def.name) @@ -388,7 +389,7 @@ def _handle_gen_event( def _type_check_output_wrapper( - op_def: "OpDefinition", result: Any, context: "DirectInvocationOpExecutionContext" + op_def: "OpDefinition", result: Any, context: "BaseDirectInvocationContext" ) -> Any: """Type checks and returns the result of a op. @@ -462,7 +463,7 @@ def type_check_gen(gen): def _type_check_function_output( - op_def: "OpDefinition", result: T, context: "DirectInvocationOpExecutionContext" + op_def: "OpDefinition", result: T, context: "BaseDirectInvocationContext" ) -> T: from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator @@ -480,25 +481,25 @@ def _type_check_function_output( def _type_check_output( output_def: "OutputDefinition", output: Union[Output, DynamicOutput], - context: "DirectInvocationOpExecutionContext", + context: "BaseDirectInvocationContext", ) -> None: """Validates and performs core type check on a provided output. Args: output_def (OutputDefinition): The output definition to validate against. output (Any): The output to validate. - context (DirectInvocationOpExecutionContext): Context containing resources to be used for type + context (BaseDirectInvocationContext): Context containing resources to be used for type check. """ from ..execution.plan.execute_step import do_type_check - op_label = context.describe_op() + step_label = context.execution_properties.step_description dagster_type = output_def.dagster_type type_check = do_type_check(context.for_type(dagster_type), dagster_type, output.value) if not type_check.success: raise DagsterTypeCheckDidNotPass( description=( - f'Type check failed for {op_label} output "{output.output_name}" - ' + f'Type check failed for {step_label} output "{output.output_name}" - ' f'expected type "{dagster_type.display_name}". ' f"Description: {type_check.description}" ), diff --git a/python_modules/dagster/dagster/_core/execution/context/invocation.py b/python_modules/dagster/dagster/_core/execution/context/invocation.py index 7d847d2fd414b..69957e16bf53a 100644 --- a/python_modules/dagster/dagster/_core/execution/context/invocation.py +++ b/python_modules/dagster/dagster/_core/execution/context/invocation.py @@ -1,4 +1,5 @@ import warnings +from abc import abstractmethod from contextlib import ExitStack from typing import ( AbstractSet, @@ -55,7 +56,13 @@ from dagster._utils.forked_pdb import ForkedPdb from dagster._utils.merger import merge_dicts -from .compute import AssetExecutionContext, ExecutionProperties, OpExecutionContext, RunProperties +from .compute import ( + AssetExecutionContext, + ContextHasExecutionProperties, + ExecutionProperties, + OpExecutionContext, + RunProperties, +) from .system import StepExecutionContext, TypeCheckContext @@ -63,7 +70,8 @@ def _property_msg(prop_name: str, method_name: str, step_type: str) -> str: return f"The {prop_name} {method_name} is not set on the context when an {step_type} is directly invoked." -class BaseDirectInvocationContext: +class BaseDirectInvocationContext(ContextHasExecutionProperties): + @abstractmethod def bind( self, op_def: OpDefinition, @@ -74,6 +82,14 @@ def bind( ): pass + @abstractmethod + def for_type(self, dagster_type: DagsterType) -> TypeCheckContext: + pass + + @abstractmethod + def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None: + pass + class DirectInvocationOpExecutionContext(OpExecutionContext, BaseDirectInvocationContext): """The ``context`` object available as the first argument to an op's compute function when @@ -716,9 +732,18 @@ def unbind(self): self._bound = False - @property - def op_execution_context(self) -> OpExecutionContext: - return self._op_execution_context + def for_type(self, dagster_type: DagsterType) -> TypeCheckContext: + self._check_bound(fn_name="for_type", fn_type="method") + resources = cast(NamedTuple, self.resources) + return TypeCheckContext( + self.run_id, + self.log, + ScopedResourcesBuilder(resources._asdict()), + dagster_type, + ) + + def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None: + self._op_execution_context.observe_output(output_name=output_name, mapping_key=mapping_key) @property def run_properties(self) -> RunProperties: