Skip to content

Commit

Permalink
Update based off use case carving
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Jan 29, 2025
1 parent 0865ef8 commit a791690
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 45 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
google-cloud-spanner==3.49.1
langchain-core==0.3.9
google-cloud-spanner==3.51.0
langchain-core==0.3.15
langchain-community==0.3.1
pydantic==2.9.1
180 changes: 180 additions & 0 deletions samples/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
import datetime
import os
import time
import uuid

from google.cloud.spanner import Client # type: ignore
from langchain_community.document_loaders import HNLoader
from langchain_community.embeddings import FakeEmbeddings

from langchain_google_spanner.vector_store import ( # type: ignore
DistanceStrategy,
QueryParameters,
SpannerVectorStore,
TableColumn,
VectorSearchIndex,
)

project_id = 'quip-441723'
instance_id = 'contracting'
google_database = 'ann'
zone = os.environ.get("GOOGLE_DATABASE_ZONE", "us-west2")
table_name_ANN = "products"
OPERATION_TIMEOUT_SECONDS = 240

def use_case():
# Initialize the vector store table if necessary.
distance_strategy = DistanceStrategy.COSINE
SpannerVectorStore.init_vector_store_table(
instance_id=instance_id,
database_id=google_database,
table_name=table_name_ANN,
id_column=TableColumn("productId", type="INT64"),
vector_size=758,
embedding_column=TableColumn(
name="productDescriptionEmbedding",
type="ARRAY<FLOAT32>",
is_null=True,
),
metadata_columns=[
TableColumn(name="categoryId", type="INT64", is_null=False),
TableColumn(name="productName", type="STRING(MAX)", is_null=False),
TableColumn(
name="productDescription", type="STRING(MAX)", is_null=False
),
TableColumn(name="inventoryCount", type="INT64", is_null=False),
TableColumn(name="priceInCents", type="INT64", is_null=True),
TableColumn(name="createTime", type="TIMESTAMP", is_null=False),
],
secondary_indexes=[
VectorSearchIndex(
index_name="ProductDescriptionEmbeddingIndex",
columns=["productDescriptionEmbedding"],
nullable_column=True,
num_branches=1000,
tree_depth=3,
index_type=distance_strategy,
num_leaves=100000,
),
],
)

# Create the models if necessary.
client = Client(project=project_id)
database = client.instance(instance_id).database(google_database)

model_ddl_statements = [
f"""
CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT(
content STRING(MAX),
) OUTPUT(
embeddings STRUCT<statistics STRUCT<truncated BOOL, token_count FLOAT32>, values ARRAY<FLOAT32>>,
) REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004'
)
""",
f"""
CREATE MODEL IF NOT EXISTS LLMModel INPUT(
prompt STRING(MAX),
) OUTPUT(
content STRING(MAX),
) REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro',
default_batch_size = 1
)
""",
]
operation = database.update_ddl(model_ddl_statements)
operation.result(OPERATION_TIMEOUT_SECONDS)

def clear_and_insert_data(tx):
tx.execute_update("DELETE FROM products WHERE 1=1")
tx.insert(
'products',
columns=[
'categoryId', 'productId', 'productName',
'productDescription',
'createTime', 'inventoryCount', 'priceInCents',
],
values=raw_data,
)

tx.execute_update(
"""UPDATE products p1
SET productDescriptionEmbedding =
(SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel,
(SELECT productDescription as content FROM products p2 where p2.productId=p1.productId)))
WHERE categoryId=1""",
)

embeddings = []
rows = tx.execute_sql(
"""SELECT embeddings.values
FROM ML.PREDICT(
MODEL EmbeddingsModel,
(SELECT "I'd like to buy a starter bike for my 3 year old child" as content)
)""")

for row in rows:
for nesting in row:
embeddings.extend(nesting)

return embeddings

embeddings = database.run_in_transaction(clear_and_insert_data)

vec_store = SpannerVectorStore(
instance_id=instance_id,
database_id=google_database,
table_name=table_name_ANN,
id_column="categoryId",
embedding_service=embeddings,
embedding_column="productDescriptionEmbedding",
skip_not_nullable_columns=True,
)
vec_store.search_by_ANN(
'ProductDescriptionEmbeddingIndex',
1000,
k=20,
embedding_column_is_nullable=True,
return_columns=['productName', 'productDescription', 'inventoryCount'],
)

def main():
use_case()


def PENDING_COMMIT_TIMESTAMP():
return (datetime.datetime.utcnow() + datetime.timedelta(days=1)).isoformat() + "Z"
return 'PENDING_COMMIT_TIMESTAMP()'

