diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index de5df68..90fb31c 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -87,11 +87,6 @@ class SecondaryIndex: index_name: str columns: list[str] storing_columns: Optional[list[str]] = None - num_leaves: Optional[int] = None # Only necessary for ANN - nullable_column: Optional[bool] = False # Only necessary for ANN - num_branches: Optional[int] = None # Only necessary for ANN - tree_depth: Optional[int] = None # Only necessary for ANN - index_type: Optional[DistanceStrategy] = None # Only necessary for ANN def __post_init__(self): # Check if column_name is None after initialization @@ -102,24 +97,39 @@ def __post_init__(self): raise ValueError("Index Columns can't be None") +@dataclass +class VectorSearchIndex: + index_name: str + columns: list[str] + num_leaves: int + num_branches: int + tree_depth: int + index_type: DistanceStrategy + nullable_column: bool = False + + def __post_init__(self): + if self.index_name is None: + raise ValueError("index_name must be set") + + if len(self.columns) == 0: + raise ValueError("columns must be set") + + ok_tree_depth = self.tree_depth in (2, 3) + if not ok_tree_depth: + raise ValueError("tree_depth must be either 2 or 3") + + class DistanceStrategy(Enum): """ Enum for distance calculation strategies. """ COSINE = 1 - EUCLIDEIAN = 2 + EUCLIDEAN = 2 DOT_PRODUCT = 3 def __str__(self): - return DISTANCE_STRATEGY_STRING[self] - - -DISTANCE_STRATEGY_STRING = { - DistanceStrategy.COSINE: "COSINE", - DistanceStrategy.EUCLIDEIAN: "EUCLIDEIAN", - DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", -} + return self.name class DialectSemantics(ABC): @@ -128,7 +138,7 @@ class DialectSemantics(ABC): """ @abstractmethod - def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: + def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: """ Abstract method to get the distance function based on the provided distance strategy. @@ -155,22 +165,18 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: ) -_GOOGLE_DISTANCE_ALGO_NAMES = { +# Maps between distance strategy enums and the appropriate vector search index name. +GOOGLE_DIALECT_DISTANCE_FUCNTIONS = { DistanceStrategy.COSINE: "COSINE_DISTANCE", DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", - DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN_DISTANCE", + DistanceStrategy.EUCLIDEAN: "EUCLIDEAN_DISTANCE", } +# Maps between distance strategy and the appropriate ANN search function name. distance_strategy_to_ANN_function = { DistanceStrategy.COSINE: "APPROX_COSINE_DISTANCE", DistanceStrategy.DOT_PRODUCT: "APPROX_DOT_PRODUCT", - DistanceStrategy.EUCLIDEIAN: "APPROX_EUCLIDEAN_DISTANCE", -} - -_GOOGLE_ALGO_INDEX_NAME = { - DistanceStrategy.COSINE: "COSINE", - DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", - DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN", + DistanceStrategy.EUCLIDEAN: "APPROX_EUCLIDEAN_DISTANCE", } @@ -179,8 +185,8 @@ class GoogleSqlSemnatics(DialectSemantics): Implementation of dialect semantics for Google SQL. """ - def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: - return _GOOGLE_DISTANCE_ALGO_NAMES.get(distance_strategy, "EUCLIDEAN") + def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: + return GOOGLE_DIALECT_DISTANCE_FUCNTIONS.get(distance_strategy, "EUCLIDEAN") def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]: where_clause_condition = " AND ".join( @@ -201,10 +207,11 @@ def getIndexDistanceType(self, distance_strategy) -> str: return value -_PG_DISTANCE_ALGO_NAMES = { +# Maps between DistanceStrategy and the expected PostgreSQL distance equivalent. +PG_DIALECT_DISTANCE_FUNCTIONS = { DistanceStrategy.COSINE: "spanner.cosine_distance", DistanceStrategy.DOT_PRODUCT: "spanner.dot_product", - DistanceStrategy.EUCLIDEIAN: "spanner.euclidean_distance", + DistanceStrategy.EUCLIDEAN: "spanner.euclidean_distance", } @@ -213,8 +220,8 @@ class PGSqlSemnatics(DialectSemantics): Implementation of dialect semantics for PostgreSQL. """ - def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: - name = _PG_DISTANCE_ALGO_NAMES.get(distance_strategy, None) + def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: + name = PG_DIALECT_DISTANCE_FUNCTIONS.get(distance_strategy, None) if name is None: raise Exception( "Unsupported PostgreSQL distance strategy: {}".format(distance_strategy) @@ -254,7 +261,7 @@ class QueryParameters: class NearestNeighborsAlgorithm(Enum): """ - Enum for nearest neighbors search algorithms. + Enum for k-nearest neighbors search algorithms. """ EXACT_NEAREST_NEIGHBOR = 1 @@ -263,7 +270,7 @@ class NearestNeighborsAlgorithm(Enum): def __init__( self, algorithm=NearestNeighborsAlgorithm.EXACT_NEAREST_NEIGHBOR, - distance_strategy=DistanceStrategy.EUCLIDEIAN, + distance_strategy=DistanceStrategy.EUCLIDEAN, read_timestamp: Optional[datetime.datetime] = None, min_read_timestamp: Optional[datetime.datetime] = None, max_staleness: Optional[datetime.timedelta] = None, @@ -303,10 +310,6 @@ def __init__( self.staleness = {key: value} -DEFAULT_ANN_TREE_DEPTH = 2 -ANN_ACCEPTABLE_TREE_DEPTHS = (2, 3) - - class AlgoKind(Enum): KNN = 0 ANN = 1 @@ -341,8 +344,8 @@ def init_vector_store_table( metadata_columns: Optional[List[TableColumn]] = None, primary_key: Optional[str] = None, vector_size: Optional[int] = None, - secondary_indexes: Optional[List[SecondaryIndex]] = None, - kind: AlgoKind = None, + secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None, + kind: AlgoKind = AlgoKind.KNN, ) -> bool: """ Initialize the vector store new table in Google Cloud Spanner. @@ -357,6 +360,7 @@ def init_vector_store_table( - embedding_column (str): The name of the embedding column. Defaults to EMBEDDING_COLUMN_NAME. - metadata_columns (Optional[List[Tuple]]): List of tuples containing metadata column information. Defaults to None. - vector_size (Optional[int]): The size of the vector. Defaults to None. + - kind (AlgoKind): Defines whether to use k-Nearest Neighbors or Approximate Nearest Neighbors. Defaults to kNN. """ client = client_with_user_agent(client, USER_AGENT_VECTOR_STORE) @@ -400,7 +404,7 @@ def _generate_sql( embedding_column, column_configs, primary_key, - secondary_indexes: Optional[List[SecondaryIndex]] = None, + secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None, kind: Optional[AlgoKind] = AlgoKind.KNN, limit=None, ): @@ -546,7 +550,7 @@ def _generate_secondary_indices_ddl_ANN( ): if dialect != DatabaseDialect.GOOGLE_STANDARD_SQL: raise Exception( - f"ANN is only supported for the GoogleSQL dialect not {dialect}" + f"ANN is only supported for the GoogleSQL dialect not {dialect}. File an issue on Github?" ) secondary_index_ddl_statements = [] @@ -554,15 +558,13 @@ def _generate_secondary_indices_ddl_ANN( for secondary_index in secondary_indexes: column_name = secondary_index.columns[0] statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({column_name})" - if secondary_index.nullable_column: + if getattr(secondary_index, "nullable_column", False): statement += f"\n\tWHERE {column_name} IS NOT NULL" options_segments = [f"distance_type='{secondary_index.index_type}'"] - if secondary_index.tree_depth > 0: + if getattr(secondary_index, "tree_depth", 0) > 0: tree_depth = secondary_index.tree_depth - if tree_depth not in ANN_ACCEPTABLE_TREE_DEPTHS: - raise Exception( - f"tree_depth: {tree_depth} is not in the acceptable values: {ANN_ACCEPTABLE_TREE_DEPTHS}" - ) + if tree_depth not in (2, 3): + raise Exception(f"tree_depth: {tree_depth} must be either 2 or 3") options_segments.append(f"tree_depth={secondary_index.tree_depth}") if secondary_index.num_branches > 0: @@ -761,7 +763,7 @@ def _validate_table_schema(self, column_type_map, types, default_columns): def _select_relevance_score_fn(self) -> Callable[[float], float]: if self._query_parameters.distance_strategy == DistanceStrategy.COSINE: return self._cosine_relevance_score_fn - elif self._query_parameters.distance_strategy == DistanceStrategy.EUCLIDEIAN: + elif self._query_parameters.distance_strategy == DistanceStrategy.EUCLIDEAN: return self._euclidean_relevance_score_fn else: raise Exception( diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index c86bef0..b488975 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -18,12 +18,12 @@ from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from langchain_google_spanner.vector_store import ( - AlgoKind, DistanceStrategy, GoogleSqlSemnatics, PGSqlSemnatics, SecondaryIndex, SpannerVectorStore, + VectorSearchIndex, ) @@ -32,7 +32,7 @@ def test_distance_function_to_string(self): cases = [ (DistanceStrategy.COSINE, "COSINE_DISTANCE"), (DistanceStrategy.DOT_PRODUCT, "DOT_PRODUCT"), - (DistanceStrategy.EUCLIDEIAN, "EUCLIDEAN_DISTANCE"), + (DistanceStrategy.EUCLIDEAN, "EUCLIDEAN_DISTANCE"), ] sem = GoogleSqlSemnatics() @@ -46,18 +46,19 @@ def test_distance_function_to_string(self): class TestPGSqlSemnatics(unittest.TestCase): + sem = PGSqlSemnatics() + def test_distance_function_to_string(self): cases = [ (DistanceStrategy.COSINE, "spanner.cosine_distance"), (DistanceStrategy.DOT_PRODUCT, "spanner.dot_product"), - (DistanceStrategy.EUCLIDEIAN, "spanner.euclidean_distance"), + (DistanceStrategy.EUCLIDEAN, "spanner.euclidean_distance"), ] - sem = PGSqlSemnatics() got_results = [] want_results = [] for strategy, want_str in cases: - got_results.append(sem.getDistanceFunction(strategy)) + got_results.append(self.sem.getDistanceFunction(strategy)) want_results.append(want_str) assert got_results == want_results @@ -70,7 +71,7 @@ def test_distance_function_raises_exception_if_unknown(self): for strategy in strategies: with self.assertRaises(Exception): - sem.getDistanceFunction(strategy) + self.sem.getDistanceFunction(strategy) class TestSpannerVectorStore_KNN(unittest.TestCase): @@ -83,14 +84,17 @@ def test_generate_create_table_sql(self): [], "id", ) - want = "CREATE TABLE users (\n id STRING(36),\n essays STRING(MAX),\n science_scores ARRAY\n) PRIMARY KEY(id)" + want = ( + "CREATE TABLE users (\n id STRING(36),\n essays STRING(MAX)," + + "\n science_scores ARRAY\n) PRIMARY KEY(id)" + ) assert got == want def test_generate_secondary_indices_ddl_ANN(self): strategies = [ DistanceStrategy.COSINE, DistanceStrategy.DOT_PRODUCT, - DistanceStrategy.EUCLIDEIAN, + DistanceStrategy.EUCLIDEAN, ] nullables = [True, False] @@ -99,7 +103,7 @@ def test_generate_secondary_indices_ddl_ANN(self): got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( "Documents", secondary_indexes=[ - SecondaryIndex( + VectorSearchIndex( index_name="DocEmbeddingIndex", columns=["DocEmbedding"], nullable_column=nullable, @@ -115,24 +119,26 @@ def test_generate_secondary_indices_ddl_ANN(self): "CREATE VECTOR INDEX DocEmbeddingIndex\n" + " ON Documents(DocEmbedding)\n" + " WHERE DocEmbedding IS NOT NULL\n" - + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" + + f" OPTIONS(distance_type='{distance_strategy}', " + + "tree_depth=3, num_branches=1000, num_leaves=100000)" ] if not nullable: want = [ "CREATE VECTOR INDEX DocEmbeddingIndex\n" + " ON Documents(DocEmbedding)\n" - + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" + + f" OPTIONS(distance_type='{distance_strategy}', " + + "tree_depth=3, num_branches=1000, num_leaves=100000)" ] assert canonicalize(got) == canonicalize(want) - def test_generate_secondary_indices_ddl_ANN_raises_exception_for_non_GoogleSQL_dialect( + def test_generate_ANN_indices_exception_for_non_GoogleSQL_dialect( self, ): strategies = [ DistanceStrategy.COSINE, DistanceStrategy.DOT_PRODUCT, - DistanceStrategy.EUCLIDEIAN, + DistanceStrategy.EUCLIDEAN, ] for strategy in strategies: @@ -141,7 +147,7 @@ def test_generate_secondary_indices_ddl_ANN_raises_exception_for_non_GoogleSQL_d "Documents", dialect=DatabaseDialect.POSTGRESQL, secondary_indexes=[ - SecondaryIndex( + VectorSearchIndex( index_name="DocEmbeddingIndex", columns=["DocEmbedding"], num_branches=1000, @@ -163,16 +169,13 @@ def test_generate_secondary_indices_ddl_KNN_GoogleDialect(self): SecondaryIndex( index_name="DocEmbeddingIndex", columns=["DocEmbedding"], - num_branches=1000, - tree_depth=3, - index_type=DistanceStrategy.COSINE, - num_leaves=100000, ) ], ) want = [ - "CREATE INDEX DocEmbeddingIndex ON Documents(DocEmbedding) STORING (text)" + "CREATE INDEX DocEmbeddingIndex ON " + + "Documents(DocEmbedding) STORING (text)" ] assert canonicalize(got) == canonicalize(want) @@ -188,16 +191,13 @@ def test_generate_secondary_indices_ddl_KNN_PostgresDialect(self): SecondaryIndex( index_name="DocEmbeddingIndex", columns=["DocEmbedding"], - num_branches=1000, - tree_depth=3, - index_type=DistanceStrategy.COSINE, - num_leaves=100000, ) ], ) want = [ - "CREATE INDEX DocEmbeddingIndex ON Documents(DocEmbedding) INCLUDE (text)" + "CREATE INDEX DocEmbeddingIndex ON " + + "Documents(DocEmbedding) INCLUDE (text)" ] assert canonicalize(got) == canonicalize(want) @@ -217,7 +217,8 @@ def test_query_ANN(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" - + ' ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n' + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + '\'{"num_leaves_to_search": 10})\n' + "LIMIT 100" ) @@ -240,7 +241,8 @@ def test_query_ANN_column_is_nullable(self): "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + "WHERE DocEmbedding IS NOT NULL\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" - + ' ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n' + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + '\'{"num_leaves_to_search": 10})\n' + "LIMIT 100" )