Skip to content

Commit

Permalink
Extract _get_op_def_compute_fn into wrap_source_asset_observe_fn_in_o…
Browse files Browse the repository at this point in the history
…p_compute_fn

This refactoring will be useful in a subsequent PR
  • Loading branch information
schrockn committed Sep 19, 2023
1 parent 2fa51cc commit 11f8bf6
Showing 1 changed file with 76 additions and 62 deletions.
138 changes: 76 additions & 62 deletions python_modules/dagster/dagster/_core/definitions/source_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
cast,
)

from typing_extensions import TypeAlias
from typing_extensions import TYPE_CHECKING, TypeAlias

import dagster._check as check
from dagster._annotations import PublicAttr, experimental_param, public
Expand Down Expand Up @@ -50,10 +50,84 @@
from dagster._utils.merger import merge_dicts
from dagster._utils.warnings import disable_dagster_warnings

if TYPE_CHECKING:
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
)

# Going with this catch-all for the time-being to permit pythonic resources
SourceAssetObserveFunction: TypeAlias = Callable[..., Any]


@staticmethod
def wrap_source_asset_observe_fn_in_op_compute_fn(
source_asset: "SourceAsset",
) -> "DecoratedOpFunction":
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
is_context_provided,
)
from dagster._core.execution.context.compute import (
OpExecutionContext,
)

check.not_none(source_asset.observe_fn, "Must be an observable source asset")
assert source_asset.observe_fn # for type checker

observe_fn = source_asset.observe_fn

observe_fn_has_context = is_context_provided(get_function_params(observe_fn))

def fn(context: OpExecutionContext):
resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)]
resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys}
observe_fn_return_value = (
observe_fn(context, **resource_kwargs)
if observe_fn_has_context
else observe_fn(**resource_kwargs)
)

if isinstance(observe_fn_return_value, DataVersion):
if source_asset.partitions_def is not None:
raise DagsterInvalidObservationError(
f"{source_asset.key} is partitioned, so its observe function should return a"
" DataVersionsByPartition, not a DataVersion"
)

context.log_event(
AssetObservation(
asset_key=source_asset.key,
tags={DATA_VERSION_TAG: observe_fn_return_value.value},
)
)
elif isinstance(observe_fn_return_value, DataVersionsByPartition):
if source_asset.partitions_def is None:
raise DagsterInvalidObservationError(
f"{source_asset.key} is not partitioned, so its observe function should return"
" a DataVersion, not a DataVersionsByPartition"
)

for (
partition_key,
data_version,
) in observe_fn_return_value.data_versions_by_partition.items():
context.log_event(
AssetObservation(
asset_key=source_asset.key,
tags={DATA_VERSION_TAG: data_version.value},
partition=partition_key,
)
)
else:
raise DagsterInvalidObservationError(
f"Observe function for {source_asset.key} must return a DataVersion or"
" DataVersionsByPartition, but returned a value of type"
f" {type(observe_fn_return_value)}"
)

return DecoratedOpFunction(fn)


@experimental_param(param="resource_defs")
@experimental_param(param="io_manager_def")
class SourceAsset(ResourceAddable):
Expand Down Expand Up @@ -180,66 +254,6 @@ def is_observable(self) -> bool:
"""bool: Whether the asset is observable."""
return self.node_def is not None

def _get_op_def_compute_fn(self, observe_fn: SourceAssetObserveFunction):
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
is_context_provided,
)
from dagster._core.execution.context.compute import (
OpExecutionContext,
)

observe_fn_has_context = is_context_provided(get_function_params(observe_fn))

def fn(context: OpExecutionContext):
resource_kwarg_keys = [param.name for param in get_resource_args(observe_fn)]
resource_kwargs = {key: getattr(context.resources, key) for key in resource_kwarg_keys}
observe_fn_return_value = (
observe_fn(context, **resource_kwargs)
if observe_fn_has_context
else observe_fn(**resource_kwargs)
)

if isinstance(observe_fn_return_value, DataVersion):
if self.partitions_def is not None:
raise DagsterInvalidObservationError(
f"{self.key} is partitioned, so its observe function should return a"
" DataVersionsByPartition, not a DataVersion"
)

context.log_event(
AssetObservation(
asset_key=self.key,
tags={DATA_VERSION_TAG: observe_fn_return_value.value},
)
)
elif isinstance(observe_fn_return_value, DataVersionsByPartition):
if self.partitions_def is None:
raise DagsterInvalidObservationError(
f"{self.key} is not partitioned, so its observe function should return a"
" DataVersion, not a DataVersionsByPartition"
)

for (
partition_key,
data_version,
) in observe_fn_return_value.data_versions_by_partition.items():
context.log_event(
AssetObservation(
asset_key=self.key,
tags={DATA_VERSION_TAG: data_version.value},
partition=partition_key,
)
)
else:
raise DagsterInvalidObservationError(
f"Observe function for {self.key} must return a DataVersion or"
" DataVersionsByPartition, but returned a value of type"
f" {type(observe_fn_return_value)}"
)

return DecoratedOpFunction(fn)

@property
def required_resource_keys(self) -> AbstractSet[str]:
return {requirement.key for requirement in self.get_resource_requirements()}
Expand All @@ -252,7 +266,7 @@ def node_def(self) -> Optional[OpDefinition]:

if self._node_def is None:
self._node_def = OpDefinition(
compute_fn=self._get_op_def_compute_fn(self.observe_fn),
compute_fn=wrap_source_asset_observe_fn_in_op_compute_fn(self),
name=self.key.to_python_identifier(),
description=self.description,
required_resource_keys=self._required_resource_keys,
Expand Down

0 comments on commit 11f8bf6

Please sign in to comment.