Skip to content

Commit

Permalink
Refactor LanceDB client implementation and error handling
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Jun 4, 2024
1 parent 092fcf0 commit b1783b2
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 36 deletions.
29 changes: 29 additions & 0 deletions dlt/destinations/impl/lancedb/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from functools import wraps
from typing import (
Any,
)

from lancedb.exceptions import MissingValueError, MissingColumnError

from dlt.common.destination.exceptions import (
DestinationUndefinedEntity,
DestinationTerminalException,
)
from dlt.common.destination.reference import JobClientBase
from dlt.common.typing import TFun


def lancedb_error(f: TFun) -> TFun:
@wraps(f)
def _wrap(self: JobClientBase, *args: Any, **kwargs: Any) -> Any:
try:
return f(self, *args, **kwargs)
except (
MissingValueError,
MissingColumnError,
) as status_ex:
raise DestinationUndefinedEntity(status_ex) from status_ex
except Exception as e:
raise DestinationTerminalException(e) from e

return _wrap
80 changes: 46 additions & 34 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from dlt.common import json, pendulum, logger
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.exceptions import DestinationUndefinedEntity
from dlt.common.destination.reference import (
JobClientBase,
WithStateSync,
Expand All @@ -43,8 +44,9 @@
from dlt.destinations.impl.lancedb.configuration import (
LanceDBClientConfiguration,
)
from dlt.destinations.impl.lancedb.errors import lancedb_error
from dlt.destinations.impl.lancedb.lancedb_adapter import VECTORIZE_HINT
from dlt.destinations.impl.lancedb.schema_conversion import (
from dlt.destinations.impl.lancedb.schema import (
TLanceModel,
create_template_schema,
make_fields,
Expand Down Expand Up @@ -151,10 +153,9 @@ def dataset_name(self) -> str:

@property
def sentinel_table(self) -> str:
# If no dataset name is provided, we still want to create a sentinel table.
return self.dataset_name or "DltSentinelTable"

def _make_qualified_table_name(self, table_name: str) -> str:
def make_qualified_table_name(self, table_name: str) -> str:
return (
f"{self.dataset_name}{self.config.dataset_separator}{table_name}"
if self.dataset_name
Expand All @@ -164,7 +165,7 @@ def _make_qualified_table_name(self, table_name: str) -> str:
def get_table_schema(self, table_name: str) -> pa.Schema:
return cast(pa.Schema, self.db_client[table_name].schema)

def _create_table(
def create_table(
self, table_name: str, schema: Union[pa.Schema, LanceModel]
) -> Table:
"""Create a LanceDB Table from the provided LanceModel or PyArrow schema.
Expand Down Expand Up @@ -244,6 +245,7 @@ def add_to_table(
data, mode, on_bad_vectors, fill_value
)

@lancedb_error
def drop_storage(self) -> None:
"""Drop the dataset from the LanceDB instance.
Expand All @@ -266,41 +268,51 @@ def drop_storage(self) -> None:

self._delete_sentinel_table()

@lancedb_error
def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None:
if not self.is_storage_initialized():
self._create_sentinel_table()
elif truncate_tables:
for table_name in truncate_tables:
fq_table_name = self._make_qualified_table_name(table_name)
if not self._table_exists(fq_table_name):
fq_table_name = self.make_qualified_table_name(table_name)
if not self.table_exists(fq_table_name):
continue
self.db_client.drop_table(fq_table_name)
self._create_table(
self.create_table(
table_name=fq_table_name,
schema=self.get_table_schema(fq_table_name),
)

@lancedb_error
def is_storage_initialized(self) -> bool:
return self._table_exists(self.sentinel_table)
return self.table_exists(self.sentinel_table)

def _create_sentinel_table(self) -> None:
"""Create an empty table to indicate that the storage is initialized."""
self._create_table(
self.create_table(
schema=cast(LanceModel, NullSchema), table_name=self.sentinel_table
)

def _delete_sentinel_table(self) -> None:
"""Delete the sentinel table."""
self.db_client.drop_table(self.sentinel_table)

@lancedb_error
def update_stored_schema(
self,
only_tables: Iterable[str] = None,
expected_update: TSchemaTables = None,
) -> Optional[TSchemaTables]:
super().update_stored_schema(only_tables, expected_update)
applied_update: TSchemaTables = {}
schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash)

try:
schema_info = self.get_stored_schema_by_hash(
self.schema.stored_version_hash
)
except DestinationUndefinedEntity:
schema_info = None

if schema_info is None:
logger.info(
f"Schema with hash {self.schema.stored_version_hash} "
Expand All @@ -315,7 +327,17 @@ def update_stored_schema(
)
return applied_update

def _update_schema_in_storage(self, schema: Schema) -> None:
def _execute_schema_update(self, only_tables: Iterable[str]) -> None:
for table_name in only_tables or self.schema.tables:
exists = self.table_exists(self.make_qualified_table_name(table_name))
if not exists:
self.create_table(
self.make_qualified_table_name(table_name),
schema=cast(LanceModel, NullSchema),
)
self.update_schema_in_storage(self.schema)

def update_schema_in_storage(self, schema: Schema) -> None:
properties = {
"version_hash": schema.stored_version_hash,
"schema_name": schema.name,
Expand All @@ -324,12 +346,12 @@ def _update_schema_in_storage(self, schema: Schema) -> None:
"inserted_at": str(pendulum.now()),
"schema": json.dumps(schema.to_dict()),
}
version_table_name = self._make_qualified_table_name(
version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)
self._create_record(properties, VersionSchema, version_table_name)
self.create_record(properties, VersionSchema, version_table_name)

def _create_record(
def create_record(
self, record: DictStrAny, lancedb_model: TLanceModel, table_name: str
) -> None:
"""Inserts a record into a LanceDB table without a vector.
Expand All @@ -340,31 +362,21 @@ def _create_record(
lancedb_model (LanceModel): Pydantic model to parse records.
"""
try:
tbl = self.db_client.open_table(self._make_qualified_table_name(table_name))
tbl = self.db_client.open_table(self.make_qualified_table_name(table_name))
except FileNotFoundError:
tbl = self.db_client.create_table(
self._make_qualified_table_name(table_name)
self.make_qualified_table_name(table_name)
)
except Exception:
raise

tbl.add(lancedb_model(**record))

def _execute_schema_update(self, only_tables: Iterable[str]) -> None:
for table_name in only_tables or self.schema.tables:
exists = self._table_exists(self._make_qualified_table_name(table_name))
if not exists:
self._create_table(
self._make_qualified_table_name(table_name),
schema=cast(LanceModel, NullSchema),
)
self._update_schema_in_storage(self.schema)

def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
"""Loads compressed state from destination storage by finding a load ID that was completed."""
while True:
try:
state_table_name = self._make_qualified_table_name(
state_table_name = self.make_qualified_table_name(
self.schema.state_table_name
)
state_records = (
Expand All @@ -378,7 +390,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
return None
for state in state_records:
load_id = state["_dlt_load_id"]
loads_table_name = self._make_qualified_table_name(
loads_table_name = self.make_qualified_table_name(
self.schema.loads_table_name
)
load_records = (
Expand All @@ -395,7 +407,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:

def get_stored_schema_by_hash(self, schema_hash: str) -> StorageSchemaInfo:
try:
table_name = self._make_qualified_table_name(self.schema.version_table_name)
table_name = self.make_qualified_table_name(self.schema.version_table_name)
response = (
self.db_client[table_name]
.search()
Expand All @@ -410,7 +422,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> StorageSchemaInfo:
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
try:
version_table_name = self._make_qualified_table_name(
version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)
response = (
Expand Down Expand Up @@ -443,8 +455,8 @@ def complete_load(self, load_id: str) -> None:
"status": 0,
"inserted_at": str(pendulum.now()),
}
loads_table_name = self._make_qualified_table_name(self.schema.loads_table_name)
self._create_record(properties, LoadsSchema, loads_table_name)
loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name)
self.create_record(properties, LoadsSchema, loads_table_name)

def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")
Expand All @@ -459,11 +471,11 @@ def start_file_load(
type_mapper=self.type_mapper,
db_client=self.db_client,
client_config=self.config,
table_name=self._make_qualified_table_name(table["name"]),
table_name=self.make_qualified_table_name(table["name"]),
model_func=self.model_func,
)

def _table_exists(self, table_name: str) -> bool:
def table_exists(self, table_name: str) -> bool:
return table_name in self.db_client.table_names()

def _from_db_type(
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions tests/load/lancedb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def assert_table(
) -> None:
client: LanceDBClient = pipeline.destination_client() # type: ignore[assignment]

exists = client._table_exists(collection_name)
exists = client.table_exists(collection_name)
assert exists

qualified_collection_name = client._make_qualified_table_name(collection_name)
qualified_collection_name = client.make_qualified_table_name(collection_name)
records = client.db_client.open_table(qualified_collection_name).search().limit(50).to_list()

if expected_items_count is not None:
Expand Down

0 comments on commit b1783b2

Please sign in to comment.