Skip to content

Commit

Permalink
fix(datasets): share the cache of Ibis connections
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <[email protected]>
  • Loading branch information
deepyaman committed Nov 25, 2024
1 parent 07aef5a commit d6c1fc7
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 46 deletions.
1 change: 1 addition & 0 deletions kedro-datasets/kedro_datasets/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .connection_mixin import ConnectionMixin
23 changes: 23 additions & 0 deletions kedro-datasets/kedro_datasets/_utils/connection_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from collections.abc import Hashable
from typing import Any, ClassVar


class ConnectionMixin:
_connections: ClassVar[dict[Hashable, Any]] = {}

@property
def _connection(self) -> Any:
def hashable(value: Any) -> Hashable:
"""Return a hashable key for a potentially-nested object."""
if isinstance(value, dict):
return tuple((k, hashable(v)) for k, v in sorted(value.items()))
if isinstance(value, list):
return tuple(hashable(x) for x in value)
return value

cls = type(self)
key = self._CONNECTION_GROUP, hashable(self._connection_config)
if key not in cls._connections:
cls._connections[key] = self._connect()

return cls._connections[key]
33 changes: 12 additions & 21 deletions kedro-datasets/kedro_datasets/ibis/file_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import ibis.expr.types as ir
from kedro.io import AbstractVersionedDataset, DatasetError, Version

from kedro_datasets._utils import ConnectionMixin

if TYPE_CHECKING:
from ibis import BaseBackend


class FileDataset(AbstractVersionedDataset[ir.Table, ir.Table]):
class FileDataset(ConnectionMixin, AbstractVersionedDataset[ir.Table, ir.Table]):
"""``FileDataset`` loads/saves data from/to a specified file format.
Example usage for the
Expand Down Expand Up @@ -73,7 +75,7 @@ class FileDataset(AbstractVersionedDataset[ir.Table, ir.Table]):
DEFAULT_LOAD_ARGS: ClassVar[dict[str, Any]] = {}
DEFAULT_SAVE_ARGS: ClassVar[dict[str, Any]] = {}

_connections: ClassVar[dict[tuple[tuple[str, str]], BaseBackend]] = {}
_CONNECTION_GROUP: ClassVar[str] = "ibis"

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -143,28 +145,17 @@ def __init__( # noqa: PLR0913
if save_args is not None:
self._save_args.update(save_args)

def _connect(self) -> BaseBackend:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
return backend.connect(**config)

@property
def connection(self) -> BaseBackend:
"""The ``Backend`` instance for the connection configuration."""

def hashable(value):
"""Return a hashable key for a potentially-nested object."""
if isinstance(value, dict):
return tuple((k, hashable(v)) for k, v in sorted(value.items()))
if isinstance(value, list):
return tuple(hashable(x) for x in value)
return value

cls = type(self)
key = hashable(self._connection_config)
if key not in cls._connections:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
cls._connections[key] = backend.connect(**config)

return cls._connections[key]
return self._connection

def load(self) -> ir.Table:
load_path = self._get_load_path()
Expand Down
32 changes: 11 additions & 21 deletions kedro-datasets/kedro_datasets/ibis/table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from kedro.io import AbstractDataset, DatasetError

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._utils import ConnectionMixin

if TYPE_CHECKING:
from ibis import BaseBackend


class TableDataset(AbstractDataset[ir.Table, ir.Table]):
class TableDataset(ConnectionMixin, AbstractDataset[ir.Table, ir.Table]):
"""``TableDataset`` loads/saves data from/to Ibis table expressions.
Example usage for the
Expand Down Expand Up @@ -70,7 +71,7 @@ class TableDataset(AbstractDataset[ir.Table, ir.Table]):
"overwrite": True,
}

_connections: ClassVar[dict[tuple[tuple[str, str]], BaseBackend]] = {}
_CONNECTION_GROUP: ClassVar[str] = "ibis"

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -145,28 +146,17 @@ def __init__( # noqa: PLR0913

self._materialized = self._save_args.pop("materialized")

def _connect(self) -> BaseBackend:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
return backend.connect(**config)

@property
def connection(self) -> BaseBackend:
"""The ``Backend`` instance for the connection configuration."""

def hashable(value):
"""Return a hashable key for a potentially-nested object."""
if isinstance(value, dict):
return tuple((k, hashable(v)) for k, v in sorted(value.items()))
if isinstance(value, list):
return tuple(hashable(x) for x in value)
return value

cls = type(self)
key = hashable(self._connection_config)
if key not in cls._connections:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
cls._connections[key] = backend.connect(**config)

return cls._connections[key]
return self._connection

def load(self) -> ir.Table:
if self._filepath is not None:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/tests/ibis/test_file_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def dummy_table():


class TestFileDataset:
def test_save_and_load(self, file_dataset, dummy_table, database):
def test_save_and_load(self, file_dataset, dummy_table):
"""Test saving and reloading the data set."""
file_dataset.save(dummy_table)
reloaded = file_dataset.load()
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_connection_config(self, mocker, file_dataset, connection_config, key):
)
mocker.patch(f"ibis.{backend}")
file_dataset.load()
assert key in file_dataset._connections
assert "ibis", key in file_dataset._connections


class TestFileDatasetVersioned:
Expand Down
22 changes: 20 additions & 2 deletions kedro-datasets/tests/ibis/test_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from kedro.io import DatasetError
from pandas.testing import assert_frame_equal

from kedro_datasets.ibis import TableDataset
from kedro_datasets.ibis import FileDataset, TableDataset

_SENTINEL = object()

Expand Down Expand Up @@ -56,6 +56,17 @@ def dummy_table(table_dataset_from_csv):
return table_dataset_from_csv.load()


@pytest.fixture
def file_dataset(filepath_csv, connection_config, load_args, save_args):
return FileDataset(
filepath=filepath_csv,
file_format="csv",
connection=connection_config,
load_args=load_args,
save_args=save_args,
)


class TestTableDataset:
def test_save_and_load(self, table_dataset, dummy_table, database):
"""Test saving and reloading the dataset."""
Expand Down Expand Up @@ -146,4 +157,11 @@ def test_connection_config(self, mocker, table_dataset, connection_config, key):
)
mocker.patch(f"ibis.{backend}")
table_dataset.load()
assert key in table_dataset._connections
assert "ibis", key in table_dataset._connections

def test_save_data_loaded_using_file_dataset(self, file_dataset, table_dataset):
"""Test interoperability of Ibis datasets sharing a database."""
dummy_table = file_dataset.load()
assert not table_dataset.exists()
table_dataset.save(dummy_table)
assert table_dataset.exists()

0 comments on commit d6c1fc7

Please sign in to comment.