Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: option for classes to create instances from factory path #240

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/evaluation/document-search/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hydra
from omegaconf import DictConfig, OmegaConf

from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.utils.config_handling import import_by_path
from ragbits.evaluate.loaders import dataloader_factory
from ragbits.evaluate.metrics import metric_set_factory
from ragbits.evaluate.optimizer import Optimizer
Expand All @@ -21,12 +21,12 @@ def main(config: DictConfig) -> None:
config: Hydra configuration.
"""
dataloader = dataloader_factory(config.data)
pipeline_class = get_cls_from_config(config.pipeline.type, module)
pipeline_class = import_by_path(config.pipeline.type, module)
metrics = metric_set_factory(config.metrics)
callback_configurators = None
if getattr(config, "callbacks", None):
callback_configurators = [
get_cls_from_config(callback_cfg.type, module)(callback_cfg.args) for callback_cfg in config.callbacks
import_by_path(callback_cfg.type, module)(callback_cfg.args) for callback_cfg in config.callbacks
]

optimization_cfg = OmegaConf.create({"direction": "maximize", "n_trials": 10})
Expand Down
6 changes: 2 additions & 4 deletions packages/ragbits-core/src/ragbits/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ragbits.cli.app import CLI
from ragbits.core.config import core_config
from ragbits.core.llms.base import LLMType
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.prompt.prompt import ChatFormat, Prompt


Expand Down Expand Up @@ -91,13 +91,11 @@ def execute(
Raises:
ValueError: If `llm_factory` is not provided.
"""
from ragbits.core.llms.factory import get_llm_from_factory

prompt = _render(prompt_path=prompt_path, payload=payload)

if llm_factory is None:
raise ValueError("`llm_factory` must be provided")
llm = get_llm_from_factory(llm_factory)
llm: LLM = LLM.subclass_from_factory(llm_factory)

llm_output = asyncio.run(llm.generate(prompt))
response = LLMResponseCliOutput(question=prompt.chat, answer=llm_output)
Expand Down
6 changes: 3 additions & 3 deletions packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class CoreConfig(BaseModel):

# Path to a functions that returns LLM objects, e.g. "my_project.llms.get_llm"
default_llm_factories: dict[LLMType, str] = {
LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_structured_output_factory",
LLMType.TEXT: "ragbits.core.llms.factory:simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory:simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory:simple_litellm_structured_output_factory",
}


Expand Down
20 changes: 1 addition & 19 deletions packages/ragbits-core/src/ragbits/core/llms/factory.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
import importlib

from ragbits.core.config import core_config
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.llms.litellm import LiteLLM


def get_llm_from_factory(factory_path: str) -> LLM:
"""
Get an instance of an LLM using a factory function specified by the user.

Args:
factory_path (str): The path to the factory function.

Returns:
LLM: An instance of the LLM class.
"""
module_name, function_name = factory_path.rsplit(".", 1)
module = importlib.import_module(module_name)
function = getattr(module, function_name)
return function()


def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:
"""
Get an instance of the default LLM using the factory function
Expand All @@ -34,7 +16,7 @@ def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:

"""
factory = core_config.default_llm_factories[llm_type]
return get_llm_from_factory(factory)
return LLM.subclass_from_factory(factory)


def simple_litellm_factory() -> LLM:
Expand Down
3 changes: 1 addition & 2 deletions packages/ragbits-core/src/ragbits/core/prompt/lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ragbits.core.config import core_config
from ragbits.core.llms import LLM
from ragbits.core.llms.base import LLMType
from ragbits.core.llms.factory import get_llm_from_factory
from ragbits.core.prompt import Prompt
from ragbits.core.prompt.discovery import PromptDiscovery

