diff --git a/embedchain/embedchain/config/__init__.py b/embedchain/embedchain/config/__init__.py index 768408b785..9f9c1a8cff 100644 --- a/embedchain/embedchain/config/__init__.py +++ b/embedchain/embedchain/config/__init__.py @@ -11,5 +11,6 @@ from .mem0_config import Mem0Config from .vector_db.chroma import ChromaDbConfig from .vector_db.elasticsearch import ElasticsearchDBConfig +from .vector_db.oceanbase import OceanBaseConfig from .vector_db.opensearch import OpenSearchDBConfig from .vector_db.zilliz import ZillizDBConfig diff --git a/embedchain/embedchain/config/vector_db/oceanbase.py b/embedchain/embedchain/config/vector_db/oceanbase.py new file mode 100644 index 0000000000..64f814c2f8 --- /dev/null +++ b/embedchain/embedchain/config/vector_db/oceanbase.py @@ -0,0 +1,75 @@ +import os +from typing import Optional + +from embedchain.config.vector_db.base import BaseVectorDbConfig +from embedchain.helpers.json_serializable import register_deserializable + +DEFAULT_OCEANBASE_COLLECTION_NAME = "embedchain_vector" +DEFAULT_OCEANBASE_HOST = "localhost" +DEFAULT_OCEANBASE_PORT = "2881" +DEFAULT_OCEANBASE_USER = "root@test" +DEFAULT_OCEANBASE_PASSWORD = "" +DEFAULT_OCEANBASE_DBNAME = "test" +DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2" +DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256} + + +@register_deserializable +class OceanBaseConfig(BaseVectorDbConfig): + def __init__( + self, + collection_name: Optional[str] = None, + dir: str = "db", + host: Optional[str] = None, + port: Optional[str] = None, + user: Optional[str] = None, + dbname: Optional[str] = None, + vidx_metric_type: str = DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE, + vidx_algo_params: Optional[dict] = None, + drop_old: bool = False, + normalize: bool = False, + ): + """ + Initializes a configuration class instance for OceanBase. + + :param collection_name: Default name for the collection, defaults to None + :type collection_name: Optional[str], optional + :param dir: Path to the database directory, where the database is stored, defaults to "db". + In OceanBase, this parameter is not valid. + :type dir: str, optional + :param host: Database connection remote host. + :type host: Optional[str], optional + :param port: Database connection remote port. + :type port: Optional[str], optional + :param user: Database user name. + :type user: Optional[str], optional + :param dbname: OceanBase database name + :type dbname: Optional[str], optional + :param vidx_metric_type: vector index metric type, 'l2' or 'inner_product'. + :type vidx_metric_type: Optional[str], optional + :param vidx_algo_params: vector index building params, + refer to `DEFAULT_OCEANBASE_HNSW_BUILD_PARAM` for an example. + :type vidx_algo_params: Optional[dict], optional + :param drop_old: drop old table before creating. + :type drop_old: bool + :param normalize: normalize vector before storing into OceanBase. + :type normalize: bool + """ + self.collection_name = collection_name or DEFAULT_OCEANBASE_COLLECTION_NAME + self.host = host or DEFAULT_OCEANBASE_HOST + self.port = port or DEFAULT_OCEANBASE_PORT + self.passwd = os.environ.get("OB_PASSWORD", "") + super().__init__( + collection_name=self.collection_name, + dir=dir, + host=self.host, + port=self.port, + ) + self.user = user or DEFAULT_OCEANBASE_USER + self.dbname = dbname or DEFAULT_OCEANBASE_DBNAME + self.vidx_metric_type = vidx_metric_type.lower() + if self.vidx_metric_type not in ("l2", "inner_product"): + raise ValueError("`vidx_metric_type` should be set in `l2`/`inner_product`.") + self.vidx_algo_params = vidx_algo_params if vidx_algo_params is not None else DEFAULT_OCEANBASE_HNSW_BUILD_PARAM + self.drop_old = drop_old + self.normalize = normalize diff --git a/embedchain/embedchain/factory.py b/embedchain/embedchain/factory.py index 69636286cf..cec34eaecc 100644 --- a/embedchain/embedchain/factory.py +++ b/embedchain/embedchain/factory.py @@ -98,6 +98,7 @@ class VectorDBFactory: "qdrant": "embedchain.vectordb.qdrant.QdrantDB", "weaviate": "embedchain.vectordb.weaviate.WeaviateDB", "zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB", + "oceanbase": "embedchain.vectordb.oceanbase.OceanBaseVectorDB", } provider_to_config_class = { "chroma": "embedchain.config.vector_db.chroma.ChromaDbConfig", @@ -108,6 +109,7 @@ class VectorDBFactory: "qdrant": "embedchain.config.vector_db.qdrant.QdrantDBConfig", "weaviate": "embedchain.config.vector_db.weaviate.WeaviateDBConfig", "zilliz": "embedchain.config.vector_db.zilliz.ZillizDBConfig", + "oceanbase": "embedchain.config.vector_db.oceanbase.OceanBaseConfig" } @classmethod diff --git a/embedchain/embedchain/utils/misc.py b/embedchain/embedchain/utils/misc.py index 7c5468ec93..eefcd98fe2 100644 --- a/embedchain/embedchain/utils/misc.py +++ b/embedchain/embedchain/utils/misc.py @@ -449,7 +449,8 @@ def validate_config(config_data): }, Optional("vectordb"): { Optional("provider"): Or( - "chroma", "elasticsearch", "opensearch", "lancedb", "pinecone", "qdrant", "weaviate", "zilliz" + "chroma", "elasticsearch", "opensearch", "lancedb", "pinecone", "qdrant", "weaviate", "zilliz", + "oceanbase" ), Optional("config"): object, # TODO: add particular config schema for each provider }, diff --git a/embedchain/embedchain/vectordb/oceanbase.py b/embedchain/embedchain/vectordb/oceanbase.py new file mode 100644 index 0000000000..0895e89aa8 --- /dev/null +++ b/embedchain/embedchain/vectordb/oceanbase.py @@ -0,0 +1,314 @@ +import json +import logging +import math +from typing import Any, List, Optional, Union + +import numpy as np +from sqlalchemy import JSON, Column, String, Table, func, text +from sqlalchemy.dialects.mysql import LONGTEXT + +from embedchain.config import OceanBaseConfig +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.vectordb.base import BaseVectorDB + +try: + from pyobvector import VECTOR, ObVecClient +except ImportError: + raise ImportError( + "OceanBase requires extra dependencies. Install with `pip install --upgrade pyobvector`" + ) from None + +logger = logging.getLogger(__name__) + +DEFAULT_OCEANBASE_ID_COL = "id" +DEFAULT_OCEANBASE_TEXT_COL = "text" +DEFAULT_OCEANBASE_EMBEDDING_COL = "embeddings" +DEFAULT_OCEANBASE_METADATA_COL = "metadata" +DEFAULT_OCEANBASE_VIDX_NAME = "vidx" +DEFAULT_OCEANBASE_VIDX_TYPE = "hnsw" +DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64} + + +def _normalize(self, vector: List[float]) -> List[float]: + arr = np.array(vector) + norm = np.linalg.norm(arr) + arr = arr / norm + return arr.tolist() + + +def _euclidean_similarity(distance: float) -> float: + return 1.0 - distance / math.sqrt(2) + + +def _neg_inner_product_similarity(distance: float) -> float: + return -distance + + +@register_deserializable +class OceanBaseVectorDB(BaseVectorDB): + """`OceanBase` vector store.""" + + def __init__(self, config: OceanBaseConfig = None): + if config is None: + self.obconfig = OceanBaseConfig() + else: + self.obconfig = config + + self.id_field = DEFAULT_OCEANBASE_ID_COL + self.text_field = DEFAULT_OCEANBASE_TEXT_COL + self.embed_field = DEFAULT_OCEANBASE_EMBEDDING_COL + self.metadata_field = DEFAULT_OCEANBASE_METADATA_COL + self.vidx_name = DEFAULT_OCEANBASE_VIDX_NAME + self.hnsw_ef_search = -1 + + self.client = ObVecClient( + uri=(self.obconfig.host + ":" + self.obconfig.port), + user=self.obconfig.user, + password=self.obconfig.passwd, + db_name=self.obconfig.dbname, + ) + + super().__init__(config=self.obconfig) + + def _initialize(self): + """ + This method is needed because `embedder` attribute needs to be set externally before it can be initialized. + + So it's can't be done in __init__ in one step. + """ + if not hasattr(self, "embedder") or not self.embedder: + raise ValueError("Cannot create a OceanBase database collection without an embedder.") + if self.obconfig.drop_old: + self.client.drop_table_if_exist(table_name=self.obconfig.collection_name) + self._get_or_create_collection() + + def _get_or_create_db(self): + """Called during initialization""" + return self.client + + def _load_table(self): + table = Table( + self.obconfig.collection_name, + self.client.metadata_obj, + autoload_with=self.client.engine, + ) + column_names = [column.name for column in table.columns] + assert len(column_names) == 4 + + logging.info(f"load exist table with {column_names} columns") + self.id_field = column_names[0] + self.text_field = column_names[1] + self.embed_field = column_names[2] + self.metadata_field = column_names[3] + + def _get_or_create_collection(self): + """Get or create a named collection.""" + if self.client.check_table_exists( + table_name=self.obconfig.collection_name, + ): + self._load_table() + return + + cols = [ + Column(self.id_field, String(4096), primary_key=True, autoincrement=False), + Column(self.text_field, LONGTEXT), + Column(self.embed_field, VECTOR(self.embedder.vector_dimension)), + Column(self.metadata_field, JSON), + ] + + vidx_params = self.client.prepare_index_params() + vidx_params.add_index( + field_name=self.embed_field, + index_type=DEFAULT_OCEANBASE_VIDX_TYPE, + index_name=DEFAULT_OCEANBASE_VIDX_NAME, + metric_type=self.obconfig.vidx_metric_type, + params=self.obconfig.vidx_algo_params, + ) + + self.client.create_table_with_index_params( + table_name=self.obconfig.collection_name, + columns=cols, + indexes=None, + vidxs=vidx_params, + ) + + def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): + """ + Get existing doc ids present in vector database + + :param ids: list of doc ids to check for existence + :type ids: list[str] + :param where: Optional. to filter data + :type where: dict[str, Any] + :param limit: Optional. maximum number of documents + :type limit: Optional[int] + :return: Existing documents. + :rtype: Set[str] + """ + res = self.client.get( + table_name=self.obconfig.collection_name, + ids=ids, + where_clause=self._generate_oceanbase_filter(where), + output_column_name=[self.id_field, self.metadata_field], + ) + + data_ids = [] + metadatas = [] + for r in res.fetchall(): + data_ids.append(r[0]) + if isinstance(r[1], str) or isinstance(r[1], bytes): + metadatas.append(json.loads(r[1])) + elif isinstance(r[1], dict): + metadatas.append(r[1]) + else: + raise ValueError("invalid json type") + + return {"ids": data_ids, "metadatas": metadatas} + + def add( + self, + documents: list[str], + metadatas: list[object], + ids: list[str], + **kwargs: Optional[dict[str, any]], + ): + """Add to database""" + batch_size = 100 + embeddings = self.embedder.embedding_fn(documents) + + total_count = len(embeddings) + for i in range(0, total_count, batch_size): + data = [ + { + self.id_field: id, + self.text_field: text, + self.embed_field: (embedding if not self.obconfig.normalize else self._normalize(embedding)), + self.metadata_field: metadata, + } + for id, text, embedding, metadata in zip( + ids[i : i + batch_size], + documents[i : i + batch_size], + embeddings[i : i + batch_size], + metadatas[i : i + batch_size], + ) + ] + self.client.insert( + table_name=self.obconfig.collection_name, + data=data, + ) + + def _parse_metric_type_str_to_dist_func(self) -> Any: + if self.obconfig.vidx_metric_type == "l2": + return func.l2_distance + if self.obconfig.vidx_metric_type == "cosine": + return func.cosine_distance + if self.obconfig.vidx_metric_type == "inner_product": + return func.negative_inner_product + raise ValueError(f"Invalid vector index metric type: {self.obconfig.vidx_metric_type}") + + def _parse_distance_to_similarities(self, distance: float) -> float: + if self.obconfig.vidx_metric_type == "l2": + return _euclidean_similarity(distance) + elif self.obconfig.vidx_metric_type == "inner_product": + return _neg_inner_product_similarity(distance) + raise ValueError(f"Metric Type {self._vidx_metric_type} is not supported") + + def query( + self, + input_query: str, + n_results: int, + where: dict[str, Any], + citations: bool = False, + param: Optional[dict] = None, + **kwargs: Optional[dict[str, Any]], + ) -> Union[list[tuple[str, dict]], list[str]]: + """ + Query contents from vector database based on vector similarity + + :param input_query: query string + :type input_query: str + :param n_results: no of similar documents to fetch from database + :type n_results: int + :param where: to filter data + :type where: dict[str, Any] + :param citations: we use citations boolean param to return context along with the answer. + :type citations: bool, default is False. + :param param: search parameters for hnsw. + :type param: Optional[dict] + :return: The content of the document that matched your query, + along with url of the source and doc_id (if citations flag is true) + :rtype: list[str], if citations=False, otherwise list[tuple[str, dict]] + """ + search_param = param if param is not None else DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM + ef_search = search_param.get("efSearch", DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM["efSearch"]) + if ef_search != self.hnsw_ef_search: + self.client.set_ob_hnsw_ef_search(ef_search) + self.hnsw_ef_search = ef_search + + input_query_vector = self.embedder.embedding_fn([input_query]) + res = self.client.ann_search( + table_name=self.obconfig.collection_name, + vec_data=(input_query_vector[0] if not self.obconfig.normalize else _normalize(input_query_vector[0])), + vec_column_name=self.embed_field, + distance_func=self._parse_metric_type_str_to_dist_func(), + with_dist=True, + topk=n_results, + output_column_names=[self.text_field, self.metadata_field], + where_clause=self._generate_oceanbase_filter(where), + **kwargs, + ) + + contexts = [] + for r in res: + context = r[0] + if isinstance(r[1], str) or isinstance(r[1], bytes): + metadata = json.loads(r[1]) + elif isinstance(r[1], dict): + metadata = r[1] + else: + raise ValueError("invalid json type") + score = self._parse_distance_to_similarities(r[2]) + + if citations: + metadata["score"] = score + contexts.append((context, metadata)) + else: + contexts.append(context) + return contexts + + def count(self) -> int: + """ + Count number of documents/chunks embedded in the database. + + :return: number of documents + :rtype: int + """ + res = self.client.perform_raw_text_sql(f"SELECT COUNT(*) FROM {self.obconfig.collection_name}") + return res.fetchall()[0][0] + + def reset(self, collection_names: list[str] = None): + """ + Resets the database. Deletes all embeddings irreversibly. + """ + if collection_names: + for collection_name in collection_names: + self.client.drop_table_if_exist(table_name=collection_name) + + def set_collection_name(self, name: str): + """ + Set the name of the collection. A collection is an isolated space for vectors. + + :param name: Name of the collection. + :type name: str + """ + if not isinstance(name, str): + raise TypeError("Collection name must be a string") + self.obconfig.collection_name = name + + def _generate_oceanbase_filter(self, where: dict[str, str]): + if where is None or len(where.keys()) == 0: + return None + operands = [] + for key, value in where.items(): + operands.append(f"({self.metadata_field}->'$.{key}' = '{value}')") + return [text(" and ".join(operands))] diff --git a/embedchain/poetry.lock b/embedchain/poetry.lock index b1b1a13044..53be9ac3a6 100644 --- a/embedchain/poetry.lock +++ b/embedchain/poetry.lock @@ -96,6 +96,24 @@ yarl = ">=1.0,<2.0" [package.extras] speedups = ["Brotli", "aiodns", "brotlicffi"] +[[package]] +name = "aiomysql" +version = "0.2.0" +description = "MySQL driver for asyncio." +optional = true +python-versions = ">=3.7" +files = [ + {file = "aiomysql-0.2.0-py3-none-any.whl", hash = "sha256:b7c26da0daf23a5ec5e0b133c03d20657276e4eae9b73e040b72787f6f6ade0a"}, + {file = "aiomysql-0.2.0.tar.gz", hash = "sha256:558b9c26d580d08b8c5fd1be23c5231ce3aeff2dadad989540fee740253deb67"}, +] + +[package.dependencies] +PyMySQL = ">=1.0" + +[package.extras] +rsa = ["PyMySQL[rsa] (>=1.0)"] +sa = ["sqlalchemy (>=1.3,<1.4)"] + [[package]] name = "aiosignal" version = "1.3.1" @@ -4437,6 +4455,38 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"] model = ["milvus-model (>=0.1.0)"] +[[package]] +name = "pymysql" +version = "1.1.1" +description = "Pure Python MySQL Driver" +optional = true +python-versions = ">=3.7" +files = [ + {file = "PyMySQL-1.1.1-py3-none-any.whl", hash = "sha256:4de15da4c61dc132f4fb9ab763063e693d521a80fd0e87943b9a453dd4c19d6c"}, + {file = "pymysql-1.1.1.tar.gz", hash = "sha256:e127611aaf2b417403c60bf4dc570124aeb4a57f5f37b8e95ae399a42f904cd0"}, +] + +[package.extras] +ed25519 = ["PyNaCl (>=1.4.0)"] +rsa = ["cryptography"] + +[[package]] +name = "pyobvector" +version = "0.1.13" +description = "A python SDK for OceanBase Vector Store, based on SQLAlchemy, compatible with Milvus API." +optional = true +python-versions = "<4.0,>=3.9" +files = [ + {file = "pyobvector-0.1.13-py3-none-any.whl", hash = "sha256:b6a9e7a4673aebeefe835e04f7474d2f2ef8b9c96982af41cf9ce6f3e3500fdb"}, + {file = "pyobvector-0.1.13.tar.gz", hash = "sha256:e4b8f3ba3ad142cd7584b36278a38c0ef2fe7b6af142cdf5467d988e0737e03e"}, +] + +[package.dependencies] +aiomysql = ">=0.2.0,<0.3.0" +numpy = ">=1.26.0,<2.0.0" +pymysql = ">=1.1.1,<2.0.0" +sqlalchemy = ">=2.0.32,<3.0.0" + [[package]] name = "pyparsing" version = "3.1.2" @@ -5387,60 +5437,68 @@ files = [ [[package]] name = "sqlalchemy" -version = "2.0.31" +version = "2.0.36" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.31-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f2a213c1b699d3f5768a7272de720387ae0122f1becf0901ed6eaa1abd1baf6c"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9fea3d0884e82d1e33226935dac990b967bef21315cbcc894605db3441347443"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3ad7f221d8a69d32d197e5968d798217a4feebe30144986af71ada8c548e9fa"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f2bee229715b6366f86a95d497c347c22ddffa2c7c96143b59a2aa5cc9eebbc"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cd5b94d4819c0c89280b7c6109c7b788a576084bf0a480ae17c227b0bc41e109"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:750900a471d39a7eeba57580b11983030517a1f512c2cb287d5ad0fcf3aebd58"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-win32.whl", hash = "sha256:7bd112be780928c7f493c1a192cd8c5fc2a2a7b52b790bc5a84203fb4381c6be"}, - {file = "SQLAlchemy-2.0.31-cp310-cp310-win_amd64.whl", hash = "sha256:5a48ac4d359f058474fadc2115f78a5cdac9988d4f99eae44917f36aa1476327"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f68470edd70c3ac3b6cd5c2a22a8daf18415203ca1b036aaeb9b0fb6f54e8298"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e2c38c2a4c5c634fe6c3c58a789712719fa1bf9b9d6ff5ebfce9a9e5b89c1ca"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd15026f77420eb2b324dcb93551ad9c5f22fab2c150c286ef1dc1160f110203"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2196208432deebdfe3b22185d46b08f00ac9d7b01284e168c212919891289396"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:352b2770097f41bff6029b280c0e03b217c2dcaddc40726f8f53ed58d8a85da4"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56d51ae825d20d604583f82c9527d285e9e6d14f9a5516463d9705dab20c3740"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-win32.whl", hash = "sha256:6e2622844551945db81c26a02f27d94145b561f9d4b0c39ce7bfd2fda5776dac"}, - {file = "SQLAlchemy-2.0.31-cp311-cp311-win_amd64.whl", hash = "sha256:ccaf1b0c90435b6e430f5dd30a5aede4764942a695552eb3a4ab74ed63c5b8d3"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3b74570d99126992d4b0f91fb87c586a574a5872651185de8297c6f90055ae42"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f77c4f042ad493cb8595e2f503c7a4fe44cd7bd59c7582fd6d78d7e7b8ec52c"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd1591329333daf94467e699e11015d9c944f44c94d2091f4ac493ced0119449"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74afabeeff415e35525bf7a4ecdab015f00e06456166a2eba7590e49f8db940e"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b9c01990d9015df2c6f818aa8f4297d42ee71c9502026bb074e713d496e26b67"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:66f63278db425838b3c2b1c596654b31939427016ba030e951b292e32b99553e"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-win32.whl", hash = "sha256:0b0f658414ee4e4b8cbcd4a9bb0fd743c5eeb81fc858ca517217a8013d282c96"}, - {file = "SQLAlchemy-2.0.31-cp312-cp312-win_amd64.whl", hash = "sha256:fa4b1af3e619b5b0b435e333f3967612db06351217c58bfb50cee5f003db2a5a"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f43e93057cf52a227eda401251c72b6fbe4756f35fa6bfebb5d73b86881e59b0"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d337bf94052856d1b330d5fcad44582a30c532a2463776e1651bd3294ee7e58b"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c06fb43a51ccdff3b4006aafee9fcf15f63f23c580675f7734245ceb6b6a9e05"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:b6e22630e89f0e8c12332b2b4c282cb01cf4da0d26795b7eae16702a608e7ca1"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:79a40771363c5e9f3a77f0e28b3302801db08040928146e6808b5b7a40749c88"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-win32.whl", hash = "sha256:501ff052229cb79dd4c49c402f6cb03b5a40ae4771efc8bb2bfac9f6c3d3508f"}, - {file = "SQLAlchemy-2.0.31-cp37-cp37m-win_amd64.whl", hash = "sha256:597fec37c382a5442ffd471f66ce12d07d91b281fd474289356b1a0041bdf31d"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:dc6d69f8829712a4fd799d2ac8d79bdeff651c2301b081fd5d3fe697bd5b4ab9"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:23b9fbb2f5dd9e630db70fbe47d963c7779e9c81830869bd7d137c2dc1ad05fb"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a21c97efcbb9f255d5c12a96ae14da873233597dfd00a3a0c4ce5b3e5e79704"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26a6a9837589c42b16693cf7bf836f5d42218f44d198f9343dd71d3164ceeeac"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:dc251477eae03c20fae8db9c1c23ea2ebc47331bcd73927cdcaecd02af98d3c3"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2fd17e3bb8058359fa61248c52c7b09a97cf3c820e54207a50af529876451808"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-win32.whl", hash = "sha256:c76c81c52e1e08f12f4b6a07af2b96b9b15ea67ccdd40ae17019f1c373faa227"}, - {file = "SQLAlchemy-2.0.31-cp38-cp38-win_amd64.whl", hash = "sha256:4b600e9a212ed59355813becbcf282cfda5c93678e15c25a0ef896b354423238"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b6cf796d9fcc9b37011d3f9936189b3c8074a02a4ed0c0fbbc126772c31a6d4"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:78fe11dbe37d92667c2c6e74379f75746dc947ee505555a0197cfba9a6d4f1a4"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fc47dc6185a83c8100b37acda27658fe4dbd33b7d5e7324111f6521008ab4fe"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a41514c1a779e2aa9a19f67aaadeb5cbddf0b2b508843fcd7bafdf4c6864005"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:afb6dde6c11ea4525318e279cd93c8734b795ac8bb5dda0eedd9ebaca7fa23f1"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3f9faef422cfbb8fd53716cd14ba95e2ef655400235c3dfad1b5f467ba179c8c"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-win32.whl", hash = "sha256:fc6b14e8602f59c6ba893980bea96571dd0ed83d8ebb9c4479d9ed5425d562e9"}, - {file = "SQLAlchemy-2.0.31-cp39-cp39-win_amd64.whl", hash = "sha256:3cb8a66b167b033ec72c3812ffc8441d4e9f5f78f5e31e54dcd4c90a4ca5bebc"}, - {file = "SQLAlchemy-2.0.31-py3-none-any.whl", hash = "sha256:69f3e3c08867a8e4856e92d7afb618b95cdee18e0bc1647b77599722c9a28911"}, - {file = "SQLAlchemy-2.0.31.tar.gz", hash = "sha256:b607489dd4a54de56984a0c7656247504bd5523d9d0ba799aef59d4add009484"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59b8f3adb3971929a3e660337f5dacc5942c2cdb760afcabb2614ffbda9f9f72"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37350015056a553e442ff672c2d20e6f4b6d0b2495691fa239d8aa18bb3bc908"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8318f4776c85abc3f40ab185e388bee7a6ea99e7fa3a30686580b209eaa35c08"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c245b1fbade9c35e5bd3b64270ab49ce990369018289ecfde3f9c318411aaa07"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:69f93723edbca7342624d09f6704e7126b152eaed3cdbb634cb657a54332a3c5"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f9511d8dd4a6e9271d07d150fb2f81874a3c8c95e11ff9af3a2dfc35fe42ee44"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-win32.whl", hash = "sha256:c3f3631693003d8e585d4200730616b78fafd5a01ef8b698f6967da5c605b3fa"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-win_amd64.whl", hash = "sha256:a86bfab2ef46d63300c0f06936bd6e6c0105faa11d509083ba8f2f9d237fb5b5"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fd3a55deef00f689ce931d4d1b23fa9f04c880a48ee97af488fd215cf24e2a6c"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4f5e9cd989b45b73bd359f693b935364f7e1f79486e29015813c338450aa5a71"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0ddd9db6e59c44875211bc4c7953a9f6638b937b0a88ae6d09eb46cced54eff"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2519f3a5d0517fc159afab1015e54bb81b4406c278749779be57a569d8d1bb0d"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59b1ee96617135f6e1d6f275bbe988f419c5178016f3d41d3c0abb0c819f75bb"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:39769a115f730d683b0eb7b694db9789267bcd027326cccc3125e862eb03bfd8"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-win32.whl", hash = "sha256:66bffbad8d6271bb1cc2f9a4ea4f86f80fe5e2e3e501a5ae2a3dc6a76e604e6f"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-win_amd64.whl", hash = "sha256:23623166bfefe1487d81b698c423f8678e80df8b54614c2bf4b4cfcd7c711959"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f7b64e6ec3f02c35647be6b4851008b26cff592a95ecb13b6788a54ef80bbdd4"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:46331b00096a6db1fdc052d55b101dbbfc99155a548e20a0e4a8e5e4d1362855"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdf3386a801ea5aba17c6410dd1dc8d39cf454ca2565541b5ac42a84e1e28f53"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac9dfa18ff2a67b09b372d5db8743c27966abf0e5344c555d86cc7199f7ad83a"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:90812a8933df713fdf748b355527e3af257a11e415b613dd794512461eb8a686"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1bc330d9d29c7f06f003ab10e1eaced295e87940405afe1b110f2eb93a233588"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-win32.whl", hash = "sha256:79d2e78abc26d871875b419e1fd3c0bca31a1cb0043277d0d850014599626c2e"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-win_amd64.whl", hash = "sha256:b544ad1935a8541d177cb402948b94e871067656b3a0b9e91dbec136b06a2ff5"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b5cc79df7f4bc3d11e4b542596c03826063092611e481fcf1c9dfee3c94355ef"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3c01117dd36800f2ecaa238c65365b7b16497adc1522bf84906e5710ee9ba0e8"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9bc633f4ee4b4c46e7adcb3a9b5ec083bf1d9a97c1d3854b92749d935de40b9b"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e46ed38affdfc95d2c958de328d037d87801cfcbea6d421000859e9789e61c2"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b2985c0b06e989c043f1dc09d4fe89e1616aadd35392aea2844f0458a989eacf"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4a121d62ebe7d26fec9155f83f8be5189ef1405f5973ea4874a26fab9f1e262c"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-win32.whl", hash = "sha256:0572f4bd6f94752167adfd7c1bed84f4b240ee6203a95e05d1e208d488d0d436"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-win_amd64.whl", hash = "sha256:8c78ac40bde930c60e0f78b3cd184c580f89456dd87fc08f9e3ee3ce8765ce88"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:be9812b766cad94a25bc63bec11f88c4ad3629a0cec1cd5d4ba48dc23860486b"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50aae840ebbd6cdd41af1c14590e5741665e5272d2fee999306673a1bb1fdb4d"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4557e1f11c5f653ebfdd924f3f9d5ebfc718283b0b9beebaa5dd6b77ec290971"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:07b441f7d03b9a66299ce7ccf3ef2900abc81c0db434f42a5694a37bd73870f2"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:28120ef39c92c2dd60f2721af9328479516844c6b550b077ca450c7d7dc68575"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-win32.whl", hash = "sha256:b81ee3d84803fd42d0b154cb6892ae57ea6b7c55d8359a02379965706c7efe6c"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-win_amd64.whl", hash = "sha256:f942a799516184c855e1a32fbc7b29d7e571b52612647866d4ec1c3242578fcb"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3d6718667da04294d7df1670d70eeddd414f313738d20a6f1d1f379e3139a545"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:72c28b84b174ce8af8504ca28ae9347d317f9dba3999e5981a3cd441f3712e24"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b11d0cfdd2b095e7b0686cf5fabeb9c67fae5b06d265d8180715b8cfa86522e3"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e32092c47011d113dc01ab3e1d3ce9f006a47223b18422c5c0d150af13a00687"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6a440293d802d3011028e14e4226da1434b373cbaf4a4bbb63f845761a708346"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c54a1e53a0c308a8e8a7dffb59097bff7facda27c70c286f005327f21b2bd6b1"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-win32.whl", hash = "sha256:1e0d612a17581b6616ff03c8e3d5eff7452f34655c901f75d62bd86449d9750e"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-win_amd64.whl", hash = "sha256:8958b10490125124463095bbdadda5aa22ec799f91958e410438ad6c97a7b793"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dc022184d3e5cacc9579e41805a681187650e170eb2fd70e28b86192a479dcaa"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b817d41d692bf286abc181f8af476c4fbef3fd05e798777492618378448ee689"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4e46a888b54be23d03a89be510f24a7652fe6ff660787b96cd0e57a4ebcb46d"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4ae3005ed83f5967f961fd091f2f8c5329161f69ce8480aa8168b2d7fe37f06"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03e08af7a5f9386a43919eda9de33ffda16b44eb11f3b313e6822243770e9763"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3dbb986bad3ed5ceaf090200eba750b5245150bd97d3e67343a3cfed06feecf7"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-win32.whl", hash = "sha256:9fe53b404f24789b5ea9003fc25b9a3988feddebd7e7b369c8fac27ad6f52f28"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-win_amd64.whl", hash = "sha256:af148a33ff0349f53512a049c6406923e4e02bf2f26c5fb285f143faf4f0e46a"}, + {file = "SQLAlchemy-2.0.36-py3-none-any.whl", hash = "sha256:fddbe92b4760c6f5d48162aef14824add991aeda8ddadb3c31d56eb15ca69f8e"}, + {file = "sqlalchemy-2.0.36.tar.gz", hash = "sha256:7f2767680b6d2398aea7082e45a774b2b0767b5c8d8ffb9c8b683088ea9b29c5"}, ] [package.dependencies] @@ -5453,7 +5511,7 @@ aioodbc = ["aioodbc", "greenlet (!=0.4.17)"] aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"] asyncio = ["greenlet (!=0.4.17)"] asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] -mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] +mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10)"] mssql = ["pyodbc"] mssql-pymssql = ["pymssql"] mssql-pyodbc = ["pyodbc"] @@ -6659,6 +6717,7 @@ mysql = ["mysql-connector-python"] opensearch = ["opensearch-py"] opensource = ["gpt4all", "sentence-transformers", "torch"] postgres = ["psycopg", "psycopg-binary", "psycopg-pool"] +pyobvector = ["pyobvector"] qdrant = ["qdrant-client"] together = ["together"] vertexai = ["langchain-google-vertexai"] @@ -6667,4 +6726,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.13" -content-hash = "98ab80c76f001e35dd37dd705b4d8e88ecd4377383457aab38c0c94d3b872f33" +content-hash = "2c89658261a3d4f92f84dc7ddec4e1d4a55acc2d61b8203c866cf5a8be184e79 diff --git a/embedchain/pyproject.toml b/embedchain/pyproject.toml index a4ed156e7a..0021e87734 100644 --- a/embedchain/pyproject.toml +++ b/embedchain/pyproject.toml @@ -140,6 +140,7 @@ langchain-cohere = "^0.3.0" langchain-community = "^0.3.1" langchain-aws = {version = "^0.2.1", optional = true} langsmith = "^0.1.17" +pyobvector = {version = "^0.1.13", optional = true} [tool.poetry.group.dev.dependencies] black = "^23.3.0" @@ -180,8 +181,7 @@ mysql = ["mysql-connector-python"] google = ["google-generativeai"] mistralai = ["langchain-mistralai"] aws = ["langchain-aws"] - -[tool.poetry.group.docs.dependencies] +pyobvector = ["pyobvector"] [tool.poetry.scripts] ec = "embedchain.cli:cli" \ No newline at end of file diff --git a/embedchain/tests/vectordb/test_oceanbase.py b/embedchain/tests/vectordb/test_oceanbase.py new file mode 100644 index 0000000000..08819a160f --- /dev/null +++ b/embedchain/tests/vectordb/test_oceanbase.py @@ -0,0 +1,45 @@ +# ruff: noqa: E501 + +import logging +import os +from unittest import mock +from unittest.mock import patch + +import pytest + +from embedchain.config.vector_db.oceanbase import OceanBaseConfig +from embedchain.vectordb.oceanbase import OceanBaseVectorDB + +logger = logging.getLogger(__name__) + + +class TestOceanBaseConifg: + @mock.patch.dict(os.environ, {"OB_PASSWORD": "ob_password"}) + def test_init_oceanbase_config(self): + expect_password = "ob_password" + ob_config = OceanBaseConfig() + assert ob_config.passwd == expect_password + + +class TestOceanBaseVector: + @pytest.fixture + def mock_embedder(self, mocker): + return mocker.Mock() + + @patch("embedchain.vectordb.oceanbase.ObVecClient", autospec=True) + def test_query(self, mock_client, mock_embedder): + ob_config = OceanBaseConfig(drop_old=True) + ob = OceanBaseVectorDB(config=ob_config) + ob.embedder = mock_embedder + + with patch.object(ob.client, "ann_search") as mock_search: + mock_embedder.embedding_fn.return_value = ["query_vector"] + mock_search.return_value = [("result_doc", '{"url": "url_1", "doc_id": "doc_id_1"}', 0.0)] + + query_result = ob.query(input_query="query_text", n_results=1, where={}) + + assert query_result == ["result_doc"] + + query_result = ob.query(input_query="query_text", n_results=1, where={}, citations=True) + + assert query_result == [("result_doc", {"url": "url_1", "doc_id": "doc_id_1", "score": 1.0})]