From 3ff50556dc1929b19b41e5263dd4e617ed33c1b6 Mon Sep 17 00:00:00 2001 From: Ben Pankow Date: Wed, 4 Sep 2024 14:34:20 -0700 Subject: [PATCH] [io managers] Add dagster/relation_identifier metadata to IO Manager outputs (#24196) ## Summary Automatically populate `dagster/relation_identifier` metadata on assets which are emitted by a DbIOManager. This should hopefully make it easier to pull data from Dagster into external catalogs. ## Test Plan Unit tests. ## Changelog [New] `dagster/relation_identifier` metadata is automatically attached to assets which are stored using an IO manager. --- .../dagster/_core/storage/db_io_manager.py | 27 +++++++++- .../storage_tests/test_db_io_manager.py | 51 +++++++++++++++---- .../test_type_handler.py | 27 ++++++++++ .../bigquery/test_type_handler.py | 27 ++++++++++ .../test_snowflake_pandas_type_handler.py | 29 +++++++++++ .../test_snowflake_pyspark_type_handler.py | 32 ++++++++++++ 6 files changed, 183 insertions(+), 10 deletions(-) diff --git a/python_modules/dagster/dagster/_core/storage/db_io_manager.py b/python_modules/dagster/dagster/_core/storage/db_io_manager.py index 37eeadb7ec148..39ae22947446d 100644 --- a/python_modules/dagster/dagster/_core/storage/db_io_manager.py +++ b/python_modules/dagster/dagster/_core/storage/db_io_manager.py @@ -19,6 +19,7 @@ import dagster._check as check from dagster._check import CheckError from dagster._core.definitions.metadata import RawMetadataValue +from dagster._core.definitions.metadata.metadata_set import TableMetadataSet from dagster._core.definitions.multi_dimensional_partitions import ( MultiPartitionKey, MultiPartitionsDefinition, @@ -27,7 +28,7 @@ TimeWindow, TimeWindowPartitionsDefinition, ) -from dagster._core.errors import DagsterInvariantViolationError +from dagster._core.errors import DagsterInvalidMetadata, DagsterInvariantViolationError from dagster._core.execution.context.input import InputContext from dagster._core.execution.context.output import OutputContext from dagster._core.storage.io_manager import IOManager @@ -76,6 +77,17 @@ def delete_table_slice( @abstractmethod def get_select_statement(table_slice: TableSlice) -> str: ... + @staticmethod + def get_relation_identifier(table_slice: TableSlice) -> Optional[str]: + """Returns a string which is set as the dagster/relation_identifier metadata value for an + emitted asset. This value should be the fully qualified name of the table, including the + schema and database, if applicable. + """ + if not table_slice.database: + return f"{table_slice.schema}.{table_slice.table}" + + return f"{table_slice.database}.{table_slice.schema}.{table_slice.table}" + @staticmethod @abstractmethod def ensure_schema_exists( @@ -155,6 +167,19 @@ def handle_output(self, context: OutputContext, obj: object) -> None: } ) + # Try to attach relation identifier metadata to the output asset, but + # don't fail if it errors because the user has already attached it. + try: + context.add_output_metadata( + dict( + TableMetadataSet( + relation_identifier=self._db_client.get_relation_identifier(table_slice) + ) + ) + ) + except DagsterInvalidMetadata: + pass + def load_input(self, context: InputContext) -> object: obj_type = context.dagster_type.typing_type if obj_type is Any and self._default_load_type is not None: diff --git a/python_modules/dagster/dagster_tests/storage_tests/test_db_io_manager.py b/python_modules/dagster/dagster_tests/storage_tests/test_db_io_manager.py index ac586637c5da4..547168064cc12 100644 --- a/python_modules/dagster/dagster_tests/storage_tests/test_db_io_manager.py +++ b/python_modules/dagster/dagster_tests/storage_tests/test_db_io_manager.py @@ -25,6 +25,10 @@ } +def mock_relation_identifier(*args, **kwargs) -> str: + return "relation_identifier" + + class IntHandler(DbTypeHandler[int]): def __init__(self): self.handle_input_calls = [] @@ -74,7 +78,10 @@ def test_asset_out(): handler = IntHandler() connect_mock = MagicMock() db_client = MagicMock( - spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock + spec=DbClient, + get_select_statement=MagicMock(return_value=""), + connect=connect_mock, + get_relation_identifier=mock_relation_identifier, ) manager = build_db_io_manager(type_handlers=[handler], db_client=db_client) asset_key = AssetKey(["schema1", "table1"]) @@ -107,7 +114,10 @@ def test_asset_out_columns(): handler = IntHandler() connect_mock = MagicMock() db_client = MagicMock( - spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock + spec=DbClient, + get_select_statement=MagicMock(return_value=""), + connect=connect_mock, + get_relation_identifier=mock_relation_identifier, ) manager = build_db_io_manager(type_handlers=[handler], db_client=db_client) asset_key = AssetKey(["schema1", "table1"]) @@ -146,7 +156,10 @@ def test_asset_out_partitioned(): handler = IntHandler() connect_mock = MagicMock() db_client = MagicMock( - spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock + spec=DbClient, + get_select_statement=MagicMock(return_value=""), + connect=connect_mock, + get_relation_identifier=mock_relation_identifier, ) manager = build_db_io_manager(type_handlers=[handler], db_client=db_client) asset_key = AssetKey(["schema1", "table1"]) @@ -202,7 +215,10 @@ def test_asset_out_static_partitioned(): handler = IntHandler() connect_mock = MagicMock() db_client = MagicMock( - spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock + spec=DbClient, + get_select_statement=MagicMock(return_value=""), + connect=connect_mock, + get_relation_identifier=mock_relation_identifier, ) manager = build_db_io_manager(type_handlers=[handler], db_client=db_client) asset_key = AssetKey(["schema1", "table1"]) @@ -251,7 +267,10 @@ def test_asset_out_multiple_static_partitions(): handler = IntHandler() connect_mock = MagicMock() db_client = MagicMock( - spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock + spec=DbClient, + get_select_statement=MagicMock(return_value=""), + connect=connect_mock, + get_relation_identifier=mock_relation_identifier, ) manager = build_db_io_manager(type_handlers=[handler], db_client=db_client) asset_key = AssetKey(["schema1", "table1"]) @@ -301,7 +320,10 @@ def test_different_output_and_input_types(): str_handler = StringHandler() connect_mock = MagicMock() db_client = MagicMock( - spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock + spec=DbClient, + get_select_statement=MagicMock(return_value=""), + connect=connect_mock, + get_relation_identifier=mock_relation_identifier, ) manager = build_db_io_manager(type_handlers=[int_handler, str_handler], db_client=db_client) asset_key = AssetKey(["schema1", "table1"]) @@ -336,7 +358,10 @@ def test_non_asset_out(): handler = IntHandler() connect_mock = MagicMock() db_client = MagicMock( - spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock + spec=DbClient, + get_select_statement=MagicMock(return_value=""), + connect=connect_mock, + get_relation_identifier=mock_relation_identifier, ) manager = build_db_io_manager(type_handlers=[handler], db_client=db_client) output_context = build_output_context( @@ -371,7 +396,11 @@ def test_non_asset_out(): def test_asset_schema_defaults(): handler = IntHandler() - db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value="")) + db_client = MagicMock( + spec=DbClient, + get_select_statement=MagicMock(return_value=""), + get_relation_identifier=mock_relation_identifier, + ) manager = build_db_io_manager(type_handlers=[handler], db_client=db_client) asset_key = AssetKey(["schema1", "table1"]) @@ -529,7 +558,11 @@ def test_non_supported_type(): def test_default_load_type(): handler = IntHandler() - db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value="")) + db_client = MagicMock( + spec=DbClient, + get_select_statement=MagicMock(return_value=""), + get_relation_identifier=mock_relation_identifier, + ) manager = DbIOManager( type_handlers=[handler], database=resource_config["database"], diff --git a/python_modules/libraries/dagster-duckdb-pandas/dagster_duckdb_pandas_tests/test_type_handler.py b/python_modules/libraries/dagster-duckdb-pandas/dagster_duckdb_pandas_tests/test_type_handler.py index a25bbcfd6b700..afb7be5766b40 100644 --- a/python_modules/libraries/dagster-duckdb-pandas/dagster_duckdb_pandas_tests/test_type_handler.py +++ b/python_modules/libraries/dagster-duckdb-pandas/dagster_duckdb_pandas_tests/test_type_handler.py @@ -9,7 +9,9 @@ AssetIn, AssetKey, DailyPartitionsDefinition, + Definitions, DynamicPartitionsDefinition, + MetadataValue, MultiPartitionKey, MultiPartitionsDefinition, Out, @@ -103,6 +105,31 @@ def test_duckdb_io_manager_with_assets(tmp_path, io_managers): duckdb_conn.close() +def test_io_manager_asset_metadata(tmp_path) -> None: + @asset + def my_pandas_df() -> pd.DataFrame: + return pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + db_file = os.path.join(tmp_path, "unit_test.duckdb") + defs = Definitions( + assets=[my_pandas_df], + resources={ + "io_manager": DuckDBPandasIOManager(database=db_file, schema="custom_schema"), + }, + ) + + res = defs.get_implicit_global_asset_job_def().execute_in_process() + assert res.success + + mats = res.get_asset_materialization_events() + assert len(mats) == 1 + mat = mats[0] + + assert mat.materialization.metadata["dagster/relation_identifier"] == MetadataValue.text( + f"{db_file}.custom_schema.my_pandas_df" + ) + + def test_duckdb_io_manager_with_schema(tmp_path): @asset def my_df() -> pd.DataFrame: diff --git a/python_modules/libraries/dagster-gcp-pandas/dagster_gcp_pandas_tests/bigquery/test_type_handler.py b/python_modules/libraries/dagster-gcp-pandas/dagster_gcp_pandas_tests/bigquery/test_type_handler.py index 2db8a2596357d..2ef95ac030d71 100644 --- a/python_modules/libraries/dagster-gcp-pandas/dagster_gcp_pandas_tests/bigquery/test_type_handler.py +++ b/python_modules/libraries/dagster-gcp-pandas/dagster_gcp_pandas_tests/bigquery/test_type_handler.py @@ -11,8 +11,10 @@ AssetIn, AssetKey, DailyPartitionsDefinition, + Definitions, DynamicPartitionsDefinition, EnvVar, + MetadataValue, MultiPartitionKey, MultiPartitionsDefinition, Out, @@ -57,6 +59,31 @@ def temporary_bigquery_table(schema_name: Optional[str]) -> Iterator[str]: ).result() +@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB") +@pytest.mark.integration +def test_io_manager_asset_metadata() -> None: + with temporary_bigquery_table(schema_name=SCHEMA) as table_name: + + @asset(key_prefix=SCHEMA, name=table_name) + def my_pandas_df() -> pd.DataFrame: + return pd.DataFrame({"foo": ["bar", "baz"], "quux": [1, 2]}) + + defs = Definitions( + assets=[my_pandas_df], resources={"io_manager": pythonic_bigquery_io_manager} + ) + + res = defs.get_implicit_global_asset_job_def().execute_in_process() + assert res.success + + mats = res.get_asset_materialization_events() + assert len(mats) == 1 + mat = mats[0] + + assert mat.materialization.metadata["dagster/relation_identifier"] == MetadataValue.text( + f"{os.getenv('GCP_PROJECT_ID')}.{SCHEMA}.{table_name}" + ) + + @pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE bigquery DB") @pytest.mark.parametrize("io_manager", [(old_bigquery_io_manager), (pythonic_bigquery_io_manager)]) @pytest.mark.integration diff --git a/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas_tests/test_snowflake_pandas_type_handler.py b/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas_tests/test_snowflake_pandas_type_handler.py index fd2d853c6c9bf..a869f15a093b3 100644 --- a/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas_tests/test_snowflake_pandas_type_handler.py +++ b/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas_tests/test_snowflake_pandas_type_handler.py @@ -11,6 +11,7 @@ AssetIn, AssetKey, DailyPartitionsDefinition, + Definitions, DynamicPartitionsDefinition, IOManagerDefinition, MetadataValue, @@ -208,6 +209,34 @@ def io_manager_test_job(): assert res.success +@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB") +@pytest.mark.integration +def test_io_manager_asset_metadata() -> None: + with temporary_snowflake_table( + schema_name=SCHEMA, + db_name=DATABASE, + ) as table_name: + + @asset(key_prefix=SCHEMA, name=table_name) + def my_pandas_df(): + return pandas.DataFrame({"foo": ["bar", "baz"], "quux": [1, 2]}) + + defs = Definitions( + assets=[my_pandas_df], resources={"io_manager": pythonic_snowflake_io_manager} + ) + + res = defs.get_implicit_global_asset_job_def().execute_in_process() + assert res.success + + mats = res.get_asset_materialization_events() + assert len(mats) == 1 + mat = mats[0] + + assert mat.materialization.metadata["dagster/relation_identifier"] == MetadataValue.text( + f"{DATABASE}.{SCHEMA}.{table_name}" + ) + + @pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB") @pytest.mark.parametrize( "io_manager", [(snowflake_pandas_io_manager), (SnowflakePandasIOManager.configure_at_launch())] diff --git a/python_modules/libraries/dagster-snowflake-pyspark/dagster_snowflake_pyspark_tests/test_snowflake_pyspark_type_handler.py b/python_modules/libraries/dagster-snowflake-pyspark/dagster_snowflake_pyspark_tests/test_snowflake_pyspark_type_handler.py index 94719db55ccc4..5c8586a10d62d 100644 --- a/python_modules/libraries/dagster-snowflake-pyspark/dagster_snowflake_pyspark_tests/test_snowflake_pyspark_type_handler.py +++ b/python_modules/libraries/dagster-snowflake-pyspark/dagster_snowflake_pyspark_tests/test_snowflake_pyspark_type_handler.py @@ -10,6 +10,7 @@ AssetIn, AssetKey, DailyPartitionsDefinition, + Definitions, DynamicPartitionsDefinition, EnvVar, IOManagerDefinition, @@ -184,6 +185,37 @@ def io_manager_test_job(): assert res.success +@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB") +@pytest.mark.integration +def test_io_manager_asset_metadata(spark) -> None: + with temporary_snowflake_table( + schema_name=SCHEMA, + db_name=DATABASE, + ) as table_name: + + @asset(key_prefix=SCHEMA, name=table_name) + def my_spark_df(): + columns = ["foo", "quux"] + data = [("bar", 1), ("baz", 2)] + df = spark.createDataFrame(data).toDF(*columns) + return df + + defs = Definitions( + assets=[my_spark_df], resources={"io_manager": pythonic_snowflake_io_manager} + ) + + res = defs.get_implicit_global_asset_job_def().execute_in_process() + assert res.success + + mats = res.get_asset_materialization_events() + assert len(mats) == 1 + mat = mats[0] + + assert mat.materialization.metadata["dagster/relation_identifier"] == MetadataValue.text( + f"{DATABASE}.{SCHEMA}.{table_name}" + ) + + @pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB") @pytest.mark.parametrize( "io_manager", [(old_snowflake_io_manager), (pythonic_snowflake_io_manager)]