Skip to content

Commit

Permalink
[io managers] Add dagster/relation_identifier metadata to IO Manager …
Browse files Browse the repository at this point in the history
…outputs (#24196)

## Summary

Automatically populate `dagster/relation_identifier` metadata on assets
which are emitted by a DbIOManager.

This should hopefully make it easier to pull data from Dagster into
external catalogs.

## Test Plan

Unit tests.

## Changelog [New]

`dagster/relation_identifier` metadata is automatically attached to
assets which are stored using an IO manager.
  • Loading branch information
benpankow authored Sep 4, 2024
1 parent 9b3bf2d commit 3ff5055
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 10 deletions.
27 changes: 26 additions & 1 deletion python_modules/dagster/dagster/_core/storage/db_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dagster._check as check
from dagster._check import CheckError
from dagster._core.definitions.metadata import RawMetadataValue
from dagster._core.definitions.metadata.metadata_set import TableMetadataSet
from dagster._core.definitions.multi_dimensional_partitions import (
MultiPartitionKey,
MultiPartitionsDefinition,
Expand All @@ -27,7 +28,7 @@
TimeWindow,
TimeWindowPartitionsDefinition,
)
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.errors import DagsterInvalidMetadata, DagsterInvariantViolationError
from dagster._core.execution.context.input import InputContext
from dagster._core.execution.context.output import OutputContext
from dagster._core.storage.io_manager import IOManager
Expand Down Expand Up @@ -76,6 +77,17 @@ def delete_table_slice(
@abstractmethod
def get_select_statement(table_slice: TableSlice) -> str: ...

@staticmethod
def get_relation_identifier(table_slice: TableSlice) -> Optional[str]:
"""Returns a string which is set as the dagster/relation_identifier metadata value for an
emitted asset. This value should be the fully qualified name of the table, including the
schema and database, if applicable.
"""
if not table_slice.database:
return f"{table_slice.schema}.{table_slice.table}"

return f"{table_slice.database}.{table_slice.schema}.{table_slice.table}"

@staticmethod
@abstractmethod
def ensure_schema_exists(
Expand Down Expand Up @@ -155,6 +167,19 @@ def handle_output(self, context: OutputContext, obj: object) -> None:
}
)

# Try to attach relation identifier metadata to the output asset, but
# don't fail if it errors because the user has already attached it.
try:
context.add_output_metadata(
dict(
TableMetadataSet(
relation_identifier=self._db_client.get_relation_identifier(table_slice)
)
)
)
except DagsterInvalidMetadata:
pass

def load_input(self, context: InputContext) -> object:
obj_type = context.dagster_type.typing_type
if obj_type is Any and self._default_load_type is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
}


def mock_relation_identifier(*args, **kwargs) -> str:
return "relation_identifier"


class IntHandler(DbTypeHandler[int]):
def __init__(self):
self.handle_input_calls = []
Expand Down Expand Up @@ -74,7 +78,10 @@ def test_asset_out():
handler = IntHandler()
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
spec=DbClient,
get_select_statement=MagicMock(return_value=""),
connect=connect_mock,
get_relation_identifier=mock_relation_identifier,
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
Expand Down Expand Up @@ -107,7 +114,10 @@ def test_asset_out_columns():
handler = IntHandler()
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
spec=DbClient,
get_select_statement=MagicMock(return_value=""),
connect=connect_mock,
get_relation_identifier=mock_relation_identifier,
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
Expand Down Expand Up @@ -146,7 +156,10 @@ def test_asset_out_partitioned():
handler = IntHandler()
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
spec=DbClient,
get_select_statement=MagicMock(return_value=""),
connect=connect_mock,
get_relation_identifier=mock_relation_identifier,
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
Expand Down Expand Up @@ -202,7 +215,10 @@ def test_asset_out_static_partitioned():
handler = IntHandler()
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
spec=DbClient,
get_select_statement=MagicMock(return_value=""),
connect=connect_mock,
get_relation_identifier=mock_relation_identifier,
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
Expand Down Expand Up @@ -251,7 +267,10 @@ def test_asset_out_multiple_static_partitions():
handler = IntHandler()
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
spec=DbClient,
get_select_statement=MagicMock(return_value=""),
connect=connect_mock,
get_relation_identifier=mock_relation_identifier,
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
Expand Down Expand Up @@ -301,7 +320,10 @@ def test_different_output_and_input_types():
str_handler = StringHandler()
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
spec=DbClient,
get_select_statement=MagicMock(return_value=""),
connect=connect_mock,
get_relation_identifier=mock_relation_identifier,
)
manager = build_db_io_manager(type_handlers=[int_handler, str_handler], db_client=db_client)
asset_key = AssetKey(["schema1", "table1"])
Expand Down Expand Up @@ -336,7 +358,10 @@ def test_non_asset_out():
handler = IntHandler()
connect_mock = MagicMock()
db_client = MagicMock(
spec=DbClient, get_select_statement=MagicMock(return_value=""), connect=connect_mock
spec=DbClient,
get_select_statement=MagicMock(return_value=""),
connect=connect_mock,
get_relation_identifier=mock_relation_identifier,
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)
output_context = build_output_context(
Expand Down Expand Up @@ -371,7 +396,11 @@ def test_non_asset_out():

def test_asset_schema_defaults():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
db_client = MagicMock(
spec=DbClient,
get_select_statement=MagicMock(return_value=""),
get_relation_identifier=mock_relation_identifier,
)
manager = build_db_io_manager(type_handlers=[handler], db_client=db_client)

asset_key = AssetKey(["schema1", "table1"])
Expand Down Expand Up @@ -529,7 +558,11 @@ def test_non_supported_type():

def test_default_load_type():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
db_client = MagicMock(
spec=DbClient,
get_select_statement=MagicMock(return_value=""),
get_relation_identifier=mock_relation_identifier,
)
manager = DbIOManager(
type_handlers=[handler],
database=resource_config["database"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
AssetIn,
AssetKey,
DailyPartitionsDefinition,
Definitions,
DynamicPartitionsDefinition,
MetadataValue,
MultiPartitionKey,
MultiPartitionsDefinition,
Out,
Expand Down Expand Up @@ -103,6 +105,31 @@ def test_duckdb_io_manager_with_assets(tmp_path, io_managers):
duckdb_conn.close()


def test_io_manager_asset_metadata(tmp_path) -> None:
@asset
def my_pandas_df() -> pd.DataFrame:
return pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})

db_file = os.path.join(tmp_path, "unit_test.duckdb")
defs = Definitions(
assets=[my_pandas_df],
resources={
"io_manager": DuckDBPandasIOManager(database=db_file, schema="custom_schema"),
},
)

res = defs.get_implicit_global_asset_job_def().execute_in_process()
assert res.success

mats = res.get_asset_materialization_events()
assert len(mats) == 1
mat = mats[0]

assert mat.materialization.metadata["dagster/relation_identifier"] == MetadataValue.text(
f"{db_file}.custom_schema.my_pandas_df"
)


def test_duckdb_io_manager_with_schema(tmp_path):
@asset
def my_df() -> pd.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
AssetIn,
AssetKey,
DailyPartitionsDefinition,
Definitions,
DynamicPartitionsDefinition,
EnvVar,
MetadataValue,
MultiPartitionKey,
MultiPartitionsDefinition,
Out,
Expand Down Expand Up @@ -57,6 +59,31 @@ def temporary_bigquery_table(schema_name: Optional[str]) -> Iterator[str]:
).result()


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
@pytest.mark.integration
def test_io_manager_asset_metadata() -> None:
with temporary_bigquery_table(schema_name=SCHEMA) as table_name:

