diff --git a/README.rst b/README.rst index f09c809..a379847 100644 --- a/README.rst +++ b/README.rst @@ -210,4 +210,5 @@ This is not an officially supported Google product. Limitations ---------- -* Approximate Nearest Neighbors (ANN) strategies are only support for the GoogleSQL dialect +* Approximate Nearest Neighbors (ANN) strategies are only supported for the GoogleSQL dialect +* ANN's `ALTER VECTOR INDEX` is not supported by [Google Cloud Spanner](https://cloud.google.com/spanner/docs/find-approximate-nearest-neighbors#limitations) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 1b46b90..de5df68 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -88,6 +88,7 @@ class SecondaryIndex: columns: list[str] storing_columns: Optional[list[str]] = None num_leaves: Optional[int] = None # Only necessary for ANN + nullable_column: Optional[bool] = False # Only necessary for ANN num_branches: Optional[int] = None # Only necessary for ANN tree_depth: Optional[int] = None # Only necessary for ANN index_type: Optional[DistanceStrategy] = None # Only necessary for ANN @@ -551,7 +552,10 @@ def _generate_secondary_indices_ddl_ANN( secondary_index_ddl_statements = [] for secondary_index in secondary_indexes: - statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({secondary_index.columns[0]})" + column_name = secondary_index.columns[0] + statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({column_name})" + if secondary_index.nullable_column: + statement += f"\n\tWHERE {column_name} IS NOT NULL" options_segments = [f"distance_type='{secondary_index.index_type}'"] if secondary_index.tree_depth > 0: tree_depth = secondary_index.tree_depth @@ -983,6 +987,7 @@ def search_by_ANN( limit: int = None, is_embedding_nullable: bool = False, where_condition: str = None, + column_is_nullable: bool = False, ) -> List[Any]: sql = SpannerVectorStore._query_ANN( column_name, @@ -995,6 +1000,7 @@ def search_by_ANN( limit, is_embedding_nullable, where_condition, + column_is_nullable=column_is_nullable, ) staleness = self._query_parameters.staleness with self._database.snapshot( @@ -1017,6 +1023,7 @@ def _query_ANN( limit: int = None, is_embedding_nullable: bool = False, where_condition: str = None, + column_is_nullable: bool = False, ): """ Sample query: @@ -1026,6 +1033,16 @@ def _query_ANN( ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON '{"num_leaves_to_search": 10}') LIMIT 100 + + OR + + SELECT DocId + FROM Documents@{FORCE_INDEX=DocEmbeddingIndex} + WHERE NullableDocEmbedding IS NOT NULL + ORDER BY APPROX_EUCLIDEAN_DISTANCE( + ARRAY[1.0, 2.0, 3.0], NullableDocEmbedding, + options => JSON '{"num_leaves_to_search": 10}') + LIMIT 100 """ ann_strategy_name = distance_strategy_to_ANN_function.get(strategy, None) @@ -1036,8 +1053,12 @@ def _query_ANN( f"SELECT {column_name} FROM {table_name}" + "@{FORCE_INDEX=" + f"{index_name}" - + "}\n" - + f" ORDER BY {ann_strategy_name}(\n" + + ( + "}\n" + if (not column_is_nullable) + else "}\nWHERE " + f"{embedding_column_name} IS NOT NULL\n" + ) + + f"ORDER BY {ann_strategy_name}(\n" + f" ARRAY{embedding}, {embedding_column_name}, options => JSON '" + '{"num_leaves_to_search": %s})\n' % (num_leaves) ) diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index b4bb433..e791537 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -93,28 +93,38 @@ def test_generate_secondary_indices_ddl_ANN(self): DistanceStrategy.EUCLIDEIAN, ] + nullables = [True, False] for distance_strategy in strategies: - got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( - "Documents", - secondary_indexes=[ - SecondaryIndex( - index_name="DocEmbeddingIndex", - columns=["DocEmbedding"], - num_branches=1000, - tree_depth=3, - index_type=distance_strategy, - num_leaves=100000, - ) - ], - ) - - want = [ - "CREATE VECTOR INDEX DocEmbeddingIndex\n" - + " ON Documents(DocEmbedding)\n" - + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" - ] - - assert canonicalize(got) == canonicalize(want) + for nullable in nullables: + got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( + "Documents", + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + nullable_column=nullable, + num_branches=1000, + tree_depth=3, + index_type=distance_strategy, + num_leaves=100000, + ) + ], + ) + + want = [ + "CREATE VECTOR INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + " WHERE DocEmbedding IS NOT NULL\n" + + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + if not nullable: + want = [ + "CREATE VECTOR INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + + assert canonicalize(got) == canonicalize(want) def test_generate_secondary_indices_ddl_ANN_raises_exception_for_non_GoogleSQL_dialect( self, @@ -206,13 +216,35 @@ def test_query_ANN(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" - + " ORDER BY APPROX_COSINE_DISTANCE(\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + ' ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n' + "LIMIT 100" ) print("got", got) - print("want", want) + assert got == want + + def test_query_ANN_column_is_nullable(self): + got = SpannerVectorStore._query_ANN( + "DocId", + "Documents", + "DocEmbeddingIndex", + [1.0, 2.0, 3.0], + "DocEmbedding", + 10, + DistanceStrategy.COSINE, + limit=100, + column_is_nullable=True, + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE DocEmbedding IS NOT NULL\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + ' ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n' + + "LIMIT 100" + ) + assert got == want