diff --git a/docs/cli/main.md b/docs/cli/main.md new file mode 100644 index 000000000..aef6d9e29 --- /dev/null +++ b/docs/cli/main.md @@ -0,0 +1,11 @@ +# Ragbits CLI + +Ragbits comes with a command line interface (CLI) that provides a number of commands for working with the Ragbits platform. It can be accessed by running the `ragbits` command in your terminal. + +::: mkdocs-click + :module: ragbits.cli + :command: _click_app + :prog_name: ragbits + :style: table + :list_subcommands: true + :depth: 1 \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 75f9d05aa..89b163158 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -29,6 +29,8 @@ nav: - how-to/evaluate/custom_evaluation_pipeline.md - how-to/evaluate/custom_metric.md - how-to/evaluate/custom_dataloader.md + - CLI: + - cli/main.md - API Reference: - Core: - api_reference/core/prompt.md @@ -41,6 +43,8 @@ nav: - Ingestion: - api_reference/document_search/processing.md - api_reference/document_search/execution_strategies.md +hooks: + - mkdocs_hooks.py theme: name: material icon: @@ -69,6 +73,8 @@ theme: - navigation.top - content.code.annotate - content.code.copy + - toc.integrate + - toc.follow extra_css: - stylesheets/extra.css markdown_extensions: @@ -94,6 +100,7 @@ markdown_extensions: alternate_style: true - toc: permalink: "#" + - mkdocs-click plugins: - search - autorefs diff --git a/mkdocs_hooks.py b/mkdocs_hooks.py new file mode 100644 index 000000000..5881b3856 --- /dev/null +++ b/mkdocs_hooks.py @@ -0,0 +1,14 @@ +from typing import Literal + +from ragbits import cli + + +def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool) -> None: + """ + Hook that runs during mkdocs startup. + + Args: + command: The command that is being run. + dirty: whether --dirty flag was passed. + """ + cli._init_for_mkdocs() diff --git a/packages/ragbits-cli/src/ragbits/cli/__init__.py b/packages/ragbits-cli/src/ragbits/cli/__init__.py index 2fb825409..ec70a3c38 100644 --- a/packages/ragbits-cli/src/ragbits/cli/__init__.py +++ b/packages/ragbits-cli/src/ragbits/cli/__init__.py @@ -3,32 +3,39 @@ from pathlib import Path from typing import Annotated +import click import typer +from typer.main import get_command import ragbits -from .app import CLI, OutputType +from .state import OutputType, cli_state, print_output -app = CLI(no_args_is_help=True) +__all__ = [ + "OutputType", + "app", + "cli_state", + "print_output", +] + +app = typer.Typer(no_args_is_help=True) +_click_app: click.Command | None = None # initialized in the `init_for_mkdocs` function @app.callback() -def output_type( +def ragbits_cli( # `OutputType.text.value` used as a workaround for the issue with `typer.Option` not accepting Enum values output: Annotated[ OutputType, typer.Option("--output", "-o", help="Set the output type (text or json)") ] = OutputType.text.value, # type: ignore ) -> None: - """Sets an output type for the CLI - Args: - output: type of output to be set - """ - app.set_output_type(output_type=output) + """Common CLI arguments for all ragbits commands.""" + cli_state.output_type = output -def main() -> None: +def autoregister() -> None: """ - Main entry point for the CLI. + Autodiscover and register all the CLI modules in the ragbits packages. This function registers all the CLI modules in the ragbits packages: - iterates over every package in the ragbits.* namespace @@ -46,4 +53,23 @@ def main() -> None: register_func = importlib.import_module(f"ragbits.{module.name}.cli").register register_func(app) + +def _init_for_mkdocs() -> None: + """ + Initializes the CLI app for the mkdocs environment. + + This function registers all the CLI commands and sets the `_click_app` variable to a click + command object containing all the CLI commands. This way the `mkdocs-click` plugin can + create an automatic CLI documentation. + """ + global _click_app # noqa: PLW0603 + autoregister() + _click_app = get_command(app) + + +def main() -> None: + """ + Main entry point for the CLI. Registers all the CLI commands and runs the app. + """ + autoregister() app() diff --git a/packages/ragbits-cli/src/ragbits/cli/app.py b/packages/ragbits-cli/src/ragbits/cli/app.py deleted file mode 100644 index 37a29067f..000000000 --- a/packages/ragbits-cli/src/ragbits/cli/app.py +++ /dev/null @@ -1,76 +0,0 @@ -import json -from dataclasses import dataclass -from enum import Enum -from typing import Any - -import typer -from pydantic import BaseModel -from rich.console import Console -from rich.table import Table - - -class OutputType(Enum): - """Indicates a type of CLI output formatting""" - - text = "text" - json = "json" - - -@dataclass() -class CliState: - """A dataclass describing CLI state""" - - output_type: OutputType = OutputType.text - - -class CLI(typer.Typer): - """A CLI class with output formatting""" - - def __init__(self, *args: Any, **kwargs: Any): # noqa: ANN401 - super().__init__(*args, **kwargs) - self.state: CliState = CliState() - self.console: Console = Console() - - def set_output_type(self, output_type: OutputType) -> None: - """ - Set the output type in the app state - Args: - output_type: OutputType - """ - self.state.output_type = output_type - - def print_output(self, data: list[BaseModel] | BaseModel) -> None: - """ - Process and display output based on the current state's output type. - - Args: - data: list of ditionaries or list of pydantic models representing output of CLI function - """ - if isinstance(data, BaseModel): - data = [data] - if len(data) == 0: - self._print_empty_list() - return - first_el_instance = type(data[0]) - if any(not isinstance(datapoint, first_el_instance) for datapoint in data): - raise ValueError("All the rows need to be of the same type") - data_dicts: list[dict] = [output.model_dump(mode="python") for output in data] - output_type = self.state.output_type - if output_type == OutputType.json: - print(json.dumps(data_dicts, indent=4)) - elif output_type == OutputType.text: - table = Table(show_header=True, header_style="bold magenta") - properties = data[0].model_json_schema()["properties"] - for key in properties: - table.add_column(properties[key]["title"]) - for row in data_dicts: - table.add_row(*[str(value) for value in row.values()]) - self.console.print(table) - else: - raise ValueError(f"Output type: {output_type} not supported") - - def _print_empty_list(self) -> None: - if self.state.output_type == OutputType.text: - print("Empty data list") - elif self.state.output_type == OutputType.json: - print(json.dumps([])) diff --git a/packages/ragbits-cli/src/ragbits/cli/state.py b/packages/ragbits-cli/src/ragbits/cli/state.py new file mode 100644 index 000000000..032d8d283 --- /dev/null +++ b/packages/ragbits-cli/src/ragbits/cli/state.py @@ -0,0 +1,64 @@ +import json +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum + +from pydantic import BaseModel +from rich.console import Console +from rich.table import Table + + +class OutputType(Enum): + """Indicates a type of CLI output formatting""" + + text = "text" + json = "json" + + +@dataclass() +class CliState: + """A dataclass describing CLI state""" + + output_type: OutputType = OutputType.text + + +cli_state = CliState() + + +def print_output(data: Sequence[BaseModel] | BaseModel) -> None: + """ + Process and display output based on the current state's output type. + + Args: + data: a list of pydantic models representing output of CLI function + """ + console = Console() + if isinstance(data, BaseModel): + data = [data] + if len(data) == 0: + _print_empty_list() + return + first_el_instance = type(data[0]) + if any(not isinstance(datapoint, first_el_instance) for datapoint in data): + raise ValueError("All the rows need to be of the same type") + data_dicts: list[dict] = [output.model_dump(mode="python") for output in data] + output_type = cli_state.output_type + if output_type == OutputType.json: + console.print(json.dumps(data_dicts, indent=4)) + elif output_type == OutputType.text: + table = Table(show_header=True, header_style="bold magenta") + properties = data[0].model_json_schema()["properties"] + for key in properties: + table.add_column(properties[key]["title"]) + for row in data_dicts: + table.add_row(*[str(value) for value in row.values()]) + console.print(table) + else: + raise ValueError(f"Output type: {output_type} not supported") + + +def _print_empty_list() -> None: + if cli_state.output_type == OutputType.text: + print("Empty data list") + elif cli_state.output_type == OutputType.json: + print(json.dumps([])) diff --git a/packages/ragbits-core/src/ragbits/core/cli.py b/packages/ragbits-core/src/ragbits/core/cli.py index 4e962b769..92768e649 100644 --- a/packages/ragbits-core/src/ragbits/core/cli.py +++ b/packages/ragbits-core/src/ragbits/core/cli.py @@ -1,104 +1,15 @@ -# pylint: disable=import-outside-toplevel -# pylint: disable=missing-param-doc -import asyncio -import json -from importlib import import_module -from pathlib import Path - import typer -from pydantic import BaseModel - -from ragbits.cli.app import CLI -from ragbits.core.config import core_config -from ragbits.core.llms.base import LLM, LLMType -from ragbits.core.prompt.prompt import ChatFormat, Prompt - - -def _render(prompt_path: str, payload: str | None) -> Prompt: - module_stringified, object_stringified = prompt_path.split(":") - prompt_cls = getattr(import_module(module_stringified), object_stringified) - - if payload is not None: - payload = json.loads(payload) - inputs = prompt_cls.input_type(**payload) - return prompt_cls(inputs) - - return prompt_cls() - - -class LLMResponseCliOutput(BaseModel): - """An output model for llm responses in CLI""" - - question: ChatFormat - answer: str | BaseModel | None = None - -prompts_app = typer.Typer(no_args_is_help=True) +from ragbits.core.prompt._cli import prompts_app +from ragbits.core.vector_stores._cli import vector_stores_app -def register(app: CLI) -> None: +def register(app: typer.Typer) -> None: """ Register the CLI commands for the package. Args: app: The Typer object to register the commands with. """ - - @prompts_app.command() - def lab( - file_pattern: str = core_config.prompt_path_pattern, - llm_factory: str = core_config.default_llm_factories[LLMType.TEXT], - ) -> None: - """ - Launches the interactive application for listing, rendering, and testing prompts - defined within the current project. - """ - from ragbits.core.prompt.lab.app import lab_app - - lab_app(file_pattern=file_pattern, llm_factory=llm_factory) - - @prompts_app.command() - def generate_promptfoo_configs( - file_pattern: str = core_config.prompt_path_pattern, - root_path: Path = Path.cwd(), # noqa: B008 - target_path: Path = Path("promptfooconfigs"), - ) -> None: - """ - Generates the configuration files for the PromptFoo prompts. - """ - from ragbits.core.prompt.promptfoo import generate_configs - - generate_configs(file_pattern=file_pattern, root_path=root_path, target_path=target_path) - - @prompts_app.command() - def render(prompt_path: str, payload: str | None = None) -> None: - """ - Renders a prompt by loading a class from a module and initializing it with a given payload. - """ - prompt = _render(prompt_path=prompt_path, payload=payload) - response = LLMResponseCliOutput(question=prompt.chat) - app.print_output(response) - - @prompts_app.command(name="exec") - def execute( - prompt_path: str, - payload: str | None = None, - llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT], - ) -> None: - """ - Executes a prompt using the specified prompt class and LLM factory. - - Raises: - ValueError: If `llm_factory` is not provided. - """ - prompt = _render(prompt_path=prompt_path, payload=payload) - - if llm_factory is None: - raise ValueError("`llm_factory` must be provided") - llm: LLM = LLM.subclass_from_factory(llm_factory) - - llm_output = asyncio.run(llm.generate(prompt)) - response = LLMResponseCliOutput(question=prompt.chat, answer=llm_output) - app.print_output(response) - app.add_typer(prompts_app, name="prompts", help="Commands for managing prompts") + app.add_typer(vector_stores_app, name="vector-store", help="Commands for managing vector stores") diff --git a/packages/ragbits-core/src/ragbits/core/config.py b/packages/ragbits-core/src/ragbits/core/config.py index 77168f2ac..4601a4507 100644 --- a/packages/ragbits-core/src/ragbits/core/config.py +++ b/packages/ragbits-core/src/ragbits/core/config.py @@ -1,7 +1,10 @@ +from functools import cached_property +from pathlib import Path + from pydantic import BaseModel from ragbits.core.llms.base import LLMType -from ragbits.core.utils._pyproject import get_config_instance +from ragbits.core.utils._pyproject import get_config_from_yaml, get_config_instance class CoreConfig(BaseModel): @@ -9,6 +12,9 @@ class CoreConfig(BaseModel): Configuration for the ragbits-core package, loaded from downstream projects' pyproject.toml files. """ + # Path to the base directory of the project, defaults to the directory of the pyproject.toml file + project_base_path: Path | None = None + # Pattern used to search for prompt files prompt_path_pattern: str = "**/prompt_*.py" @@ -19,5 +25,24 @@ class CoreConfig(BaseModel): LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory:simple_litellm_structured_output_factory", } + # Path to functions that returns instances of diffrent types of Ragbits objects + default_factories: dict[str, str] = {} + + # Path to a YAML file with default configuration of varius Ragbits objects + default_instaces_config_path: Path | None = None + + @cached_property + def default_instances_config(self) -> dict: + """ + Get the configuration from the file specified in default_instaces_config_path. + + Returns: + dict: The configuration from the file. + """ + if self.default_instaces_config_path is None or not self.project_base_path: + return {} + + return get_config_from_yaml(self.project_base_path / self.default_instaces_config_path) + core_config = get_config_instance(CoreConfig, subproject="core") diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/base.py b/packages/ragbits-core/src/ragbits/core/embeddings/base.py index 460476c95..130113fa2 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/base.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/base.py @@ -27,6 +27,7 @@ class Embeddings(WithConstructionConfig, ABC): """ default_module: ClassVar = embeddings + configuration_key: ClassVar = "embedder" @abstractmethod async def embed_text(self, data: list[str]) -> list[list[float]]: diff --git a/packages/ragbits-core/src/ragbits/core/llms/base.py b/packages/ragbits-core/src/ragbits/core/llms/base.py index 8ff01821f..0f46811e1 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/base.py @@ -31,6 +31,7 @@ class LLM(WithConstructionConfig, Generic[LLMClientOptions], ABC): _options_cls: type[LLMClientOptions] default_module: ClassVar = llms + configuration_key: ClassVar = "llm" def __init__(self, model_name: str, default_options: LLMOptions | None = None) -> None: """ diff --git a/packages/ragbits-core/src/ragbits/core/metadata_stores/base.py b/packages/ragbits-core/src/ragbits/core/metadata_stores/base.py index f5a5767b2..6ad011c1a 100644 --- a/packages/ragbits-core/src/ragbits/core/metadata_stores/base.py +++ b/packages/ragbits-core/src/ragbits/core/metadata_stores/base.py @@ -11,6 +11,7 @@ class MetadataStore(WithConstructionConfig, ABC): """ default_module: ClassVar = metadata_stores + configuration_key: ClassVar = "metadata_store" @abstractmethod async def store(self, ids: list[str], metadatas: list[dict]) -> None: diff --git a/packages/ragbits-core/src/ragbits/core/prompt/_cli.py b/packages/ragbits-core/src/ragbits/core/prompt/_cli.py new file mode 100644 index 000000000..605e517ac --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/prompt/_cli.py @@ -0,0 +1,97 @@ +import asyncio +import json +from importlib import import_module +from pathlib import Path + +import typer +from pydantic import BaseModel + +from ragbits.cli import print_output +from ragbits.core.config import core_config +from ragbits.core.llms.base import LLM, LLMType +from ragbits.core.prompt.prompt import ChatFormat, Prompt + +prompts_app = typer.Typer(no_args_is_help=True) + + +class LLMResponseCliOutput(BaseModel): + """An output model for llm responses in CLI""" + + question: ChatFormat + answer: str | BaseModel | None = None + + +def _render(prompt_path: str, payload: str | None) -> Prompt: + module_stringified, object_stringified = prompt_path.split(":") + prompt_cls = getattr(import_module(module_stringified), object_stringified) + + if payload is not None: + payload = json.loads(payload) + inputs = prompt_cls.input_type(**payload) + return prompt_cls(inputs) + + return prompt_cls() + + +@prompts_app.command() +def lab( + file_pattern: str = core_config.prompt_path_pattern, + llm_factory: str = core_config.default_llm_factories[LLMType.TEXT], +) -> None: + """ + Launches the interactive application for listing, rendering, and testing prompts + defined within the current project. + + For more information, see the [Prompts Lab documentation](../how-to/prompts_lab.md). + """ + from ragbits.core.prompt.lab.app import lab_app + + lab_app(file_pattern=file_pattern, llm_factory=llm_factory) + + +@prompts_app.command() +def generate_promptfoo_configs( + file_pattern: str = core_config.prompt_path_pattern, + root_path: Path = Path.cwd(), # noqa: B008 + target_path: Path = Path("promptfooconfigs"), +) -> None: + """ + Generates the configuration files for the PromptFoo prompts. + + For more information, see the [Promptfoo integration documentation](../how-to/integrations/promptfoo.md). + """ + from ragbits.core.prompt.promptfoo import generate_configs + + generate_configs(file_pattern=file_pattern, root_path=root_path, target_path=target_path) + + +@prompts_app.command() +def render(prompt_path: str, payload: str | None = None) -> None: + """ + Renders a prompt by loading a class from a module and initializing it with a given payload. + """ + prompt = _render(prompt_path=prompt_path, payload=payload) + response = LLMResponseCliOutput(question=prompt.chat) + print_output(response) + + +@prompts_app.command(name="exec") +def execute( + prompt_path: str, + payload: str | None = None, + llm_factory: str = core_config.default_llm_factories[LLMType.TEXT], +) -> None: + """ + Executes a prompt using the specified prompt class and LLM factory. + + For an example of how to use this command, see the [Quickstart guide](../quickstart/quickstart1_prompts.md). + """ + prompt = _render(prompt_path=prompt_path, payload=payload) + + if llm_factory is None: + raise ValueError("`llm_factory` must be provided") + llm: LLM = LLM.subclass_from_factory(llm_factory) + + llm_output = asyncio.run(llm.generate(prompt)) + response = LLMResponseCliOutput(question=prompt.chat, answer=llm_output) + print_output(response) diff --git a/packages/ragbits-core/src/ragbits/core/utils/_pyproject.py b/packages/ragbits-core/src/ragbits/core/utils/_pyproject.py index 2c3f9831f..479e33f7a 100644 --- a/packages/ragbits-core/src/ragbits/core/utils/_pyproject.py +++ b/packages/ragbits-core/src/ragbits/core/utils/_pyproject.py @@ -1,12 +1,10 @@ -import enum from pathlib import Path from typing import Any, TypeVar import tomli +import yaml from pydantic import BaseModel -from ragbits.core.llms.base import LLMType - def find_pyproject(current_dir: Path | None = None) -> Path: """ @@ -57,7 +55,13 @@ def get_ragbits_config(current_dir: Path | None = None) -> dict[str, Any]: with pyproject.open("rb") as f: pyproject_data = tomli.load(f) - return pyproject_data.get("tool", {}).get("ragbits", {}) + + config = pyproject_data.get("tool", {}).get("ragbits", {}) + + # Detect project base path from pyproject.toml location + if "project_base_path" not in config: + config["project_base_path"] = str(pyproject.absolute().parent) + return config ConfigModelT = TypeVar("ConfigModelT", bound=BaseModel) @@ -83,16 +87,28 @@ def get_config_instance( config = get_ragbits_config(current_dir) if subproject: - config = config.get(subproject, {}) - if "default_llm_factories" in config: - config["default_llm_factories"] = { - _resolve_enum_member(k): v for k, v in config["default_llm_factories"].items() + config = { + **config.get(subproject, {}), + "project_base_path": config.get("project_base_path"), } - return model(**config) + return model.model_validate(config) -def _resolve_enum_member(enum_string: str) -> enum.Enum: - try: - return LLMType(enum_string) - except ValueError as err: - raise ValueError("Unsupported LLMType value provided in default_llm_factories in pyproject.toml") from err +def get_config_from_yaml(yaml_path: Path) -> dict: + """ + Reads a YAML file and returns its content as a dictionary. + + Args: + yaml_path: The path to the YAML file. + + Returns: + dict: The content of the YAML file as a dictionary. + + Raises: + ValueError: If the YAML file does not contain a dictionary. + """ + with open(yaml_path) as file: + obj = yaml.safe_load(file) + if not isinstance(obj, dict): + raise ValueError(f"Expected a dictionary in {yaml_path}") + return obj diff --git a/packages/ragbits-core/src/ragbits/core/utils/config_handling.py b/packages/ragbits-core/src/ragbits/core/utils/config_handling.py index 861ffff41..b69596972 100644 --- a/packages/ragbits-core/src/ragbits/core/utils/config_handling.py +++ b/packages/ragbits-core/src/ragbits/core/utils/config_handling.py @@ -1,11 +1,19 @@ +from __future__ import annotations + import abc from importlib import import_module +from pathlib import Path from types import ModuleType -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar from pydantic import BaseModel from typing_extensions import Self +from ragbits.core.utils._pyproject import get_config_from_yaml + +if TYPE_CHECKING: + from ragbits.core.config import CoreConfig + class InvalidConfigError(Exception): """ @@ -68,6 +76,9 @@ class WithConstructionConfig(abc.ABC): # The default module to search for the subclass if no specific module is provided in the type string. default_module: ClassVar[ModuleType | None] = None + # The key under configuration for this class (and its subclasses) can be found. + configuration_key: ClassVar[str] + @classmethod def subclass_from_config(cls, config: ObjectContructionConfig) -> Self: """ @@ -112,6 +123,40 @@ def subclass_from_factory(cls, factory_path: str) -> Self: raise InvalidConfigError(f"The object returned by factory {factory_path} is not an instance of {cls}") return obj + @classmethod + def subclass_from_defaults( + cls, defaults: CoreConfig, factory_path_override: str | None = None, yaml_path_override: Path | None = None + ) -> Self: + """ + Tries to create an instance by looking at default configuration file, and default factory function. + Takes optional overrides for both, which takes a higher precedence. + + Args: + defaults: The CoreConfig instance containing default factory and configuration details. + factory_path_override: A string representing the path to the factory function + in the format of "module.submodule:factory_name". + yaml_path_override: A string representing the path to the YAML file containing + the Ragstack instance configuration. + + Raises: + InvalidConfigError: If the default factory or configuration can't be found. + """ + if yaml_path_override: + config = get_config_from_yaml(yaml_path_override) + if type_config := config.get(cls.configuration_key): + return cls.subclass_from_config(ObjectContructionConfig.model_validate(type_config)) + + if factory_path_override: + return cls.subclass_from_factory(factory_path_override) + + if default_factory := defaults.default_factories.get(cls.configuration_key): + return cls.subclass_from_factory(default_factory) + + if default_config := defaults.default_instances_config.get(cls.configuration_key): + return cls.subclass_from_config(ObjectContructionConfig.model_validate(default_config)) + + raise InvalidConfigError(f"Could not find default factory or configuration for {cls.configuration_key}") + @classmethod def from_config(cls, config: dict) -> Self: """ diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py b/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py new file mode 100644 index 000000000..5bac47daf --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py @@ -0,0 +1,117 @@ +import asyncio +from dataclasses import dataclass +from pathlib import Path + +import typer +from pydantic import BaseModel +from rich.console import Console + +from ragbits.cli import cli_state, print_output +from ragbits.cli.state import OutputType +from ragbits.core.config import core_config +from ragbits.core.embeddings.base import Embeddings +from ragbits.core.utils.config_handling import InvalidConfigError +from ragbits.core.vector_stores.base import VectorStore, VectorStoreOptions + +vector_stores_app = typer.Typer(no_args_is_help=True) + + +@dataclass +class CLIState: + vector_store: VectorStore | None = None + + +state: CLIState = CLIState() + + +@vector_stores_app.callback() +def common_args( + factory_path: str | None = None, + yaml_path: str | None = None, +) -> None: + try: + state.vector_store = VectorStore.subclass_from_defaults( + core_config, + factory_path_override=factory_path, + yaml_path_override=Path.cwd() / yaml_path if yaml_path else None, + ) + except InvalidConfigError as e: + Console(stderr=True).print(e) + raise typer.Exit(1) from e + + +@vector_stores_app.command(name="list") +def list_entries(limit: int = 10, offset: int = 0) -> None: + """ + List all objects in the chosen vector store. + """ + + async def run() -> None: + if state.vector_store is None: + raise ValueError("Vector store not initialized") + + entries = await state.vector_store.list(limit=limit, offset=offset) + print_output(entries) + + asyncio.run(run()) + + +class RemovedItem(BaseModel): + id: str + + +@vector_stores_app.command() +def remove(ids: list[str]) -> None: + """ + Remove objects from the chosen vector store. + """ + + async def run() -> None: + if state.vector_store is None: + raise ValueError("Vector store not initialized") + + await state.vector_store.remove(ids) + if cli_state.output_type == OutputType.text: + typer.echo(f"Removed entries with IDs: {', '.join(ids)}") + else: + print_output([RemovedItem(id=id) for id in ids]) + + asyncio.run(run()) + + +@vector_stores_app.command() +def query( + text: str, + k: int = 5, + max_distance: float | None = None, + embedder_factory_path: str | None = None, + embedder_yaml_path: str | None = None, +) -> None: + """ + Query the chosen vector store. + """ + + async def run() -> None: + if state.vector_store is None: + raise ValueError("Vector store not initialized") + + try: + embedder = Embeddings.subclass_from_defaults( + core_config, + factory_path_override=embedder_factory_path, + yaml_path_override=Path.cwd() / embedder_yaml_path if embedder_yaml_path else None, + ) + except InvalidConfigError as e: + Console(stderr=True).print(e) + raise typer.Exit(1) from e + + search_vector = await embedder.embed_text([text]) + + options = VectorStoreOptions(k=k, max_distance=max_distance) + entries = await state.vector_store.retrieve( + vector=search_vector[0], + options=options, + ) + print_output(entries) + + asyncio.run(run()) diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/base.py b/packages/ragbits-core/src/ragbits/core/vector_stores/base.py index 1b74b52b3..e8adb746a 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/base.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/base.py @@ -37,6 +37,7 @@ class VectorStore(WithConstructionConfig, ABC): """ default_module: ClassVar = vector_stores + configuration_key: ClassVar = "vector_store" def __init__( self, diff --git a/packages/ragbits-core/tests/cli/__init__.py b/packages/ragbits-core/tests/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/ragbits-core/tests/cli/test_vector_store.py b/packages/ragbits-core/tests/cli/test_vector_store.py new file mode 100644 index 000000000..2b8a9e570 --- /dev/null +++ b/packages/ragbits-core/tests/cli/test_vector_store.py @@ -0,0 +1,175 @@ +import asyncio +import json + +import pytest +from typer.testing import CliRunner + +from ragbits.cli import app as root_app +from ragbits.cli import autoregister +from ragbits.core.embeddings.base import Embeddings +from ragbits.core.embeddings.noop import NoopEmbeddings +from ragbits.core.vector_stores import InMemoryVectorStore, VectorStore +from ragbits.core.vector_stores._cli import vector_stores_app +from ragbits.core.vector_stores.base import VectorStoreEntry + +example_entries = [ + VectorStoreEntry(id="1", key="entry 1", vector=[4.0, 5.0], metadata={"key": "value"}), + VectorStoreEntry(id="2", key="entry 2", vector=[1.0, 2.0], metadata={"another_key": "another_value"}), + VectorStoreEntry(id="3", key="entry 3", vector=[7.0, 8.0], metadata={"foo": "bar", "baz": "qux"}), +] + + +def vector_store_factory() -> VectorStore: + """ + A factory function that creates an instance of the VectorStore with example entries. + + Returns: + VectorStore: An instance of the VectorStore. + """ + + async def add_examples(store: VectorStore) -> None: + await store.store(example_entries) + + store = InMemoryVectorStore() + asyncio.new_event_loop().run_until_complete(add_examples(store)) + return store + + +# A vector store that's persistant between factory runs, +# to test the remove command. +_vector_store_for_remove: VectorStore | None = None + + +@pytest.fixture(autouse=True) +def reset_vector_store_for_remove(): + """ + Make sure that the global variable for the vector store used in the remove test is reset before each test. + """ + global _vector_store_for_remove # noqa: PLW0603 + _vector_store_for_remove = None + + +def vector_store_factory_for_remove() -> VectorStore: + """ + A factory function that creates an instance of the VectorStore with example entries, + and stores it in a global variable to be used in the remove test. + + Returns: + VectorStore: An instance of the VectorStore. + """ + + async def add_examples(store: VectorStore) -> None: + await store.store(example_entries) + + global _vector_store_for_remove # noqa: PLW0603 + if _vector_store_for_remove is None: + _vector_store_for_remove = InMemoryVectorStore() + asyncio.new_event_loop().run_until_complete(add_examples(_vector_store_for_remove)) + return _vector_store_for_remove + + +def embedder_factory() -> Embeddings: + """ + A factory function that creates an instance of no-op Embeddings. + + Returns: + Embeddings: An instance of the Embeddings. + """ + return NoopEmbeddings() + + +def test_vector_store_cli_no_store(): + """ + Test the vector-store CLI command with no store. + + Args: + cli_runner: A CLI runner fixture. + """ + runner = CliRunner(mix_stderr=False) + result = runner.invoke(vector_stores_app, ["list"]) + assert "Could not find default factory or configuration" in result.stderr + + +def test_vector_store_list(): + runner = CliRunner(mix_stderr=False) + result = runner.invoke( + vector_stores_app, + ["--factory-path", "cli.test_vector_store:vector_store_factory", "list"], + ) + assert result.exit_code == 0 + assert "entry 1" in result.stdout + assert "entry 2" in result.stdout + assert "entry 3" in result.stdout + + +def test_vector_store_list_limit_offset(): + runner = CliRunner(mix_stderr=False) + result = runner.invoke( + vector_stores_app, + ["--factory-path", "cli.test_vector_store:vector_store_factory", "list", "--limit", "1", "--offset", "1"], + ) + assert result.exit_code == 0 + assert "entry 1" not in result.stdout + assert "entry 2" in result.stdout + assert "entry 3" not in result.stdout + + +def test_vector_store_remove(): + runner = CliRunner(mix_stderr=False) + result = runner.invoke( + vector_stores_app, + ["--factory-path", "cli.test_vector_store:vector_store_factory_for_remove", "remove", "1", "3"], + ) + assert result.exit_code == 0 + assert "Removed entries with IDs: 1, 3" in result.stdout + + result = runner.invoke( + vector_stores_app, ["--factory-path", "cli.test_vector_store:vector_store_factory_for_remove", "list"] + ) + assert result.exit_code == 0 + assert "entry 1" not in result.stdout + assert "entry 2" in result.stdout + assert "entry 3" not in result.stdout + + +def test_vector_store_query(): + runner = CliRunner(mix_stderr=False) + result = runner.invoke( + vector_stores_app, + [ + "--factory-path", + "cli.test_vector_store:vector_store_factory", + "query", + "--embedder-factory-path", + "cli.test_vector_store:embedder_factory", + "--k", + "1", + "example query", + ], + ) + print(result.stderr) + assert result.exit_code == 0 + assert "entry 1" not in result.stdout + assert "entry 2" in result.stdout + assert "entry 3" not in result.stdout + + +def test_vector_store_list_json(): + autoregister() + runner = CliRunner(mix_stderr=False) + result = runner.invoke( + root_app, + [ + "--output", + "json", + "vector-store", + "--factory-path", + "cli.test_vector_store:vector_store_factory", + "list", + ], + ) + print(result.stderr) + assert result.exit_code == 0 + dicts = json.loads(result.stdout) + entries = [VectorStoreEntry.model_validate(entry) for entry in dicts] + assert entries == example_entries diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/test_find.py b/packages/ragbits-core/tests/unit/utils/pyproject/test_find.py index 2694721ad..7c9fcde3e 100644 --- a/packages/ragbits-core/tests/unit/utils/pyproject/test_find.py +++ b/packages/ragbits-core/tests/unit/utils/pyproject/test_find.py @@ -4,7 +4,7 @@ from ragbits.core.utils._pyproject import find_pyproject -projects_dir = Path(__file__).parent / "testprojects" +projects_dir = Path(__file__).parent.parent / "testprojects" def test_find_in_current_dir(): diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/test_get_config.py b/packages/ragbits-core/tests/unit/utils/pyproject/test_get_config.py index 2c12dabde..4a0a663e3 100644 --- a/packages/ragbits-core/tests/unit/utils/pyproject/test_get_config.py +++ b/packages/ragbits-core/tests/unit/utils/pyproject/test_get_config.py @@ -2,7 +2,7 @@ from ragbits.core.utils._pyproject import get_ragbits_config -projects_dir = Path(__file__).parent / "testprojects" +projects_dir = Path(__file__).parent.parent / "testprojects" def test_get_config(): @@ -16,6 +16,7 @@ def test_get_config(): "is_happy": True, "happiness_level": 100, }, + "project_base_path": str(projects_dir / "happy_project"), } diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py b/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py index 6263bf4cc..489131feb 100644 --- a/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py +++ b/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py @@ -7,7 +7,7 @@ from ragbits.core.llms.base import LLMType from ragbits.core.utils._pyproject import get_config_instance -projects_dir = Path(__file__).parent / "testprojects" +projects_dir = Path(__file__).parent.parent / "testprojects" class HappyProjectConfig(BaseModel): @@ -95,4 +95,4 @@ def test_get_config_instance_bad_factories(): current_dir=projects_dir / "bad_factory_project", ) - assert "Unsupported LLMType value provided in default_llm_factories in pyproject.toml" in str(err.value) + assert "Input should be 'text', 'vision' or 'structured_output'" in str(err.value) diff --git a/packages/ragbits-core/tests/unit/utils/test_config_handling.py b/packages/ragbits-core/tests/unit/utils/test_config_handling.py index dd2ca8ce0..e76cbd7d9 100644 --- a/packages/ragbits-core/tests/unit/utils/test_config_handling.py +++ b/packages/ragbits-core/tests/unit/utils/test_config_handling.py @@ -1,12 +1,18 @@ import sys +from pathlib import Path import pytest +from ragbits.core.config import CoreConfig, core_config +from ragbits.core.utils._pyproject import get_config_instance from ragbits.core.utils.config_handling import InvalidConfigError, ObjectContructionConfig, WithConstructionConfig +projects_dir = Path(__file__).parent / "testprojects" + class ExampleClassWithConfigMixin(WithConstructionConfig): default_module = sys.modules[__name__] + configuration_key = "example" def __init__(self, foo: str, bar: int) -> None: self.foo = foo @@ -78,3 +84,37 @@ def test_subclass_from_factory(): def test_subclass_from_factory_incorrect_class(): with pytest.raises(InvalidConfigError): ExampleWithNoDefaultModule.subclass_from_factory("unit.utils.test_config_handling:example_factory") + + +def test_subclass_from_defaults_factory_override(): + instance = ExampleClassWithConfigMixin.subclass_from_defaults( + core_config, factory_path_override="unit.utils.test_config_handling:example_factory" + ) + assert isinstance(instance, ExampleSubclass) + assert instance.foo == "aligator" + assert instance.bar == 42 + + +def test_subclass_from_defaults_pyproject_factory(): + config = get_config_instance( + CoreConfig, + subproject="core", + current_dir=projects_dir / "project_with_instance_factory", + ) + instance = ExampleClassWithConfigMixin.subclass_from_defaults(config) + assert isinstance(instance, ExampleSubclass) + assert instance.foo == "aligator" + assert instance.bar == 42 + + +def test_subclass_from_defaults_instance_yaml(): + config = get_config_instance( + CoreConfig, + subproject="core", + current_dir=projects_dir / "project_with_instances_yaml", + ) + print(config) + instance = ExampleClassWithConfigMixin.subclass_from_defaults(config) + assert isinstance(instance, ExampleSubclass) + assert instance.foo == "I am a foo" + assert instance.bar == 122 diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/bad_factory_project/pyproject.toml b/packages/ragbits-core/tests/unit/utils/testprojects/bad_factory_project/pyproject.toml similarity index 100% rename from packages/ragbits-core/tests/unit/utils/pyproject/testprojects/bad_factory_project/pyproject.toml rename to packages/ragbits-core/tests/unit/utils/testprojects/bad_factory_project/pyproject.toml diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/factory_project/pyproject.toml b/packages/ragbits-core/tests/unit/utils/testprojects/factory_project/pyproject.toml similarity index 100% rename from packages/ragbits-core/tests/unit/utils/pyproject/testprojects/factory_project/pyproject.toml rename to packages/ragbits-core/tests/unit/utils/testprojects/factory_project/pyproject.toml diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/happy_project/pyproject.toml b/packages/ragbits-core/tests/unit/utils/testprojects/happy_project/pyproject.toml similarity index 100% rename from packages/ragbits-core/tests/unit/utils/pyproject/testprojects/happy_project/pyproject.toml rename to packages/ragbits-core/tests/unit/utils/testprojects/happy_project/pyproject.toml diff --git a/packages/ragbits-core/tests/unit/utils/testprojects/project_with_instance_factory/pyproject.toml b/packages/ragbits-core/tests/unit/utils/testprojects/project_with_instance_factory/pyproject.toml new file mode 100644 index 000000000..aca545a9e --- /dev/null +++ b/packages/ragbits-core/tests/unit/utils/testprojects/project_with_instance_factory/pyproject.toml @@ -0,0 +1,5 @@ +[project] +name = "instance_factory_project" + +[tool.ragbits.core.default_factories] +example = "unit.utils.test_config_handling:example_factory" diff --git a/packages/ragbits-core/tests/unit/utils/testprojects/project_with_instances_yaml/instances.yaml b/packages/ragbits-core/tests/unit/utils/testprojects/project_with_instances_yaml/instances.yaml new file mode 100644 index 000000000..c08338ad1 --- /dev/null +++ b/packages/ragbits-core/tests/unit/utils/testprojects/project_with_instances_yaml/instances.yaml @@ -0,0 +1,5 @@ +example: + type: unit.utils.test_config_handling:ExampleSubclass + config: + foo: I am a foo + bar: 122 diff --git a/packages/ragbits-core/tests/unit/utils/testprojects/project_with_instances_yaml/pyproject.toml b/packages/ragbits-core/tests/unit/utils/testprojects/project_with_instances_yaml/pyproject.toml new file mode 100644 index 000000000..e5d5fc86f --- /dev/null +++ b/packages/ragbits-core/tests/unit/utils/testprojects/project_with_instances_yaml/pyproject.toml @@ -0,0 +1,5 @@ +[project] +name = "project_with_instances_yaml" + +[tool.ragbits.core] +default_instaces_config_path = "instances.yaml" diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/base.py index 15656db92..69f90c617 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/base.py @@ -23,6 +23,7 @@ class BaseProvider(WithConstructionConfig, ABC): """ default_module: ClassVar = providers + configuration_key: ClassVar = "provider" SUPPORTED_DOCUMENT_TYPES: set[DocumentType] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py index 88859d795..0b68b54ab 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py @@ -11,6 +11,7 @@ class QueryRephraser(WithConstructionConfig, ABC): """ default_module: ClassVar = rephrasers + configuration_key: ClassVar = "rephraser" @abstractmethod async def rephrase(self, query: str) -> list[str]: diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py index a7e08a951..6b335652e 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py @@ -25,6 +25,7 @@ class Reranker(WithConstructionConfig, ABC): """ default_module: ClassVar = rerankers + configuration_key: ClassVar = "reranker" def __init__(self, default_options: RerankerOptions | None = None) -> None: """ diff --git a/pyproject.toml b/pyproject.toml index 25d92cd21..2502b26b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dev-dependencies = [ "mkdocs-autorefs>=1.2.0", "mkdocs-material>=9.5.39", "mkdocs-material-extensions>=1.3.1", + "mkdocs-click>=0.8.1", "mkdocstrings>=0.26.1", "mkdocstrings-python>=1.11.1", "griffe>=1.3.2", diff --git a/uv.lock b/uv.lock index 8f8cd916d..3b60bd59e 100644 --- a/uv.lock +++ b/uv.lock @@ -2222,6 +2222,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/71/26/4d39d52ea2219604053a4d05b98e90d6a335511cc01806436ec4886b1028/mkdocs_autorefs-1.2.0-py3-none-any.whl", hash = "sha256:d588754ae89bd0ced0c70c06f58566a4ee43471eeeee5202427da7de9ef85a2f", size = 16522 }, ] +[[package]] +name = "mkdocs-click" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "markdown" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/61/d6b68573b4c399cd201502e4ea4cbfc12e274333d9ee622668cfbc9940ac/mkdocs_click-0.8.1.tar.gz", hash = "sha256:0a88cce04870c5d70ff63138e2418219c3c4119cc928a59c66b76eb5214edba6", size = 17874 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/ce/12158add31617ea579f7975f502812555371d7b8a4410c993a27d7e20727/mkdocs_click-0.8.1-py3-none-any.whl", hash = "sha256:a100ff938be63911f86465a1c21d29a669a7c51932b700fdb3daa90d13b61ee4", size = 14862 }, +] + [[package]] name = "mkdocs-get-deps" version = "0.2.0" @@ -4054,6 +4067,7 @@ dev = [ { name = "griffe-typingdoc" }, { name = "mkdocs" }, { name = "mkdocs-autorefs" }, + { name = "mkdocs-click" }, { name = "mkdocs-material" }, { name = "mkdocs-material-extensions" }, { name = "mkdocstrings" }, @@ -4082,6 +4096,7 @@ dev = [ { name = "griffe-typingdoc", specifier = ">=0.2.7" }, { name = "mkdocs", specifier = ">=1.6.1" }, { name = "mkdocs-autorefs", specifier = ">=1.2.0" }, + { name = "mkdocs-click", specifier = ">=0.8.1" }, { name = "mkdocs-material", specifier = ">=9.5.39" }, { name = "mkdocs-material-extensions", specifier = ">=1.3.1" }, { name = "mkdocstrings", specifier = ">=0.26.1" },