diff --git a/samples/main.py b/samples/main.py index a0bf21e..d879d76 100644 --- a/samples/main.py +++ b/samples/main.py @@ -18,131 +18,138 @@ VectorSearchIndex, ) -project_id = 'quip-441723' -instance_id = 'contracting' -google_database = 'ann' -zone = os.environ.get("GOOGLE_DATABASE_ZONE", "us-west2") +project_id = os.environ['PROJECT_ID'] +instance_id = os.environ['INSTANCE_ID'] +google_database = os.environ['GOOGLE_DATABASE'] +zone = os.environ["GOOGLE_DATABASE_ZONE"] table_name_ANN = "products" OPERATION_TIMEOUT_SECONDS = 240 def use_case(): - # Initialize the vector store table if necessary. - distance_strategy = DistanceStrategy.COSINE - SpannerVectorStore.init_vector_store_table( - instance_id=instance_id, - database_id=google_database, - table_name=table_name_ANN, - id_column=TableColumn("productId", type="INT64"), - vector_size=758, - embedding_column=TableColumn( - name="productDescriptionEmbedding", - type="ARRAY", - is_null=True, + # Initialize the vector store table if necessary. + distance_strategy = DistanceStrategy.COSINE + model_vector_size = 758 + SpannerVectorStore.init_vector_store_table( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column=TableColumn("productId", type="INT64"), + vector_size=model_vector_size, + embedding_column=TableColumn( + name="productDescriptionEmbedding", + type="ARRAY", + is_null=True, + ), + metadata_columns=[ + TableColumn(name="categoryId", type="INT64", is_null=False), + TableColumn(name="productName", type="STRING(MAX)", is_null=False), + TableColumn( + name="productDescription", type="STRING(MAX)", is_null=False ), - metadata_columns=[ - TableColumn(name="categoryId", type="INT64", is_null=False), - TableColumn(name="productName", type="STRING(MAX)", is_null=False), - TableColumn( - name="productDescription", type="STRING(MAX)", is_null=False - ), - TableColumn(name="inventoryCount", type="INT64", is_null=False), - TableColumn(name="priceInCents", type="INT64", is_null=True), - TableColumn(name="createTime", type="TIMESTAMP", is_null=False), - ], - secondary_indexes=[ - VectorSearchIndex( - index_name="ProductDescriptionEmbeddingIndex", - columns=["productDescriptionEmbedding"], - nullable_column=True, - num_branches=1000, - tree_depth=3, - index_type=distance_strategy, - num_leaves=100000, - ), + TableColumn(name="inventoryCount", type="INT64", is_null=False), + TableColumn(name="priceInCents", type="INT64", is_null=True), + TableColumn(name="createTime", type="TIMESTAMP", is_null=False), + ], + secondary_indexes=[ + VectorSearchIndex( + index_name="ProductDescriptionEmbeddingIndex", + columns=["productDescriptionEmbedding"], + nullable_column=True, + num_branches=1000, + tree_depth=3, + index_type=distance_strategy, + num_leaves=100000, + ), + ], + ) + + # Create the models if necessary. + client = Client(project=project_id) + database = client.instance(instance_id).database(google_database) + + model_ddl_statements = [ + f""" + CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( + content STRING(MAX), + ) OUTPUT( + embeddings STRUCT, values ARRAY>, + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' + ) + """, + f""" + CREATE MODEL IF NOT EXISTS LLMModel INPUT( + prompt STRING(MAX), + ) OUTPUT( + content STRING(MAX), + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', + default_batch_size = 1 + ) + """, + ] + operation = database.update_ddl(model_ddl_statements) + operation.result(OPERATION_TIMEOUT_SECONDS) + + def clear_and_insert_data(tx): + tx.execute_update("DELETE FROM products WHERE 1=1") + tx.insert( + 'products', + columns=[ + 'categoryId', 'productId', 'productName', + 'productDescription', + 'createTime', 'inventoryCount', 'priceInCents', ], + values=raw_data, ) - # Create the models if necessary. - client = Client(project=project_id) - database = client.instance(instance_id).database(google_database) - - model_ddl_statements = [ - f""" - CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( - content STRING(MAX), - ) OUTPUT( - embeddings STRUCT, values ARRAY>, - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' - ) - """, - f""" - CREATE MODEL IF NOT EXISTS LLMModel INPUT( - prompt STRING(MAX), - ) OUTPUT( - content STRING(MAX), - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', - default_batch_size = 1 - ) - """, - ] - operation = database.update_ddl(model_ddl_statements) - operation.result(OPERATION_TIMEOUT_SECONDS) - - def clear_and_insert_data(tx): - tx.execute_update("DELETE FROM products WHERE 1=1") - tx.insert( - 'products', - columns=[ - 'categoryId', 'productId', 'productName', - 'productDescription', - 'createTime', 'inventoryCount', 'priceInCents', - ], - values=raw_data, - ) - - tx.execute_update( - """UPDATE products p1 - SET productDescriptionEmbedding = - (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, - (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) - WHERE categoryId=1""", - ) - - embeddings = [] - rows = tx.execute_sql( - """SELECT embeddings.values - FROM ML.PREDICT( - MODEL EmbeddingsModel, - (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) - )""") - - for row in rows: - for nesting in row: - embeddings.extend(nesting) - - return embeddings - - embeddings = database.run_in_transaction(clear_and_insert_data) - - vec_store = SpannerVectorStore( - instance_id=instance_id, - database_id=google_database, - table_name=table_name_ANN, - id_column="categoryId", - embedding_service=embeddings, - embedding_column="productDescriptionEmbedding", - skip_not_nullable_columns=True, - ) - vec_store.search_by_ANN( - 'ProductDescriptionEmbeddingIndex', - 1000, - k=20, - embedding_column_is_nullable=True, - return_columns=['productName', 'productDescription', 'inventoryCount'], + tx.execute_update( + """UPDATE products p1 + SET productDescriptionEmbedding = + (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, + (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) + WHERE categoryId=1""", ) + embeddings = [] + rows = tx.execute_sql( + """SELECT embeddings.values + FROM ML.PREDICT( + MODEL EmbeddingsModel, + (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) + )""") + + for row in rows: + for nesting in row: + embeddings.extend(nesting) + + return embeddings + + embeddings = database.run_in_transaction(clear_and_insert_data) + if len(embeddings) > model_vector_size: + embeddings = embeddings[:model_vector_size] + + vec_store = SpannerVectorStore( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column="categoryId", + embedding_service=embeddings, + embedding_column="productDescriptionEmbedding", + skip_not_nullable_columns=True, + ) + results = vec_store.search_by_ANN( + 'ProductDescriptionEmbeddingIndex', + 1000, + limit=20, + embedding_column_is_nullable=True, + return_columns=['productName', 'productDescription', 'inventoryCount'], + ) + + print("Search by ANN results") + for res in results: + print(res) + def main(): use_case() diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 0e3b59e..5bf502b 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -225,7 +225,7 @@ def __clean_element(self, element: dict[str, Any], embedding_column: str) -> Non del element["properties"][embedding_column] def __get_distance_function( - self, distance_strategy=DistanceStrategy.EUCLIDEIAN + self, distance_strategy=DistanceStrategy.EUCLIDEAN ) -> str: """Gets the vector distance function.""" if distance_strategy == DistanceStrategy.COSINE: diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index c5ac979..1ba0692 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -1048,8 +1048,8 @@ def search_by_ANN( self, index_name: str, num_leaves: int, + limit: int, embedding: List[float] = None, - k: int = None, is_embedding_nullable: bool = False, where_condition: str = None, embedding_column_is_nullable: bool = False, @@ -1064,7 +1064,7 @@ def search_by_ANN( embedding or self._embedding_service, num_leaves, strategy, - k, + limit, is_embedding_nullable, where_condition, embedding_column_is_nullable=embedding_column_is_nullable, @@ -1075,6 +1075,7 @@ def search_by_ANN( with self._database.snapshot( **staleness if staleness is not None else {} ) as snapshot: + print("search by ANN sql", sql) results = snapshot.execute_sql(sql=sql) return list(results)