-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat(vector-store): CLI commands for managing vector stores #244
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import json | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: the code in this file is almost entirely copied from now deleted The only other difference is the typing of |
||
from collections.abc import Sequence | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
|
||
from pydantic import BaseModel | ||
from rich.console import Console | ||
from rich.table import Table | ||
|
||
|
||
class OutputType(Enum): | ||
"""Indicates a type of CLI output formatting""" | ||
|
||
text = "text" | ||
json = "json" | ||
|
||
|
||
@dataclass() | ||
class CliState: | ||
"""A dataclass describing CLI state""" | ||
|
||
output_type: OutputType = OutputType.text | ||
|
||
|
||
cli_state = CliState() | ||
|
||
|
||
def print_output(data: Sequence[BaseModel] | BaseModel) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in current state would rename it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? It's still under the |
||
""" | ||
Process and display output based on the current state's output type. | ||
|
||
Args: | ||
data: a list of pydantic models representing output of CLI function | ||
""" | ||
console = Console() | ||
if isinstance(data, BaseModel): | ||
data = [data] | ||
if len(data) == 0: | ||
_print_empty_list() | ||
return | ||
first_el_instance = type(data[0]) | ||
if any(not isinstance(datapoint, first_el_instance) for datapoint in data): | ||
raise ValueError("All the rows need to be of the same type") | ||
data_dicts: list[dict] = [output.model_dump(mode="python") for output in data] | ||
output_type = cli_state.output_type | ||
if output_type == OutputType.json: | ||
console.print(json.dumps(data_dicts, indent=4)) | ||
elif output_type == OutputType.text: | ||
table = Table(show_header=True, header_style="bold magenta") | ||
properties = data[0].model_json_schema()["properties"] | ||
for key in properties: | ||
table.add_column(properties[key]["title"]) | ||
for row in data_dicts: | ||
table.add_row(*[str(value) for value in row.values()]) | ||
console.print(table) | ||
else: | ||
raise ValueError(f"Output type: {output_type} not supported") | ||
|
||
|
||
def _print_empty_list() -> None: | ||
if cli_state.output_type == OutputType.text: | ||
print("Empty data list") | ||
elif cli_state.output_type == OutputType.json: | ||
print(json.dumps([])) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,15 @@ | ||
# pylint: disable=import-outside-toplevel | ||
# pylint: disable=missing-param-doc | ||
import asyncio | ||
import json | ||
from importlib import import_module | ||
from pathlib import Path | ||
|
||
import typer | ||
from pydantic import BaseModel | ||
|
||
from ragbits.cli.app import CLI | ||
from ragbits.core.config import core_config | ||
from ragbits.core.llms.base import LLM, LLMType | ||
from ragbits.core.prompt.prompt import ChatFormat, Prompt | ||
|
||
|
||
def _render(prompt_path: str, payload: str | None) -> Prompt: | ||
module_stringified, object_stringified = prompt_path.split(":") | ||
prompt_cls = getattr(import_module(module_stringified), object_stringified) | ||
|
||
if payload is not None: | ||
payload = json.loads(payload) | ||
inputs = prompt_cls.input_type(**payload) | ||
return prompt_cls(inputs) | ||
|
||
return prompt_cls() | ||
|
||
|
||
class LLMResponseCliOutput(BaseModel): | ||
"""An output model for llm responses in CLI""" | ||
|
||
question: ChatFormat | ||
answer: str | BaseModel | None = None | ||
|
||
|
||
prompts_app = typer.Typer(no_args_is_help=True) | ||
from ragbits.core.prompt._cli import prompts_app | ||
from ragbits.core.vector_stores._cli import vector_stores_app | ||
|
||
|
||
def register(app: CLI) -> None: | ||
def register(app: typer.Typer) -> None: | ||
""" | ||
Register the CLI commands for the package. | ||
|
||
Args: | ||
app: The Typer object to register the commands with. | ||
""" | ||
|
||
@prompts_app.command() | ||
def lab( | ||
file_pattern: str = core_config.prompt_path_pattern, | ||
llm_factory: str = core_config.default_llm_factories[LLMType.TEXT], | ||
) -> None: | ||
""" | ||
Launches the interactive application for listing, rendering, and testing prompts | ||
defined within the current project. | ||
""" | ||
from ragbits.core.prompt.lab.app import lab_app | ||
|
||
lab_app(file_pattern=file_pattern, llm_factory=llm_factory) | ||
|
||
@prompts_app.command() | ||
def generate_promptfoo_configs( | ||
file_pattern: str = core_config.prompt_path_pattern, | ||
root_path: Path = Path.cwd(), # noqa: B008 | ||
target_path: Path = Path("promptfooconfigs"), | ||
) -> None: | ||
""" | ||
Generates the configuration files for the PromptFoo prompts. | ||
""" | ||
from ragbits.core.prompt.promptfoo import generate_configs | ||
|
||
generate_configs(file_pattern=file_pattern, root_path=root_path, target_path=target_path) | ||
|
||
@prompts_app.command() | ||
def render(prompt_path: str, payload: str | None = None) -> None: | ||
""" | ||
Renders a prompt by loading a class from a module and initializing it with a given payload. | ||
""" | ||
prompt = _render(prompt_path=prompt_path, payload=payload) | ||
response = LLMResponseCliOutput(question=prompt.chat) | ||
app.print_output(response) | ||
|
||
@prompts_app.command(name="exec") | ||
def execute( | ||
prompt_path: str, | ||
payload: str | None = None, | ||
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT], | ||
) -> None: | ||
""" | ||
Executes a prompt using the specified prompt class and LLM factory. | ||
|
||
Raises: | ||
ValueError: If `llm_factory` is not provided. | ||
""" | ||
prompt = _render(prompt_path=prompt_path, payload=payload) | ||
|
||
if llm_factory is None: | ||
raise ValueError("`llm_factory` must be provided") | ||
llm: LLM = LLM.subclass_from_factory(llm_factory) | ||
|
||
llm_output = asyncio.run(llm.generate(prompt)) | ||
response = LLMResponseCliOutput(question=prompt.chat, answer=llm_output) | ||
app.print_output(response) | ||
|
||
app.add_typer(prompts_app, name="prompts", help="Commands for managing prompts") | ||
app.add_typer(vector_stores_app, name="vector-store", help="Commands for managing vector stores") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,20 @@ | ||
from functools import cached_property | ||
from pathlib import Path | ||
|
||
from pydantic import BaseModel | ||
|
||
from ragbits.core.llms.base import LLMType | ||
from ragbits.core.utils._pyproject import get_config_instance | ||
from ragbits.core.utils._pyproject import get_config_from_yaml, get_config_instance | ||
|
||
|
||
class CoreConfig(BaseModel): | ||
""" | ||
Configuration for the ragbits-core package, loaded from downstream projects' pyproject.toml files. | ||
""" | ||
|
||
# Path to the base directory of the project, defaults to the directory of the pyproject.toml file | ||
project_base_path: Path | None = None | ||
|
||
# Pattern used to search for prompt files | ||
prompt_path_pattern: str = "**/prompt_*.py" | ||
|
||
|
@@ -19,5 +25,24 @@ class CoreConfig(BaseModel): | |
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory:simple_litellm_structured_output_factory", | ||
} | ||
|
||
# Path to functions that returns instances of diffrent types of Ragbits objects | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: for LLMs |
||
default_factories: dict[str, str] = {} | ||
|
||
# Path to a YAML file with default configuration of varius Ragbits objects | ||
default_instaces_config_path: Path | None = None | ||
|
||
@cached_property | ||
def default_instances_config(self) -> dict: | ||
""" | ||
Get the configuration from the file specified in default_instaces_config_path. | ||
|
||
Returns: | ||
dict: The configuration from the file. | ||
""" | ||
if self.default_instaces_config_path is None or not self.project_base_path: | ||
return {} | ||
|
||
return get_config_from_yaml(self.project_base_path / self.default_instaces_config_path) | ||
|
||
|
||
core_config = get_config_instance(CoreConfig, subproject="core") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho
autoregister
would be more clear - orfind_and_register
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I don't have a preference here, will change to
autoregister