Skip to content

Commit

Permalink
chore: Add more database-related tests (follow up to apache#31948) (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
Vitor-Avila authored and tmsjordan committed Feb 1, 2025
1 parent 589c62e commit 77d2685
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 14 deletions.
6 changes: 6 additions & 0 deletions superset/commands/database/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def run(self) -> dict[str, Any]:
catalog=self._catalog_name,
schema=self._schema_name,
datasource_names=sorted(
# get_all_table_names_in_schema may return raw (unserialized) cached
# results, so we wrap them as DatasourceName objects here instead of
# directly in the method to ensure consistency.
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_table_names_in_schema(
catalog=self._catalog_name,
Expand All @@ -76,6 +79,9 @@ def run(self) -> dict[str, Any]:
catalog=self._catalog_name,
schema=self._schema_name,
datasource_names=sorted(
# get_all_view_names_in_schema may return raw (unserialized) cached
# results, so we wrap them as DatasourceName objects here instead of
# directly in the method to ensure consistency.
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_view_names_in_schema(
catalog=self._catalog_name,
Expand Down
28 changes: 14 additions & 14 deletions tests/unit_tests/commands/databases/tables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:

database = mocker.MagicMock()
database.database_name = "test_database"
database.get_all_table_names_in_schema.return_value = [
DatasourceName("table1", "schema1", "catalog1"),
DatasourceName("table2", "schema1", "catalog1"),
]
database.get_all_view_names_in_schema.return_value = [
DatasourceName("view1", "schema1", "catalog1"),
]
database.get_all_table_names_in_schema.return_value = {
("table1", "schema1", "catalog1"),
("table2", "schema1", "catalog1"),
}
database.get_all_view_names_in_schema.return_value = {
("view1", "schema1", "catalog1"),
}

DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO") # noqa: N806
DatabaseDAO.find_by_id.return_value = database
Expand All @@ -57,13 +57,13 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:

database = mocker.MagicMock()
database.database_name = "test_database"
database.get_all_table_names_in_schema.return_value = [
DatasourceName("table1", "schema1"),
DatasourceName("table2", "schema1"),
]
database.get_all_view_names_in_schema.return_value = [
DatasourceName("view1", "schema1"),
]
database.get_all_table_names_in_schema.return_value = {
("table1", "schema1", None),
("table2", "schema1", None),
}
database.get_all_view_names_in_schema.return_value = {
("view1", "schema1", None),
}

DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO") # noqa: N806
DatabaseDAO.find_by_id.return_value = database
Expand Down
56 changes: 56 additions & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,59 @@ def test_compile_sqla_query(query: Select) -> None:
WHERE
TRUE AND TRUE"""
)


def test_get_all_table_names_in_schema(mocker: MockerFixture) -> None:
"""
Test the `get_all_table_names_in_schema` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)

mocker.patch.object(database, "get_inspector")
get_table_names = mocker.patch(
"superset.db_engine_specs.postgres.PostgresEngineSpec.get_table_names"
)
get_table_names.return_value = {"first_table", "second_table", "third_table"}

tables_list = database.get_all_table_names_in_schema(
catalog="examples",
schema="public",
)
assert sorted(tables_list) == sorted(
{
("first_table", "public", "examples"),
("second_table", "public", "examples"),
("third_table", "public", "examples"),
}
)


def test_get_all_view_names_in_schema(mocker: MockerFixture) -> None:
"""
Test the `get_all_view_names_in_schema` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)

mocker.patch.object(database, "get_inspector")
get_view_names = mocker.patch(
"superset.db_engine_specs.base.BaseEngineSpec.get_view_names"
)
get_view_names.return_value = {"first_view", "second_view", "third_view"}

views_list = database.get_all_view_names_in_schema(
catalog="examples",
schema="public",
)
assert sorted(views_list) == sorted(
{
("first_view", "public", "examples"),
("second_view", "public", "examples"),
("third_view", "public", "examples"),
}
)

0 comments on commit 77d2685

Please sign in to comment.