Skip to content

Commit

Permalink
Only have one kind of context for direct invocation (#17554)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria authored Jan 29, 2024
1 parent cc65e5f commit 07923e8
Show file tree
Hide file tree
Showing 6 changed files with 676 additions and 419 deletions.
123 changes: 81 additions & 42 deletions python_modules/dagster/dagster/_core/definitions/op_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .result import MaterializeResult

if TYPE_CHECKING:
from ..execution.context.invocation import BoundOpExecutionContext
from ..execution.context.invocation import DirectOpExecutionContext
from .assets import AssetsDefinition
from .composition import PendingNodeInvocation
from .decorators.op_decorator import DecoratedOpFunction
Expand Down Expand Up @@ -109,7 +109,7 @@ def direct_invocation_result(
) -> Any:
from dagster._config.pythonic_config import Config
from dagster._core.execution.context.invocation import (
UnboundOpExecutionContext,
DirectOpExecutionContext,
build_op_context,
)

Expand Down Expand Up @@ -149,12 +149,12 @@ def direct_invocation_result(
" no context was provided when invoking."
)
if len(args) > 0:
if args[0] is not None and not isinstance(args[0], UnboundOpExecutionContext):
if args[0] is not None and not isinstance(args[0], DirectOpExecutionContext):
raise DagsterInvalidInvocationError(
f"Decorated function '{compute_fn.name}' has context argument, "
"but no context was provided when invoking."
)
context = cast(UnboundOpExecutionContext, args[0])
context = cast(DirectOpExecutionContext, args[0])
# update args to omit context
args = args[1:]
else: # context argument is provided under kwargs
Expand All @@ -165,14 +165,14 @@ def direct_invocation_result(
f"'{context_param_name}', but no value for '{context_param_name}' was "
f"found when invoking. Provided kwargs: {kwargs}"
)
context = cast(UnboundOpExecutionContext, kwargs[context_param_name])
context = cast(DirectOpExecutionContext, kwargs[context_param_name])
# update kwargs to remove context
kwargs = {
kwarg: val for kwarg, val in kwargs.items() if not kwarg == context_param_name
}
# allow passing context, even if the function doesn't have an arg for it
elif len(args) > 0 and isinstance(args[0], UnboundOpExecutionContext):
context = cast(UnboundOpExecutionContext, args[0])
elif len(args) > 0 and isinstance(args[0], DirectOpExecutionContext):
context = cast(DirectOpExecutionContext, args[0])
args = args[1:]

resource_arg_mapping = {arg.name: arg.name for arg in compute_fn.get_resource_args()}
Expand Down Expand Up @@ -206,24 +206,31 @@ def direct_invocation_result(
),
)

input_dict = _resolve_inputs(op_def, input_args, input_kwargs, bound_context)
try:
# if the compute function fails, we want to ensure we unbind the context. This
# try-except handles "vanilla" asset and op invocation (generators and async handled in
# _type_check_output_wrapper)

result = invoke_compute_fn(
fn=compute_fn.decorated_fn,
context=bound_context,
kwargs=input_dict,
context_arg_provided=compute_fn.has_context_arg(),
config_arg_cls=(
compute_fn.get_config_arg().annotation if compute_fn.has_config_arg() else None
),
resource_args=resource_arg_mapping,
)
input_dict = _resolve_inputs(op_def, input_args, input_kwargs, bound_context)

return _type_check_output_wrapper(op_def, result, bound_context)
result = invoke_compute_fn(
fn=compute_fn.decorated_fn,
context=bound_context,
kwargs=input_dict,
context_arg_provided=compute_fn.has_context_arg(),
config_arg_cls=(
compute_fn.get_config_arg().annotation if compute_fn.has_config_arg() else None
),
resource_args=resource_arg_mapping,
)
return _type_check_output_wrapper(op_def, result, bound_context)
except Exception:
bound_context.unbind()
raise


def _resolve_inputs(
op_def: "OpDefinition", args, kwargs, context: "BoundOpExecutionContext"
op_def: "OpDefinition", args, kwargs, context: "DirectOpExecutionContext"
) -> Mapping[str, Any]:
from dagster._core.execution.plan.execute_step import do_type_check

