From 852b74e2011579eb18f7c8066dcce3d3c0707da3 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 26 Dec 2024 02:56:46 -0800 Subject: [PATCH 01/16] fix(testing+linting): add nox lint+format directives This change introduces new nox directives: * blacken: `nox -s blacken` * format: `nox -s format` to apply formatting to files * lint: `nox -s lint` to flag linting issues * unit: to run unit tests locally which are the basis to enable scalable development and continuous testing as I prepare to bring in Approximate Nearest Neighors (ANN) functionality into this package. Also while here, fixed a typo in the README.rst file that didn't have the correct import path. --- README.rst | 2 +- noxfile.py | 2 +- src/langchain_google_spanner/graph_store.py | 20 +++++++++---------- src/langchain_google_spanner/loader.py | 2 +- .../test_spanner_chat_message_history.py | 2 +- tests/integration/test_spanner_graph_qa.py | 2 +- tests/integration/test_spanner_loader.py | 4 ++-- .../integration/test_spanner_vector_store.py | 2 +- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/README.rst b/README.rst index 48fafdd..fd38ddc 100644 --- a/README.rst +++ b/README.rst @@ -73,7 +73,7 @@ Use a vector store to store embedded data and perform vector search. .. code-block:: python - from langchain_google_sapnner import SpannerVectorstore + from langchain_google_spanner import SpannerVectorstore from langchain.embeddings import VertexAIEmbeddings embeddings_service = VertexAIEmbeddings(model_name="textembedding-gecko@003") diff --git a/noxfile.py b/noxfile.py index d524d41..9d4c457 100644 --- a/noxfile.py +++ b/noxfile.py @@ -18,8 +18,8 @@ import os import pathlib -import shutil from pathlib import Path +import shutil import nox diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index d3e03a0..e6e211a 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -14,9 +14,9 @@ from __future__ import annotations +from abc import abstractmethod import re import string -from abc import abstractmethod from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union from google.cloud import spanner @@ -211,9 +211,9 @@ def from_nodes(name: str, nodes: List[Node]) -> ElementSchema: for k, v in n.properties.items() } ) - node.types[ElementSchema.NODE_KEY_COLUMN_NAME] = ( - TypeUtility.value_to_param_type(nodes[0].id) - ) + node.types[ + ElementSchema.NODE_KEY_COLUMN_NAME + ] = TypeUtility.value_to_param_type(nodes[0].id) return node @staticmethod @@ -264,12 +264,12 @@ def from_edges(name: str, edges: List[Relationship]) -> ElementSchema: for k, v in e.properties.items() } ) - edge.types[ElementSchema.NODE_KEY_COLUMN_NAME] = ( - TypeUtility.value_to_param_type(edges[0].source.id) - ) - edge.types[ElementSchema.TARGET_NODE_KEY_COLUMN_NAME] = ( - TypeUtility.value_to_param_type(edges[0].target.id) - ) + edge.types[ + ElementSchema.NODE_KEY_COLUMN_NAME + ] = TypeUtility.value_to_param_type(edges[0].source.id) + edge.types[ + ElementSchema.TARGET_NODE_KEY_COLUMN_NAME + ] = TypeUtility.value_to_param_type(edges[0].target.id) edge.source = NodeReference( edges[0].source.type, diff --git a/src/langchain_google_spanner/loader.py b/src/langchain_google_spanner/loader.py index 5bc8286..fdc466e 100644 --- a/src/langchain_google_spanner/loader.py +++ b/src/langchain_google_spanner/loader.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass import datetime import json -from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Union from google.cloud.spanner import Client, KeySet # type: ignore diff --git a/tests/integration/test_spanner_chat_message_history.py b/tests/integration/test_spanner_chat_message_history.py index 397af03..b83d6fb 100644 --- a/tests/integration/test_spanner_chat_message_history.py +++ b/tests/integration/test_spanner_chat_message_history.py @@ -16,10 +16,10 @@ import os import uuid -import pytest # noqa from google.cloud.spanner import Client # type: ignore from langchain_core.messages.ai import AIMessage from langchain_core.messages.human import HumanMessage +import pytest # noqa from langchain_google_spanner import SpannerChatMessageHistory diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py index 48758fc..55e8153 100644 --- a/tests/integration/test_spanner_graph_qa.py +++ b/tests/integration/test_spanner_graph_qa.py @@ -16,12 +16,12 @@ import random import string -import pytest from google.cloud import spanner from langchain.evaluation import load_evaluator from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_core.documents import Document from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings +import pytest from langchain_google_spanner.graph_qa import SpannerGraphQAChain from langchain_google_spanner.graph_store import SpannerGraphStore diff --git a/tests/integration/test_spanner_loader.py b/tests/integration/test_spanner_loader.py index fd8f028..f384cb0 100644 --- a/tests/integration/test_spanner_loader.py +++ b/tests/integration/test_spanner_loader.py @@ -15,9 +15,9 @@ import os import uuid -import pytest -from google.cloud.spanner import Client # type: ignore +from google.cloud.spanner import Client from langchain_core.documents import Document +import pytest from langchain_google_spanner.loader import Column, SpannerDocumentSaver, SpannerLoader diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index bf4de63..2403d7d 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -16,10 +16,10 @@ import os import uuid -import pytest from google.cloud.spanner import Client # type: ignore from langchain_community.document_loaders import HNLoader from langchain_community.embeddings import FakeEmbeddings +import pytest from langchain_google_spanner.vector_store import ( # type: ignore DistanceStrategy, From 3845f3faf7cb6a5698a4858b1b162904ae8cb3fe Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 26 Dec 2024 04:34:34 -0800 Subject: [PATCH 02/16] feat: add Approximate Nearest Neighbor support to distance strategies This change adds ANN distance strategies for GoogleSQL semantics. While here started unit tests to effectively test out components without having to have a running Cloud Spanner instance. Updates #94 --- README.rst | 8 +- noxfile.py | 6 +- requirements.txt | 4 +- samples/main.py | 180 ++++++++ src/langchain_google_spanner/graph_qa.py | 1 + src/langchain_google_spanner/graph_store.py | 20 +- src/langchain_google_spanner/loader.py | 2 +- src/langchain_google_spanner/vector_store.py | 423 +++++++++++++++--- .../test_spanner_chat_message_history.py | 2 +- tests/integration/test_spanner_graph_qa.py | 3 +- tests/integration/test_spanner_loader.py | 2 +- .../integration/test_spanner_vector_store.py | 235 +++++++++- tests/unit/test_vectore_store.py | 323 +++++++++++++ 13 files changed, 1132 insertions(+), 77 deletions(-) create mode 100644 samples/main.py create mode 100644 tests/unit/test_vectore_store.py diff --git a/README.rst b/README.rst index fd38ddc..cb6f9b0 100644 --- a/README.rst +++ b/README.rst @@ -73,7 +73,7 @@ Use a vector store to store embedded data and perform vector search. .. code-block:: python - from langchain_google_spanner import SpannerVectorstore + from langchain_google_sapnner import SpannerVectorstore from langchain.embeddings import VertexAIEmbeddings embeddings_service = VertexAIEmbeddings(model_name="textembedding-gecko@003") @@ -253,3 +253,9 @@ Disclaimer This is not an officially supported Google product. + +Limitations +---------- + +* Approximate Nearest Neighbors (ANN) strategies are only supported for the GoogleSQL dialect +* ANN's `ALTER VECTOR INDEX` is not yet supported by [Google Cloud Spanner](https://cloud.google.com/spanner/docs/find-approximate-nearest-neighbors#limitations) diff --git a/noxfile.py b/noxfile.py index 9d4c457..4425e3e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -18,7 +18,6 @@ import os import pathlib -from pathlib import Path import shutil import nox @@ -33,6 +32,7 @@ "docfx", "docs", "format", + "integration", "lint", "unit", ] @@ -41,7 +41,7 @@ nox.options.error_on_missing_interpreters = True -@nox.session(python=DEFAULT_PYTHON_VERSION) +@nox.session(python="3.10") def docs(session): """Build the docs for this library.""" @@ -76,7 +76,7 @@ def docs(session): ) -@nox.session(python=DEFAULT_PYTHON_VERSION) +@nox.session(python="3.10") def docfx(session): """Build the docfx yaml files for this library.""" diff --git a/requirements.txt b/requirements.txt index 9e16179..621d9c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -google-cloud-spanner==3.49.1 -langchain-core==0.3.9 +google-cloud-spanner==3.51.0 +langchain-core==0.3.15 langchain-community==0.3.1 pydantic==2.9.1 diff --git a/samples/main.py b/samples/main.py new file mode 100644 index 0000000..a0bf21e --- /dev/null +++ b/samples/main.py @@ -0,0 +1,180 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +import datetime +import os +import time +import uuid + +from google.cloud.spanner import Client # type: ignore +from langchain_community.document_loaders import HNLoader +from langchain_community.embeddings import FakeEmbeddings + +from langchain_google_spanner.vector_store import ( # type: ignore + DistanceStrategy, + QueryParameters, + SpannerVectorStore, + TableColumn, + VectorSearchIndex, +) + +project_id = 'quip-441723' +instance_id = 'contracting' +google_database = 'ann' +zone = os.environ.get("GOOGLE_DATABASE_ZONE", "us-west2") +table_name_ANN = "products" +OPERATION_TIMEOUT_SECONDS = 240 + +def use_case(): + # Initialize the vector store table if necessary. + distance_strategy = DistanceStrategy.COSINE + SpannerVectorStore.init_vector_store_table( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column=TableColumn("productId", type="INT64"), + vector_size=758, + embedding_column=TableColumn( + name="productDescriptionEmbedding", + type="ARRAY", + is_null=True, + ), + metadata_columns=[ + TableColumn(name="categoryId", type="INT64", is_null=False), + TableColumn(name="productName", type="STRING(MAX)", is_null=False), + TableColumn( + name="productDescription", type="STRING(MAX)", is_null=False + ), + TableColumn(name="inventoryCount", type="INT64", is_null=False), + TableColumn(name="priceInCents", type="INT64", is_null=True), + TableColumn(name="createTime", type="TIMESTAMP", is_null=False), + ], + secondary_indexes=[ + VectorSearchIndex( + index_name="ProductDescriptionEmbeddingIndex", + columns=["productDescriptionEmbedding"], + nullable_column=True, + num_branches=1000, + tree_depth=3, + index_type=distance_strategy, + num_leaves=100000, + ), + ], + ) + + # Create the models if necessary. + client = Client(project=project_id) + database = client.instance(instance_id).database(google_database) + + model_ddl_statements = [ + f""" + CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( + content STRING(MAX), + ) OUTPUT( + embeddings STRUCT, values ARRAY>, + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' + ) + """, + f""" + CREATE MODEL IF NOT EXISTS LLMModel INPUT( + prompt STRING(MAX), + ) OUTPUT( + content STRING(MAX), + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', + default_batch_size = 1 + ) + """, + ] + operation = database.update_ddl(model_ddl_statements) + operation.result(OPERATION_TIMEOUT_SECONDS) + + def clear_and_insert_data(tx): + tx.execute_update("DELETE FROM products WHERE 1=1") + tx.insert( + 'products', + columns=[ + 'categoryId', 'productId', 'productName', + 'productDescription', + 'createTime', 'inventoryCount', 'priceInCents', + ], + values=raw_data, + ) + + tx.execute_update( + """UPDATE products p1 + SET productDescriptionEmbedding = + (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, + (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) + WHERE categoryId=1""", + ) + + embeddings = [] + rows = tx.execute_sql( + """SELECT embeddings.values + FROM ML.PREDICT( + MODEL EmbeddingsModel, + (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) + )""") + + for row in rows: + for nesting in row: + embeddings.extend(nesting) + + return embeddings + + embeddings = database.run_in_transaction(clear_and_insert_data) + + vec_store = SpannerVectorStore( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column="categoryId", + embedding_service=embeddings, + embedding_column="productDescriptionEmbedding", + skip_not_nullable_columns=True, + ) + vec_store.search_by_ANN( + 'ProductDescriptionEmbeddingIndex', + 1000, + k=20, + embedding_column_is_nullable=True, + return_columns=['productName', 'productDescription', 'inventoryCount'], + ) + +def main(): + use_case() + + +def PENDING_COMMIT_TIMESTAMP(): + return (datetime.datetime.utcnow() + datetime.timedelta(days=1)).isoformat() + "Z" + return 'PENDING_COMMIT_TIMESTAMP()' + +raw_data = [ + (1, 1, "Cymbal Helios Helmet", "Safety meets style with the Cymbal children's bike helmet. Its lightweight design, superior ventilation, and adjustable fit ensure comfort and protection on every ride. Stay bright and keep your child safe under the sun with Cymbal Helios!", PENDING_COMMIT_TIMESTAMP(), 100, 10999), + (1, 2, "Cymbal Sprout", "Let their cycling journey begin with the Cymbal Sprout, the ideal balance bike for beginning riders ages 2-4 years. Its lightweight frame, low seat height, and puncture-proof tires promote stability and confidence as little ones learn to balance and steer. Watch them sprout into cycling enthusiasts with Cymbal Sprout!", PENDING_COMMIT_TIMESTAMP(), 10, 13999), + (1, 3, "Cymbal Spark Jr.", "Light, vibrant, and ready for adventure, the Spark Jr. is the perfect first bike for young riders (ages 5-8). Its sturdy frame, easy-to-use brakes, and puncture-resistant tires inspire confidence and endless playtime. Let the spark of cycling ignite with Cymbal!", PENDING_COMMIT_TIMESTAMP(), 34, 13900), + (1, 4, "Cymbal Summit", "Conquering trails is a breeze with the Summit mountain bike. Its lightweight aluminum frame, responsive suspension, and powerful disc brakes provide exceptional control and comfort for experienced bikers navigating rocky climbs or shredding downhill. Reach new heights with Cymbal Summit!", PENDING_COMMIT_TIMESTAMP(), 0, 79999), + (1, 5, "Cymbal Breeze", "Cruise in style and embrace effortless pedaling with the Breeze electric bike. Its whisper-quiet motor and long-lasting battery let you conquer hills and distances with ease. Enjoy scenic rides, commutes, or errands with a boost of confidence from Cymbal Breeze!", PENDING_COMMIT_TIMESTAMP(), 72, 129999), + (1, 6, "Cymbal Trailblazer Backpack", "Carry all your essentials in style with the Trailblazer backpack. Its water-resistant material, multiple compartments, and comfortable straps keep your gear organized and accessible, allowing you to focus on the adventure. Blaze new trails with Cymbal Trailblazer!", PENDING_COMMIT_TIMESTAMP(), 24, 7999), + (1, 7, "Cymbal Phoenix Lights", "See and be seen with the Phoenix bike lights. Powerful LEDs and multiple light modes ensure superior visibility, enhancing your safety and enjoyment during day or night rides. Light up your journey with Cymbal Phoenix!", PENDING_COMMIT_TIMESTAMP(), 87, 3999), + (1, 8, "Cymbal Windstar Pump", "Flat tires are no match for the Windstar pump. Its compact design, lightweight construction, and high-pressure capacity make inflating tires quick and effortless. Get back on the road in no time with Cymbal Windstar!", PENDING_COMMIT_TIMESTAMP(), 36, 24999), + (1, 9,"Cymbal Odyssey Multi-Tool","Be prepared for anything with the Odyssey multi-tool. This handy gadget features essential tools like screwdrivers, hex wrenches, and tire levers, keeping you ready for minor repairs and adjustments on the go. Conquer your journey with Cymbal Odyssey!", PENDING_COMMIT_TIMESTAMP(), 52, 999), + (1, 10,"Cymbal Nomad Water Bottle","Stay hydrated on every ride with the Nomad water bottle. Its sleek design, BPA-free construction, and secure lock lid make it the perfect companion for staying refreshed and motivated throughout your adventures. Hydrate and explore with Cymbal Nomad!", PENDING_COMMIT_TIMESTAMP(), 42, 1299), +] + +columns = [ + "categoryId", + "productId", + "productName", + "productDescription", + "createTime", + "inventoryCount", + "priceInCents", +] + + +if __name__ == '__main__': + main() + diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py index 36251f4..f071773 100644 --- a/src/langchain_google_spanner/graph_qa.py +++ b/src/langchain_google_spanner/graph_qa.py @@ -288,6 +288,7 @@ def _call( inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: + intermediate_steps: List = [] """Generate gql statement, uses it to look up in db and answer question.""" diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index e6e211a..d3e03a0 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -14,9 +14,9 @@ from __future__ import annotations -from abc import abstractmethod import re import string +from abc import abstractmethod from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union from google.cloud import spanner @@ -211,9 +211,9 @@ def from_nodes(name: str, nodes: List[Node]) -> ElementSchema: for k, v in n.properties.items() } ) - node.types[ - ElementSchema.NODE_KEY_COLUMN_NAME - ] = TypeUtility.value_to_param_type(nodes[0].id) + node.types[ElementSchema.NODE_KEY_COLUMN_NAME] = ( + TypeUtility.value_to_param_type(nodes[0].id) + ) return node @staticmethod @@ -264,12 +264,12 @@ def from_edges(name: str, edges: List[Relationship]) -> ElementSchema: for k, v in e.properties.items() } ) - edge.types[ - ElementSchema.NODE_KEY_COLUMN_NAME - ] = TypeUtility.value_to_param_type(edges[0].source.id) - edge.types[ - ElementSchema.TARGET_NODE_KEY_COLUMN_NAME - ] = TypeUtility.value_to_param_type(edges[0].target.id) + edge.types[ElementSchema.NODE_KEY_COLUMN_NAME] = ( + TypeUtility.value_to_param_type(edges[0].source.id) + ) + edge.types[ElementSchema.TARGET_NODE_KEY_COLUMN_NAME] = ( + TypeUtility.value_to_param_type(edges[0].target.id) + ) edge.source = NodeReference( edges[0].source.type, diff --git a/src/langchain_google_spanner/loader.py b/src/langchain_google_spanner/loader.py index fdc466e..5bc8286 100644 --- a/src/langchain_google_spanner/loader.py +++ b/src/langchain_google_spanner/loader.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import datetime import json +from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Union from google.cloud.spanner import Client, KeySet # type: ignore diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 444c37a..c5ac979 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -14,13 +14,12 @@ from __future__ import annotations -import datetime -import logging from abc import ABC, abstractmethod +import datetime from enum import Enum +import logging from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union -import numpy as np from google.cloud import spanner # type: ignore from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from google.cloud.spanner_v1 import JsonObject, param_types @@ -28,6 +27,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore +import numpy as np from .version import __version__ @@ -67,11 +67,13 @@ class TableColumn: column_name (str): The name of the column. type (str): The type of the column. is_null (bool): Indicates whether the column allows null values. + vector_length Optional(int): for ANN, mandatory and must be >=1 for the embedding column. """ name: str type: str is_null: bool = True + vector_length: int = None def __post_init__(self): # Check if column_name is None after initialization @@ -81,12 +83,20 @@ def __post_init__(self): if self.type is None: raise ValueError("type is mandatory and cannot be None.") + if (self.vector_length is not None) and (self.vector_length <= 0): + raise ValueError("vector_length must be >=1") + -@dataclass class SecondaryIndex: - index_name: str - columns: list[str] - storing_columns: Optional[list[str]] = None + def __init__( + self, + index_name: str, + columns: list[str], + storing_columns: Optional[list[str]] = None, + ): + self.index_name = index_name + self.columns = columns + self.storing_columns = storing_columns def __post_init__(self): # Check if column_name is None after initialization @@ -97,13 +107,51 @@ def __post_init__(self): raise ValueError("Index Columns can't be None") +class VectorSearchIndex(SecondaryIndex): + """ + The index for use with Approximate Nearest Neighbor (ANN) vector search. + """ + + def __init__( + self, + num_leaves: int, + num_branches: int, + tree_depth: int, + index_type: DistanceStrategy, + nullable_column: bool = False, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.num_leaves = num_leaves + self.num_branches = num_branches + self.tree_depth = tree_depth + self.index_type = index_type + self.nullable_column = nullable_column + + def __post_init__(self): + if self.index_name is None: + raise ValueError("index_name must be set") + + if len(self.columns) == 0: + raise ValueError("columns must be set") + + ok_tree_depth = self.tree_depth in (2, 3) + if not ok_tree_depth: + raise ValueError("tree_depth must be either 2 or 3") + + class DistanceStrategy(Enum): """ Enum for distance calculation strategies. """ COSINE = 1 - EUCLIDEIAN = 2 + EUCLIDEAN = 2 + DOT_PRODUCT = 3 + + def __str__(self): + return self.name class DialectSemantics(ABC): @@ -112,7 +160,7 @@ class DialectSemantics(ABC): """ @abstractmethod - def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: + def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: """ Abstract method to get the distance function based on the provided distance strategy. @@ -139,16 +187,30 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: ) +# Maps between distance strategy enums and the appropriate vector search index name. +GOOGLE_DIALECT_TO_KNN_DISTANCE_FUNCTIONS = { + DistanceStrategy.COSINE: "COSINE_DISTANCE", + DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", + DistanceStrategy.EUCLIDEAN: "EUCLIDEAN_DISTANCE", +} + +# Maps between distance strategy and the appropriate ANN search function name. +GOOGLE_DIALECT_TO_ANN_DISTANCE_FUNCTIONS = { + DistanceStrategy.COSINE: "APPROX_COSINE_DISTANCE", + DistanceStrategy.DOT_PRODUCT: "APPROX_DOT_PRODUCT", + DistanceStrategy.EUCLIDEAN: "APPROX_EUCLIDEAN_DISTANCE", +} + + class GoogleSqlSemantics(DialectSemantics): """ Implementation of dialect semantics for Google SQL. """ - def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: - if distance_strategy == DistanceStrategy.COSINE: - return "COSINE_DISTANCE" - - return "EUCLIDEAN_DISTANCE" + def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: + return GOOGLE_DIALECT_TO_KNN_DISTANCE_FUNCTIONS.get( + distance_strategy, "EUCLIDEAN" + ) def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]: where_clause_condition = " AND ".join( @@ -162,16 +224,33 @@ def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]: def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: return dict(zip(columns, values)) + def getIndexDistanceType(self, distance_strategy) -> str: + value = _GOOGLE_ALGO_INDEX_NAME.get(distance_strategy, None) + if value is None: + raise Exception(f"{distance_strategy} is unsupported for distance_type") + return value + + +# Maps between DistanceStrategy and the expected PostgreSQL distance equivalent. +PG_DIALECT_TO_KNN_DISTANCE_FUNCTIONS = { + DistanceStrategy.COSINE: "spanner.cosine_distance", + DistanceStrategy.DOT_PRODUCT: "spanner.dot_product", + DistanceStrategy.EUCLIDEAN: "spanner.euclidean_distance", +} + class PGSqlSemantics(DialectSemantics): """ Implementation of dialect semantics for PostgreSQL. """ - def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: - if distance_strategy == DistanceStrategy.COSINE: - return "spanner.cosine_distance" - return "spanner.euclidean_distance" + def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: + name = PG_DIALECT_TO_KNN_DISTANCE_FUNCTIONS.get(distance_strategy, None) + if name is None: + raise Exception( + "Unsupported PostgreSQL distance strategy: {}".format(distance_strategy) + ) + return name def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]: where_clause_condition = " AND ".join( @@ -206,15 +285,16 @@ class QueryParameters: class NearestNeighborsAlgorithm(Enum): """ - Enum for nearest neighbors search algorithms. + Enum for k-nearest neighbors search algorithms. """ EXACT_NEAREST_NEIGHBOR = 1 + APPROXIMATE_NEAREST_NEIGHBOR = 2 def __init__( self, algorithm=NearestNeighborsAlgorithm.EXACT_NEAREST_NEIGHBOR, - distance_strategy=DistanceStrategy.EUCLIDEIAN, + distance_strategy=DistanceStrategy.EUCLIDEAN, read_timestamp: Optional[datetime.datetime] = None, min_read_timestamp: Optional[datetime.datetime] = None, max_staleness: Optional[datetime.timedelta] = None, @@ -257,7 +337,7 @@ def __init__( class SpannerVectorStore(VectorStore): GSQL_TYPES = { CONTENT_COLUMN_NAME: ["STRING"], - EMBEDDING_COLUMN_NAME: ["ARRAY"], + EMBEDDING_COLUMN_NAME: ["ARRAY", "ARRAY"], "metadata_json_column": ["JSON"], } @@ -283,7 +363,7 @@ def init_vector_store_table( metadata_columns: Optional[List[TableColumn]] = None, primary_key: Optional[str] = None, vector_size: Optional[int] = None, - secondary_indexes: Optional[List[SecondaryIndex]] = None, + secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None, ) -> bool: """ Initialize the vector store new table in Google Cloud Spanner. @@ -297,7 +377,7 @@ def init_vector_store_table( - content_column (str): The name of the content column. Defaults to CONTENT_COLUMN_NAME. - embedding_column (str): The name of the embedding column. Defaults to EMBEDDING_COLUMN_NAME. - metadata_columns (Optional[List[Tuple]]): List of tuples containing metadata column information. Defaults to None. - - vector_size (Optional[int]): The size of the vector. Defaults to None. + - vector_size (Optional[int]): The size of the vector for KNN or ANN. Defaults to None. It is presumed that exactly ONLY 1 field will have the vector. """ client = client_with_user_agent(client, USER_AGENT_VECTOR_STORE) @@ -322,6 +402,7 @@ def init_vector_store_table( metadata_columns, primary_key, secondary_indexes, + vector_size, ) operation = database.update_ddl(ddl) @@ -340,7 +421,8 @@ def _generate_sql( embedding_column, column_configs, primary_key, - secondary_indexes: Optional[List[SecondaryIndex]] = None, + secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None, + vector_size: int = None, ): """ Generate SQL for creating the vector store table. @@ -352,11 +434,70 @@ def _generate_sql( - content_column: The name of the content column. - embedding_column: The name of the embedding column. - column_names: List of tuples containing metadata column information. + - vector_size: The vector length to be used by default. It is presumed by proxy of the langchain usage patterns, that exactly ONE column will be used as the embedding. Returns: - str: The generated SQL. """ - create_table_statement = f"CREATE TABLE {table_name} (\n" + + # 1. If any of the columns is a VectorSearchIndex + embedding_config = list( + filter(lambda x: x.name == embedding_column, column_configs) + ) + print("column_configs", column_configs, "\nembedding_config", embedding_config) + if embedding_column and len(embedding_config) > 0: + config = embedding_config[0] + if config.vector_length is None or config.vector_length <= 0: + raise ValueError("vector_length is mandatory and must be >=1") + + ddl_statements = [ + SpannerVectorStore._generate_create_table_sql( + table_name, + id_column, + content_column, + embedding_column, + column_configs, + primary_key, + dialect, + vector_length=vector_size, + ) + ] + + ann_indices = list( + filter( + lambda index: type(index) is VectorSearchIndex, secondary_indexes + ) + ) + ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_ANN( + table_name, + dialect, + secondary_indexes=list(ann_indices), + ) + + knn_indices = list(filter( + lambda index: type(index) is SecondaryIndex, secondary_indexes + )) + ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_KNN( + table_name, + embedding_column, + dialect, + secondary_indexes=list(knn_indices), + ) + + return ddl_statements + + @staticmethod + def _generate_create_table_sql( + table_name, + id_column, + content_column, + embedding_column, + column_configs, + primary_key, + dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + vector_length=None, + ): + create_table_statement = f"CREATE TABLE IF NOT EXISTS {table_name} (\n" if not isinstance(id_column, TableColumn): if dialect == DatabaseDialect.POSTGRESQL: @@ -382,6 +523,11 @@ def _generate_sql( embedding_column, "ARRAY", is_null=True ) + if not embedding_column.vector_length: + ok_vector_length = vector_length and vector_length > 0 + if ok_vector_length: + embedding_column.vector_length = vector_length + configs = [id_column, content_column, embedding_column] if column_configs is not None: @@ -397,6 +543,9 @@ def _generate_sql( # Append column name and data type column_sql = f" {column_config.name} {column_config.type}" + if column_config.vector_length and column_config.vector_length >= 1: + column_sql += f"(vector_length=>{column_config.vector_length})" + # Add nullable constraint if specified if not column_config.is_null: column_sql += " NOT NULL" @@ -416,30 +565,70 @@ def _generate_sql( + ")" ) + return create_table_statement + + @staticmethod + def _generate_secondary_indices_ddl_KNN( + table_name, embedding_column, dialect, secondary_indexes=None + ): + if not secondary_indexes: + return [] + secondary_index_ddl_statements = [] + for secondary_index in secondary_indexes: + statement = f"CREATE INDEX {secondary_index.index_name} ON {table_name}(" + statement = statement + ",".join(secondary_index.columns) + ") " - if secondary_indexes is not None: - for secondary_index in secondary_indexes: - statement = ( - f"CREATE INDEX {secondary_index.index_name} ON {table_name}(" - ) - statement = statement + ",".join(secondary_index.columns) + ") " + if dialect == DatabaseDialect.POSTGRESQL: + statement = statement + "INCLUDE (" + else: + statement = statement + "STORING (" - if dialect == DatabaseDialect.POSTGRESQL: - statement = statement + "INCLUDE (" - else: - statement = statement + "STORING (" + if secondary_index.storing_columns is None: + secondary_index.storing_columns = [embedding_column.name] + elif embedding_column not in secondary_index.storing_columns: + secondary_index.storing_columns.append(embedding_column.name) - if secondary_index.storing_columns is None: - secondary_index.storing_columns = [embedding_column.name] - elif embedding_column not in secondary_index.storing_columns: - secondary_index.storing_columns.append(embedding_column.name) + statement = statement + ",".join(secondary_index.storing_columns) + ")" + secondary_index_ddl_statements.append(statement) + return secondary_index_ddl_statements - statement = statement + ",".join(secondary_index.storing_columns) + ")" + @staticmethod + def _generate_secondary_indices_ddl_ANN( + table_name, dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, secondary_indexes=[] + ): + if dialect != DatabaseDialect.GOOGLE_STANDARD_SQL: + raise Exception( + f"ANN is only supported for the GoogleSQL dialect not {dialect}. File an issue on Github?" + ) + + if not secondary_indexes: + return [] + + secondary_index_ddl_statements = [] - secondary_index_ddl_statements.append(statement) + for secondary_index in secondary_indexes: + column_name = secondary_index.columns[0] + statement = f"CREATE VECTOR INDEX IF NOT EXISTS {secondary_index.index_name}\n\tON {table_name}({column_name})" + if getattr(secondary_index, "nullable_column", False): + statement += f"\n\tWHERE {column_name} IS NOT NULL" + options_segments = [f"distance_type='{secondary_index.index_type}'"] + if getattr(secondary_index, "tree_depth", 0) > 0: + tree_depth = secondary_index.tree_depth + if tree_depth not in (2, 3): + raise Exception(f"tree_depth: {tree_depth} must be either 2 or 3") + options_segments.append(f"tree_depth={secondary_index.tree_depth}") - return [create_table_statement] + secondary_index_ddl_statements + if secondary_index.num_branches > 0: + options_segments.append(f"num_branches={secondary_index.num_branches}") + + if secondary_index.num_leaves > 0: + options_segments.append(f"num_leaves={secondary_index.num_leaves}") + + statement += "\n\tOPTIONS(" + ", ".join(options_segments) + ")" + secondary_index_ddl_statements.append(statement.strip()) + + return secondary_index_ddl_statements def __init__( self, @@ -455,6 +644,7 @@ def __init__( ignore_metadata_columns: Optional[List[str]] = None, metadata_json_column: Optional[str] = None, query_parameters: QueryParameters = QueryParameters(), + skip_not_nullable_columns = False, ): """ Initialize the SpannerVectorStore. @@ -484,6 +674,8 @@ def __init__( self._query_parameters = query_parameters self._embedding_service = embedding_service + self.__strategy = None + self._skip_not_nullable_columns = skip_not_nullable_columns if metadata_columns is not None and ignore_metadata_columns is not None: raise Exception( @@ -532,6 +724,7 @@ def __init__( ] else: self._metadata_columns = [] + if metadata_columns is not None: columns_to_insert.extend(metadata_columns) self._metadata_columns.extend(metadata_columns) @@ -597,9 +790,9 @@ def _validate_table_schema(self, column_type_map, types, default_columns): for substring in types[EMBEDDING_COLUMN_NAME] ): raise Exception( - "Embedding Column is not of correct type. Expected one of: {} but found: {}", + "Embedding Column is not of correct type. Expected one of: {} but found: {}".format( types[EMBEDDING_COLUMN_NAME], - embedding_column_type, + embedding_column_type) ) if self._metadata_json_column is not None: @@ -615,18 +808,19 @@ def _validate_table_schema(self, column_type_map, types, default_columns): embedding_column_type, ) - for column_name, column_config in column_type_map.items(): - if column_name not in self._columns_to_insert: - if "NO" == column_config[2].upper(): - raise Exception( - "Found not nullable constraint on column: {}.", - column_name, - ) + if not self._skip_not_nullable_columns: + for column_name, column_config in column_type_map.items(): + if column_name not in self._columns_to_insert: + if "NO" == column_config[2].upper(): + raise Exception( + "Found not nullable constraint on column: {}.", + column_name, + ) def _select_relevance_score_fn(self) -> Callable[[float], float]: if self._query_parameters.distance_strategy == DistanceStrategy.COSINE: return self._cosine_relevance_score_fn - elif self._query_parameters.distance_strategy == DistanceStrategy.EUCLIDEIAN: + elif self._query_parameters.distance_strategy == DistanceStrategy.EUCLIDEAN: return self._euclidean_relevance_score_fn else: raise Exception( @@ -715,6 +909,13 @@ def _insert_data(self, records, columns_to_insert): values=records, ) + def add_ann_rows( + self, data: List[Tuple], id_column_index: int, columns=Dict[str, str] + ) -> List[str]: + self._insert_data(data, columns) + ids = list(map(lambda row: row[id_column_index], data)) + return ids + def add_documents( self, documents: List[Document], @@ -832,7 +1033,7 @@ def similarity_search_with_score_by_vector( List[Document]: List of documents most similar to the query. """ - results, column_order_map = self._get_rows_by_similarity_search( + results, column_order_map = self._get_rows_by_similarity_search_knn( embedding, k, pre_filter ) documents = self._get_documents_from_query_results( @@ -840,7 +1041,119 @@ def similarity_search_with_score_by_vector( ) return documents - def _get_rows_by_similarity_search( + def set_strategy(strategy: DistanceStrategy): + self.__strategy = strategy + + def search_by_ANN( + self, + index_name: str, + num_leaves: int, + embedding: List[float] = None, + k: int = None, + is_embedding_nullable: bool = False, + where_condition: str = None, + embedding_column_is_nullable: bool = False, + ascending: bool = True, + return_columns: List[str] = None, + strategy: DistanceStrategy = DistanceStrategy.COSINE, + ) -> List[Any]: + sql = SpannerVectorStore._query_ANN( + self._table_name, + index_name, + self._embedding_column, + embedding or self._embedding_service, + num_leaves, + strategy, + k, + is_embedding_nullable, + where_condition, + embedding_column_is_nullable=embedding_column_is_nullable, + ascending=ascending, + return_columns=return_columns, + ) + staleness = self._query_parameters.staleness + with self._database.snapshot( + **staleness if staleness is not None else {} + ) as snapshot: + results = snapshot.execute_sql(sql=sql) + return list(results) + + @staticmethod + def _query_ANN( + table_name: str, + index_name: str, + embedding_column_name: str, + embedding: List[float], + num_leaves: int, + strategy: DistanceStrategy = DistanceStrategy.COSINE, + k: int = None, + is_embedding_nullable: bool = False, + where_condition: str = None, + embedding_column_is_nullable: bool = False, + ascending: bool = True, + return_columns: List[str] = None, + ): + """ + Sample query: + SELECT DocId + FROM Documents@{FORCE_INDEX=DocEmbeddingIndex} + ORDER BY APPROX_EUCLIDEAN_DISTANCE( + ARRAY[1.0, 2.0, 3.0], DocEmbedding, + options => JSON '{"num_leaves_to_search": 10}') + LIMIT 100 + + OR + + SELECT DocId + FROM Documents@{FORCE_INDEX=DocEmbeddingIndex} + WHERE NullableDocEmbedding IS NOT NULL + ORDER BY APPROX_EUCLIDEAN_DISTANCE( + ARRAY[1.0, 2.0, 3.0], NullableDocEmbedding, + options => JSON '{"num_leaves_to_search": 10}') + LIMIT 100 + """ + + if not embedding_column_name: + raise Exception("embedding_column_name must be set") + + ann_strategy_name = GOOGLE_DIALECT_TO_ANN_DISTANCE_FUNCTIONS.get(strategy, None) + if not ann_strategy_name: + raise Exception(f"{strategy} is not supported for ANN") + + column_names = None + if return_columns: + column_names = ",".join(return_columns) + + if not column_names: + column_names = "*" + + sql = ( + f"SELECT {column_names} FROM {table_name}" + + "@{FORCE_INDEX=" + + f"{index_name}" + + ( + "}\n" + if (not embedding_column_is_nullable) + else "}\nWHERE " + f"{embedding_column_name} IS NOT NULL\n" + ) + + f"ORDER BY {ann_strategy_name}(\n" + + f" ARRAY{embedding}, {embedding_column_name}, options => JSON '" + + '{"num_leaves_to_search": %s}\')%s\n' + % (num_leaves, "" if ascending else " DESC") + ) + + if where_condition: + sql += " WHERE " + where_condition + "\n" + + if k: + sql += f"LIMIT {k}" + + return sql.strip() + + def _get_rows_by_similarity_search_ann(): + pass + + def _get_rows_by_similarity_search_knn( self, embedding: List[float], k: int, @@ -1017,7 +1330,7 @@ def max_marginal_relevance_search_with_score_by_vector( List of Documents and similarity scores selected by maximal marginal relevance and score for each. """ - results, column_order_map = self._get_rows_by_similarity_search( + results, column_order_map = self._get_rows_by_similarity_search_knn( embedding, fetch_k, pre_filter ) diff --git a/tests/integration/test_spanner_chat_message_history.py b/tests/integration/test_spanner_chat_message_history.py index b83d6fb..397af03 100644 --- a/tests/integration/test_spanner_chat_message_history.py +++ b/tests/integration/test_spanner_chat_message_history.py @@ -16,10 +16,10 @@ import os import uuid +import pytest # noqa from google.cloud.spanner import Client # type: ignore from langchain_core.messages.ai import AIMessage from langchain_core.messages.human import HumanMessage -import pytest # noqa from langchain_google_spanner import SpannerChatMessageHistory diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py index 55e8153..8bac7b8 100644 --- a/tests/integration/test_spanner_graph_qa.py +++ b/tests/integration/test_spanner_graph_qa.py @@ -16,12 +16,12 @@ import random import string +import pytest from google.cloud import spanner from langchain.evaluation import load_evaluator from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_core.documents import Document from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings -import pytest from langchain_google_spanner.graph_qa import SpannerGraphQAChain from langchain_google_spanner.graph_store import SpannerGraphStore @@ -145,6 +145,7 @@ def load_data(graph: SpannerGraphStore): class TestSpannerGraphQAChain: + @pytest.fixture(scope="module") def setup_db_load_data(self): graph = get_spanner_graph() diff --git a/tests/integration/test_spanner_loader.py b/tests/integration/test_spanner_loader.py index f384cb0..568b41a 100644 --- a/tests/integration/test_spanner_loader.py +++ b/tests/integration/test_spanner_loader.py @@ -16,8 +16,8 @@ import uuid from google.cloud.spanner import Client -from langchain_core.documents import Document import pytest +from langchain_core.documents import Document from langchain_google_spanner.loader import Column, SpannerDocumentSaver, SpannerLoader diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index 2403d7d..2c3bb4d 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -26,13 +26,16 @@ QueryParameters, SpannerVectorStore, TableColumn, + VectorSearchIndex, ) project_id = os.environ["PROJECT_ID"] instance_id = os.environ["INSTANCE_ID"] google_database = os.environ["GOOGLE_DATABASE"] -pg_database = os.environ["PG_DATABASE"] +pg_database = os.environ.get("PG_DATABASE", None) +zone = os.environ.get("GOOGLE_DATABASE_ZONE", "us-west2") table_name = "test_table" + str(uuid.uuid4()).replace("-", "_") +table_name_ANN = "products" OPERATION_TIMEOUT_SECONDS = 240 @@ -207,7 +210,7 @@ def test_init_vector_store_table4(self): ) -class TestSpannerVectorStoreGoogleSQL: +class TestSpannerVectorStoreGoogleSQL_KNN: @pytest.fixture(scope="class") def setup_database(self, client): SpannerVectorStore.init_vector_store_table( @@ -389,6 +392,234 @@ def test_spanner_vector_search_data4(self, setup_database): assert len(docs) == 3 +class TestSpannerVectorStoreGoogleSQL_ANN: + @pytest.fixture(scope="class") + def setup_database(self, client): + """ + CREATE TABLE products ( + categoryId INT64 NOT NULL, + productId INT64 NOT NULL, + productName STRING(MAX) NOT NULL, + productDescription STRING(MAX) NOT NULL, + productDescriptionEmbedding ARRAY(vector_length=>728), + createTime TIMESTAMP NOT NULL OPTIONS ( + allow_commit_timestamp = true + ), + inventoryCount INT64 NOT NULL, + priceInCents INT64, + ) PRIMARY KEY(categoryId, productId); + """ + distance_strategy = DistanceStrategy.COSINE + SpannerVectorStore.init_vector_store_table( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column=TableColumn("productId", type="INT64"), + vector_size=758, + embedding_column=TableColumn( + name="productDescriptionEmbedding", + type="ARRAY", + is_null=True, + ), + metadata_columns=[ + TableColumn(name="categoryId", type="INT64", is_null=False), + TableColumn(name="productName", type="STRING(MAX)", is_null=False), + TableColumn( + name="productDescription", type="STRING(MAX)", is_null=False + ), + TableColumn(name="inventoryCount", type="INT64", is_null=False), + TableColumn(name="priceInCents", type="INT64", is_null=True), + ], + secondary_indexes=[ + VectorSearchIndex( + index_name="ProductDescriptionEmbeddingIndex", + columns=["productDescriptionEmbedding"], + nullable_column=True, + num_branches=1000, + tree_depth=3, + index_type=distance_strategy, + num_leaves=100000, + ), + ], + ) + + raw_data = [ + ( + 1, + 1, + "Cymbal Helios Helmet", + "Safety meets style with the Cymbal children's bike helmet. Its lightweight design, superior ventilation, and adjustable fit ensure comfort and protection on every ride. Stay bright and keep your child safe under the sun with Cymbal Helios!", + 100, + 10999, + ), + ( + 1, + 2, + "Cymbal Sprout", + "Let their cycling journey begin with the Cymbal Sprout, the ideal balance bike for beginning riders ages 2-4 years. Its lightweight frame, low seat height, and puncture-proof tires promote stability and confidence as little ones learn to balance and steer. Watch them sprout into cycling enthusiasts with Cymbal Sprout!", + 10, + 13999, + ), + ( + 1, + 3, + "Cymbal Spark Jr.", + "Light, vibrant, and ready for adventure, the Spark Jr. is the perfect first bike for young riders (ages 5-8). Its sturdy frame, easy-to-use brakes, and puncture-resistant tires inspire confidence and endless playtime. Let the spark of cycling ignite with Cymbal!", + 34, + 13900, + ), + ( + 1, + 4, + "Cymbal Summit", + "Conquering trails is a breeze with the Summit mountain bike. Its lightweight aluminum frame, responsive suspension, and powerful disc brakes provide exceptional control and comfort for experienced bikers navigating rocky climbs or shredding downhill. Reach new heights with Cymbal Summit!", + 0, + 79999, + ), + ( + 1, + 5, + "Cymbal Breeze", + "Cruise in style and embrace effortless pedaling with the Breeze electric bike. Its whisper-quiet motor and long-lasting battery let you conquer hills and distances with ease. Enjoy scenic rides, commutes, or errands with a boost of confidence from Cymbal Breeze!", + 72, + 129999, + ), + ( + 1, + 6, + "Cymbal Trailblazer Backpack", + "Carry all your essentials in style with the Trailblazer backpack. Its water-resistant material, multiple compartments, and comfortable straps keep your gear organized and accessible, allowing you to focus on the adventure. Blaze new trails with Cymbal Trailblazer!", + 24, + 7999, + ), + ( + 1, + 7, + "Cymbal Phoenix Lights", + "See and be seen with the Phoenix bike lights. Powerful LEDs and multiple light modes ensure superior visibility, enhancing your safety and enjoyment during day or night rides. Light up your journey with Cymbal Phoenix!", + 87, + 3999, + ), + ( + 1, + 8, + "Cymbal Windstar Pump", + "Flat tires are no match for the Windstar pump. Its compact design, lightweight construction, and high-pressure capacity make inflating tires quick and effortless. Get back on the road in no time with Cymbal Windstar!", + 36, + 24999, + ), + ( + 1, + 9, + "Cymbal Odyssey Multi-Tool", + "Be prepared for anything with the Odyssey multi-tool. This handy gadget features essential tools like screwdrivers, hex wrenches, and tire levers, keeping you ready for minor repairs and adjustments on the go. Conquer your journey with Cymbal Odyssey!", + 52, + 999, + ), + ( + 1, + 10, + "Cymbal Nomad Water Bottle", + "Stay hydrated on every ride with the Nomad water bottle. Its sleek design, BPA-free construction, and secure lock lid make it the perfect companion for staying refreshed and motivated throughout your adventures. Hydrate and explore with Cymbal Nomad!", + 42, + 1299, + ), + ] + + columns = [ + "categoryId", + "productId", + "productName", + "productDescription", + "createTime", + "inventoryCount", + "priceInCents", + ] + + model_ddl_statements = [ + f""" + CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( + content STRING(MAX), + ) OUTPUT( + embeddings STRUCT, values ARRAY>, + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/{zone}/publishers/google/models/text-embedding-004' + ) + """, + f""" + CREATE MODEL IF NOT EXISTS LLMModel INPUT( + prompt STRING(MAX), + ) OUTPUT( + content STRING(MAX), + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/{zone}/publishers/google/models/gemini-pro', + default_batch_size = 1 + ) + """, + """ + UPDATE products p1 + SET productDescriptionEmbedding = + ( + SELECT embeddings.values from ML.PREDICT( + MODEL EmbeddingsModel, + (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId) + ) + ) + WHERE categoryId=1 + """, + ] + database = client.instance(instance_id).database(google_database) + + def create_models(): + operation = database.update_ddl(model_ddl_statements) + return operation.result(OPERATION_TIMEOUT_SECONDS) + + def get_embeddings(self): + sql = """SELECT embeddings.values FROM ML.PREDICT( + MODEL EmbeddingsModel, + (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) + )""" + + with database.snapshot() as snapshot: + res = snapshot.execute_sql(sql) + return list(res) + + yield raw_data, columns, create_models, get_embeddings + + print("\nPerforming GSQL cleanup after each ANN test...") + + operation = database.update_ddl( + [ + f"DROP TABLE IF EXISTS {table_name_ANN}", + "DROP MODEL IF EXISTS EmbeddingsModel", + "DROP MODEL IF EXISTS LLMModel", + "DROP Index IF EXISTS ProductDescriptionEmbeddingIndex", + ] + ) + if False: # Creating a vector index takes 30+ minutes, so avoiding this. + operation.result(OPERATION_TIMEOUT_SECONDS) + + # Code to perform teardown after each test goes here + print("\nGSQL Cleanup complete.") + + def test_ann_add_data1(self, setup_database): + raw_data, columns, create_models, get_embeddings = setup_database + + # Retrieve embeddings using ML_PREDICT. + embeddings = get_embeddings() + print("embeddings", embeddings) + + db = SpannerVectorStore( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column="categoryId", + ignore_metadata_columns=[], + embedding_service=embeddings, + metadata_json_column="metadata", + ) + + class TestSpannerVectorStorePGSQL: @pytest.fixture(scope="class") def setup_database(self, client): diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py new file mode 100644 index 0000000..70d4473 --- /dev/null +++ b/tests/unit/test_vectore_store.py @@ -0,0 +1,323 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from collections import namedtuple +import unittest + +from google.cloud.spanner_admin_database_v1.types import DatabaseDialect + +from langchain_google_spanner.vector_store import ( + DistanceStrategy, + GoogleSqlSemantics, + PGSqlSemantics, + SecondaryIndex, + SpannerVectorStore, + VectorSearchIndex, +) + + +class TestGoogleSqlSemantics(unittest.TestCase): + def test_distance_function_to_string(self): + cases = [ + (DistanceStrategy.COSINE, "COSINE_DISTANCE"), + (DistanceStrategy.DOT_PRODUCT, "DOT_PRODUCT"), + (DistanceStrategy.EUCLIDEAN, "EUCLIDEAN_DISTANCE"), + ] + + sem = GoogleSqlSemantics() + got_results = [] + want_results = [] + for strategy, want_str in cases: + got_results.append(sem.getDistanceFunction(strategy)) + want_results.append(want_str) + + assert got_results == want_results + + +class TestPGSqlSemantics(unittest.TestCase): + sem = PGSqlSemantics() + + def test_distance_function_to_string(self): + cases = [ + (DistanceStrategy.COSINE, "spanner.cosine_distance"), + (DistanceStrategy.DOT_PRODUCT, "spanner.dot_product"), + (DistanceStrategy.EUCLIDEAN, "spanner.euclidean_distance"), + ] + + got_results = [] + want_results = [] + for strategy, want_str in cases: + got_results.append(self.sem.getDistanceFunction(strategy)) + want_results.append(want_str) + + assert got_results == want_results + + def test_distance_function_raises_exception_if_unknown(self): + strategies = [ + 100, + -1, + ] + + for strategy in strategies: + with self.assertRaises(Exception): + self.sem.getDistanceFunction(strategy) + + +class TestSpannerVectorStore(unittest.TestCase): + def test_generate_create_table_sql(self): + got = SpannerVectorStore._generate_create_table_sql( + "users", + "id", + "essays", + "science_scores", + [], + "id", + ) + want = ( + "CREATE TABLE IF NOT EXISTS users (\n id STRING(36),\n essays STRING(MAX)," + + "\n science_scores ARRAY\n) PRIMARY KEY(id)" + ) + assert got == want + + def test_generate_secondary_indices_ddl_ANN(self): + strategies = [ + DistanceStrategy.COSINE, + DistanceStrategy.DOT_PRODUCT, + DistanceStrategy.EUCLIDEAN, + ] + + nullables = [True, False] + for distance_strategy in strategies: + for nullable in nullables: + got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( + "Documents", + secondary_indexes=[ + VectorSearchIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + nullable_column=nullable, + num_branches=1000, + tree_depth=3, + index_type=distance_strategy, + num_leaves=100000, + ) + ], + ) + + want = [ + "CREATE VECTOR INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + " WHERE DocEmbedding IS NOT NULL\n" + + f" OPTIONS(distance_type='{distance_strategy}', " + + "tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + if not nullable: + want = [ + "CREATE VECTOR INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + f" OPTIONS(distance_type='{distance_strategy}', " + + "tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + + assert canonicalize(got) == canonicalize(want) + + def test_generate_ANN_indices_exception_for_non_GoogleSQL_dialect( + self, + ): + strategies = [ + DistanceStrategy.COSINE, + DistanceStrategy.DOT_PRODUCT, + DistanceStrategy.EUCLIDEAN, + ] + + for strategy in strategies: + with self.assertRaises(Exception): + SpannerVectorStore._generate_secondary_indices_ddl_ANN( + "Documents", + dialect=DatabaseDialect.POSTGRESQL, + secondary_indexes=[ + VectorSearchIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + num_branches=1000, + tree_depth=3, + index_type=strategy, + num_leaves=100000, + ) + ], + ) + + def test_generate_secondary_indices_ddl_KNN_GoogleDialect(self): + embed_column = namedtuple("Column", ["name"]) + embed_column.name = "text" + got = SpannerVectorStore._generate_secondary_indices_ddl_KNN( + "Documents", + embedding_column=embed_column, + dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + ) + ], + ) + + want = [ + "CREATE INDEX DocEmbeddingIndex ON " + + "Documents(DocEmbedding) STORING (text)" + ] + + assert canonicalize(got) == canonicalize(want) + + def test_generate_secondary_indices_ddl_KNN_PostgresDialect(self): + embed_column = namedtuple("Column", ["name"]) + embed_column.name = "text" + got = SpannerVectorStore._generate_secondary_indices_ddl_KNN( + "Documents", + embedding_column=embed_column, + dialect=DatabaseDialect.POSTGRESQL, + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + ) + ], + ) + + want = [ + "CREATE INDEX DocEmbeddingIndex ON " + + "Documents(DocEmbedding) INCLUDE (text)" + ] + + assert canonicalize(got) == canonicalize(want) + + def test_query_ANN(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + DistanceStrategy.COSINE, + limit=100, + return_columns=["DocId"], + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + '\'{"num_leaves_to_search": 10})\n' + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_column_is_nullable(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + DistanceStrategy.COSINE, + limit=100, + embedding_column_is_nullable=True, + return_columns=["DocId"], + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE DocEmbedding IS NOT NULL\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + '\'{"num_leaves_to_search": 10})\n' + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_column_unspecified_return_columns_star_result(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + DistanceStrategy.COSINE, + limit=100, + embedding_column_is_nullable=True, + ) + + want = ( + "SELECT * FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE DocEmbedding IS NOT NULL\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + '\'{"num_leaves_to_search": 10})\n' + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_order_DESC(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + DistanceStrategy.COSINE, + limit=100, + return_columns=["DocId"], + ascending=False, + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + '\'{"num_leaves_to_search": 10}) DESC\n' + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_unspecified_limit(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + DistanceStrategy.COSINE, + return_columns=["DocId"], + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + '\'{"num_leaves_to_search": 10})' + ) + + assert got == want + + +def trimSpaces(x: str) -> str: + return x.lstrip("\n").rstrip("\n").replace("\t", " ").strip() + + +def canonicalize(s): + return list(map(trimSpaces, s)) From 430d14dac19cb8ddc83775de4b869db266c45479 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 30 Jan 2025 13:00:09 +0200 Subject: [PATCH 03/16] Ensure vector fits within limits in sample --- noxfile.py | 6 +- samples/main.py | 180 ------------ samples/search_ann.py | 271 ++++++++++++++++++ .../graph_retriever.py | 2 +- src/langchain_google_spanner/vector_store.py | 27 +- tests/integration/test_spanner_loader.py | 2 +- .../integration/test_spanner_vector_store.py | 2 +- tests/unit/test_vectore_store.py | 2 +- 8 files changed, 291 insertions(+), 201 deletions(-) delete mode 100644 samples/main.py create mode 100644 samples/search_ann.py diff --git a/noxfile.py b/noxfile.py index 4425e3e..4da86cc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -24,7 +24,7 @@ DEFAULT_PYTHON_VERSION = "3.10" CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() -LINT_PATHS = ["src", "tests", "noxfile.py"] +LINT_PATHS = ["samples", "src", "tests", "noxfile.py"] nox.options.sessions = [ @@ -41,7 +41,7 @@ nox.options.error_on_missing_interpreters = True -@nox.session(python="3.10") +@nox.session(python=DEFAULT_PYTHON_VERSION) def docs(session): """Build the docs for this library.""" @@ -76,7 +76,7 @@ def docs(session): ) -@nox.session(python="3.10") +@nox.session(python=DEFAULT_PYTHON_VERSION) def docfx(session): """Build the docfx yaml files for this library.""" diff --git a/samples/main.py b/samples/main.py deleted file mode 100644 index a0bf21e..0000000 --- a/samples/main.py +++ /dev/null @@ -1,180 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union -import datetime -import os -import time -import uuid - -from google.cloud.spanner import Client # type: ignore -from langchain_community.document_loaders import HNLoader -from langchain_community.embeddings import FakeEmbeddings - -from langchain_google_spanner.vector_store import ( # type: ignore - DistanceStrategy, - QueryParameters, - SpannerVectorStore, - TableColumn, - VectorSearchIndex, -) - -project_id = 'quip-441723' -instance_id = 'contracting' -google_database = 'ann' -zone = os.environ.get("GOOGLE_DATABASE_ZONE", "us-west2") -table_name_ANN = "products" -OPERATION_TIMEOUT_SECONDS = 240 - -def use_case(): - # Initialize the vector store table if necessary. - distance_strategy = DistanceStrategy.COSINE - SpannerVectorStore.init_vector_store_table( - instance_id=instance_id, - database_id=google_database, - table_name=table_name_ANN, - id_column=TableColumn("productId", type="INT64"), - vector_size=758, - embedding_column=TableColumn( - name="productDescriptionEmbedding", - type="ARRAY", - is_null=True, - ), - metadata_columns=[ - TableColumn(name="categoryId", type="INT64", is_null=False), - TableColumn(name="productName", type="STRING(MAX)", is_null=False), - TableColumn( - name="productDescription", type="STRING(MAX)", is_null=False - ), - TableColumn(name="inventoryCount", type="INT64", is_null=False), - TableColumn(name="priceInCents", type="INT64", is_null=True), - TableColumn(name="createTime", type="TIMESTAMP", is_null=False), - ], - secondary_indexes=[ - VectorSearchIndex( - index_name="ProductDescriptionEmbeddingIndex", - columns=["productDescriptionEmbedding"], - nullable_column=True, - num_branches=1000, - tree_depth=3, - index_type=distance_strategy, - num_leaves=100000, - ), - ], - ) - - # Create the models if necessary. - client = Client(project=project_id) - database = client.instance(instance_id).database(google_database) - - model_ddl_statements = [ - f""" - CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( - content STRING(MAX), - ) OUTPUT( - embeddings STRUCT, values ARRAY>, - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' - ) - """, - f""" - CREATE MODEL IF NOT EXISTS LLMModel INPUT( - prompt STRING(MAX), - ) OUTPUT( - content STRING(MAX), - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', - default_batch_size = 1 - ) - """, - ] - operation = database.update_ddl(model_ddl_statements) - operation.result(OPERATION_TIMEOUT_SECONDS) - - def clear_and_insert_data(tx): - tx.execute_update("DELETE FROM products WHERE 1=1") - tx.insert( - 'products', - columns=[ - 'categoryId', 'productId', 'productName', - 'productDescription', - 'createTime', 'inventoryCount', 'priceInCents', - ], - values=raw_data, - ) - - tx.execute_update( - """UPDATE products p1 - SET productDescriptionEmbedding = - (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, - (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) - WHERE categoryId=1""", - ) - - embeddings = [] - rows = tx.execute_sql( - """SELECT embeddings.values - FROM ML.PREDICT( - MODEL EmbeddingsModel, - (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) - )""") - - for row in rows: - for nesting in row: - embeddings.extend(nesting) - - return embeddings - - embeddings = database.run_in_transaction(clear_and_insert_data) - - vec_store = SpannerVectorStore( - instance_id=instance_id, - database_id=google_database, - table_name=table_name_ANN, - id_column="categoryId", - embedding_service=embeddings, - embedding_column="productDescriptionEmbedding", - skip_not_nullable_columns=True, - ) - vec_store.search_by_ANN( - 'ProductDescriptionEmbeddingIndex', - 1000, - k=20, - embedding_column_is_nullable=True, - return_columns=['productName', 'productDescription', 'inventoryCount'], - ) - -def main(): - use_case() - - -def PENDING_COMMIT_TIMESTAMP(): - return (datetime.datetime.utcnow() + datetime.timedelta(days=1)).isoformat() + "Z" - return 'PENDING_COMMIT_TIMESTAMP()' - -raw_data = [ - (1, 1, "Cymbal Helios Helmet", "Safety meets style with the Cymbal children's bike helmet. Its lightweight design, superior ventilation, and adjustable fit ensure comfort and protection on every ride. Stay bright and keep your child safe under the sun with Cymbal Helios!", PENDING_COMMIT_TIMESTAMP(), 100, 10999), - (1, 2, "Cymbal Sprout", "Let their cycling journey begin with the Cymbal Sprout, the ideal balance bike for beginning riders ages 2-4 years. Its lightweight frame, low seat height, and puncture-proof tires promote stability and confidence as little ones learn to balance and steer. Watch them sprout into cycling enthusiasts with Cymbal Sprout!", PENDING_COMMIT_TIMESTAMP(), 10, 13999), - (1, 3, "Cymbal Spark Jr.", "Light, vibrant, and ready for adventure, the Spark Jr. is the perfect first bike for young riders (ages 5-8). Its sturdy frame, easy-to-use brakes, and puncture-resistant tires inspire confidence and endless playtime. Let the spark of cycling ignite with Cymbal!", PENDING_COMMIT_TIMESTAMP(), 34, 13900), - (1, 4, "Cymbal Summit", "Conquering trails is a breeze with the Summit mountain bike. Its lightweight aluminum frame, responsive suspension, and powerful disc brakes provide exceptional control and comfort for experienced bikers navigating rocky climbs or shredding downhill. Reach new heights with Cymbal Summit!", PENDING_COMMIT_TIMESTAMP(), 0, 79999), - (1, 5, "Cymbal Breeze", "Cruise in style and embrace effortless pedaling with the Breeze electric bike. Its whisper-quiet motor and long-lasting battery let you conquer hills and distances with ease. Enjoy scenic rides, commutes, or errands with a boost of confidence from Cymbal Breeze!", PENDING_COMMIT_TIMESTAMP(), 72, 129999), - (1, 6, "Cymbal Trailblazer Backpack", "Carry all your essentials in style with the Trailblazer backpack. Its water-resistant material, multiple compartments, and comfortable straps keep your gear organized and accessible, allowing you to focus on the adventure. Blaze new trails with Cymbal Trailblazer!", PENDING_COMMIT_TIMESTAMP(), 24, 7999), - (1, 7, "Cymbal Phoenix Lights", "See and be seen with the Phoenix bike lights. Powerful LEDs and multiple light modes ensure superior visibility, enhancing your safety and enjoyment during day or night rides. Light up your journey with Cymbal Phoenix!", PENDING_COMMIT_TIMESTAMP(), 87, 3999), - (1, 8, "Cymbal Windstar Pump", "Flat tires are no match for the Windstar pump. Its compact design, lightweight construction, and high-pressure capacity make inflating tires quick and effortless. Get back on the road in no time with Cymbal Windstar!", PENDING_COMMIT_TIMESTAMP(), 36, 24999), - (1, 9,"Cymbal Odyssey Multi-Tool","Be prepared for anything with the Odyssey multi-tool. This handy gadget features essential tools like screwdrivers, hex wrenches, and tire levers, keeping you ready for minor repairs and adjustments on the go. Conquer your journey with Cymbal Odyssey!", PENDING_COMMIT_TIMESTAMP(), 52, 999), - (1, 10,"Cymbal Nomad Water Bottle","Stay hydrated on every ride with the Nomad water bottle. Its sleek design, BPA-free construction, and secure lock lid make it the perfect companion for staying refreshed and motivated throughout your adventures. Hydrate and explore with Cymbal Nomad!", PENDING_COMMIT_TIMESTAMP(), 42, 1299), -] - -columns = [ - "categoryId", - "productId", - "productName", - "productDescription", - "createTime", - "inventoryCount", - "priceInCents", -] - - -if __name__ == '__main__': - main() - diff --git a/samples/search_ann.py b/samples/search_ann.py new file mode 100644 index 0000000..009939b --- /dev/null +++ b/samples/search_ann.py @@ -0,0 +1,271 @@ +import datetime +import os +import time +import uuid +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union + +from google.cloud.spanner import Client # type: ignore +from langchain_community.document_loaders import HNLoader +from langchain_community.embeddings import FakeEmbeddings + +from langchain_google_spanner.vector_store import ( # type: ignore + DistanceStrategy, + QueryParameters, + SpannerVectorStore, + TableColumn, + VectorSearchIndex, +) + +project_id = os.environ["PROJECT_ID"] +instance_id = os.environ["INSTANCE_ID"] +google_database = os.environ["GOOGLE_DATABASE"] +zone = os.environ["GOOGLE_DATABASE_ZONE"] +table_name_ANN = "products" +OPERATION_TIMEOUT_SECONDS = 240 + + +def use_case(): + # Initialize the vector store table if necessary. + distance_strategy = DistanceStrategy.COSINE + model_vector_size = 758 + SpannerVectorStore.init_vector_store_table( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column=TableColumn("productId", type="INT64"), + vector_size=model_vector_size, + embedding_column=TableColumn( + name="productDescriptionEmbedding", + type="ARRAY", + is_null=True, + ), + metadata_columns=[ + TableColumn(name="categoryId", type="INT64", is_null=False), + TableColumn(name="productName", type="STRING(MAX)", is_null=False), + TableColumn(name="productDescription", type="STRING(MAX)", is_null=False), + TableColumn(name="inventoryCount", type="INT64", is_null=False), + TableColumn(name="priceInCents", type="INT64", is_null=True), + TableColumn(name="createTime", type="TIMESTAMP", is_null=False), + ], + secondary_indexes=[ + VectorSearchIndex( + index_name="ProductDescriptionEmbeddingIndex", + columns=["productDescriptionEmbedding"], + nullable_column=True, + num_branches=1000, + tree_depth=3, + index_type=distance_strategy, + num_leaves=100000, + ), + ], + ) + + # Create the models if necessary. + client = Client(project=project_id) + database = client.instance(instance_id).database(google_database) + + model_ddl_statements = [ + f""" + CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( + content STRING(MAX), + ) OUTPUT( + embeddings STRUCT, values ARRAY>, + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' + ) + """, + f""" + CREATE MODEL IF NOT EXISTS LLMModel INPUT( + prompt STRING(MAX), + ) OUTPUT( + content STRING(MAX), + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', + default_batch_size = 1 + ) + """, + ] + operation = database.update_ddl(model_ddl_statements) + operation.result(OPERATION_TIMEOUT_SECONDS) + + def clear_and_insert_data(tx): + tx.execute_update("DELETE FROM products WHERE 1=1") + tx.insert( + "products", + columns=[ + "categoryId", + "productId", + "productName", + "productDescription", + "createTime", + "inventoryCount", + "priceInCents", + ], + values=raw_data, + ) + + tx.execute_update( + """UPDATE products p1 + SET productDescriptionEmbedding = + (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, + (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) + WHERE categoryId=1""", + ) + + embeddings = [] + rows = tx.execute_sql( + """SELECT embeddings.values + FROM ML.PREDICT( + MODEL EmbeddingsModel, + (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) + )""" + ) + + for row in rows: + for nesting in row: + embeddings.extend(nesting) + + return embeddings + + embeddings = database.run_in_transaction(clear_and_insert_data) + if len(embeddings) > model_vector_size: + embeddings = embeddings[:model_vector_size] + + vec_store = SpannerVectorStore( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column="categoryId", + embedding_service=embeddings, + embedding_column="productDescriptionEmbedding", + skip_not_nullable_columns=True, + ) + results = vec_store.search_by_ANN( + "ProductDescriptionEmbeddingIndex", + 1000, + limit=20, + embedding_column_is_nullable=True, + return_columns=["productName", "productDescription", "inventoryCount"], + ) + + print("Search by ANN results") + for res in results: + print(res) + + +def main(): + use_case() + + +def PENDING_COMMIT_TIMESTAMP(): + return (datetime.datetime.utcnow() + datetime.timedelta(days=1)).isoformat() + "Z" + + +raw_data = [ + ( + 1, + 1, + "Cymbal Helios Helmet", + "Safety meets style with the Cymbal children's bike helmet. Its lightweight design, superior ventilation, and adjustable fit ensure comfort and protection on every ride. Stay bright and keep your child safe under the sun with Cymbal Helios!", + PENDING_COMMIT_TIMESTAMP(), + 100, + 10999, + ), + ( + 1, + 2, + "Cymbal Sprout", + "Let their cycling journey begin with the Cymbal Sprout, the ideal balance bike for beginning riders ages 2-4 years. Its lightweight frame, low seat height, and puncture-proof tires promote stability and confidence as little ones learn to balance and steer. Watch them sprout into cycling enthusiasts with Cymbal Sprout!", + PENDING_COMMIT_TIMESTAMP(), + 10, + 13999, + ), + ( + 1, + 3, + "Cymbal Spark Jr.", + "Light, vibrant, and ready for adventure, the Spark Jr. is the perfect first bike for young riders (ages 5-8). Its sturdy frame, easy-to-use brakes, and puncture-resistant tires inspire confidence and endless playtime. Let the spark of cycling ignite with Cymbal!", + PENDING_COMMIT_TIMESTAMP(), + 34, + 13900, + ), + ( + 1, + 4, + "Cymbal Summit", + "Conquering trails is a breeze with the Summit mountain bike. Its lightweight aluminum frame, responsive suspension, and powerful disc brakes provide exceptional control and comfort for experienced bikers navigating rocky climbs or shredding downhill. Reach new heights with Cymbal Summit!", + PENDING_COMMIT_TIMESTAMP(), + 0, + 79999, + ), + ( + 1, + 5, + "Cymbal Breeze", + "Cruise in style and embrace effortless pedaling with the Breeze electric bike. Its whisper-quiet motor and long-lasting battery let you conquer hills and distances with ease. Enjoy scenic rides, commutes, or errands with a boost of confidence from Cymbal Breeze!", + PENDING_COMMIT_TIMESTAMP(), + 72, + 129999, + ), + ( + 1, + 6, + "Cymbal Trailblazer Backpack", + "Carry all your essentials in style with the Trailblazer backpack. Its water-resistant material, multiple compartments, and comfortable straps keep your gear organized and accessible, allowing you to focus on the adventure. Blaze new trails with Cymbal Trailblazer!", + PENDING_COMMIT_TIMESTAMP(), + 24, + 7999, + ), + ( + 1, + 7, + "Cymbal Phoenix Lights", + "See and be seen with the Phoenix bike lights. Powerful LEDs and multiple light modes ensure superior visibility, enhancing your safety and enjoyment during day or night rides. Light up your journey with Cymbal Phoenix!", + PENDING_COMMIT_TIMESTAMP(), + 87, + 3999, + ), + ( + 1, + 8, + "Cymbal Windstar Pump", + "Flat tires are no match for the Windstar pump. Its compact design, lightweight construction, and high-pressure capacity make inflating tires quick and effortless. Get back on the road in no time with Cymbal Windstar!", + PENDING_COMMIT_TIMESTAMP(), + 36, + 24999, + ), + ( + 1, + 9, + "Cymbal Odyssey Multi-Tool", + "Be prepared for anything with the Odyssey multi-tool. This handy gadget features essential tools like screwdrivers, hex wrenches, and tire levers, keeping you ready for minor repairs and adjustments on the go. Conquer your journey with Cymbal Odyssey!", + PENDING_COMMIT_TIMESTAMP(), + 52, + 999, + ), + ( + 1, + 10, + "Cymbal Nomad Water Bottle", + "Stay hydrated on every ride with the Nomad water bottle. Its sleek design, BPA-free construction, and secure lock lid make it the perfect companion for staying refreshed and motivated throughout your adventures. Hydrate and explore with Cymbal Nomad!", + PENDING_COMMIT_TIMESTAMP(), + 42, + 1299, + ), +] + +columns = [ + "categoryId", + "productId", + "productName", + "productDescription", + "createTime", + "inventoryCount", + "priceInCents", +] + + +if __name__ == "__main__": + main() diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 0e3b59e..5bf502b 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -225,7 +225,7 @@ def __clean_element(self, element: dict[str, Any], embedding_column: str) -> Non del element["properties"][embedding_column] def __get_distance_function( - self, distance_strategy=DistanceStrategy.EUCLIDEIAN + self, distance_strategy=DistanceStrategy.EUCLIDEAN ) -> str: """Gets the vector distance function.""" if distance_strategy == DistanceStrategy.COSINE: diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index c5ac979..3af28a4 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -14,12 +14,13 @@ from __future__ import annotations -from abc import ABC, abstractmethod import datetime -from enum import Enum import logging +from abc import ABC, abstractmethod +from enum import Enum from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +import numpy as np from google.cloud import spanner # type: ignore from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from google.cloud.spanner_v1 import JsonObject, param_types @@ -27,7 +28,6 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore -import numpy as np from .version import __version__ @@ -464,9 +464,7 @@ def _generate_sql( ] ann_indices = list( - filter( - lambda index: type(index) is VectorSearchIndex, secondary_indexes - ) + filter(lambda index: type(index) is VectorSearchIndex, secondary_indexes) ) ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_ANN( table_name, @@ -474,9 +472,9 @@ def _generate_sql( secondary_indexes=list(ann_indices), ) - knn_indices = list(filter( - lambda index: type(index) is SecondaryIndex, secondary_indexes - )) + knn_indices = list( + filter(lambda index: type(index) is SecondaryIndex, secondary_indexes) + ) ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_KNN( table_name, embedding_column, @@ -644,7 +642,7 @@ def __init__( ignore_metadata_columns: Optional[List[str]] = None, metadata_json_column: Optional[str] = None, query_parameters: QueryParameters = QueryParameters(), - skip_not_nullable_columns = False, + skip_not_nullable_columns=False, ): """ Initialize the SpannerVectorStore. @@ -791,8 +789,8 @@ def _validate_table_schema(self, column_type_map, types, default_columns): ): raise Exception( "Embedding Column is not of correct type. Expected one of: {} but found: {}".format( - types[EMBEDDING_COLUMN_NAME], - embedding_column_type) + types[EMBEDDING_COLUMN_NAME], embedding_column_type + ) ) if self._metadata_json_column is not None: @@ -1048,8 +1046,8 @@ def search_by_ANN( self, index_name: str, num_leaves: int, + limit: int, embedding: List[float] = None, - k: int = None, is_embedding_nullable: bool = False, where_condition: str = None, embedding_column_is_nullable: bool = False, @@ -1064,7 +1062,7 @@ def search_by_ANN( embedding or self._embedding_service, num_leaves, strategy, - k, + limit, is_embedding_nullable, where_condition, embedding_column_is_nullable=embedding_column_is_nullable, @@ -1075,6 +1073,7 @@ def search_by_ANN( with self._database.snapshot( **staleness if staleness is not None else {} ) as snapshot: + print("search by ANN sql", sql) results = snapshot.execute_sql(sql=sql) return list(results) diff --git a/tests/integration/test_spanner_loader.py b/tests/integration/test_spanner_loader.py index 568b41a..e046bf5 100644 --- a/tests/integration/test_spanner_loader.py +++ b/tests/integration/test_spanner_loader.py @@ -15,8 +15,8 @@ import os import uuid -from google.cloud.spanner import Client import pytest +from google.cloud.spanner import Client from langchain_core.documents import Document from langchain_google_spanner.loader import Column, SpannerDocumentSaver, SpannerLoader diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index 2c3bb4d..692916b 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -16,10 +16,10 @@ import os import uuid +import pytest from google.cloud.spanner import Client # type: ignore from langchain_community.document_loaders import HNLoader from langchain_community.embeddings import FakeEmbeddings -import pytest from langchain_google_spanner.vector_store import ( # type: ignore DistanceStrategy, diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 70d4473..6ed66d7 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License -from collections import namedtuple import unittest +from collections import namedtuple from google.cloud.spanner_admin_database_v1.types import DatabaseDialect From 78d9d1d3a9e8eda96115ac21ce2800b6cb81686c Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 30 Jan 2025 14:45:26 +0200 Subject: [PATCH 04/16] Update ANN query names + test expectations --- src/langchain_google_spanner/vector_store.py | 14 +++++----- tests/unit/test_vectore_store.py | 28 +++++++++++--------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 3af28a4..fafccfa 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -117,7 +117,7 @@ def __init__( num_leaves: int, num_branches: int, tree_depth: int, - index_type: DistanceStrategy, + distance_type: DistanceStrategy, nullable_column: bool = False, *args, **kwargs, @@ -126,7 +126,7 @@ def __init__( self.num_leaves = num_leaves self.num_branches = num_branches self.tree_depth = tree_depth - self.index_type = index_type + self.distance_type = distance_type self.nullable_column = nullable_column def __post_init__(self): @@ -610,7 +610,7 @@ def _generate_secondary_indices_ddl_ANN( statement = f"CREATE VECTOR INDEX IF NOT EXISTS {secondary_index.index_name}\n\tON {table_name}({column_name})" if getattr(secondary_index, "nullable_column", False): statement += f"\n\tWHERE {column_name} IS NOT NULL" - options_segments = [f"distance_type='{secondary_index.index_type}'"] + options_segments = [f"distance_type='{secondary_index.distance_type}'"] if getattr(secondary_index, "tree_depth", 0) > 0: tree_depth = secondary_index.tree_depth if tree_depth not in (2, 3): @@ -1061,8 +1061,8 @@ def search_by_ANN( self._embedding_column, embedding or self._embedding_service, num_leaves, - strategy, limit, + strategy, is_embedding_nullable, where_condition, embedding_column_is_nullable=embedding_column_is_nullable, @@ -1084,8 +1084,8 @@ def _query_ANN( embedding_column_name: str, embedding: List[float], num_leaves: int, + limit: int, strategy: DistanceStrategy = DistanceStrategy.COSINE, - k: int = None, is_embedding_nullable: bool = False, where_condition: str = None, embedding_column_is_nullable: bool = False, @@ -1144,8 +1144,8 @@ def _query_ANN( if where_condition: sql += " WHERE " + where_condition + "\n" - if k: - sql += f"LIMIT {k}" + if limit: + sql += f"LIMIT {limit}" return sql.strip() diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 6ed66d7..d253dbd 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -109,14 +109,14 @@ def test_generate_secondary_indices_ddl_ANN(self): nullable_column=nullable, num_branches=1000, tree_depth=3, - index_type=distance_strategy, + distance_type=distance_strategy, num_leaves=100000, ) ], ) want = [ - "CREATE VECTOR INDEX DocEmbeddingIndex\n" + "CREATE VECTOR INDEX IF NOT EXISTS DocEmbeddingIndex\n" + " ON Documents(DocEmbedding)\n" + " WHERE DocEmbedding IS NOT NULL\n" + f" OPTIONS(distance_type='{distance_strategy}', " @@ -124,7 +124,7 @@ def test_generate_secondary_indices_ddl_ANN(self): ] if not nullable: want = [ - "CREATE VECTOR INDEX DocEmbeddingIndex\n" + "CREATE VECTOR INDEX IF NOT EXISTS DocEmbeddingIndex\n" + " ON Documents(DocEmbedding)\n" + f" OPTIONS(distance_type='{distance_strategy}', " + "tree_depth=3, num_branches=1000, num_leaves=100000)" @@ -152,7 +152,7 @@ def test_generate_ANN_indices_exception_for_non_GoogleSQL_dialect( columns=["DocEmbedding"], num_branches=1000, tree_depth=3, - index_type=strategy, + distance_type=strategy, num_leaves=100000, ) ], @@ -209,8 +209,8 @@ def test_query_ANN(self): "DocEmbedding", [1.0, 2.0, 3.0], 10, + 100, DistanceStrategy.COSINE, - limit=100, return_columns=["DocId"], ) @@ -218,7 +218,7 @@ def test_query_ANN(self): "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " - + '\'{"num_leaves_to_search": 10})\n' + + '\'{"num_leaves_to_search": 10}\')\n' + "LIMIT 100" ) @@ -231,8 +231,8 @@ def test_query_ANN_column_is_nullable(self): "DocEmbedding", [1.0, 2.0, 3.0], 10, + 100, DistanceStrategy.COSINE, - limit=100, embedding_column_is_nullable=True, return_columns=["DocId"], ) @@ -242,7 +242,7 @@ def test_query_ANN_column_is_nullable(self): + "WHERE DocEmbedding IS NOT NULL\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " - + '\'{"num_leaves_to_search": 10})\n' + + '\'{"num_leaves_to_search": 10}\')\n' + "LIMIT 100" ) @@ -255,8 +255,8 @@ def test_query_ANN_column_unspecified_return_columns_star_result(self): "DocEmbedding", [1.0, 2.0, 3.0], 10, + 100, DistanceStrategy.COSINE, - limit=100, embedding_column_is_nullable=True, ) @@ -265,7 +265,7 @@ def test_query_ANN_column_unspecified_return_columns_star_result(self): + "WHERE DocEmbedding IS NOT NULL\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " - + '\'{"num_leaves_to_search": 10})\n' + + '\'{"num_leaves_to_search": 10}\')\n' + "LIMIT 100" ) @@ -278,8 +278,8 @@ def test_query_ANN_order_DESC(self): "DocEmbedding", [1.0, 2.0, 3.0], 10, + 100, DistanceStrategy.COSINE, - limit=100, return_columns=["DocId"], ascending=False, ) @@ -288,7 +288,7 @@ def test_query_ANN_order_DESC(self): "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " - + '\'{"num_leaves_to_search": 10}) DESC\n' + + '\'{"num_leaves_to_search": 10}\') DESC\n' + "LIMIT 100" ) @@ -301,6 +301,7 @@ def test_query_ANN_unspecified_limit(self): "DocEmbedding", [1.0, 2.0, 3.0], 10, + 100, DistanceStrategy.COSINE, return_columns=["DocId"], ) @@ -309,7 +310,8 @@ def test_query_ANN_unspecified_limit(self): "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " - + '\'{"num_leaves_to_search": 10})' + + '\'{"num_leaves_to_search": 10}\')\n' + + "LIMIT 100" ) assert got == want From 5f01339f964fdfe32ba252c2fc10729a04aac149 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 30 Jan 2025 14:56:43 +0200 Subject: [PATCH 05/16] Pass in strategy inferred from initialization --- src/langchain_google_spanner/vector_store.py | 3 +-- tests/unit/test_vectore_store.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index fafccfa..d859bb1 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -1053,7 +1053,6 @@ def search_by_ANN( embedding_column_is_nullable: bool = False, ascending: bool = True, return_columns: List[str] = None, - strategy: DistanceStrategy = DistanceStrategy.COSINE, ) -> List[Any]: sql = SpannerVectorStore._query_ANN( self._table_name, @@ -1062,7 +1061,7 @@ def search_by_ANN( embedding or self._embedding_service, num_leaves, limit, - strategy, + self._query_parameters.distance_strategy, is_embedding_nullable, where_condition, embedding_column_is_nullable=embedding_column_is_nullable, diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index d253dbd..64bf307 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -218,7 +218,7 @@ def test_query_ANN(self): "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " - + '\'{"num_leaves_to_search": 10}\')\n' + + "'{\"num_leaves_to_search\": 10}')\n" + "LIMIT 100" ) @@ -242,7 +242,7 @@ def test_query_ANN_column_is_nullable(self): + "WHERE DocEmbedding IS NOT NULL\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " - + '\'{"num_leaves_to_search": 10}\')\n' + + "'{\"num_leaves_to_search\": 10}')\n" + "LIMIT 100" ) @@ -265,7 +265,7 @@ def test_query_ANN_column_unspecified_return_columns_star_result(self): + "WHERE DocEmbedding IS NOT NULL\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " - + '\'{"num_leaves_to_search": 10}\')\n' + + "'{\"num_leaves_to_search\": 10}')\n" + "LIMIT 100" ) @@ -288,7 +288,7 @@ def test_query_ANN_order_DESC(self): "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " - + '\'{"num_leaves_to_search": 10}\') DESC\n' + + "'{\"num_leaves_to_search\": 10}') DESC\n" + "LIMIT 100" ) @@ -310,7 +310,7 @@ def test_query_ANN_unspecified_limit(self): "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " - + '\'{"num_leaves_to_search": 10}\')\n' + + "'{\"num_leaves_to_search\": 10}')\n" + "LIMIT 100" ) From 039b1bd49369fed2d513116aa3ca9fe3f27e837b Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 30 Jan 2025 15:46:40 +0200 Subject: [PATCH 06/16] Hook up get_documents_from_query_results --- samples/search_ann.py | 6 +++++- src/langchain_google_spanner/vector_store.py | 9 +++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/samples/search_ann.py b/samples/search_ann.py index 009939b..ff7fb2b 100644 --- a/samples/search_ann.py +++ b/samples/search_ann.py @@ -56,7 +56,7 @@ def use_case(): nullable_column=True, num_branches=1000, tree_depth=3, - index_type=distance_strategy, + distance_type=distance_strategy, num_leaves=100000, ), ], @@ -140,6 +140,10 @@ def clear_and_insert_data(tx): id_column="categoryId", embedding_service=embeddings, embedding_column="productDescriptionEmbedding", + query_parameters=QueryParameters( + algorithm=QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR, + distance_strategy=distance_strategy, + ), skip_not_nullable_columns=True, ) results = vec_store.search_by_ANN( diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index d859bb1..5c60fdd 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -1053,7 +1053,7 @@ def search_by_ANN( embedding_column_is_nullable: bool = False, ascending: bool = True, return_columns: List[str] = None, - ) -> List[Any]: + ) -> List[Document]: sql = SpannerVectorStore._query_ANN( self._table_name, index_name, @@ -1074,7 +1074,12 @@ def search_by_ANN( ) as snapshot: print("search by ANN sql", sql) results = snapshot.execute_sql(sql=sql) - return list(results) + column_order_map = { + value: index for index, value in enumerate(self._columns_to_insert) + } + return self._get_documents_from_query_results( + list(results), column_order_map + ) @staticmethod def _query_ANN( From 8be267da7f2226233c126ff71da1fcf0526154c7 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 31 Jan 2025 09:42:20 +0200 Subject: [PATCH 07/16] Link up __search_by_ANN to similarity_search_by_vector --- samples/search_ann.py | 9 +++--- src/langchain_google_spanner/vector_store.py | 30 ++++++++++++++----- .../integration/test_spanner_vector_store.py | 1 + 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/samples/search_ann.py b/samples/search_ann.py index ff7fb2b..d1439b4 100644 --- a/samples/search_ann.py +++ b/samples/search_ann.py @@ -146,10 +146,11 @@ def clear_and_insert_data(tx): ), skip_not_nullable_columns=True, ) - results = vec_store.search_by_ANN( - "ProductDescriptionEmbeddingIndex", - 1000, - limit=20, + results = vec_store.similarity_search_by_vector( + embedding=embeddings, + index_name="ProductDescriptionEmbeddingIndex", + num_leaves=1000, + k=20, embedding_column_is_nullable=True, return_columns=["productName", "productDescription", "inventoryCount"], ) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 5c60fdd..e631c1d 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -225,7 +225,7 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: return dict(zip(columns, values)) def getIndexDistanceType(self, distance_strategy) -> str: - value = _GOOGLE_ALGO_INDEX_NAME.get(distance_strategy, None) + value = GOOGLE_DIALECT_TO_ANN_DISTANCE_FUNCTIONS.get(distance_strategy, None) if value is None: raise Exception(f"{distance_strategy} is unsupported for distance_type") return value @@ -1039,10 +1039,7 @@ def similarity_search_with_score_by_vector( ) return documents - def set_strategy(strategy: DistanceStrategy): - self.__strategy = strategy - - def search_by_ANN( + def __search_by_ANN( self, index_name: str, num_leaves: int, @@ -1301,9 +1298,26 @@ def similarity_search_by_vector( Returns: List[Document]: List of documents most similar to the query. """ - documents = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, pre_filter=pre_filter - ) + documents: List[Document] = None + if ( + self._query_parameters.algorithm + == QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR + ): + documents = self.__search_by_ANN( + index_name=kwargs.get("index_name", None), + num_leaves=kwargs.get("num_leaves", 1000), + limit=k, + embedding=embedding, + is_embedding_nullable=kwargs.get("is_embedding_nullable", False), + where_condition=kwargs.get("where_condition", ""), + ascending=kwargs.get("ascending", True), + return_columns=kwargs.get("return_columns", []), + ) + else: + documents = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, pre_filter=pre_filter + ) + return [doc for doc, _ in documents] def max_marginal_relevance_search_with_score_by_vector( diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index 692916b..1fa0097 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -618,6 +618,7 @@ def test_ann_add_data1(self, setup_database): embedding_service=embeddings, metadata_json_column="metadata", ) + _ = db class TestSpannerVectorStorePGSQL: From 66930b4c38fde3681a2513c8da479e054ca4cf58 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sat, 1 Feb 2025 17:29:10 +0200 Subject: [PATCH 08/16] Incorporate pre_filter and post_filter plus update tests --- samples/search_ann.py | 125 ++++++++++--------- src/langchain_google_spanner/vector_store.py | 45 ++++--- tests/unit/test_vectore_store.py | 79 +++++++++++- 3 files changed, 168 insertions(+), 81 deletions(-) diff --git a/samples/search_ann.py b/samples/search_ann.py index d1439b4..dff45d0 100644 --- a/samples/search_ann.py +++ b/samples/search_ann.py @@ -66,69 +66,9 @@ def use_case(): client = Client(project=project_id) database = client.instance(instance_id).database(google_database) - model_ddl_statements = [ - f""" - CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( - content STRING(MAX), - ) OUTPUT( - embeddings STRUCT, values ARRAY>, - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' - ) - """, - f""" - CREATE MODEL IF NOT EXISTS LLMModel INPUT( - prompt STRING(MAX), - ) OUTPUT( - content STRING(MAX), - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', - default_batch_size = 1 - ) - """, - ] operation = database.update_ddl(model_ddl_statements) operation.result(OPERATION_TIMEOUT_SECONDS) - def clear_and_insert_data(tx): - tx.execute_update("DELETE FROM products WHERE 1=1") - tx.insert( - "products", - columns=[ - "categoryId", - "productId", - "productName", - "productDescription", - "createTime", - "inventoryCount", - "priceInCents", - ], - values=raw_data, - ) - - tx.execute_update( - """UPDATE products p1 - SET productDescriptionEmbedding = - (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, - (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) - WHERE categoryId=1""", - ) - - embeddings = [] - rows = tx.execute_sql( - """SELECT embeddings.values - FROM ML.PREDICT( - MODEL EmbeddingsModel, - (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) - )""" - ) - - for row in rows: - for nesting in row: - embeddings.extend(nesting) - - return embeddings - embeddings = database.run_in_transaction(clear_and_insert_data) if len(embeddings) > model_vector_size: embeddings = embeddings[:model_vector_size] @@ -149,7 +89,8 @@ def clear_and_insert_data(tx): results = vec_store.similarity_search_by_vector( embedding=embeddings, index_name="ProductDescriptionEmbeddingIndex", - num_leaves=1000, + num_leaves=10, + tree_depth=3, k=20, embedding_column_is_nullable=True, return_columns=["productName", "productDescription", "inventoryCount"], @@ -271,6 +212,68 @@ def PENDING_COMMIT_TIMESTAMP(): "priceInCents", ] +model_ddl_statements = [ + f""" + CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( + content STRING(MAX), + ) OUTPUT( + embeddings STRUCT, values ARRAY>, + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' + ) + """, + f""" + CREATE MODEL IF NOT EXISTS LLMModel INPUT( + prompt STRING(MAX), + ) OUTPUT( + content STRING(MAX), + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', + default_batch_size = 1 + ) + """, +] + + +def clear_and_insert_data(tx): + tx.execute_update("DELETE FROM products WHERE 1=1") + tx.insert( + "products", + columns=[ + "categoryId", + "productId", + "productName", + "productDescription", + "createTime", + "inventoryCount", + "priceInCents", + ], + values=raw_data, + ) + + tx.execute_update( + """UPDATE products p1 + SET productDescriptionEmbedding = + (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, + (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) + WHERE categoryId=1""", + ) + + embeddings = [] + rows = tx.execute_sql( + """SELECT embeddings.values + FROM ML.PREDICT( + MODEL EmbeddingsModel, + (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) + )""" + ) + + for row in rows: + for nesting in row: + embeddings.extend(nesting) + + return embeddings + if __name__ == "__main__": main() diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index e631c1d..5fbb579 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -1043,10 +1043,10 @@ def __search_by_ANN( self, index_name: str, num_leaves: int, - limit: int, + k: int, # Defines the limit embedding: List[float] = None, is_embedding_nullable: bool = False, - where_condition: str = None, + pre_filter: str = None, embedding_column_is_nullable: bool = False, ascending: bool = True, return_columns: List[str] = None, @@ -1057,10 +1057,10 @@ def __search_by_ANN( self._embedding_column, embedding or self._embedding_service, num_leaves, - limit, + k, self._query_parameters.distance_strategy, is_embedding_nullable, - where_condition, + pre_filter=pre_filter, embedding_column_is_nullable=embedding_column_is_nullable, ascending=ascending, return_columns=return_columns, @@ -1085,12 +1085,13 @@ def _query_ANN( embedding_column_name: str, embedding: List[float], num_leaves: int, - limit: int, + k: int, strategy: DistanceStrategy = DistanceStrategy.COSINE, is_embedding_nullable: bool = False, - where_condition: str = None, + pre_filter: str = None, embedding_column_is_nullable: bool = False, ascending: bool = True, + post_filter: str = None, # TODO(@odeke-em): Not yet supported return_columns: List[str] = None, ): """ @@ -1132,9 +1133,12 @@ def _query_ANN( + "@{FORCE_INDEX=" + f"{index_name}" + ( - "}\n" + ("}\nWHERE " + ("1=1" if not pre_filter else f"{pre_filter}") + "\n") if (not embedding_column_is_nullable) - else "}\nWHERE " + f"{embedding_column_name} IS NOT NULL\n" + else "}\nWHERE " + + f"{embedding_column_name} IS NOT NULL" + + ("" if not pre_filter else f" AND {pre_filter}") + + "\n" ) + f"ORDER BY {ann_strategy_name}(\n" + f" ARRAY{embedding}, {embedding_column_name}, options => JSON '" @@ -1142,17 +1146,11 @@ def _query_ANN( % (num_leaves, "" if ascending else " DESC") ) - if where_condition: - sql += " WHERE " + where_condition + "\n" - - if limit: - sql += f"LIMIT {limit}" + if k: + sql += f"LIMIT {k}" return sql.strip() - def _get_rows_by_similarity_search_ann(): - pass - def _get_rows_by_similarity_search_knn( self, embedding: List[float], @@ -1205,6 +1203,15 @@ def _get_rows_by_similarity_search_knn( return list(results), column_order_map + def _get_rows_by_similarity_search_ann( + self, + embedding: List[float], + k: int, + pre_filter: Optional[str] = None, + **kwargs: Any, + ): + raise RuntimeError("Unimplemented") + def _get_documents_from_query_results( self, results: List[List], column_order_map: Dict[str, int] ) -> List[Tuple[Document, float]]: @@ -1306,10 +1313,10 @@ def similarity_search_by_vector( documents = self.__search_by_ANN( index_name=kwargs.get("index_name", None), num_leaves=kwargs.get("num_leaves", 1000), - limit=k, + k=k, embedding=embedding, - is_embedding_nullable=kwargs.get("is_embedding_nullable", False), - where_condition=kwargs.get("where_condition", ""), + embedding_column_is_nullable=kwargs.get("embedding_column_is_nullable", False), + pre_filter=kwargs.get("pre_filter", ""), ascending=kwargs.get("ascending", True), return_columns=kwargs.get("return_columns", []), ) diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 64bf307..4064ea2 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -216,6 +216,7 @@ def test_query_ANN(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE 1=1\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + "'{\"num_leaves_to_search\": 10}')\n" @@ -286,6 +287,7 @@ def test_query_ANN_order_DESC(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE 1=1\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + "'{\"num_leaves_to_search\": 10}') DESC\n" @@ -294,7 +296,7 @@ def test_query_ANN_order_DESC(self): assert got == want - def test_query_ANN_unspecified_limit(self): + def test_query_ANN_specified_limit(self): got = SpannerVectorStore._query_ANN( "Documents", "DocEmbeddingIndex", @@ -308,6 +310,81 @@ def test_query_ANN_unspecified_limit(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE 1=1\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + "'{\"num_leaves_to_search\": 10}')\n" + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_specified_pre_filter(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + 100, + DistanceStrategy.COSINE, + return_columns=["DocId"], + pre_filter="categoryId!=20", + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE categoryId!=20\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + "'{\"num_leaves_to_search\": 10}')\n" + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_specified_pre_filter_with_nullable_column(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + 100, + DistanceStrategy.COSINE, + return_columns=["DocId"], + pre_filter="categoryId!=9", + embedding_column_is_nullable=True, + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE DocEmbedding IS NOT NULL AND categoryId!=9\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + "'{\"num_leaves_to_search\": 10}')\n" + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_no_pre_filter_non_nullable(self): + got = SpannerVectorStore._query_ANN( + "Documents", + "DocEmbeddingIndex", + "DocEmbedding", + [1.0, 2.0, 3.0], + 10, + 100, + DistanceStrategy.COSINE, + embedding_column_is_nullable=True, + return_columns=["DocId"], + pre_filter="DocId!=2", + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE DocEmbedding IS NOT NULL AND DocId!=2\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + "'{\"num_leaves_to_search\": 10}')\n" From c0ae25d91522595cb840795ea4a26256f70f9462 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 3 Feb 2025 11:53:09 +0200 Subject: [PATCH 09/16] Review addressing --- src/langchain_google_spanner/vector_store.py | 38 ++++++++------------ tests/unit/test_vectore_store.py | 32 ++++++++--------- 2 files changed, 31 insertions(+), 39 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 5fbb579..4630df4 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -422,7 +422,7 @@ def _generate_sql( column_configs, primary_key, secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None, - vector_size: int = None, + vector_size: Optional[int] = None, ): """ Generate SQL for creating the vector store table. @@ -672,7 +672,6 @@ def __init__( self._query_parameters = query_parameters self._embedding_service = embedding_service - self.__strategy = None self._skip_not_nullable_columns = skip_not_nullable_columns if metadata_columns is not None and ignore_metadata_columns is not None: @@ -907,13 +906,6 @@ def _insert_data(self, records, columns_to_insert): values=records, ) - def add_ann_rows( - self, data: List[Tuple], id_column_index: int, columns=Dict[str, str] - ) -> List[str]: - self._insert_data(data, columns) - ids = list(map(lambda row: row[id_column_index], data)) - return ids - def add_documents( self, documents: List[Document], @@ -1043,15 +1035,14 @@ def __search_by_ANN( self, index_name: str, num_leaves: int, - k: int, # Defines the limit - embedding: List[float] = None, - is_embedding_nullable: bool = False, - pre_filter: str = None, + k: int, # Defines the limit + embedding: Optional[List[float]] = None, + pre_filter: Optional[str] = None, embedding_column_is_nullable: bool = False, ascending: bool = True, - return_columns: List[str] = None, - ) -> List[Document]: - sql = SpannerVectorStore._query_ANN( + return_columns: Optional[List[str]] = None, + ) -> List[Tuple[Document, float]]: + sql = SpannerVectorStore._generate_sql_for_ANN( self._table_name, index_name, self._embedding_column, @@ -1069,7 +1060,6 @@ def __search_by_ANN( with self._database.snapshot( **staleness if staleness is not None else {} ) as snapshot: - print("search by ANN sql", sql) results = snapshot.execute_sql(sql=sql) column_order_map = { value: index for index, value in enumerate(self._columns_to_insert) @@ -1079,7 +1069,7 @@ def __search_by_ANN( ) @staticmethod - def _query_ANN( + def _generate_sql_for_ANN( table_name: str, index_name: str, embedding_column_name: str, @@ -1088,12 +1078,12 @@ def _query_ANN( k: int, strategy: DistanceStrategy = DistanceStrategy.COSINE, is_embedding_nullable: bool = False, - pre_filter: str = None, + pre_filter: Optional[str] = None, embedding_column_is_nullable: bool = False, ascending: bool = True, - post_filter: str = None, # TODO(@odeke-em): Not yet supported + post_filter: Optional[str] = None, # TODO(@odeke-em): Not yet supported return_columns: List[str] = None, - ): + ) -> str: """ Sample query: SELECT DocId @@ -1305,7 +1295,7 @@ def similarity_search_by_vector( Returns: List[Document]: List of documents most similar to the query. """ - documents: List[Document] = None + documents: List[Tuple[Document, float]] = [] if ( self._query_parameters.algorithm == QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR @@ -1315,7 +1305,9 @@ def similarity_search_by_vector( num_leaves=kwargs.get("num_leaves", 1000), k=k, embedding=embedding, - embedding_column_is_nullable=kwargs.get("embedding_column_is_nullable", False), + embedding_column_is_nullable=kwargs.get( + "embedding_column_is_nullable", False + ), pre_filter=kwargs.get("pre_filter", ""), ascending=kwargs.get("ascending", True), return_columns=kwargs.get("return_columns", []), diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 4064ea2..fbfefbe 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -202,8 +202,8 @@ def test_generate_secondary_indices_ddl_KNN_PostgresDialect(self): assert canonicalize(got) == canonicalize(want) - def test_query_ANN(self): - got = SpannerVectorStore._query_ANN( + def test_generate_sql_for_ANN(self): + got = SpannerVectorStore._generate_sql_for_ANN( "Documents", "DocEmbeddingIndex", "DocEmbedding", @@ -225,8 +225,8 @@ def test_query_ANN(self): assert got == want - def test_query_ANN_column_is_nullable(self): - got = SpannerVectorStore._query_ANN( + def test_generate_sql_for_ANN_column_is_nullable(self): + got = SpannerVectorStore._generate_sql_for_ANN( "Documents", "DocEmbeddingIndex", "DocEmbedding", @@ -249,8 +249,8 @@ def test_query_ANN_column_is_nullable(self): assert got == want - def test_query_ANN_column_unspecified_return_columns_star_result(self): - got = SpannerVectorStore._query_ANN( + def test_generate_sql_for_ANN_column_unspecified_return_columns_star_result(self): + got = SpannerVectorStore._generate_sql_for_ANN( "Documents", "DocEmbeddingIndex", "DocEmbedding", @@ -272,8 +272,8 @@ def test_query_ANN_column_unspecified_return_columns_star_result(self): assert got == want - def test_query_ANN_order_DESC(self): - got = SpannerVectorStore._query_ANN( + def test_generate_sql_for_ANN_order_DESC(self): + got = SpannerVectorStore._generate_sql_for_ANN( "Documents", "DocEmbeddingIndex", "DocEmbedding", @@ -296,8 +296,8 @@ def test_query_ANN_order_DESC(self): assert got == want - def test_query_ANN_specified_limit(self): - got = SpannerVectorStore._query_ANN( + def test_generate_sql_for_ANN_specified_limit(self): + got = SpannerVectorStore._generate_sql_for_ANN( "Documents", "DocEmbeddingIndex", "DocEmbedding", @@ -319,8 +319,8 @@ def test_query_ANN_specified_limit(self): assert got == want - def test_query_ANN_specified_pre_filter(self): - got = SpannerVectorStore._query_ANN( + def test_generate_sql_for_ANN_specified_pre_filter(self): + got = SpannerVectorStore._generate_sql_for_ANN( "Documents", "DocEmbeddingIndex", "DocEmbedding", @@ -343,8 +343,8 @@ def test_query_ANN_specified_pre_filter(self): assert got == want - def test_query_ANN_specified_pre_filter_with_nullable_column(self): - got = SpannerVectorStore._query_ANN( + def test_generate_sql_for_ANN_specified_pre_filter_with_nullable_column(self): + got = SpannerVectorStore._generate_sql_for_ANN( "Documents", "DocEmbeddingIndex", "DocEmbedding", @@ -368,8 +368,8 @@ def test_query_ANN_specified_pre_filter_with_nullable_column(self): assert got == want - def test_query_ANN_no_pre_filter_non_nullable(self): - got = SpannerVectorStore._query_ANN( + def test_generate_sql_for_ANN_no_pre_filter_non_nullable(self): + got = SpannerVectorStore._generate_sql_for_ANN( "Documents", "DocEmbeddingIndex", "DocEmbedding", From c39881085a0e760de7a8cd63a66099e27e72fff0 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 3 Feb 2025 12:36:02 +0200 Subject: [PATCH 10/16] Simplified checking if using ANN --- src/langchain_google_spanner/vector_store.py | 23 ++++++++------------ 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 4630df4..d2cf473 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -1193,15 +1193,6 @@ def _get_rows_by_similarity_search_knn( return list(results), column_order_map - def _get_rows_by_similarity_search_ann( - self, - embedding: List[float], - k: int, - pre_filter: Optional[str] = None, - **kwargs: Any, - ): - raise RuntimeError("Unimplemented") - def _get_documents_from_query_results( self, results: List[List], column_order_map: Dict[str, int] ) -> List[Tuple[Document, float]]: @@ -1277,6 +1268,13 @@ def similarity_search_with_score( ) return documents + @property + def __using_ANN(self): + return ( + self._query_parameters.algorithm + == QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR + ) + def similarity_search_by_vector( self, embedding: List[float], @@ -1296,10 +1294,7 @@ def similarity_search_by_vector( List[Document]: List of documents most similar to the query. """ documents: List[Tuple[Document, float]] = [] - if ( - self._query_parameters.algorithm - == QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR - ): + if self.__using_ANN: documents = self.__search_by_ANN( index_name=kwargs.get("index_name", None), num_leaves=kwargs.get("num_leaves", 1000), @@ -1308,7 +1303,7 @@ def similarity_search_by_vector( embedding_column_is_nullable=kwargs.get( "embedding_column_is_nullable", False ), - pre_filter=kwargs.get("pre_filter", ""), + pre_filter=pre_filter, ascending=kwargs.get("ascending", True), return_columns=kwargs.get("return_columns", []), ) From 3c65bc93c3e15b2e4192cf0e7e1afd4f2556c226 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 3 Feb 2025 12:57:38 +0200 Subject: [PATCH 11/16] Reduce the amount of changes --- samples/search_ann.py | 279 ------------------ src/langchain_google_spanner/vector_store.py | 24 +- .../integration/test_spanner_vector_store.py | 14 - 3 files changed, 10 insertions(+), 307 deletions(-) delete mode 100644 samples/search_ann.py diff --git a/samples/search_ann.py b/samples/search_ann.py deleted file mode 100644 index dff45d0..0000000 --- a/samples/search_ann.py +++ /dev/null @@ -1,279 +0,0 @@ -import datetime -import os -import time -import uuid -from dataclasses import dataclass -from enum import Enum -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union - -from google.cloud.spanner import Client # type: ignore -from langchain_community.document_loaders import HNLoader -from langchain_community.embeddings import FakeEmbeddings - -from langchain_google_spanner.vector_store import ( # type: ignore - DistanceStrategy, - QueryParameters, - SpannerVectorStore, - TableColumn, - VectorSearchIndex, -) - -project_id = os.environ["PROJECT_ID"] -instance_id = os.environ["INSTANCE_ID"] -google_database = os.environ["GOOGLE_DATABASE"] -zone = os.environ["GOOGLE_DATABASE_ZONE"] -table_name_ANN = "products" -OPERATION_TIMEOUT_SECONDS = 240 - - -def use_case(): - # Initialize the vector store table if necessary. - distance_strategy = DistanceStrategy.COSINE - model_vector_size = 758 - SpannerVectorStore.init_vector_store_table( - instance_id=instance_id, - database_id=google_database, - table_name=table_name_ANN, - id_column=TableColumn("productId", type="INT64"), - vector_size=model_vector_size, - embedding_column=TableColumn( - name="productDescriptionEmbedding", - type="ARRAY", - is_null=True, - ), - metadata_columns=[ - TableColumn(name="categoryId", type="INT64", is_null=False), - TableColumn(name="productName", type="STRING(MAX)", is_null=False), - TableColumn(name="productDescription", type="STRING(MAX)", is_null=False), - TableColumn(name="inventoryCount", type="INT64", is_null=False), - TableColumn(name="priceInCents", type="INT64", is_null=True), - TableColumn(name="createTime", type="TIMESTAMP", is_null=False), - ], - secondary_indexes=[ - VectorSearchIndex( - index_name="ProductDescriptionEmbeddingIndex", - columns=["productDescriptionEmbedding"], - nullable_column=True, - num_branches=1000, - tree_depth=3, - distance_type=distance_strategy, - num_leaves=100000, - ), - ], - ) - - # Create the models if necessary. - client = Client(project=project_id) - database = client.instance(instance_id).database(google_database) - - operation = database.update_ddl(model_ddl_statements) - operation.result(OPERATION_TIMEOUT_SECONDS) - - embeddings = database.run_in_transaction(clear_and_insert_data) - if len(embeddings) > model_vector_size: - embeddings = embeddings[:model_vector_size] - - vec_store = SpannerVectorStore( - instance_id=instance_id, - database_id=google_database, - table_name=table_name_ANN, - id_column="categoryId", - embedding_service=embeddings, - embedding_column="productDescriptionEmbedding", - query_parameters=QueryParameters( - algorithm=QueryParameters.NearestNeighborsAlgorithm.APPROXIMATE_NEAREST_NEIGHBOR, - distance_strategy=distance_strategy, - ), - skip_not_nullable_columns=True, - ) - results = vec_store.similarity_search_by_vector( - embedding=embeddings, - index_name="ProductDescriptionEmbeddingIndex", - num_leaves=10, - tree_depth=3, - k=20, - embedding_column_is_nullable=True, - return_columns=["productName", "productDescription", "inventoryCount"], - ) - - print("Search by ANN results") - for res in results: - print(res) - - -def main(): - use_case() - - -def PENDING_COMMIT_TIMESTAMP(): - return (datetime.datetime.utcnow() + datetime.timedelta(days=1)).isoformat() + "Z" - - -raw_data = [ - ( - 1, - 1, - "Cymbal Helios Helmet", - "Safety meets style with the Cymbal children's bike helmet. Its lightweight design, superior ventilation, and adjustable fit ensure comfort and protection on every ride. Stay bright and keep your child safe under the sun with Cymbal Helios!", - PENDING_COMMIT_TIMESTAMP(), - 100, - 10999, - ), - ( - 1, - 2, - "Cymbal Sprout", - "Let their cycling journey begin with the Cymbal Sprout, the ideal balance bike for beginning riders ages 2-4 years. Its lightweight frame, low seat height, and puncture-proof tires promote stability and confidence as little ones learn to balance and steer. Watch them sprout into cycling enthusiasts with Cymbal Sprout!", - PENDING_COMMIT_TIMESTAMP(), - 10, - 13999, - ), - ( - 1, - 3, - "Cymbal Spark Jr.", - "Light, vibrant, and ready for adventure, the Spark Jr. is the perfect first bike for young riders (ages 5-8). Its sturdy frame, easy-to-use brakes, and puncture-resistant tires inspire confidence and endless playtime. Let the spark of cycling ignite with Cymbal!", - PENDING_COMMIT_TIMESTAMP(), - 34, - 13900, - ), - ( - 1, - 4, - "Cymbal Summit", - "Conquering trails is a breeze with the Summit mountain bike. Its lightweight aluminum frame, responsive suspension, and powerful disc brakes provide exceptional control and comfort for experienced bikers navigating rocky climbs or shredding downhill. Reach new heights with Cymbal Summit!", - PENDING_COMMIT_TIMESTAMP(), - 0, - 79999, - ), - ( - 1, - 5, - "Cymbal Breeze", - "Cruise in style and embrace effortless pedaling with the Breeze electric bike. Its whisper-quiet motor and long-lasting battery let you conquer hills and distances with ease. Enjoy scenic rides, commutes, or errands with a boost of confidence from Cymbal Breeze!", - PENDING_COMMIT_TIMESTAMP(), - 72, - 129999, - ), - ( - 1, - 6, - "Cymbal Trailblazer Backpack", - "Carry all your essentials in style with the Trailblazer backpack. Its water-resistant material, multiple compartments, and comfortable straps keep your gear organized and accessible, allowing you to focus on the adventure. Blaze new trails with Cymbal Trailblazer!", - PENDING_COMMIT_TIMESTAMP(), - 24, - 7999, - ), - ( - 1, - 7, - "Cymbal Phoenix Lights", - "See and be seen with the Phoenix bike lights. Powerful LEDs and multiple light modes ensure superior visibility, enhancing your safety and enjoyment during day or night rides. Light up your journey with Cymbal Phoenix!", - PENDING_COMMIT_TIMESTAMP(), - 87, - 3999, - ), - ( - 1, - 8, - "Cymbal Windstar Pump", - "Flat tires are no match for the Windstar pump. Its compact design, lightweight construction, and high-pressure capacity make inflating tires quick and effortless. Get back on the road in no time with Cymbal Windstar!", - PENDING_COMMIT_TIMESTAMP(), - 36, - 24999, - ), - ( - 1, - 9, - "Cymbal Odyssey Multi-Tool", - "Be prepared for anything with the Odyssey multi-tool. This handy gadget features essential tools like screwdrivers, hex wrenches, and tire levers, keeping you ready for minor repairs and adjustments on the go. Conquer your journey with Cymbal Odyssey!", - PENDING_COMMIT_TIMESTAMP(), - 52, - 999, - ), - ( - 1, - 10, - "Cymbal Nomad Water Bottle", - "Stay hydrated on every ride with the Nomad water bottle. Its sleek design, BPA-free construction, and secure lock lid make it the perfect companion for staying refreshed and motivated throughout your adventures. Hydrate and explore with Cymbal Nomad!", - PENDING_COMMIT_TIMESTAMP(), - 42, - 1299, - ), -] - -columns = [ - "categoryId", - "productId", - "productName", - "productDescription", - "createTime", - "inventoryCount", - "priceInCents", -] - -model_ddl_statements = [ - f""" - CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( - content STRING(MAX), - ) OUTPUT( - embeddings STRUCT, values ARRAY>, - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004' - ) - """, - f""" - CREATE MODEL IF NOT EXISTS LLMModel INPUT( - prompt STRING(MAX), - ) OUTPUT( - content STRING(MAX), - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro', - default_batch_size = 1 - ) - """, -] - - -def clear_and_insert_data(tx): - tx.execute_update("DELETE FROM products WHERE 1=1") - tx.insert( - "products", - columns=[ - "categoryId", - "productId", - "productName", - "productDescription", - "createTime", - "inventoryCount", - "priceInCents", - ], - values=raw_data, - ) - - tx.execute_update( - """UPDATE products p1 - SET productDescriptionEmbedding = - (SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel, - (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId))) - WHERE categoryId=1""", - ) - - embeddings = [] - rows = tx.execute_sql( - """SELECT embeddings.values - FROM ML.PREDICT( - MODEL EmbeddingsModel, - (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) - )""" - ) - - for row in rows: - for nesting in row: - embeddings.extend(nesting) - - return embeddings - - -if __name__ == "__main__": - main() diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index d2cf473..87c7cf4 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -672,7 +672,6 @@ def __init__( self._query_parameters = query_parameters self._embedding_service = embedding_service - self._skip_not_nullable_columns = skip_not_nullable_columns if metadata_columns is not None and ignore_metadata_columns is not None: raise Exception( @@ -805,14 +804,13 @@ def _validate_table_schema(self, column_type_map, types, default_columns): embedding_column_type, ) - if not self._skip_not_nullable_columns: - for column_name, column_config in column_type_map.items(): - if column_name not in self._columns_to_insert: - if "NO" == column_config[2].upper(): - raise Exception( - "Found not nullable constraint on column: {}.", - column_name, - ) + for column_name, column_config in column_type_map.items(): + if column_name not in self._columns_to_insert: + if "NO" == column_config[2].upper(): + raise Exception( + "Found not nullable constraint on column: {}.", + column_name, + ) def _select_relevance_score_fn(self) -> Callable[[float], float]: if self._query_parameters.distance_strategy == DistanceStrategy.COSINE: @@ -1046,15 +1044,14 @@ def __search_by_ANN( self._table_name, index_name, self._embedding_column, - embedding or self._embedding_service, + embedding, num_leaves, k, self._query_parameters.distance_strategy, - is_embedding_nullable, pre_filter=pre_filter, embedding_column_is_nullable=embedding_column_is_nullable, ascending=ascending, - return_columns=return_columns, + return_columns=return_columns or self._columns_to_insert, ) staleness = self._query_parameters.staleness with self._database.snapshot( @@ -1077,12 +1074,11 @@ def _generate_sql_for_ANN( num_leaves: int, k: int, strategy: DistanceStrategy = DistanceStrategy.COSINE, - is_embedding_nullable: bool = False, pre_filter: Optional[str] = None, embedding_column_is_nullable: bool = False, ascending: bool = True, post_filter: Optional[str] = None, # TODO(@odeke-em): Not yet supported - return_columns: List[str] = None, + return_columns: Optional[List[str]] = None, ) -> str: """ Sample query: diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index 1fa0097..bb71125 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -395,20 +395,6 @@ def test_spanner_vector_search_data4(self, setup_database): class TestSpannerVectorStoreGoogleSQL_ANN: @pytest.fixture(scope="class") def setup_database(self, client): - """ - CREATE TABLE products ( - categoryId INT64 NOT NULL, - productId INT64 NOT NULL, - productName STRING(MAX) NOT NULL, - productDescription STRING(MAX) NOT NULL, - productDescriptionEmbedding ARRAY(vector_length=>728), - createTime TIMESTAMP NOT NULL OPTIONS ( - allow_commit_timestamp = true - ), - inventoryCount INT64 NOT NULL, - priceInCents INT64, - ) PRIMARY KEY(categoryId, productId); - """ distance_strategy = DistanceStrategy.COSINE SpannerVectorStore.init_vector_store_table( instance_id=instance_id, From aa6a6c2a0932b28ac3da56c9b53046326f1ef44b Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 3 Feb 2025 13:05:02 +0200 Subject: [PATCH 12/16] More reductions --- noxfile.py | 2 +- src/langchain_google_spanner/vector_store.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/noxfile.py b/noxfile.py index 4da86cc..9a91d75 100644 --- a/noxfile.py +++ b/noxfile.py @@ -24,7 +24,7 @@ DEFAULT_PYTHON_VERSION = "3.10" CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() -LINT_PATHS = ["samples", "src", "tests", "noxfile.py"] +LINT_PATHS = ["src", "tests", "noxfile.py"] nox.options.sessions = [ diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 87c7cf4..23609bf 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -73,7 +73,7 @@ class TableColumn: name: str type: str is_null: bool = True - vector_length: int = None + vector_length: Optional[int] = None def __post_init__(self): # Check if column_name is None after initialization @@ -285,7 +285,7 @@ class QueryParameters: class NearestNeighborsAlgorithm(Enum): """ - Enum for k-nearest neighbors search algorithms. + Enum for nearest neighbors search algorithms. """ EXACT_NEAREST_NEIGHBOR = 1 @@ -444,7 +444,6 @@ def _generate_sql( embedding_config = list( filter(lambda x: x.name == embedding_column, column_configs) ) - print("column_configs", column_configs, "\nembedding_config", embedding_config) if embedding_column and len(embedding_config) > 0: config = embedding_config[0] if config.vector_length is None or config.vector_length <= 0: @@ -642,7 +641,6 @@ def __init__( ignore_metadata_columns: Optional[List[str]] = None, metadata_json_column: Optional[str] = None, query_parameters: QueryParameters = QueryParameters(), - skip_not_nullable_columns=False, ): """ Initialize the SpannerVectorStore. From 130bc461ca84f7284ab3cf9895350822307dd263 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 3 Feb 2025 13:08:48 +0200 Subject: [PATCH 13/16] More reductions to ease code review --- src/langchain_google_spanner/vector_store.py | 1 - tests/integration/test_spanner_loader.py | 2 +- .../integration/test_spanner_vector_store.py | 220 +----------------- 3 files changed, 2 insertions(+), 221 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 23609bf..33ee668 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -440,7 +440,6 @@ def _generate_sql( - str: The generated SQL. """ - # 1. If any of the columns is a VectorSearchIndex embedding_config = list( filter(lambda x: x.name == embedding_column, column_configs) ) diff --git a/tests/integration/test_spanner_loader.py b/tests/integration/test_spanner_loader.py index e046bf5..fd8f028 100644 --- a/tests/integration/test_spanner_loader.py +++ b/tests/integration/test_spanner_loader.py @@ -16,7 +16,7 @@ import uuid import pytest -from google.cloud.spanner import Client +from google.cloud.spanner import Client # type: ignore from langchain_core.documents import Document from langchain_google_spanner.loader import Column, SpannerDocumentSaver, SpannerLoader diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index bb71125..81dd1b9 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -26,16 +26,13 @@ QueryParameters, SpannerVectorStore, TableColumn, - VectorSearchIndex, ) project_id = os.environ["PROJECT_ID"] instance_id = os.environ["INSTANCE_ID"] google_database = os.environ["GOOGLE_DATABASE"] -pg_database = os.environ.get("PG_DATABASE", None) -zone = os.environ.get("GOOGLE_DATABASE_ZONE", "us-west2") +pg_database = os.environ["PG_DATABASE"] table_name = "test_table" + str(uuid.uuid4()).replace("-", "_") -table_name_ANN = "products" OPERATION_TIMEOUT_SECONDS = 240 @@ -392,221 +389,6 @@ def test_spanner_vector_search_data4(self, setup_database): assert len(docs) == 3 -class TestSpannerVectorStoreGoogleSQL_ANN: - @pytest.fixture(scope="class") - def setup_database(self, client): - distance_strategy = DistanceStrategy.COSINE - SpannerVectorStore.init_vector_store_table( - instance_id=instance_id, - database_id=google_database, - table_name=table_name_ANN, - id_column=TableColumn("productId", type="INT64"), - vector_size=758, - embedding_column=TableColumn( - name="productDescriptionEmbedding", - type="ARRAY", - is_null=True, - ), - metadata_columns=[ - TableColumn(name="categoryId", type="INT64", is_null=False), - TableColumn(name="productName", type="STRING(MAX)", is_null=False), - TableColumn( - name="productDescription", type="STRING(MAX)", is_null=False - ), - TableColumn(name="inventoryCount", type="INT64", is_null=False), - TableColumn(name="priceInCents", type="INT64", is_null=True), - ], - secondary_indexes=[ - VectorSearchIndex( - index_name="ProductDescriptionEmbeddingIndex", - columns=["productDescriptionEmbedding"], - nullable_column=True, - num_branches=1000, - tree_depth=3, - index_type=distance_strategy, - num_leaves=100000, - ), - ], - ) - - raw_data = [ - ( - 1, - 1, - "Cymbal Helios Helmet", - "Safety meets style with the Cymbal children's bike helmet. Its lightweight design, superior ventilation, and adjustable fit ensure comfort and protection on every ride. Stay bright and keep your child safe under the sun with Cymbal Helios!", - 100, - 10999, - ), - ( - 1, - 2, - "Cymbal Sprout", - "Let their cycling journey begin with the Cymbal Sprout, the ideal balance bike for beginning riders ages 2-4 years. Its lightweight frame, low seat height, and puncture-proof tires promote stability and confidence as little ones learn to balance and steer. Watch them sprout into cycling enthusiasts with Cymbal Sprout!", - 10, - 13999, - ), - ( - 1, - 3, - "Cymbal Spark Jr.", - "Light, vibrant, and ready for adventure, the Spark Jr. is the perfect first bike for young riders (ages 5-8). Its sturdy frame, easy-to-use brakes, and puncture-resistant tires inspire confidence and endless playtime. Let the spark of cycling ignite with Cymbal!", - 34, - 13900, - ), - ( - 1, - 4, - "Cymbal Summit", - "Conquering trails is a breeze with the Summit mountain bike. Its lightweight aluminum frame, responsive suspension, and powerful disc brakes provide exceptional control and comfort for experienced bikers navigating rocky climbs or shredding downhill. Reach new heights with Cymbal Summit!", - 0, - 79999, - ), - ( - 1, - 5, - "Cymbal Breeze", - "Cruise in style and embrace effortless pedaling with the Breeze electric bike. Its whisper-quiet motor and long-lasting battery let you conquer hills and distances with ease. Enjoy scenic rides, commutes, or errands with a boost of confidence from Cymbal Breeze!", - 72, - 129999, - ), - ( - 1, - 6, - "Cymbal Trailblazer Backpack", - "Carry all your essentials in style with the Trailblazer backpack. Its water-resistant material, multiple compartments, and comfortable straps keep your gear organized and accessible, allowing you to focus on the adventure. Blaze new trails with Cymbal Trailblazer!", - 24, - 7999, - ), - ( - 1, - 7, - "Cymbal Phoenix Lights", - "See and be seen with the Phoenix bike lights. Powerful LEDs and multiple light modes ensure superior visibility, enhancing your safety and enjoyment during day or night rides. Light up your journey with Cymbal Phoenix!", - 87, - 3999, - ), - ( - 1, - 8, - "Cymbal Windstar Pump", - "Flat tires are no match for the Windstar pump. Its compact design, lightweight construction, and high-pressure capacity make inflating tires quick and effortless. Get back on the road in no time with Cymbal Windstar!", - 36, - 24999, - ), - ( - 1, - 9, - "Cymbal Odyssey Multi-Tool", - "Be prepared for anything with the Odyssey multi-tool. This handy gadget features essential tools like screwdrivers, hex wrenches, and tire levers, keeping you ready for minor repairs and adjustments on the go. Conquer your journey with Cymbal Odyssey!", - 52, - 999, - ), - ( - 1, - 10, - "Cymbal Nomad Water Bottle", - "Stay hydrated on every ride with the Nomad water bottle. Its sleek design, BPA-free construction, and secure lock lid make it the perfect companion for staying refreshed and motivated throughout your adventures. Hydrate and explore with Cymbal Nomad!", - 42, - 1299, - ), - ] - - columns = [ - "categoryId", - "productId", - "productName", - "productDescription", - "createTime", - "inventoryCount", - "priceInCents", - ] - - model_ddl_statements = [ - f""" - CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT( - content STRING(MAX), - ) OUTPUT( - embeddings STRUCT, values ARRAY>, - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/{zone}/publishers/google/models/text-embedding-004' - ) - """, - f""" - CREATE MODEL IF NOT EXISTS LLMModel INPUT( - prompt STRING(MAX), - ) OUTPUT( - content STRING(MAX), - ) REMOTE OPTIONS ( - endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/{zone}/publishers/google/models/gemini-pro', - default_batch_size = 1 - ) - """, - """ - UPDATE products p1 - SET productDescriptionEmbedding = - ( - SELECT embeddings.values from ML.PREDICT( - MODEL EmbeddingsModel, - (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId) - ) - ) - WHERE categoryId=1 - """, - ] - database = client.instance(instance_id).database(google_database) - - def create_models(): - operation = database.update_ddl(model_ddl_statements) - return operation.result(OPERATION_TIMEOUT_SECONDS) - - def get_embeddings(self): - sql = """SELECT embeddings.values FROM ML.PREDICT( - MODEL EmbeddingsModel, - (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) - )""" - - with database.snapshot() as snapshot: - res = snapshot.execute_sql(sql) - return list(res) - - yield raw_data, columns, create_models, get_embeddings - - print("\nPerforming GSQL cleanup after each ANN test...") - - operation = database.update_ddl( - [ - f"DROP TABLE IF EXISTS {table_name_ANN}", - "DROP MODEL IF EXISTS EmbeddingsModel", - "DROP MODEL IF EXISTS LLMModel", - "DROP Index IF EXISTS ProductDescriptionEmbeddingIndex", - ] - ) - if False: # Creating a vector index takes 30+ minutes, so avoiding this. - operation.result(OPERATION_TIMEOUT_SECONDS) - - # Code to perform teardown after each test goes here - print("\nGSQL Cleanup complete.") - - def test_ann_add_data1(self, setup_database): - raw_data, columns, create_models, get_embeddings = setup_database - - # Retrieve embeddings using ML_PREDICT. - embeddings = get_embeddings() - print("embeddings", embeddings) - - db = SpannerVectorStore( - instance_id=instance_id, - database_id=google_database, - table_name=table_name_ANN, - id_column="categoryId", - ignore_metadata_columns=[], - embedding_service=embeddings, - metadata_json_column="metadata", - ) - _ = db - - class TestSpannerVectorStorePGSQL: @pytest.fixture(scope="class") def setup_database(self, client): From ed45ddce4d02909337b17d625e40b72197fe500f Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 3 Feb 2025 19:20:54 +0200 Subject: [PATCH 14/16] Fit with get_rows_by_similarity_search_ann --- src/langchain_google_spanner/vector_store.py | 69 ++++++++++---------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 33ee668..398be20 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -1017,26 +1017,34 @@ def similarity_search_with_score_by_vector( Returns: List[Document]: List of documents most similar to the query. """ + if self.__using_ANN: + results, column_order_map = self._get_rows_by_similarity_search_ann( + embedding, + k, + pre_filter, + **kwargs, + ) + else: + results, column_order_map = self._get_rows_by_similarity_search_knn( + embedding, k, pre_filter + ) - results, column_order_map = self._get_rows_by_similarity_search_knn( - embedding, k, pre_filter - ) documents = self._get_documents_from_query_results( list(results), column_order_map ) return documents - def __search_by_ANN( + def _get_rows_by_similarity_search_ann( self, - index_name: str, - num_leaves: int, - k: int, # Defines the limit - embedding: Optional[List[float]] = None, + embedding: List[float], + k: int, pre_filter: Optional[str] = None, + index_name: str = "", + num_leaves: int = 1000, embedding_column_is_nullable: bool = False, ascending: bool = True, return_columns: Optional[List[str]] = None, - ) -> List[Tuple[Document, float]]: + ): sql = SpannerVectorStore._generate_sql_for_ANN( self._table_name, index_name, @@ -1058,9 +1066,7 @@ def __search_by_ANN( column_order_map = { value: index for index, value in enumerate(self._columns_to_insert) } - return self._get_documents_from_query_results( - list(results), column_order_map - ) + return results, column_order_map @staticmethod def _generate_sql_for_ANN( @@ -1286,24 +1292,12 @@ def similarity_search_by_vector( Returns: List[Document]: List of documents most similar to the query. """ - documents: List[Tuple[Document, float]] = [] - if self.__using_ANN: - documents = self.__search_by_ANN( - index_name=kwargs.get("index_name", None), - num_leaves=kwargs.get("num_leaves", 1000), - k=k, - embedding=embedding, - embedding_column_is_nullable=kwargs.get( - "embedding_column_is_nullable", False - ), - pre_filter=pre_filter, - ascending=kwargs.get("ascending", True), - return_columns=kwargs.get("return_columns", []), - ) - else: - documents = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, pre_filter=pre_filter - ) + documents = self.similarity_search_with_score_by_vector( + embedding=embedding, + k=k, + pre_filter=pre_filter, + **kwargs, + ) return [doc for doc, _ in documents] @@ -1314,6 +1308,7 @@ def max_marginal_relevance_search_with_score_by_vector( fetch_k: int = 20, lambda_mult: float = 0.5, pre_filter: Optional[str] = None, + **kwargs, ) -> List[Tuple[Document, float]]: """Return docs and their similarity scores selected using the maximal marginal relevance. @@ -1334,9 +1329,17 @@ def max_marginal_relevance_search_with_score_by_vector( List of Documents and similarity scores selected by maximal marginal relevance and score for each. """ - results, column_order_map = self._get_rows_by_similarity_search_knn( - embedding, fetch_k, pre_filter - ) + if self.__using_ANN: + results, column_order_map = self._get_rows_by_similarity_search_ann( + embedding, + fetch_k, + pre_filter, + **kwargs, + ) + else: + results, column_order_map = self._get_rows_by_similarity_search_knn( + embedding, fetch_k, pre_filter + ) embeddings = [ result[column_order_map[self._embedding_column]] for result in results From e4eac6d9f9f1b7ad0269ab2010a1361aba36739a Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 4 Feb 2025 11:03:01 +0200 Subject: [PATCH 15/16] Updates from nox --- README.rst | 2 +- src/langchain_google_spanner/vector_store.py | 24 +++----------------- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/README.rst b/README.rst index cb6f9b0..ee95ba2 100644 --- a/README.rst +++ b/README.rst @@ -255,7 +255,7 @@ This is not an officially supported Google product. Limitations ----------- +----------- * Approximate Nearest Neighbors (ANN) strategies are only supported for the GoogleSQL dialect * ANN's `ALTER VECTOR INDEX` is not yet supported by [Google Cloud Spanner](https://cloud.google.com/spanner/docs/find-approximate-nearest-neighbors#limitations) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 398be20..5a5eb61 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -461,6 +461,9 @@ def _generate_sql( ) ] + if not secondary_indexes: + secondary_indexes = [] + ann_indices = list( filter(lambda index: type(index) is VectorSearchIndex, secondary_indexes) ) @@ -1080,29 +1083,8 @@ def _generate_sql_for_ANN( pre_filter: Optional[str] = None, embedding_column_is_nullable: bool = False, ascending: bool = True, - post_filter: Optional[str] = None, # TODO(@odeke-em): Not yet supported return_columns: Optional[List[str]] = None, ) -> str: - """ - Sample query: - SELECT DocId - FROM Documents@{FORCE_INDEX=DocEmbeddingIndex} - ORDER BY APPROX_EUCLIDEAN_DISTANCE( - ARRAY[1.0, 2.0, 3.0], DocEmbedding, - options => JSON '{"num_leaves_to_search": 10}') - LIMIT 100 - - OR - - SELECT DocId - FROM Documents@{FORCE_INDEX=DocEmbeddingIndex} - WHERE NullableDocEmbedding IS NOT NULL - ORDER BY APPROX_EUCLIDEAN_DISTANCE( - ARRAY[1.0, 2.0, 3.0], NullableDocEmbedding, - options => JSON '{"num_leaves_to_search": 10}') - LIMIT 100 - """ - if not embedding_column_name: raise Exception("embedding_column_name must be set") From 3b26ec9ad3b86f177519fd3244cbd322c7d1e770 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 4 Feb 2025 11:35:10 +0200 Subject: [PATCH 16/16] Fix PostGreSQL --- src/langchain_google_spanner/vector_store.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 5a5eb61..a81cff8 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -596,14 +596,14 @@ def _generate_secondary_indices_ddl_KNN( def _generate_secondary_indices_ddl_ANN( table_name, dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, secondary_indexes=[] ): + if not secondary_indexes: + return [] + if dialect != DatabaseDialect.GOOGLE_STANDARD_SQL: raise Exception( f"ANN is only supported for the GoogleSQL dialect not {dialect}. File an issue on Github?" ) - if not secondary_indexes: - return [] - secondary_index_ddl_statements = [] for secondary_index in secondary_indexes: