Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: generalize creating objects from config #233

Merged
merged 7 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/how-to/document_search/use_reranker.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,4 @@ class CustomReranker(Reranker):
options: RerankerOptions | None = None,
) -> Sequence[Element]:
pass

@classmethod
def from_config(cls, config: dict) -> "CustomReranker":
pass
```
24 changes: 0 additions & 24 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,4 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import Embeddings, EmbeddingType
from .noop import NoopEmbeddings

__all__ = ["EmbeddingType", "Embeddings", "NoopEmbeddings"]

module = sys.modules[__name__]


def get_embeddings(embedder_config: dict) -> Embeddings:
"""
Initializes and returns an Embeddings object based on the provided embedder configuration.

Args:
embedder_config : A dictionary containing configuration details for the embedder.

Returns:
An instance of the specified Embeddings class, initialized with the provided config
(if any) or default arguments.
"""
embeddings_type = embedder_config["type"]
config = embedder_config.get("config", {})

embbedings = get_cls_from_config(embeddings_type, module)
return embbedings(**config)
8 changes: 7 additions & 1 deletion packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import ClassVar

from ragbits.core import embeddings
from ragbits.core.utils.config_handling import WithConstructionConfig


class EmbeddingType(Enum):
Expand All @@ -17,11 +21,13 @@ class EmbeddingType(Enum):
IMAGE: str = "image"


class Embeddings(ABC):
class Embeddings(WithConstructionConfig, ABC):
"""
Abstract client for communication with embedding models.
"""

default_module: ClassVar = embeddings

@abstractmethod
async def embed_text(self, data: list[str]) -> list[list[float]]:
"""
Expand Down
36 changes: 0 additions & 36 deletions packages/ragbits-core/src/ragbits/core/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,3 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import LLM

__all__ = ["LLM"]

module = sys.modules[__name__]


def get_llm(config: dict) -> LLM:
"""
Initializes and returns an LLM object based on the provided configuration.

Args:
config : A dictionary containing configuration details for the LLM.

Returns:
An instance of the specified LLM class, initialized with the provided config
(if any) or default arguments.

Raises:
KeyError: If the configuration dictionary does not contain a "type" key.
ValueError: If the LLM class is not a subclass of LLM.
"""
llm_type = config["type"]
llm_config = config.get("config", {})
default_options = llm_config.pop("default_options", None)
llm_cls = get_cls_from_config(llm_type, module)

if not issubclass(llm_cls, LLM):
raise ValueError(f"Invalid LLM class: {llm_cls}")

# We need to infer the options class from the LLM class.
# pylint: disable=protected-access
options = llm_cls._options_cls(**default_options) if default_options else None # type: ignore

return llm_cls(**llm_config, default_options=options)
26 changes: 24 additions & 2 deletions packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from functools import cached_property
from typing import Generic, cast, overload
from typing import ClassVar, Generic, cast, overload

from typing_extensions import Self

from ragbits.core import llms
from ragbits.core.prompt.base import BasePrompt, BasePromptWithParser, ChatFormat, OutputT
from ragbits.core.utils.config_handling import WithConstructionConfig

from .clients.base import LLMClient, LLMClientOptions, LLMOptions

Expand All @@ -20,12 +24,13 @@ class LLMType(enum.Enum):
STRUCTURED_OUTPUT = "structured_output"


class LLM(Generic[LLMClientOptions], ABC):
class LLM(WithConstructionConfig, Generic[LLMClientOptions], ABC):
"""
Abstract class for interaction with Large Language Model.
"""

_options_cls: type[LLMClientOptions]
default_module: ClassVar = llms

def __init__(self, model_name: str, default_options: LLMOptions | None = None) -> None:
"""
Expand Down Expand Up @@ -160,3 +165,20 @@ def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat:
if prompt.list_images():
wrngs.warn(message=f"Image input not implemented for {self.__class__.__name__}")
return prompt.chat

@classmethod
def from_config(cls, config: dict) -> Self:
"""
Initializes the class with the provided configuration.

Args:
config: A dictionary containing configuration details for the class.

Returns:
An instance of the class initialized with the provided configuration.
"""
default_options = config.pop("default_options", None)

options = cls._options_cls(**default_options) if default_options else None

return cls(**config, default_options=options)
26 changes: 0 additions & 26 deletions packages/ragbits-core/src/ragbits/core/metadata_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,4 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import MetadataStore
from .in_memory import InMemoryMetadataStore

__all__ = ["InMemoryMetadataStore", "MetadataStore"]

module = sys.modules[__name__]


def get_metadata_store(metadata_store_config: dict | None) -> MetadataStore | None:
"""
Initializes and returns a MetadataStore object based on the provided configuration.

Args:
metadata_store_config: A dictionary containing configuration details for the MetadataStore.

Returns:
An instance of the specified MetadataStore class, initialized with the provided config
(if any) or default arguments.
"""
if metadata_store_config is None:
return None

metadata_store_class = get_cls_from_config(metadata_store_config["type"], module)
config = metadata_store_config.get("config", {})

return metadata_store_class(**config)
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from abc import ABC, abstractmethod
from typing import ClassVar

from ragbits.core import metadata_stores
from ragbits.core.utils.config_handling import WithConstructionConfig

class MetadataStore(ABC):

class MetadataStore(WithConstructionConfig, ABC):
"""
An abstract class for metadata storage. Allows to store, query and retrieve metadata in form of key value pairs.
"""

default_module: ClassVar = metadata_stores

@abstractmethod
async def store(self, ids: list[str], metadatas: list[dict]) -> None:
"""
Expand Down
63 changes: 61 additions & 2 deletions packages/ragbits-core/src/ragbits/core/utils/config_handling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import abc
from importlib import import_module
from types import ModuleType
from typing import Any
from typing import Any, ClassVar

from pydantic import BaseModel
from typing_extensions import Self


class InvalidConfigError(Exception):
Expand All @@ -9,7 +13,7 @@ class InvalidConfigError(Exception):
"""


def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # noqa: ANN401
def get_cls_from_config(cls_path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401
"""
Retrieves and returns a class based on the given type string. The class can be either in the
default module or a specified module if provided in the type string.
Expand All @@ -32,7 +36,62 @@ def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # no
except AttributeError as err:
raise InvalidConfigError(f"Class {object_stringified} not found in module {module_stringified}") from err

if default_module is None:
raise InvalidConfigError("Given type string does not contain a module and no default module provided")

try:
return getattr(default_module, cls_path)
except AttributeError as err:
raise InvalidConfigError(f"Class {cls_path} not found in module {default_module}") from err


class ObjectContructionConfig(BaseModel):
"""
A model for object construction configuration.
"""

# Path to the class to be constructed
type: str

# Configuration details for the class
config: dict[str, Any] = {}


class WithConstructionConfig(abc.ABC):
"""
A mixin class that provides methods for initializing classes from configuration.
"""

# The default module to search for the subclass if no specific module is provided in the type string.
default_module: ClassVar[ModuleType | None] = None

@classmethod
def subclass_from_config(cls, config: ObjectContructionConfig) -> Self:
"""
Initializes the class with the provided configuration. May return a subclass of the class,
if requested by the configuration.

Args:
config: A model containing configuration details for the class.

Returns:
An instance of the class initialized with the provided configuration.
"""
subclass = get_cls_from_config(config.type, cls.default_module)
if not issubclass(subclass, cls):
raise InvalidConfigError(f"{subclass} is not a subclass of {cls}")

return subclass.from_config(config.config)

@classmethod
def from_config(cls, config: dict) -> Self:
"""
Initializes the class with the provided configuration.

Args:
config: A dictionary containing configuration details for the class.

Returns:
An instance of the class initialized with the provided configuration.
"""
return cls(**config)
23 changes: 0 additions & 23 deletions packages/ragbits-core/src/ragbits/core/vector_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,4 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore

__all__ = ["InMemoryVectorStore", "VectorStore", "VectorStoreEntry", "VectorStoreOptions", "WhereQuery"]


def get_vector_store(config: dict) -> VectorStore:
"""
Initializes and returns a VectorStore object based on the provided configuration.

Args:
config: A dictionary containing configuration details for the VectorStore.

Returns:
An instance of the specified VectorStore class, initialized with the provided config
(if any) or default arguments.

Raises:
KeyError: If the provided configuration does not contain a valid "type" key.
InvalidConfigurationError: If the provided configuration is invalid.
NotImplementedError: If the specified VectorStore class cannot be created from the provided configuration.
"""
vector_store_cls = get_cls_from_config(config["type"], sys.modules[__name__])
return vector_store_cls.from_config(config.get("config", {}))
31 changes: 22 additions & 9 deletions packages/ragbits-core/src/ragbits/core/vector_stores/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from abc import ABC, abstractmethod
from typing import ClassVar

from pydantic import BaseModel
from typing_extensions import Self

from ragbits.core import vector_stores
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.utils.config_handling import ObjectContructionConfig, WithConstructionConfig

WhereQuery = dict[str, str | int | float | bool]

Expand All @@ -27,11 +31,13 @@ class VectorStoreOptions(BaseModel, ABC):
max_distance: float | None = None


class VectorStore(ABC):
class VectorStore(WithConstructionConfig, ABC):
"""
A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function.
"""

default_module: ClassVar = vector_stores

def __init__(
self,
default_options: VectorStoreOptions | None = None,
Expand All @@ -49,20 +55,27 @@ def __init__(
self._metadata_store = metadata_store

@classmethod
def from_config(cls, config: dict) -> "VectorStore":
def from_config(cls, config: dict) -> Self:
"""
Creates and returns an instance of the Reranker class from the given configuration.
Initializes the class with the provided configuration.

Args:
config: A dictionary containing the configuration for initializing the Reranker instance.
config: A dictionary containing configuration details for the class.

Returns:
An initialized instance of the Reranker class.

Raises:
NotImplementedError: If the class cannot be created from the provided configuration.
An instance of the class initialized with the provided configuration.
"""
raise NotImplementedError(f"Cannot create class {cls.__name__} from config.")
default_options = config.pop("default_options", None)
options = VectorStoreOptions(**default_options) if default_options else None

store_config = config.pop("metadata_store", None)
store = (
MetadataStore.subclass_from_config(ObjectContructionConfig.model_validate(store_config))
if store_config
else None
)

return cls(**config, default_options=options, metadata_store=store)

@abstractmethod
async def store(self, entries: list[VectorStoreEntry]) -> None:
Expand Down
Loading
Loading