Skip to content

Commit

Permalink
refactor: generalize creating objects from config (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer authored Dec 5, 2024
1 parent ba4a59d commit 0b6e1e1
Show file tree
Hide file tree
Showing 34 changed files with 606 additions and 372 deletions.
4 changes: 0 additions & 4 deletions docs/how-to/document_search/use_reranker.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,4 @@ class CustomReranker(Reranker):
options: RerankerOptions | None = None,
) -> Sequence[Element]:
pass

@classmethod
def from_config(cls, config: dict) -> "CustomReranker":
pass
```
24 changes: 0 additions & 24 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,5 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import Embeddings, EmbeddingType
from .litellm import LiteLLMEmbeddings
from .noop import NoopEmbeddings

__all__ = ["EmbeddingType", "Embeddings", "LiteLLMEmbeddings", "NoopEmbeddings"]

module = sys.modules[__name__]


def get_embeddings(embedder_config: dict) -> Embeddings:
"""
Initializes and returns an Embeddings object based on the provided embedder configuration.
Args:
embedder_config : A dictionary containing configuration details for the embedder.
Returns:
An instance of the specified Embeddings class, initialized with the provided config
(if any) or default arguments.
"""
embeddings_type = embedder_config["type"]
config = embedder_config.get("config", {})

embbedings = get_cls_from_config(embeddings_type, module)
return embbedings(**config)
8 changes: 7 additions & 1 deletion packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import ClassVar

from ragbits.core import embeddings
from ragbits.core.utils.config_handling import WithConstructionConfig


class EmbeddingType(Enum):
Expand All @@ -17,11 +21,13 @@ class EmbeddingType(Enum):
IMAGE: str = "image"


class Embeddings(ABC):
class Embeddings(WithConstructionConfig, ABC):
"""
Abstract client for communication with embedding models.
"""

default_module: ClassVar = embeddings

@abstractmethod
async def embed_text(self, data: list[str]) -> list[list[float]]:
"""
Expand Down
36 changes: 0 additions & 36 deletions packages/ragbits-core/src/ragbits/core/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,4 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import LLM
from .litellm import LiteLLM

__all__ = ["LLM", "LiteLLM"]

module = sys.modules[__name__]


def get_llm(config: dict) -> LLM:
"""
Initializes and returns an LLM object based on the provided configuration.
Args:
config : A dictionary containing configuration details for the LLM.
Returns:
An instance of the specified LLM class, initialized with the provided config
(if any) or default arguments.
Raises:
KeyError: If the configuration dictionary does not contain a "type" key.
ValueError: If the LLM class is not a subclass of LLM.
"""
llm_type = config["type"]
llm_config = config.get("config", {})
default_options = llm_config.pop("default_options", None)
llm_cls = get_cls_from_config(llm_type, module)

if not issubclass(llm_cls, LLM):
raise ValueError(f"Invalid LLM class: {llm_cls}")

# We need to infer the options class from the LLM class.
# pylint: disable=protected-access
options = llm_cls._options_cls(**default_options) if default_options else None # type: ignore

return llm_cls(**llm_config, default_options=options)
26 changes: 24 additions & 2 deletions packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from functools import cached_property
from typing import Generic, cast, overload
from typing import ClassVar, Generic, cast, overload

from typing_extensions import Self

from ragbits.core import llms
from ragbits.core.prompt.base import BasePrompt, BasePromptWithParser, ChatFormat, OutputT
from ragbits.core.utils.config_handling import WithConstructionConfig

from .clients.base import LLMClient, LLMClientOptions, LLMOptions

Expand All @@ -20,12 +24,13 @@ class LLMType(enum.Enum):
STRUCTURED_OUTPUT = "structured_output"


class LLM(Generic[LLMClientOptions], ABC):
class LLM(WithConstructionConfig, Generic[LLMClientOptions], ABC):
"""
Abstract class for interaction with Large Language Model.
"""

_options_cls: type[LLMClientOptions]
default_module: ClassVar = llms

def __init__(self, model_name: str, default_options: LLMOptions | None = None) -> None:
"""
Expand Down Expand Up @@ -160,3 +165,20 @@ def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat:
if prompt.list_images():
wrngs.warn(message=f"Image input not implemented for {self.__class__.__name__}")
return prompt.chat

@classmethod
def from_config(cls, config: dict) -> Self:
"""
Initializes the class with the provided configuration.
Args:
config: A dictionary containing configuration details for the class.
Returns:
An instance of the class initialized with the provided configuration.
"""
default_options = config.pop("default_options", None)

options = cls._options_cls(**default_options) if default_options else None

return cls(**config, default_options=options)
26 changes: 0 additions & 26 deletions packages/ragbits-core/src/ragbits/core/metadata_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,4 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import MetadataStore
from .in_memory import InMemoryMetadataStore

__all__ = ["InMemoryMetadataStore", "MetadataStore"]

module = sys.modules[__name__]


def get_metadata_store(metadata_store_config: dict | None) -> MetadataStore | None:
"""
Initializes and returns a MetadataStore object based on the provided configuration.
Args:
metadata_store_config: A dictionary containing configuration details for the MetadataStore.
Returns:
An instance of the specified MetadataStore class, initialized with the provided config
(if any) or default arguments.
"""
if metadata_store_config is None:
return None

metadata_store_class = get_cls_from_config(metadata_store_config["type"], module)
config = metadata_store_config.get("config", {})

return metadata_store_class(**config)
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from abc import ABC, abstractmethod
from typing import ClassVar

from ragbits.core import metadata_stores
from ragbits.core.utils.config_handling import WithConstructionConfig

class MetadataStore(ABC):

class MetadataStore(WithConstructionConfig, ABC):
"""
An abstract class for metadata storage. Allows to store, query and retrieve metadata in form of key value pairs.
"""

default_module: ClassVar = metadata_stores

@abstractmethod
async def store(self, ids: list[str], metadatas: list[dict]) -> None:
"""
Expand Down
69 changes: 67 additions & 2 deletions packages/ragbits-core/src/ragbits/core/utils/config_handling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import abc
from importlib import import_module
from types import ModuleType
from typing import Any
from typing import Any, ClassVar

from pydantic import BaseModel
from typing_extensions import Self


class InvalidConfigError(Exception):
Expand All @@ -9,7 +13,7 @@ class InvalidConfigError(Exception):
"""


def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # noqa: ANN401
def get_cls_from_config(cls_path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401
"""
Retrieves and returns a class based on the given type string. The class can be either in the
default module or a specified module if provided in the type string.
Expand All @@ -23,6 +27,9 @@ def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # no
Returns:
Any: The object retrieved from the specified or default module.
Raises:
InvalidConfigError: The requested class is not found under the specified module
"""
if ":" in cls_path:
try:
Expand All @@ -32,7 +39,65 @@ def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # no
except AttributeError as err:
raise InvalidConfigError(f"Class {object_stringified} not found in module {module_stringified}") from err

if default_module is None:
raise InvalidConfigError("Given type string does not contain a module and no default module provided")

try:
return getattr(default_module, cls_path)
except AttributeError as err:
raise InvalidConfigError(f"Class {cls_path} not found in module {default_module}") from err


class ObjectContructionConfig(BaseModel):
"""
A model for object construction configuration.
"""

# Path to the class to be constructed
type: str

# Configuration details for the class
config: dict[str, Any] = {}


class WithConstructionConfig(abc.ABC):
"""
A mixin class that provides methods for initializing classes from configuration.
"""

# The default module to search for the subclass if no specific module is provided in the type string.
default_module: ClassVar[ModuleType | None] = None

@classmethod
def subclass_from_config(cls, config: ObjectContructionConfig) -> Self:
"""
Initializes the class with the provided configuration. May return a subclass of the class,
if requested by the configuration.
Args:
config: A model containing configuration details for the class.
Returns:
An instance of the class initialized with the provided configuration.
Raises:
InvalidConfigError: The class can't be found or is not a subclass of the current class.
"""
subclass = get_cls_from_config(config.type, cls.default_module)
if not issubclass(subclass, cls):
raise InvalidConfigError(f"{subclass} is not a subclass of {cls}")

return subclass.from_config(config.config)

@classmethod
def from_config(cls, config: dict) -> Self:
"""
Initializes the class with the provided configuration.
Args:
config: A dictionary containing configuration details for the class.
Returns:
An instance of the class initialized with the provided configuration.
"""
return cls(**config)
23 changes: 0 additions & 23 deletions packages/ragbits-core/src/ragbits/core/vector_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,4 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore

__all__ = ["InMemoryVectorStore", "VectorStore", "VectorStoreEntry", "VectorStoreOptions", "WhereQuery"]


def get_vector_store(config: dict) -> VectorStore:
"""
Initializes and returns a VectorStore object based on the provided configuration.
Args:
config: A dictionary containing configuration details for the VectorStore.
Returns:
An instance of the specified VectorStore class, initialized with the provided config
(if any) or default arguments.
Raises:
KeyError: If the provided configuration does not contain a valid "type" key.
InvalidConfigurationError: If the provided configuration is invalid.
NotImplementedError: If the specified VectorStore class cannot be created from the provided configuration.
"""
vector_store_cls = get_cls_from_config(config["type"], sys.modules[__name__])
return vector_store_cls.from_config(config.get("config", {}))
31 changes: 24 additions & 7 deletions packages/ragbits-core/src/ragbits/core/vector_stores/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from abc import ABC, abstractmethod
from typing import ClassVar

from pydantic import BaseModel
from typing_extensions import Self

from ragbits.core import vector_stores
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.utils.config_handling import ObjectContructionConfig, WithConstructionConfig

WhereQuery = dict[str, str | int | float | bool]

Expand All @@ -27,11 +31,13 @@ class VectorStoreOptions(BaseModel, ABC):
max_distance: float | None = None


class VectorStore(ABC):
class VectorStore(WithConstructionConfig, ABC):
"""
A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function.
"""

default_module: ClassVar = vector_stores

def __init__(
self,
default_options: VectorStoreOptions | None = None,
Expand All @@ -49,20 +55,31 @@ def __init__(
self._metadata_store = metadata_store

@classmethod
def from_config(cls, config: dict) -> "VectorStore":
def from_config(cls, config: dict) -> Self:
"""
Creates and returns an instance of the Reranker class from the given configuration.
Initializes the class with the provided configuration.
Args:
config: A dictionary containing the configuration for initializing the Reranker instance.
config: A dictionary containing configuration details for the class.
Returns:
An initialized instance of the Reranker class.
An instance of the class initialized with the provided configuration.
Raises:
NotImplementedError: If the class cannot be created from the provided configuration.
ValidationError: The metadata_store configuration doesn't follow the expected format.
InvalidConfigError: The metadata_store class can't be found or is not the correct type.
"""
raise NotImplementedError(f"Cannot create class {cls.__name__} from config.")
default_options = config.pop("default_options", None)
options = VectorStoreOptions(**default_options) if default_options else None

store_config = config.pop("metadata_store", None)
store = (
MetadataStore.subclass_from_config(ObjectContructionConfig.model_validate(store_config))
if store_config
else None
)

return cls(**config, default_options=options, metadata_store=store)

@abstractmethod
async def store(self, entries: list[VectorStoreEntry]) -> None:
Expand Down
Loading

0 comments on commit 0b6e1e1

Please sign in to comment.