Skip to content

Commit

Permalink
[global asset graph] precompute targeting instigators (#25202)
Browse files Browse the repository at this point in the history
To facilitate being able to resolve all required information for an
asset node from just the global scope remote asset node, track which
sensors and schedules target each repository scoped asset.

## How I Tested These Changes

existing coverage, additional coverage in
dagster-io/internal#11840
  • Loading branch information
alangenfeld authored Oct 15, 2024
1 parent 566d6cd commit d7c8769
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def get_schedules_or_error(

results = [
GrapheneSchedule(
schedule, repository, schedule_states_by_name.get(schedule.name), batch_loader
schedule,
repository.handle,
schedule_states_by_name.get(schedule.name),
batch_loader,
)
for schedule in filtered
]
Expand Down Expand Up @@ -172,7 +175,8 @@ def get_schedules_for_pipeline(
schedule.get_remote_origin_id(),
schedule.selector_id,
)
results.append(GrapheneSchedule(schedule, repository, schedule_state))

results.append(GrapheneSchedule(schedule, repository.handle, schedule_state))

return results

Expand All @@ -197,7 +201,7 @@ def get_schedule_or_error(
schedule_state = graphene_info.context.instance.get_instigator_state(
schedule.get_remote_origin_id(), schedule.selector_id
)
return GrapheneSchedule(schedule, repository, schedule_state)
return GrapheneSchedule(schedule, repository.handle, schedule_state)


def get_schedule_next_tick(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def get_sensors_or_error(

return GrapheneSensors(
results=[
GrapheneSensor(sensor, repository, sensor_states_by_name.get(sensor.name), batch_loader)
GrapheneSensor(
sensor,
repository.handle,
sensor_states_by_name.get(sensor.name),
batch_loader,
)
for sensor in filtered
]
)
Expand All @@ -83,7 +88,7 @@ def get_sensor_or_error(graphene_info: ResolveInfo, selector: SensorSelector) ->
sensor.selector_id,
)

return GrapheneSensor(sensor, repository, sensor_state)
return GrapheneSensor(sensor, repository.handle, sensor_state)


def start_sensor(graphene_info: ResolveInfo, sensor_selector: SensorSelector) -> "GrapheneSensor":
Expand All @@ -98,7 +103,7 @@ def start_sensor(graphene_info: ResolveInfo, sensor_selector: SensorSelector) ->
raise UserFacingGraphQLError(GrapheneSensorNotFoundError(sensor_selector.sensor_name))
sensor = repository.get_sensor(sensor_selector.sensor_name)
sensor_state = graphene_info.context.instance.start_sensor(sensor)
return GrapheneSensor(sensor, repository, sensor_state)
return GrapheneSensor(sensor, repository.handle, sensor_state)


def stop_sensor(
Expand Down Expand Up @@ -151,7 +156,7 @@ def reset_sensor(graphene_info: ResolveInfo, sensor_selector: SensorSelector) ->
sensor = repository.get_sensor(sensor_selector.sensor_name)
sensor_state = graphene_info.context.instance.reset_sensor(sensor)

return GrapheneSensor(sensor, repository, sensor_state)
return GrapheneSensor(sensor, repository.handle, sensor_state)


def get_sensors_for_pipeline(
Expand All @@ -174,7 +179,7 @@ def get_sensors_for_pipeline(
sensor.get_remote_origin_id(),
sensor.selector_id,
)
results.append(GrapheneSensor(sensor, repository, sensor_state))
results.append(GrapheneSensor(sensor, repository.handle, sensor_state))

return results

Expand Down Expand Up @@ -259,4 +264,4 @@ def set_sensor_cursor(
else:
instance.update_instigator_state(updated_state)

return GrapheneSensor(sensor, repository, updated_state)
return GrapheneSensor(sensor, repository.handle, updated_state)
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Set, Union, cast
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast

import graphene
from dagster import (
AssetKey,
DagsterError,
_check as check,
)
from dagster._core.definitions.asset_graph_differ import AssetDefinitionChangeType, AssetGraphDiffer
from dagster._core.definitions.asset_job import IMPLICIT_ASSET_JOB_NAME
from dagster._core.definitions.data_time import CachingDataTimeResolver
from dagster._core.definitions.data_version import (
NULL_DATA_VERSION,
Expand All @@ -16,7 +14,7 @@
)
from dagster._core.definitions.partition import CachingDynamicPartitionsLoader, PartitionsDefinition
from dagster._core.definitions.partition_mapping import PartitionMapping
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph, RemoteAssetNode
from dagster._core.definitions.remote_asset_graph import RemoteAssetNode, RemoteWorkspaceAssetNode
from dagster._core.definitions.selector import JobSelector
from dagster._core.definitions.sensor_definition import SensorType
from dagster._core.errors import DagsterInvariantViolationError
Expand Down Expand Up @@ -847,53 +845,49 @@ def resolve_automationCondition(
return GrapheneAutomationCondition(self._asset_node_snap.automation_condition)
return None

def _sensor_targets_asset(
self, sensor: RemoteSensor, asset_graph: RemoteAssetGraph, job_names: Set[str]
) -> bool:
asset_key = self._asset_node_snap.asset_key

if sensor.asset_selection is not None:
try:
asset_selection = sensor.asset_selection.resolve(asset_graph)
except DagsterError:
return False

if asset_key in asset_selection:
return True

return any(target.job_name in job_names for target in sensor.get_targets())

def resolve_targetingInstigators(self, graphene_info: ResolveInfo) -> Sequence[GrapheneSensor]:
repo = graphene_info.context.get_repository(self._repository_selector)
sensors = repo.get_sensors()
schedules = repo.get_schedules()

asset_graph = repo.asset_graph

job_names = {
job_name
for job_name in self._asset_node_snap.job_names
if not job_name == IMPLICIT_ASSET_JOB_NAME
}
if isinstance(self._remote_node, RemoteWorkspaceAssetNode):
# global nodes have saved references to their targeting instigators
schedules = [
graphene_info.context.get_schedule(schedule_handle)
for schedule_handle in self._remote_node.get_targeting_schedule_handles()
]
sensors = [
graphene_info.context.get_sensor(sensor_handle)
for sensor_handle in self._remote_node.get_targeting_sensor_handles()
]
else:
# fallback to using the repository
repo = graphene_info.context.get_repository(self._repository_selector)
schedules = repo.get_schedules_targeting(self._asset_node_snap.asset_key)
sensors = repo.get_sensors_targeting(self._asset_node_snap.asset_key)

results = []
for sensor in sensors:
if not self._sensor_targets_asset(sensor, asset_graph, job_names):
continue

sensor_state = graphene_info.context.instance.get_instigator_state(
sensor.get_remote_origin_id(),
sensor.selector_id,
)
results.append(GrapheneSensor(sensor, repo, sensor_state))
results.append(
GrapheneSensor(
sensor,
sensor.handle.repository_handle,
sensor_state,
)
)

for schedule in schedules:
if schedule.job_name in job_names:
schedule_state = graphene_info.context.instance.get_instigator_state(
schedule.get_remote_origin_id(),
schedule.selector_id,
schedule_state = graphene_info.context.instance.get_instigator_state(
schedule.get_remote_origin_id(),
schedule.selector_id,
)
results.append(
GrapheneSchedule(
schedule,
schedule.handle.repository_handle,
schedule_state,
)
results.append(GrapheneSchedule(schedule, repo, schedule_state))
)

return results

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

import graphene
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.asset_selection import AssetSelection
from dagster._core.remote_representation.external import RemoteRepository
from dagster._core.remote_representation.handle import RepositoryHandle

from dagster_graphql.implementation.fetch_assets import get_asset_nodes_by_asset_key
from dagster_graphql.implementation.fetch_assets import get_asset
from dagster_graphql.implementation.utils import capture_error
from dagster_graphql.schema.asset_key import GrapheneAssetKey
from dagster_graphql.schema.util import non_null_list
from dagster_graphql.schema.util import ResolveInfo, non_null_list

if TYPE_CHECKING:
from dagster_graphql.schema.roots.assets import GrapheneAssetConnection
Expand All @@ -21,34 +20,40 @@ class GrapheneAssetSelection(graphene.ObjectType):
assets = non_null_list("dagster_graphql.schema.pipelines.pipeline.GrapheneAsset")
assetsOrError = graphene.NonNull("dagster_graphql.schema.roots.assets.GrapheneAssetsOrError")

def __init__(self, asset_selection: AssetSelection, remote_repository: RemoteRepository):
def __init__(
self,
asset_selection: AssetSelection,
repository_handle: RepositoryHandle,
):
self._asset_selection = asset_selection
self._remote_repository = remote_repository
self._repository_handle = repository_handle
self._resolved_keys = None

def resolve_assetSelectionString(self, _graphene_info):
def resolve_assetSelectionString(self, _graphene_info) -> str:
return str(self._asset_selection)

def resolve_assetKeys(self, _graphene_info):
def resolve_assetKeys(self, graphene_info: ResolveInfo):
return [
GrapheneAssetKey(path=asset_key.path) for asset_key in self._resolved_and_sorted_keys
GrapheneAssetKey(path=asset_key.path)
for asset_key in self._get_resolved_and_sorted_keys(graphene_info)
]

def _get_assets(self, graphene_info):
from dagster_graphql.schema.pipelines.pipeline import GrapheneAsset

asset_nodes_by_asset_key = get_asset_nodes_by_asset_key(graphene_info)
def _get_assets(self, graphene_info: ResolveInfo):
return [
GrapheneAsset(key=asset_key, definition=asset_nodes_by_asset_key.get(asset_key))
for asset_key in self._resolved_and_sorted_keys
get_asset(graphene_info, asset_key)
for asset_key in self._get_resolved_and_sorted_keys(graphene_info)
]

def resolve_assets(self, graphene_info):
def resolve_assets(self, graphene_info: ResolveInfo):
return self._get_assets(graphene_info)

@cached_property
def _resolved_and_sorted_keys(self) -> Sequence[AssetKey]:
def _get_resolved_and_sorted_keys(self, graphene_info: ResolveInfo) -> Sequence[AssetKey]:
"""Use this to maintain stability in ordering."""
return sorted(self._asset_selection.resolve(self._remote_repository.asset_graph), key=str)
if self._resolved_keys is None:
repo = graphene_info.context.get_repository(self._repository_handle)
self._resolved_keys = sorted(self._asset_selection.resolve(repo.asset_graph), key=str)

return self._resolved_keys

@capture_error
def resolve_assetsOrError(self, graphene_info) -> "GrapheneAssetConnection":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def resolve_schedules(self, graphene_info: ResolveInfo):
[
GrapheneSchedule(
schedule,
repository,
repository.handle,
batch_loader.get_schedule_state(schedule.name),
batch_loader,
)
Expand All @@ -326,7 +326,7 @@ def resolve_sensors(self, graphene_info: ResolveInfo, sensorType: Optional[Senso
return [
GrapheneSensor(
sensor,
repository,
repository.handle,
batch_loader.get_sensor_state(sensor.name),
batch_loader,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import graphene
from dagster import DefaultScheduleStatus
from dagster._core.remote_representation import RemoteSchedule
from dagster._core.remote_representation.external import RemoteRepository
from dagster._core.remote_representation.handle import RepositoryHandle
from dagster._core.scheduler.instigation import InstigatorState, InstigatorStatus
from dagster._time import get_current_timestamp

Expand Down Expand Up @@ -66,12 +66,11 @@ class Meta:
def __init__(
self,
remote_schedule: RemoteSchedule,
remote_repository: RemoteRepository,
repository_handle: RepositoryHandle,
schedule_state: Optional[InstigatorState],
batch_loader: Optional[RepositoryScopedBatchLoader] = None,
):
self._remote_schedule = check.inst_param(remote_schedule, "remote_schedule", RemoteSchedule)
self._remote_repository = remote_repository

# optional run loader, provided by a parent graphene object (e.g. GrapheneRepository)
# that instantiates multiple schedules
Expand All @@ -98,7 +97,7 @@ def __init__(
description=remote_schedule.description,
assetSelection=GrapheneAssetSelection(
asset_selection=remote_schedule.asset_selection,
remote_repository=self._remote_repository,
repository_handle=repository_handle,
)
if remote_schedule.asset_selection
else None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from dagster._core.definitions.sensor_definition import SensorType
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.remote_representation import RemoteSensor, TargetSnap
from dagster._core.remote_representation.external import CompoundID, RemoteRepository
from dagster._core.remote_representation.external import CompoundID
from dagster._core.remote_representation.handle import RepositoryHandle
from dagster._core.scheduler.instigation import InstigatorState, InstigatorStatus
from dagster._core.workspace.permissions import Permissions

Expand Down Expand Up @@ -94,12 +95,11 @@ class Meta:
def __init__(
self,
remote_sensor: RemoteSensor,
remote_repo: RemoteRepository,
repository_handle: RepositoryHandle,
sensor_state: Optional[InstigatorState],
batch_loader: Optional[RepositoryScopedBatchLoader] = None,
):
self._remote_sensor = check.inst_param(remote_sensor, "remote_sensor", RemoteSensor)
self._remote_repository = remote_repo

# optional run loader, provided by a parent GrapheneRepository object that instantiates
# multiple sensors
Expand All @@ -122,7 +122,7 @@ def __init__(
sensorType=remote_sensor.sensor_type.value,
assetSelection=GrapheneAssetSelection(
asset_selection=remote_sensor.asset_selection,
remote_repository=self._remote_repository,
repository_handle=repository_handle,
)
if remote_sensor.asset_selection
else None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,29 +265,29 @@ def entity_dep_graph(self) -> DependencyGraph[EntityKey]:
"downstream": {node.key: node.child_entity_keys for node in self.nodes},
}

@cached_property
@property
def all_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes}
return set(self._asset_nodes_by_key)

@cached_property
def materializable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_materializable}
return {key for key, node in self._asset_nodes_by_key.items() if node.is_materializable}

@cached_property
def observable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_observable}
return {key for key, node in self._asset_nodes_by_key.items() if node.is_observable}

@cached_property
def external_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_external}
return {key for key, node in self._asset_nodes_by_key.items() if node.is_external}

@cached_property
def executable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if node.is_executable}
return {key for key, node in self._asset_nodes_by_key.items() if node.is_executable}

@cached_property
def unexecutable_asset_keys(self) -> AbstractSet[AssetKey]:
return {node.key for node in self.asset_nodes if not node.is_executable}
return {key for key, node in self._asset_nodes_by_key.items() if not node.is_executable}

@cached_property
def toposorted_asset_keys(self) -> Sequence[AssetKey]:
Expand Down
Loading

0 comments on commit d7c8769

Please sign in to comment.