Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make job_name param on PartitionArgs and PartitionNames args optional #23983

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python_modules/dagster/dagster/_grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,8 @@ def ExternalPartitionNames(
serialized_response = serialize_value(
get_partition_names(
self._get_repo_for_origin(partition_names_args.repository_origin),
job_name=partition_names_args.job_name,
selected_asset_keys=partition_names_args.selected_asset_keys,
job_name=partition_names_args.get_job_name(),
)
)
except Exception:
Expand Down Expand Up @@ -701,7 +701,7 @@ def ExternalPartitionConfig(
serialized_data = serialize_value(
get_partition_config(
self._get_repo_for_origin(args.repository_origin),
job_name=args.job_name,
job_name=args.get_job_name(),
partition_key=args.partition_name,
instance_ref=instance_ref,
)
Expand Down Expand Up @@ -731,7 +731,7 @@ def ExternalPartitionTags(
serialized_data = serialize_value(
get_partition_tags(
self._get_repo_for_origin(partition_args.repository_origin),
job_name=partition_args.job_name,
job_name=partition_args.get_job_name(),
partition_name=partition_args.partition_name,
selected_asset_keys=partition_args.selected_asset_keys,
instance_ref=instance_ref,
Expand Down
29 changes: 22 additions & 7 deletions python_modules/dagster/dagster/_grpc/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from dagster._core.execution.retries import RetryMode
from dagster._core.instance.ref import InstanceRef
from dagster._core.origin import JobPythonOrigin, get_python_environment_entry_point
from dagster._core.remote_representation.external_data import DEFAULT_MODE_NAME
from dagster._core.remote_representation.external_data import (
DEFAULT_MODE_NAME,
job_name_for_external_partition_set_name,
)
from dagster._core.remote_representation.origin import (
CodeLocationOrigin,
RemoteJobOrigin,
Expand Down Expand Up @@ -405,10 +408,10 @@ class PartitionArgs(
"_PartitionArgs",
[
("repository_origin", RemoteRepositoryOrigin),
("job_name", str),
# This is here for backcompat. it's expected to always be f"{job_name}_partition_set".
("partition_set_name", str),
("partition_name", str),
("job_name", Optional[str]),
("instance_ref", Optional[InstanceRef]),
# This is introduced in the same release that we're making it possible for an asset job
# to target assets with different PartitionsDefinitions. Prior user code versions can
Expand All @@ -421,9 +424,9 @@ class PartitionArgs(
def __new__(
cls,
repository_origin: RemoteRepositoryOrigin,
job_name: str,
partition_set_name: str,
partition_name: str,
job_name: Optional[str] = None,
instance_ref: Optional[InstanceRef] = None,
selected_asset_keys: Optional[AbstractSet[AssetKey]] = None,
):
Expand All @@ -435,51 +438,63 @@ def __new__(
RemoteRepositoryOrigin,
),
partition_set_name=check.str_param(partition_set_name, "partition_set_name"),
job_name=check.str_param(job_name, "job_name"),
job_name=check.opt_str_param(job_name, "job_name"),
partition_name=check.str_param(partition_name, "partition_name"),
instance_ref=check.opt_inst_param(instance_ref, "instance_ref", InstanceRef),
selected_asset_keys=check.opt_nullable_set_param(
selected_asset_keys, "selected_asset_keys", of_type=AssetKey
),
)

def get_job_name(self) -> str:
if self.job_name:
return self.job_name
else:
return job_name_for_external_partition_set_name(self.partition_set_name)


@whitelist_for_serdes
class PartitionNamesArgs(
NamedTuple(
"_PartitionNamesArgs",
[
("repository_origin", RemoteRepositoryOrigin),
("job_name", str),
# This is here for backcompat. it's expected to always be f"{job_name}_partition_set".
("partition_set_name", str),
# This is introduced in the same release that we're making it possible for an asset job
# to target assets with different PartitionsDefinitions. Prior user code versions can
# (and do) safely ignore this parameter, because, in those versions, the job name on its
# own is enough to specify which PartitionsDefinition to use.
("job_name", Optional[str]),
("selected_asset_keys", Optional[AbstractSet[AssetKey]]),
],
)
):
def __new__(
cls,
repository_origin: RemoteRepositoryOrigin,
job_name: str,
partition_set_name: str,
job_name: Optional[str] = None,
selected_asset_keys: Optional[AbstractSet[AssetKey]] = None,
):
return super(PartitionNamesArgs, cls).__new__(
cls,
repository_origin=check.inst_param(
repository_origin, "repository_origin", RemoteRepositoryOrigin
),
job_name=check.str_param(job_name, "job_name"),
job_name=check.opt_str_param(job_name, "job_name"),
partition_set_name=check.str_param(partition_set_name, "partition_set_name"),
selected_asset_keys=check.opt_nullable_set_param(
selected_asset_keys, "selected_asset_keys", of_type=AssetKey
),
)

def get_job_name(self) -> str:
if self.job_name:
return self.job_name
else:
return job_name_for_external_partition_set_name(self.partition_set_name)


@whitelist_for_serdes
class PartitionSetExecutionParamArgs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,44 @@ def test_dynamic_partition_set_grpc(instance: DagsterInstance):
)
assert isinstance(data, ExternalPartitionSetExecutionParamData)
assert data.partition_data == []


def test_external_partition_tags_grpc_backcompat_no_job_name(instance: DagsterInstance):
with get_bar_repo_code_location(instance) as code_location:
repository_handle = code_location.get_repository("bar_repo").handle

api_client = code_location.client

result = deserialize_value(
api_client.external_partition_tags(
partition_args=PartitionArgs(
repository_origin=repository_handle.get_external_origin(),
partition_set_name="baz_partition_set",
partition_name="c",
instance_ref=instance.get_ref(),
)
)
)

assert isinstance(result, ExternalPartitionTagsData)
assert result.tags
assert result.tags["foo"] == "bar"


def test_external_partition_names_grpc_backcompat_no_job_name(instance: DagsterInstance):
with get_bar_repo_code_location(instance) as code_location:
repository_handle = code_location.get_repository("bar_repo").handle

api_client = code_location.client

result = deserialize_value(
api_client.external_partition_names(
partition_names_args=PartitionNamesArgs(
repository_origin=repository_handle.get_external_origin(),
partition_set_name="baz_partition_set",
)
)
)

assert isinstance(result, ExternalPartitionNamesData)
assert result.partition_names == list(string.ascii_lowercase)