raw_data = [
(1, 1, "Cymbal Helios Helmet", "Safety meets style with the Cymbal children's bike helmet. Its lightweight design, superior ventilation, and adjustable fit ensure comfort and protection on every ride. Stay bright and keep your child safe under the sun with Cymbal Helios!", PENDING_COMMIT_TIMESTAMP(), 100, 10999),
(1, 2, "Cymbal Sprout", "Let their cycling journey begin with the Cymbal Sprout, the ideal balance bike for beginning riders ages 2-4 years. Its lightweight frame, low seat height, and puncture-proof tires promote stability and confidence as little ones learn to balance and steer. Watch them sprout into cycling enthusiasts with Cymbal Sprout!", PENDING_COMMIT_TIMESTAMP(), 10, 13999),
(1, 3, "Cymbal Spark Jr.", "Light, vibrant, and ready for adventure, the Spark Jr. is the perfect first bike for young riders (ages 5-8). Its sturdy frame, easy-to-use brakes, and puncture-resistant tires inspire confidence and endless playtime. Let the spark of cycling ignite with Cymbal!", PENDING_COMMIT_TIMESTAMP(), 34, 13900),
(1, 4, "Cymbal Summit", "Conquering trails is a breeze with the Summit mountain bike. Its lightweight aluminum frame, responsive suspension, and powerful disc brakes provide exceptional control and comfort for experienced bikers navigating rocky climbs or shredding downhill. Reach new heights with Cymbal Summit!", PENDING_COMMIT_TIMESTAMP(), 0, 79999),
(1, 5, "Cymbal Breeze", "Cruise in style and embrace effortless pedaling with the Breeze electric bike. Its whisper-quiet motor and long-lasting battery let you conquer hills and distances with ease. Enjoy scenic rides, commutes, or errands with a boost of confidence from Cymbal Breeze!", PENDING_COMMIT_TIMESTAMP(), 72, 129999),
(1, 6, "Cymbal Trailblazer Backpack", "Carry all your essentials in style with the Trailblazer backpack. Its water-resistant material, multiple compartments, and comfortable straps keep your gear organized and accessible, allowing you to focus on the adventure. Blaze new trails with Cymbal Trailblazer!", PENDING_COMMIT_TIMESTAMP(), 24, 7999),
(1, 7, "Cymbal Phoenix Lights", "See and be seen with the Phoenix bike lights. Powerful LEDs and multiple light modes ensure superior visibility, enhancing your safety and enjoyment during day or night rides. Light up your journey with Cymbal Phoenix!", PENDING_COMMIT_TIMESTAMP(), 87, 3999),
(1, 8, "Cymbal Windstar Pump", "Flat tires are no match for the Windstar pump. Its compact design, lightweight construction, and high-pressure capacity make inflating tires quick and effortless. Get back on the road in no time with Cymbal Windstar!", PENDING_COMMIT_TIMESTAMP(), 36, 24999),
(1, 9,"Cymbal Odyssey Multi-Tool","Be prepared for anything with the Odyssey multi-tool. This handy gadget features essential tools like screwdrivers, hex wrenches, and tire levers, keeping you ready for minor repairs and adjustments on the go. Conquer your journey with Cymbal Odyssey!", PENDING_COMMIT_TIMESTAMP(), 52, 999),
(1, 10,"Cymbal Nomad Water Bottle","Stay hydrated on every ride with the Nomad water bottle. Its sleek design, BPA-free construction, and secure lock lid make it the perfect companion for staying refreshed and motivated throughout your adventures. Hydrate and explore with Cymbal Nomad!", PENDING_COMMIT_TIMESTAMP(), 42, 1299),
]

columns = [
"categoryId",
"productId",
"productName",
"productDescription",
"createTime",
"inventoryCount",
"priceInCents",
]


if __name__ == '__main__':
main()

