Skip to content

Commit

Permalink
Imports
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Jun 16, 2024
1 parent a8e4e62 commit b977e5c
Showing 1 changed file with 28 additions and 86 deletions.
114 changes: 28 additions & 86 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from datetime import timedelta
from dlt.common.pendulum import timedelta
from types import TracebackType
from typing import (
ClassVar,
Expand Down Expand Up @@ -74,9 +74,7 @@


TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {
v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()
}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()}


class LanceDBTypeMapper(TypeMapper):
Expand Down Expand Up @@ -183,9 +181,7 @@ def upload_batch(
tbl.add(records, mode="overwrite")
elif write_disposition == "merge":
if not id_field_name:
raise ValueError(
"To perform a merge update, 'id_field_name' must be specified."
)
raise ValueError("To perform a merge update, 'id_field_name' must be specified.")
tbl.merge_insert(
id_field_name
).when_matched_update_all().when_not_matched_insert_all().execute(records)
Expand Down Expand Up @@ -260,9 +256,7 @@ def get_table_schema(self, table_name: str) -> TArrowSchema:
)

@lancedb_error
def create_table(
self, table_name: str, schema: TArrowSchema, mode: str = "create"
) -> Table:
def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table:
"""Create a LanceDB Table from the provided LanceModel or PyArrow schema.
Args:
Expand Down Expand Up @@ -315,11 +309,7 @@ def _get_table_names(self) -> List[str]:
else:
table_names = self.db_client.table_names()

return [
table_name
for table_name in table_names
if table_name != self.sentinel_table
]
return [table_name for table_name in table_names if table_name != self.sentinel_table]

@lancedb_error
def drop_storage(self) -> None:
Expand Down Expand Up @@ -372,9 +362,7 @@ def update_stored_schema(
applied_update: TSchemaTables = {}

try:
schema_info = self.get_stored_schema_by_hash(
self.schema.stored_version_hash
)
schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash)
except DestinationUndefinedEntity:
schema_info = None

Expand Down Expand Up @@ -428,47 +416,35 @@ def add_table_fields(

# Check if any of the new fields already exist in the table.
existing_fields = set(arrow_table.schema.names)
new_fields = [
field for field in field_schemas if field.name not in existing_fields
]
new_fields = [field for field in field_schemas if field.name not in existing_fields]

if not new_fields:
# All fields already present, skip.
return None

null_arrays = [
pa.nulls(len(arrow_table), type=field.type) for field in new_fields
]
null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields]

for field, null_array in zip(new_fields, null_arrays):
arrow_table = arrow_table.append_column(field, null_array)

try:
return self.db_client.create_table(
table_name, arrow_table, mode="overwrite"
)
return self.db_client.create_table(table_name, arrow_table, mode="overwrite")
except OSError:
# Error occurred while creating the table, skip.
return None

def _execute_schema_update(self, only_tables: Iterable[str]) -> None:
for table_name in only_tables or self.schema.tables:
exists, existing_columns = self.get_storage_table(table_name)
new_columns = self.schema.get_new_table_columns(
table_name, existing_columns
)
new_columns = self.schema.get_new_table_columns(table_name, existing_columns)
embedding_fields: List[str] = get_columns_names_with_prop(
self.schema.get_table(table_name), VECTORIZE_HINT
)
logger.info(
f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}"
)
logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}")
if len(new_columns) > 0:
if exists:
field_schemas: List[TArrowField] = [
make_arrow_field_schema(
column["name"], column, self.type_mapper
)
make_arrow_field_schema(column["name"], column, self.type_mapper)
for column in new_columns
]
fq_table_name = self.make_qualified_table_name(table_name)
Expand All @@ -481,9 +457,7 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None:
vector_field_name = self.vector_field_name
id_field_name = self.id_field_name
embedding_model_func = self.model_func
embedding_model_dimensions = (
self.config.embedding_model_dimensions
)
embedding_model_dimensions = self.config.embedding_model_dimensions
else:
embedding_fields = None
vector_field_name = None
Expand Down Expand Up @@ -518,9 +492,7 @@ def update_schema_in_storage(self) -> None:
"schema": json.dumps(self.schema.to_dict()),
}
]
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)
write_disposition = self.schema.get_table(self.schema.version_table_name).get(
"write_disposition"
)
Expand All @@ -534,12 +506,8 @@ def update_schema_in_storage(self) -> None:
@lancedb_error
def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
"""Retrieves the latest completed state for a pipeline."""
fq_state_table_name = self.make_qualified_table_name(
self.schema.state_table_name
)
fq_loads_table_name = self.make_qualified_table_name(
self.schema.loads_table_name
)
fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name)
fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name)

state_table_: Table = self.db_client.open_table(fq_state_table_name)
state_table_.checkout_latest()
Expand All @@ -550,17 +518,11 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
# Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less
# data into memory as possible.
state_table = (
state_table_
.search()
state_table_.search()
.where(f"pipeline_name = '{pipeline_name}'", prefilter=True)
.to_arrow()
)
loads_table = (
loads_table_
.search()
.where("status = 0", prefilter=True)
.to_arrow()
)
loads_table = loads_table_.search().where("status = 0", prefilter=True).to_arrow()

# Join arrow tables in-memory.
joined_table: pa.Table = state_table.join(
Expand All @@ -576,54 +538,40 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
return StateInfo(**{k: v for k, v in state.items() if k in StateInfo._fields})

@lancedb_error
def get_stored_schema_by_hash(
self, schema_hash: str
) -> Optional[StorageSchemaInfo]:
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)
def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]:
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)

version_table: Table = self.db_client.open_table(fq_version_table_name)
version_table.checkout_latest()

try:
schemas = (
version_table
.search()
.where(f'version_hash = "{schema_hash}"', prefilter=True)
version_table.search().where(f'version_hash = "{schema_hash}"', prefilter=True)
).to_list()

# LanceDB's ORDER BY clause doesn't seem to work.
# See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341
most_recent_schema = sorted(
schemas, key=lambda x: x["inserted_at"], reverse=True
)[0]
most_recent_schema = sorted(schemas, key=lambda x: x["inserted_at"], reverse=True)[0]
return StorageSchemaInfo(**most_recent_schema)
except IndexError:
return None

@lancedb_error
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)

version_table: Table = self.db_client.open_table(fq_version_table_name)
version_table.checkout_latest()

try:
schemas = (
version_table
.search()
.where(f'schema_name = "{self.schema.name}"', prefilter=True)
version_table.search().where(f'schema_name = "{self.schema.name}"', prefilter=True)
).to_list()

# LanceDB's ORDER BY clause doesn't seem to work.
# See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341
most_recent_schema = sorted(
schemas, key=lambda x: x["inserted_at"], reverse=True
)[0]
most_recent_schema = sorted(schemas, key=lambda x: x["inserted_at"], reverse=True)[0]
return StorageSchemaInfo(**most_recent_schema)
except IndexError:
return None
Expand All @@ -650,9 +598,7 @@ def complete_load(self, load_id: str) -> None:
"schema_version_hash": None, # Payload schema must match the target schema.
}
]
fq_loads_table_name = self.make_qualified_table_name(
self.schema.loads_table_name
)
fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name)
write_disposition = self.schema.get_table(self.schema.loads_table_name).get(
"write_disposition"
)
Expand All @@ -666,9 +612,7 @@ def complete_load(self, load_id: str) -> None:
def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

def start_file_load(
self, table: TTableSchema, file_path: str, load_id: str
) -> LoadJob:
def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
return LoadLanceDBJob(
self.schema,
table,
Expand Down Expand Up @@ -707,9 +651,7 @@ def __init__(
self.table_name: str = table_schema["name"]
self.fq_table_name: str = fq_table_name
self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema)
self.embedding_fields: List[str] = get_columns_names_with_prop(
table_schema, VECTORIZE_HINT
)
self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT)
self.embedding_model_func: TextEmbeddingFunction = model_func
self.embedding_model_dimensions: int = client_config.embedding_model_dimensions
self.id_field_name: str = client_config.id_field_name
Expand Down

0 comments on commit b977e5c

Please sign in to comment.