Expand Down Expand Up @@ -263,7 +270,7 @@ def _resolve_inputs(

node_label = op_def.node_type_str
raise DagsterInvalidInvocationError(
f"Too many input arguments were provided for {node_label} '{context.alias}'."
f"Too many input arguments were provided for {node_label} '{context.per_invocation_properties.alias}'."
f" {suggestion}"
)

Expand Down Expand Up @@ -306,7 +313,7 @@ def _resolve_inputs(
input_dict[k] = v

# Type check inputs
op_label = context.describe_op()
op_label = context.per_invocation_properties.step_description

for input_name, val in input_dict.items():
input_def = input_defs_by_name[input_name]
Expand All @@ -326,31 +333,42 @@ def _resolve_inputs(
return input_dict


def _key_for_result(result: MaterializeResult, context: "BoundOpExecutionContext") -> AssetKey:
def _key_for_result(result: MaterializeResult, context: "DirectOpExecutionContext") -> AssetKey:
if not context.per_invocation_properties.assets_def:
raise DagsterInvariantViolationError(
f"Op {context.per_invocation_properties.alias} does not have an assets definition."
)
if result.asset_key:
return result.asset_key

if len(context.assets_def.keys) == 1:
return next(iter(context.assets_def.keys))
if (
context.per_invocation_properties.assets_def
and len(context.per_invocation_properties.assets_def.keys) == 1
):
return next(iter(context.per_invocation_properties.assets_def.keys))

raise DagsterInvariantViolationError(
"MaterializeResult did not include asset_key and it can not be inferred. Specify which"
f" asset_key, options are: {context.assets_def.keys}"
f" asset_key, options are: {context.per_invocation_properties.assets_def.keys}"
)


def _output_name_for_result_obj(
event: MaterializeResult,
context: "BoundOpExecutionContext",
context: "DirectOpExecutionContext",
):
if not context.per_invocation_properties.assets_def:
raise DagsterInvariantViolationError(
f"Op {context.per_invocation_properties.alias} does not have an assets definition."
)
asset_key = _key_for_result(event, context)
return context.assets_def.get_output_name_for_asset_key(asset_key)
return context.per_invocation_properties.assets_def.get_output_name_for_asset_key(asset_key)


def _handle_gen_event(
event: T,
op_def: "OpDefinition",
context: "BoundOpExecutionContext",
context: "DirectOpExecutionContext",
output_defs: Mapping[str, OutputDefinition],
outputs_seen: Set[str],
) -> T:
Expand All @@ -376,15 +394,15 @@ def _handle_gen_event(
output_def, DynamicOutputDefinition
):
raise DagsterInvariantViolationError(
f"Invocation of {op_def.node_type_str} '{context.alias}' yielded"
f"Invocation of {op_def.node_type_str} '{context.per_invocation_properties.alias}' yielded"
f" an output '{output_def.name}' multiple times."
)
outputs_seen.add(output_def.name)
return event


def _type_check_output_wrapper(
op_def: "OpDefinition", result: Any, context: "BoundOpExecutionContext"
op_def: "OpDefinition", result: Any, context: "DirectOpExecutionContext"
) -> Any:
"""Type checks and returns the result of a op.
Expand All @@ -399,8 +417,14 @@ def _type_check_output_wrapper(
async def to_gen(async_gen):
outputs_seen = set()

async for event in async_gen:
yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen)
try:
# if the compute function fails, we want to ensure we unbind the context. For
# async generators, the errors will only be surfaced here
async for event in async_gen:
yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen)
except Exception:
context.unbind()
raise

for output_def in op_def.output_defs:
if (
Expand All @@ -413,17 +437,24 @@ async def to_gen(async_gen):
yield Output(output_name=output_def.name, value=None)
else:
raise DagsterInvariantViolationError(
f"Invocation of {op_def.node_type_str} '{context.alias}' did not"
f"Invocation of {op_def.node_type_str} '{context.per_invocation_properties.alias}' did not"
f" return an output for non-optional output '{output_def.name}'"
)
context.unbind()

return to_gen(result)

# Coroutine result case
elif inspect.iscoroutine(result):

async def type_check_coroutine(coro):
out = await coro
try:
# if the compute function fails, we want to ensure we unbind the context. For
# async, the errors will only be surfaced here
out = await coro
except Exception:
context.unbind()
raise
return _type_check_function_output(op_def, out, context)

return type_check_coroutine(result)
Expand All @@ -433,8 +464,14 @@ async def type_check_coroutine(coro):

def type_check_gen(gen):
outputs_seen = set()
for event in gen:
yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen)
try:
# if the compute function fails, we want to ensure we unbind the context. For
# generators, the errors will only be surfaced here
for event in gen:
yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen)
except Exception:
context.unbind()
raise

for output_def in op_def.output_defs:
if (
Expand All @@ -447,9 +484,10 @@ def type_check_gen(gen):
yield Output(output_name=output_def.name, value=None)
else:
raise DagsterInvariantViolationError(
f'Invocation of {op_def.node_type_str} "{context.alias}" did not'
f'Invocation of {op_def.node_type_str} "{context.per_invocation_properties.alias}" did not'
f' return an output for non-optional output "{output_def.name}"'
)
context.unbind()

return type_check_gen(result)

Expand All @@ -458,7 +496,7 @@ def type_check_gen(gen):


def _type_check_function_output(
op_def: "OpDefinition", result: T, context: "BoundOpExecutionContext"
op_def: "OpDefinition", result: T, context: "DirectOpExecutionContext"
) -> T:
from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator

Expand All @@ -470,25 +508,26 @@ def _type_check_function_output(
# ensure result objects are contextually valid
_output_name_for_result_obj(event, context)

context.unbind()
return result


def _type_check_output(
output_def: "OutputDefinition",
output: Union[Output, DynamicOutput],
context: "BoundOpExecutionContext",
context: "DirectOpExecutionContext",
) -> None:
"""Validates and performs core type check on a provided output.
Args:
output_def (OutputDefinition): The output definition to validate against.
output (Any): The output to validate.
context (BoundOpExecutionContext): Context containing resources to be used for type
context (DirectOpExecutionContext): Context containing resources to be used for type
check.
"""
from ..execution.plan.execute_step import do_type_check

op_label = context.describe_op()
op_label = context.per_invocation_properties.step_description
dagster_type = output_def.dagster_type
type_check = do_type_check(context.for_type(dagster_type), dagster_type, output.value)
if not type_check.success:
Expand Down
Loading

0 comments on commit 07923e8

Please sign in to comment.