Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Dec 6, 2023
1 parent 4ea1df7 commit 29b57ea
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ def __new__(
)


class AssetExecutionContext(OpExecutionContext):
class AssetExecutionContext:
def __init__(self, op_execution_context: OpExecutionContext) -> None:
self._op_execution_context = check.inst_param(
op_execution_context, "op_execution_context", OpExecutionContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create_op_compute_wrapper(

@wraps(fn)
def compute(
context: OpExecutionContext,
context: Union[OpExecutionContext, AssetExecutionContext],
input_defs: Mapping[str, InputDefinition],
) -> Union[Iterator[Output], AsyncIterator[Output]]:
kwargs = {}
Expand Down Expand Up @@ -105,7 +105,9 @@ def compute(

# called in this file (create_op_compute_wrapper)
async def _coerce_async_op_to_async_gen(
awaitable: Awaitable[Any], context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
awaitable: Awaitable[Any],
context: Union[OpExecutionContext, AssetExecutionContext],
output_defs: Sequence[OutputDefinition],
) -> AsyncIterator[Any]:
result = await awaitable
for event in validate_and_coerce_op_result_to_iterator(result, context, output_defs):
Expand All @@ -115,7 +117,7 @@ async def _coerce_async_op_to_async_gen(
# called in this file, and in op_invocation for direct invocation
def invoke_compute_fn(
fn: Callable,
context: OpExecutionContext,
context: Union[OpExecutionContext, AssetExecutionContext],
kwargs: Mapping[str, Any],
context_arg_provided: bool,
config_arg_cls: Optional[Type[Config]],
Expand Down Expand Up @@ -258,7 +260,9 @@ def _check_output_object_name(

# called in op_invocation and this file
def validate_and_coerce_op_result_to_iterator(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
result: Any,
context: Union[OpExecutionContext, AssetExecutionContext],
output_defs: Sequence[OutputDefinition],
) -> Iterator[Any]:
if inspect.isgenerator(result):
# this happens when a user explicitly returns a generator in the op
Expand Down

0 comments on commit 29b57ea

Please sign in to comment.