From 34d9560b3a352fa67f5b936cc44e3418c8ef07ef Mon Sep 17 00:00:00 2001 From: Alex Langenfeld Date: Tue, 8 Oct 2024 17:00:00 -0500 Subject: [PATCH] [graphql] make GrapheneRepository lazy load the repository (#25122) restructure `GrapheneRepository` to init with minimal information and load objects when they are needed ## How I Tested These Changes existing coverage --- .../implementation/external.py | 10 +- .../implementation/fetch_assets.py | 9 +- .../dagster_graphql/schema/asset_graph.py | 6 +- .../dagster_graphql/schema/external.py | 162 ++++++++++-------- .../schema/pipelines/pipeline.py | 8 +- .../_core/remote_representation/external.py | 32 +--- .../_core/remote_representation/handle.py | 37 ++++ 7 files changed, 129 insertions(+), 135 deletions(-) diff --git a/python_modules/dagster-graphql/dagster_graphql/implementation/external.py b/python_modules/dagster-graphql/dagster_graphql/implementation/external.py index 7dc7ee7d1a9cc..28a21334c299e 100644 --- a/python_modules/dagster-graphql/dagster_graphql/implementation/external.py +++ b/python_modules/dagster-graphql/dagster_graphql/implementation/external.py @@ -114,11 +114,7 @@ def fetch_repositories(graphene_info: "ResolveInfo") -> "GrapheneRepositoryConne return GrapheneRepositoryConnection( nodes=[ - GrapheneRepository( - workspace_context=graphene_info.context, - repository=repository, - repository_location=location, - ) + GrapheneRepository(repository.handle) for location in graphene_info.context.code_locations for repository in location.get_repositories().values() ] @@ -137,9 +133,7 @@ def fetch_repository( repo_loc = graphene_info.context.get_code_location(repository_selector.location_name) if repo_loc.has_repository(repository_selector.repository_name): return GrapheneRepository( - workspace_context=graphene_info.context, - repository=repo_loc.get_repository(repository_selector.repository_name), - repository_location=repo_loc, + repo_loc.get_repository(repository_selector.repository_name).handle, ) raise UserFacingGraphQLError( diff --git a/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_assets.py b/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_assets.py index 10585aecbe34a..8f9731d351112 100644 --- a/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_assets.py +++ b/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_assets.py @@ -166,14 +166,7 @@ def get_asset_node_definition_collisions( if not is_defined: continue - code_location = graphene_info.context.get_code_location(repo_handle.location_name) - repos[asset_node_snap.asset_key].append( - GrapheneRepository( - workspace_context=graphene_info.context, - repository=code_location.get_repository(repo_handle.repository_name), - repository_location=code_location, - ) - ) + repos[asset_node_snap.asset_key].append(GrapheneRepository(repo_handle)) results: List[GrapheneAssetNodeDefinitionCollision] = [] for asset_key in repos.keys(): diff --git a/python_modules/dagster-graphql/dagster_graphql/schema/asset_graph.py b/python_modules/dagster-graphql/dagster_graphql/schema/asset_graph.py index 7c872b126933e..3c227be2d463e 100644 --- a/python_modules/dagster-graphql/dagster_graphql/schema/asset_graph.py +++ b/python_modules/dagster-graphql/dagster_graphql/schema/asset_graph.py @@ -1280,11 +1280,7 @@ def resolve_partitionDefinition( return None def resolve_repository(self, graphene_info: ResolveInfo) -> "GrapheneRepository": - return external.GrapheneRepository( - graphene_info.context, - graphene_info.context.get_repository(self._repository_selector), - graphene_info.context.get_code_location(self._repository_selector.location_name), - ) + return external.GrapheneRepository(self._repository_handle) def resolve_required_resources( self, graphene_info: ResolveInfo diff --git a/python_modules/dagster-graphql/dagster_graphql/schema/external.py b/python_modules/dagster-graphql/dagster_graphql/schema/external.py index a81f3d3d3352d..e5b3d69552723 100644 --- a/python_modules/dagster-graphql/dagster_graphql/schema/external.py +++ b/python_modules/dagster-graphql/dagster_graphql/schema/external.py @@ -2,26 +2,24 @@ from typing import TYPE_CHECKING, Dict, List, Optional import graphene -from dagster import ( - DagsterInstance, - _check as check, -) +from dagster import _check as check from dagster._core.definitions.asset_graph_differ import AssetGraphDiffer from dagster._core.definitions.partition import CachingDynamicPartitionsLoader from dagster._core.definitions.sensor_definition import SensorType from dagster._core.remote_representation import ( CodeLocation, - ExternalRepository, GrpcServerCodeLocation, ManagedGrpcPythonEnvCodeLocationOrigin, ) +from dagster._core.remote_representation.external import ExternalRepository from dagster._core.remote_representation.feature_flags import get_feature_flags_for_location from dagster._core.remote_representation.grpc_server_state_subscriber import ( LocationStateChangeEvent, LocationStateChangeEventType, LocationStateSubscriber, ) -from dagster._core.workspace.context import BaseWorkspaceRequestContext, WorkspaceProcessContext +from dagster._core.remote_representation.handle import RepositoryHandle +from dagster._core.workspace.context import WorkspaceProcessContext from dagster._core.workspace.workspace import CodeLocationEntry, CodeLocationLoadStatus from dagster_graphql.implementation.asset_checks_loader import AssetChecksLoader @@ -110,7 +108,7 @@ def resolve_id(self, _) -> str: def resolve_repositories(self, graphene_info: ResolveInfo): return [ - GrapheneRepository(graphene_info.context, repository, self._location) + GrapheneRepository(repository.handle) for repository in self._location.get_repositories().values() ] @@ -273,110 +271,101 @@ class Meta: def __init__( self, - workspace_context: BaseWorkspaceRequestContext, - repository: ExternalRepository, - repository_location: CodeLocation, + handle: RepositoryHandle, ): # Warning! GrapheneAssetNode contains a GrapheneRepository. Any computation in this # __init__ will be done **once per asset**. Ensure that any expensive work is done # elsewhere or cached. - instance = workspace_context.instance - self._repository = check.inst_param(repository, "repository", ExternalRepository) - self._repository_location = check.inst_param( - repository_location, "repository_location", CodeLocation - ) - check.inst_param(instance, "instance", DagsterInstance) - self._batch_loader = RepositoryScopedBatchLoader(instance, repository) - self._stale_status_loader = StaleStatusLoader( - instance=instance, - asset_graph=lambda: repository.asset_graph, - loading_context=workspace_context, - ) - self._dynamic_partitions_loader = CachingDynamicPartitionsLoader(instance) + self._handle = handle - self._asset_graph_differ = None - # get_base_deployment_context is cached so there will only be one context per query - base_deployment_context = workspace_context.get_base_deployment_context() - if base_deployment_context is not None: - # then we are in a branch deployment - self._asset_graph_differ = AssetGraphDiffer.from_external_repositories( - code_location_name=self._repository_location.name, - repository_name=self._repository.name, - branch_workspace=workspace_context, - base_workspace=base_deployment_context, + self._batch_loader = None + + super().__init__(name=handle.repository_name) + + def get_repository(self, graphene_info: ResolveInfo) -> ExternalRepository: + return graphene_info.context.get_repository(self._handle.to_selector()) + + def get_batch_loader(self, graphene_info: ResolveInfo): + if self._batch_loader is None: + self._batch_loader = RepositoryScopedBatchLoader( + graphene_info.context.instance, self.get_repository(graphene_info) ) - super().__init__(name=repository.name) + return self._batch_loader def resolve_id(self, _graphene_info: ResolveInfo) -> str: - return self._repository.get_compound_id().to_string() + return self._handle.get_compound_id().to_string() def resolve_origin(self, _graphene_info: ResolveInfo): - origin = self._repository.get_remote_origin() + origin = self._handle.get_remote_origin() return GrapheneRepositoryOrigin(origin) - def resolve_location(self, _graphene_info: ResolveInfo): - return GrapheneRepositoryLocation(self._repository_location) + def resolve_location(self, graphene_info: ResolveInfo): + return GrapheneRepositoryLocation( + graphene_info.context.get_code_location(self._handle.location_name) + ) - def resolve_schedules(self, _graphene_info: ResolveInfo): + def resolve_schedules(self, graphene_info: ResolveInfo): + batch_loader = self.get_batch_loader(graphene_info) + repository = self.get_repository(graphene_info) return sorted( [ GrapheneSchedule( schedule, - self._repository, - self._batch_loader.get_schedule_state(schedule.name), - self._batch_loader, + repository, + batch_loader.get_schedule_state(schedule.name), + batch_loader, ) - for schedule in self._repository.get_external_schedules() + for schedule in repository.get_external_schedules() ], key=lambda schedule: schedule.name, ) - def resolve_sensors(self, _graphene_info: ResolveInfo, sensorType: Optional[SensorType] = None): + def resolve_sensors(self, graphene_info: ResolveInfo, sensorType: Optional[SensorType] = None): + batch_loader = self.get_batch_loader(graphene_info) + repository = self.get_repository(graphene_info) return [ GrapheneSensor( sensor, - self._repository, - self._batch_loader.get_sensor_state(sensor.name), - self._batch_loader, - ) - for sensor in sorted( - self._repository.get_external_sensors(), key=lambda sensor: sensor.name + repository, + batch_loader.get_sensor_state(sensor.name), + batch_loader, ) + for sensor in sorted(repository.get_external_sensors(), key=lambda sensor: sensor.name) if not sensorType or sensor.sensor_type == sensorType ] - def resolve_pipelines(self, _graphene_info: ResolveInfo): + def resolve_pipelines(self, graphene_info: ResolveInfo): return [ GraphenePipeline(pipeline) for pipeline in sorted( - self._repository.get_all_external_jobs(), + self.get_repository(graphene_info).get_all_external_jobs(), key=lambda pipeline: pipeline.name, ) ] - def resolve_jobs(self, _graphene_info: ResolveInfo): + def resolve_jobs(self, graphene_info: ResolveInfo): return [ GrapheneJob(pipeline) for pipeline in sorted( - self._repository.get_all_external_jobs(), + self.get_repository(graphene_info).get_all_external_jobs(), key=lambda pipeline: pipeline.name, ) ] - def resolve_usedSolid(self, _graphene_info: ResolveInfo, name): - return get_solid(self._repository, name) + def resolve_usedSolid(self, graphene_info: ResolveInfo, name): + return get_solid(self.get_repository(graphene_info), name) - def resolve_usedSolids(self, _graphene_info: ResolveInfo): - return get_solids(self._repository) + def resolve_usedSolids(self, graphene_info: ResolveInfo): + return get_solids(self.get_repository(graphene_info)) - def resolve_partitionSets(self, _graphene_info: ResolveInfo): + def resolve_partitionSets(self, graphene_info: ResolveInfo): return ( - GraphenePartitionSet(self._repository.handle, partition_set) - for partition_set in self._repository.get_external_partition_sets() + GraphenePartitionSet(self._handle, partition_set) + for partition_set in self.get_repository(graphene_info).get_external_partition_sets() ) - def resolve_displayMetadata(self, _graphene_info: ResolveInfo): - metadata = self._repository.get_display_metadata() + def resolve_displayMetadata(self, graphene_info: ResolveInfo): + metadata = self._handle.display_metadata return [ GrapheneRepositoryMetadata(key=key, value=value) for key, value in metadata.items() @@ -384,26 +373,47 @@ def resolve_displayMetadata(self, _graphene_info: ResolveInfo): ] def resolve_assetNodes(self, graphene_info: ResolveInfo): - asset_node_snaps = self._repository.get_asset_node_snaps() + asset_node_snaps = self.get_repository(graphene_info).get_asset_node_snaps() asset_checks_loader = AssetChecksLoader( context=graphene_info.context, asset_keys=[node.asset_key for node in asset_node_snaps], ) + + asset_graph_differ = None + base_deployment_context = graphene_info.context.get_base_deployment_context() + if base_deployment_context is not None: + # then we are in a branch deployment + asset_graph_differ = AssetGraphDiffer.from_external_repositories( + code_location_name=self._handle.location_name, + repository_name=self._handle.repository_name, + branch_workspace=graphene_info.context, + base_workspace=base_deployment_context, + ) + + dynamic_partitions_loader = CachingDynamicPartitionsLoader( + graphene_info.context.instance, + ) + stale_status_loader = StaleStatusLoader( + instance=graphene_info.context.instance, + asset_graph=lambda: self.get_repository(graphene_info).asset_graph, + loading_context=graphene_info.context, + ) + return [ GrapheneAssetNode( - repository_handle=self._repository.handle, + repository_handle=self._handle, asset_node_snap=asset_node_snap, asset_checks_loader=asset_checks_loader, - stale_status_loader=self._stale_status_loader, - dynamic_partitions_loader=self._dynamic_partitions_loader, - asset_graph_differ=self._asset_graph_differ, + stale_status_loader=stale_status_loader, + dynamic_partitions_loader=dynamic_partitions_loader, + asset_graph_differ=asset_graph_differ, ) - for asset_node_snap in self._repository.get_asset_node_snaps() + for asset_node_snap in self.get_repository(graphene_info).get_asset_node_snaps() ] - def resolve_assetGroups(self, _graphene_info: ResolveInfo): + def resolve_assetGroups(self, graphene_info: ResolveInfo): groups: Dict[str, List[AssetNodeSnap]] = {} - for asset_node_snap in self._repository.get_asset_node_snaps(): + for asset_node_snap in self.get_repository(graphene_info).get_asset_node_snaps(): if not asset_node_snap.group_name: continue external_assets = groups.setdefault(asset_node_snap.group_name, []) @@ -411,22 +421,22 @@ def resolve_assetGroups(self, _graphene_info: ResolveInfo): return [ GrapheneAssetGroup( - f"{self._repository_location.name}-{self._repository.name}-{group_name}", + f"{self._handle.location_name}-{self._handle.repository_name}-{group_name}", group_name, [external_node.asset_key for external_node in external_nodes], ) for group_name, external_nodes in groups.items() ] - def resolve_allTopLevelResourceDetails(self, _graphene_info) -> List[GrapheneResourceDetails]: + def resolve_allTopLevelResourceDetails(self, graphene_info) -> List[GrapheneResourceDetails]: return [ GrapheneResourceDetails( - location_name=self._repository_location.name, - repository_name=self._repository.name, + location_name=self._handle.location_name, + repository_name=self._handle.repository_name, external_resource=resource, ) for resource in sorted( - self._repository.get_external_resources(), + self.get_repository(graphene_info).get_external_resources(), key=lambda resource: resource.name, ) if resource.is_top_level diff --git a/python_modules/dagster-graphql/dagster_graphql/schema/pipelines/pipeline.py b/python_modules/dagster-graphql/dagster_graphql/schema/pipelines/pipeline.py index 342b9a120652e..cb76e4ad9c02d 100644 --- a/python_modules/dagster-graphql/dagster_graphql/schema/pipelines/pipeline.py +++ b/python_modules/dagster-graphql/dagster_graphql/schema/pipelines/pipeline.py @@ -962,13 +962,7 @@ def resolve_isAssetJob(self, graphene_info: ResolveInfo): def resolve_repository(self, graphene_info: ResolveInfo): from dagster_graphql.schema.external import GrapheneRepository - handle = self._external_job.repository_handle - location = graphene_info.context.get_code_location(handle.location_name) - return GrapheneRepository( - graphene_info.context, - location.get_repository(handle.repository_name), - location, - ) + return GrapheneRepository(self._external_job.repository_handle) @capture_error def resolve_partitionKeysOrError( diff --git a/python_modules/dagster/dagster/_core/remote_representation/external.py b/python_modules/dagster/dagster/_core/remote_representation/external.py index 4fd52d20adc3c..0f95e616c7a7e 100644 --- a/python_modules/dagster/dagster/_core/remote_representation/external.py +++ b/python_modules/dagster/dagster/_core/remote_representation/external.py @@ -37,7 +37,6 @@ DefaultSensorStatus, SensorType, ) -from dagster._core.errors import DagsterInvariantViolationError from dagster._core.execution.plan.handle import ResolvedFromDynamicStepHandle, StepHandle from dagster._core.instance import DagsterInstance from dagster._core.origin import JobPythonOrigin, RepositoryPythonOrigin @@ -61,6 +60,7 @@ TargetSnap, ) from dagster._core.remote_representation.handle import ( + CompoundID, InstigatorHandle, JobHandle, PartitionSetHandle, @@ -77,7 +77,6 @@ from dagster._core.snap import ExecutionPlanSnapshot from dagster._core.snap.job_snapshot import JobSnapshot from dagster._core.utils import toposort -from dagster._record import record from dagster._serdes import create_snapshot_id from dagster._utils.cached_method import cached_method from dagster._utils.schedules import schedule_execution_time_iterator @@ -88,35 +87,6 @@ from dagster._core.scheduler.instigation import InstigatorState from dagster._core.snap.execution_plan_snapshot import ExecutionStepSnap -_DELIMITER = "::" - - -@record -class CompoundID: - """Compound ID object for the two id schemes that state is recorded in the database against.""" - - remote_origin_id: str - selector_id: str - - def to_string(self) -> str: - return f"{self.remote_origin_id}{_DELIMITER}{self.selector_id}" - - @staticmethod - def from_string(serialized: str): - parts = serialized.split(_DELIMITER) - if len(parts) != 2: - raise DagsterInvariantViolationError(f"Invalid serialized InstigatorID: {serialized}") - - return CompoundID( - remote_origin_id=parts[0], - selector_id=parts[1], - ) - - @staticmethod - def is_valid_string(serialized: str): - parts = serialized.split(_DELIMITER) - return len(parts) == 2 - class ExternalRepository: """ExternalRepository is a object that represents a loaded repository definition that diff --git a/python_modules/dagster/dagster/_core/remote_representation/handle.py b/python_modules/dagster/dagster/_core/remote_representation/handle.py index c574fc76cea4e..ece5f1a629234 100644 --- a/python_modules/dagster/dagster/_core/remote_representation/handle.py +++ b/python_modules/dagster/dagster/_core/remote_representation/handle.py @@ -4,6 +4,7 @@ import dagster._check as check from dagster._core.code_pointer import ModuleCodePointer from dagster._core.definitions.selector import JobSubsetSelector, RepositorySelector +from dagster._core.errors import DagsterInvariantViolationError from dagster._core.origin import RepositoryPythonOrigin from dagster._core.remote_representation.origin import ( CodeLocationOrigin, @@ -54,6 +55,12 @@ def to_selector(self) -> RepositorySelector: repository_name=self.repository_name, ) + def get_compound_id(self) -> "CompoundID": + return CompoundID( + remote_origin_id=self.get_remote_origin().get_id(), + selector_id=self.to_selector().selector_id, + ) + @staticmethod def for_test( *, @@ -158,3 +165,33 @@ def get_remote_origin(self): return self.repository_handle.get_remote_origin().get_partition_set_origin( self.partition_set_name ) + + +_DELIMITER = "::" + + +@record +class CompoundID: + """Compound ID object for the two id schemes that state is recorded in the database against.""" + + remote_origin_id: str + selector_id: str + + def to_string(self) -> str: + return f"{self.remote_origin_id}{_DELIMITER}{self.selector_id}" + + @staticmethod + def from_string(serialized: str): + parts = serialized.split(_DELIMITER) + if len(parts) != 2: + raise DagsterInvariantViolationError(f"Invalid serialized InstigatorID: {serialized}") + + return CompoundID( + remote_origin_id=parts[0], + selector_id=parts[1], + ) + + @staticmethod + def is_valid_string(serialized: str): + parts = serialized.split(_DELIMITER) + return len(parts) == 2