Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct invocation asset execution context #18549

Merged
merged 17 commits into from
Jan 31, 2024
44 changes: 28 additions & 16 deletions python_modules/dagster/dagster/_core/definitions/op_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from .result import MaterializeResult

if TYPE_CHECKING:
from ..execution.context.invocation import DirectOpExecutionContext
from ..execution.context.compute import OpExecutionContext
from ..execution.context.invocation import BaseDirectExecutionContext
from .assets import AssetsDefinition
from .composition import PendingNodeInvocation
from .decorators.op_decorator import DecoratedOpFunction
Expand Down Expand Up @@ -100,6 +101,16 @@ def _separate_args_and_kwargs(
)


def _get_op_context(
context,
) -> "OpExecutionContext":
from dagster._core.execution.context.compute import AssetExecutionContext

if isinstance(context, AssetExecutionContext):
return context.op_execution_context
return context


def direct_invocation_result(
def_or_invocation: Union[
"OpDefinition", "PendingNodeInvocation[OpDefinition]", "AssetsDefinition"
Expand All @@ -109,7 +120,7 @@ def direct_invocation_result(
) -> Any:
from dagster._config.pythonic_config import Config
from dagster._core.execution.context.invocation import (
DirectOpExecutionContext,
BaseDirectExecutionContext,
build_op_context,
)

Expand Down Expand Up @@ -149,12 +160,12 @@ def direct_invocation_result(
" no context was provided when invoking."
)
if len(args) > 0:
if args[0] is not None and not isinstance(args[0], DirectOpExecutionContext):
if args[0] is not None and not isinstance(args[0], BaseDirectExecutionContext):
raise DagsterInvalidInvocationError(
f"Decorated function '{compute_fn.name}' has context argument, "
"but no context was provided when invoking."
)
context = cast(DirectOpExecutionContext, args[0])
context = args[0]
# update args to omit context
args = args[1:]
else: # context argument is provided under kwargs
Expand All @@ -165,14 +176,14 @@ def direct_invocation_result(
f"'{context_param_name}', but no value for '{context_param_name}' was "
f"found when invoking. Provided kwargs: {kwargs}"
)
context = cast(DirectOpExecutionContext, kwargs[context_param_name])
context = kwargs[context_param_name]
# update kwargs to remove context
kwargs = {
kwarg: val for kwarg, val in kwargs.items() if not kwarg == context_param_name
}
# allow passing context, even if the function doesn't have an arg for it
elif len(args) > 0 and isinstance(args[0], DirectOpExecutionContext):
context = cast(DirectOpExecutionContext, args[0])
elif len(args) > 0 and isinstance(args[0], BaseDirectExecutionContext):
context = args[0]
args = args[1:]

resource_arg_mapping = {arg.name: arg.name for arg in compute_fn.get_resource_args()}
Expand Down Expand Up @@ -230,7 +241,7 @@ def direct_invocation_result(


def _resolve_inputs(
op_def: "OpDefinition", args, kwargs, context: "DirectOpExecutionContext"
op_def: "OpDefinition", args, kwargs, context: "BaseDirectExecutionContext"
) -> Mapping[str, Any]:
from dagster._core.execution.plan.execute_step import do_type_check

Expand Down Expand Up @@ -333,7 +344,7 @@ def _resolve_inputs(
return input_dict


def _key_for_result(result: MaterializeResult, context: "DirectOpExecutionContext") -> AssetKey:
def _key_for_result(result: MaterializeResult, context: "BaseDirectExecutionContext") -> AssetKey:
if not context.per_invocation_properties.assets_def:
raise DagsterInvariantViolationError(
f"Op {context.per_invocation_properties.alias} does not have an assets definition."
Expand All @@ -355,7 +366,7 @@ def _key_for_result(result: MaterializeResult, context: "DirectOpExecutionContex

def _output_name_for_result_obj(
event: MaterializeResult,
context: "DirectOpExecutionContext",
context: "BaseDirectExecutionContext",
):
if not context.per_invocation_properties.assets_def:
raise DagsterInvariantViolationError(
Expand All @@ -368,7 +379,7 @@ def _output_name_for_result_obj(
def _handle_gen_event(
event: T,
op_def: "OpDefinition",
context: "DirectOpExecutionContext",
context: "BaseDirectExecutionContext",
output_defs: Mapping[str, OutputDefinition],
outputs_seen: Set[str],
) -> T:
Expand Down Expand Up @@ -402,7 +413,7 @@ def _handle_gen_event(


def _type_check_output_wrapper(
op_def: "OpDefinition", result: Any, context: "DirectOpExecutionContext"
op_def: "OpDefinition", result: Any, context: "BaseDirectExecutionContext"
) -> Any:
"""Type checks and returns the result of a op.

Expand Down Expand Up @@ -496,12 +507,13 @@ def type_check_gen(gen):


def _type_check_function_output(
op_def: "OpDefinition", result: T, context: "DirectOpExecutionContext"
op_def: "OpDefinition", result: T, context: "BaseDirectExecutionContext"
) -> T:
from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator

output_defs_by_name = {output_def.name: output_def for output_def in op_def.output_defs}
for event in validate_and_coerce_op_result_to_iterator(result, context, op_def.output_defs):
op_context = _get_op_context(context)
for event in validate_and_coerce_op_result_to_iterator(result, op_context, op_def.output_defs):
if isinstance(event, (Output, DynamicOutput)):
_type_check_output(output_defs_by_name[event.output_name], event, context)
elif isinstance(event, (MaterializeResult)):
Expand All @@ -515,14 +527,14 @@ def _type_check_function_output(
def _type_check_output(
output_def: "OutputDefinition",
output: Union[Output, DynamicOutput],
context: "DirectOpExecutionContext",
context: "BaseDirectExecutionContext",
) -> None:
"""Validates and performs core type check on a provided output.

Args:
output_def (OutputDefinition): The output definition to validate against.
output (Any): The output to validate.
context (DirectOpExecutionContext): Context containing resources to be used for type
context (BaseDirectExecutionContext): Context containing resources to be used for type
check.
"""
from ..execution.plan.execute_step import do_type_check
Expand Down
140 changes: 133 additions & 7 deletions python_modules/dagster/dagster/_core/execution/context/invocation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import abstractmethod
from contextlib import ExitStack
from typing import (
AbstractSet,
Expand Down Expand Up @@ -56,14 +57,60 @@
from dagster._utils.merger import merge_dicts
from dagster._utils.warnings import deprecation_warning

from .compute import OpExecutionContext
from .compute import AssetExecutionContext, OpExecutionContext
from .system import StepExecutionContext, TypeCheckContext


def _property_msg(prop_name: str, method_name: str) -> str:
return (
f"The {prop_name} {method_name} is not set on the context when an op is directly invoked."
)
return f"The {prop_name} {method_name} is not set on the context when an asset or op is directly invoked."


class BaseDirectExecutionContext:
jamiedemaria marked this conversation as resolved.
Show resolved Hide resolved
"""Base class for any direct invocation execution contexts. Each type of execution context
(ex. OpExecutionContext, AssetExecutionContext) needs to have a variant for direct invocation.
Those direct invocation contexts have some methods that are not available until the context
is bound to a particular op/asset. The "bound" properties are held in PerInvocationProperties.
There are also some properties that are specific to a particular execution of an op/asset, these
properties are held in DirectExecutionProperties. Direct invocation contexts must
be able to be bound and unbound from a particular op/asset. Additionally, there are some methods
that all direct invocation contexts must implement so that the will be usable in the execution
code path.
"""

@abstractmethod
def bind(
self,
op_def: OpDefinition,
pending_invocation: Optional[PendingNodeInvocation[OpDefinition]],
assets_def: Optional[AssetsDefinition],
config_from_args: Optional[Mapping[str, Any]],
resources_from_args: Optional[Mapping[str, Any]],
):
"""Subclasses of BaseDirectExecutionContext must implement bind."""

@abstractmethod
def unbind(self):
"""Subclasses of BaseDirectExecutionContext must implement unbind."""

@property
@abstractmethod
def per_invocation_properties(self) -> "PerInvocationProperties":
"""Subclasses of BaseDirectExecutionContext must contain a PerInvocationProperties object."""

@property
@abstractmethod
def execution_properties(self) -> "DirectExecutionProperties":
"""Subclasses of BaseDirectExecutionContext must contain a DirectExecutionProperties object."""

@abstractmethod
def for_type(self, dagster_type: DagsterType) -> TypeCheckContext:
"""Subclasses of BaseDirectExecutionContext must implement for_type."""
pass

@abstractmethod
def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None:
"""Subclasses of BaseDirectExecutionContext must implement observe_output."""
pass


class PerInvocationProperties(
Expand Down Expand Up @@ -127,7 +174,7 @@ def __init__(self):
self.typed_event_stream_error_message: Optional[str] = None


class DirectOpExecutionContext(OpExecutionContext):
class DirectOpExecutionContext(OpExecutionContext, BaseDirectExecutionContext):
"""The ``context`` object available as the first argument to an op's compute function when
being invoked directly. Can also be used as a context manager.
"""
Expand Down Expand Up @@ -706,6 +753,83 @@ def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> No
self._execution_properties.typed_event_stream_error_message = error_message


class DirectAssetExecutionContext(AssetExecutionContext, BaseDirectExecutionContext):
"""The ``context`` object available as the first argument to an asset's compute function when
being invoked directly. Can also be used as a context manager.
"""

def __init__(self, op_execution_context: DirectOpExecutionContext):
self._op_execution_context = op_execution_context

def __enter__(self):
self.op_execution_context._cm_scope_entered = True # noqa: SLF001
return self

def __exit__(self, *exc):
self.op_execution_context._exit_stack.close() # noqa: SLF001

def __del__(self):
self.op_execution_context._exit_stack.close() # noqa: SLF001

def _check_bound_to_invocation(self, fn_name: str, fn_type: str):
if not self._op_execution_context._per_invocation_properties: # noqa: SLF001
raise DagsterInvalidPropertyError(_property_msg(fn_name, fn_type))

def bind(
self,
op_def: OpDefinition,
pending_invocation: Optional[PendingNodeInvocation[OpDefinition]],
assets_def: Optional[AssetsDefinition],
config_from_args: Optional[Mapping[str, Any]],
resources_from_args: Optional[Mapping[str, Any]],
) -> "DirectAssetExecutionContext":
if assets_def is None:
raise DagsterInvariantViolationError(
"DirectAssetExecutionContext can only being used to invoke an asset."
)
if self._op_execution_context._per_invocation_properties is not None: # noqa: SLF001
raise DagsterInvalidInvocationError(
f"This context is currently being used to execute {self.op_execution_context.alias}."
" The context cannot be used to execute another asset until"
f" {self.op_execution_context.alias} has finished executing."
)

self._op_execution_context = self._op_execution_context.bind(
op_def=op_def,
pending_invocation=pending_invocation,
assets_def=assets_def,
config_from_args=config_from_args,
resources_from_args=resources_from_args,
)

return self

def unbind(self):
self._op_execution_context.unbind()

@property
def per_invocation_properties(self) -> PerInvocationProperties:
return self.op_execution_context.per_invocation_properties

@property
def is_bound(self) -> bool:
return self.op_execution_context.is_bound

@property
def execution_properties(self) -> DirectExecutionProperties:
return self.op_execution_context.execution_properties

@property
def op_execution_context(self) -> DirectOpExecutionContext:
return self._op_execution_context

def for_type(self, dagster_type: DagsterType) -> TypeCheckContext:
return self.op_execution_context.for_type(dagster_type)

def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None:
self.op_execution_context.observe_output(output_name=output_name, mapping_key=mapping_key)


def _validate_resource_requirements(
resource_defs: Mapping[str, ResourceDefinition], op_def: OpDefinition
) -> None:
Expand Down Expand Up @@ -796,7 +920,7 @@ def build_asset_context(
instance: Optional[DagsterInstance] = None,
partition_key: Optional[str] = None,
partition_key_range: Optional[PartitionKeyRange] = None,
):
) -> DirectAssetExecutionContext:
"""Builds asset execution context from provided parameters.

``build_asset_context`` can be used as either a function or context manager. If there is a
Expand All @@ -823,11 +947,13 @@ def build_asset_context(
with build_asset_context(resources={"foo": context_manager_resource}) as context:
asset_to_invoke(context)
"""
return build_op_context(
op_context = build_op_context(
op_config=asset_config,
resources=resources,
resources_config=resources_config,
partition_key=partition_key,
partition_key_range=partition_key_range,
instance=instance,
)

return DirectAssetExecutionContext(op_execution_context=op_context)
6 changes: 3 additions & 3 deletions python_modules/dagster/dagster/_core/pipes/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from dagster._core.errors import DagsterPipesExecutionError
from dagster._core.events import EngineEventData
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.invocation import DirectOpExecutionContext
from dagster._core.execution.context.invocation import BaseDirectExecutionContext
from dagster._utils.error import (
ExceptionInfo,
SerializableErrorInfo,
Expand Down Expand Up @@ -406,8 +406,8 @@ def build_external_execution_context_data(
_convert_time_window(partition_time_window) if partition_time_window else None
),
run_id=context.run_id,
job_name=None if isinstance(context, DirectOpExecutionContext) else context.job_name,
retry_number=0 if isinstance(context, DirectOpExecutionContext) else context.retry_number,
job_name=None if isinstance(context, BaseDirectExecutionContext) else context.job_name,
retry_number=0 if isinstance(context, BaseDirectExecutionContext) else context.retry_number,
extras=extras or {},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1383,7 +1383,7 @@ async def main():
with pytest.raises(
DagsterInvalidInvocationError,
match=r"This context is currently being used to execute .* The context"
r" cannot be used to execute another op until .* has finished executing",
r" cannot be used to execute another asset until .* has finished executing",
):
asyncio.run(main())

Expand Down