Skip to content

Commit

Permalink
add DI asset context
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Dec 6, 2023
1 parent 5d2ef81 commit ee347b4
Showing 1 changed file with 81 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@
from dagster._utils.merger import merge_dicts
from dagster._utils.warnings import deprecation_warning

from .compute import OpExecutionContext
from .compute import AssetExecutionContext, OpExecutionContext, RunProperties
from .system import StepExecutionContext, TypeCheckContext


def _property_msg(prop_name: str, method_name: str) -> str:
# TODO - update to handle assets too
return (
f"The {prop_name} {method_name} is not set on the context when an op is directly invoked."
)
Expand Down Expand Up @@ -716,6 +717,82 @@ def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> No
self._execution_properties.set_requires_typed_event_stream(error_message=error_message)


class RunlessAssetExecutionContext(AssetExecutionContext):
"""The ``context`` object available as the first argument to an asset's compute function when
being invoked directly. Can also be used as a context manager.
"""

def __init__(self, op_execution_context: RunlessOpExecutionContext):
self._op_execution_context = op_execution_context

self._run_props = None

def _check_bound(self, fn_name: str, fn_type: str):
if not self._op_execution_context._bound_properties: # noqa: SLF001
raise DagsterInvalidPropertyError(_property_msg(fn_name, fn_type))

def bind(
self,
op_def: OpDefinition,
pending_invocation: Optional[PendingNodeInvocation[OpDefinition]],
assets_def: Optional[AssetsDefinition],
config_from_args: Optional[Mapping[str, Any]],
resources_from_args: Optional[Mapping[str, Any]],
) -> "RunlessAssetExecutionContext":
if assets_def is None:
raise DagsterInvariantViolationError(
"RunlessAssetExecutionContext can only being used to invoke an asset."
)
if self._op_execution_context._bound_properties is not None: # noqa: SLF001
raise DagsterInvalidInvocationError(
f"This context is currently being used to execute {self.op_execution_context.alias}."
" The context cannot be used to execute another asset until"
f" {self.op_execution_context.alias} has finished executing."
)

self._op_execution_context = self._op_execution_context.bind(
op_def=op_def,
pending_invocation=pending_invocation,
assets_def=assets_def,
config_from_args=config_from_args,
resources_from_args=resources_from_args,
)

return self

def unbind(self):
self._op_execution_context = self._op_execution_context.unbind()

@property
def op_execution_context(self) -> RunlessOpExecutionContext:
return self._op_execution_context

def for_type(self, dagster_type: DagsterType) -> TypeCheckContext:
self._check_bound(fn_name="for_type", fn_type="method")
resources = cast(NamedTuple, self.resources)
return TypeCheckContext(
self.run_id,
self.log,
ScopedResourcesBuilder(resources._asdict()),
dagster_type,
)

def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None:
self.op_execution_context.observe_output(output_name=output_name, mapping_key=mapping_key)

@property
def run_properties(self) -> RunProperties:
self._check_bound(fn_name="run_properties", fn_type="property")
if self._run_props is None:
self._run_props = RunProperties(
run_id=self.op_execution_context.run_id,
run_config=self.op_execution_context.run_config,
dagster_run=self.op_execution_context.run,
retry_number=self.op_execution_context.retry_number,
)
return self._run_props


def _validate_resource_requirements(
resource_defs: Mapping[str, ResourceDefinition], op_def: OpDefinition
) -> None:
Expand Down Expand Up @@ -833,11 +910,13 @@ def build_asset_context(
with build_asset_context(resources={"foo": context_manager_resource}) as context:
asset_to_invoke(context)
"""
return build_op_context(
op_context = build_op_context(
op_config=asset_config,
resources=resources,
resources_config=resources_config,
partition_key=partition_key,
partition_key_range=partition_key_range,
instance=instance,
)

return RunlessAssetExecutionContext(op_execution_context=op_context)

0 comments on commit ee347b4

Please sign in to comment.