73 changes: 31 additions & 42 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def __init__(
class SpannerVectorStore(VectorStore):
GSQL_TYPES = {
CONTENT_COLUMN_NAME: ["STRING"],
EMBEDDING_COLUMN_NAME: ["ARRAY<FLOAT64>"],
EMBEDDING_COLUMN_NAME: ["ARRAY<FLOAT64>", "ARRAY<FLOAT32>"],
"metadata_json_column": ["JSON"],
}

Expand Down Expand Up @@ -405,7 +405,6 @@ def init_vector_store_table(
vector_size,
)

print("ddl", "\n".join(ddl))
operation = database.update_ddl(ddl)

print("Waiting for operation to complete...")
Expand Down Expand Up @@ -466,7 +465,7 @@ def _generate_sql(

ann_indices = list(
filter(
lambda index: isinstance(index, VectorSearchIndex), secondary_indexes
lambda index: type(index) is VectorSearchIndex, secondary_indexes
)
)
ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_ANN(
Expand All @@ -475,9 +474,9 @@ def _generate_sql(
secondary_indexes=list(ann_indices),
)

knn_indices = filter(
lambda index: isinstance(index, SecondaryIndex), secondary_indexes
)
knn_indices = list(filter(
lambda index: type(index) is SecondaryIndex, secondary_indexes
))
ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_KNN(
table_name,
embedding_column,
Expand Down Expand Up @@ -544,13 +543,8 @@ def _generate_create_table_sql(
# Append column name and data type
column_sql = f" {column_config.name} {column_config.type}"

vector_len = vector_length

if column_config.vector_length and column_config.vector_length >= 1:
vector_len = column_config.vector_length

if vector_len and vector_len > 0:
column_sql += f"(vector_length=>{vector_len})"
column_sql += f"(vector_length=>{column_config.vector_length})"

# Add nullable constraint if specified
if not column_config.is_null:
Expand All @@ -571,7 +565,6 @@ def _generate_create_table_sql(
+ ")"
)

# print(create_table_statement)
return create_table_statement

@staticmethod
Expand Down Expand Up @@ -616,7 +609,7 @@ def _generate_secondary_indices_ddl_ANN(

for secondary_index in secondary_indexes:
column_name = secondary_index.columns[0]
statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({column_name})"
statement = f"CREATE VECTOR INDEX IF NOT EXISTS {secondary_index.index_name}\n\tON {table_name}({column_name})"
if getattr(secondary_index, "nullable_column", False):
statement += f"\n\tWHERE {column_name} IS NOT NULL"
options_segments = [f"distance_type='{secondary_index.index_type}'"]
Expand Down Expand Up @@ -651,6 +644,7 @@ def __init__(
ignore_metadata_columns: Optional[List[str]] = None,
metadata_json_column: Optional[str] = None,
query_parameters: QueryParameters = QueryParameters(),
skip_not_nullable_columns = False,
):
"""
Initialize the SpannerVectorStore.
Expand Down Expand Up @@ -681,6 +675,7 @@ def __init__(
self._query_parameters = query_parameters
self._embedding_service = embedding_service
self.__strategy = None
self._skip_not_nullable_columns = skip_not_nullable_columns

if metadata_columns is not None and ignore_metadata_columns is not None:
raise Exception(
Expand Down Expand Up @@ -795,9 +790,9 @@ def _validate_table_schema(self, column_type_map, types, default_columns):
for substring in types[EMBEDDING_COLUMN_NAME]
):
raise Exception(
"Embedding Column is not of correct type. Expected one of: {} but found: {}",
"Embedding Column is not of correct type. Expected one of: {} but found: {}".format(
types[EMBEDDING_COLUMN_NAME],
embedding_column_type,
embedding_column_type)
)

if self._metadata_json_column is not None:
Expand All @@ -813,13 +808,14 @@ def _validate_table_schema(self, column_type_map, types, default_columns):
embedding_column_type,
)

for column_name, column_config in column_type_map.items():
if column_name not in self._columns_to_insert:
if "NO" == column_config[2].upper():
raise Exception(
"Found not nullable constraint on column: {}.",
column_name,
)
if not self._skip_not_nullable_columns:
for column_name, column_config in column_type_map.items():
if column_name not in self._columns_to_insert:
if "NO" == column_config[2].upper():
raise Exception(
"Found not nullable constraint on column: {}.",
column_name,
)

def _select_relevance_score_fn(self) -> Callable[[float], float]:
if self._query_parameters.distance_strategy == DistanceStrategy.COSINE:
Expand Down Expand Up @@ -1050,43 +1046,36 @@ def set_strategy(strategy: DistanceStrategy):

def search_by_ANN(
self,
table_name: str,
column_name: str,
index_name: str,
embedding_column_name: str,
embedding: List[float],
num_leaves: int,
embedding: List[float] = None,
k: int = None,
is_embedding_nullable: bool = False,
where_condition: str = None,
embedding_column_is_nullable: bool = False,
ascending: bool = True,
return_columns: List[str] = None,
strategy: DistanceStrategy = DistanceStrategy.COSINE,
) -> List[Any]:
# Firstly only the GoogleSQL dialect is supported.
if self._dialect_semantics != DatabaseDialect.GOOGLE_STANDARD_SQL:
raise Exception(
f"ANN is only supported for the GoogleSQL dialect not {dialect}. File an issue on Github?"
)

sql = SpannerVectorStore._query_ANN(
table_name,
column_name,
self._table_name,
index_name,
embedding_column_name,
embedding,
self._embedding_column,
embedding or self._embedding_service,
num_leaves,
self._strategy,
strategy,
k,
is_embedding_nullable,
where_condition,
embedding_column_is_nullable=embedding_column_is_nullable,
ascending=ascending,
return_columns=return_columns,
)
staleness = self._query_parameters.staleness
with self._database.snapshot(
**staleness if staleness is not None else {}
) as snapshot:
results = snapshot.execute_sql(
sql=sql_query,
)
results = snapshot.execute_sql(sql=sql)
return list(results)

@staticmethod
Expand Down Expand Up @@ -1149,7 +1138,7 @@ def _query_ANN(
)
+ f"ORDER BY {ann_strategy_name}(\n"
+ f" ARRAY<FLOAT32>{embedding}, {embedding_column_name}, options => JSON '"
+ '{"num_leaves_to_search": %s})%s\n'
+ '{"num_leaves_to_search": %s}\')%s\n'
% (num_leaves, "" if ascending else " DESC")
)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_spanner_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def setup_database(self, client):
(SELECT productDescription as content FROM products p2 where p2.productId=p1.productId)
)
)
WHERE categoryId=1;
WHERE categoryId=1
""",
]
database = client.instance(instance_id).database(google_database)
Expand Down

0 comments on commit a791690

Please sign in to comment.