Skip to content

Commit

Permalink
Direct invocation asset execution context (#18549)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria authored Jan 31, 2024
1 parent 1cf9e98 commit f3e4817
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 27 deletions.
44 changes: 28 additions & 16 deletions python_modules/dagster/dagster/_core/definitions/op_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from .result import MaterializeResult

if TYPE_CHECKING:
from ..execution.context.invocation import DirectOpExecutionContext
from ..execution.context.compute import OpExecutionContext
from ..execution.context.invocation import BaseDirectExecutionContext
from .assets import AssetsDefinition
from .composition import PendingNodeInvocation
from .decorators.op_decorator import DecoratedOpFunction
Expand Down Expand Up @@ -100,6 +101,16 @@ def _separate_args_and_kwargs(
)


def _get_op_context(
context,
) -> "OpExecutionContext":
from dagster._core.execution.context.compute import AssetExecutionContext

if isinstance(context, AssetExecutionContext):
return context.op_execution_context
return context


def direct_invocation_result(
def_or_invocation: Union[
"OpDefinition", "PendingNodeInvocation[OpDefinition]", "AssetsDefinition"
Expand All @@ -109,7 +120,7 @@ def direct_invocation_result(
) -> Any:
from dagster._config.pythonic_config import Config
from dagster._core.execution.context.invocation import (
DirectOpExecutionContext,
BaseDirectExecutionContext,
build_op_context,
)

Expand Down Expand Up @@ -149,12 +160,12 @@ def direct_invocation_result(
" no context was provided when invoking."
)
if len(args) > 0:
if args[0] is not None and not isinstance(args[0], DirectOpExecutionContext):
if args[0] is not None and not isinstance(args[0], BaseDirectExecutionContext):
raise DagsterInvalidInvocationError(
f"Decorated function '{compute_fn.name}' has context argument, "
"but no context was provided when invoking."
)
context = cast(DirectOpExecutionContext, args[0])
context = args[0]
# update args to omit context
args = args[1:]
else: # context argument is provided under kwargs
Expand All @@ -165,14 +176,14 @@ def direct_invocation_result(
f"'{context_param_name}', but no value for '{context_param_name}' was "
f"found when invoking. Provided kwargs: {kwargs}"
)
context = cast(DirectOpExecutionContext, kwargs[context_param_name])
context = kwargs[context_param_name]
# update kwargs to remove context
kwargs = {
kwarg: val for kwarg, val in kwargs.items() if not kwarg == context_param_name
}
# allow passing context, even if the function doesn't have an arg for it
elif len(args) > 0 and isinstance(args[0], DirectOpExecutionContext):
context = cast(DirectOpExecutionContext, args[0])
elif len(args) > 0 and isinstance(args[0], BaseDirectExecutionContext):
context = args[0]
args = args[1:]

resource_arg_mapping = {arg.name: arg.name for arg in compute_fn.get_resource_args()}
Expand Down Expand Up @@ -230,7 +241,7 @@ def direct_invocation_result(


def _resolve_inputs(
op_def: "OpDefinition", args, kwargs, context: "DirectOpExecutionContext"
op_def: "OpDefinition", args, kwargs, context: "BaseDirectExecutionContext"
) -> Mapping[str, Any]:
from dagster._core.execution.plan.execute_step import do_type_check

Expand Down Expand Up @@ -333,7 +344,7 @@ def _resolve_inputs(
return input_dict


def _key_for_result(result: MaterializeResult, context: "DirectOpExecutionContext") -> AssetKey:
def _key_for_result(result: MaterializeResult, context: "BaseDirectExecutionContext") -> AssetKey:
if not context.per_invocation_properties.assets_def:
raise DagsterInvariantViolationError(
f"Op {context.per_invocation_properties.alias} does not have an assets definition."
Expand All @@ -355,7 +366,7 @@ def _key_for_result(result: MaterializeResult, context: "DirectOpExecutionContex

def _output_name_for_result_obj(
event: MaterializeResult,
context: "DirectOpExecutionContext",
context: "BaseDirectExecutionContext",
):
if not context.per_invocation_properties.assets_def:
raise DagsterInvariantViolationError(
Expand All @@ -368,7 +379,7 @@ def _output_name_for_result_obj(
def _handle_gen_event(
event: T,
op_def: "OpDefinition",
context: "DirectOpExecutionContext",
context: "BaseDirectExecutionContext",
output_defs: Mapping[str, OutputDefinition],
outputs_seen: Set[str],
) -> T:
Expand Down Expand Up @@ -402,7 +413,7 @@ def _handle_gen_event(


def _type_check_output_wrapper(
op_def: "OpDefinition", result: Any, context: "DirectOpExecutionContext"
op_def: "OpDefinition", result: Any, context: "BaseDirectExecutionContext"
) -> Any:
"""Type checks and returns the result of a op.
Expand Down Expand Up @@ -496,12 +507,13 @@ def type_check_gen(gen):


def _type_check_function_output(
op_def: "OpDefinition", result: T, context: "DirectOpExecutionContext"
op_def: "OpDefinition", result: T, context: "BaseDirectExecutionContext"
) -> T:
from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator

output_defs_by_name = {output_def.name: output_def for output_def in op_def.output_defs}
for event in validate_and_coerce_op_result_to_iterator(result, context, op_def.output_defs):
op_context = _get_op_context(context)
for event in validate_and_coerce_op_result_to_iterator(result, op_context, op_def.output_defs):
if isinstance(event, (Output, DynamicOutput)):
_type_check_output(output_defs_by_name[event.output_name], event, context)
elif isinstance(event, (MaterializeResult)):
Expand All @@ -515,14 +527,14 @@ def _type_check_function_output(
def _type_check_output(
output_def: "OutputDefinition",
output: Union[Output, DynamicOutput],
context: "DirectOpExecutionContext",
context: "BaseDirectExecutionContext",
) -> None:
"""Validates and performs core type check on a provided output.
Args:
output_def (OutputDefinition): The output definition to validate against.
output (Any): The output to validate.
context (DirectOpExecutionContext): Context containing resources to be used for type
context (BaseDirectExecutionContext): Context containing resources to be used for type
check.
"""
from ..execution.plan.execute_step import do_type_check
Expand Down
140 changes: 133 additions & 7 deletions python_modules/dagster/dagster/_core/execution/context/invocation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import abstractmethod
from contextlib import ExitStack
from typing import (
AbstractSet,
Expand Down Expand Up @@ -56,14 +57,60 @@
from dagster._utils.merger import merge_dicts
from dagster._utils.warnings import deprecation_warning

from .compute import OpExecutionContext
from .compute import AssetExecutionContext, OpExecutionContext
from .system import StepExecutionContext, TypeCheckContext


def _property_msg(prop_name: str, method_name: str) -> str:
return (
f"The {prop_name} {method_name} is not set on the context when an op is directly invoked."
)
return f"The {prop_name} {method_name} is not set on the context when an asset or op is directly invoked."


class BaseDirectExecutionContext:
"""Base class for any direct invocation execution contexts. Each type of execution context
(ex. OpExecutionContext, AssetExecutionContext) needs to have a variant for direct invocation.
Those direct invocation contexts have some methods that are not available until the context
is bound to a particular op/asset. The "bound" properties are held in PerInvocationProperties.
There are also some properties that are specific to a particular execution of an op/asset, these
properties are held in DirectExecutionProperties. Direct invocation contexts must
be able to be bound and unbound from a particular op/asset. Additionally, there are some methods
that all direct invocation contexts must implement so that the will be usable in the execution
code path.
"""

@abstractmethod
def bind(
self,
op_def: OpDefinition,
pending_invocation: Optional[PendingNodeInvocation[OpDefinition]],
assets_def: Optional[AssetsDefinition],
config_from_args: Optional[Mapping[str, Any]],
resources_from_args: Optional[Mapping[str, Any]],
):
"""Subclasses of BaseDirectExecutionContext must implement bind."""

@abstractmethod
def unbind(self):
"""Subclasses of BaseDirectExecutionContext must implement unbind."""

@property
@abstractmethod
def per_invocation_properties(self) -> "PerInvocationProperties":
"""Subclasses of BaseDirectExecutionContext must contain a PerInvocationProperties object."""

@property
@abstractmethod
def execution_properties(self) -> "DirectExecutionProperties":
"""Subclasses of BaseDirectExecutionContext must contain a DirectExecutionProperties object."""

@abstractmethod
def for_type(self, dagster_type: DagsterType) -> TypeCheckContext:
"""Subclasses of BaseDirectExecutionContext must implement for_type."""
pass

@abstractmethod
def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None:
"""Subclasses of BaseDirectExecutionContext must implement observe_output."""
pass


class PerInvocationProperties(
Expand Down Expand Up @@ -127,7 +174,7 @@ def __init__(self):
self.typed_event_stream_error_message: Optional[str] = None


class DirectOpExecutionContext(OpExecutionContext):
class DirectOpExecutionContext(OpExecutionContext, BaseDirectExecutionContext):
"""The ``context`` object available as the first argument to an op's compute function when
being invoked directly. Can also be used as a context manager.
"""
Expand Down Expand Up @@ -706,6 +753,83 @@ def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> No
self._execution_properties.typed_event_stream_error_message = error_message


class DirectAssetExecutionContext(AssetExecutionContext, BaseDirectExecutionContext):
"""The ``context`` object available as the first argument to an asset's compute function when
being invoked directly. Can also be used as a context manager.
"""

def __init__(self, op_execution_context: DirectOpExecutionContext):
self._op_execution_context = op_execution_context

def __enter__(self):
self.op_execution_context._cm_scope_entered = True # noqa: SLF001
return self

def __exit__(self, *exc):
self.op_execution_context._exit_stack.close() # noqa: SLF001

def __del__(self):
self.op_execution_context._exit_stack.close() # noqa: SLF001

def _check_bound_to_invocation(self, fn_name: str, fn_type: str):
if not self._op_execution_context._per_invocation_properties: # noqa: SLF001
raise DagsterInvalidPropertyError(_property_msg(fn_name, fn_type))

def bind(
self,
op_def: OpDefinition,
pending_invocation: Optional[PendingNodeInvocation[OpDefinition]],
assets_def: Optional[AssetsDefinition],
config_from_args: Optional[Mapping[str, Any]],
resources_from_args: Optional[Mapping[str, Any]],
) -> "DirectAssetExecutionContext":
if assets_def is None:
raise DagsterInvariantViolationError(
"DirectAssetExecutionContext can only being used to invoke an asset."
)
if self._op_execution_context._per_invocation_properties is not None: # noqa: SLF001
raise DagsterInvalidInvocationError(
f"This context is currently being used to execute {self.op_execution_context.alias}."
" The context cannot be used to execute another asset until"
f" {self.op_execution_context.alias} has finished executing."
)

self._op_execution_context = self._op_execution_context.bind(
op_def=op_def,
pending_invocation=pending_invocation,
assets_def=assets_def,
config_from_args=config_from_args,
resources_from_args=resources_from_args,
)

return self

def unbind(self):
self._op_execution_context.unbind()

@property
def per_invocation_properties(self) -> PerInvocationProperties:
return self.op_execution_context.per_invocation_properties

@property
def is_bound(self) -> bool:
return self.op_execution_context.is_bound

@property
def execution_properties(self) -> DirectExecutionProperties:
return self.op_execution_context.execution_properties

@property
def op_execution_context(self) -> DirectOpExecutionContext:
return self._op_execution_context

def for_type(self, dagster_type: DagsterType) -> TypeCheckContext:
return self.op_execution_context.for_type(dagster_type)

def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None:
self.op_execution_context.observe_output(output_name=output_name, mapping_key=mapping_key)


def _validate_resource_requirements(
resource_defs: Mapping[str, ResourceDefinition], op_def: OpDefinition
) -> None:
Expand Down Expand Up @@ -796,7 +920,7 @@ def build_asset_context(
instance: Optional[DagsterInstance] = None,
partition_key: Optional[str] = None,
partition_key_range: Optional[PartitionKeyRange] = None,
):
) -> DirectAssetExecutionContext:
"""Builds asset execution context from provided parameters.
``build_asset_context`` can be used as either a function or context manager. If there is a
Expand All @@ -823,11 +947,13 @@ def build_asset_context(
with build_asset_context(resources={"foo": context_manager_resource}) as context:
asset_to_invoke(context)
"""
return build_op_context(
op_context = build_op_context(
op_config=asset_config,
resources=resources,
resources_config=resources_config,
partition_key=partition_key,
partition_key_range=partition_key_range,
instance=instance,
)

return DirectAssetExecutionContext(op_execution_context=op_context)
6 changes: 3 additions & 3 deletions python_modules/dagster/dagster/_core/pipes/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from dagster._core.errors import DagsterPipesExecutionError
from dagster._core.events import EngineEventData
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.invocation import DirectOpExecutionContext
from dagster._core.execution.context.invocation import BaseDirectExecutionContext
from dagster._utils.error import (
ExceptionInfo,
SerializableErrorInfo,
Expand Down Expand Up @@ -406,8 +406,8 @@ def build_external_execution_context_data(
_convert_time_window(partition_time_window) if partition_time_window else None
),
run_id=context.run_id,
job_name=None if isinstance(context, DirectOpExecutionContext) else context.job_name,
retry_number=0 if isinstance(context, DirectOpExecutionContext) else context.retry_number,
job_name=None if isinstance(context, BaseDirectExecutionContext) else context.job_name,
retry_number=0 if isinstance(context, BaseDirectExecutionContext) else context.retry_number,
extras=extras or {},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1383,7 +1383,7 @@ async def main():
with pytest.raises(
DagsterInvalidInvocationError,
match=r"This context is currently being used to execute .* The context"
r" cannot be used to execute another op until .* has finished executing",
r" cannot be used to execute another asset until .* has finished executing",
):
asyncio.run(main())

Expand Down

0 comments on commit f3e4817

Please sign in to comment.