Expand Down Expand Up @@ -166,7 +165,7 @@ def lab_app( # pylint: disable=missing-param-doc
prompts_state = gr.State(
PromptState(
prompts=list(prompts),
llm=get_llm_from_factory(llm_factory) if llm_factory else None,
llm=LLM.subclass_from_factory(llm_factory) if llm_factory else None,
)
)

Expand Down
55 changes: 39 additions & 16 deletions packages/ragbits-core/src/ragbits/core/utils/config_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,39 @@ class InvalidConfigError(Exception):
"""


def get_cls_from_config(cls_path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401
def import_by_path(path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401
ludwiktrammer marked this conversation as resolved.
Show resolved Hide resolved
"""
Retrieves and returns a class based on the given type string. The class can be either in the
default module or a specified module if provided in the type string.
Retrieves and returns an object based on the string in the format of "module.submodule:object_name".
If the first part is ommited, the default module is used.

Args:
cls_path: A string representing the path to the class or object. This can either be a
path implicitly referencing the default module or a full path (module.submodule:ClassName)
if the class is located in a different module.
default_module: The default module to search for the class if no specific module
is provided in the type string.
path: A string representing the path to the object. This can either be a
path implicitly referencing the default module or a full path (module.submodule:object_name)
if the object is located in a different module.
default_module: The default module to search for the object if no specific module
is provided in the path string.

Returns:
Any: The object retrieved from the specified or default module.

Raises:
InvalidConfigError: The requested class is not found under the specified module
InvalidConfigError: The requested object is not found under the specified module
"""
if ":" in cls_path:
if ":" in path:
try:
module_stringified, object_stringified = cls_path.split(":")
module_stringified, object_stringified = path.split(":")
module = import_module(module_stringified)
return getattr(module, object_stringified)
except AttributeError as err:
raise InvalidConfigError(f"Class {object_stringified} not found in module {module_stringified}") from err
raise InvalidConfigError(f"{object_stringified} not found in module {module_stringified}") from err

if default_module is None:
raise InvalidConfigError("Given type string does not contain a module and no default module provided")
raise InvalidConfigError("Not provided a full path and no default module specified")

try:
return getattr(default_module, cls_path)
return getattr(default_module, path)
except AttributeError as err:
raise InvalidConfigError(f"Class {cls_path} not found in module {default_module}") from err
raise InvalidConfigError(f"{path} not found in module {default_module}") from err


class ObjectContructionConfig(BaseModel):
Expand Down Expand Up @@ -83,12 +83,35 @@ def subclass_from_config(cls, config: ObjectContructionConfig) -> Self:
Raises:
InvalidConfigError: The class can't be found or is not a subclass of the current class.
"""
subclass = get_cls_from_config(config.type, cls.default_module)
subclass = import_by_path(config.type, cls.default_module)
if not issubclass(subclass, cls):
raise InvalidConfigError(f"{subclass} is not a subclass of {cls}")

return subclass.from_config(config.config)

@classmethod
def subclass_from_factory(cls, factory_path: str) -> Self:
"""
Creates the class using the provided factory function. May return a subclass of the class,
if requested by the factory.

Args:
factory_path: A string representing the path to the factory function
in the format of "module.submodule:factory_name".

Returns:
An instance of the class initialized with the provided factory function.

Raises:
InvalidConfigError: The factory can't be found or the object returned
is not a subclass of the current class.
"""
factory = import_by_path(factory_path, cls.default_module)
obj = factory()
if not isinstance(obj, cls):
raise InvalidConfigError(f"The object returned by factory {factory_path} is not an instance of {cls}")
return obj

@classmethod
def from_config(cls, config: dict) -> Self:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ragbits.core.audit import traceable
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.utils.config_handling import ObjectContructionConfig, get_cls_from_config
from ragbits.core.utils.config_handling import ObjectContructionConfig, import_by_path
from ragbits.core.utils.dict_transformations import flatten_dict, unflatten_dict
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery

Expand Down Expand Up @@ -59,7 +59,7 @@ def from_config(cls, config: dict) -> Self:
InvalidConfigError: The client or metadata_store class can't be found or is not the correct type.
"""
client_options = ObjectContructionConfig.model_validate(config["client"])
client_cls = get_cls_from_config(client_options.type, chromadb)
client_cls = import_by_path(client_options.type, chromadb)
config["client"] = client_cls(**client_options.config)
return super().from_config(config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ragbits.core.audit import traceable
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.utils.config_handling import ObjectContructionConfig, get_cls_from_config
from ragbits.core.utils.config_handling import ObjectContructionConfig, import_by_path
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions


Expand Down Expand Up @@ -56,7 +56,7 @@ def from_config(cls, config: dict) -> Self:
InvalidConfigError: The client or metadata_store class can't be found or is not the correct type.
"""
client_options = ObjectContructionConfig.model_validate(config["client"])
client_cls = get_cls_from_config(client_options.type, qdrant_client)
client_cls = import_by_path(client_options.type, qdrant_client)
config["client"] = client_cls(**client_options.config)
return super().from_config(config)

Expand Down
5 changes: 0 additions & 5 deletions packages/ragbits-core/tests/unit/llms/factory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
import sys
from pathlib import Path

# Add "llms" to sys.path
sys.path.append(str(Path(__file__).parent.parent))
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,22 @@
from ragbits.core.llms.litellm import LiteLLM


def mock_llm_factory() -> LiteLLM:
"""
A mock LLM factory that creates a LiteLLM instance with a mock model name.

Returns:
LiteLLM: An instance of the LiteLLM.
"""
return LiteLLM(model_name="mock_model")


def test_get_default_llm(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Test the get_llm_from_factory function.
"""
monkeypatch.setattr(
core_config, "default_llm_factories", {LLMType.TEXT: "factory.test_get_llm_from_factory.mock_llm_factory"}
core_config, "default_llm_factories", {LLMType.TEXT: "unit.llms.factory.test_get_default_llm:mock_llm_factory"}
)

llm = get_default_llm()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def test_get_config_instance_factories():
)

assert config.default_llm_factories == {
LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_vision_factory",
LLMType.TEXT: "ragbits.core.llms.factory:simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory:simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory:simple_litellm_vision_factory",
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
name = "bad_factory_project"

[tool.ragbits.core.default_llm_factories]
non_existing = "ragbits.core.llms.factory.simple_litellm_factory"
vision = "ragbits.core.llms.factory.simple_litellm_vision_factory"
structured_output = "ragbits.core.llms.factory.simple_litellm_vision_factory"
non_existing = "ragbits.core.llms.factory:simple_litellm_factory"
vision = "ragbits.core.llms.factory:simple_litellm_vision_factory"
micpst marked this conversation as resolved.
Show resolved Hide resolved
structured_output = "ragbits.core.llms.factory:simple_litellm_vision_factory"
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
name = "factory_project"

[tool.ragbits.core.default_llm_factories]
text = "ragbits.core.llms.factory.simple_litellm_factory"
vision = "ragbits.core.llms.factory.simple_litellm_vision_factory"
structured_output = "ragbits.core.llms.factory.simple_litellm_vision_factory"
text = "ragbits.core.llms.factory:simple_litellm_factory"
vision = "ragbits.core.llms.factory:simple_litellm_vision_factory"
micpst marked this conversation as resolved.
Show resolved Hide resolved
structured_output = "ragbits.core.llms.factory:simple_litellm_vision_factory"
16 changes: 16 additions & 0 deletions packages/ragbits-core/tests/unit/utils/test_config_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def __init__(self, foo: str, bar: int) -> None:
self.bar = bar


def example_factory() -> ExampleClassWithConfigMixin:
return ExampleSubclass("aligator", 42)


def test_defacult_from_config():
config = {"foo": "foo", "bar": 1}
instance = ExampleClassWithConfigMixin.from_config(config)
Expand Down Expand Up @@ -62,3 +66,15 @@ def test_no_default_module():
)
with pytest.raises(InvalidConfigError):
ExampleWithNoDefaultModule.subclass_from_config(config)


def test_subclass_from_factory():
instance = ExampleClassWithConfigMixin.subclass_from_factory("unit.utils.test_config_handling:example_factory")
assert isinstance(instance, ExampleSubclass)
assert instance.foo == "aligator"
assert instance.bar == 42


def test_subclass_from_factory_incorrect_class():
with pytest.raises(InvalidConfigError):
ExampleWithNoDefaultModule.subclass_from_factory("unit.utils.test_config_handling:example_factory")
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel

from ragbits.core.prompt.prompt import Prompt
from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.utils.config_handling import import_by_path

module = sys.modules[__name__]

Expand Down Expand Up @@ -46,7 +46,7 @@ def get_rephraser_prompt(prompt: str) -> type[Prompt[QueryRephraserInput, Any]]:
Raises:
ValueError: If the prompt class is not a subclass of `Prompt`.
"""
prompt_cls = get_cls_from_config(prompt, module)
prompt_cls = import_by_path(prompt, module)

if not issubclass(prompt_cls, Prompt):
raise ValueError(f"Invalid rephraser prompt class: {prompt_cls}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import neptune

from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.utils.config_handling import import_by_path

from .base import CallbackConfigurator

Expand All @@ -21,6 +21,6 @@ def get_callback(self) -> Callable:
Returns:
Callable: configured neptune callback
"""
callback_class = get_cls_from_config(self.config.callback_type, module)
callback_class = import_by_path(self.config.callback_type, module)
run = neptune.init_run(project=self.config.project)
return callback_class(run)
Loading
Loading