Skip to content

Commit

Permalink
Incorporate pre_filter and post_filter plus update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Feb 1, 2025
1 parent 8be267d commit b8948a3
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 79 deletions.
122 changes: 62 additions & 60 deletions samples/search_ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,69 +66,9 @@ def use_case():
client = Client(project=project_id)
database = client.instance(instance_id).database(google_database)

model_ddl_statements = [
f"""
CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT(
content STRING(MAX),
) OUTPUT(
embeddings STRUCT<statistics STRUCT<truncated BOOL, token_count FLOAT32>, values ARRAY<FLOAT32>>,
) REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004'
)
""",
f"""
CREATE MODEL IF NOT EXISTS LLMModel INPUT(
prompt STRING(MAX),
) OUTPUT(
content STRING(MAX),
) REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro',
default_batch_size = 1
)
""",
]
operation = database.update_ddl(model_ddl_statements)
operation.result(OPERATION_TIMEOUT_SECONDS)

def clear_and_insert_data(tx):
tx.execute_update("DELETE FROM products WHERE 1=1")
tx.insert(
"products",
columns=[
"categoryId",
"productId",
"productName",
"productDescription",
"createTime",
"inventoryCount",
"priceInCents",
],
values=raw_data,
)

tx.execute_update(
"""UPDATE products p1
SET productDescriptionEmbedding =
(SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel,
(SELECT productDescription as content FROM products p2 where p2.productId=p1.productId)))
WHERE categoryId=1""",
)

embeddings = []
rows = tx.execute_sql(
"""SELECT embeddings.values
FROM ML.PREDICT(
MODEL EmbeddingsModel,
(SELECT "I'd like to buy a starter bike for my 3 year old child" as content)
)"""
)

for row in rows:
for nesting in row:
embeddings.extend(nesting)

return embeddings

embeddings = database.run_in_transaction(clear_and_insert_data)
if len(embeddings) > model_vector_size:
embeddings = embeddings[:model_vector_size]
Expand Down Expand Up @@ -271,6 +211,68 @@ def PENDING_COMMIT_TIMESTAMP():
"priceInCents",
]

model_ddl_statements = [
f"""
CREATE MODEL IF NOT EXISTS EmbeddingsModel INPUT(
content STRING(MAX),
) OUTPUT(
embeddings STRUCT<statistics STRUCT<truncated BOOL, token_count FLOAT32>, values ARRAY<FLOAT32>>,
) REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/text-embedding-004'
)
""",
f"""
CREATE MODEL IF NOT EXISTS LLMModel INPUT(
prompt STRING(MAX),
) OUTPUT(
content STRING(MAX),
) REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-pro',
default_batch_size = 1
)
""",
]


def clear_and_insert_data(tx):
tx.execute_update("DELETE FROM products WHERE 1=1")
tx.insert(
"products",
columns=[
"categoryId",
"productId",
"productName",
"productDescription",
"createTime",
"inventoryCount",
"priceInCents",
],
values=raw_data,
)

tx.execute_update(
"""UPDATE products p1
SET productDescriptionEmbedding =
(SELECT embeddings.values from ML.PREDICT(MODEL EmbeddingsModel,
(SELECT productDescription as content FROM products p2 where p2.productId=p1.productId)))
WHERE categoryId=1""",
)

embeddings = []
rows = tx.execute_sql(
"""SELECT embeddings.values
FROM ML.PREDICT(
MODEL EmbeddingsModel,
(SELECT "I'd like to buy a starter bike for my 3 year old child" as content)
)"""
)

for row in rows:
for nesting in row:
embeddings.extend(nesting)

return embeddings


