Skip to content

Commit

Permalink
feat: add Approximate Nearest Neighbor support to distance strategies
Browse files Browse the repository at this point in the history
This change adds ANN distance strategies for GoogleSQL semantics.
While here started unit tests to effectively test out components
without having to have a running Cloud Spanner instance.

Updates #94
  • Loading branch information
odeke-em committed Dec 26, 2024
1 parent af7b616 commit aeab06b
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 10 deletions.
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,8 @@ Disclaimer

This is not an officially supported Google product.


Limitations
----------

* Approximate Nearest Neighbors (ANN) strategies are only support for the GoogleSQL dialect
11 changes: 11 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def format(session):
)


@nox.session(python=DEFAULT_PYTHON_VERSION)
def unit(session):
install_unittest_dependencies(session)
session.run(
Expand All @@ -192,3 +193,13 @@ def unit(session):
def install_unittest_dependencies(session, *constraints):
standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES
session.install(*standard_deps, *constraints)
session.run(
"pip",
"install",
"--no-compile", # To ensure no byte recompliation which is usually super slow
"-q",
"--disable-pip-version-check", # Avoid the slow version check
".",
"-r",
"requirements.txt",
)
42 changes: 32 additions & 10 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@

from __future__ import annotations

import datetime
import logging
from abc import ABC, abstractmethod
import datetime
from enum import Enum
import logging
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union

import numpy as np
from google.cloud import spanner # type: ignore
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
from google.cloud.spanner_v1 import JsonObject, param_types
from langchain_community.vectorstores.utils import maximal_marginal_relevance
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
import numpy as np

from .version import __version__

Expand Down Expand Up @@ -104,6 +104,10 @@ class DistanceStrategy(Enum):

COSINE = 1
EUCLIDEIAN = 2
DOT_PRODUCT = 3
APPROX_DOT_PRODUCT = 4
APPROX_COSINE = 5
APPROX_EUCLIDEAN = 6


class DialectSemantics(ABC):
Expand Down Expand Up @@ -139,16 +143,23 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]:
)


_GOOGLE_DISTANCE_ALGO_NAMES = {
DistanceStrategy.APPROX_COSINE: "APPROX_COSINE_DISTANCE",
DistanceStrategy.APPROX_DOT_PRODUCT: "APPROX_DOT_PRODUCT",
DistanceStrategy.APPROX_EUCLIDEAN: "APPROX_EUCLIDEAN_DISTANCE",
DistanceStrategy.COSINE: "COSINE_DISTANCE",
DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT",
DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN_DISTANCE",
}


class GoogleSqlSemnatics(DialectSemantics):
"""
Implementation of dialect semantics for Google SQL.
"""

def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str:
if distance_strategy == DistanceStrategy.COSINE:
return "COSINE_DISTANCE"

return "EUCLIDEAN_DISTANCE"
return _GOOGLE_DISTANCE_ALGO_NAMES.get(distance_strategy, "EUCLIDEAN")

def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]:
where_clause_condition = " AND ".join(
Expand All @@ -163,15 +174,25 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]:
return dict(zip(columns, values))


_PG_DISTANCE_ALGO_NAMES = {
DistanceStrategy.COSINE: "spanner.cosine_distance",
DistanceStrategy.DOT_PRODUCT: "spanner.dot_product",
DistanceStrategy.EUCLIDEIAN: "spanner.euclidean_distance",
}


class PGSqlSemnatics(DialectSemantics):
"""
Implementation of dialect semantics for PostgreSQL.
"""

def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str:
if distance_strategy == DistanceStrategy.COSINE:
return "spanner.cosine_distance"
return "spanner.euclidean_distance"
name = _PG_DISTANCE_ALGO_NAMES.get(distance_strategy, None)
if name is None:
raise Exception(
"Unsupported PostgreSQL distance strategy: {}".format(distance_strategy)
)
return name

def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]:
where_clause_condition = " AND ".join(
Expand Down Expand Up @@ -210,6 +231,7 @@ class NearestNeighborsAlgorithm(Enum):
"""

EXACT_NEAREST_NEIGHBOR = 1
APPROXIMATE_NEAREST_NEIGHBOR = 2

def __init__(
self,
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/test_vectore_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import unittest

from langchain_google_spanner.vector_store import (
DistanceStrategy,
GoogleSqlSemnatics,
PGSqlSemnatics,
)


class TestGoogleSqlSemnatics(unittest.TestCase):
def test_distance_function_to_string(self):
cases = [
(DistanceStrategy.COSINE, "COSINE_DISTANCE"),
(DistanceStrategy.DOT_PRODUCT, "DOT_PRODUCT"),
(DistanceStrategy.EUCLIDEIAN, "EUCLIDEAN_DISTANCE"),
(DistanceStrategy.APPROX_COSINE, "APPROX_COSINE_DISTANCE"),
(DistanceStrategy.APPROX_DOT_PRODUCT, "APPROX_DOT_PRODUCT"),
(DistanceStrategy.APPROX_EUCLIDEAN, "APPROX_EUCLIDEAN_DISTANCE"),
]

sem = GoogleSqlSemnatics()
got_results = []
want_results = []
for strategy, want_str in cases:
got_results.append(sem.getDistanceFunction(strategy))
want_results.append(want_str)

assert got_results == want_results


class TestPGSqlSemnatics(unittest.TestCase):
def test_distance_function_to_string(self):
cases = [
(DistanceStrategy.COSINE, "spanner.cosine_distance"),
(DistanceStrategy.DOT_PRODUCT, "spanner.dot_product"),
(DistanceStrategy.EUCLIDEIAN, "spanner.euclidean_distance"),
]

sem = PGSqlSemnatics()
got_results = []
want_results = []
for strategy, want_str in cases:
got_results.append(sem.getDistanceFunction(strategy))
want_results.append(want_str)

assert got_results == want_results

def test_distance_function_raises_exception_if_unknown(self):
strategies = [
DistanceStrategy.APPROX_COSINE,
DistanceStrategy.APPROX_DOT_PRODUCT,
DistanceStrategy.APPROX_EUCLIDEAN,
]

for strategy in strategies:
with self.assertRaises(Exception):
sem.getDistanceFunction(strategy)

0 comments on commit aeab06b

Please sign in to comment.