From 9c8d3c523dc801ac6c6909cfc81ce719b473804b Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 22 Aug 2024 16:28:46 -0700 Subject: [PATCH] Add tests for grpc and graphql against single implicit asset jobs (#23807) ## Summary & Motivation While working on https://github.com/dagster-io/dagster/pull/23491 and https://github.com/dagster-io/dagster/pull/23494, I wrote a set of tests that validate that operations like `get_tags_for_partition` over GraphQL and gRPC, when an asset job includes assets with different `PartitionsDefinition`s and an asset selection is provided. However, I had to take them out due to the way the stack was sequenced. This PR adds them back in. ## How I Tested These Changes ## Changelog NOCHANGELOG --- .../graphql/test_job_partitions.py | 114 +++++++++++++++++- .../api_tests/test_api_snapshot_partition.py | 78 +++++++++++- 2 files changed, 190 insertions(+), 2 deletions(-) diff --git a/python_modules/dagster-graphql/dagster_graphql_tests/graphql/test_job_partitions.py b/python_modules/dagster-graphql/dagster_graphql_tests/graphql/test_job_partitions.py index d4f66eb908f81..80c723f2baf46 100644 --- a/python_modules/dagster-graphql/dagster_graphql_tests/graphql/test_job_partitions.py +++ b/python_modules/dagster-graphql/dagster_graphql_tests/graphql/test_job_partitions.py @@ -1,4 +1,13 @@ -from dagster import Definitions, job, op, static_partitioned_config +from dagster import ( + AssetKey, + ConfigurableResource, + Definitions, + StaticPartitionsDefinition, + asset, + job, + op, + static_partitioned_config, +) from dagster._core.definitions.repository_definition import SINGLETON_REPOSITORY_NAME from dagster._core.test_utils import ensure_dagster_tests_import, instance_for_test from dagster_graphql.test.utils import define_out_of_process_context, execute_dagster_graphql @@ -74,6 +83,26 @@ def job1(): return Definitions(jobs=[job1]).get_repository_def() +def get_repo_with_differently_partitioned_assets(): + @asset(partitions_def=StaticPartitionsDefinition(["1", "2"])) + def asset1(): ... + + ab_partitions_def = StaticPartitionsDefinition(["a", "b"]) + + @asset(partitions_def=ab_partitions_def) + def asset2(): ... + + class MyResource(ConfigurableResource): + foo: str + + @asset(partitions_def=ab_partitions_def) + def asset3(resource1: MyResource): ... + + return Definitions( + assets=[asset1, asset2, asset3], resources={"resource1": MyResource(foo="bar")} + ).get_repository_def() + + def test_get_partition_names(): with instance_for_test() as instance: with define_out_of_process_context( @@ -97,6 +126,33 @@ def test_get_partition_names(): ] +def test_get_partition_names_asset_selection(): + with instance_for_test() as instance: + with define_out_of_process_context( + __file__, "get_repo_with_differently_partitioned_assets", instance + ) as context: + result = execute_dagster_graphql( + context, + GET_PARTITIONS_QUERY, + variables={ + "selector": { + "repositoryLocationName": context.code_location_names[0], + "repositoryName": SINGLETON_REPOSITORY_NAME, + "pipelineName": "__ASSET_JOB", + }, + "selectedAssetKeys": [ + AssetKey("asset2").to_graphql_input(), + AssetKey("asset3").to_graphql_input(), + ], + }, + ) + assert result.data["pipelineOrError"]["name"] == "__ASSET_JOB" + assert result.data["pipelineOrError"]["partitionKeysOrError"]["partitionKeys"] == [ + "a", + "b", + ] + + def test_get_partition_tags(): with instance_for_test() as instance: with define_out_of_process_context( @@ -125,6 +181,35 @@ def test_get_partition_tags(): } +def test_get_partition_tags_asset_selection(): + with instance_for_test() as instance: + with define_out_of_process_context( + __file__, "get_repo_with_differently_partitioned_assets", instance + ) as context: + result = execute_dagster_graphql( + context, + GET_PARTITION_TAGS_QUERY, + variables={ + "selector": { + "repositoryLocationName": context.code_location_names[0], + "repositoryName": SINGLETON_REPOSITORY_NAME, + "pipelineName": "__ASSET_JOB", + }, + "selectedAssetKeys": [ + AssetKey("asset2").to_graphql_input(), + AssetKey("asset3").to_graphql_input(), + ], + "partitionName": "b", + }, + ) + assert result.data["pipelineOrError"]["name"] == "__ASSET_JOB" + result_partition = result.data["pipelineOrError"]["partition"] + assert result_partition["name"] == "b" + assert { + item["key"]: item["value"] for item in result_partition["tagsOrError"]["results"] + } == {"dagster/partition": "b"} + + def test_get_partition_config(): with instance_for_test() as instance: with define_out_of_process_context( @@ -149,3 +234,30 @@ def test_get_partition_config(): result_partition["runConfigOrError"]["yaml"] == """ops:\n op1:\n config:\n p: '1'\n""" ) + + +def test_get_partition_config_asset_selection(): + with instance_for_test() as instance: + with define_out_of_process_context( + __file__, "get_repo_with_differently_partitioned_assets", instance + ) as context: + result = execute_dagster_graphql( + context, + GET_PARTITION_RUN_CONFIG_QUERY, + variables={ + "selector": { + "repositoryLocationName": context.code_location_names[0], + "repositoryName": SINGLETON_REPOSITORY_NAME, + "pipelineName": "__ASSET_JOB", + }, + "selectedAssetKeys": [ + AssetKey("asset2").to_graphql_input(), + AssetKey("asset3").to_graphql_input(), + ], + "partitionName": "b", + }, + ) + assert result.data["pipelineOrError"]["name"] == "__ASSET_JOB" + result_partition = result.data["pipelineOrError"]["partition"] + assert result_partition["name"] == "b" + assert result_partition["runConfigOrError"]["yaml"] == "{}\n" diff --git a/python_modules/dagster/dagster_tests/api_tests/test_api_snapshot_partition.py b/python_modules/dagster/dagster_tests/api_tests/test_api_snapshot_partition.py index a96047046d541..5187bb7336355 100644 --- a/python_modules/dagster/dagster_tests/api_tests/test_api_snapshot_partition.py +++ b/python_modules/dagster/dagster_tests/api_tests/test_api_snapshot_partition.py @@ -1,12 +1,15 @@ import string import pytest +from dagster import AssetKey, ConfigurableResource, Definitions, StaticPartitionsDefinition, asset from dagster._api.snapshot_partition import ( sync_get_external_partition_config_grpc, sync_get_external_partition_names_grpc, sync_get_external_partition_set_execution_param_data_grpc, sync_get_external_partition_tags_grpc, ) +from dagster._core.definitions.asset_job import IMPLICIT_ASSET_JOB_NAME +from dagster._core.definitions.repository_definition import SINGLETON_REPOSITORY_NAME from dagster._core.errors import DagsterUserCodeProcessError from dagster._core.instance import DagsterInstance from dagster._core.remote_representation import ( @@ -22,7 +25,27 @@ ensure_dagster_tests_import() -from dagster_tests.api_tests.utils import get_bar_repo_code_location +from dagster_tests.api_tests.utils import get_bar_repo_code_location, get_code_location # noqa: I001 + + +def get_repo_with_differently_partitioned_assets(): + @asset(partitions_def=StaticPartitionsDefinition(["1", "2"])) + def asset1(): ... + + ab_partitions_def = StaticPartitionsDefinition(["a", "b"]) + + @asset(partitions_def=ab_partitions_def) + def asset2(): ... + + class MyResource(ConfigurableResource): + foo: str + + @asset(partitions_def=ab_partitions_def) + def asset3(resource1: MyResource): ... + + return Definitions( + assets=[asset1, asset2, asset3], resources={"resource1": MyResource(foo="bar")} + ).get_repository_def() def test_external_partition_names_grpc(instance: DagsterInstance): @@ -47,6 +70,23 @@ def test_external_partition_names(instance: DagsterInstance): assert data.partition_names == list(string.ascii_lowercase) +def test_external_partition_names_asset_selection(instance: DagsterInstance): + with get_code_location( + python_file=__file__, + attribute="get_repo_with_differently_partitioned_assets", + location_name="something", + instance=instance, + ) as code_location: + data = code_location.get_external_partition_names( + repository_handle=code_location.get_repository(SINGLETON_REPOSITORY_NAME).handle, + job_name=IMPLICIT_ASSET_JOB_NAME, + instance=instance, + selected_asset_keys={AssetKey("asset2"), AssetKey("asset3")}, + ) + assert isinstance(data, ExternalPartitionNamesData) + assert data.partition_names == ["a", "b"] + + def test_external_partition_names_deserialize_error_grpc(instance: DagsterInstance): with get_bar_repo_code_location(instance) as code_location: api_client = code_location.client @@ -93,6 +133,23 @@ def test_external_partition_config(instance: DagsterInstance): assert data.run_config["ops"]["do_input"]["inputs"]["x"]["value"] == "c" # type: ignore +def test_external_partition_config_different_partitions_defs(instance: DagsterInstance): + with get_code_location( + python_file=__file__, + attribute="get_repo_with_differently_partitioned_assets", + location_name="something", + instance=instance, + ) as code_location: + data = code_location.get_external_partition_config( + job_name=IMPLICIT_ASSET_JOB_NAME, + repository_handle=code_location.get_repository(SINGLETON_REPOSITORY_NAME).handle, + partition_name="b", + instance=instance, + ) + assert isinstance(data, ExternalPartitionConfigData) + assert data.run_config == {} + + def test_external_partitions_config_error_grpc(instance: DagsterInstance): with get_bar_repo_code_location(instance) as code_location: repository_handle = code_location.get_repository("bar_repo").handle @@ -157,6 +214,25 @@ def test_external_partition_tags(instance: DagsterInstance): assert data.tags["foo"] == "bar" +def test_external_partition_tags_different_partitions_defs(instance: DagsterInstance): + with get_code_location( + python_file=__file__, + attribute="get_repo_with_differently_partitioned_assets", + location_name="something", + instance=instance, + ) as code_location: + data = code_location.get_external_partition_tags( + repository_handle=code_location.get_repository(SINGLETON_REPOSITORY_NAME).handle, + job_name=IMPLICIT_ASSET_JOB_NAME, + selected_asset_keys={AssetKey("asset2"), AssetKey("asset3")}, + partition_name="b", + instance=instance, + ) + assert isinstance(data, ExternalPartitionTagsData) + assert data.tags + assert data.tags["dagster/partition"] == "b" + + def test_external_partitions_tags_deserialize_error_grpc(instance: DagsterInstance): with get_bar_repo_code_location(instance) as code_location: repository_handle = code_location.get_repository("bar_repo").handle