Skip to content

Commit

Permalink
test(ANN): add integration tests and retrofit
Browse files Browse the repository at this point in the history
With the respective adjustments to ensure that ANN
works with integration tests.
  • Loading branch information
odeke-em committed Feb 4, 2025
1 parent 5a25f91 commit 4f75597
Show file tree
Hide file tree
Showing 2 changed files with 309 additions and 7 deletions.
28 changes: 24 additions & 4 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def __init__(
embedding_service: Embeddings,
id_column: str = ID_COLUMN_NAME,
content_column: str = CONTENT_COLUMN_NAME,
embedding_column: str = EMBEDDING_COLUMN_NAME,
embedding_column: Optional[str | TableColumn] = None,
client: Optional[spanner.Client] = None,
metadata_columns: Optional[List[str]] = None,
ignore_metadata_columns: Optional[List[str]] = None,
Expand Down Expand Up @@ -667,7 +667,14 @@ def __init__(
self._client = client_with_user_agent(client, USER_AGENT_VECTOR_STORE)
self._id_column = id_column
self._content_column = content_column
if embedding_column is None:
embedding_column = EMBEDDING_COLUMN_NAME
self._embedding_column = embedding_column
self._embedding_column_type = None
if isinstance(embedding_column, TableColumn):
self._embedding_column_type = embedding_column.type
self._embedding_column = embedding_column.name
embedding_column = embedding_column.name
self._metadata_json_column = metadata_json_column

self._query_parameters = query_parameters
Expand Down Expand Up @@ -1057,14 +1064,17 @@ def _get_rows_by_similarity_search_ann(
k,
self._query_parameters.distance_strategy,
pre_filter=pre_filter,
embedding_column_type=self._embedding_column_type,
embedding_column_is_nullable=embedding_column_is_nullable,
ascending=ascending,
return_columns=return_columns or self._columns_to_insert,
)
print("ANN sql", sql)
staleness = self._query_parameters.staleness
with self._database.snapshot(
**staleness if staleness is not None else {}
) as snapshot:
param_types = {}
results = snapshot.execute_sql(sql=sql)
column_order_map = {
value: index for index, value in enumerate(self._columns_to_insert)
Expand All @@ -1081,13 +1091,17 @@ def _generate_sql_for_ANN(
k: int,
strategy: DistanceStrategy = DistanceStrategy.COSINE,
pre_filter: Optional[str] = None,
embedding_column_type: str = "ARRAY<FLOAT32>",
embedding_column_is_nullable: bool = False,
ascending: bool = True,
return_columns: Optional[List[str]] = None,
) -> str:
if not embedding_column_name:
raise Exception("embedding_column_name must be set")

if not index_name:
raise Exception("index_name must be set")

ann_strategy_name = GOOGLE_DIALECT_TO_ANN_DISTANCE_FUNCTIONS.get(strategy, None)
if not ann_strategy_name:
raise Exception(f"{strategy} is not supported for ANN")
Expand All @@ -1112,7 +1126,7 @@ def _generate_sql_for_ANN(
+ "\n"
)
+ f"ORDER BY {ann_strategy_name}(\n"
+ f" ARRAY<FLOAT32>{embedding}, {embedding_column_name}, options => JSON '"
+ f" {embedding_column_type}{embedding}, {embedding_column_name}, options => JSON '"
+ '{"num_leaves_to_search": %s}\')%s\n'
% (num_leaves, "" if ascending else " DESC")
)
Expand Down Expand Up @@ -1221,7 +1235,10 @@ def similarity_search(
"""
embedding = self._embedding_service.embed_query(query)
documents = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, pre_filter=pre_filter
embedding=embedding,
k=k,
pre_filter=pre_filter,
**kwargs,
)
return [doc for doc, _ in documents]

Expand All @@ -1245,7 +1262,10 @@ def similarity_search_with_score(
"""
embedding = self._embedding_service.embed_query(query)
documents = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, pre_filter=pre_filter
embedding=embedding,
k=k,
pre_filter=pre_filter,
**kwargs,
)
return documents

Expand Down
Loading

0 comments on commit 4f75597

Please sign in to comment.