Skip to content

Commit

Permalink
move RemoteAssetGraph.from_external_repository to cached property on …
Browse files Browse the repository at this point in the history
…ExternalRepository (#20468)

`RemoteAssetGraph` is a relatively heavy object. Given a large asset
graph, if any code ends up creating copies on the order of number of
assets, it can be very expensive.

We already do a good job of memoizing `ExternalRepository` via
`CodeLocation` via `WorkspaceRequestContext` via
`WorkspaceProcessContext`, so caching these `RemoteAssetGraph` on the
`ExternalRepository` that they are a pure function of should avoid
excess computation and memory until the `WorkspaceProcessContext`
refreshes its snapshots on code server update.

## How I Tested These Changes

existing coverage
a `memray` run of a script doing per asset `RemoteAssetGraph` showed
drastic reduction in memory
  • Loading branch information
alangenfeld authored and PedramNavid committed Mar 28, 2024
1 parent a549070 commit 27c0ef2
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def resolve_assetMaterializationUsedData(
return []

instance = graphene_info.context.instance
asset_graph = RemoteAssetGraph.from_external_repository(self._external_repository)
asset_graph = self._external_repository.asset_graph
asset_key = self._external_asset_node.asset_key

# in the future, we can share this same CachingInstanceQueryer across all
Expand Down Expand Up @@ -885,7 +885,7 @@ def resolve_freshnessInfo(
self, graphene_info: ResolveInfo
) -> Optional[GrapheneAssetFreshnessInfo]:
if self._external_asset_node.freshness_policy:
asset_graph = RemoteAssetGraph.from_external_repository(self._external_repository)
asset_graph = self._external_repository.asset_graph
return get_freshness_info(
asset_key=self._external_asset_node.asset_key,
# in the future, we can share this same CachingInstanceQueryer across all
Expand Down Expand Up @@ -929,7 +929,7 @@ def resolve_targetingInstigators(self, graphene_info) -> Sequence[GrapheneSensor
external_sensors = self._external_repository.get_external_sensors()
external_schedules = self._external_repository.get_external_schedules()

asset_graph = RemoteAssetGraph.from_external_repository(self._external_repository)
asset_graph = self._external_repository.asset_graph

job_names = {
job_name
Expand Down Expand Up @@ -959,7 +959,7 @@ def resolve_targetingInstigators(self, graphene_info) -> Sequence[GrapheneSensor
return results

def _get_auto_materialize_external_sensor(self) -> Optional[ExternalSensor]:
asset_graph = RemoteAssetGraph.from_external_repository(self._external_repository)
asset_graph = self._external_repository.asset_graph

asset_key = self._external_asset_node.asset_key
matching_sensors = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import graphene
from dagster._core.definitions.asset_selection import AssetSelection
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph
from dagster._core.remote_representation.external import ExternalRepository

from ..implementation.fetch_assets import get_asset_nodes_by_asset_key
Expand All @@ -21,7 +20,7 @@ def resolve_assetSelectionString(self, _graphene_info):
return str(self._asset_selection)

def resolve_assetKeys(self, _graphene_info):
asset_graph = RemoteAssetGraph.from_external_repository(self._external_repository)
asset_graph = self._external_repository.asset_graph
return [
GrapheneAssetKey(path=asset_key.path)
for asset_key in self._asset_selection.resolve(asset_graph)
Expand All @@ -30,7 +29,7 @@ def resolve_assetKeys(self, _graphene_info):
def resolve_assets(self, graphene_info):
from dagster_graphql.schema.pipelines.pipeline import GrapheneAsset

asset_graph = RemoteAssetGraph.from_external_repository(self._external_repository)
asset_graph = self._external_repository.asset_graph
asset_nodes_by_asset_key = get_asset_nodes_by_asset_key(graphene_info)

return [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
)
from dagster._core.definitions.asset_graph_differ import AssetGraphDiffer
from dagster._core.definitions.partition import CachingDynamicPartitionsLoader
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph
from dagster._core.definitions.sensor_definition import (
SensorType,
)
Expand Down Expand Up @@ -275,7 +274,7 @@ def __init__(
self._batch_loader = RepositoryScopedBatchLoader(instance, repository)
self._stale_status_loader = StaleStatusLoader(
instance=instance,
asset_graph=lambda: RemoteAssetGraph.from_external_repository(repository),
asset_graph=lambda: repository.asset_graph,
)
self._dynamic_partitions_loader = CachingDynamicPartitionsLoader(instance)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def resolve_assetNodes(

def load_asset_graph() -> RemoteAssetGraph:
if repo is not None:
return RemoteAssetGraph.from_external_repository(repo)
return repo.asset_graph
else:
return RemoteAssetGraph.from_workspace(graphene_info.context)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def test_cancel_asset_backfill(self, graphql_context):
# since launching the run will cause test process will hang forever.
code_location = graphql_context.get_code_location("test")
repository = code_location.get_repository("test_repo")
asset_graph = RemoteAssetGraph.from_external_repository(repository)
asset_graph = repository.asset_graph
_execute_asset_backfill_iteration_no_side_effects(graphql_context, backfill_id, asset_graph)

# Launch the run that runs forever
Expand Down Expand Up @@ -793,7 +793,7 @@ def test_asset_backfill_partition_stats(self, graphql_context):

code_location = graphql_context.get_code_location("test")
repository = code_location.get_repository("test_repo")
asset_graph = RemoteAssetGraph.from_external_repository(repository)
asset_graph = repository.asset_graph

_execute_asset_backfill_iteration_no_side_effects(graphql_context, backfill_id, asset_graph)

Expand Down Expand Up @@ -836,7 +836,7 @@ def test_asset_backfill_partition_stats(self, graphql_context):
def test_asset_backfill_status_with_upstream_failure(self, graphql_context):
code_location = graphql_context.get_code_location("test")
repository = code_location.get_repository("test_repo")
asset_graph = RemoteAssetGraph.from_external_repository(repository)
asset_graph = repository.asset_graph

asset_keys = [
AssetKey("unpartitioned_upstream_of_partitioned"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,8 @@ def from_external_repositories(
base_workspace, code_location_name, repository_name
)
return AssetGraphDiffer(
branch_asset_graph=lambda: RemoteAssetGraph.from_external_repository(branch_repo),
base_asset_graph=(lambda: RemoteAssetGraph.from_external_repository(base_repo))
if base_repo is not None
else None,
branch_asset_graph=lambda: branch_repo.asset_graph,
base_asset_graph=(lambda: base_repo.asset_graph) if base_repo is not None else None,
)

def _compare_base_and_branch_assets(self, asset_key: "AssetKey") -> Sequence[ChangeReason]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,18 +237,6 @@ def from_workspace(cls, context: IWorkspace) -> "RemoteAssetGraph":
external_asset_checks=asset_checks,
)

@classmethod
def from_external_repository(
cls, external_repository: ExternalRepository
) -> "RemoteAssetGraph":
return cls.from_repository_handles_and_external_asset_nodes(
repo_handle_external_asset_nodes=[
(external_repository.handle, asset_node)
for asset_node in external_repository.get_external_asset_nodes()
],
external_asset_checks=external_repository.get_external_asset_checks(),
)

@classmethod
def from_repository_handles_and_external_asset_nodes(
cls,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from functools import cached_property
from threading import RLock
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -75,6 +76,7 @@
from .represented import RepresentedJob

if TYPE_CHECKING:
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph
from dagster._core.scheduler.instigation import InstigatorState
from dagster._core.snap.execution_plan_snapshot import ExecutionStepSnap

Expand Down Expand Up @@ -186,15 +188,13 @@ def get_default_auto_materialize_sensor_name(self):
@property
@cached_method
def _external_sensors(self) -> Dict[str, "ExternalSensor"]:
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph

sensor_datas = {
external_sensor_data.name: ExternalSensor(external_sensor_data, self._handle)
for external_sensor_data in self.external_repository_data.external_sensor_datas
}

if self._instance.auto_materialize_use_sensors:
asset_graph = RemoteAssetGraph.from_external_repository(self)
asset_graph = self.asset_graph

has_any_auto_observe_source_assets = False

Expand Down Expand Up @@ -373,6 +373,18 @@ def get_external_asset_checks(
def get_display_metadata(self) -> Mapping[str, str]:
return self.handle.display_metadata

@cached_property
def asset_graph(self) -> "RemoteAssetGraph":
"""Returns a repository scoped RemoteAssetGraph."""
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph

return RemoteAssetGraph.from_repository_handles_and_external_asset_nodes(
repo_handle_external_asset_nodes=[
(self.handle, asset_node) for asset_node in self.get_external_asset_nodes()
],
external_asset_checks=self.get_external_asset_checks(),
)


class ExternalJob(RepresentedJob):
"""ExternalJob is a object that represents a loaded job definition that
Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/_daemon/asset_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def _process_auto_materialize_tick_generator(

if sensor:
eligible_keys = check.not_none(sensor.asset_selection).resolve(
RemoteAssetGraph.from_external_repository(check.not_none(repository))
check.not_none(repository).asset_graph
)
else:
eligible_keys = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from dagster._core.definitions.auto_materialize_sensor_definition import (
AutoMaterializeSensorDefinition,
)
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph
from dagster._core.definitions.sensor_definition import (
SensorType,
)
Expand Down Expand Up @@ -223,7 +222,7 @@ def test_combine_default_sensors_with_non_default_sensors(instance_with_auto_mat
assert external_repo.has_external_sensor("default_auto_materialize_sensor")
assert external_repo.has_external_sensor("my_custom_policy_sensor")

asset_graph = RemoteAssetGraph.from_external_repository(external_repo)
asset_graph = external_repo.asset_graph

# default sensor includes all assets that weren't covered by the custom one

Expand Down Expand Up @@ -293,7 +292,7 @@ def test_custom_sensors_cover_all(instance_with_auto_materialize_sensors):
assert external_repo.has_external_sensor("normal_sensor")
assert external_repo.has_external_sensor("my_custom_policy_sensor")

asset_graph = RemoteAssetGraph.from_external_repository(external_repo)
asset_graph = external_repo.asset_graph

# Custom sensor covered all the valid assets
custom_sensor = external_repo.get_external_sensor("my_custom_policy_sensor")
Expand Down

0 comments on commit 27c0ef2

Please sign in to comment.