diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 75c2897..080f1a6 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -87,6 +87,10 @@ class SecondaryIndex: index_name: str columns: list[str] storing_columns: Optional[list[str]] = None + num_leaves: Optional[int] = None # Only necessary for ANN + num_branches: Optional[int] = None # Only necessary for ANN + tree_depth: Optional[int] = None # Only necessary for ANN + index_type: Optional[DistanceStrategy] = None # Only necessary for ANN def __post_init__(self): # Check if column_name is None after initialization @@ -109,6 +113,16 @@ class DistanceStrategy(Enum): APPROX_COSINE = 5 APPROX_EUCLIDEAN = 6 + def __str__(self): + return DISTANCE_STRATEGY_STRING[self] + + +DISTANCE_STRATEGY_STRING = { + DistanceStrategy.COSINE: "COSINE", + DistanceStrategy.EUCLIDEIAN: "EUCLIDEIAN", + DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", +} + class DialectSemantics(ABC): """ @@ -152,6 +166,12 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN_DISTANCE", } +_GOOGLE_ALGO_INDEX_NAME = { + DistanceStrategy.COSINE: "COSINE", + DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", + DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN", +} + class GoogleSqlSemnatics(DialectSemantics): """ @@ -173,6 +193,12 @@ def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]: def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: return dict(zip(columns, values)) + def getIndexDistanceType(self, distance_strategy) -> str: + value = _GOOGLE_ALGO_INDEX_NAME.get(distance_strategy, None) + if value is None: + raise Exception(f"{distance_strategy} is unsupported for distance_type") + return value + _PG_DISTANCE_ALGO_NAMES = { DistanceStrategy.COSINE: "spanner.cosine_distance", @@ -276,6 +302,15 @@ def __init__( self.staleness = {key: value} +DEFAULT_ANN_TREE_DEPTH = 2 +ANN_ACCEPTABLE_TREE_DEPTHS = (2, 3) + + +class AlgoKind(Enum): + KNN = 0 + ANN = 1 + + class SpannerVectorStore(VectorStore): GSQL_TYPES = { CONTENT_COLUMN_NAME: ["STRING"], @@ -306,6 +341,7 @@ def init_vector_store_table( primary_key: Optional[str] = None, vector_size: Optional[int] = None, secondary_indexes: Optional[List[SecondaryIndex]] = None, + kind: AlgoKind = None, ) -> bool: """ Initialize the vector store new table in Google Cloud Spanner. @@ -344,6 +380,7 @@ def init_vector_store_table( metadata_columns, primary_key, secondary_indexes, + kind=kind, ) operation = database.update_ddl(ddl) @@ -363,6 +400,7 @@ def _generate_sql( column_configs, primary_key, secondary_indexes: Optional[List[SecondaryIndex]] = None, + kind: Optional[AlgoKind] = AlgoKind.KNN, ): """ Generate SQL for creating the vector store table. @@ -378,6 +416,40 @@ def _generate_sql( Returns: - str: The generated SQL. """ + + ddl_statements = [ + SpannerVectorStore._generate_create_table_sql( + table_name, + id_column, + content_column, + embedding_column, + column_configs, + primary_key, + dialect, + ) + ] + + if kind == AlgoKind.ANN: + ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_ANN( + table_name, dialect, secondary_indexes + ) + else: + ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_KNN( + table_name, embedding_column, dialect, secondary_indexes + ) + + return ddl_statements + + @staticmethod + def _generate_create_table_sql( + table_name, + id_column, + content_column, + embedding_column, + column_configs, + primary_key, + dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + ): create_table_statement = f"CREATE TABLE {table_name} (\n" if not isinstance(id_column, TableColumn): @@ -438,30 +510,66 @@ def _generate_sql( + ")" ) + return create_table_statement + + @staticmethod + def _generate_secondary_indices_ddl_KNN( + table_name, embedding_column, dialect, secondary_indexes=None + ): + if not secondary_indexes: + return [] + secondary_index_ddl_statements = [] + for secondary_index in secondary_indexes: + statement = f"CREATE INDEX {secondary_index.index_name} ON {table_name}(" + statement = statement + ",".join(secondary_index.columns) + ") " - if secondary_indexes is not None: - for secondary_index in secondary_indexes: - statement = ( - f"CREATE INDEX {secondary_index.index_name} ON {table_name}(" - ) - statement = statement + ",".join(secondary_index.columns) + ") " + if dialect == DatabaseDialect.POSTGRESQL: + statement = statement + "INCLUDE (" + else: + statement = statement + "STORING (" + + if secondary_index.storing_columns is None: + secondary_index.storing_columns = [embedding_column.name] + elif embedding_column not in secondary_index.storing_columns: + secondary_index.storing_columns.append(embedding_column.name) + + statement = statement + ",".join(secondary_index.storing_columns) + ")" + secondary_index_ddl_statements.append(statement) + return secondary_index_ddl_statements + + @staticmethod + def _generate_secondary_indices_ddl_ANN( + table_name, dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, secondary_indexes=[] + ): + if dialect != DatabaseDialect.GOOGLE_STANDARD_SQL: + raise Exception( + f"ANN is only supported for the GoogleSQL dialect not {dialect}" + ) + + secondary_index_ddl_statements = [] - if dialect == DatabaseDialect.POSTGRESQL: - statement = statement + "INCLUDE (" - else: - statement = statement + "STORING (" + for secondary_index in secondary_indexes: + statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({secondary_index.columns[0]})" + options_segments = [f"distance_type='{secondary_index.index_type}'"] + if secondary_index.tree_depth > 0: + tree_depth = secondary_index.tree_depth + if tree_depth not in ANN_ACCEPTABLE_TREE_DEPTHS: + raise Exception( + f"tree_depth: {tree_depth} is not in the acceptable values: {ANN_ACCEPTABLE_TREE_DEPTHS}" + ) + options_segments.append(f"tree_depth={secondary_index.tree_depth}") - if secondary_index.storing_columns is None: - secondary_index.storing_columns = [embedding_column.name] - elif embedding_column not in secondary_index.storing_columns: - secondary_index.storing_columns.append(embedding_column.name) + if secondary_index.num_branches > 0: + options_segments.append(f"num_branches={secondary_index.num_branches}") - statement = statement + ",".join(secondary_index.storing_columns) + ")" + if secondary_index.num_leaves > 0: + options_segments.append(f"num_leaves={secondary_index.num_leaves}") - secondary_index_ddl_statements.append(statement) + statement += "\n\tOPTIONS(" + ", ".join(options_segments) + ")" + secondary_index_ddl_statements.append(statement.strip()) - return [create_table_statement] + secondary_index_ddl_statements + return secondary_index_ddl_statements def __init__( self, diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 31f9e3d..350dfad 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -18,6 +18,8 @@ DistanceStrategy, GoogleSqlSemnatics, PGSqlSemnatics, + SecondaryIndex, + SpannerVectorStore, ) @@ -69,3 +71,37 @@ def test_distance_function_raises_exception_if_unknown(self): for strategy in strategies: with self.assertRaises(Exception): sem.getDistanceFunction(strategy) + + +class TestSpannerVectorStore_KNN(unittest.TestCase): + def test_generate_create_table_sql(self): + got = SpannerVectorStore._generate_create_table_sql( + "users", + "id", + "essays", + "science_scores", + [], + "id", + ) + want = "CREATE TABLE users (\n id STRING(36),\n essays STRING(MAX),\n science_scores ARRAY\n) PRIMARY KEY(id)" + assert got == want + + def test_generate_secondary_indices_ddl_ANN(self): + got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( + "Documents", + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + num_branches=1000, + tree_depth=3, + index_type=DistanceStrategy.COSINE, + num_leaves=100000, + ) + ], + ) + want = [ + "CREATE VECTOR INDEX DocEmbeddingIndex\n\tON Documents(DocEmbedding)\n\tOPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + + assert got == want