From 8500be5de55dc8faf35357dd0c28c9b71f163ae0 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Tue, 6 Aug 2024 15:41:45 -0700 Subject: [PATCH] grpc branch-name: partition-set-to-job/grpc --- .../implementation/execution/backfill.py | 5 +- .../implementation/fetch_partition_sets.py | 24 +-- .../dagster_graphql/schema/partition_sets.py | 11 +- .../dagster/_api/snapshot_partition.py | 37 ++-- python_modules/dagster/dagster/_cli/asset.py | 5 +- python_modules/dagster/dagster/_cli/job.py | 5 +- .../_core/definitions/job_definition.py | 66 +++++--- .../dagster/_core/definitions/run_request.py | 10 +- .../remote_representation/code_location.py | 100 ++++++----- .../dagster/_core/workspace/context.py | 41 +++-- python_modules/dagster/dagster/_grpc/impl.py | 83 ++++----- .../dagster/dagster/_grpc/server.py | 16 +- .../api_tests/test_api_snapshot_partition.py | 159 +++++++++++++++--- .../dagster/dagster_tests/api_tests/utils.py | 22 +++ .../command_tests/test_materialize_command.py | 2 +- 15 files changed, 401 insertions(+), 185 deletions(-) diff --git a/python_modules/dagster-graphql/dagster_graphql/implementation/execution/backfill.py b/python_modules/dagster-graphql/dagster_graphql/implementation/execution/backfill.py index e1aaf823190a0..b1163612eeebc 100644 --- a/python_modules/dagster-graphql/dagster_graphql/implementation/execution/backfill.py +++ b/python_modules/dagster-graphql/dagster_graphql/implementation/execution/backfill.py @@ -140,7 +140,10 @@ def create_and_launch_partition_backfill( if backfill_params.get("allPartitions"): result = graphene_info.context.get_external_partition_names( - external_partition_set, instance=graphene_info.context.instance + repository_handle=repository.handle, + job_name=external_partition_set.job_name, + instance=graphene_info.context.instance, + selected_asset_keys=None, ) if isinstance(result, ExternalPartitionExecutionErrorData): raise DagsterUserCodeProcessError.from_error_info(result.error) diff --git a/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_partition_sets.py b/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_partition_sets.py index 8f60467179dee..10f36245eaae7 100644 --- a/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_partition_sets.py +++ b/python_modules/dagster-graphql/dagster_graphql/implementation/fetch_partition_sets.py @@ -1,7 +1,8 @@ from collections import defaultdict -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, AbstractSet, Optional, Sequence, Union import dagster._check as check +from dagster._core.definitions.asset_key import AssetKey from dagster._core.definitions.selector import RepositorySelector from dagster._core.errors import DagsterUserCodeProcessError from dagster._core.remote_representation import ExternalPartitionSet, RepositoryHandle @@ -107,20 +108,18 @@ def get_partition_by_name( def get_partition_config( graphene_info: ResolveInfo, repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> "GraphenePartitionRunConfig": from ..schema.partition_sets import GraphenePartitionRunConfig check.inst_param(repository_handle, "repository_handle", RepositoryHandle) - check.str_param(partition_set_name, "partition_set_name") + check.str_param(job_name, "job_name") check.str_param(partition_name, "partition_name") result = graphene_info.context.get_external_partition_config( - repository_handle, - partition_set_name, - partition_name, - graphene_info.context.instance, + repository_handle, job_name, partition_name, graphene_info.context.instance ) if isinstance(result, ExternalPartitionExecutionErrorData): @@ -132,18 +131,23 @@ def get_partition_config( def get_partition_tags( graphene_info: ResolveInfo, repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> "GraphenePartitionTags": from ..schema.partition_sets import GraphenePartitionTags from ..schema.tags import GraphenePipelineTag check.inst_param(repository_handle, "repository_handle", RepositoryHandle) - check.str_param(partition_set_name, "partition_set_name") + check.str_param(job_name, "job_name") check.str_param(partition_name, "partition_name") result = graphene_info.context.get_external_partition_tags( - repository_handle, partition_set_name, partition_name, graphene_info.context.instance + repository_handle, + job_name, + partition_name, + graphene_info.context.instance, + selected_asset_keys=selected_asset_keys, ) if isinstance(result, ExternalPartitionExecutionErrorData): diff --git a/python_modules/dagster-graphql/dagster_graphql/schema/partition_sets.py b/python_modules/dagster-graphql/dagster_graphql/schema/partition_sets.py index 6c1da95485d08..cc00bbd286750 100644 --- a/python_modules/dagster-graphql/dagster_graphql/schema/partition_sets.py +++ b/python_modules/dagster-graphql/dagster_graphql/schema/partition_sets.py @@ -12,6 +12,7 @@ ExternalPartitionsDefinitionData, ExternalStaticPartitionsDefinitionData, ExternalTimeWindowPartitionsDefinitionData, + job_name_for_external_partition_set_name, ) from dagster._core.storage.dagster_run import RunsFilter from dagster._core.storage.tags import PARTITION_NAME_TAG, PARTITION_SET_TAG @@ -227,8 +228,9 @@ def resolve_runConfigOrError(self, graphene_info: ResolveInfo): return get_partition_config( graphene_info, self._external_repository_handle, - self._external_partition_set.name, + job_name_for_external_partition_set_name(self._external_partition_set.name), self._partition_name, + selected_asset_keys=None, ) @capture_error @@ -236,8 +238,9 @@ def resolve_tagsOrError(self, graphene_info: ResolveInfo): return get_partition_tags( graphene_info, self._external_repository_handle, - self._external_partition_set.name, + job_name_for_external_partition_set_name(self._external_partition_set.name), self._partition_name, + selected_asset_keys=None, ) def resolve_runs( @@ -326,8 +329,10 @@ def __init__( def _get_partition_names(self, graphene_info: ResolveInfo) -> Sequence[str]: if self._partition_names is None: result = graphene_info.context.get_external_partition_names( - self._external_partition_set, + repository_handle=self._external_repository_handle, + job_name=self._external_partition_set.job_name, instance=graphene_info.context.instance, + selected_asset_keys=None, ) if isinstance(result, ExternalPartitionExecutionErrorData): raise DagsterUserCodeProcessError.from_error_info(result.error) diff --git a/python_modules/dagster/dagster/_api/snapshot_partition.py b/python_modules/dagster/dagster/_api/snapshot_partition.py index 3377e1509d18b..c77c573fb1252 100644 --- a/python_modules/dagster/dagster/_api/snapshot_partition.py +++ b/python_modules/dagster/dagster/_api/snapshot_partition.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, AbstractSet, Optional, Sequence import dagster._check as check +from dagster._core.definitions.asset_key import AssetKey from dagster._core.errors import DagsterUserCodeProcessError from dagster._core.instance import DagsterInstance from dagster._core.remote_representation.external_data import ( @@ -9,7 +10,7 @@ ExternalPartitionNamesData, ExternalPartitionSetExecutionParamData, ExternalPartitionTagsData, - job_name_for_external_partition_set_name, + external_partition_set_name_for_job_name, ) from dagster._core.remote_representation.handle import RepositoryHandle from dagster._grpc.types import PartitionArgs, PartitionNamesArgs, PartitionSetExecutionParamArgs @@ -20,21 +21,24 @@ def sync_get_external_partition_names_grpc( - api_client: "DagsterGrpcClient", repository_handle: RepositoryHandle, partition_set_name: str + api_client: "DagsterGrpcClient", + repository_handle: RepositoryHandle, + job_name: str, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> ExternalPartitionNamesData: from dagster._grpc.client import DagsterGrpcClient check.inst_param(api_client, "api_client", DagsterGrpcClient) check.inst_param(repository_handle, "repository_handle", RepositoryHandle) - check.str_param(partition_set_name, "partition_set_name") + check.str_param(job_name, "job_name") repository_origin = repository_handle.get_external_origin() result = deserialize_value( api_client.external_partition_names( partition_names_args=PartitionNamesArgs( repository_origin=repository_origin, - job_name=job_name_for_external_partition_set_name(partition_set_name), - partition_set_name=partition_set_name, - selected_asset_keys=None, + job_name=job_name, + partition_set_name=external_partition_set_name_for_job_name(job_name), + selected_asset_keys=selected_asset_keys, ), ), (ExternalPartitionNamesData, ExternalPartitionExecutionErrorData), @@ -48,7 +52,7 @@ def sync_get_external_partition_names_grpc( def sync_get_external_partition_config_grpc( api_client: "DagsterGrpcClient", repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, instance: DagsterInstance, ) -> ExternalPartitionConfigData: @@ -56,15 +60,15 @@ def sync_get_external_partition_config_grpc( check.inst_param(api_client, "api_client", DagsterGrpcClient) check.inst_param(repository_handle, "repository_handle", RepositoryHandle) - check.str_param(partition_set_name, "partition_set_name") + check.str_param(job_name, "job_name") check.str_param(partition_name, "partition_name") repository_origin = repository_handle.get_external_origin() result = deserialize_value( api_client.external_partition_config( partition_args=PartitionArgs( repository_origin=repository_origin, - job_name=job_name_for_external_partition_set_name(partition_set_name), - partition_set_name=partition_set_name, + job_name=job_name, + partition_set_name=external_partition_set_name_for_job_name(job_name), partition_name=partition_name, instance_ref=instance.get_ref(), selected_asset_keys=None, @@ -81,15 +85,16 @@ def sync_get_external_partition_config_grpc( def sync_get_external_partition_tags_grpc( api_client: "DagsterGrpcClient", repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, instance: DagsterInstance, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> ExternalPartitionTagsData: from dagster._grpc.client import DagsterGrpcClient check.inst_param(api_client, "api_client", DagsterGrpcClient) check.inst_param(repository_handle, "repository_handle", RepositoryHandle) - check.str_param(partition_set_name, "partition_set_name") + check.str_param(job_name, "job_name") check.str_param(partition_name, "partition_name") repository_origin = repository_handle.get_external_origin() @@ -97,11 +102,11 @@ def sync_get_external_partition_tags_grpc( api_client.external_partition_tags( partition_args=PartitionArgs( repository_origin=repository_origin, - job_name=job_name_for_external_partition_set_name(partition_set_name), - partition_set_name=partition_set_name, + job_name=job_name, + partition_set_name=external_partition_set_name_for_job_name(job_name), partition_name=partition_name, instance_ref=instance.get_ref(), - selected_asset_keys=None, + selected_asset_keys=selected_asset_keys, ), ), (ExternalPartitionTagsData, ExternalPartitionExecutionErrorData), diff --git a/python_modules/dagster/dagster/_cli/asset.py b/python_modules/dagster/dagster/_cli/asset.py index d2e17afd4b46f..07dfa32d18bdd 100644 --- a/python_modules/dagster/dagster/_cli/asset.py +++ b/python_modules/dagster/dagster/_cli/asset.py @@ -60,9 +60,12 @@ def execute_materialize_command(instance: DagsterInstance, kwargs: Mapping[str, check.failed("Provided '--partition' option, but none of the assets are partitioned") try: - tags = implicit_job_def.get_tags_for_partition_key( + implicit_job_def.validate_partition_key( partition, selected_asset_keys=asset_keys, dynamic_partitions_store=instance ) + tags = implicit_job_def.get_tags_for_partition_key( + partition, selected_asset_keys=asset_keys + ) except DagsterUnknownPartitionError: raise DagsterInvalidSubsetError( "All selected assets must have a PartitionsDefinition containing the passed" diff --git a/python_modules/dagster/dagster/_cli/job.py b/python_modules/dagster/dagster/_cli/job.py index 92ff1e2124822..918289d1a51ce 100644 --- a/python_modules/dagster/dagster/_cli/job.py +++ b/python_modules/dagster/dagster/_cli/job.py @@ -650,7 +650,10 @@ def _execute_backfill_command_at_location( try: partition_names_or_error = code_location.get_external_partition_names( - job_partition_set, instance=instance + repository_handle=repo_handle, + job_name=external_job.name, + instance=instance, + selected_asset_keys=None, ) except Exception as e: error_info = serializable_error_info_from_exc_info(sys.exc_info()) diff --git a/python_modules/dagster/dagster/_core/definitions/job_definition.py b/python_modules/dagster/dagster/_core/definitions/job_definition.py index de8e0ec9c4fc8..3c17f999f3398 100644 --- a/python_modules/dagster/dagster/_core/definitions/job_definition.py +++ b/python_modules/dagster/dagster/_core/definitions/job_definition.py @@ -685,11 +685,15 @@ def execute_in_process( merged_tags = merge_dicts(self.tags, tags or {}) if partition_key: - tags_for_partition_key = ephemeral_job.get_tags_for_partition_key( + ephemeral_job.validate_partition_key( partition_key, selected_asset_keys=asset_selection, dynamic_partitions_store=instance, ) + tags_for_partition_key = ephemeral_job.get_tags_for_partition_key( + partition_key, + selected_asset_keys=asset_selection, + ) if not run_config and self.partitioned_config: run_config = self.partitioned_config.get_run_config_for_partition_key(partition_key) @@ -714,18 +718,11 @@ def execute_in_process( asset_selection=frozenset(asset_selection), ) - def get_tags_for_partition_key( - self, - partition_key: str, - dynamic_partitions_store: Optional["DynamicPartitionsStore"], - selected_asset_keys: Optional[Iterable[AssetKey]], - ) -> Mapping[str, str]: - """Gets tags for the given partition key and ensures that it's a member of the PartitionsDefinition - corresponding to every asset in the selection. - """ - partitions_def = None + def _get_partitions_def( + self, selected_asset_keys: Optional[Iterable[AssetKey]] + ) -> PartitionsDefinition: if self.partitions_def: - partitions_def = self.partitions_def + return self.partitions_def elif self.asset_layer: if selected_asset_keys: resolved_selected_asset_keys = selected_asset_keys @@ -736,21 +733,48 @@ def get_tags_for_partition_key( key for key in self.asset_layer.asset_keys_by_node_output_handle.values() ] - unique_partitions_defs = { - self.asset_layer.get(asset_key).partitions_def - for asset_key in resolved_selected_asset_keys - } - {None} + unique_partitions_defs: Set[PartitionsDefinition] = set() + for asset_key in resolved_selected_asset_keys: + partitions_def = self.asset_layer.get(asset_key).partitions_def + if partitions_def is not None: + unique_partitions_defs.add(partitions_def) + if len(unique_partitions_defs) == 1: - partitions_def = next(iter(unique_partitions_defs)) - elif len(unique_partitions_defs) > 1: - check.failed("Attempted to execute a run for assets with different partitions") + return check.not_none(next(iter(unique_partitions_defs))) + + if selected_asset_keys is not None: + check.failed("There is no PartitionsDefinition shared by all the provided assets") + else: + check.failed("Job has no PartitionsDefinition") - if partitions_def is None: - check.failed("Attempted to execute a partitioned run for a non-partitioned job") + def get_partition_keys( + self, selected_asset_keys: Optional[Iterable[AssetKey]] + ) -> Sequence[str]: + partitions_def = self._get_partitions_def(selected_asset_keys) + return partitions_def.get_partition_keys() + def validate_partition_key( + self, + partition_key: str, + dynamic_partitions_store: Optional["DynamicPartitionsStore"], + selected_asset_keys: Optional[Iterable[AssetKey]], + ) -> None: + """Ensures that the given partition_key is a member of the PartitionsDefinition + corresponding to every asset in the selection. + """ + partitions_def = self._get_partitions_def(selected_asset_keys) partitions_def.validate_partition_key( partition_key, dynamic_partitions_store=dynamic_partitions_store ) + + def get_tags_for_partition_key( + self, partition_key: str, selected_asset_keys: Optional[Iterable[AssetKey]] + ) -> Mapping[str, str]: + """Gets tags for the given partition key.""" + if self._partitioned_config is not None: + return self._partitioned_config.get_tags_for_partition_key(partition_key, self.name) + + partitions_def = self._get_partitions_def(selected_asset_keys) return partitions_def.get_tags_for_partition_key(partition_key) def get_run_config_for_partition_key(self, partition_key: str) -> Mapping[str, Any]: diff --git a/python_modules/dagster/dagster/_core/definitions/run_request.py b/python_modules/dagster/dagster/_core/definitions/run_request.py index 6797e635724a7..f877c49f5468b 100644 --- a/python_modules/dagster/dagster/_core/definitions/run_request.py +++ b/python_modules/dagster/dagster/_core/definitions/run_request.py @@ -203,12 +203,16 @@ def with_resolved_tags_and_config( if dynamic_partitions_store else None ) + target_definition.validate_partition_key( + self.partition_key, + dynamic_partitions_store=dynamic_partitions_store_after_requests, + selected_asset_keys=self.asset_selection, + ) + tags = { **(self.tags or {}), **target_definition.get_tags_for_partition_key( - self.partition_key, - dynamic_partitions_store=dynamic_partitions_store_after_requests, - selected_asset_keys=self.asset_selection, + self.partition_key, selected_asset_keys=self.asset_selection ), } diff --git a/python_modules/dagster/dagster/_core/remote_representation/code_location.py b/python_modules/dagster/dagster/_core/remote_representation/code_location.py index 20d78015b8f3c..e3c6f9da3dd3b 100644 --- a/python_modules/dagster/dagster/_core/remote_representation/code_location.py +++ b/python_modules/dagster/dagster/_core/remote_representation/code_location.py @@ -2,7 +2,18 @@ import threading from abc import abstractmethod from contextlib import AbstractContextManager -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Dict, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, +) import dagster._check as check from dagster._api.get_server_id import sync_get_server_id @@ -19,6 +30,7 @@ from dagster._api.snapshot_repository import sync_get_streaming_external_repositories_data_grpc from dagster._api.snapshot_schedule import sync_get_external_schedule_execution_data_grpc from dagster._core.code_pointer import CodePointer +from dagster._core.definitions.asset_key import AssetKey from dagster._core.definitions.reconstruct import ReconstructableJob from dagster._core.definitions.repository_definition import RepositoryDefinition from dagster._core.definitions.selector import JobSubsetSelector @@ -37,13 +49,13 @@ from dagster._core.remote_representation.external import ( ExternalExecutionPlan, ExternalJob, - ExternalPartitionSet, ExternalRepository, ) from dagster._core.remote_representation.external_data import ( ExternalPartitionNamesData, ExternalScheduleExecutionErrorData, ExternalSensorExecutionErrorData, + external_partition_set_name_for_job_name, external_repository_data_from_def, ) from dagster._core.remote_representation.grpc_server_registry import GrpcServerRegistry @@ -172,7 +184,7 @@ def get_subset_external_job_result( def get_external_partition_config( self, repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, instance: DagsterInstance, ) -> Union["ExternalPartitionConfigData", "ExternalPartitionExecutionErrorData"]: @@ -182,15 +194,20 @@ def get_external_partition_config( def get_external_partition_tags( self, repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, instance: DagsterInstance, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> Union["ExternalPartitionTagsData", "ExternalPartitionExecutionErrorData"]: pass @abstractmethod def get_external_partition_names( - self, external_partition_set: ExternalPartitionSet, instance: DagsterInstance + self, + repository_handle: RepositoryHandle, + job_name: str, + instance: DagsterInstance, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> Union["ExternalPartitionNamesData", "ExternalPartitionExecutionErrorData"]: pass @@ -428,17 +445,17 @@ def get_external_execution_plan( def get_external_partition_config( self, repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, instance: DagsterInstance, ) -> Union["ExternalPartitionConfigData", "ExternalPartitionExecutionErrorData"]: check.inst_param(repository_handle, "repository_handle", RepositoryHandle) - check.str_param(partition_set_name, "partition_set_name") + check.str_param(job_name, "job_name") check.str_param(partition_name, "partition_name") return get_partition_config( self._get_repo_def(repository_handle.repository_name), - partition_set_name=partition_set_name, + job_name=job_name, partition_key=partition_name, instance_ref=instance.get_ref(), ) @@ -446,37 +463,35 @@ def get_external_partition_config( def get_external_partition_tags( self, repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, instance: DagsterInstance, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> Union["ExternalPartitionTagsData", "ExternalPartitionExecutionErrorData"]: check.inst_param(repository_handle, "repository_handle", RepositoryHandle) - check.str_param(partition_set_name, "partition_set_name") + check.str_param(job_name, "job_name") check.str_param(partition_name, "partition_name") check.inst_param(instance, "instance", DagsterInstance) return get_partition_tags( self._get_repo_def(repository_handle.repository_name), - partition_set_name=partition_set_name, + job_name=job_name, partition_name=partition_name, instance_ref=instance.get_ref(), + selected_asset_keys=selected_asset_keys, ) def get_external_partition_names( - self, external_partition_set: ExternalPartitionSet, instance: DagsterInstance + self, + repository_handle: RepositoryHandle, + job_name: str, + instance: DagsterInstance, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> Union["ExternalPartitionNamesData", "ExternalPartitionExecutionErrorData"]: - check.inst_param(external_partition_set, "external_partition_set", ExternalPartitionSet) - - # Prefer to return the names without calling out to user code if the - # partition set allows it - if external_partition_set.has_partition_name_data(): - return ExternalPartitionNamesData( - partition_names=external_partition_set.get_partition_names(instance) - ) - return get_partition_names( - self._get_repo_def(external_partition_set.repository_handle.repository_name), - partition_set_name=external_partition_set.name, + self._get_repo_def(repository_handle.repository_name), + job_name=job_name, + selected_asset_keys=selected_asset_keys, ) def get_external_schedule_execution_data( @@ -826,47 +841,56 @@ def get_subset_external_job_result( def get_external_partition_config( self, repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, instance: DagsterInstance, ) -> "ExternalPartitionConfigData": check.inst_param(repository_handle, "repository_handle", RepositoryHandle) - check.str_param(partition_set_name, "partition_set_name") + check.str_param(job_name, "job_name") check.str_param(partition_name, "partition_name") return sync_get_external_partition_config_grpc( - self.client, repository_handle, partition_set_name, partition_name, instance + self.client, repository_handle, job_name, partition_name, instance ) def get_external_partition_tags( self, repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, instance: DagsterInstance, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> "ExternalPartitionTagsData": check.inst_param(repository_handle, "repository_handle", RepositoryHandle) - check.str_param(partition_set_name, "partition_set_name") + check.str_param(job_name, "job_name") check.str_param(partition_name, "partition_name") return sync_get_external_partition_tags_grpc( - self.client, repository_handle, partition_set_name, partition_name, instance + self.client, repository_handle, job_name, partition_name, instance, selected_asset_keys ) def get_external_partition_names( - self, external_partition_set: ExternalPartitionSet, instance: DagsterInstance + self, + repository_handle: RepositoryHandle, + job_name: str, + instance: DagsterInstance, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> "ExternalPartitionNamesData": - check.inst_param(external_partition_set, "external_partition_set", ExternalPartitionSet) + external_repo = self.get_repository(repository_handle.repository_name) + partition_set_name = external_partition_set_name_for_job_name(job_name) - # Prefer to return the names without calling out to user code if the - # partition set allows it - if external_partition_set.has_partition_name_data(): - return ExternalPartitionNamesData( - partition_names=external_partition_set.get_partition_names(instance=instance) - ) + # Prefer to return the names without calling out to user code if there's a corresponding + # partition set that allows it + if external_repo.has_external_partition_set(partition_set_name): + external_partition_set = external_repo.get_external_partition_set(partition_set_name) + + if external_partition_set.has_partition_name_data(): + return ExternalPartitionNamesData( + partition_names=external_partition_set.get_partition_names(instance=instance) + ) return sync_get_external_partition_names_grpc( - self.client, external_partition_set.repository_handle, external_partition_set.name + self.client, repository_handle, job_name, selected_asset_keys ) def get_external_schedule_execution_data( diff --git a/python_modules/dagster/dagster/_core/workspace/context.py b/python_modules/dagster/dagster/_core/workspace/context.py index 2b66e9ca2f990..ecde813391236 100644 --- a/python_modules/dagster/dagster/_core/workspace/context.py +++ b/python_modules/dagster/dagster/_core/workspace/context.py @@ -6,11 +6,24 @@ from abc import ABC, abstractmethod from contextlib import ExitStack from itertools import count -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence, Set, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Dict, + Mapping, + Optional, + Sequence, + Set, + Type, + TypeVar, + Union, +) from typing_extensions import Self import dagster._check as check +from dagster._core.definitions.asset_key import AssetKey from dagster._core.definitions.selector import JobSubsetSelector from dagster._core.errors import DagsterCodeLocationLoadError, DagsterCodeLocationNotFoundError from dagster._core.execution.plan.state import KnownExecutionState @@ -20,7 +33,6 @@ CodeLocationOrigin, ExternalExecutionPlan, ExternalJob, - ExternalPartitionSet, GrpcServerCodeLocation, RepositoryHandle, ) @@ -256,7 +268,7 @@ def get_external_execution_plan( def get_external_partition_config( self, repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, instance: DagsterInstance, ) -> Union["ExternalPartitionConfigData", "ExternalPartitionExecutionErrorData"]: @@ -264,7 +276,7 @@ def get_external_partition_config( repository_handle.location_name ).get_external_partition_config( repository_handle=repository_handle, - partition_set_name=partition_set_name, + job_name=job_name, partition_name=partition_name, instance=instance, ) @@ -272,23 +284,32 @@ def get_external_partition_config( def get_external_partition_tags( self, repository_handle: RepositoryHandle, - partition_set_name: str, + job_name: str, partition_name: str, instance: DagsterInstance, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> Union["ExternalPartitionTagsData", "ExternalPartitionExecutionErrorData"]: return self.get_code_location(repository_handle.location_name).get_external_partition_tags( repository_handle=repository_handle, - partition_set_name=partition_set_name, + job_name=job_name, partition_name=partition_name, instance=instance, + selected_asset_keys=selected_asset_keys, ) def get_external_partition_names( - self, external_partition_set: ExternalPartitionSet, instance: DagsterInstance + self, + repository_handle: RepositoryHandle, + job_name: str, + instance: DagsterInstance, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> Union["ExternalPartitionNamesData", "ExternalPartitionExecutionErrorData"]: - return self.get_code_location( - external_partition_set.repository_handle.location_name - ).get_external_partition_names(external_partition_set, instance=instance) + return self.get_code_location(repository_handle.location_name).get_external_partition_names( + repository_handle=repository_handle, + job_name=job_name, + instance=instance, + selected_asset_keys=selected_asset_keys, + ) def get_external_partition_set_execution_param_data( self, diff --git a/python_modules/dagster/dagster/_grpc/impl.py b/python_modules/dagster/dagster/_grpc/impl.py index b8d5cfd1ba75d..c2de8165917f1 100644 --- a/python_modules/dagster/dagster/_grpc/impl.py +++ b/python_modules/dagster/dagster/_grpc/impl.py @@ -430,30 +430,22 @@ def _get_job_partitions_and_config_for_partition_set_name( def get_partition_config( repo_def: RepositoryDefinition, - partition_set_name: str, + job_name: str, partition_key: str, instance_ref: Optional[InstanceRef] = None, ) -> Union[ExternalPartitionConfigData, ExternalPartitionExecutionErrorData]: try: - ( - _, - partitions_def, - partitioned_config, - ) = _get_job_partitions_and_config_for_partition_set_name(repo_def, partition_set_name) + job_def = repo_def.get_job(job_name) - with _instance_from_ref_for_dynamic_partitions(instance_ref, partitions_def) as instance: - with user_code_error_boundary( - PartitionExecutionError, - lambda: ( - "Error occurred during the evaluation of the `run_config_for_partition`" - f" function for partition set {partition_set_name}" - ), - ): - partitions_def.validate_partition_key( - partition_key, dynamic_partitions_store=instance - ) - run_config = partitioned_config.get_run_config_for_partition_key(partition_key) - return ExternalPartitionConfigData(name=partition_key, run_config=run_config) + with user_code_error_boundary( + PartitionExecutionError, + lambda: ( + "Error occurred during the evaluation of the `run_config_for_partition`" + f" function for job {job_name}" + ), + ): + run_config = job_def.get_run_config_for_partition_key(partition_key) + return ExternalPartitionConfigData(name=partition_key, run_config=run_config) except Exception: return ExternalPartitionExecutionErrorData( error=serializable_error_info_from_exc_info(sys.exc_info()) @@ -462,14 +454,11 @@ def get_partition_config( def get_partition_names( repo_def: RepositoryDefinition, - partition_set_name: str, + job_name: str, + selected_asset_keys: Optional[AbstractSet[AssetKey]], ) -> Union[ExternalPartitionNamesData, ExternalPartitionExecutionErrorData]: try: - ( - job_def, - partitions_def, - _, - ) = _get_job_partitions_and_config_for_partition_set_name(repo_def, partition_set_name) + job_def = repo_def.get_job(job_name) with user_code_error_boundary( PartitionExecutionError, @@ -478,7 +467,9 @@ def get_partition_names( f" partitioned config on job '{job_def.name}'" ), ): - return ExternalPartitionNamesData(partition_names=partitions_def.get_partition_keys()) + return ExternalPartitionNamesData( + partition_names=job_def.get_partition_keys(selected_asset_keys) + ) except Exception: return ExternalPartitionExecutionErrorData( error=serializable_error_info_from_exc_info(sys.exc_info()) @@ -487,35 +478,25 @@ def get_partition_names( def get_partition_tags( repo_def: RepositoryDefinition, - partition_set_name: str, + job_name: str, partition_name: str, + selected_asset_keys: Optional[AbstractSet[AssetKey]], instance_ref: Optional[InstanceRef] = None, -): +) -> Union[ExternalPartitionTagsData, ExternalPartitionExecutionErrorData]: try: - ( - job_def, - partitions_def, - partitioned_config, - ) = _get_job_partitions_and_config_for_partition_set_name(repo_def, partition_set_name) - - # Certain gRPC servers do not have access to the instance, so we only attempt to instantiate - # the instance when necessary for dynamic partitions: https://github.com/dagster-io/dagster/issues/12440 + job_def = repo_def.get_job(job_name) - with _instance_from_ref_for_dynamic_partitions(instance_ref, partitions_def) as instance: - with user_code_error_boundary( - PartitionExecutionError, - lambda: ( - "Error occurred during the evaluation of the `tags_for_partition` function for" - f" partitioned config on job '{job_def.name}'" - ), - ): - partitions_def.validate_partition_key( - partition_name, dynamic_partitions_store=instance - ) - tags = partitioned_config.get_tags_for_partition_key( - partition_name, job_name=job_def.name - ) - return ExternalPartitionTagsData(name=partition_name, tags=tags) + with user_code_error_boundary( + PartitionExecutionError, + lambda: ( + "Error occurred during the evaluation of the `tags_for_partition` function for" + f" partitioned config on job '{job_def.name}'" + ), + ): + tags = job_def.get_tags_for_partition_key( + partition_name, selected_asset_keys=selected_asset_keys + ) + return ExternalPartitionTagsData(name=partition_name, tags=tags) except Exception: return ExternalPartitionExecutionErrorData( diff --git a/python_modules/dagster/dagster/_grpc/server.py b/python_modules/dagster/dagster/_grpc/server.py index 121345f5a0b72..66c740e442dcb 100644 --- a/python_modules/dagster/dagster/_grpc/server.py +++ b/python_modules/dagster/dagster/_grpc/server.py @@ -613,13 +613,14 @@ def ExternalPartitionNames( ) -> api_pb2.ExternalPartitionNamesReply: try: partition_names_args = deserialize_value( - request.serialized_partition_names_args, - PartitionNamesArgs, + request.serialized_partition_names_args, PartitionNamesArgs ) + serialized_response = serialize_value( get_partition_names( self._get_repo_for_origin(partition_names_args.repository_origin), - partition_names_args.partition_set_name, + job_name=partition_names_args.job_name, + selected_asset_keys=partition_names_args.selected_asset_keys, ) ) except Exception: @@ -681,8 +682,8 @@ def ExternalPartitionConfig( serialized_data = serialize_value( get_partition_config( self._get_repo_for_origin(args.repository_origin), - args.partition_set_name, - args.partition_name, + job_name=args.job_name, + partition_key=args.partition_name, instance_ref=instance_ref, ) ) @@ -710,8 +711,9 @@ def ExternalPartitionTags( serialized_data = serialize_value( get_partition_tags( self._get_repo_for_origin(partition_args.repository_origin), - partition_args.partition_set_name, - partition_args.partition_name, + job_name=partition_args.job_name, + partition_name=partition_args.partition_name, + selected_asset_keys=partition_args.selected_asset_keys, instance_ref=instance_ref, ) ) diff --git a/python_modules/dagster/dagster_tests/api_tests/test_api_snapshot_partition.py b/python_modules/dagster/dagster_tests/api_tests/test_api_snapshot_partition.py index c50a4470cc40e..a97081acc557f 100644 --- a/python_modules/dagster/dagster_tests/api_tests/test_api_snapshot_partition.py +++ b/python_modules/dagster/dagster_tests/api_tests/test_api_snapshot_partition.py @@ -1,12 +1,15 @@ import string import pytest +from dagster import AssetKey, ConfigurableResource, Definitions, StaticPartitionsDefinition, asset from dagster._api.snapshot_partition import ( sync_get_external_partition_config_grpc, sync_get_external_partition_names_grpc, sync_get_external_partition_set_execution_param_data_grpc, sync_get_external_partition_tags_grpc, ) +from dagster._core.definitions.asset_job import IMPLICIT_ASSET_JOB_NAME +from dagster._core.definitions.repository_definition import SINGLETON_REPOSITORY_NAME from dagster._core.errors import DagsterUserCodeProcessError from dagster._core.instance import DagsterInstance from dagster._core.remote_representation import ( @@ -16,22 +19,74 @@ ExternalPartitionSetExecutionParamData, ExternalPartitionTagsData, ) +from dagster._core.test_utils import ensure_dagster_tests_import from dagster._grpc.types import PartitionArgs, PartitionNamesArgs, PartitionSetExecutionParamArgs from dagster._serdes import deserialize_value -from .utils import get_bar_repo_code_location +ensure_dagster_tests_import() + +from dagster_tests.api_tests.utils import get_bar_repo_code_location, get_code_location # noqa: I001 + + +def get_repo_with_differently_partitioned_assets(): + @asset(partitions_def=StaticPartitionsDefinition(["1", "2"])) + def asset1(): ... + + ab_partitions_def = StaticPartitionsDefinition(["a", "b"]) + + @asset(partitions_def=ab_partitions_def) + def asset2(): ... + + class MyResource(ConfigurableResource): + foo: str + + @asset(partitions_def=ab_partitions_def) + def asset3(resource1: MyResource): ... + + return Definitions( + assets=[asset1, asset2, asset3], resources={"resource1": MyResource(foo="bar")} + ).get_repository_def() def test_external_partition_names_grpc(instance: DagsterInstance): with get_bar_repo_code_location(instance) as code_location: repository_handle = code_location.get_repository("bar_repo").handle data = sync_get_external_partition_names_grpc( - code_location.client, repository_handle, "baz_partition_set" + code_location.client, repository_handle, "baz", None ) assert isinstance(data, ExternalPartitionNamesData) assert data.partition_names == list(string.ascii_lowercase) +def test_external_partition_names(instance: DagsterInstance): + with get_bar_repo_code_location(instance) as code_location: + data = code_location.get_external_partition_names( + repository_handle=code_location.get_repository("bar_repo").handle, + job_name="baz", + instance=instance, + selected_asset_keys=None, + ) + assert isinstance(data, ExternalPartitionNamesData) + assert data.partition_names == list(string.ascii_lowercase) + + +def test_external_partition_names_asset_selection(instance: DagsterInstance): + with get_code_location( + python_file=__file__, + attribute="get_repo_with_differently_partitioned_assets", + location_name="something", + instance=instance, + ) as code_location: + data = code_location.get_external_partition_names( + repository_handle=code_location.get_repository(SINGLETON_REPOSITORY_NAME).handle, + job_name=IMPLICIT_ASSET_JOB_NAME, + instance=instance, + selected_asset_keys={AssetKey("asset2"), AssetKey("asset3")}, + ) + assert isinstance(data, ExternalPartitionNamesData) + assert data.partition_names == ["a", "b"] + + def test_external_partition_names_deserialize_error_grpc(instance: DagsterInstance): with get_bar_repo_code_location(instance) as code_location: api_client = code_location.client @@ -57,24 +112,52 @@ def test_external_partitions_config_grpc(instance: DagsterInstance): repository_handle = code_location.get_repository("bar_repo").handle data = sync_get_external_partition_config_grpc( - code_location.client, repository_handle, "baz_partition_set", "c", instance + code_location.client, repository_handle, "baz", "c", instance + ) + assert isinstance(data, ExternalPartitionConfigData) + assert data.run_config + assert data.run_config["ops"]["do_input"]["inputs"]["x"]["value"] == "c" # type: ignore + + +def test_external_partition_config(instance: DagsterInstance): + with get_bar_repo_code_location(instance) as code_location: + data = code_location.get_external_partition_config( + job_name="baz", + repository_handle=code_location.get_repository("bar_repo").handle, + partition_name="c", + instance=instance, ) + assert isinstance(data, ExternalPartitionConfigData) assert data.run_config assert data.run_config["ops"]["do_input"]["inputs"]["x"]["value"] == "c" # type: ignore +def test_external_partition_config_different_partitions_defs(instance: DagsterInstance): + with get_code_location( + python_file=__file__, + attribute="get_repo_with_differently_partitioned_assets", + location_name="something", + instance=instance, + ) as code_location: + data = code_location.get_external_partition_config( + job_name=IMPLICIT_ASSET_JOB_NAME, + repository_handle=code_location.get_repository(SINGLETON_REPOSITORY_NAME).handle, + selected_asset_keys={AssetKey("asset2"), AssetKey("asset3")}, + partition_name="b", + instance=instance, + ) + assert isinstance(data, ExternalPartitionConfigData) + assert data.run_config == {} + + def test_external_partitions_config_error_grpc(instance: DagsterInstance): with get_bar_repo_code_location(instance) as code_location: repository_handle = code_location.get_repository("bar_repo").handle with pytest.raises(DagsterUserCodeProcessError): sync_get_external_partition_config_grpc( - code_location.client, - repository_handle, - "error_partition_config", - "c", - instance, + code_location.client, repository_handle, "error_partition_config", "c", instance ) @@ -105,13 +188,52 @@ def test_external_partitions_tags_grpc(instance: DagsterInstance): repository_handle = code_location.get_repository("bar_repo").handle data = sync_get_external_partition_tags_grpc( - code_location.client, repository_handle, "baz_partition_set", "c", instance=instance + code_location.client, + repository_handle, + "baz_partition_set", + "c", + instance=instance, + selected_asset_keys=None, ) assert isinstance(data, ExternalPartitionTagsData) assert data.tags assert data.tags["foo"] == "bar" +def test_external_partition_tags(instance: DagsterInstance): + with get_bar_repo_code_location(instance) as code_location: + data = code_location.get_external_partition_tags( + repository_handle=code_location.get_repository("bar_repo").handle, + job_name="baz", + partition_name="c", + instance=instance, + selected_asset_keys=None, + ) + + assert isinstance(data, ExternalPartitionTagsData) + assert data.tags + assert data.tags["foo"] == "bar" + + +def test_external_partition_tags_different_partitions_defs(instance: DagsterInstance): + with get_code_location( + python_file=__file__, + attribute="get_repo_with_differently_partitioned_assets", + location_name="something", + instance=instance, + ) as code_location: + data = code_location.get_external_partition_tags( + repository_handle=code_location.get_repository(SINGLETON_REPOSITORY_NAME).handle, + job_name=IMPLICIT_ASSET_JOB_NAME, + selected_asset_keys={AssetKey("asset2"), AssetKey("asset3")}, + partition_name="b", + instance=instance, + ) + assert isinstance(data, ExternalPartitionTagsData) + assert data.tags + assert data.tags["dagster/partition"] == "b" + + def test_external_partitions_tags_deserialize_error_grpc(instance: DagsterInstance): with get_bar_repo_code_location(instance) as code_location: repository_handle = code_location.get_repository("bar_repo").handle @@ -140,7 +262,7 @@ def test_external_partitions_tags_error_grpc(instance: DagsterInstance): with pytest.raises(DagsterUserCodeProcessError): sync_get_external_partition_tags_grpc( - code_location.client, repository_handle, "error_partition_tags", "c", instance + code_location.client, repository_handle, "error_partition_tags", "c", instance, None ) @@ -197,22 +319,14 @@ def test_dynamic_partition_set_grpc(instance: DagsterInstance): assert len(data.partition_data) == 3 data = sync_get_external_partition_config_grpc( - code_location.client, - repository_handle, - "dynamic_job_partition_set", - "a", - instance, + code_location.client, repository_handle, "dynamic_job", "a", instance ) assert isinstance(data, ExternalPartitionConfigData) assert data.name == "a" assert data.run_config == {} data = sync_get_external_partition_tags_grpc( - code_location.client, - repository_handle, - "dynamic_job_partition_set", - "a", - instance, + code_location.client, repository_handle, "dynamic_job", "a", instance, None ) assert isinstance(data, ExternalPartitionTagsData) assert data.tags @@ -232,7 +346,7 @@ def test_dynamic_partition_set_grpc(instance: DagsterInstance): sync_get_external_partition_config_grpc( code_location.client, repository_handle, - "dynamic_job_partition_set", + "dynamic_job", "nonexistent_partition", instance, ) @@ -241,7 +355,8 @@ def test_dynamic_partition_set_grpc(instance: DagsterInstance): sync_get_external_partition_tags_grpc( code_location.client, repository_handle, - "dynamic_job_partition_set", + "dynamic_job", "nonexistent_partition", instance, + None, ) diff --git a/python_modules/dagster/dagster_tests/api_tests/utils.py b/python_modules/dagster/dagster_tests/api_tests/utils.py index 5fb01961a60af..e77193bdf453d 100644 --- a/python_modules/dagster/dagster_tests/api_tests/utils.py +++ b/python_modules/dagster/dagster_tests/api_tests/utils.py @@ -48,6 +48,28 @@ def get_bar_repo_code_location( yield location +@contextmanager +def get_code_location( + python_file: str, + attribute: str, + location_name: str, + instance: Optional[DagsterInstance] = None, +) -> Iterator[GrpcServerCodeLocation]: + with ExitStack() as stack: + if not instance: + instance = stack.enter_context(instance_for_test()) + + loadable_target_origin = LoadableTargetOrigin( + executable_path=sys.executable, + python_file=python_file, + attribute=attribute, + ) + origin = ManagedGrpcPythonEnvCodeLocationOrigin(loadable_target_origin, location_name) + + with origin.create_single_location(instance) as location: + yield location + + @contextmanager def get_bar_repo_handle(instance: Optional[DagsterInstance] = None) -> Iterator[RepositoryHandle]: with ExitStack() as stack: diff --git a/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_materialize_command.py b/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_materialize_command.py index de3b266215437..87ca70a58732e 100644 --- a/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_materialize_command.py +++ b/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_materialize_command.py @@ -117,7 +117,7 @@ def test_one_of_the_asset_keys_missing(): def test_conflicting_partitions(): with instance_for_test(): result = invoke_materialize("partitioned_asset,differently_partitioned_asset", "one") - assert "Attempted to execute a run for assets with different partitions" in str( + assert "There is no PartitionsDefinition shared by all the provided asset" in str( result.exception )