Skip to content

Commit

Permalink
[RemoteAssetGraph] from workspace and repository builders (#25156)
Browse files Browse the repository at this point in the history
refactor to make it more explicit the conditions in which
RemoteAssetGraphs are built and ease future changes

## How I Tested These Changes

existing tests
  • Loading branch information
alangenfeld authored Oct 10, 2024
1 parent 8fb2133 commit 11fd93f
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import warnings
from collections import defaultdict
from enum import Enum
from functools import cached_property
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -39,6 +40,7 @@
from dagster._core.definitions.utils import DEFAULT_GROUP_NAME
from dagster._core.remote_representation.external import RemoteRepository
from dagster._core.remote_representation.handle import RepositoryHandle
from dagster._core.workspace.workspace import WorkspaceSnapshot

if TYPE_CHECKING:
from dagster._core.remote_representation.external_data import AssetCheckNodeSnap, AssetNodeSnap
Expand Down Expand Up @@ -232,14 +234,23 @@ def _observable_node_snap(self) -> "AssetNodeSnap":
check.failed("No observable node found")


class RemoteAssetGraphScope(Enum):
"""Was this asset graph built from a single repository or all repositories across the whole workspace."""

REPOSITORY = "REPOSITORY"
WORKSPACE = "WORKSPACE"


class RemoteAssetGraph(BaseAssetGraph[RemoteAssetNode]):
def __init__(
self,
scope: RemoteAssetGraphScope,
asset_nodes_by_key: Mapping[AssetKey, RemoteAssetNode],
asset_checks_by_key: Mapping[AssetCheckKey, "AssetCheckNodeSnap"],
asset_check_execution_sets_by_key: Mapping[AssetCheckKey, AbstractSet[EntityKey]],
repository_handles_by_asset_check_key: Mapping[AssetCheckKey, RepositoryHandle],
):
self._scope = scope
self._asset_nodes_by_key = asset_nodes_by_key
self._asset_checks_by_key = asset_checks_by_key
self._asset_check_nodes_by_key = {
Expand All @@ -250,8 +261,49 @@ def __init__(
self._repository_handles_by_asset_check_key = repository_handles_by_asset_check_key

@classmethod
def from_repository_handles_and_asset_node_snaps(
def from_remote_repository(cls, repo: RemoteRepository):
return cls._build(
scope=RemoteAssetGraphScope.REPOSITORY,
repo_handle_assets=[
(repo.handle, node_snap) for node_snap in repo.get_asset_node_snaps()
],
repo_handle_asset_checks=[
(repo.handle, asset_check_node)
for asset_check_node in repo.get_asset_check_node_snaps()
],
)

@classmethod
def from_workspace_snapshot(cls, workspace: WorkspaceSnapshot):
code_locations = (
location_entry.code_location
for location_entry in workspace.code_location_entries.values()
if location_entry.code_location
)
repos = (
repo
for code_location in code_locations
for repo in code_location.get_repositories().values()
)

repo_handle_assets: Sequence[Tuple["RepositoryHandle", "AssetNodeSnap"]] = []
repo_handle_asset_checks: Sequence[Tuple["RepositoryHandle", "AssetCheckNodeSnap"]] = []
for repo in repos:
for asset_node_snap in repo.get_asset_node_snaps():
repo_handle_assets.append((repo.handle, asset_node_snap))
for asset_check_node_snap in repo.get_asset_check_node_snaps():
repo_handle_asset_checks.append((repo.handle, asset_check_node_snap))

return cls._build(
scope=RemoteAssetGraphScope.WORKSPACE,
repo_handle_assets=repo_handle_assets,
repo_handle_asset_checks=repo_handle_asset_checks,
)

@classmethod
def _build(
cls,
scope: RemoteAssetGraphScope,
repo_handle_assets: Sequence[Tuple[RepositoryHandle, "AssetNodeSnap"]],
repo_handle_asset_checks: Sequence[Tuple[RepositoryHandle, "AssetCheckNodeSnap"]],
) -> "RemoteAssetGraph":
Expand Down Expand Up @@ -313,6 +365,7 @@ def from_repository_handles_and_asset_node_snaps(
}

return cls(
scope,
asset_nodes_by_key,
asset_checks_by_key,
asset_check_execution_sets_by_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,7 @@ def asset_graph(self) -> "RemoteAssetGraph":
"""Returns a repository scoped RemoteAssetGraph."""
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph

return RemoteAssetGraph.from_repository_handles_and_asset_node_snaps(
repo_handle_assets=[
(self.handle, node_snap) for node_snap in self.get_asset_node_snaps()
],
repo_handle_asset_checks=[
(self.handle, asset_check_node)
for asset_check_node in self.get_asset_check_node_snaps()
],
)
return RemoteAssetGraph.from_remote_repository(self)

def get_partition_names_for_asset_job(
self,
Expand Down
25 changes: 25 additions & 0 deletions python_modules/dagster/dagster/_core/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from dagster._core.definitions.graph_definition import GraphDefinition
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.node_definition import NodeDefinition
from dagster._core.definitions.repository_definition.repository_definition import (
RepositoryDefinition,
)
from dagster._core.definitions.source_asset import SourceAsset
from dagster._core.definitions.unresolved_asset_job_definition import define_asset_job
from dagster._core.errors import DagsterUserCodeUnreachableError
Expand All @@ -61,13 +64,17 @@
)
from dagster._core.launcher import RunLauncher
from dagster._core.remote_representation import RemoteRepository
from dagster._core.remote_representation.code_location import CodeLocation
from dagster._core.remote_representation.external_data import RepositorySnap
from dagster._core.remote_representation.handle import RepositoryHandle
from dagster._core.remote_representation.origin import InProcessCodeLocationOrigin
from dagster._core.run_coordinator import RunCoordinator, SubmitRunContext
from dagster._core.secrets import SecretsLoader
from dagster._core.storage.dagster_run import DagsterRun, DagsterRunStatus, RunsFilter
from dagster._core.types.loadable_target_origin import LoadableTargetOrigin
from dagster._core.workspace.context import WorkspaceProcessContext, WorkspaceRequestContext
from dagster._core.workspace.load_target import WorkspaceLoadTarget
from dagster._core.workspace.workspace import CodeLocationEntry, WorkspaceSnapshot
from dagster._serdes import ConfigurableClass
from dagster._serdes.config_class import ConfigurableClassData
from dagster._time import create_datetime, get_timezone
Expand Down Expand Up @@ -763,3 +770,21 @@ def freeze_time(new_now: Union[datetime.datetime, float]):


class TestType: ...


def mock_workspace_from_repos(repos: Sequence[RepositoryDefinition]) -> WorkspaceSnapshot:
remote_repos = {}
for repo in repos:
remote_repos[repo.name] = RemoteRepository(
RepositorySnap.from_def(repo),
repository_handle=RepositoryHandle.for_test(
location_name="test",
repository_name=repo.name,
),
instance=DagsterInstance.ephemeral(),
)
mock_entry = unittest.mock.MagicMock(spec=CodeLocationEntry)
mock_location = unittest.mock.MagicMock(spec=CodeLocation)
mock_location.get_repositories.return_value = remote_repos
type(mock_entry).code_location = unittest.mock.PropertyMock(return_value=mock_location)
return WorkspaceSnapshot(code_location_entries={"test": mock_entry})
28 changes: 2 additions & 26 deletions python_modules/dagster/dagster/_core/workspace/workspace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Mapping, Optional

from typing_extensions import Annotated

Expand All @@ -10,8 +10,6 @@
if TYPE_CHECKING:
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph
from dagster._core.remote_representation import CodeLocation, CodeLocationOrigin
from dagster._core.remote_representation.external_data import AssetCheckNodeSnap, AssetNodeSnap
from dagster._core.remote_representation.handle import RepositoryHandle


# For locations that are loaded asynchronously
Expand Down Expand Up @@ -53,29 +51,7 @@ class WorkspaceSnapshot:
def asset_graph(self) -> "RemoteAssetGraph":
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph

code_locations = (
location_entry.code_location
for location_entry in self.code_location_entries.values()
if location_entry.code_location
)
repos = (
repo
for code_location in code_locations
for repo in code_location.get_repositories().values()
)
repo_handle_assets: Sequence[Tuple["RepositoryHandle", "AssetNodeSnap"]] = []
repo_handle_asset_checks: Sequence[Tuple["RepositoryHandle", "AssetCheckNodeSnap"]] = []

for repo in repos:
for asset_node_snap in repo.get_asset_node_snaps():
repo_handle_assets.append((repo.handle, asset_node_snap))
for asset_check_node_snap in repo.get_asset_check_node_snaps():
repo_handle_asset_checks.append((repo.handle, asset_check_node_snap))

return RemoteAssetGraph.from_repository_handles_and_asset_node_snaps(
repo_handle_assets=repo_handle_assets,
repo_handle_asset_checks=repo_handle_asset_checks,
)
return RemoteAssetGraph.from_workspace_snapshot(self)

def with_code_location(self, name: str, entry: CodeLocationEntry) -> "WorkspaceSnapshot":
return WorkspaceSnapshot(code_location_entries={**self.code_location_entries, name: entry})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AssetOut,
AssetsDefinition,
AutomationCondition,
DagsterInstance,
DailyPartitionsDefinition,
GraphOut,
HourlyPartitionsDefinition,
Expand Down Expand Up @@ -37,12 +38,10 @@
from dagster._core.definitions.source_asset import SourceAsset
from dagster._core.errors import DagsterDefinitionChangedDeserializationError
from dagster._core.instance import DynamicPartitionsStore
from dagster._core.remote_representation.external_data import (
asset_check_node_snaps_from_repo,
asset_node_snaps_from_repo,
)
from dagster._core.remote_representation.external import RemoteRepository
from dagster._core.remote_representation.external_data import RepositorySnap
from dagster._core.remote_representation.handle import RepositoryHandle
from dagster._core.test_utils import freeze_time, instance_for_test
from dagster._core.test_utils import freeze_time, instance_for_test, mock_workspace_from_repos
from dagster._time import create_datetime, get_current_datetime


Expand All @@ -51,12 +50,12 @@ def to_remote_asset_graph(assets, asset_checks=None) -> RemoteAssetGraph:
def repo():
return assets + (asset_checks or [])

asset_node_snaps = asset_node_snaps_from_repo(repo)
handle = RepositoryHandle.for_test(location_name="fake", repository_name="repo")
return RemoteAssetGraph.from_repository_handles_and_asset_node_snaps(
[(handle, asset_node) for asset_node in asset_node_snaps],
[(handle, asset_check) for asset_check in asset_check_node_snaps_from_repo(repo)],
remote_repo = RemoteRepository(
RepositorySnap.from_def(repo),
repository_handle=RepositoryHandle.for_test(location_name="fake", repository_name="repo"),
instance=DagsterInstance.ephemeral(),
)
return RemoteAssetGraph.from_remote_repository(remote_repo)


@pytest.fixture(
Expand Down Expand Up @@ -889,11 +888,8 @@ def b(): ...
def repo_b():
return [b]

a_nodes = asset_node_snaps_from_repo(repo_a)
b_nodes = asset_node_snaps_from_repo(repo_b)
handle = RepositoryHandle.for_test(location_name="foo", repository_name="bar")
asset_graph = RemoteAssetGraph.from_repository_handles_and_asset_node_snaps(
[(handle, asset_node) for asset_node in [*a_nodes, *b_nodes]], []
asset_graph = RemoteAssetGraph.from_workspace_snapshot(
mock_workspace_from_repos([repo_a, repo_b])
)

assert isinstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)
from unittest.mock import MagicMock, patch

import mock
import pytest
from dagster import (
AssetIn,
Expand All @@ -24,7 +23,6 @@
DagsterInstance,
DagsterRunStatus,
DailyPartitionsDefinition,
Definitions,
HourlyPartitionsDefinition,
LastPartitionMapping,
Nothing,
Expand All @@ -41,6 +39,7 @@
from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView, TemporalContext
from dagster._core.definitions.asset_graph_subset import AssetGraphSubset
from dagster._core.definitions.base_asset_graph import BaseAssetGraph
from dagster._core.definitions.decorators.repository_decorator import repository
from dagster._core.definitions.events import AssetKeyPartitionKey
from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph
from dagster._core.definitions.selector import (
Expand All @@ -56,16 +55,19 @@
execute_asset_backfill_iteration_inner,
get_canceling_asset_backfill_iteration_data,
)
from dagster._core.remote_representation.external_data import asset_node_snaps_from_repo
from dagster._core.remote_representation.handle import RepositoryHandle
from dagster._core.storage.dagster_run import RunsFilter
from dagster._core.storage.tags import (
ASSET_PARTITION_RANGE_END_TAG,
ASSET_PARTITION_RANGE_START_TAG,
BACKFILL_ID_TAG,
PARTITION_NAME_TAG,
)
from dagster._core.test_utils import environ, freeze_time, instance_for_test
from dagster._core.test_utils import (
environ,
freeze_time,
instance_for_test,
mock_workspace_from_repos,
)
from dagster._serdes import deserialize_value, serialize_value
from dagster._time import create_datetime, get_current_datetime, get_current_timestamp
from dagster._utils import Counter, traced_counter
Expand Down Expand Up @@ -293,9 +295,9 @@ def _single_backfill_iteration_create_but_do_not_submit_runs(
backfill_id, backfill_data, asset_graph, instance, assets_by_repo_name
) -> AssetBackfillData:
# Patch the run execution to not actually execute the run, but instead just create it
with mock.patch(
with patch(
"dagster._core.execution.execute_in_process.ExecuteRunWithPlanIterable",
return_value=mock.MagicMock(),
return_value=MagicMock(),
):
return _single_backfill_iteration(
backfill_id, backfill_data, asset_graph, instance, assets_by_repo_name
Expand Down Expand Up @@ -741,19 +743,16 @@ def _requested_asset_partitions_in_run_request(
def remote_asset_graph_from_assets_by_repo_name(
assets_by_repo_name: Mapping[str, Sequence[AssetsDefinition]],
) -> RemoteAssetGraph:
from_repository_handles_and_asset_node_snaps = []

repos = []
for repo_name, assets in assets_by_repo_name.items():
repo = Definitions(assets=assets).get_repository_def()
asset_node_snaps = asset_node_snaps_from_repo(repo)
handle = RepositoryHandle.for_test(location_name="test", repository_name=repo_name)
from_repository_handles_and_asset_node_snaps.extend(
[(handle, asset_node) for asset_node in asset_node_snaps]
)

return RemoteAssetGraph.from_repository_handles_and_asset_node_snaps(
from_repository_handles_and_asset_node_snaps, []
)
@repository(name=repo_name)
def repo(assets=assets):
return assets

repos.append(repo)

return RemoteAssetGraph.from_workspace_snapshot(mock_workspace_from_repos(repos))


@pytest.mark.parametrize(
Expand Down

0 comments on commit 11fd93f

Please sign in to comment.