From ebb8603231201bd131ff602d8aabed3ba3e54077 Mon Sep 17 00:00:00 2001 From: JamieDeMaria Date: Wed, 6 Dec 2023 17:00:57 -0500 Subject: [PATCH] add DI asset context --- .../_core/execution/context/invocation.py | 83 ++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/python_modules/dagster/dagster/_core/execution/context/invocation.py b/python_modules/dagster/dagster/_core/execution/context/invocation.py index 49adc714175ff..f99850726526a 100644 --- a/python_modules/dagster/dagster/_core/execution/context/invocation.py +++ b/python_modules/dagster/dagster/_core/execution/context/invocation.py @@ -56,11 +56,12 @@ from dagster._utils.merger import merge_dicts from dagster._utils.warnings import deprecation_warning -from .compute import OpExecutionContext +from .compute import AssetExecutionContext, OpExecutionContext, RunProperties from .system import StepExecutionContext, TypeCheckContext def _property_msg(prop_name: str, method_name: str) -> str: + # TODO - update to handle assets too return ( f"The {prop_name} {method_name} is not set on the context when an op is directly invoked." ) @@ -716,6 +717,82 @@ def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> No self._execution_properties.set_requires_typed_event_stream(error_message=error_message) +class RunlessAssetExecutionContext(AssetExecutionContext): + """The ``context`` object available as the first argument to an asset's compute function when + being invoked directly. Can also be used as a context manager. + """ + + def __init__(self, op_execution_context: RunlessOpExecutionContext): + self._op_execution_context = op_execution_context + + self._run_props = None + + def _check_bound(self, fn_name: str, fn_type: str): + if not self._op_execution_context._bound_properties: # noqa: SLF001 + raise DagsterInvalidPropertyError(_property_msg(fn_name, fn_type)) + + def bind( + self, + op_def: OpDefinition, + pending_invocation: Optional[PendingNodeInvocation[OpDefinition]], + assets_def: Optional[AssetsDefinition], + config_from_args: Optional[Mapping[str, Any]], + resources_from_args: Optional[Mapping[str, Any]], + ) -> "RunlessAssetExecutionContext": + if assets_def is None: + raise DagsterInvariantViolationError( + "RunlessAssetExecutionContext can only being used to invoke an asset." + ) + if self._op_execution_context._bound_properties is not None: # noqa: SLF001 + raise DagsterInvalidInvocationError( + f"This context is currently being used to execute {self.op_execution_context.alias}." + " The context cannot be used to execute another asset until" + f" {self.op_execution_context.alias} has finished executing." + ) + + self._op_execution_context = self._op_execution_context.bind( + op_def=op_def, + pending_invocation=pending_invocation, + assets_def=assets_def, + config_from_args=config_from_args, + resources_from_args=resources_from_args, + ) + + return self + + def unbind(self): + self._op_execution_context = self._op_execution_context.unbind() + + @property + def op_execution_context(self) -> RunlessOpExecutionContext: + return self._op_execution_context + + def for_type(self, dagster_type: DagsterType) -> TypeCheckContext: + self._check_bound(fn_name="for_type", fn_type="method") + resources = cast(NamedTuple, self.resources) + return TypeCheckContext( + self.run_id, + self.log, + ScopedResourcesBuilder(resources._asdict()), + dagster_type, + ) + + def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None: + self.op_execution_context.observe_output(output_name=output_name, mapping_key=mapping_key) + + @property + def run_properties(self) -> RunProperties: + self._check_bound(fn_name="run_properties", fn_type="property") + if self._run_props is None: + self._run_props = RunProperties( + run_id=self.op_execution_context.run_id, + run_config=self.op_execution_context.run_config, + dagster_run=self.op_execution_context.run, + retry_number=self.op_execution_context.retry_number, + ) + return self._run_props + + def _validate_resource_requirements( resource_defs: Mapping[str, ResourceDefinition], op_def: OpDefinition ) -> None: @@ -833,7 +910,7 @@ def build_asset_context( with build_asset_context(resources={"foo": context_manager_resource}) as context: asset_to_invoke(context) """ - return build_op_context( + op_context = build_op_context( op_config=asset_config, resources=resources, resources_config=resources_config, @@ -841,3 +918,5 @@ def build_asset_context( partition_key_range=partition_key_range, instance=instance, ) + + return RunlessAssetExecutionContext(op_execution_context=op_context)