From 8be267da7f2226233c126ff71da1fcf0526154c7 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 31 Jan 2025 09:42:20 +0200 Subject: [PATCH] Link up __search_by_ANN to similarity_search_by_vector --- samples/search_ann.py | 9 +++--- src/langchain_google_spanner/vector_store.py | 30 ++++++++++++++----- .../integration/test_spanner_vector_store.py | 1 + 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/samples/search_ann.py b/samples/search_ann.py index ff7fb2b..d1439b4 100644 --- a/samples/search_ann.py +++ b/samples/search_ann.py @@ -146,10 +146,11 @@ def clear_and_insert_data(tx): ), skip_not_nullable_columns=True, ) - results = vec_store.search_by_ANN( - "ProductDescriptionEmbeddingIndex", - 1000, - limit=20, + results = vec_store.similarity_search_by_vector( + embedding=embeddings, + index_name="ProductDescriptionEmbeddingIndex", + num_leaves=1000, + k=20, embedding_column_is_nullable=True, return_columns=["productName", "productDescription", "inventoryCount"], ) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 5c60fdd..e631c1d 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -225,7 +225,7 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: return dict(zip(columns, values)) def getIndexDistanceType(self, distance_strategy) -> str: - value = _GOOGLE_ALGO_INDEX_NAME.get(distance_strategy, None) + value = GOOGLE_DIALECT_TO_ANN_DISTANCE_FUNCTIONS.get(distance_strategy, None) if value is None: raise Exception(f"{distance_strategy} is unsupported for distance_type") return value @@ -1039,10 +1039,7 @@ def similarity_search_with_score_by_vector( ) return documents - def set_strategy(strategy: DistanceStrategy): - self.__strategy = strategy - - def search_by_ANN( + def __search_by_ANN( self, index_name: str, num_leaves: int, @@ -1301,9 +1298,26 @@ def similarity_search_by_vector( Returns: List[Document]: List of documents most similar to the query. """ - documents = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, pre_filter=pre_filter - ) + documents: List[Document] = None + if ( + self._query_parameters.algorithm + == QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR + ): + documents = self.__search_by_ANN( + index_name=kwargs.get("index_name", None), + num_leaves=kwargs.get("num_leaves", 1000), + limit=k, + embedding=embedding, + is_embedding_nullable=kwargs.get("is_embedding_nullable", False), + where_condition=kwargs.get("where_condition", ""), + ascending=kwargs.get("ascending", True), + return_columns=kwargs.get("return_columns", []), + ) + else: + documents = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, pre_filter=pre_filter + ) + return [doc for doc, _ in documents] def max_marginal_relevance_search_with_score_by_vector( diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index 692916b..1fa0097 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -618,6 +618,7 @@ def test_ann_add_data1(self, setup_database): embedding_service=embeddings, metadata_json_column="metadata", ) + _ = db class TestSpannerVectorStorePGSQL: