Skip to content

Commit

Permalink
Ensure vector fits within limits in sample
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Jan 30, 2025
1 parent 3845f3f commit d654428
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 120 deletions.
4 changes: 2 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
nox.options.error_on_missing_interpreters = True


@nox.session(python="3.10")
@nox.session(python=DEFAULT_PYTHON_VERSION)
def docs(session):
"""Build the docs for this library."""

Expand Down Expand Up @@ -76,7 +76,7 @@ def docs(session):
)


@nox.session(python="3.10")
@nox.session(python=DEFAULT_PYTHON_VERSION)
def docfx(session):
"""Build the docfx yaml files for this library."""

Expand Down
237 changes: 122 additions & 115 deletions samples/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,131 +18,138 @@
VectorSearchIndex,
)

project_id = 'quip-441723'
instance_id = 'contracting'
google_database = 'ann'
zone = os.environ.get("GOOGLE_DATABASE_ZONE", "us-west2")
project_id = os.environ['PROJECT_ID']
instance_id = os.environ['INSTANCE_ID']
google_database = os.environ['GOOGLE_DATABASE']
zone = os.environ["GOOGLE_DATABASE_ZONE"]
table_name_ANN = "products"
OPERATION_TIMEOUT_SECONDS = 240

def use_case():
# Initialize the vector store table if necessary.
distance_strategy = DistanceStrategy.COSINE
SpannerVectorStore.init_vector_store_table(
instance_id=instance_id,
database_id=google_database,
table_name=table_name_ANN,
id_column=TableColumn("productId", type="INT64"),
vector_size=758,
embedding_column=TableColumn(
name="productDescriptionEmbedding",
type="ARRAY<FLOAT32>",
is_null=True,
# Initialize the vector store table if necessary.
distance_strategy = DistanceStrategy.COSINE
model_vector_size = 758
SpannerVectorStore.init_vector_store_table(
instance_id=instance_id,
database_id=google_database,
table_name=table_name_ANN,
id_column=TableColumn("productId", type="INT64"),
vector_size=model_vector_size,
embedding_column=TableColumn(
name="productDescriptionEmbedding",
type="ARRAY<FLOAT32>",
is_null=True,
),
metadata_columns=[
TableColumn(name="categoryId", type="INT64", is_null=False),
TableColumn(name="productName", type="STRING(MAX)", is_null=False),
TableColumn(
name="productDescription", type="STRING(MAX)", is_null=False
),
metadata_columns=[
TableColumn(name="categoryId", type="INT64", is_null=False),
TableColumn(name="productName", type="STRING(MAX)", is_null=False),
TableColumn(
name="productDescription", type="STRING(MAX)", is_null=False
),
TableColumn(name="inventoryCount", type="INT64", is_null=False),
TableColumn(name="priceInCents", type="INT64", is_null=True),
TableColumn(name="createTime", type="TIMESTAMP", is_null=False),
],
secondary_indexes=[
VectorSearchIndex(
index_name="ProductDescriptionEmbeddingIndex",
columns=["productDescriptionEmbedding"],
nullable_column=True,
num_branches=1000,
tree_depth=3,
index_type=distance_strategy,
num_leaves=100000,
),
TableColumn(name="inventoryCount", type="INT64", is_null=False),
TableColumn(name="priceInCents", type="INT64", is_null=True),
TableColumn(name="createTime", type="TIMESTAMP", is_null=False),
],
secondary_indexes=[
VectorSearchIndex(
index_name="ProductDescriptionEmbeddingIndex",
columns=["productDescriptionEmbedding"],
nullable_column=True,
num_branches=1000,
tree_depth=3,
index_type=distance_strategy,
num_leaves=100000,
),
],
)

# Create the models if necessary.
client = Client(project=project_id)
database = client.instance(instance_id).database(google_database)

model_ddl_statements = [
f"""
CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT(
content STRING(MAX),
) OUTPUT(
embeddings STRUCT<statistics STRUCT<truncated BOOL, token_count FLOAT32>, values ARRAY<FLOAT32>>,
) REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004'
)
""",
f"""
CREATE MODEL IF NOT EXISTS LLMModel INPUT(
prompt STRING(MAX),
) OUTPUT(
content STRING(MAX),
) REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro',
default_batch_size = 1
)
""",
]
operation = database.update_ddl(model_ddl_statements)
operation.result(OPERATION_TIMEOUT_SECONDS)

def clear_and_insert_data(tx):
tx.execute_update("DELETE FROM products WHERE 1=1")
tx.insert(
'products',
columns=[
'categoryId', 'productId', 'productName',
'productDescription',
'createTime', 'inventoryCount', 'priceInCents',
],
values=raw_data,
)

# Create the models if necessary.
client = Client(project=project_id)
database = client.instance(instance_id).database(google_database)

model_ddl_statements = [
f"""
CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT(
content STRING(MAX),
) OUTPUT(
embeddings STRUCT<statistics STRUCT<truncated BOOL, token_count FLOAT32>, values ARRAY<FLOAT32>>,
) REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004'
)
""",
f"""
CREATE MODEL IF NOT EXISTS LLMModel INPUT(
prompt STRING(MAX),
) OUTPUT(
content STRING(MAX),
) REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro',
default_batch_size = 1
)
""",
]
operation = database.update_ddl(model_ddl_statements)
operation.result(OPERATION_TIMEOUT_SECONDS)

def clear_and_insert_data(tx):
tx.execute_update("DELETE FROM products WHERE 1=1")
tx.insert(
'products',
columns=[
'categoryId', 'productId', 'productName',
'productDescription',
'createTime', 'inventoryCount', 'priceInCents',
],
values=raw_data,
)

tx.execute_update(
"""UPDATE products p1
SET productDescriptionEmbedding =
(SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel,
(SELECT productDescription as content FROM products p2 where p2.productId=p1.productId)))
WHERE categoryId=1""",
)

embeddings = []
rows = tx.execute_sql(
"""SELECT embeddings.values
FROM ML.PREDICT(
MODEL EmbeddingsModel,
(SELECT "I'd like to buy a starter bike for my 3 year old child" as content)
)""")

for row in rows:
for nesting in row:
embeddings.extend(nesting)

return embeddings

embeddings = database.run_in_transaction(clear_and_insert_data)

vec_store = SpannerVectorStore(
instance_id=instance_id,
database_id=google_database,
table_name=table_name_ANN,
id_column="categoryId",
embedding_service=embeddings,
embedding_column="productDescriptionEmbedding",
skip_not_nullable_columns=True,
)
vec_store.search_by_ANN(
'ProductDescriptionEmbeddingIndex',
1000,
k=20,
embedding_column_is_nullable=True,
return_columns=['productName', 'productDescription', 'inventoryCount'],
tx.execute_update(
"""UPDATE products p1
SET productDescriptionEmbedding =
(SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel,
(SELECT productDescription as content FROM products p2 where p2.productId=p1.productId)))
WHERE categoryId=1""",
)

embeddings = []
rows = tx.execute_sql(
"""SELECT embeddings.values
FROM ML.PREDICT(
MODEL EmbeddingsModel,
(SELECT "I'd like to buy a starter bike for my 3 year old child" as content)
)""")

for row in rows:
for nesting in row:
embeddings.extend(nesting)

return embeddings

embeddings = database.run_in_transaction(clear_and_insert_data)
if len(embeddings) > model_vector_size:
embeddings = embeddings[:model_vector_size]

vec_store = SpannerVectorStore(
instance_id=instance_id,
database_id=google_database,
table_name=table_name_ANN,
id_column="categoryId",
embedding_service=embeddings,
embedding_column="productDescriptionEmbedding",
skip_not_nullable_columns=True,
)
results = vec_store.search_by_ANN(
'ProductDescriptionEmbeddingIndex',
1000,
limit=20,
embedding_column_is_nullable=True,
return_columns=['productName', 'productDescription', 'inventoryCount'],
)

print("Search by ANN results")
for res in results:
print(res)

def main():
use_case()

Expand Down
2 changes: 1 addition & 1 deletion src/langchain_google_spanner/graph_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __clean_element(self, element: dict[str, Any], embedding_column: str) -> Non
del element["properties"][embedding_column]

def __get_distance_function(
self, distance_strategy=DistanceStrategy.EUCLIDEIAN
self, distance_strategy=DistanceStrategy.EUCLIDEAN
) -> str:
"""Gets the vector distance function."""
if distance_strategy == DistanceStrategy.COSINE:
Expand Down
5 changes: 3 additions & 2 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,8 +1048,8 @@ def search_by_ANN(
self,
index_name: str,
num_leaves: int,
limit: int,
embedding: List[float] = None,
k: int = None,
is_embedding_nullable: bool = False,
where_condition: str = None,
embedding_column_is_nullable: bool = False,
Expand All @@ -1064,7 +1064,7 @@ def search_by_ANN(
embedding or self._embedding_service,
num_leaves,
strategy,
k,
limit,
is_embedding_nullable,
where_condition,
embedding_column_is_nullable=embedding_column_is_nullable,
Expand All @@ -1075,6 +1075,7 @@ def search_by_ANN(
with self._database.snapshot(
**staleness if staleness is not None else {}
) as snapshot:
print("search by ANN sql", sql)
results = snapshot.execute_sql(sql=sql)
return list(results)

Expand Down

0 comments on commit d654428

Please sign in to comment.