diff --git a/kedro-datasets/kedro_datasets/_utils/__init__.py b/kedro-datasets/kedro_datasets/_utils/__init__.py index e69de29bb..d6f540fe0 100644 --- a/kedro-datasets/kedro_datasets/_utils/__init__.py +++ b/kedro-datasets/kedro_datasets/_utils/__init__.py @@ -0,0 +1 @@ +from .connection_mixin import ConnectionMixin diff --git a/kedro-datasets/kedro_datasets/_utils/connection_mixin.py b/kedro-datasets/kedro_datasets/_utils/connection_mixin.py new file mode 100644 index 000000000..620aaacf6 --- /dev/null +++ b/kedro-datasets/kedro_datasets/_utils/connection_mixin.py @@ -0,0 +1,23 @@ +from collections.abc import Hashable +from typing import Any, ClassVar + + +class ConnectionMixin: + _connections: ClassVar[dict[Hashable, Any]] = {} + + @property + def _connection(self) -> Any: + def hashable(value: Any) -> Hashable: + """Return a hashable key for a potentially-nested object.""" + if isinstance(value, dict): + return tuple((k, hashable(v)) for k, v in sorted(value.items())) + if isinstance(value, list): + return tuple(hashable(x) for x in value) + return value + + cls = type(self) + key = self._CONNECTION_GROUP, hashable(self._connection_config) + if key not in cls._connections: + cls._connections[key] = self._connect() + + return cls._connections[key] diff --git a/kedro-datasets/kedro_datasets/ibis/file_dataset.py b/kedro-datasets/kedro_datasets/ibis/file_dataset.py index f204e297b..c3c43b74f 100644 --- a/kedro-datasets/kedro_datasets/ibis/file_dataset.py +++ b/kedro-datasets/kedro_datasets/ibis/file_dataset.py @@ -8,11 +8,13 @@ import ibis.expr.types as ir from kedro.io import AbstractVersionedDataset, DatasetError, Version +from kedro_datasets._utils import ConnectionMixin + if TYPE_CHECKING: from ibis import BaseBackend -class FileDataset(AbstractVersionedDataset[ir.Table, ir.Table]): +class FileDataset(ConnectionMixin, AbstractVersionedDataset[ir.Table, ir.Table]): """``FileDataset`` loads/saves data from/to a specified file format. Example usage for the @@ -73,7 +75,7 @@ class FileDataset(AbstractVersionedDataset[ir.Table, ir.Table]): DEFAULT_LOAD_ARGS: ClassVar[dict[str, Any]] = {} DEFAULT_SAVE_ARGS: ClassVar[dict[str, Any]] = {} - _connections: ClassVar[dict[tuple[tuple[str, str]], BaseBackend]] = {} + _CONNECTION_GROUP: ClassVar[str] = "ibis" def __init__( # noqa: PLR0913 self, @@ -143,28 +145,17 @@ def __init__( # noqa: PLR0913 if save_args is not None: self._save_args.update(save_args) + def _connect(self) -> BaseBackend: + import ibis + + config = deepcopy(self._connection_config) + backend = getattr(ibis, config.pop("backend")) + return backend.connect(**config) + @property def connection(self) -> BaseBackend: """The ``Backend`` instance for the connection configuration.""" - - def hashable(value): - """Return a hashable key for a potentially-nested object.""" - if isinstance(value, dict): - return tuple((k, hashable(v)) for k, v in sorted(value.items())) - if isinstance(value, list): - return tuple(hashable(x) for x in value) - return value - - cls = type(self) - key = hashable(self._connection_config) - if key not in cls._connections: - import ibis - - config = deepcopy(self._connection_config) - backend = getattr(ibis, config.pop("backend")) - cls._connections[key] = backend.connect(**config) - - return cls._connections[key] + return self._connection def load(self) -> ir.Table: load_path = self._get_load_path() diff --git a/kedro-datasets/kedro_datasets/ibis/table_dataset.py b/kedro-datasets/kedro_datasets/ibis/table_dataset.py index 30709d08e..f2e6f23fc 100644 --- a/kedro-datasets/kedro_datasets/ibis/table_dataset.py +++ b/kedro-datasets/kedro_datasets/ibis/table_dataset.py @@ -9,12 +9,13 @@ from kedro.io import AbstractDataset, DatasetError from kedro_datasets import KedroDeprecationWarning +from kedro_datasets._utils import ConnectionMixin if TYPE_CHECKING: from ibis import BaseBackend -class TableDataset(AbstractDataset[ir.Table, ir.Table]): +class TableDataset(ConnectionMixin, AbstractDataset[ir.Table, ir.Table]): """``TableDataset`` loads/saves data from/to Ibis table expressions. Example usage for the @@ -70,7 +71,7 @@ class TableDataset(AbstractDataset[ir.Table, ir.Table]): "overwrite": True, } - _connections: ClassVar[dict[tuple[tuple[str, str]], BaseBackend]] = {} + _CONNECTION_GROUP: ClassVar[str] = "ibis" def __init__( # noqa: PLR0913 self, @@ -145,28 +146,17 @@ def __init__( # noqa: PLR0913 self._materialized = self._save_args.pop("materialized") + def _connect(self) -> BaseBackend: + import ibis + + config = deepcopy(self._connection_config) + backend = getattr(ibis, config.pop("backend")) + return backend.connect(**config) + @property def connection(self) -> BaseBackend: """The ``Backend`` instance for the connection configuration.""" - - def hashable(value): - """Return a hashable key for a potentially-nested object.""" - if isinstance(value, dict): - return tuple((k, hashable(v)) for k, v in sorted(value.items())) - if isinstance(value, list): - return tuple(hashable(x) for x in value) - return value - - cls = type(self) - key = hashable(self._connection_config) - if key not in cls._connections: - import ibis - - config = deepcopy(self._connection_config) - backend = getattr(ibis, config.pop("backend")) - cls._connections[key] = backend.connect(**config) - - return cls._connections[key] + return self._connection def load(self) -> ir.Table: if self._filepath is not None: diff --git a/kedro-datasets/tests/ibis/test_file_dataset.py b/kedro-datasets/tests/ibis/test_file_dataset.py index c21a06466..762a3c9d1 100644 --- a/kedro-datasets/tests/ibis/test_file_dataset.py +++ b/kedro-datasets/tests/ibis/test_file_dataset.py @@ -59,7 +59,7 @@ def dummy_table(): class TestFileDataset: - def test_save_and_load(self, file_dataset, dummy_table, database): + def test_save_and_load(self, file_dataset, dummy_table): """Test saving and reloading the data set.""" file_dataset.save(dummy_table) reloaded = file_dataset.load() @@ -127,7 +127,7 @@ def test_connection_config(self, mocker, file_dataset, connection_config, key): ) mocker.patch(f"ibis.{backend}") file_dataset.load() - assert key in file_dataset._connections + assert "ibis", key in file_dataset._connections class TestFileDatasetVersioned: diff --git a/kedro-datasets/tests/ibis/test_table_dataset.py b/kedro-datasets/tests/ibis/test_table_dataset.py index a778b08e0..ee93d4f38 100644 --- a/kedro-datasets/tests/ibis/test_table_dataset.py +++ b/kedro-datasets/tests/ibis/test_table_dataset.py @@ -4,7 +4,7 @@ from kedro.io import DatasetError from pandas.testing import assert_frame_equal -from kedro_datasets.ibis import TableDataset +from kedro_datasets.ibis import FileDataset, TableDataset _SENTINEL = object() @@ -56,6 +56,17 @@ def dummy_table(table_dataset_from_csv): return table_dataset_from_csv.load() +@pytest.fixture +def file_dataset(filepath_csv, connection_config, load_args, save_args): + return FileDataset( + filepath=filepath_csv, + file_format="csv", + connection=connection_config, + load_args=load_args, + save_args=save_args, + ) + + class TestTableDataset: def test_save_and_load(self, table_dataset, dummy_table, database): """Test saving and reloading the dataset.""" @@ -146,4 +157,11 @@ def test_connection_config(self, mocker, table_dataset, connection_config, key): ) mocker.patch(f"ibis.{backend}") table_dataset.load() - assert key in table_dataset._connections + assert "ibis", key in table_dataset._connections + + def test_save_data_loaded_using_file_dataset(self, file_dataset, table_dataset): + """Test interoperability of Ibis datasets sharing a database.""" + dummy_table = file_dataset.load() + assert not table_dataset.exists() + table_dataset.save(dummy_table) + assert table_dataset.exists()