Skip to content

Commit

Permalink
Infer vector_size/length from arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Jan 28, 2025
1 parent 605acd3 commit b4f7f9d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 130 deletions.
19 changes: 17 additions & 2 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def init_vector_store_table(
- content_column (str): The name of the content column. Defaults to CONTENT_COLUMN_NAME.
- embedding_column (str): The name of the embedding column. Defaults to EMBEDDING_COLUMN_NAME.
- metadata_columns (Optional[List[Tuple]]): List of tuples containing metadata column information. Defaults to None.
- vector_size (Optional[int]): The size of the vector for KNN. Defaults to None.
- vector_size (Optional[int]): The size of the vector for KNN or ANN. Defaults to None. It is presumed that exactly ONLY 1 field will have the vector.
"""

client = client_with_user_agent(client, USER_AGENT_VECTOR_STORE)
Expand All @@ -407,6 +407,7 @@ def init_vector_store_table(
metadata_columns,
primary_key,
secondary_indexes,
vector_size,
)

print("ddl", "\n".join(ddl))
Expand All @@ -427,6 +428,7 @@ def _generate_sql(
column_configs,
primary_key,
secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None,
vector_size: int = None,
):
"""
Generate SQL for creating the vector store table.
Expand All @@ -438,6 +440,7 @@ def _generate_sql(
- content_column: The name of the content column.
- embedding_column: The name of the embedding column.
- column_names: List of tuples containing metadata column information.
- vector_size: The vector length to be used by default. It is presumed by proxy of the langchain usage patterns, that exactly ONE column will be used as the embedding.
Returns:
- str: The generated SQL.
Expand All @@ -462,6 +465,7 @@ def _generate_sql(
column_configs,
primary_key,
dialect,
vector_length=vector_size,
)
]

Expand Down Expand Up @@ -497,6 +501,7 @@ def _generate_create_table_sql(
column_configs,
primary_key,
dialect=DatabaseDialect.GOOGLE_STANDARD_SQL,
vector_length=None,
):
create_table_statement = f"CREATE TABLE IF NOT EXISTS {table_name} (\n"

Expand Down Expand Up @@ -524,6 +529,11 @@ def _generate_create_table_sql(
embedding_column, "ARRAY<FLOAT64>", is_null=True
)

if not embedding_column.vector_length:
ok_vector_length = vector_length and vector_length > 0
if ok_vector_length:
embedding_column.vector_length = vector_length

configs = [id_column, content_column, embedding_column]

if column_configs is not None:
Expand All @@ -539,8 +549,13 @@ def _generate_create_table_sql(
# Append column name and data type
column_sql = f" {column_config.name} {column_config.type}"

vector_len = vector_length

if column_config.vector_length and column_config.vector_length >= 1:
column_sql += f"(vector_length=>{column_config.vector_length})"
vector_len = column_config.vector_length

if vector_len and vector_len > 0:
column_sql += f"(vector_length=>{vector_len})"

# Add nullable constraint if specified
if not column_config.is_null:
Expand Down
128 changes: 0 additions & 128 deletions tests/integration/test_spanner_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,134 +619,6 @@ def test_ann_add_data1(self, setup_database):
metadata_json_column="metadata",
)

def test_ann_add_data2(self, setup_database):
loader, embeddings = setup_database

db = SpannerVectorStore(
instance_id=instance_id,
database_id=google_database,
table_name=table_name,
id_column="row_id",
ignore_metadata_columns=[],
embedding_service=embeddings,
metadata_json_column="metadata",
)

texts = [
"Langchain Test Text 1",
"Langchain Test Text 2",
"Langchain Test Text 3",
]
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
ids_row_inserted = db.add_texts(
texts=texts,
ids=ids,
metadatas=[
{"title": "Title 1"},
{"title": "Title 2"},
{"title": "Title 3"},
],
)
assert ids == ids_row_inserted

def test_delete_data(self, setup_database):
loader, embeddings = setup_database

db = SpannerVectorStore(
instance_id=instance_id,
database_id=google_database,
table_name=table_name,
id_column="row_id",
ignore_metadata_columns=[],
embedding_service=embeddings,
metadata_json_column="metadata",
)

docs = loader.load()
deleted = db.delete(documents=[docs[0], docs[1]])
assert deleted == True

def test_search_data1(self, setup_database):
# loader, embeddings = setup_database
# db = SpannerVectorStore(
# instance_id=instance_id,
# database_id=google_database,
# table_name=table_name,
# id_column="row_id",
# ignore_metadata_columns=[],
# embedding_service=embeddings,
# metadata_json_column="metadata",
# )
# docs = db.similarity_search(
# "Testing the langchain integration with spanner", k=2
# )
# assert len(docs) == 2
pass

def test_search_data2(self, setup_database):
# TODO: Implement me
# loader, embeddings = setup_database
# db = SpannerVectorStore(
# instance_id=instance_id,
# database_id=google_database,
# table_name=table_name,
# id_column="row_id",
# ignore_metadata_columns=[],
# embedding_service=embeddings,
# metadata_json_column="metadata",
# )
# embeds = embeddings.embed_query(
# "Testing the langchain integration with spanner"
# )
# docs = db.similarity_search_by_vector(embeds, k=3, pre_filter="1 = 1")
# assert len(docs) == 3
pass

def test_search_data3(self, setup_database):
# TODO: Implement me
# loader, embeddings = setup_database
# db = SpannerVectorStore(
# instance_id=instance_id,
# database_id=google_database,
# table_name=table_name,
# id_column="row_id",
# ignore_metadata_columns=[],
# embedding_service=embeddings,
# metadata_json_column="metadata",
# query_parameters=QueryParameters(
# distance_strategy=DistanceStrategy.COSINE,
# max_staleness=datetime.timedelta(seconds=15),
# ),
# )
#
# docs = db.similarity_search(
# "Testing the langchain integration with spanner", k=3
# )
#
# assert len(docs) == 3
pass

def test_search_data4(self, setup_database):
# loader, embeddings = setup_database
# db = SpannerVectorStore(
# instance_id=instance_id,
# database_id=google_database,
# table_name=table_name,
# id_column="row_id",
# ignore_metadata_columns=[],
# embedding_service=embeddings,
# metadata_json_column="metadata",
# query_parameters=QueryParameters(
# distance_strategy=DistanceStrategy.COSINE,
# max_staleness=datetime.timedelta(seconds=15),
# ),
# )
# docs = db.max_marginal_relevance_search(
# "Testing the langchain integration with spanner", k=3
# )
# assert len(docs) == 3
pass


class TestSpannerVectorStorePGSQL:
@pytest.fixture(scope="class")
Expand Down

0 comments on commit b4f7f9d

Please sign in to comment.