@asset(key_prefix=SCHEMA, name=table_name)
def my_pandas_df() -> pd.DataFrame:
return pd.DataFrame({"foo": ["bar", "baz"], "quux": [1, 2]})

defs = Definitions(
assets=[my_pandas_df], resources={"io_manager": pythonic_bigquery_io_manager}
)

res = defs.get_implicit_global_asset_job_def().execute_in_process()
assert res.success

mats = res.get_asset_materialization_events()
assert len(mats) == 1
mat = mats[0]

assert mat.materialization.metadata["dagster/relation_identifier"] == MetadataValue.text(
f"{os.getenv('GCP_PROJECT_ID')}.{SCHEMA}.{table_name}"
)


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE bigquery DB")
@pytest.mark.parametrize("io_manager", [(old_bigquery_io_manager), (pythonic_bigquery_io_manager)])
@pytest.mark.integration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AssetIn,
AssetKey,
DailyPartitionsDefinition,
Definitions,
DynamicPartitionsDefinition,
IOManagerDefinition,
MetadataValue,
Expand Down Expand Up @@ -208,6 +209,34 @@ def io_manager_test_job():
assert res.success


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
@pytest.mark.integration
def test_io_manager_asset_metadata() -> None:
with temporary_snowflake_table(
schema_name=SCHEMA,
db_name=DATABASE,
) as table_name:

@asset(key_prefix=SCHEMA, name=table_name)
def my_pandas_df():
return pandas.DataFrame({"foo": ["bar", "baz"], "quux": [1, 2]})

defs = Definitions(
assets=[my_pandas_df], resources={"io_manager": pythonic_snowflake_io_manager}
)

res = defs.get_implicit_global_asset_job_def().execute_in_process()
assert res.success

mats = res.get_asset_materialization_events()
assert len(mats) == 1
mat = mats[0]

assert mat.materialization.metadata["dagster/relation_identifier"] == MetadataValue.text(
f"{DATABASE}.{SCHEMA}.{table_name}"
)


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
@pytest.mark.parametrize(
"io_manager", [(snowflake_pandas_io_manager), (SnowflakePandasIOManager.configure_at_launch())]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AssetIn,
AssetKey,
DailyPartitionsDefinition,
Definitions,
DynamicPartitionsDefinition,
EnvVar,
IOManagerDefinition,
Expand Down Expand Up @@ -184,6 +185,37 @@ def io_manager_test_job():
assert res.success


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
@pytest.mark.integration
def test_io_manager_asset_metadata(spark) -> None:
with temporary_snowflake_table(
schema_name=SCHEMA,
db_name=DATABASE,
) as table_name:

@asset(key_prefix=SCHEMA, name=table_name)
def my_spark_df():
columns = ["foo", "quux"]
data = [("bar", 1), ("baz", 2)]
df = spark.createDataFrame(data).toDF(*columns)
return df

defs = Definitions(
assets=[my_spark_df], resources={"io_manager": pythonic_snowflake_io_manager}
)

res = defs.get_implicit_global_asset_job_def().execute_in_process()
assert res.success

mats = res.get_asset_materialization_events()
assert len(mats) == 1
mat = mats[0]

assert mat.materialization.metadata["dagster/relation_identifier"] == MetadataValue.text(
f"{DATABASE}.{SCHEMA}.{table_name}"
)


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
@pytest.mark.parametrize(
"io_manager", [(old_snowflake_io_manager), (pythonic_snowflake_io_manager)]
Expand Down

0 comments on commit 3ff5055

Please sign in to comment.