From 11fd93fe3eb96c75620adf719a7f957ae8339b7c Mon Sep 17 00:00:00 2001 From: Alex Langenfeld Date: Thu, 10 Oct 2024 11:49:30 -0500 Subject: [PATCH] [RemoteAssetGraph] from workspace and repository builders (#25156) refactor to make it more explicit the conditions in which RemoteAssetGraphs are built and ease future changes ## How I Tested These Changes existing tests --- .../_core/definitions/remote_asset_graph.py | 55 ++++++++++++++++++- .../_core/remote_representation/external.py | 10 +--- .../dagster/dagster/_core/test_utils.py | 25 +++++++++ .../dagster/_core/workspace/workspace.py | 28 +--------- .../asset_defs_tests/test_asset_graph.py | 26 ++++----- .../execution_tests/test_asset_backfill.py | 35 ++++++------ 6 files changed, 110 insertions(+), 69 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/remote_asset_graph.py b/python_modules/dagster/dagster/_core/definitions/remote_asset_graph.py index b1ca797533286..7a7117416552a 100644 --- a/python_modules/dagster/dagster/_core/definitions/remote_asset_graph.py +++ b/python_modules/dagster/dagster/_core/definitions/remote_asset_graph.py @@ -1,6 +1,7 @@ import itertools import warnings from collections import defaultdict +from enum import Enum from functools import cached_property from typing import ( TYPE_CHECKING, @@ -39,6 +40,7 @@ from dagster._core.definitions.utils import DEFAULT_GROUP_NAME from dagster._core.remote_representation.external import RemoteRepository from dagster._core.remote_representation.handle import RepositoryHandle +from dagster._core.workspace.workspace import WorkspaceSnapshot if TYPE_CHECKING: from dagster._core.remote_representation.external_data import AssetCheckNodeSnap, AssetNodeSnap @@ -232,14 +234,23 @@ def _observable_node_snap(self) -> "AssetNodeSnap": check.failed("No observable node found") +class RemoteAssetGraphScope(Enum): + """Was this asset graph built from a single repository or all repositories across the whole workspace.""" + + REPOSITORY = "REPOSITORY" + WORKSPACE = "WORKSPACE" + + class RemoteAssetGraph(BaseAssetGraph[RemoteAssetNode]): def __init__( self, + scope: RemoteAssetGraphScope, asset_nodes_by_key: Mapping[AssetKey, RemoteAssetNode], asset_checks_by_key: Mapping[AssetCheckKey, "AssetCheckNodeSnap"], asset_check_execution_sets_by_key: Mapping[AssetCheckKey, AbstractSet[EntityKey]], repository_handles_by_asset_check_key: Mapping[AssetCheckKey, RepositoryHandle], ): + self._scope = scope self._asset_nodes_by_key = asset_nodes_by_key self._asset_checks_by_key = asset_checks_by_key self._asset_check_nodes_by_key = { @@ -250,8 +261,49 @@ def __init__( self._repository_handles_by_asset_check_key = repository_handles_by_asset_check_key @classmethod - def from_repository_handles_and_asset_node_snaps( + def from_remote_repository(cls, repo: RemoteRepository): + return cls._build( + scope=RemoteAssetGraphScope.REPOSITORY, + repo_handle_assets=[ + (repo.handle, node_snap) for node_snap in repo.get_asset_node_snaps() + ], + repo_handle_asset_checks=[ + (repo.handle, asset_check_node) + for asset_check_node in repo.get_asset_check_node_snaps() + ], + ) + + @classmethod + def from_workspace_snapshot(cls, workspace: WorkspaceSnapshot): + code_locations = ( + location_entry.code_location + for location_entry in workspace.code_location_entries.values() + if location_entry.code_location + ) + repos = ( + repo + for code_location in code_locations + for repo in code_location.get_repositories().values() + ) + + repo_handle_assets: Sequence[Tuple["RepositoryHandle", "AssetNodeSnap"]] = [] + repo_handle_asset_checks: Sequence[Tuple["RepositoryHandle", "AssetCheckNodeSnap"]] = [] + for repo in repos: + for asset_node_snap in repo.get_asset_node_snaps(): + repo_handle_assets.append((repo.handle, asset_node_snap)) + for asset_check_node_snap in repo.get_asset_check_node_snaps(): + repo_handle_asset_checks.append((repo.handle, asset_check_node_snap)) + + return cls._build( + scope=RemoteAssetGraphScope.WORKSPACE, + repo_handle_assets=repo_handle_assets, + repo_handle_asset_checks=repo_handle_asset_checks, + ) + + @classmethod + def _build( cls, + scope: RemoteAssetGraphScope, repo_handle_assets: Sequence[Tuple[RepositoryHandle, "AssetNodeSnap"]], repo_handle_asset_checks: Sequence[Tuple[RepositoryHandle, "AssetCheckNodeSnap"]], ) -> "RemoteAssetGraph": @@ -313,6 +365,7 @@ def from_repository_handles_and_asset_node_snaps( } return cls( + scope, asset_nodes_by_key, asset_checks_by_key, asset_check_execution_sets_by_key, diff --git a/python_modules/dagster/dagster/_core/remote_representation/external.py b/python_modules/dagster/dagster/_core/remote_representation/external.py index 8cc853e33eaea..eec98f1474081 100644 --- a/python_modules/dagster/dagster/_core/remote_representation/external.py +++ b/python_modules/dagster/dagster/_core/remote_representation/external.py @@ -396,15 +396,7 @@ def asset_graph(self) -> "RemoteAssetGraph": """Returns a repository scoped RemoteAssetGraph.""" from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph - return RemoteAssetGraph.from_repository_handles_and_asset_node_snaps( - repo_handle_assets=[ - (self.handle, node_snap) for node_snap in self.get_asset_node_snaps() - ], - repo_handle_asset_checks=[ - (self.handle, asset_check_node) - for asset_check_node in self.get_asset_check_node_snaps() - ], - ) + return RemoteAssetGraph.from_remote_repository(self) def get_partition_names_for_asset_job( self, diff --git a/python_modules/dagster/dagster/_core/test_utils.py b/python_modules/dagster/dagster/_core/test_utils.py index bc131c34c3946..c312adaca2996 100644 --- a/python_modules/dagster/dagster/_core/test_utils.py +++ b/python_modules/dagster/dagster/_core/test_utils.py @@ -47,6 +47,9 @@ from dagster._core.definitions.graph_definition import GraphDefinition from dagster._core.definitions.job_definition import JobDefinition from dagster._core.definitions.node_definition import NodeDefinition +from dagster._core.definitions.repository_definition.repository_definition import ( + RepositoryDefinition, +) from dagster._core.definitions.source_asset import SourceAsset from dagster._core.definitions.unresolved_asset_job_definition import define_asset_job from dagster._core.errors import DagsterUserCodeUnreachableError @@ -61,6 +64,9 @@ ) from dagster._core.launcher import RunLauncher from dagster._core.remote_representation import RemoteRepository +from dagster._core.remote_representation.code_location import CodeLocation +from dagster._core.remote_representation.external_data import RepositorySnap +from dagster._core.remote_representation.handle import RepositoryHandle from dagster._core.remote_representation.origin import InProcessCodeLocationOrigin from dagster._core.run_coordinator import RunCoordinator, SubmitRunContext from dagster._core.secrets import SecretsLoader @@ -68,6 +74,7 @@ from dagster._core.types.loadable_target_origin import LoadableTargetOrigin from dagster._core.workspace.context import WorkspaceProcessContext, WorkspaceRequestContext from dagster._core.workspace.load_target import WorkspaceLoadTarget +from dagster._core.workspace.workspace import CodeLocationEntry, WorkspaceSnapshot from dagster._serdes import ConfigurableClass from dagster._serdes.config_class import ConfigurableClassData from dagster._time import create_datetime, get_timezone @@ -763,3 +770,21 @@ def freeze_time(new_now: Union[datetime.datetime, float]): class TestType: ... + + +def mock_workspace_from_repos(repos: Sequence[RepositoryDefinition]) -> WorkspaceSnapshot: + remote_repos = {} + for repo in repos: + remote_repos[repo.name] = RemoteRepository( + RepositorySnap.from_def(repo), + repository_handle=RepositoryHandle.for_test( + location_name="test", + repository_name=repo.name, + ), + instance=DagsterInstance.ephemeral(), + ) + mock_entry = unittest.mock.MagicMock(spec=CodeLocationEntry) + mock_location = unittest.mock.MagicMock(spec=CodeLocation) + mock_location.get_repositories.return_value = remote_repos + type(mock_entry).code_location = unittest.mock.PropertyMock(return_value=mock_location) + return WorkspaceSnapshot(code_location_entries={"test": mock_entry}) diff --git a/python_modules/dagster/dagster/_core/workspace/workspace.py b/python_modules/dagster/dagster/_core/workspace/workspace.py index 9490c863e1bd7..0bf85648ec883 100644 --- a/python_modules/dagster/dagster/_core/workspace/workspace.py +++ b/python_modules/dagster/dagster/_core/workspace/workspace.py @@ -1,6 +1,6 @@ from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Mapping, Optional from typing_extensions import Annotated @@ -10,8 +10,6 @@ if TYPE_CHECKING: from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph from dagster._core.remote_representation import CodeLocation, CodeLocationOrigin - from dagster._core.remote_representation.external_data import AssetCheckNodeSnap, AssetNodeSnap - from dagster._core.remote_representation.handle import RepositoryHandle # For locations that are loaded asynchronously @@ -53,29 +51,7 @@ class WorkspaceSnapshot: def asset_graph(self) -> "RemoteAssetGraph": from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph - code_locations = ( - location_entry.code_location - for location_entry in self.code_location_entries.values() - if location_entry.code_location - ) - repos = ( - repo - for code_location in code_locations - for repo in code_location.get_repositories().values() - ) - repo_handle_assets: Sequence[Tuple["RepositoryHandle", "AssetNodeSnap"]] = [] - repo_handle_asset_checks: Sequence[Tuple["RepositoryHandle", "AssetCheckNodeSnap"]] = [] - - for repo in repos: - for asset_node_snap in repo.get_asset_node_snaps(): - repo_handle_assets.append((repo.handle, asset_node_snap)) - for asset_check_node_snap in repo.get_asset_check_node_snaps(): - repo_handle_asset_checks.append((repo.handle, asset_check_node_snap)) - - return RemoteAssetGraph.from_repository_handles_and_asset_node_snaps( - repo_handle_assets=repo_handle_assets, - repo_handle_asset_checks=repo_handle_asset_checks, - ) + return RemoteAssetGraph.from_workspace_snapshot(self) def with_code_location(self, name: str, entry: CodeLocationEntry) -> "WorkspaceSnapshot": return WorkspaceSnapshot(code_location_entries={**self.code_location_entries, name: entry}) diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_graph.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_graph.py index b230279c898c9..f94f92bdd8f7a 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_graph.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_asset_graph.py @@ -9,6 +9,7 @@ AssetOut, AssetsDefinition, AutomationCondition, + DagsterInstance, DailyPartitionsDefinition, GraphOut, HourlyPartitionsDefinition, @@ -37,12 +38,10 @@ from dagster._core.definitions.source_asset import SourceAsset from dagster._core.errors import DagsterDefinitionChangedDeserializationError from dagster._core.instance import DynamicPartitionsStore -from dagster._core.remote_representation.external_data import ( - asset_check_node_snaps_from_repo, - asset_node_snaps_from_repo, -) +from dagster._core.remote_representation.external import RemoteRepository +from dagster._core.remote_representation.external_data import RepositorySnap from dagster._core.remote_representation.handle import RepositoryHandle -from dagster._core.test_utils import freeze_time, instance_for_test +from dagster._core.test_utils import freeze_time, instance_for_test, mock_workspace_from_repos from dagster._time import create_datetime, get_current_datetime @@ -51,12 +50,12 @@ def to_remote_asset_graph(assets, asset_checks=None) -> RemoteAssetGraph: def repo(): return assets + (asset_checks or []) - asset_node_snaps = asset_node_snaps_from_repo(repo) - handle = RepositoryHandle.for_test(location_name="fake", repository_name="repo") - return RemoteAssetGraph.from_repository_handles_and_asset_node_snaps( - [(handle, asset_node) for asset_node in asset_node_snaps], - [(handle, asset_check) for asset_check in asset_check_node_snaps_from_repo(repo)], + remote_repo = RemoteRepository( + RepositorySnap.from_def(repo), + repository_handle=RepositoryHandle.for_test(location_name="fake", repository_name="repo"), + instance=DagsterInstance.ephemeral(), ) + return RemoteAssetGraph.from_remote_repository(remote_repo) @pytest.fixture( @@ -889,11 +888,8 @@ def b(): ... def repo_b(): return [b] - a_nodes = asset_node_snaps_from_repo(repo_a) - b_nodes = asset_node_snaps_from_repo(repo_b) - handle = RepositoryHandle.for_test(location_name="foo", repository_name="bar") - asset_graph = RemoteAssetGraph.from_repository_handles_and_asset_node_snaps( - [(handle, asset_node) for asset_node in [*a_nodes, *b_nodes]], [] + asset_graph = RemoteAssetGraph.from_workspace_snapshot( + mock_workspace_from_repos([repo_a, repo_b]) ) assert isinstance( diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_asset_backfill.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_asset_backfill.py index 10c1ba732e8f0..31430b56c44d4 100644 --- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_asset_backfill.py +++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_asset_backfill.py @@ -14,7 +14,6 @@ ) from unittest.mock import MagicMock, patch -import mock import pytest from dagster import ( AssetIn, @@ -24,7 +23,6 @@ DagsterInstance, DagsterRunStatus, DailyPartitionsDefinition, - Definitions, HourlyPartitionsDefinition, LastPartitionMapping, Nothing, @@ -41,6 +39,7 @@ from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView, TemporalContext from dagster._core.definitions.asset_graph_subset import AssetGraphSubset from dagster._core.definitions.base_asset_graph import BaseAssetGraph +from dagster._core.definitions.decorators.repository_decorator import repository from dagster._core.definitions.events import AssetKeyPartitionKey from dagster._core.definitions.remote_asset_graph import RemoteAssetGraph from dagster._core.definitions.selector import ( @@ -56,8 +55,6 @@ execute_asset_backfill_iteration_inner, get_canceling_asset_backfill_iteration_data, ) -from dagster._core.remote_representation.external_data import asset_node_snaps_from_repo -from dagster._core.remote_representation.handle import RepositoryHandle from dagster._core.storage.dagster_run import RunsFilter from dagster._core.storage.tags import ( ASSET_PARTITION_RANGE_END_TAG, @@ -65,7 +62,12 @@ BACKFILL_ID_TAG, PARTITION_NAME_TAG, ) -from dagster._core.test_utils import environ, freeze_time, instance_for_test +from dagster._core.test_utils import ( + environ, + freeze_time, + instance_for_test, + mock_workspace_from_repos, +) from dagster._serdes import deserialize_value, serialize_value from dagster._time import create_datetime, get_current_datetime, get_current_timestamp from dagster._utils import Counter, traced_counter @@ -293,9 +295,9 @@ def _single_backfill_iteration_create_but_do_not_submit_runs( backfill_id, backfill_data, asset_graph, instance, assets_by_repo_name ) -> AssetBackfillData: # Patch the run execution to not actually execute the run, but instead just create it - with mock.patch( + with patch( "dagster._core.execution.execute_in_process.ExecuteRunWithPlanIterable", - return_value=mock.MagicMock(), + return_value=MagicMock(), ): return _single_backfill_iteration( backfill_id, backfill_data, asset_graph, instance, assets_by_repo_name @@ -741,19 +743,16 @@ def _requested_asset_partitions_in_run_request( def remote_asset_graph_from_assets_by_repo_name( assets_by_repo_name: Mapping[str, Sequence[AssetsDefinition]], ) -> RemoteAssetGraph: - from_repository_handles_and_asset_node_snaps = [] - + repos = [] for repo_name, assets in assets_by_repo_name.items(): - repo = Definitions(assets=assets).get_repository_def() - asset_node_snaps = asset_node_snaps_from_repo(repo) - handle = RepositoryHandle.for_test(location_name="test", repository_name=repo_name) - from_repository_handles_and_asset_node_snaps.extend( - [(handle, asset_node) for asset_node in asset_node_snaps] - ) - return RemoteAssetGraph.from_repository_handles_and_asset_node_snaps( - from_repository_handles_and_asset_node_snaps, [] - ) + @repository(name=repo_name) + def repo(assets=assets): + return assets + + repos.append(repo) + + return RemoteAssetGraph.from_workspace_snapshot(mock_workspace_from_repos(repos)) @pytest.mark.parametrize(