From 59b67c243e63a118871d9547669225cd9f6645b7 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Tue, 22 Oct 2024 11:24:55 +0800 Subject: [PATCH 01/14] frame code Signed-off-by: shanhaikang.shk --- embedchain/embedchain/config/__init__.py | 1 + .../embedchain/config/vector_db/oceanbase.py | 83 ++++++++++ embedchain/embedchain/factory.py | 1 + embedchain/embedchain/vectordb/oceanbase.py | 45 ++++++ embedchain/poetry.lock | 145 +++++++++++------- embedchain/pyproject.toml | 1 + 6 files changed, 222 insertions(+), 54 deletions(-) create mode 100644 embedchain/embedchain/config/vector_db/oceanbase.py create mode 100644 embedchain/embedchain/vectordb/oceanbase.py diff --git a/embedchain/embedchain/config/__init__.py b/embedchain/embedchain/config/__init__.py index 768408b785..9f9c1a8cff 100644 --- a/embedchain/embedchain/config/__init__.py +++ b/embedchain/embedchain/config/__init__.py @@ -11,5 +11,6 @@ from .mem0_config import Mem0Config from .vector_db.chroma import ChromaDbConfig from .vector_db.elasticsearch import ElasticsearchDBConfig +from .vector_db.oceanbase import OceanBaseConfig from .vector_db.opensearch import OpenSearchDBConfig from .vector_db.zilliz import ZillizDBConfig diff --git a/embedchain/embedchain/config/vector_db/oceanbase.py b/embedchain/embedchain/config/vector_db/oceanbase.py new file mode 100644 index 0000000000..fa0ea18321 --- /dev/null +++ b/embedchain/embedchain/config/vector_db/oceanbase.py @@ -0,0 +1,83 @@ +import os +from typing import Optional + +from embedchain.config.vector_db.base import BaseVectorDbConfig +from embedchain.helpers.json_serializable import register_deserializable + +DEFAULT_OCEANBASE_COLLECTION_NAME = "embedchain_vector" +DEFAULT_OCEANBASE_HOST = "localhost" +DEFAULT_OCEANBASE_PORT = "2881" +DEFAULT_OCEANBASE_USER = "root@test" +DEFAULT_OCEANBASE_PASSWORD = "" +DEFAULT_OCEANBASE_DBNAME = "test" +DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2" +DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256} + +@register_deserializable +class OceanBaseConfig(BaseVectorDbConfig): + pass + def __init__( + self, + collection_name: Optional[str] = None, + dir: str = "db", + host: Optional[str] = None, + port: Optional[str] = None, + user: Optional[str] = None, + dbname: Optional[str] = None, + vidx_metric_type: str = DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE, + vidx_algo_params: Optional[dict] = None, + drop_old: bool = False, + normalize: bool = False, + ): + """ + Initializes a configuration class instance for OceanBase. + + :param collection_name: Default name for the collection, defaults to None + :type collection_name: Optional[str], optional + :param dir: Path to the database directory, where the database is stored, defaults to "db". + In OceanBase, this parameter is not valid. + :type dir: str, optional + :param host: Database connection remote host. + :type host: Optional[str], optional + :param port: Database connection remote port. + :type port: Optional[str], optional + :param user: Database user name. + :type user: Optional[str], optional + :param dbname: OceanBase database name + :type dbname: Optional[str], optional + :param vidx_metric_type: vector index metric type, 'l2' or 'inner_product'. + :type vidx_metric_type: Optional[str], optional + :param vidx_algo_params: vector index building params, + refer to `DEFAULT_OCEANBASE_HNSW_BUILD_PARAM` for an example. + :type vidx_algo_params: Optional[dict], optional + :param drop_old: drop old table before creating. + :type drop_old: bool + :param normalize: normalize vector before storing into OceanBase. + :type normalize: bool + """ + self.collection_name = ( + collection_name or DEFAULT_OCEANBASE_COLLECTION_NAME + ) + self.host = host or DEFAULT_OCEANBASE_HOST + self.port = port or DEFAULT_OCEANBASE_PORT + self.passwd = os.environ.get("OB_PASSWORD", "") + super().__init__( + collection_name=self.collection_name, + dir=dir, + host=self.host, + port=self.port, + ) + self.user = user or DEFAULT_OCEANBASE_USER + self.dbname = dbname or DEFAULT_OCEANBASE_DBNAME + self.vidx_metric_type = vidx_metric_type.lower() + if self.vidx_metric_type not in ("l2", "inner_product"): + raise ValueError( + "`vidx_metric_type` should be set in `l2`/`inner_product`." + ) + self.vidx_algo_params = ( + vidx_algo_params + if vidx_algo_params is not None + else DEFAULT_OCEANBASE_HNSW_BUILD_PARAM + ) + self.drop_old = drop_old + self.normalize = normalize diff --git a/embedchain/embedchain/factory.py b/embedchain/embedchain/factory.py index 69636286cf..b8e31b786c 100644 --- a/embedchain/embedchain/factory.py +++ b/embedchain/embedchain/factory.py @@ -108,6 +108,7 @@ class VectorDBFactory: "qdrant": "embedchain.config.vector_db.qdrant.QdrantDBConfig", "weaviate": "embedchain.config.vector_db.weaviate.WeaviateDBConfig", "zilliz": "embedchain.config.vector_db.zilliz.ZillizDBConfig", + "oceanbase": "embedchain.config.vector_db.zilliz.OceanBaseConfig" } @classmethod diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py new file mode 100644 index 0000000000..77eaf1d79c --- /dev/null +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -0,0 +1,45 @@ +import logging + +from embedchain.config import OceanBaseConfig +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.vectordb.base import BaseVectorDB + +try: + from pyobvector import ObVecClient +except ImportError: + raise ImportError( + "OceanBase requires extra dependencies. Install with `pip install --upgrade pyobvector`" + ) from None + +logger = logging.getLogger(__name__) + +@register_deserializable +class OceanBaseVectorDB(BaseVectorDB): + """`OceanBase` vector store. + + + """ + def __init__(self, config: OceanBaseConfig = None): + if config is None: + self.config = OceanBaseConfig() + else: + self.config = config + + self.client = ObVecClient( + uri=( + self.config.host + ":" + self.config.port + ), + user=self.config.user, + password=self.config.passwd, + db_name=self.config.dbname, + ) + + super().__init__(config=self.config) + + def _initialize(self): + """ + This method is needed because `embedder` attribute needs to be set externally before it can be initialized. + + So it's can't be done in __init__ in one step. + """ + return super()._initialize() \ No newline at end of file diff --git a/embedchain/poetry.lock b/embedchain/poetry.lock index 1edf130d35..e7661193a9 100644 --- a/embedchain/poetry.lock +++ b/embedchain/poetry.lock @@ -2644,7 +2644,6 @@ description = "Client library to connect to the LangSmith LLM Tracing and Evalua optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.126-py3-none-any.whl", hash = "sha256:16c38ba5dae37a3cc715b6bc5d87d9579228433c2f34d6fa328345ee2b2bcc2a"}, {file = "langsmith-0.1.126.tar.gz", hash = "sha256:40f72e2d1d975473dd69269996053122941c1252915bcea55787607e2a7f949a"}, ] @@ -4503,6 +4502,37 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"] model = ["milvus-model (>=0.1.0)"] +[[package]] +name = "pymysql" +version = "1.1.1" +description = "Pure Python MySQL Driver" +optional = false +python-versions = ">=3.7" +files = [ + {file = "PyMySQL-1.1.1-py3-none-any.whl", hash = "sha256:4de15da4c61dc132f4fb9ab763063e693d521a80fd0e87943b9a453dd4c19d6c"}, + {file = "pymysql-1.1.1.tar.gz", hash = "sha256:e127611aaf2b417403c60bf4dc570124aeb4a57f5f37b8e95ae399a42f904cd0"}, +] + +[package.extras] +ed25519 = ["PyNaCl (>=1.4.0)"] +rsa = ["cryptography"] + +[[package]] +name = "pyobvector" +version = "0.1.6" +description = "A python SDK for OceanBase Vector Store, based on SQLAlchemy, compatible with Milvus API." +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "pyobvector-0.1.6-py3-none-any.whl", hash = "sha256:0d700e865a85b4716b9a03384189e49288cd9d5f3cef88aed4740bc82d5fd136"}, + {file = "pyobvector-0.1.6.tar.gz", hash = "sha256:05551addcac8c596992d5e38b480c83ca3481c6cfc6f56a1a1bddfb2e6ae037e"}, +] + +[package.dependencies] +numpy = ">=1.26.0,<2.0.0" +pymysql = ">=1.1.1,<2.0.0" +sqlalchemy = ">=2.0.32,<3.0.0" + [[package]] name = "pyparsing" version = "3.1.2" @@ -5470,60 +5500,68 @@ files = [ [[package]] name = "sqlalchemy" -version = "2.0.31" +version = "2.0.36" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.31-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f2a213c1b699d3f5768a7272de720387ae0122f1becf0901ed6eaa1abd1baf6c"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9fea3d0884e82d1e33226935dac990b967bef21315cbcc894605db3441347443"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3ad7f221d8a69d32d197e5968d798217a4feebe30144986af71ada8c548e9fa"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f2bee229715b6366f86a95d497c347c22ddffa2c7c96143b59a2aa5cc9eebbc"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cd5b94d4819c0c89280b7c6109c7b788a576084bf0a480ae17c227b0bc41e109"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:750900a471d39a7eeba57580b11983030517a1f512c2cb287d5ad0fcf3aebd58"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-win32.whl", hash = "sha256:7bd112be780928c7f493c1a192cd8c5fc2a2a7b52b790bc5a84203fb4381c6be"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-win_amd64.whl", hash = "sha256:5a48ac4d359f058474fadc2115f78a5cdac9988d4f99eae44917f36aa1476327"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f68470edd70c3ac3b6cd5c2a22a8daf18415203ca1b036aaeb9b0fb6f54e8298"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e2c38c2a4c5c634fe6c3c58a789712719fa1bf9b9d6ff5ebfce9a9e5b89c1ca"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd15026f77420eb2b324dcb93551ad9c5f22fab2c150c286ef1dc1160f110203"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2196208432deebdfe3b22185d46b08f00ac9d7b01284e168c212919891289396"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:352b2770097f41bff6029b280c0e03b217c2dcaddc40726f8f53ed58d8a85da4"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56d51ae825d20d604583f82c9527d285e9e6d14f9a5516463d9705dab20c3740"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-win32.whl", hash = "sha256:6e2622844551945db81c26a02f27d94145b561f9d4b0c39ce7bfd2fda5776dac"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-win_amd64.whl", hash = "sha256:ccaf1b0c90435b6e430f5dd30a5aede4764942a695552eb3a4ab74ed63c5b8d3"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3b74570d99126992d4b0f91fb87c586a574a5872651185de8297c6f90055ae42"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f77c4f042ad493cb8595e2f503c7a4fe44cd7bd59c7582fd6d78d7e7b8ec52c"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd1591329333daf94467e699e11015d9c944f44c94d2091f4ac493ced0119449"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74afabeeff415e35525bf7a4ecdab015f00e06456166a2eba7590e49f8db940e"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b9c01990d9015df2c6f818aa8f4297d42ee71c9502026bb074e713d496e26b67"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:66f63278db425838b3c2b1c596654b31939427016ba030e951b292e32b99553e"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-win32.whl", hash = "sha256:0b0f658414ee4e4b8cbcd4a9bb0fd743c5eeb81fc858ca517217a8013d282c96"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-win_amd64.whl", hash = "sha256:fa4b1af3e619b5b0b435e333f3967612db06351217c58bfb50cee5f003db2a5a"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f43e93057cf52a227eda401251c72b6fbe4756f35fa6bfebb5d73b86881e59b0"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d337bf94052856d1b330d5fcad44582a30c532a2463776e1651bd3294ee7e58b"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c06fb43a51ccdff3b4006aafee9fcf15f63f23c580675f7734245ceb6b6a9e05"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:b6e22630e89f0e8c12332b2b4c282cb01cf4da0d26795b7eae16702a608e7ca1"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:79a40771363c5e9f3a77f0e28b3302801db08040928146e6808b5b7a40749c88"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-win32.whl", hash = "sha256:501ff052229cb79dd4c49c402f6cb03b5a40ae4771efc8bb2bfac9f6c3d3508f"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-win_amd64.whl", hash = "sha256:597fec37c382a5442ffd471f66ce12d07d91b281fd474289356b1a0041bdf31d"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:dc6d69f8829712a4fd799d2ac8d79bdeff651c2301b081fd5d3fe697bd5b4ab9"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:23b9fbb2f5dd9e630db70fbe47d963c7779e9c81830869bd7d137c2dc1ad05fb"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a21c97efcbb9f255d5c12a96ae14da873233597dfd00a3a0c4ce5b3e5e79704"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26a6a9837589c42b16693cf7bf836f5d42218f44d198f9343dd71d3164ceeeac"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:dc251477eae03c20fae8db9c1c23ea2ebc47331bcd73927cdcaecd02af98d3c3"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2fd17e3bb8058359fa61248c52c7b09a97cf3c820e54207a50af529876451808"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-win32.whl", hash = "sha256:c76c81c52e1e08f12f4b6a07af2b96b9b15ea67ccdd40ae17019f1c373faa227"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-win_amd64.whl", hash = "sha256:4b600e9a212ed59355813becbcf282cfda5c93678e15c25a0ef896b354423238"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b6cf796d9fcc9b37011d3f9936189b3c8074a02a4ed0c0fbbc126772c31a6d4"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:78fe11dbe37d92667c2c6e74379f75746dc947ee505555a0197cfba9a6d4f1a4"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fc47dc6185a83c8100b37acda27658fe4dbd33b7d5e7324111f6521008ab4fe"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a41514c1a779e2aa9a19f67aaadeb5cbddf0b2b508843fcd7bafdf4c6864005"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:afb6dde6c11ea4525318e279cd93c8734b795ac8bb5dda0eedd9ebaca7fa23f1"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3f9faef422cfbb8fd53716cd14ba95e2ef655400235c3dfad1b5f467ba179c8c"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-win32.whl", hash = "sha256:fc6b14e8602f59c6ba893980bea96571dd0ed83d8ebb9c4479d9ed5425d562e9"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-win_amd64.whl", hash = "sha256:3cb8a66b167b033ec72c3812ffc8441d4e9f5f78f5e31e54dcd4c90a4ca5bebc"}, - {file = "SQLAlchemy-2.0.31-py3-none-any.whl", hash = "sha256:69f3e3c08867a8e4856e92d7afb618b95cdee18e0bc1647b77599722c9a28911"}, - {file = "SQLAlchemy-2.0.31.tar.gz", hash = "sha256:b607489dd4a54de56984a0c7656247504bd5523d9d0ba799aef59d4add009484"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59b8f3adb3971929a3e660337f5dacc5942c2cdb760afcabb2614ffbda9f9f72"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37350015056a553e442ff672c2d20e6f4b6d0b2495691fa239d8aa18bb3bc908"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8318f4776c85abc3f40ab185e388bee7a6ea99e7fa3a30686580b209eaa35c08"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c245b1fbade9c35e5bd3b64270ab49ce990369018289ecfde3f9c318411aaa07"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:69f93723edbca7342624d09f6704e7126b152eaed3cdbb634cb657a54332a3c5"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f9511d8dd4a6e9271d07d150fb2f81874a3c8c95e11ff9af3a2dfc35fe42ee44"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-win32.whl", hash = "sha256:c3f3631693003d8e585d4200730616b78fafd5a01ef8b698f6967da5c605b3fa"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-win_amd64.whl", hash = "sha256:a86bfab2ef46d63300c0f06936bd6e6c0105faa11d509083ba8f2f9d237fb5b5"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fd3a55deef00f689ce931d4d1b23fa9f04c880a48ee97af488fd215cf24e2a6c"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4f5e9cd989b45b73bd359f693b935364f7e1f79486e29015813c338450aa5a71"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0ddd9db6e59c44875211bc4c7953a9f6638b937b0a88ae6d09eb46cced54eff"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2519f3a5d0517fc159afab1015e54bb81b4406c278749779be57a569d8d1bb0d"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59b1ee96617135f6e1d6f275bbe988f419c5178016f3d41d3c0abb0c819f75bb"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:39769a115f730d683b0eb7b694db9789267bcd027326cccc3125e862eb03bfd8"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-win32.whl", hash = "sha256:66bffbad8d6271bb1cc2f9a4ea4f86f80fe5e2e3e501a5ae2a3dc6a76e604e6f"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-win_amd64.whl", hash = "sha256:23623166bfefe1487d81b698c423f8678e80df8b54614c2bf4b4cfcd7c711959"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f7b64e6ec3f02c35647be6b4851008b26cff592a95ecb13b6788a54ef80bbdd4"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:46331b00096a6db1fdc052d55b101dbbfc99155a548e20a0e4a8e5e4d1362855"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdf3386a801ea5aba17c6410dd1dc8d39cf454ca2565541b5ac42a84e1e28f53"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac9dfa18ff2a67b09b372d5db8743c27966abf0e5344c555d86cc7199f7ad83a"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:90812a8933df713fdf748b355527e3af257a11e415b613dd794512461eb8a686"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1bc330d9d29c7f06f003ab10e1eaced295e87940405afe1b110f2eb93a233588"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-win32.whl", hash = "sha256:79d2e78abc26d871875b419e1fd3c0bca31a1cb0043277d0d850014599626c2e"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-win_amd64.whl", hash = "sha256:b544ad1935a8541d177cb402948b94e871067656b3a0b9e91dbec136b06a2ff5"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b5cc79df7f4bc3d11e4b542596c03826063092611e481fcf1c9dfee3c94355ef"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3c01117dd36800f2ecaa238c65365b7b16497adc1522bf84906e5710ee9ba0e8"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9bc633f4ee4b4c46e7adcb3a9b5ec083bf1d9a97c1d3854b92749d935de40b9b"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e46ed38affdfc95d2c958de328d037d87801cfcbea6d421000859e9789e61c2"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b2985c0b06e989c043f1dc09d4fe89e1616aadd35392aea2844f0458a989eacf"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4a121d62ebe7d26fec9155f83f8be5189ef1405f5973ea4874a26fab9f1e262c"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-win32.whl", hash = "sha256:0572f4bd6f94752167adfd7c1bed84f4b240ee6203a95e05d1e208d488d0d436"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-win_amd64.whl", hash = "sha256:8c78ac40bde930c60e0f78b3cd184c580f89456dd87fc08f9e3ee3ce8765ce88"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:be9812b766cad94a25bc63bec11f88c4ad3629a0cec1cd5d4ba48dc23860486b"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50aae840ebbd6cdd41af1c14590e5741665e5272d2fee999306673a1bb1fdb4d"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4557e1f11c5f653ebfdd924f3f9d5ebfc718283b0b9beebaa5dd6b77ec290971"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:07b441f7d03b9a66299ce7ccf3ef2900abc81c0db434f42a5694a37bd73870f2"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:28120ef39c92c2dd60f2721af9328479516844c6b550b077ca450c7d7dc68575"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-win32.whl", hash = "sha256:b81ee3d84803fd42d0b154cb6892ae57ea6b7c55d8359a02379965706c7efe6c"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-win_amd64.whl", hash = "sha256:f942a799516184c855e1a32fbc7b29d7e571b52612647866d4ec1c3242578fcb"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3d6718667da04294d7df1670d70eeddd414f313738d20a6f1d1f379e3139a545"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:72c28b84b174ce8af8504ca28ae9347d317f9dba3999e5981a3cd441f3712e24"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b11d0cfdd2b095e7b0686cf5fabeb9c67fae5b06d265d8180715b8cfa86522e3"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e32092c47011d113dc01ab3e1d3ce9f006a47223b18422c5c0d150af13a00687"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6a440293d802d3011028e14e4226da1434b373cbaf4a4bbb63f845761a708346"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c54a1e53a0c308a8e8a7dffb59097bff7facda27c70c286f005327f21b2bd6b1"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-win32.whl", hash = "sha256:1e0d612a17581b6616ff03c8e3d5eff7452f34655c901f75d62bd86449d9750e"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-win_amd64.whl", hash = "sha256:8958b10490125124463095bbdadda5aa22ec799f91958e410438ad6c97a7b793"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dc022184d3e5cacc9579e41805a681187650e170eb2fd70e28b86192a479dcaa"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b817d41d692bf286abc181f8af476c4fbef3fd05e798777492618378448ee689"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4e46a888b54be23d03a89be510f24a7652fe6ff660787b96cd0e57a4ebcb46d"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4ae3005ed83f5967f961fd091f2f8c5329161f69ce8480aa8168b2d7fe37f06"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03e08af7a5f9386a43919eda9de33ffda16b44eb11f3b313e6822243770e9763"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3dbb986bad3ed5ceaf090200eba750b5245150bd97d3e67343a3cfed06feecf7"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-win32.whl", hash = "sha256:9fe53b404f24789b5ea9003fc25b9a3988feddebd7e7b369c8fac27ad6f52f28"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-win_amd64.whl", hash = "sha256:af148a33ff0349f53512a049c6406923e4e02bf2f26c5fb285f143faf4f0e46a"}, + {file = "SQLAlchemy-2.0.36-py3-none-any.whl", hash = "sha256:fddbe92b4760c6f5d48162aef14824add991aeda8ddadb3c31d56eb15ca69f8e"}, + {file = "sqlalchemy-2.0.36.tar.gz", hash = "sha256:7f2767680b6d2398aea7082e45a774b2b0767b5c8d8ffb9c8b683088ea9b29c5"}, ] [package.dependencies] @@ -5536,7 +5574,7 @@ aioodbc = ["aioodbc", "greenlet (!=0.4.17)"] aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"] asyncio = ["greenlet (!=0.4.17)"] asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] -mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] +mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10)"] mssql = ["pyodbc"] mssql-pymssql = ["pymssql"] mssql-pyodbc = ["pyodbc"] @@ -6750,5 +6788,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.13" -content-hash = "ec8a87e5281b7fa0c2c28f24c2562e823f0c546a24da2bb285b2f239b7b1758d" - +content-hash = "6f83d33b608d27686a75a476bbf717111d82dfdbf0b1e77f7118ef07f16cb40d" diff --git a/embedchain/pyproject.toml b/embedchain/pyproject.toml index 6e7cc44d67..fdc4328e9e 100644 --- a/embedchain/pyproject.toml +++ b/embedchain/pyproject.toml @@ -141,6 +141,7 @@ langchain-community = "^0.3.1" langchain-aws = {version = "^0.2.1", optional = true} langsmith = "^0.1.17" +pyobvector = "^0.1.6" [tool.poetry.group.dev.dependencies] black = "^23.3.0" pre-commit = "^3.2.2" From df338ff4e0367d478970bdd1e9fd945f95845b65 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Tue, 22 Oct 2024 17:59:41 +0800 Subject: [PATCH 02/14] basically support oceanbase Signed-off-by: shanhaikang.shk --- .../embedchain/config/vector_db/oceanbase.py | 1 - embedchain/embedchain/vectordb/oceanbase.py | 301 +++++++++++++++++- 2 files changed, 292 insertions(+), 10 deletions(-) diff --git a/embedchain/embedchain/config/vector_db/oceanbase.py b/embedchain/embedchain/config/vector_db/oceanbase.py index fa0ea18321..38b5f9f09b 100644 --- a/embedchain/embedchain/config/vector_db/oceanbase.py +++ b/embedchain/embedchain/config/vector_db/oceanbase.py @@ -15,7 +15,6 @@ @register_deserializable class OceanBaseConfig(BaseVectorDbConfig): - pass def __init__( self, collection_name: Optional[str] = None, diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py index 77eaf1d79c..cf83fb12c8 100644 --- a/embedchain/embedchain/vectordb/oceanbase.py +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -1,11 +1,18 @@ +import json import logging +import math +from typing import Any, List, Optional, Union + +import numpy as np +from sqlalchemy import JSON, Column, String, Table, func, text +from sqlalchemy.dialects.mysql import LONGTEXT from embedchain.config import OceanBaseConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB try: - from pyobvector import ObVecClient + from pyobvector import VECTOR, ObVecClient except ImportError: raise ImportError( "OceanBase requires extra dependencies. Install with `pip install --upgrade pyobvector`" @@ -13,6 +20,26 @@ logger = logging.getLogger(__name__) +DEFAULT_OCEANBASE_ID_COL = "id" +DEFAULT_OCEANBASE_TEXT_COL = "text" +DEFAULT_OCEANBASE_EMBEDDING_COL = "embeddings" +DEFAULT_OCEANBASE_METADATA_COL = "metadata" +DEFAULT_OCEANBASE_VIDX_NAME = "vidx" +DEFAULT_OCEANBASE_VIDX_TYPE = "hnsw" +DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64} + +def _normalize(self, vector: List[float]) -> List[float]: + arr = np.array(vector) + norm = np.linalg.norm(arr) + arr = arr / norm + return arr.tolist() + +def _euclidean_similarity(distance: float) -> float: + return 1.0 - distance / math.sqrt(2) + +def _neg_inner_product_similarity(distance: float) -> float: + return -distance + @register_deserializable class OceanBaseVectorDB(BaseVectorDB): """`OceanBase` vector store. @@ -21,20 +48,27 @@ class OceanBaseVectorDB(BaseVectorDB): """ def __init__(self, config: OceanBaseConfig = None): if config is None: - self.config = OceanBaseConfig() + self.obconfig = OceanBaseConfig() else: - self.config = config + self.obconfig = config + + self.id_field = DEFAULT_OCEANBASE_ID_COL + self.text_field = DEFAULT_OCEANBASE_TEXT_COL + self.embed_field = DEFAULT_OCEANBASE_EMBEDDING_COL + self.metadata_field = DEFAULT_OCEANBASE_METADATA_COL + self.vidx_name = DEFAULT_OCEANBASE_VIDX_NAME + self.hnsw_ef_search = -1 self.client = ObVecClient( uri=( - self.config.host + ":" + self.config.port + self.obconfig.host + ":" + self.obconfig.port ), - user=self.config.user, - password=self.config.passwd, - db_name=self.config.dbname, + user=self.obconfig.user, + password=self.obconfig.passwd, + db_name=self.obconfig.dbname, ) - super().__init__(config=self.config) + super().__init__(config=self.obconfig) def _initialize(self): """ @@ -42,4 +76,253 @@ def _initialize(self): So it's can't be done in __init__ in one step. """ - return super()._initialize() \ No newline at end of file + self._get_or_create_collection() + + def _get_or_create_db(self): + """Called during initialization""" + return self.client + + def _load_table(self): + table = Table( + self.obconfig.collection_name, + self.client.metadata_obj, + autoload_with=self.client.engine, + ) + column_names = [column.name for column in table.columns] + assert len(column_names) == 4 + + logging.info(f"load exist table with {column_names} columns") + self.id_field = column_names[0] + self.text_field = column_names[1] + self.embed_field = column_names[2] + self.metadata_field = column_names[3] + + def _get_or_create_collection(self): + """Get or create a named collection.""" + if self.client.check_table_exists( + table_name=self.obconfig.collection_name, + ): + self._load_table() + return + + cols = [ + Column( + self.id_field, String(4096), primary_key=True, autoincrement=False + ), + Column(self.text_field, LONGTEXT), + Column(self.embed_field, VECTOR(self.embedder.vector_dimension)), + Column(self.metadata_field, JSON), + ] + + vidx_params = self.client.prepare_index_params() + vidx_params.add_index( + field_name=self.embed_field, + index_type=DEFAULT_OCEANBASE_VIDX_TYPE, + index_name=DEFAULT_OCEANBASE_VIDX_NAME, + metric_type=self.obconfig.vidx_metric_type, + params=self.obconfig.vidx_algo_params, + ) + + self.client.create_table_with_index_params( + table_name=self.obconfig.collection_name, + columns=cols, + indexes=None, + vidxs=vidx_params, + ) + + def get( + self, + ids: Optional[list[str]] = None, + where: Optional[dict[str, any]] = None, + limit: Optional[int] = None + ): + """ + Get existing doc ids present in vector database + + :param ids: list of doc ids to check for existence + :type ids: list[str] + :param where: Optional. to filter data + :type where: dict[str, Any] + :param limit: Optional. maximum number of documents + :type limit: Optional[int] + :return: Existing documents. + :rtype: Set[str] + """ + res = self.client.get( + table_name=self.obconfig.collection_name, + ids=ids, + where_clause=text(self._generate_oceanbase_filter(where)), + output_column_name=[self.id_field, self.metadata_field], + ) + + data_ids = [] + metadatas = [] + for r in res.fetchall(): + data_ids.append(r[0]) + metadatas.append(json.loads(r[1])) + + return {"ids": data_ids, "metadatas": metadatas} + + def add( + self, + documents: list[str], + metadatas: list[object], + ids: list[str], + **kwargs: Optional[dict[str, any]], + ): + """Add to database""" + batch_size = 100 + embeddings = self.embedder.embedding_fn(documents) + + total_count = len(embeddings) + for i in range(0, total_count, batch_size): + data = [ + { + self.id_field: id, + self.text_field: text, + self.embed_field: ( + embedding + if not self.obconfig.normalize else + self._normalize(embedding) + ), + self.metadata_field: metadata, + } + for id, text, embedding, metadata in zip( + ids[i : i + batch_size], + documents[i : i + batch_size], + embeddings[i : i + batch_size], + metadatas[i : i + batch_size], + ) + ] + self.client.insert( + table_name=self.table_name, + data=data, + ) + + def _parse_metric_type_str_to_dist_func(self) -> Any: + if self.obconfig.vidx_metric_type == "l2": + return func.l2_distance + if self.obconfig.vidx_metric_type == "cosine": + return func.cosine_distance + if self.obconfig.vidx_metric_type == "inner_product": + return func.negative_inner_product + raise ValueError( + f"Invalid vector index metric type: {self.obconfig.vidx_metric_type}" + ) + + def _parse_distance_to_similarities(self, distance: float) -> float: + if self.obconfig.vidx_metric_type == "l2": + return _euclidean_similarity(distance) + elif self.obconfig.vidx_metric_type == "inner_product": + return _neg_inner_product_similarity(distance) + raise ValueError(f"Metric Type {self._vidx_metric_type} is not supported") + + def query( + self, + input_query: str, + n_results: int, + where: dict[str, Any], + citations: bool = False, + param: Optional[dict] = None, + **kwargs: Optional[dict[str, Any]], + ) -> Union[list[tuple[str, dict]], list[str]]: + """ + Query contents from vector database based on vector similarity + + :param input_query: query string + :type input_query: str + :param n_results: no of similar documents to fetch from database + :type n_results: int + :param where: to filter data + :type where: dict[str, Any] + :raises InvalidDimensionException: Dimensions do not match. + :param citations: we use citations boolean param to return context along with the answer. + :type citations: bool, default is False. + :return: The content of the document that matched your query, + along with url of the source and doc_id (if citations flag is true) + :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] + """ + search_param = ( + param + if param is not None else + DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM + ) + ef_search = search_param.get( + "efSearch", + DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM["efSearch"] + ) + if ef_search != self.hnsw_ef_search: + self.client.set_ob_hnsw_ef_search(ef_search) + self.hnsw_ef_search = ef_search + + input_query_vector = self.embedder.embedding_fn([input_query]) + res = self.client.ann_search( + table_name=self.obconfig.collection_name, + vec_data=( + input_query_vector[0] + if not self.obconfig.normalize else + _normalize(input_query_vector[0]) + ), + vec_column_name=self.embed_field, + distance_func=self._parse_metric_type_str_to_dist_func(), + with_dist=True, + topk=n_results, + output_column_names=[ + self.text_field, + self.metadata_field + ], + where_clause=( + text(self._generate_oceanbase_filter(where)) + ), + **kwargs, + ) + + contexts = [] + for r in res: + context = r[0] + metadata = json.loads(r[1]) + score = self._parse_distance_to_similarities(r[2]) + + if citations: + metadata["score"] = score + contexts.append((context, metadata)) + else: + contexts.append(context) + return contexts + + def count(self) -> int: + """ + Count number of documents/chunks embedded in the database. + + :return: number of documents + :rtype: int + """ + res = self.client.perform_raw_text_sql( + f"SELECT COUNT(*) FROM {self.obconfig.collection_name}" + ) + return res.fetchall()[0][0] + + def reset(self, collection_names: list[str] = None): + """ + Resets the database. Deletes all embeddings irreversibly. + """ + if collection_names: + for collection_name in collection_names: + self.client.drop_table_if_exist(table_name=collection_name) + + def set_collection_name(self, name: str): + """ + Set the name of the collection. A collection is an isolated space for vectors. + + :param name: Name of the collection. + :type name: str + """ + if not isinstance(name, str): + raise TypeError("Collection name must be a string") + self.obconfig.collection_name = name + + def _generate_oceanbase_filter(self, where: dict[str, str]): + operands = [] + for key, value in where.items(): + operands.append(f"({self.metadata_field}->'$.{key}' == '{value}')") + return " and ".join(operands) From fb70448383fafc35fa36697cd9b3592bd8adc05f Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Tue, 22 Oct 2024 18:01:44 +0800 Subject: [PATCH 03/14] fix factory Signed-off-by: shanhaikang.shk --- embedchain/embedchain/factory.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/embedchain/embedchain/factory.py b/embedchain/embedchain/factory.py index b8e31b786c..cec34eaecc 100644 --- a/embedchain/embedchain/factory.py +++ b/embedchain/embedchain/factory.py @@ -98,6 +98,7 @@ class VectorDBFactory: "qdrant": "embedchain.vectordb.qdrant.QdrantDB", "weaviate": "embedchain.vectordb.weaviate.WeaviateDB", "zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB", + "oceanbase": "embedchain.vectordb.oceanbase.OceanBaseVectorDB", } provider_to_config_class = { "chroma": "embedchain.config.vector_db.chroma.ChromaDbConfig", @@ -108,7 +109,7 @@ class VectorDBFactory: "qdrant": "embedchain.config.vector_db.qdrant.QdrantDBConfig", "weaviate": "embedchain.config.vector_db.weaviate.WeaviateDBConfig", "zilliz": "embedchain.config.vector_db.zilliz.ZillizDBConfig", - "oceanbase": "embedchain.config.vector_db.zilliz.OceanBaseConfig" + "oceanbase": "embedchain.config.vector_db.oceanbase.OceanBaseConfig" } @classmethod From 7f1bf39772c7f4316a61faafc47037ca35852a4a Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Mon, 28 Oct 2024 14:42:40 +0800 Subject: [PATCH 04/14] fix some bug Signed-off-by: shanhaikang.shk --- embedchain/embedchain/vectordb/oceanbase.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py index cf83fb12c8..a7f3b4ac55 100644 --- a/embedchain/embedchain/vectordb/oceanbase.py +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -76,6 +76,12 @@ def _initialize(self): So it's can't be done in __init__ in one step. """ + if not hasattr(self, "embedder") or not self.embedder: + raise ValueError("Cannot create a OceanBase database collection without an embedder.") + if self.obconfig.drop_old: + self.client.drop_table_if_exist( + table_name=self.obconfig.collection_name + ) self._get_or_create_collection() def _get_or_create_db(self): @@ -151,7 +157,7 @@ def get( res = self.client.get( table_name=self.obconfig.collection_name, ids=ids, - where_clause=text(self._generate_oceanbase_filter(where)), + where_clause=self._generate_oceanbase_filter(where), output_column_name=[self.id_field, self.metadata_field], ) @@ -195,7 +201,7 @@ def add( ) ] self.client.insert( - table_name=self.table_name, + table_name=self.obconfig.collection_name, data=data, ) @@ -272,7 +278,7 @@ def query( self.metadata_field ], where_clause=( - text(self._generate_oceanbase_filter(where)) + self._generate_oceanbase_filter(where) ), **kwargs, ) @@ -322,7 +328,9 @@ def set_collection_name(self, name: str): self.obconfig.collection_name = name def _generate_oceanbase_filter(self, where: dict[str, str]): + if len(where.keys()) == 0: + return None operands = [] for key, value in where.items(): operands.append(f"({self.metadata_field}->'$.{key}' == '{value}')") - return " and ".join(operands) + return text(" and ".join(operands)) From 603c44535b07ffda433ccca51e80dd49c6422fe6 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Mon, 28 Oct 2024 15:32:18 +0800 Subject: [PATCH 05/14] add unittest for OceanBaseVector Signed-off-by: shanhaikang.shk --- embedchain/embedchain/vectordb/oceanbase.py | 5 +-- embedchain/tests/vectordb/test_oceanbase.py | 45 +++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) create mode 100644 embedchain/tests/vectordb/test_oceanbase.py diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py index a7f3b4ac55..d80bf79f62 100644 --- a/embedchain/embedchain/vectordb/oceanbase.py +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -42,10 +42,7 @@ def _neg_inner_product_similarity(distance: float) -> float: @register_deserializable class OceanBaseVectorDB(BaseVectorDB): - """`OceanBase` vector store. - - - """ + """`OceanBase` vector store.""" def __init__(self, config: OceanBaseConfig = None): if config is None: self.obconfig = OceanBaseConfig() diff --git a/embedchain/tests/vectordb/test_oceanbase.py b/embedchain/tests/vectordb/test_oceanbase.py new file mode 100644 index 0000000000..62614f0881 --- /dev/null +++ b/embedchain/tests/vectordb/test_oceanbase.py @@ -0,0 +1,45 @@ +# ruff: noqa: E501 + +import logging +from unittest.mock import patch + +import pytest + +from embedchain.config.vector_db.oceanbase import OceanBaseConfig +from embedchain.vectordb.oceanbase import OceanBaseVectorDB + +logger = logging.getLogger(__name__) + +class TestOceanBaseVector: + @pytest.fixture + def mock_embedder(self, mocker): + return mocker.Mock() + + @patch("embedchain.vectordb.oceanbase.ObVecClient", autospec=True) + def test_query(self, mock_client, mock_embedder): + ob_config = OceanBaseConfig(drop_old=True) + ob = OceanBaseVectorDB(config=ob_config) + ob.embedder = mock_embedder + + with patch.object(ob.client, "ann_search") as mock_search: + mock_embedder.embedding_fn.return_value = ["query_vector"] + mock_search.return_value = [ + ( + 'result_doc', + '{"url": "url_1", "doc_id": "doc_id_1"}', + 0.0 + ) + ] + + query_result = ob.query(input_query="query_text", n_results=1, where={}) + + assert query_result == ["result_doc"] + + query_result = ob.query(input_query="query_text", n_results=1, where={}, citations=True) + + assert query_result == [ + ( + "result_doc", + {"url": "url_1", "doc_id": "doc_id_1", "score": 1.0} + ) + ] From 8183ef7400da7fce6aced0089cf177b68b8063b0 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Wed, 30 Oct 2024 11:32:16 +0800 Subject: [PATCH 06/14] add config test Signed-off-by: shanhaikang.shk --- embedchain/embedchain/vectordb/oceanbase.py | 5 +++-- embedchain/tests/vectordb/test_oceanbase.py | 9 +++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py index d80bf79f62..9a5c8cd6f9 100644 --- a/embedchain/embedchain/vectordb/oceanbase.py +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -238,12 +238,13 @@ def query( :type n_results: int :param where: to filter data :type where: dict[str, Any] - :raises InvalidDimensionException: Dimensions do not match. :param citations: we use citations boolean param to return context along with the answer. :type citations: bool, default is False. + :param param: search parameters for hnsw. + :type param: Optional[dict] :return: The content of the document that matched your query, along with url of the source and doc_id (if citations flag is true) - :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] + :rtype: list[str], if citations=False, otherwise list[tuple[str, dict]] """ search_param = ( param diff --git a/embedchain/tests/vectordb/test_oceanbase.py b/embedchain/tests/vectordb/test_oceanbase.py index 62614f0881..c9eed07d96 100644 --- a/embedchain/tests/vectordb/test_oceanbase.py +++ b/embedchain/tests/vectordb/test_oceanbase.py @@ -1,6 +1,8 @@ # ruff: noqa: E501 import logging +import os +from unittest import mock from unittest.mock import patch import pytest @@ -10,6 +12,13 @@ logger = logging.getLogger(__name__) +class TestOceanBaseConifg: + @mock.patch.dict(os.environ, {"OB_PASSWORD": "ob_password"}) + def test_init_oceanbase_config(self): + expect_password = "ob_password" + ob_config = OceanBaseConfig() + assert ob_config.passwd == expect_password + class TestOceanBaseVector: @pytest.fixture def mock_embedder(self, mocker): From 899aa3e0200cb9424cd88e9afa24fa2ed559a6d2 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Wed, 30 Oct 2024 11:45:50 +0800 Subject: [PATCH 07/14] make format Signed-off-by: shanhaikang.shk --- .../embedchain/config/vector_db/oceanbase.py | 15 +--- embedchain/embedchain/vectordb/oceanbase.py | 76 ++++++------------- embedchain/tests/vectordb/test_oceanbase.py | 21 ++--- 3 files changed, 33 insertions(+), 79 deletions(-) diff --git a/embedchain/embedchain/config/vector_db/oceanbase.py b/embedchain/embedchain/config/vector_db/oceanbase.py index 38b5f9f09b..64f814c2f8 100644 --- a/embedchain/embedchain/config/vector_db/oceanbase.py +++ b/embedchain/embedchain/config/vector_db/oceanbase.py @@ -13,6 +13,7 @@ DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2" DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256} + @register_deserializable class OceanBaseConfig(BaseVectorDbConfig): def __init__( @@ -54,9 +55,7 @@ def __init__( :param normalize: normalize vector before storing into OceanBase. :type normalize: bool """ - self.collection_name = ( - collection_name or DEFAULT_OCEANBASE_COLLECTION_NAME - ) + self.collection_name = collection_name or DEFAULT_OCEANBASE_COLLECTION_NAME self.host = host or DEFAULT_OCEANBASE_HOST self.port = port or DEFAULT_OCEANBASE_PORT self.passwd = os.environ.get("OB_PASSWORD", "") @@ -70,13 +69,7 @@ def __init__( self.dbname = dbname or DEFAULT_OCEANBASE_DBNAME self.vidx_metric_type = vidx_metric_type.lower() if self.vidx_metric_type not in ("l2", "inner_product"): - raise ValueError( - "`vidx_metric_type` should be set in `l2`/`inner_product`." - ) - self.vidx_algo_params = ( - vidx_algo_params - if vidx_algo_params is not None - else DEFAULT_OCEANBASE_HNSW_BUILD_PARAM - ) + raise ValueError("`vidx_metric_type` should be set in `l2`/`inner_product`.") + self.vidx_algo_params = vidx_algo_params if vidx_algo_params is not None else DEFAULT_OCEANBASE_HNSW_BUILD_PARAM self.drop_old = drop_old self.normalize = normalize diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py index 9a5c8cd6f9..d5b656f1db 100644 --- a/embedchain/embedchain/vectordb/oceanbase.py +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -28,27 +28,32 @@ DEFAULT_OCEANBASE_VIDX_TYPE = "hnsw" DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64} + def _normalize(self, vector: List[float]) -> List[float]: arr = np.array(vector) norm = np.linalg.norm(arr) arr = arr / norm return arr.tolist() + def _euclidean_similarity(distance: float) -> float: return 1.0 - distance / math.sqrt(2) + def _neg_inner_product_similarity(distance: float) -> float: return -distance + @register_deserializable class OceanBaseVectorDB(BaseVectorDB): """`OceanBase` vector store.""" + def __init__(self, config: OceanBaseConfig = None): if config is None: self.obconfig = OceanBaseConfig() else: self.obconfig = config - + self.id_field = DEFAULT_OCEANBASE_ID_COL self.text_field = DEFAULT_OCEANBASE_TEXT_COL self.embed_field = DEFAULT_OCEANBASE_EMBEDDING_COL @@ -57,9 +62,7 @@ def __init__(self, config: OceanBaseConfig = None): self.hnsw_ef_search = -1 self.client = ObVecClient( - uri=( - self.obconfig.host + ":" + self.obconfig.port - ), + uri=(self.obconfig.host + ":" + self.obconfig.port), user=self.obconfig.user, password=self.obconfig.passwd, db_name=self.obconfig.dbname, @@ -76,15 +79,13 @@ def _initialize(self): if not hasattr(self, "embedder") or not self.embedder: raise ValueError("Cannot create a OceanBase database collection without an embedder.") if self.obconfig.drop_old: - self.client.drop_table_if_exist( - table_name=self.obconfig.collection_name - ) + self.client.drop_table_if_exist(table_name=self.obconfig.collection_name) self._get_or_create_collection() def _get_or_create_db(self): """Called during initialization""" return self.client - + def _load_table(self): table = Table( self.obconfig.collection_name, @@ -107,11 +108,9 @@ def _get_or_create_collection(self): ): self._load_table() return - + cols = [ - Column( - self.id_field, String(4096), primary_key=True, autoincrement=False - ), + Column(self.id_field, String(4096), primary_key=True, autoincrement=False), Column(self.text_field, LONGTEXT), Column(self.embed_field, VECTOR(self.embedder.vector_dimension)), Column(self.metadata_field, JSON), @@ -133,12 +132,7 @@ def _get_or_create_collection(self): vidxs=vidx_params, ) - def get( - self, - ids: Optional[list[str]] = None, - where: Optional[dict[str, any]] = None, - limit: Optional[int] = None - ): + def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): """ Get existing doc ids present in vector database @@ -165,7 +159,7 @@ def get( metadatas.append(json.loads(r[1])) return {"ids": data_ids, "metadatas": metadatas} - + def add( self, documents: list[str], @@ -176,18 +170,14 @@ def add( """Add to database""" batch_size = 100 embeddings = self.embedder.embedding_fn(documents) - + total_count = len(embeddings) for i in range(0, total_count, batch_size): data = [ { self.id_field: id, self.text_field: text, - self.embed_field: ( - embedding - if not self.obconfig.normalize else - self._normalize(embedding) - ), + self.embed_field: (embedding if not self.obconfig.normalize else self._normalize(embedding)), self.metadata_field: metadata, } for id, text, embedding, metadata in zip( @@ -209,10 +199,8 @@ def _parse_metric_type_str_to_dist_func(self) -> Any: return func.cosine_distance if self.obconfig.vidx_metric_type == "inner_product": return func.negative_inner_product - raise ValueError( - f"Invalid vector index metric type: {self.obconfig.vidx_metric_type}" - ) - + raise ValueError(f"Invalid vector index metric type: {self.obconfig.vidx_metric_type}") + def _parse_distance_to_similarities(self, distance: float) -> float: if self.obconfig.vidx_metric_type == "l2": return _euclidean_similarity(distance) @@ -246,15 +234,8 @@ def query( along with url of the source and doc_id (if citations flag is true) :rtype: list[str], if citations=False, otherwise list[tuple[str, dict]] """ - search_param = ( - param - if param is not None else - DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM - ) - ef_search = search_param.get( - "efSearch", - DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM["efSearch"] - ) + search_param = param if param is not None else DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM + ef_search = search_param.get("efSearch", DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM["efSearch"]) if ef_search != self.hnsw_ef_search: self.client.set_ob_hnsw_ef_search(ef_search) self.hnsw_ef_search = ef_search @@ -262,22 +243,13 @@ def query( input_query_vector = self.embedder.embedding_fn([input_query]) res = self.client.ann_search( table_name=self.obconfig.collection_name, - vec_data=( - input_query_vector[0] - if not self.obconfig.normalize else - _normalize(input_query_vector[0]) - ), + vec_data=(input_query_vector[0] if not self.obconfig.normalize else _normalize(input_query_vector[0])), vec_column_name=self.embed_field, distance_func=self._parse_metric_type_str_to_dist_func(), with_dist=True, topk=n_results, - output_column_names=[ - self.text_field, - self.metadata_field - ], - where_clause=( - self._generate_oceanbase_filter(where) - ), + output_column_names=[self.text_field, self.metadata_field], + where_clause=(self._generate_oceanbase_filter(where)), **kwargs, ) @@ -301,9 +273,7 @@ def count(self) -> int: :return: number of documents :rtype: int """ - res = self.client.perform_raw_text_sql( - f"SELECT COUNT(*) FROM {self.obconfig.collection_name}" - ) + res = self.client.perform_raw_text_sql(f"SELECT COUNT(*) FROM {self.obconfig.collection_name}") return res.fetchall()[0][0] def reset(self, collection_names: list[str] = None): diff --git a/embedchain/tests/vectordb/test_oceanbase.py b/embedchain/tests/vectordb/test_oceanbase.py index c9eed07d96..08819a160f 100644 --- a/embedchain/tests/vectordb/test_oceanbase.py +++ b/embedchain/tests/vectordb/test_oceanbase.py @@ -10,7 +10,8 @@ from embedchain.config.vector_db.oceanbase import OceanBaseConfig from embedchain.vectordb.oceanbase import OceanBaseVectorDB -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) + class TestOceanBaseConifg: @mock.patch.dict(os.environ, {"OB_PASSWORD": "ob_password"}) @@ -19,11 +20,12 @@ def test_init_oceanbase_config(self): ob_config = OceanBaseConfig() assert ob_config.passwd == expect_password + class TestOceanBaseVector: @pytest.fixture def mock_embedder(self, mocker): return mocker.Mock() - + @patch("embedchain.vectordb.oceanbase.ObVecClient", autospec=True) def test_query(self, mock_client, mock_embedder): ob_config = OceanBaseConfig(drop_old=True) @@ -32,13 +34,7 @@ def test_query(self, mock_client, mock_embedder): with patch.object(ob.client, "ann_search") as mock_search: mock_embedder.embedding_fn.return_value = ["query_vector"] - mock_search.return_value = [ - ( - 'result_doc', - '{"url": "url_1", "doc_id": "doc_id_1"}', - 0.0 - ) - ] + mock_search.return_value = [("result_doc", '{"url": "url_1", "doc_id": "doc_id_1"}', 0.0)] query_result = ob.query(input_query="query_text", n_results=1, where={}) @@ -46,9 +42,4 @@ def test_query(self, mock_client, mock_embedder): query_result = ob.query(input_query="query_text", n_results=1, where={}, citations=True) - assert query_result == [ - ( - "result_doc", - {"url": "url_1", "doc_id": "doc_id_1", "score": 1.0} - ) - ] + assert query_result == [("result_doc", {"url": "url_1", "doc_id": "doc_id_1", "score": 1.0})] From 783b5d8c01d848e3c0713e4c0891d3177d4b8802 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Fri, 8 Nov 2024 10:55:28 +0800 Subject: [PATCH 08/14] add oceanbase in misc.py Signed-off-by: shanhaikang.shk --- embedchain/embedchain/utils/misc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/embedchain/embedchain/utils/misc.py b/embedchain/embedchain/utils/misc.py index 7c5468ec93..eefcd98fe2 100644 --- a/embedchain/embedchain/utils/misc.py +++ b/embedchain/embedchain/utils/misc.py @@ -449,7 +449,8 @@ def validate_config(config_data): }, Optional("vectordb"): { Optional("provider"): Or( - "chroma", "elasticsearch", "opensearch", "lancedb", "pinecone", "qdrant", "weaviate", "zilliz" + "chroma", "elasticsearch", "opensearch", "lancedb", "pinecone", "qdrant", "weaviate", "zilliz", + "oceanbase" ), Optional("config"): object, # TODO: add particular config schema for each provider }, From c0c2796ebf5f8b56793c5e7bb78777f41d4d4963 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Fri, 8 Nov 2024 14:56:18 +0800 Subject: [PATCH 09/14] fix where bug Signed-off-by: shanhaikang.shk --- embedchain/embedchain/vectordb/oceanbase.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py index d5b656f1db..781a24b776 100644 --- a/embedchain/embedchain/vectordb/oceanbase.py +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -148,7 +148,7 @@ def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = res = self.client.get( table_name=self.obconfig.collection_name, ids=ids, - where_clause=self._generate_oceanbase_filter(where), + where_clause=[self._generate_oceanbase_filter(where)], output_column_name=[self.id_field, self.metadata_field], ) @@ -249,7 +249,7 @@ def query( with_dist=True, topk=n_results, output_column_names=[self.text_field, self.metadata_field], - where_clause=(self._generate_oceanbase_filter(where)), + where_clause=[self._generate_oceanbase_filter(where)], **kwargs, ) From d5386893f7b37a41bd67d7ffdb2310867cf505c8 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Fri, 8 Nov 2024 16:36:51 +0800 Subject: [PATCH 10/14] fix oceanbase filter Signed-off-by: shanhaikang.shk --- embedchain/embedchain/vectordb/oceanbase.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py index 781a24b776..590571f6b2 100644 --- a/embedchain/embedchain/vectordb/oceanbase.py +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -300,5 +300,5 @@ def _generate_oceanbase_filter(self, where: dict[str, str]): return None operands = [] for key, value in where.items(): - operands.append(f"({self.metadata_field}->'$.{key}' == '{value}')") + operands.append(f"({self.metadata_field}->'$.{key}' = '{value}')") return text(" and ".join(operands)) From 42a93c24887357625ab2909932e64688279cda22 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Tue, 12 Nov 2024 12:00:29 +0800 Subject: [PATCH 11/14] update pyobvector to v0.1.13 Signed-off-by: shanhaikang.shk --- embedchain/poetry.lock | 32 ++++++++++++++++++++++++++------ embedchain/pyproject.toml | 5 ++--- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/embedchain/poetry.lock b/embedchain/poetry.lock index e7661193a9..dd0b05cebe 100644 --- a/embedchain/poetry.lock +++ b/embedchain/poetry.lock @@ -96,6 +96,24 @@ yarl = ">=1.0,<2.0" [package.extras] speedups = ["Brotli", "aiodns", "brotlicffi"] +[[package]] +name = "aiomysql" +version = "0.2.0" +description = "MySQL driver for asyncio." +optional = true +python-versions = ">=3.7" +files = [ + {file = "aiomysql-0.2.0-py3-none-any.whl", hash = "sha256:b7c26da0daf23a5ec5e0b133c03d20657276e4eae9b73e040b72787f6f6ade0a"}, + {file = "aiomysql-0.2.0.tar.gz", hash = "sha256:558b9c26d580d08b8c5fd1be23c5231ce3aeff2dadad989540fee740253deb67"}, +] + +[package.dependencies] +PyMySQL = ">=1.0" + +[package.extras] +rsa = ["PyMySQL[rsa] (>=1.0)"] +sa = ["sqlalchemy (>=1.3,<1.4)"] + [[package]] name = "aiosignal" version = "1.3.1" @@ -4506,7 +4524,7 @@ model = ["milvus-model (>=0.1.0)"] name = "pymysql" version = "1.1.1" description = "Pure Python MySQL Driver" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "PyMySQL-1.1.1-py3-none-any.whl", hash = "sha256:4de15da4c61dc132f4fb9ab763063e693d521a80fd0e87943b9a453dd4c19d6c"}, @@ -4519,16 +4537,17 @@ rsa = ["cryptography"] [[package]] name = "pyobvector" -version = "0.1.6" +version = "0.1.13" description = "A python SDK for OceanBase Vector Store, based on SQLAlchemy, compatible with Milvus API." -optional = false +optional = true python-versions = "<4.0,>=3.9" files = [ - {file = "pyobvector-0.1.6-py3-none-any.whl", hash = "sha256:0d700e865a85b4716b9a03384189e49288cd9d5f3cef88aed4740bc82d5fd136"}, - {file = "pyobvector-0.1.6.tar.gz", hash = "sha256:05551addcac8c596992d5e38b480c83ca3481c6cfc6f56a1a1bddfb2e6ae037e"}, + {file = "pyobvector-0.1.13-py3-none-any.whl", hash = "sha256:b6a9e7a4673aebeefe835e04f7474d2f2ef8b9c96982af41cf9ce6f3e3500fdb"}, + {file = "pyobvector-0.1.13.tar.gz", hash = "sha256:e4b8f3ba3ad142cd7584b36278a38c0ef2fe7b6af142cdf5467d988e0737e03e"}, ] [package.dependencies] +aiomysql = ">=0.2.0,<0.3.0" numpy = ">=1.26.0,<2.0.0" pymysql = ">=1.1.1,<2.0.0" sqlalchemy = ">=2.0.32,<3.0.0" @@ -6780,6 +6799,7 @@ mysql = ["mysql-connector-python"] opensearch = ["opensearch-py"] opensource = ["gpt4all", "sentence-transformers", "torch"] postgres = ["psycopg", "psycopg-binary", "psycopg-pool"] +pyobvector = ["pyobvector"] qdrant = ["qdrant-client"] together = ["together"] vertexai = ["langchain-google-vertexai"] @@ -6788,4 +6808,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.13" -content-hash = "6f83d33b608d27686a75a476bbf717111d82dfdbf0b1e77f7118ef07f16cb40d" +content-hash = "2c89658261a3d4f92f84dc7ddec4e1d4a55acc2d61b8203c866cf5a8be184e79" diff --git a/embedchain/pyproject.toml b/embedchain/pyproject.toml index fdc4328e9e..85a1c655ec 100644 --- a/embedchain/pyproject.toml +++ b/embedchain/pyproject.toml @@ -140,8 +140,8 @@ langchain-cohere = "^0.3.0" langchain-community = "^0.3.1" langchain-aws = {version = "^0.2.1", optional = true} langsmith = "^0.1.17" +pyobvector = {version = "^0.1.13", optional = true} -pyobvector = "^0.1.6" [tool.poetry.group.dev.dependencies] black = "^23.3.0" pre-commit = "^3.2.2" @@ -181,8 +181,7 @@ mysql = ["mysql-connector-python"] google = ["google-generativeai"] mistralai = ["langchain-mistralai"] aws = ["langchain-aws"] - -[tool.poetry.group.docs.dependencies] +pyobvector = ["pyobvector"] [tool.poetry.scripts] ec = "embedchain.cli:cli" \ No newline at end of file From 5f0205c155e77fb36c4715087b4e9b1519447b7b Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Wed, 13 Nov 2024 10:51:27 +0800 Subject: [PATCH 12/14] fix: oceanbase result type Signed-off-by: shanhaikang.shk --- embedchain/embedchain/vectordb/oceanbase.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py index 590571f6b2..865a991153 100644 --- a/embedchain/embedchain/vectordb/oceanbase.py +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -156,7 +156,12 @@ def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = metadatas = [] for r in res.fetchall(): data_ids.append(r[0]) - metadatas.append(json.loads(r[1])) + if isinstance(r[1], str) or isinstance(r[1], bytes): + metadatas.append(json.loads(r[1])) + elif isinstance(r[1], dict): + metadatas.append(r[1]) + else: + raise ValueError("invalid json type") return {"ids": data_ids, "metadatas": metadatas} @@ -256,7 +261,12 @@ def query( contexts = [] for r in res: context = r[0] - metadata = json.loads(r[1]) + if isinstance(r[1], str) or isinstance(r[1], bytes): + metadata = json.loads(r[1]) + elif isinstance(r[1], dict): + metadata = r[1] + else: + raise ValueError("invalid json type") score = self._parse_distance_to_similarities(r[2]) if citations: From 83035cef8e968577651e0b35ed486c3651651696 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Fri, 15 Nov 2024 00:50:31 +0800 Subject: [PATCH 13/14] fix _generate_oceanbase_filter for where=None Signed-off-by: shanhaikang.shk --- embedchain/embedchain/vectordb/oceanbase.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py index 865a991153..b03590f507 100644 --- a/embedchain/embedchain/vectordb/oceanbase.py +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -148,7 +148,7 @@ def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = res = self.client.get( table_name=self.obconfig.collection_name, ids=ids, - where_clause=[self._generate_oceanbase_filter(where)], + where_clause=self._generate_oceanbase_filter(where), output_column_name=[self.id_field, self.metadata_field], ) @@ -254,7 +254,7 @@ def query( with_dist=True, topk=n_results, output_column_names=[self.text_field, self.metadata_field], - where_clause=[self._generate_oceanbase_filter(where)], + where_clause=self._generate_oceanbase_filter(where), **kwargs, ) @@ -311,4 +311,4 @@ def _generate_oceanbase_filter(self, where: dict[str, str]): operands = [] for key, value in where.items(): operands.append(f"({self.metadata_field}->'$.{key}' = '{value}')") - return text(" and ".join(operands)) + return [text(" and ".join(operands))] From 0e3c5f3fd6d7077b325f16707afd7170356ba918 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Fri, 15 Nov 2024 09:13:20 +0800 Subject: [PATCH 14/14] fix where is none Signed-off-by: shanhaikang.shk --- embedchain/embedchain/vectordb/oceanbase.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py index b03590f507..0895e89aa8 100644 --- a/embedchain/embedchain/vectordb/oceanbase.py +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -306,7 +306,7 @@ def set_collection_name(self, name: str): self.obconfig.collection_name = name def _generate_oceanbase_filter(self, where: dict[str, str]): - if len(where.keys()) == 0: + if where is None or len(where.keys()) == 0: return None operands = [] for key, value in where.items():