if __name__ == "__main__":
main()
43 changes: 25 additions & 18 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,10 +1043,10 @@ def __search_by_ANN(
self,
index_name: str,
num_leaves: int,
limit: int,
k: int, # Defines the limit
embedding: List[float] = None,
is_embedding_nullable: bool = False,
where_condition: str = None,
pre_filter: str = None,
embedding_column_is_nullable: bool = False,
ascending: bool = True,
return_columns: List[str] = None,
Expand All @@ -1057,10 +1057,10 @@ def __search_by_ANN(
self._embedding_column,
embedding or self._embedding_service,
num_leaves,
limit,
k,
self._query_parameters.distance_strategy,
is_embedding_nullable,
where_condition,
pre_filter=pre_filter,
embedding_column_is_nullable=embedding_column_is_nullable,
ascending=ascending,
return_columns=return_columns,
Expand All @@ -1085,12 +1085,13 @@ def _query_ANN(
embedding_column_name: str,
embedding: List[float],
num_leaves: int,
limit: int,
k: int,
strategy: DistanceStrategy = DistanceStrategy.COSINE,
is_embedding_nullable: bool = False,
where_condition: str = None,
pre_filter: str = None,
embedding_column_is_nullable: bool = False,
ascending: bool = True,
post_filter: str = None, # TODO(@odeke-em): Not yet supported
return_columns: List[str] = None,
):
"""
Expand Down Expand Up @@ -1132,27 +1133,24 @@ def _query_ANN(
+ "@{FORCE_INDEX="
+ f"{index_name}"
+ (
"}\n"
("}\nWHERE " + ("1=1" if not pre_filter else f"{pre_filter}") + "\n")
if (not embedding_column_is_nullable)
else "}\nWHERE " + f"{embedding_column_name} IS NOT NULL\n"
else "}\nWHERE "
+ f"{embedding_column_name} IS NOT NULL"
+ ("" if not pre_filter else f" AND {pre_filter}")
+ "\n"
)
+ f"ORDER BY {ann_strategy_name}(\n"
+ f" ARRAY<FLOAT32>{embedding}, {embedding_column_name}, options => JSON '"
+ '{"num_leaves_to_search": %s}\')%s\n'
% (num_leaves, "" if ascending else " DESC")
)

if where_condition:
sql += " WHERE " + where_condition + "\n"

if limit:
sql += f"LIMIT {limit}"
if k:
sql += f"LIMIT {k}"

return sql.strip()

def _get_rows_by_similarity_search_ann():
pass

def _get_rows_by_similarity_search_knn(
self,
embedding: List[float],
Expand Down Expand Up @@ -1205,6 +1203,15 @@ def _get_rows_by_similarity_search_knn(

return list(results), column_order_map

def _get_rows_by_similarity_search_ann(
self,
embedding: List[float],
k: int,
pre_filter: Optional[str] = None,
**kwargs: Any,
):
raise RuntimeError("Unimplemented")

def _get_documents_from_query_results(
self, results: List[List], column_order_map: Dict[str, int]
) -> List[Tuple[Document, float]]:
Expand Down Expand Up @@ -1306,10 +1313,10 @@ def similarity_search_by_vector(
documents = self.__search_by_ANN(
index_name=kwargs.get("index_name", None),
num_leaves=kwargs.get("num_leaves", 1000),
limit=k,
k=k,
embedding=embedding,
is_embedding_nullable=kwargs.get("is_embedding_nullable", False),
where_condition=kwargs.get("where_condition", ""),
pre_filter=kwargs.get("pre_filter", ""),
ascending=kwargs.get("ascending", True),
return_columns=kwargs.get("return_columns", []),
)
Expand Down
79 changes: 78 additions & 1 deletion tests/unit/test_vectore_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def test_query_ANN(self):

want = (
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
+ "WHERE 1=1\n"
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
+ "'{\"num_leaves_to_search\": 10}')\n"
Expand Down Expand Up @@ -286,6 +287,7 @@ def test_query_ANN_order_DESC(self):

want = (
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
+ "WHERE 1=1\n"
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
+ "'{\"num_leaves_to_search\": 10}') DESC\n"
Expand All @@ -294,7 +296,7 @@ def test_query_ANN_order_DESC(self):

assert got == want

def test_query_ANN_unspecified_limit(self):
def test_query_ANN_specified_limit(self):
got = SpannerVectorStore._query_ANN(
"Documents",
"DocEmbeddingIndex",
Expand All @@ -308,6 +310,81 @@ def test_query_ANN_unspecified_limit(self):

want = (
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
+ "WHERE 1=1\n"
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
+ "'{\"num_leaves_to_search\": 10}')\n"
+ "LIMIT 100"
)

assert got == want

def test_query_ANN_specified_pre_filter(self):
got = SpannerVectorStore._query_ANN(
"Documents",
"DocEmbeddingIndex",
"DocEmbedding",
[1.0, 2.0, 3.0],
10,
100,
DistanceStrategy.COSINE,
return_columns=["DocId"],
pre_filter="categoryId!=20",
)

want = (
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
+ "WHERE categoryId!=20\n"
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
+ "'{\"num_leaves_to_search\": 10}')\n"
+ "LIMIT 100"
)

assert got == want

def test_query_ANN_specified_pre_filter_with_nullable_column(self):
got = SpannerVectorStore._query_ANN(
"Documents",
"DocEmbeddingIndex",
"DocEmbedding",
[1.0, 2.0, 3.0],
10,
100,
DistanceStrategy.COSINE,
return_columns=["DocId"],
pre_filter="categoryId!=9",
embedding_column_is_nullable=True,
)

want = (
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
+ "WHERE DocEmbedding IS NOT NULL AND categoryId!=9\n"
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
+ "'{\"num_leaves_to_search\": 10}')\n"
+ "LIMIT 100"
)

assert got == want

def test_query_ANN_no_pre_filter_non_nullable(self):
got = SpannerVectorStore._query_ANN(
"Documents",
"DocEmbeddingIndex",
"DocEmbedding",
[1.0, 2.0, 3.0],
10,
100,
DistanceStrategy.COSINE,
embedding_column_is_nullable=True,
return_columns=["DocId"],
pre_filter="DocId!=2",
)

want = (
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
+ "WHERE DocEmbedding IS NOT NULL AND DocId!=2\n"
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
+ " ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON "
+ "'{\"num_leaves_to_search\": 10}')\n"
Expand Down

0 comments on commit b8948a3

Please sign in to comment.