Skip to content

Commit

Permalink
Remove anti-pattern of using dataclass which does not allow effective…
Browse files Browse the repository at this point in the history
… inheritance
  • Loading branch information
odeke-em committed Jan 28, 2025
1 parent 5a0be1b commit 605acd3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
1 change: 0 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def unit(session):


UNIT_TEST_STANDARD_DEPENDENCIES = [
"mock",
"pytest",
]
UNIT_TEST_DEPENDENCIES: List[str] = []
Expand Down
37 changes: 26 additions & 11 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,16 @@ def __post_init__(self):
raise ValueError("vector_length must be >=1")


@dataclass
class SecondaryIndex:
index_name: str
columns: list[str]
storing_columns: Optional[list[str]] = None
def __init__(
self,
index_name: str,
columns: list[str],
storing_columns: Optional[list[str]] = None,
):
self.index_name = index_name
self.columns = columns
self.storing_columns = storing_columns

def __post_init__(self):
# Check if column_name is None after initialization
Expand All @@ -102,17 +107,27 @@ def __post_init__(self):
raise ValueError("Index Columns can't be None")


@dataclass
class VectorSearchIndex(SecondaryIndex):
"""
The index for use with Approximate Nearest Neighbor (ANN) vector search.
"""

num_leaves: int
num_branches: int
tree_depth: int
index_type: DistanceStrategy
nullable_column: bool = False
def __init__(
self,
num_leaves: int,
num_branches: int,
tree_depth: int,
index_type: DistanceStrategy,
nullable_column: bool = False,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_leaves = num_leaves
self.num_branches = num_branches
self.tree_depth = tree_depth
self.index_type = index_type
self.nullable_column = nullable_column

def __post_init__(self):
if self.index_name is None:
Expand Down Expand Up @@ -506,7 +521,7 @@ def _generate_create_table_sql(
)
else:
embedding_column = TableColumn(
embedding_column, "ARRAY<FLOAT64p>", is_null=True
embedding_column, "ARRAY<FLOAT64>", is_null=True
)

configs = [id_column, content_column, embedding_column]
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_vectore_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_generate_create_table_sql(self):
"id",
)
want = (
"CREATE TABLE users (\n id STRING(36),\n essays STRING(MAX),"
"CREATE TABLE IF NOT EXISTS users (\n id STRING(36),\n essays STRING(MAX),"
+ "\n science_scores ARRAY<FLOAT64>\n) PRIMARY KEY(id)"
)
assert got == want
Expand Down

0 comments on commit 605acd3

Please sign in to comment.