Skip to content

Commit

Permalink
do the if else thing
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Dec 6, 2023
1 parent 1247272 commit 4ea1df7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 9 deletions.
26 changes: 18 additions & 8 deletions python_modules/dagster/dagster/_core/execution/plan/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@
]


def _get_op_context(
context: Union[OpExecutionContext, AssetExecutionContext]
) -> OpExecutionContext:
if isinstance(context, AssetExecutionContext):
return context.op_execution_context
return context


def create_step_outputs(
node: Node,
handle: NodeHandle,
Expand Down Expand Up @@ -189,12 +197,12 @@ def _yield_compute_results(
),
user_event_generator,
):
if compute_context.has_events():
yield from compute_context.consume_events()
if _get_op_context(compute_context).has_events():
yield from _get_op_context(compute_context).consume_events()
yield _validate_event(event, step_context)

if compute_context.has_events():
yield from compute_context.consume_events()
if _get_op_context(compute_context).has_events():
yield from _get_op_context(compute_context).consume_events()


def execute_core_compute(
Expand Down Expand Up @@ -245,7 +253,8 @@ def execute_core_compute(
output.name
for output in step.step_outputs
# checks are required if we're in requires_typed_event_stream mode
if compute_context.requires_typed_event_stream or output.properties.asset_check_key
if _get_op_context(compute_context).requires_typed_event_stream
or output.properties.asset_check_key
}
omitted_outputs = expected_op_output_names.difference(emitted_result_names)
if omitted_outputs:
Expand All @@ -254,9 +263,10 @@ def execute_core_compute(
f"expected outputs {omitted_outputs!r}."
)

if compute_context.requires_typed_event_stream:
if compute_context.typed_event_stream_error_message:
message += " " + compute_context.typed_event_stream_error_message
if _get_op_context(compute_context).requires_typed_event_stream:
error_message = _get_op_context(compute_context).typed_event_stream_error_message
if error_message:
message += " " + error_message
raise DagsterInvariantViolationError(message)
else:
step_context.log.info(message)
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@
from dagster._utils import is_named_tuple_instance
from dagster._utils.warnings import disable_dagster_warnings

from ..context.compute import OpExecutionContext
from ..context.compute import AssetExecutionContext, OpExecutionContext


def _get_op_context(
context: Union[OpExecutionContext, AssetExecutionContext]
) -> OpExecutionContext:
if isinstance(context, AssetExecutionContext):
return context.op_execution_context
return context


# called in execute_step if the fn is not decorated
def create_op_compute_wrapper(
op_def: OpDefinition,
) -> Callable[[OpExecutionContext, Mapping[str, InputDefinition]], Any]:
Expand Down Expand Up @@ -94,6 +103,7 @@ def compute(
return compute


# called in this file (create_op_compute_wrapper)
async def _coerce_async_op_to_async_gen(
awaitable: Awaitable[Any], context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> AsyncIterator[Any]:
Expand All @@ -102,6 +112,7 @@ async def _coerce_async_op_to_async_gen(
yield event


# called in this file, and in op_invocation for direct invocation
def invoke_compute_fn(
fn: Callable,
context: OpExecutionContext,
Expand All @@ -125,6 +136,7 @@ def invoke_compute_fn(
return fn(context, **args_to_pass) if context_arg_provided else fn(**args_to_pass)


# called in this file (create_op_compute_wrapper)
def _coerce_op_compute_fn_to_iterator(
fn, output_defs, context, context_arg_provided, kwargs, config_arg_class, resource_arg_mapping
):
Expand All @@ -135,6 +147,7 @@ def _coerce_op_compute_fn_to_iterator(
yield event


# called in this file (validate_and_coerce_op_result_to_iterator)
def _zip_and_iterate_op_result(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> Iterator[Tuple[int, Any, OutputDefinition]]:
Expand Down Expand Up @@ -162,6 +175,7 @@ def _zip_and_iterate_op_result(

# Filter out output_defs corresponding to asset check results that already exist on a
# MaterializeResult.
# called in this file (_zip_and_iterate_op_result)
def _filter_expected_output_defs(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> Sequence[OutputDefinition]:
Expand All @@ -177,6 +191,7 @@ def _filter_expected_output_defs(
return [out for out in output_defs if out.name not in remove_outputs]


# called in this file (_zip_and_iterate_op_result)
def _validate_multi_return(
context: OpExecutionContext,
result: Any,
Expand Down Expand Up @@ -212,6 +227,7 @@ def _validate_multi_return(
return result


# called in this file (validate_and_coerce_op_result_to_iterator)
def _get_annotation_for_output_position(
position: int, op_def: OpDefinition, output_defs: Sequence[OutputDefinition]
) -> Any:
Expand All @@ -226,6 +242,7 @@ def _get_annotation_for_output_position(
return inspect.Parameter.empty


# called in this file (validate_and_coerce_op_result_to_iterator)
def _check_output_object_name(
output: Union[DynamicOutput, Output], output_def: OutputDefinition, position: int
) -> None:
Expand All @@ -239,6 +256,7 @@ def _check_output_object_name(
)


# called in op_invocation and this file
def validate_and_coerce_op_result_to_iterator(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> Iterator[Any]:
Expand Down

0 comments on commit 4ea1df7

Please sign in to comment.