diff --git a/samples/search_ann.py b/samples/search_ann.py index d1439b4..8cc6d8b 100644 --- a/samples/search_ann.py +++ b/samples/search_ann.py @@ -66,69 +66,9 @@ def use_case(): client = Client(project=project_id) database = client.instance(instance_id).database(google_database) - model_ddl_statements = [ - f""" - CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( - content STRING(MAX), - ) OUTPUT( - embeddings STRUCT, values ARRAY>, - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' - ) - """, - f""" - CREATE MODEL IF NOT EXISTS LLMModel INPUT( - prompt STRING(MAX), - ) OUTPUT( - content STRING(MAX), - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', - default_batch_size = 1 - ) - """, - ] operation = database.update_ddl(model_ddl_statements) operation.result(OPERATION_TIMEOUT_SECONDS) - def clear_and_insert_data(tx): - tx.execute_update("DELETE FROM products WHERE 1=1") - tx.insert( - "products", - columns=[ - "categoryId", - "productId", - "productName", - "productDescription", - "createTime", - "inventoryCount", - "priceInCents", - ], - values=raw_data, - ) - - tx.execute_update( - """UPDATE products p1 - SET productDescriptionEmbedding = - (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, - (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) - WHERE categoryId=1""", - ) - - embeddings = [] - rows = tx.execute_sql( - """SELECT embeddings.values - FROM ML.PREDICT( - MODEL EmbeddingsModel, - (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) - )""" - ) - - for row in rows: - for nesting in row: - embeddings.extend(nesting) - - return embeddings - embeddings = database.run_in_transaction(clear_and_insert_data) if len(embeddings) > model_vector_size: embeddings = embeddings[:model_vector_size] @@ -271,6 +211,68 @@ def PENDING_COMMIT_TIMESTAMP(): "priceInCents", ] +model_ddl_statements = [ + f""" + CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( + content STRING(MAX), + ) OUTPUT( + embeddings STRUCT, values ARRAY>, + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' + ) + """, + f""" + CREATE MODEL IF NOT EXISTS LLMModel INPUT( + prompt STRING(MAX), + ) OUTPUT( + content STRING(MAX), + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', + default_batch_size = 1 + ) + """, +] + + +def clear_and_insert_data(tx): + tx.execute_update("DELETE FROM products WHERE 1=1") + tx.insert( + "products", + columns=[ + "categoryId", + "productId", + "productName", + "productDescription", + "createTime", + "inventoryCount", + "priceInCents", + ], + values=raw_data, + ) + + tx.execute_update( + """UPDATE products p1 + SET productDescriptionEmbedding = + (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, + (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) + WHERE categoryId=1""", + ) + + embeddings = [] + rows = tx.execute_sql( + """SELECT embeddings.values + FROM ML.PREDICT( + MODEL EmbeddingsModel, + (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) + )""" + ) + + for row in rows: + for nesting in row: + embeddings.extend(nesting) + + return embeddings + if __name__ == "__main__": main() diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index e631c1d..20679f4 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -1043,10 +1043,10 @@ def __search_by_ANN( self, index_name: str, num_leaves: int, - limit: int, + k: int, # Defines the limit embedding: List[float] = None, is_embedding_nullable: bool = False, - where_condition: str = None, + pre_filter: str = None, embedding_column_is_nullable: bool = False, ascending: bool = True, return_columns: List[str] = None, @@ -1057,10 +1057,10 @@ def __search_by_ANN( self._embedding_column, embedding or self._embedding_service, num_leaves, - limit, + k, self._query_parameters.distance_strategy, is_embedding_nullable, - where_condition, + pre_filter=pre_filter, embedding_column_is_nullable=embedding_column_is_nullable, ascending=ascending, return_columns=return_columns, @@ -1085,12 +1085,13 @@ def _query_ANN( embedding_column_name: str, embedding: List[float], num_leaves: int, - limit: int, + k: int, strategy: DistanceStrategy = DistanceStrategy.COSINE, is_embedding_nullable: bool = False, - where_condition: str = None, + pre_filter: str = None, embedding_column_is_nullable: bool = False, ascending: bool = True, + post_filter: str = None, # TODO(@odeke-em): Not yet supported return_columns: List[str] = None, ): """ @@ -1132,9 +1133,12 @@ def _query_ANN( + "@{FORCE_INDEX=" + f"{index_name}" + ( - "}\n" + ("}\nWHERE " + ("1=1" if not pre_filter else f"{pre_filter}") + "\n") if (not embedding_column_is_nullable) - else "}\nWHERE " + f"{embedding_column_name} IS NOT NULL\n" + else "}\nWHERE " + + f"{embedding_column_name} IS NOT NULL" + + ("" if not pre_filter else f" AND {pre_filter}") + + "\n" ) + f"ORDER BY {ann_strategy_name}(\n" + f" ARRAY{embedding}, {embedding_column_name}, options => JSON '" @@ -1142,17 +1146,11 @@ def _query_ANN( % (num_leaves, "" if ascending else " DESC") ) - if where_condition: - sql += " WHERE " + where_condition + "\n" - - if limit: - sql += f"LIMIT {limit}" + if k: + sql += f"LIMIT {k}" return sql.strip() - def _get_rows_by_similarity_search_ann(): - pass - def _get_rows_by_similarity_search_knn( self, embedding: List[float], @@ -1205,6 +1203,15 @@ def _get_rows_by_similarity_search_knn( return list(results), column_order_map + def _get_rows_by_similarity_search_ann( + self, + embedding: List[float], + k: int, + pre_filter: Optional[str] = None, + **kwargs: Any, + ): + raise RuntimeError("Unimplemented") + def _get_documents_from_query_results( self, results: List[List], column_order_map: Dict[str, int] ) -> List[Tuple[Document, float]]: @@ -1306,10 +1313,10 @@ def similarity_search_by_vector( documents = self.__search_by_ANN( index_name=kwargs.get("index_name", None), num_leaves=kwargs.get("num_leaves", 1000), - limit=k, + k=k, embedding=embedding, is_embedding_nullable=kwargs.get("is_embedding_nullable", False), - where_condition=kwargs.get("where_condition", ""), + pre_filter=kwargs.get("pre_filter", ""), ascending=kwargs.get("ascending", True), return_columns=kwargs.get("return_columns", []), ) diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 64bf307..4064ea2 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -216,6 +216,7 @@ def test_query_ANN(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE 1=1\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + "'{\"num_leaves_to_search\": 10}')\n" @@ -286,6 +287,7 @@ def test_query_ANN_order_DESC(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE 1=1\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + "'{\"num_leaves_to_search\": 10}') DESC\n" @@ -294,7 +296,7 @@ def test_query_ANN_order_DESC(self): assert got == want - def test_query_ANN_unspecified_limit(self): + def test_query_ANN_specified_limit(self): got = SpannerVectorStore._query_ANN( "Documents", "DocEmbeddingIndex", @@ -308,6 +310,81 @@ def test_query_ANN_unspecified_limit(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE 1=1\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + "'{\"num_leaves_to_search\": 10}')\n" + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_specified_pre_filter(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + 100, + DistanceStrategy.COSINE, + return_columns=["DocId"], + pre_filter="categoryId!=20", + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE categoryId!=20\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + "'{\"num_leaves_to_search\": 10}')\n" + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_specified_pre_filter_with_nullable_column(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + 100, + DistanceStrategy.COSINE, + return_columns=["DocId"], + pre_filter="categoryId!=9", + embedding_column_is_nullable=True, + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE DocEmbedding IS NOT NULL AND categoryId!=9\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + "'{\"num_leaves_to_search\": 10}')\n" + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_no_pre_filter_non_nullable(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + 100, + DistanceStrategy.COSINE, + embedding_column_is_nullable=True, + return_columns=["DocId"], + pre_filter="DocId!=2", + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE DocEmbedding IS NOT NULL AND DocId!=2\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + "'{\"num_leaves_to_search\": 10}')\n"