Skip to content

Commit

Permalink
grpc
Browse files Browse the repository at this point in the history
branch-name: partition-set-to-job/grpc
  • Loading branch information
sryza committed Aug 8, 2024
1 parent 34a855a commit 8500be5
Show file tree
Hide file tree
Showing 15 changed files with 401 additions and 185 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def create_and_launch_partition_backfill(

if backfill_params.get("allPartitions"):
result = graphene_info.context.get_external_partition_names(
external_partition_set, instance=graphene_info.context.instance
repository_handle=repository.handle,
job_name=external_partition_set.job_name,
instance=graphene_info.context.instance,
selected_asset_keys=None,
)
if isinstance(result, ExternalPartitionExecutionErrorData):
raise DagsterUserCodeProcessError.from_error_info(result.error)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Optional, Sequence, Union
from typing import TYPE_CHECKING, AbstractSet, Optional, Sequence, Union

import dagster._check as check
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.selector import RepositorySelector
from dagster._core.errors import DagsterUserCodeProcessError
from dagster._core.remote_representation import ExternalPartitionSet, RepositoryHandle
Expand Down Expand Up @@ -107,20 +108,18 @@ def get_partition_by_name(
def get_partition_config(
graphene_info: ResolveInfo,
repository_handle: RepositoryHandle,
partition_set_name: str,
job_name: str,
partition_name: str,
selected_asset_keys: Optional[AbstractSet[AssetKey]],
) -> "GraphenePartitionRunConfig":
from ..schema.partition_sets import GraphenePartitionRunConfig

check.inst_param(repository_handle, "repository_handle", RepositoryHandle)
check.str_param(partition_set_name, "partition_set_name")
check.str_param(job_name, "job_name")
check.str_param(partition_name, "partition_name")

result = graphene_info.context.get_external_partition_config(
repository_handle,
partition_set_name,
partition_name,
graphene_info.context.instance,
repository_handle, job_name, partition_name, graphene_info.context.instance
)

if isinstance(result, ExternalPartitionExecutionErrorData):
Expand All @@ -132,18 +131,23 @@ def get_partition_config(
def get_partition_tags(
graphene_info: ResolveInfo,
repository_handle: RepositoryHandle,
partition_set_name: str,
job_name: str,
partition_name: str,
selected_asset_keys: Optional[AbstractSet[AssetKey]],
) -> "GraphenePartitionTags":
from ..schema.partition_sets import GraphenePartitionTags
from ..schema.tags import GraphenePipelineTag

check.inst_param(repository_handle, "repository_handle", RepositoryHandle)
check.str_param(partition_set_name, "partition_set_name")
check.str_param(job_name, "job_name")
check.str_param(partition_name, "partition_name")

result = graphene_info.context.get_external_partition_tags(
repository_handle, partition_set_name, partition_name, graphene_info.context.instance
repository_handle,
job_name,
partition_name,
graphene_info.context.instance,
selected_asset_keys=selected_asset_keys,
)

if isinstance(result, ExternalPartitionExecutionErrorData):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ExternalPartitionsDefinitionData,
ExternalStaticPartitionsDefinitionData,
ExternalTimeWindowPartitionsDefinitionData,
job_name_for_external_partition_set_name,
)
from dagster._core.storage.dagster_run import RunsFilter
from dagster._core.storage.tags import PARTITION_NAME_TAG, PARTITION_SET_TAG
Expand Down Expand Up @@ -227,17 +228,19 @@ def resolve_runConfigOrError(self, graphene_info: ResolveInfo):
return get_partition_config(
graphene_info,
self._external_repository_handle,
self._external_partition_set.name,
job_name_for_external_partition_set_name(self._external_partition_set.name),
self._partition_name,
selected_asset_keys=None,
)

@capture_error
def resolve_tagsOrError(self, graphene_info: ResolveInfo):
return get_partition_tags(
graphene_info,
self._external_repository_handle,
self._external_partition_set.name,
job_name_for_external_partition_set_name(self._external_partition_set.name),
self._partition_name,
selected_asset_keys=None,
)

def resolve_runs(
Expand Down Expand Up @@ -326,8 +329,10 @@ def __init__(
def _get_partition_names(self, graphene_info: ResolveInfo) -> Sequence[str]:
if self._partition_names is None:
result = graphene_info.context.get_external_partition_names(
self._external_partition_set,
repository_handle=self._external_repository_handle,
job_name=self._external_partition_set.job_name,
instance=graphene_info.context.instance,
selected_asset_keys=None,
)
if isinstance(result, ExternalPartitionExecutionErrorData):
raise DagsterUserCodeProcessError.from_error_info(result.error)
Expand Down
37 changes: 21 additions & 16 deletions python_modules/dagster/dagster/_api/snapshot_partition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, AbstractSet, Optional, Sequence

import dagster._check as check
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.errors import DagsterUserCodeProcessError
from dagster._core.instance import DagsterInstance
from dagster._core.remote_representation.external_data import (
Expand All @@ -9,7 +10,7 @@
ExternalPartitionNamesData,
ExternalPartitionSetExecutionParamData,
ExternalPartitionTagsData,
job_name_for_external_partition_set_name,
external_partition_set_name_for_job_name,
)
from dagster._core.remote_representation.handle import RepositoryHandle
from dagster._grpc.types import PartitionArgs, PartitionNamesArgs, PartitionSetExecutionParamArgs
Expand All @@ -20,21 +21,24 @@


def sync_get_external_partition_names_grpc(
api_client: "DagsterGrpcClient", repository_handle: RepositoryHandle, partition_set_name: str
api_client: "DagsterGrpcClient",
repository_handle: RepositoryHandle,
job_name: str,
selected_asset_keys: Optional[AbstractSet[AssetKey]],
) -> ExternalPartitionNamesData:
from dagster._grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
check.inst_param(repository_handle, "repository_handle", RepositoryHandle)
check.str_param(partition_set_name, "partition_set_name")
check.str_param(job_name, "job_name")
repository_origin = repository_handle.get_external_origin()
result = deserialize_value(
api_client.external_partition_names(
partition_names_args=PartitionNamesArgs(
repository_origin=repository_origin,
job_name=job_name_for_external_partition_set_name(partition_set_name),
partition_set_name=partition_set_name,
selected_asset_keys=None,
job_name=job_name,
partition_set_name=external_partition_set_name_for_job_name(job_name),
selected_asset_keys=selected_asset_keys,
),
),
(ExternalPartitionNamesData, ExternalPartitionExecutionErrorData),
Expand All @@ -48,23 +52,23 @@ def sync_get_external_partition_names_grpc(
def sync_get_external_partition_config_grpc(
api_client: "DagsterGrpcClient",
repository_handle: RepositoryHandle,
partition_set_name: str,
job_name: str,
partition_name: str,
instance: DagsterInstance,
) -> ExternalPartitionConfigData:
from dagster._grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
check.inst_param(repository_handle, "repository_handle", RepositoryHandle)
check.str_param(partition_set_name, "partition_set_name")
check.str_param(job_name, "job_name")
check.str_param(partition_name, "partition_name")
repository_origin = repository_handle.get_external_origin()
result = deserialize_value(
api_client.external_partition_config(
partition_args=PartitionArgs(
repository_origin=repository_origin,
job_name=job_name_for_external_partition_set_name(partition_set_name),
partition_set_name=partition_set_name,
job_name=job_name,
partition_set_name=external_partition_set_name_for_job_name(job_name),
partition_name=partition_name,
instance_ref=instance.get_ref(),
selected_asset_keys=None,
Expand All @@ -81,27 +85,28 @@ def sync_get_external_partition_config_grpc(
def sync_get_external_partition_tags_grpc(
api_client: "DagsterGrpcClient",
repository_handle: RepositoryHandle,
partition_set_name: str,
job_name: str,
partition_name: str,
instance: DagsterInstance,
selected_asset_keys: Optional[AbstractSet[AssetKey]],
) -> ExternalPartitionTagsData:
from dagster._grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
check.inst_param(repository_handle, "repository_handle", RepositoryHandle)
check.str_param(partition_set_name, "partition_set_name")
check.str_param(job_name, "job_name")
check.str_param(partition_name, "partition_name")

repository_origin = repository_handle.get_external_origin()
result = deserialize_value(
api_client.external_partition_tags(
partition_args=PartitionArgs(
repository_origin=repository_origin,
job_name=job_name_for_external_partition_set_name(partition_set_name),
partition_set_name=partition_set_name,
job_name=job_name,
partition_set_name=external_partition_set_name_for_job_name(job_name),
partition_name=partition_name,
instance_ref=instance.get_ref(),
selected_asset_keys=None,
selected_asset_keys=selected_asset_keys,
),
),
(ExternalPartitionTagsData, ExternalPartitionExecutionErrorData),
Expand Down
5 changes: 4 additions & 1 deletion python_modules/dagster/dagster/_cli/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ def execute_materialize_command(instance: DagsterInstance, kwargs: Mapping[str,
check.failed("Provided '--partition' option, but none of the assets are partitioned")

try:
tags = implicit_job_def.get_tags_for_partition_key(
implicit_job_def.validate_partition_key(
partition, selected_asset_keys=asset_keys, dynamic_partitions_store=instance
)
tags = implicit_job_def.get_tags_for_partition_key(
partition, selected_asset_keys=asset_keys
)
except DagsterUnknownPartitionError:
raise DagsterInvalidSubsetError(
"All selected assets must have a PartitionsDefinition containing the passed"
Expand Down
5 changes: 4 additions & 1 deletion python_modules/dagster/dagster/_cli/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,10 @@ def _execute_backfill_command_at_location(

try:
partition_names_or_error = code_location.get_external_partition_names(
job_partition_set, instance=instance
repository_handle=repo_handle,
job_name=external_job.name,
instance=instance,
selected_asset_keys=None,
)
except Exception as e:
error_info = serializable_error_info_from_exc_info(sys.exc_info())
Expand Down
66 changes: 45 additions & 21 deletions python_modules/dagster/dagster/_core/definitions/job_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,11 +685,15 @@ def execute_in_process(

merged_tags = merge_dicts(self.tags, tags or {})
if partition_key:
tags_for_partition_key = ephemeral_job.get_tags_for_partition_key(
ephemeral_job.validate_partition_key(
partition_key,
selected_asset_keys=asset_selection,
dynamic_partitions_store=instance,
)
tags_for_partition_key = ephemeral_job.get_tags_for_partition_key(
partition_key,
selected_asset_keys=asset_selection,
)

if not run_config and self.partitioned_config:
run_config = self.partitioned_config.get_run_config_for_partition_key(partition_key)
Expand All @@ -714,18 +718,11 @@ def execute_in_process(
asset_selection=frozenset(asset_selection),
)

def get_tags_for_partition_key(
self,
partition_key: str,
dynamic_partitions_store: Optional["DynamicPartitionsStore"],
selected_asset_keys: Optional[Iterable[AssetKey]],
) -> Mapping[str, str]:
"""Gets tags for the given partition key and ensures that it's a member of the PartitionsDefinition
corresponding to every asset in the selection.
"""
partitions_def = None
def _get_partitions_def(
self, selected_asset_keys: Optional[Iterable[AssetKey]]
) -> PartitionsDefinition:
if self.partitions_def:
partitions_def = self.partitions_def
return self.partitions_def
elif self.asset_layer:
if selected_asset_keys:
resolved_selected_asset_keys = selected_asset_keys
Expand All @@ -736,21 +733,48 @@ def get_tags_for_partition_key(
key for key in self.asset_layer.asset_keys_by_node_output_handle.values()
]

unique_partitions_defs = {
self.asset_layer.get(asset_key).partitions_def
for asset_key in resolved_selected_asset_keys
} - {None}
unique_partitions_defs: Set[PartitionsDefinition] = set()
for asset_key in resolved_selected_asset_keys:
partitions_def = self.asset_layer.get(asset_key).partitions_def
if partitions_def is not None:
unique_partitions_defs.add(partitions_def)

if len(unique_partitions_defs) == 1:
partitions_def = next(iter(unique_partitions_defs))
elif len(unique_partitions_defs) > 1:
check.failed("Attempted to execute a run for assets with different partitions")
return check.not_none(next(iter(unique_partitions_defs)))

if selected_asset_keys is not None:
check.failed("There is no PartitionsDefinition shared by all the provided assets")
else:
check.failed("Job has no PartitionsDefinition")

if partitions_def is None:
check.failed("Attempted to execute a partitioned run for a non-partitioned job")
def get_partition_keys(
self, selected_asset_keys: Optional[Iterable[AssetKey]]
) -> Sequence[str]:
partitions_def = self._get_partitions_def(selected_asset_keys)
return partitions_def.get_partition_keys()

def validate_partition_key(
self,
partition_key: str,
dynamic_partitions_store: Optional["DynamicPartitionsStore"],
selected_asset_keys: Optional[Iterable[AssetKey]],
) -> None:
"""Ensures that the given partition_key is a member of the PartitionsDefinition
corresponding to every asset in the selection.
"""
partitions_def = self._get_partitions_def(selected_asset_keys)
partitions_def.validate_partition_key(
partition_key, dynamic_partitions_store=dynamic_partitions_store
)

def get_tags_for_partition_key(
self, partition_key: str, selected_asset_keys: Optional[Iterable[AssetKey]]
) -> Mapping[str, str]:
"""Gets tags for the given partition key."""
if self._partitioned_config is not None:
return self._partitioned_config.get_tags_for_partition_key(partition_key, self.name)

partitions_def = self._get_partitions_def(selected_asset_keys)
return partitions_def.get_tags_for_partition_key(partition_key)

def get_run_config_for_partition_key(self, partition_key: str) -> Mapping[str, Any]:
Expand Down
10 changes: 7 additions & 3 deletions python_modules/dagster/dagster/_core/definitions/run_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,16 @@ def with_resolved_tags_and_config(
if dynamic_partitions_store
else None
)
target_definition.validate_partition_key(
self.partition_key,
dynamic_partitions_store=dynamic_partitions_store_after_requests,
selected_asset_keys=self.asset_selection,
)

tags = {
**(self.tags or {}),
**target_definition.get_tags_for_partition_key(
self.partition_key,
dynamic_partitions_store=dynamic_partitions_store_after_requests,
selected_asset_keys=self.asset_selection,
self.partition_key, selected_asset_keys=self.asset_selection
),
}

Expand Down
Loading

0 comments on commit 8500be5

Please sign in to comment.