Skip to content

Commit

Permalink
Link up __search_by_ANN to similarity_search_by_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Jan 31, 2025
1 parent 039b1bd commit 8be267d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
9 changes: 5 additions & 4 deletions samples/search_ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,11 @@ def clear_and_insert_data(tx):
),
skip_not_nullable_columns=True,
)
results = vec_store.search_by_ANN(
"ProductDescriptionEmbeddingIndex",
1000,
limit=20,
results = vec_store.similarity_search_by_vector(
embedding=embeddings,
index_name="ProductDescriptionEmbeddingIndex",
num_leaves=1000,
k=20,
embedding_column_is_nullable=True,
return_columns=["productName", "productDescription", "inventoryCount"],
)
Expand Down
30 changes: 22 additions & 8 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]:
return dict(zip(columns, values))

def getIndexDistanceType(self, distance_strategy) -> str:
value = _GOOGLE_ALGO_INDEX_NAME.get(distance_strategy, None)
value = GOOGLE_DIALECT_TO_ANN_DISTANCE_FUNCTIONS.get(distance_strategy, None)
if value is None:
raise Exception(f"{distance_strategy} is unsupported for distance_type")
return value
Expand Down Expand Up @@ -1039,10 +1039,7 @@ def similarity_search_with_score_by_vector(
)
return documents

def set_strategy(strategy: DistanceStrategy):
self.__strategy = strategy

def search_by_ANN(
def __search_by_ANN(
self,
index_name: str,
num_leaves: int,
Expand Down Expand Up @@ -1301,9 +1298,26 @@ def similarity_search_by_vector(
Returns:
List[Document]: List of documents most similar to the query.
"""
documents = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, pre_filter=pre_filter
)
documents: List[Document] = None
if (
self._query_parameters.algorithm
== QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR
):
documents = self.__search_by_ANN(
index_name=kwargs.get("index_name", None),
num_leaves=kwargs.get("num_leaves", 1000),
limit=k,
embedding=embedding,
is_embedding_nullable=kwargs.get("is_embedding_nullable", False),
where_condition=kwargs.get("where_condition", ""),
ascending=kwargs.get("ascending", True),
return_columns=kwargs.get("return_columns", []),
)
else:
documents = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, pre_filter=pre_filter
)

return [doc for doc, _ in documents]

def max_marginal_relevance_search_with_score_by_vector(
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_spanner_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def test_ann_add_data1(self, setup_database):
embedding_service=embeddings,
metadata_json_column="metadata",
)
_ = db


class TestSpannerVectorStorePGSQL:
Expand Down

0 comments on commit 8be267d

Please sign in to comment.