From a7916903cfc9127d2e51caf4a326f43da320e8d9 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Wed, 29 Jan 2025 23:03:15 +0200 Subject: [PATCH] Update based off use case carving --- requirements.txt | 4 +- samples/main.py | 180 ++++++++++++++++++ src/langchain_google_spanner/vector_store.py | 73 +++---- .../integration/test_spanner_vector_store.py | 2 +- 4 files changed, 214 insertions(+), 45 deletions(-) create mode 100644 samples/main.py diff --git a/requirements.txt b/requirements.txt index 9e16179..621d9c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -google-cloud-spanner==3.49.1 -langchain-core==0.3.9 +google-cloud-spanner==3.51.0 +langchain-core==0.3.15 langchain-community==0.3.1 pydantic==2.9.1 diff --git a/samples/main.py b/samples/main.py new file mode 100644 index 0000000..a0bf21e --- /dev/null +++ b/samples/main.py @@ -0,0 +1,180 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +import datetime +import os +import time +import uuid + +from google.cloud.spanner import Client # type: ignore +from langchain_community.document_loaders import HNLoader +from langchain_community.embeddings import FakeEmbeddings + +from langchain_google_spanner.vector_store import ( # type: ignore + DistanceStrategy, + QueryParameters, + SpannerVectorStore, + TableColumn, + VectorSearchIndex, +) + +project_id = 'quip-441723' +instance_id = 'contracting' +google_database = 'ann' +zone = os.environ.get("GOOGLE_DATABASE_ZONE", "us-west2") +table_name_ANN = "products" +OPERATION_TIMEOUT_SECONDS = 240 + +def use_case(): + # Initialize the vector store table if necessary. + distance_strategy = DistanceStrategy.COSINE + SpannerVectorStore.init_vector_store_table( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column=TableColumn("productId", type="INT64"), + vector_size=758, + embedding_column=TableColumn( + name="productDescriptionEmbedding", + type="ARRAY", + is_null=True, + ), + metadata_columns=[ + TableColumn(name="categoryId", type="INT64", is_null=False), + TableColumn(name="productName", type="STRING(MAX)", is_null=False), + TableColumn( + name="productDescription", type="STRING(MAX)", is_null=False + ), + TableColumn(name="inventoryCount", type="INT64", is_null=False), + TableColumn(name="priceInCents", type="INT64", is_null=True), + TableColumn(name="createTime", type="TIMESTAMP", is_null=False), + ], + secondary_indexes=[ + VectorSearchIndex( + index_name="ProductDescriptionEmbeddingIndex", + columns=["productDescriptionEmbedding"], + nullable_column=True, + num_branches=1000, + tree_depth=3, + index_type=distance_strategy, + num_leaves=100000, + ), + ], + ) + + # Create the models if necessary. + client = Client(project=project_id) + database = client.instance(instance_id).database(google_database) + + model_ddl_statements = [ + f""" + CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( + content STRING(MAX), + ) OUTPUT( + embeddings STRUCT, values ARRAY>, + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' + ) + """, + f""" + CREATE MODEL IF NOT EXISTS LLMModel INPUT( + prompt STRING(MAX), + ) OUTPUT( + content STRING(MAX), + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', + default_batch_size = 1 + ) + """, + ] + operation = database.update_ddl(model_ddl_statements) + operation.result(OPERATION_TIMEOUT_SECONDS) + + def clear_and_insert_data(tx): + tx.execute_update("DELETE FROM products WHERE 1=1") + tx.insert( + 'products', + columns=[ + 'categoryId', 'productId', 'productName', + 'productDescription', + 'createTime', 'inventoryCount', 'priceInCents', + ], + values=raw_data, + ) + + tx.execute_update( + """UPDATE products p1 + SET productDescriptionEmbedding = + (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, + (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) + WHERE categoryId=1""", + ) + + embeddings = [] + rows = tx.execute_sql( + """SELECT embeddings.values + FROM ML.PREDICT( + MODEL EmbeddingsModel, + (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) + )""") + + for row in rows: + for nesting in row: + embeddings.extend(nesting) + + return embeddings + + embeddings = database.run_in_transaction(clear_and_insert_data) + + vec_store = SpannerVectorStore( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column="categoryId", + embedding_service=embeddings, + embedding_column="productDescriptionEmbedding", + skip_not_nullable_columns=True, + ) + vec_store.search_by_ANN( + 'ProductDescriptionEmbeddingIndex', + 1000, + k=20, + embedding_column_is_nullable=True, + return_columns=['productName', 'productDescription', 'inventoryCount'], + ) + +def main(): + use_case() + + +def PENDING_COMMIT_TIMESTAMP(): + return (datetime.datetime.utcnow() + datetime.timedelta(days=1)).isoformat() + "Z" + return 'PENDING_COMMIT_TIMESTAMP()' + +raw_data = [ + (1, 1, "Cymbal Helios Helmet", "Safety meets style with the Cymbal children's bike helmet. Its lightweight design, superior ventilation, and adjustable fit ensure comfort and protection on every ride. Stay bright and keep your child safe under the sun with Cymbal Helios!", PENDING_COMMIT_TIMESTAMP(), 100, 10999), + (1, 2, "Cymbal Sprout", "Let their cycling journey begin with the Cymbal Sprout, the ideal balance bike for beginning riders ages 2-4 years. Its lightweight frame, low seat height, and puncture-proof tires promote stability and confidence as little ones learn to balance and steer. Watch them sprout into cycling enthusiasts with Cymbal Sprout!", PENDING_COMMIT_TIMESTAMP(), 10, 13999), + (1, 3, "Cymbal Spark Jr.", "Light, vibrant, and ready for adventure, the Spark Jr. is the perfect first bike for young riders (ages 5-8). Its sturdy frame, easy-to-use brakes, and puncture-resistant tires inspire confidence and endless playtime. Let the spark of cycling ignite with Cymbal!", PENDING_COMMIT_TIMESTAMP(), 34, 13900), + (1, 4, "Cymbal Summit", "Conquering trails is a breeze with the Summit mountain bike. Its lightweight aluminum frame, responsive suspension, and powerful disc brakes provide exceptional control and comfort for experienced bikers navigating rocky climbs or shredding downhill. Reach new heights with Cymbal Summit!", PENDING_COMMIT_TIMESTAMP(), 0, 79999), + (1, 5, "Cymbal Breeze", "Cruise in style and embrace effortless pedaling with the Breeze electric bike. Its whisper-quiet motor and long-lasting battery let you conquer hills and distances with ease. Enjoy scenic rides, commutes, or errands with a boost of confidence from Cymbal Breeze!", PENDING_COMMIT_TIMESTAMP(), 72, 129999), + (1, 6, "Cymbal Trailblazer Backpack", "Carry all your essentials in style with the Trailblazer backpack. Its water-resistant material, multiple compartments, and comfortable straps keep your gear organized and accessible, allowing you to focus on the adventure. Blaze new trails with Cymbal Trailblazer!", PENDING_COMMIT_TIMESTAMP(), 24, 7999), + (1, 7, "Cymbal Phoenix Lights", "See and be seen with the Phoenix bike lights. Powerful LEDs and multiple light modes ensure superior visibility, enhancing your safety and enjoyment during day or night rides. Light up your journey with Cymbal Phoenix!", PENDING_COMMIT_TIMESTAMP(), 87, 3999), + (1, 8, "Cymbal Windstar Pump", "Flat tires are no match for the Windstar pump. Its compact design, lightweight construction, and high-pressure capacity make inflating tires quick and effortless. Get back on the road in no time with Cymbal Windstar!", PENDING_COMMIT_TIMESTAMP(), 36, 24999), + (1, 9,"Cymbal Odyssey Multi-Tool","Be prepared for anything with the Odyssey multi-tool. This handy gadget features essential tools like screwdrivers, hex wrenches, and tire levers, keeping you ready for minor repairs and adjustments on the go. Conquer your journey with Cymbal Odyssey!", PENDING_COMMIT_TIMESTAMP(), 52, 999), + (1, 10,"Cymbal Nomad Water Bottle","Stay hydrated on every ride with the Nomad water bottle. Its sleek design, BPA-free construction, and secure lock lid make it the perfect companion for staying refreshed and motivated throughout your adventures. Hydrate and explore with Cymbal Nomad!", PENDING_COMMIT_TIMESTAMP(), 42, 1299), +] + +columns = [ + "categoryId", + "productId", + "productName", + "productDescription", + "createTime", + "inventoryCount", + "priceInCents", +] + + +if __name__ == '__main__': + main() + diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index c905404..de9eb1b 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -337,7 +337,7 @@ def __init__( class SpannerVectorStore(VectorStore): GSQL_TYPES = { CONTENT_COLUMN_NAME: ["STRING"], - EMBEDDING_COLUMN_NAME: ["ARRAY"], + EMBEDDING_COLUMN_NAME: ["ARRAY", "ARRAY"], "metadata_json_column": ["JSON"], } @@ -405,7 +405,6 @@ def init_vector_store_table( vector_size, ) - print("ddl", "\n".join(ddl)) operation = database.update_ddl(ddl) print("Waiting for operation to complete...") @@ -466,7 +465,7 @@ def _generate_sql( ann_indices = list( filter( - lambda index: isinstance(index, VectorSearchIndex), secondary_indexes + lambda index: type(index) is VectorSearchIndex, secondary_indexes ) ) ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_ANN( @@ -475,9 +474,9 @@ def _generate_sql( secondary_indexes=list(ann_indices), ) - knn_indices = filter( - lambda index: isinstance(index, SecondaryIndex), secondary_indexes - ) + knn_indices = list(filter( + lambda index: type(index) is SecondaryIndex, secondary_indexes + )) ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_KNN( table_name, embedding_column, @@ -544,13 +543,8 @@ def _generate_create_table_sql( # Append column name and data type column_sql = f" {column_config.name} {column_config.type}" - vector_len = vector_length - if column_config.vector_length and column_config.vector_length >= 1: - vector_len = column_config.vector_length - - if vector_len and vector_len > 0: - column_sql += f"(vector_length=>{vector_len})" + column_sql += f"(vector_length=>{column_config.vector_length})" # Add nullable constraint if specified if not column_config.is_null: @@ -571,7 +565,6 @@ def _generate_create_table_sql( + ")" ) - # print(create_table_statement) return create_table_statement @staticmethod @@ -616,7 +609,7 @@ def _generate_secondary_indices_ddl_ANN( for secondary_index in secondary_indexes: column_name = secondary_index.columns[0] - statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({column_name})" + statement = f"CREATE VECTOR INDEX IF NOT EXISTS {secondary_index.index_name}\n\tON {table_name}({column_name})" if getattr(secondary_index, "nullable_column", False): statement += f"\n\tWHERE {column_name} IS NOT NULL" options_segments = [f"distance_type='{secondary_index.index_type}'"] @@ -651,6 +644,7 @@ def __init__( ignore_metadata_columns: Optional[List[str]] = None, metadata_json_column: Optional[str] = None, query_parameters: QueryParameters = QueryParameters(), + skip_not_nullable_columns = False, ): """ Initialize the SpannerVectorStore. @@ -681,6 +675,7 @@ def __init__( self._query_parameters = query_parameters self._embedding_service = embedding_service self.__strategy = None + self._skip_not_nullable_columns = skip_not_nullable_columns if metadata_columns is not None and ignore_metadata_columns is not None: raise Exception( @@ -795,9 +790,9 @@ def _validate_table_schema(self, column_type_map, types, default_columns): for substring in types[EMBEDDING_COLUMN_NAME] ): raise Exception( - "Embedding Column is not of correct type. Expected one of: {} but found: {}", + "Embedding Column is not of correct type. Expected one of: {} but found: {}".format( types[EMBEDDING_COLUMN_NAME], - embedding_column_type, + embedding_column_type) ) if self._metadata_json_column is not None: @@ -813,13 +808,14 @@ def _validate_table_schema(self, column_type_map, types, default_columns): embedding_column_type, ) - for column_name, column_config in column_type_map.items(): - if column_name not in self._columns_to_insert: - if "NO" == column_config[2].upper(): - raise Exception( - "Found not nullable constraint on column: {}.", - column_name, - ) + if not self._skip_not_nullable_columns: + for column_name, column_config in column_type_map.items(): + if column_name not in self._columns_to_insert: + if "NO" == column_config[2].upper(): + raise Exception( + "Found not nullable constraint on column: {}.", + column_name, + ) def _select_relevance_score_fn(self) -> Callable[[float], float]: if self._query_parameters.distance_strategy == DistanceStrategy.COSINE: @@ -1050,43 +1046,36 @@ def set_strategy(strategy: DistanceStrategy): def search_by_ANN( self, - table_name: str, - column_name: str, index_name: str, - embedding_column_name: str, - embedding: List[float], num_leaves: int, + embedding: List[float] = None, k: int = None, is_embedding_nullable: bool = False, where_condition: str = None, embedding_column_is_nullable: bool = False, + ascending: bool = True, + return_columns: List[str] = None, + strategy: DistanceStrategy = DistanceStrategy.COSINE, ) -> List[Any]: - # Firstly only the GoogleSQL dialect is supported. - if self._dialect_semantics != DatabaseDialect.GOOGLE_STANDARD_SQL: - raise Exception( - f"ANN is only supported for the GoogleSQL dialect not {dialect}. File an issue on Github?" - ) - sql = SpannerVectorStore._query_ANN( - table_name, - column_name, + self._table_name, index_name, - embedding_column_name, - embedding, + self._embedding_column, + embedding or self._embedding_service, num_leaves, - self._strategy, + strategy, k, is_embedding_nullable, where_condition, embedding_column_is_nullable=embedding_column_is_nullable, + ascending=ascending, + return_columns=return_columns, ) staleness = self._query_parameters.staleness with self._database.snapshot( **staleness if staleness is not None else {} ) as snapshot: - results = snapshot.execute_sql( - sql=sql_query, - ) + results = snapshot.execute_sql(sql=sql) return list(results) @staticmethod @@ -1149,7 +1138,7 @@ def _query_ANN( ) + f"ORDER BY {ann_strategy_name}(\n" + f" ARRAY{embedding}, {embedding_column_name}, options => JSON '" - + '{"num_leaves_to_search": %s})%s\n' + + '{"num_leaves_to_search": %s}\')%s\n' % (num_leaves, "" if ascending else " DESC") ) diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index c427c69..f4e5afc 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -565,7 +565,7 @@ def setup_database(self, client): (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId) ) ) - WHERE categoryId=1; + WHERE categoryId=1 """, ] database = client.instance(instance_id).database(google_database)