Skip to content

Commit

Permalink
[2/n] Pass around a CodeLocationEntriesSnapshot object into the Works…
Browse files Browse the repository at this point in the history
…paceRequestContext instead of a dictionary (#24069)

Summary:
This allows us to keep a cached asset graph if the underlying code
locations have not changed, which allows us to call
create_request_context more in the daemon without needing to worry about
it being expensive. This will be used in the next PR to resolve a race
condition where the asset graph gets out of sync with the workspace in
the asset daemon.

Test Plan: BK

## Summary & Motivation

## How I Tested These Changes

## Changelog [New | Bug | Docs]

> Replace this message with a changelog entry, or `NOCHANGELOG`
  • Loading branch information
gibsondan authored Sep 4, 2024
1 parent e1de9e8 commit 32e3f8c
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def define_out_of_process_context(
) as workspace_process_context:
yield WorkspaceRequestContext(
instance=instance,
workspace_snapshot=workspace_process_context.create_snapshot(),
workspace_snapshot=workspace_process_context.get_workspace_snapshot(),
process_context=workspace_process_context,
version=workspace_process_context.version,
source=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,10 @@ def test_code_server_cli_reload_location(self, graphql_context):
assert result.data["reloadRepositoryLocation"]["name"] == "test"
assert result.data["reloadRepositoryLocation"]["loadStatus"] == "LOADED"

new_location = graphql_context.process_context.create_snapshot()["test"].code_location
new_location = (
graphql_context.process_context.get_workspace_snapshot()
.code_location_entries["test"]
.code_location
)

assert new_location.server_id != old_server_id # Reload actually happened
3 changes: 2 additions & 1 deletion python_modules/dagster-webserver/dagster_webserver/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
IWorkspaceProcessContext,
WorkspaceRequestContext,
)
from dagster._core.workspace.workspace import WorkspaceSnapshot
from dagster._serdes.serdes import deserialize_value

from dagster_webserver.cli import (
Expand All @@ -34,7 +35,7 @@ def __init__(
def create_request_context(self, source: Optional[Any] = None) -> BaseWorkspaceRequestContext:
return WorkspaceRequestContext(
instance=self._instance,
workspace_snapshot={},
workspace_snapshot=WorkspaceSnapshot(code_location_entries={}),
process_context=self,
version=__version__,
source=source,
Expand Down
66 changes: 35 additions & 31 deletions python_modules/dagster/dagster/_core/workspace/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@
CodeLocationLoadStatus,
CodeLocationStatusEntry,
IWorkspace,
WorkspaceSnapshot,
location_status_from_location_entry,
)
from dagster._utils.aiodataloader import DataLoader
from dagster._utils.error import SerializableErrorInfo, serializable_error_info_from_exc_info

if TYPE_CHECKING:
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph
from dagster._core.remote_representation import (
ExternalPartitionConfigData,
ExternalPartitionExecutionErrorData,
Expand Down Expand Up @@ -94,18 +96,10 @@ class BaseWorkspaceRequestContext(IWorkspace, LoadingContext):
def instance(self) -> DagsterInstance:
pass

@abstractmethod
def get_code_location_entries(self) -> Mapping[str, CodeLocationEntry]:
pass

@abstractmethod
def get_location_entry(self, name: str) -> Optional[CodeLocationEntry]:
pass

@abstractmethod
def get_code_location_statuses(self) -> Sequence[CodeLocationStatusEntry]:
pass

@property
@abstractmethod
def process_context(self) -> "IWorkspaceProcessContext":
Expand Down Expand Up @@ -347,7 +341,7 @@ class WorkspaceRequestContext(BaseWorkspaceRequestContext):
def __init__(
self,
instance: DagsterInstance,
workspace_snapshot: Mapping[str, CodeLocationEntry],
workspace_snapshot: WorkspaceSnapshot,
process_context: "IWorkspaceProcessContext",
version: Optional[str],
source: Optional[object],
Expand Down Expand Up @@ -376,15 +370,15 @@ def instance(self) -> DagsterInstance:
return self._instance

def get_code_location_entries(self) -> Mapping[str, CodeLocationEntry]:
return self._workspace_snapshot
return self._workspace_snapshot.code_location_entries

def get_location_entry(self, name: str) -> Optional[CodeLocationEntry]:
return self._workspace_snapshot.get(name)
return self._workspace_snapshot.code_location_entries.get(name)

def get_code_location_statuses(self) -> Sequence[CodeLocationStatusEntry]:
return [
location_status_from_location_entry(entry)
for entry in self._workspace_snapshot.values()
for entry in self._workspace_snapshot.code_location_entries.values()
]

@property
Expand Down Expand Up @@ -434,6 +428,10 @@ def source(self) -> Optional[object]:
def loaders(self) -> Dict[Type, DataLoader]:
return self._loaders

@property
def asset_graph(self) -> "RemoteAssetGraph":
return self._workspace_snapshot.asset_graph


class IWorkspaceProcessContext(ABC):
"""Class that stores process-scoped information about a webserver session.
Expand Down Expand Up @@ -519,7 +517,7 @@ def __init__(

self._version = version

# Guards changes to _location_entry_dict, _watch_thread_shutdown_events and _watch_threads
# Guards changes to _workspace_snapshot, _watch_thread_shutdown_events and _watch_threads
self._lock = threading.Lock()
self._watch_thread_shutdown_events: Dict[str, threading.Event] = {}
self._watch_threads: Dict[str, threading.Thread] = {}
Expand All @@ -544,7 +542,7 @@ def __init__(
)
)

self._location_entry_dict: Dict[str, CodeLocationEntry] = {}
self._workspace_snapshot: WorkspaceSnapshot = WorkspaceSnapshot(code_location_entries={})
self._update_workspace(
{
origin.location_name: self._load_location(origin, reload=False)
Expand Down Expand Up @@ -672,47 +670,51 @@ def _load_location(self, origin: CodeLocationOrigin, reload: bool) -> CodeLocati
update_timestamp=time.time(),
)

def create_snapshot(self) -> Mapping[str, CodeLocationEntry]:
def get_workspace_snapshot(self) -> WorkspaceSnapshot:
with self._lock:
return self._location_entry_dict.copy()
return self._workspace_snapshot

@property
def code_locations_count(self) -> int:
with self._lock:
return len(self._location_entry_dict)
return len(self._workspace_snapshot.code_location_entries)

@property
def code_location_names(self) -> Sequence[str]:
with self._lock:
return list(self._location_entry_dict)
return list(self._workspace_snapshot.code_location_entries)

def has_code_location(self, location_name: str) -> bool:
check.str_param(location_name, "location_name")

with self._lock:
return (
location_name in self._location_entry_dict
and self._location_entry_dict[location_name].code_location is not None
location_name in self._workspace_snapshot.code_location_entries
and self._workspace_snapshot.code_location_entries[location_name].code_location
is not None
)

def has_code_location_error(self, location_name: str) -> bool:
check.str_param(location_name, "location_name")
with self._lock:
return (
location_name in self._location_entry_dict
and self._location_entry_dict[location_name].load_error is not None
location_name in self._workspace_snapshot.code_location_entries
and self._workspace_snapshot.code_location_entries[location_name].load_error
is not None
)

def reload_code_location(self, name: str) -> None:
new = self._load_location(self._location_entry_dict[name].origin, reload=True)
new_entry = self._load_location(
self._workspace_snapshot.code_location_entries[name].origin, reload=True
)
with self._lock:
# Relying on GC to clean up the old location once nothing else
# is referencing it
self._location_entry_dict[name] = new
self._workspace_snapshot = self._workspace_snapshot.with_code_location(name, new_entry)

def shutdown_code_location(self, name: str) -> None:
with self._lock:
self._location_entry_dict[name].origin.shutdown_server()
self._workspace_snapshot.code_location_entries[name].origin.shutdown_server()

def refresh_workspace(self) -> None:
updated_locations = {
Expand All @@ -737,11 +739,11 @@ def _update_workspace(self, new_locations: Dict[str, CodeLocationEntry]):
previous_threads = self._watch_threads
self._watch_threads = {}

previous_locations = self._location_entry_dict
self._location_entry_dict = new_locations
previous_locations = self._workspace_snapshot.code_location_entries
self._workspace_snapshot = WorkspaceSnapshot(code_location_entries=new_locations)

# start monitoring for new locations
for entry in self._location_entry_dict.values():
for entry in new_locations.values():
if isinstance(entry.origin, GrpcServerCodeLocationOrigin):
self._start_watch_thread(entry.origin)

Expand All @@ -759,7 +761,7 @@ def _update_workspace(self, new_locations: Dict[str, CodeLocationEntry]):
def create_request_context(self, source: Optional[object] = None) -> WorkspaceRequestContext:
return WorkspaceRequestContext(
instance=self._instance,
workspace_snapshot=self.create_snapshot(),
workspace_snapshot=self.get_workspace_snapshot(),
process_context=self,
version=self.version,
source=source,
Expand All @@ -785,11 +787,13 @@ def _location_state_events_handler(self, event: LocationStateChangeEvent) -> Non
def refresh_code_location(self, name: str) -> None:
# This method reloads the webserver's copy of the code from the remote gRPC server without
# restarting it, and returns a new request context created from the updated process context
new = self._load_location(self._location_entry_dict[name].origin, reload=False)
new_entry = self._load_location(
self._workspace_snapshot.code_location_entries[name].origin, reload=False
)
with self._lock:
# Relying on GC to clean up the old location once nothing else
# is referencing it
self._location_entry_dict[name] = new
self._workspace_snapshot = self._workspace_snapshot.with_code_location(name, new_entry)

def __enter__(self):
return self
Expand Down
45 changes: 29 additions & 16 deletions python_modules/dagster/dagster/_core/workspace/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Mapping, NamedTuple, Optional, Sequence, Tuple

from dagster._record import record
from dagster._utils.error import SerializableErrorInfo

if TYPE_CHECKING:
Expand Down Expand Up @@ -40,29 +41,17 @@ class CodeLocationStatusEntry(NamedTuple):
update_timestamp: float


class IWorkspace(ABC):
"""Manages a set of CodeLocations."""

@abstractmethod
def get_code_location(self, location_name: str) -> "CodeLocation":
"""Return the CodeLocation for the given location name, or raise an error if there is an error loading it."""

@abstractmethod
def get_code_location_entries(self) -> Mapping[str, CodeLocationEntry]:
"""Return an entry for each location in the workspace."""

@abstractmethod
def get_code_location_statuses(self) -> Sequence[CodeLocationStatusEntry]:
pass
@record
class WorkspaceSnapshot:
code_location_entries: Mapping[str, CodeLocationEntry]

@cached_property
def asset_graph(self) -> "RemoteAssetGraph":
"""Returns a workspace scoped RemoteAssetGraph."""
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph

code_locations = (
location_entry.code_location
for location_entry in self.get_code_location_entries().values()
for location_entry in self.code_location_entries.values()
if location_entry.code_location
)
repos = (
Expand All @@ -86,6 +75,30 @@ def asset_graph(self) -> "RemoteAssetGraph":
external_asset_checks=asset_checks,
)

def with_code_location(self, name: str, entry: CodeLocationEntry) -> "WorkspaceSnapshot":
return WorkspaceSnapshot(code_location_entries={**self.code_location_entries, name: entry})


class IWorkspace(ABC):
"""Manages a set of CodeLocations."""

@abstractmethod
def get_code_location(self, location_name: str) -> "CodeLocation":
"""Return the CodeLocation for the given location name, or raise an error if there is an error loading it."""

@abstractmethod
def get_code_location_entries(self) -> Mapping[str, CodeLocationEntry]:
"""Return an entry for each location in the workspace."""

@abstractmethod
def get_code_location_statuses(self) -> Sequence[CodeLocationStatusEntry]:
pass

@property
@abstractmethod
def asset_graph(self) -> "RemoteAssetGraph":
pass


def location_status_from_location_entry(
entry: CodeLocationEntry,
Expand Down
Loading

0 comments on commit 32e3f8c

Please sign in to comment.