From b4f7f9ddcbb58451a60a938c000739e6e931ed5e Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 28 Jan 2025 15:34:15 +0200 Subject: [PATCH] Infer vector_size/length from arguments --- src/langchain_google_spanner/vector_store.py | 19 ++- .../integration/test_spanner_vector_store.py | 128 ------------------ 2 files changed, 17 insertions(+), 130 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 1475b8d..587c044 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -382,7 +382,7 @@ def init_vector_store_table( - content_column (str): The name of the content column. Defaults to CONTENT_COLUMN_NAME. - embedding_column (str): The name of the embedding column. Defaults to EMBEDDING_COLUMN_NAME. - metadata_columns (Optional[List[Tuple]]): List of tuples containing metadata column information. Defaults to None. - - vector_size (Optional[int]): The size of the vector for KNN. Defaults to None. + - vector_size (Optional[int]): The size of the vector for KNN or ANN. Defaults to None. It is presumed that exactly ONLY 1 field will have the vector. """ client = client_with_user_agent(client, USER_AGENT_VECTOR_STORE) @@ -407,6 +407,7 @@ def init_vector_store_table( metadata_columns, primary_key, secondary_indexes, + vector_size, ) print("ddl", "\n".join(ddl)) @@ -427,6 +428,7 @@ def _generate_sql( column_configs, primary_key, secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None, + vector_size: int = None, ): """ Generate SQL for creating the vector store table. @@ -438,6 +440,7 @@ def _generate_sql( - content_column: The name of the content column. - embedding_column: The name of the embedding column. - column_names: List of tuples containing metadata column information. + - vector_size: The vector length to be used by default. It is presumed by proxy of the langchain usage patterns, that exactly ONE column will be used as the embedding. Returns: - str: The generated SQL. @@ -462,6 +465,7 @@ def _generate_sql( column_configs, primary_key, dialect, + vector_length=vector_size, ) ] @@ -497,6 +501,7 @@ def _generate_create_table_sql( column_configs, primary_key, dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + vector_length=None, ): create_table_statement = f"CREATE TABLE IF NOT EXISTS {table_name} (\n" @@ -524,6 +529,11 @@ def _generate_create_table_sql( embedding_column, "ARRAY", is_null=True ) + if not embedding_column.vector_length: + ok_vector_length = vector_length and vector_length > 0 + if ok_vector_length: + embedding_column.vector_length = vector_length + configs = [id_column, content_column, embedding_column] if column_configs is not None: @@ -539,8 +549,13 @@ def _generate_create_table_sql( # Append column name and data type column_sql = f" {column_config.name} {column_config.type}" + vector_len = vector_length + if column_config.vector_length and column_config.vector_length >= 1: - column_sql += f"(vector_length=>{column_config.vector_length})" + vector_len = column_config.vector_length + + if vector_len and vector_len > 0: + column_sql += f"(vector_length=>{vector_len})" # Add nullable constraint if specified if not column_config.is_null: diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index f00bc62..1bc5e80 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -619,134 +619,6 @@ def test_ann_add_data1(self, setup_database): metadata_json_column="metadata", ) - def test_ann_add_data2(self, setup_database): - loader, embeddings = setup_database - - db = SpannerVectorStore( - instance_id=instance_id, - database_id=google_database, - table_name=table_name, - id_column="row_id", - ignore_metadata_columns=[], - embedding_service=embeddings, - metadata_json_column="metadata", - ) - - texts = [ - "Langchain Test Text 1", - "Langchain Test Text 2", - "Langchain Test Text 3", - ] - ids = [str(uuid.uuid4()) for _ in range(len(texts))] - ids_row_inserted = db.add_texts( - texts=texts, - ids=ids, - metadatas=[ - {"title": "Title 1"}, - {"title": "Title 2"}, - {"title": "Title 3"}, - ], - ) - assert ids == ids_row_inserted - - def test_delete_data(self, setup_database): - loader, embeddings = setup_database - - db = SpannerVectorStore( - instance_id=instance_id, - database_id=google_database, - table_name=table_name, - id_column="row_id", - ignore_metadata_columns=[], - embedding_service=embeddings, - metadata_json_column="metadata", - ) - - docs = loader.load() - deleted = db.delete(documents=[docs[0], docs[1]]) - assert deleted == True - - def test_search_data1(self, setup_database): - # loader, embeddings = setup_database - # db = SpannerVectorStore( - # instance_id=instance_id, - # database_id=google_database, - # table_name=table_name, - # id_column="row_id", - # ignore_metadata_columns=[], - # embedding_service=embeddings, - # metadata_json_column="metadata", - # ) - # docs = db.similarity_search( - # "Testing the langchain integration with spanner", k=2 - # ) - # assert len(docs) == 2 - pass - - def test_search_data2(self, setup_database): - # TODO: Implement me - # loader, embeddings = setup_database - # db = SpannerVectorStore( - # instance_id=instance_id, - # database_id=google_database, - # table_name=table_name, - # id_column="row_id", - # ignore_metadata_columns=[], - # embedding_service=embeddings, - # metadata_json_column="metadata", - # ) - # embeds = embeddings.embed_query( - # "Testing the langchain integration with spanner" - # ) - # docs = db.similarity_search_by_vector(embeds, k=3, pre_filter="1 = 1") - # assert len(docs) == 3 - pass - - def test_search_data3(self, setup_database): - # TODO: Implement me - # loader, embeddings = setup_database - # db = SpannerVectorStore( - # instance_id=instance_id, - # database_id=google_database, - # table_name=table_name, - # id_column="row_id", - # ignore_metadata_columns=[], - # embedding_service=embeddings, - # metadata_json_column="metadata", - # query_parameters=QueryParameters( - # distance_strategy=DistanceStrategy.COSINE, - # max_staleness=datetime.timedelta(seconds=15), - # ), - # ) - # - # docs = db.similarity_search( - # "Testing the langchain integration with spanner", k=3 - # ) - # - # assert len(docs) == 3 - pass - - def test_search_data4(self, setup_database): - # loader, embeddings = setup_database - # db = SpannerVectorStore( - # instance_id=instance_id, - # database_id=google_database, - # table_name=table_name, - # id_column="row_id", - # ignore_metadata_columns=[], - # embedding_service=embeddings, - # metadata_json_column="metadata", - # query_parameters=QueryParameters( - # distance_strategy=DistanceStrategy.COSINE, - # max_staleness=datetime.timedelta(seconds=15), - # ), - # ) - # docs = db.max_marginal_relevance_search( - # "Testing the langchain integration with spanner", k=3 - # ) - # assert len(docs) == 3 - pass - class TestSpannerVectorStorePGSQL: @pytest.fixture(scope="class")