Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use OpenAI by default, make LiteLLM optional #488

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ OPENAI_FUNCTIONS_PATH=/Users/user/.config/shell_gpt/functions
SHOW_FUNCTIONS_OUTPUT=false
# Allows LLM to use functions.
OPENAI_USE_FUNCTIONS=true
# Enforce LiteLLM usage (for local LLMs).
USE_LITELLM=false
```
Possible options for `DEFAULT_COLOR`: black, red, green, yellow, blue, magenta, cyan, white, bright_black, bright_red, bright_green, bright_yellow, bright_blue, bright_magenta, bright_cyan, bright_white.
Possible options for `CODE_THEME`: https://pygments.org/styles/
Expand All @@ -423,6 +425,7 @@ Possible options for `CODE_THEME`: https://pygments.org/styles/
│ --model TEXT Large language model to use. [default: gpt-4-1106-preview] │
│ --temperature FLOAT RANGE [0.0<=x<=2.0] Randomness of generated output. [default: 0.0] │
│ --top-p FLOAT RANGE [0.0<=x<=1.0] Limits highest probable tokens (words). [default: 1.0] │
│ --md --no-md Prettify markdown output. [default: md] │
│ --editor Open $EDITOR to provide a prompt. [default: no-editor] │
│ --cache Cache completion results. [default: cache] │
│ --version Show version. │
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
]
dependencies = [
"litellm == 1.24.5",
"openai >= 1.6.1, < 2.0.0",
"typer >= 0.7.0, < 1.0.0",
"click >= 7.1.1, < 9.0.0",
"rich >= 13.1.0, < 14.0.0",
Expand All @@ -38,6 +38,9 @@ dependencies = [
]

[project.optional-dependencies]
litellm = [
"litellm == 1.24.5"
]
test = [
"pytest >= 7.2.2, < 8.0.0",
"requests-mock[fixture] >= 1.10.0, < 2.0.0",
Expand Down Expand Up @@ -95,6 +98,7 @@ ignore = [
"E501", # line too long, handled by black.
"C901", # too complex.
"B008", # do not perform function calls in argument defaults.
"E731", # do not assign a lambda expression, use a def.
]

[tool.codespell]
Expand Down
2 changes: 1 addition & 1 deletion sgpt/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.3.1"
__version__ = "1.4.1"
1 change: 1 addition & 0 deletions sgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"SHOW_FUNCTIONS_OUTPUT": os.getenv("SHOW_FUNCTIONS_OUTPUT", "false"),
"API_BASE_URL": os.getenv("API_BASE_URL", "default"),
"PRETTIFY_MARKDOWN": os.getenv("PRETTIFY_MARKDOWN", "true"),
"USE_LITELLM": os.getenv("USE_LITELLM", "false"),
# New features might add their own config variables here.
}

Expand Down
38 changes: 28 additions & 10 deletions sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
import json
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional

import litellm # type: ignore
from typing import Any, Callable, Dict, Generator, List, Optional

from ..cache import Cache
from ..config import cfg
from ..function import get_function
from ..printer import MarkdownPrinter, Printer, TextPrinter
from ..role import DefaultRoles, SystemRole

litellm.suppress_debug_info = True
completion: Callable[..., Any] = lambda *args, **kwargs: None
base_url = cfg.get("API_BASE_URL")
use_litellm = cfg.get("USE_LITELLM") == "true"
additional_kwargs = {
"timeout": int(cfg.get("REQUEST_TIMEOUT")),
"api_key": cfg.get("OPENAI_API_KEY"),
"base_url": None if base_url == "default" else base_url,
}

if use_litellm:
import litellm # type: ignore

completion = litellm.completion
litellm.suppress_debug_info = True
else:
from openai import OpenAI

client = OpenAI(**additional_kwargs) # type: ignore
completion = client.chat.completions.create
additional_kwargs = {}


class Handler:
Expand Down Expand Up @@ -79,19 +96,20 @@ def get_completion(
if is_shell_role or is_code_role or is_dsc_shell_role:
functions = None

for chunk in litellm.completion(
for chunk in completion(
model=model,
temperature=temperature,
top_p=top_p,
messages=messages,
functions=functions,
stream=True,
api_key=cfg.get("OPENAI_API_KEY"),
base_url=self.base_url,
timeout=self.timeout,
**additional_kwargs,
):
delta = chunk.choices[0].delta
function_call = delta.get("function_call")
# LiteLLM uses dict instead of Pydantic object like OpenAI does.
function_call = (
delta.get("function_call") if use_litellm else delta.function_call
)
if function_call:
if function_call.name:
name = function_call.name
Expand All @@ -104,7 +122,7 @@ def get_completion(
)
return

yield delta.get("content") or ""
yield delta.content or ""

def handle(
self,
Expand Down
14 changes: 7 additions & 7 deletions tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
role = SystemRole.get(DefaultRoles.CODE.value)


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_generation(completion):
completion.return_value = mock_comp("print('Hello World')")

Expand All @@ -23,7 +23,7 @@ def test_code_generation(completion):

@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_generation_no_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("print('Hello World')")

Expand All @@ -36,7 +36,7 @@ def test_code_generation_no_markdown(completion, markdown_printer, text_printer)
text_printer.assert_called()


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_generation_stdin(completion):
completion.return_value = mock_comp("# Hello\nprint('Hello')")

Expand All @@ -51,7 +51,7 @@ def test_code_generation_stdin(completion):
assert "print('Hello')" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_chat(completion):
completion.side_effect = [
mock_comp("print('hello')"),
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_code_chat(completion):
# TODO: Code chat can be recalled without --code option.


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_repl(completion):
completion.side_effect = [
mock_comp("print('hello')"),
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_code_repl(completion):
assert "print('world')" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_and_shell(completion):
args = {"--code": True, "--shell": True}
result = runner.invoke(app, cmd_args(**args))
Expand All @@ -134,7 +134,7 @@ def test_code_and_shell(completion):
assert "Error" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_and_describe_shell(completion):
args = {"--code": True, "--describe-shell": True}
result = runner.invoke(app, cmd_args(**args))
Expand Down
22 changes: 11 additions & 11 deletions tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
cfg = config.cfg


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_default(completion):
completion.return_value = mock_comp("Prague")

Expand All @@ -26,7 +26,7 @@ def test_default(completion):
assert "Prague" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_default_stdin(completion):
completion.return_value = mock_comp("Prague")

Expand All @@ -39,7 +39,7 @@ def test_default_stdin(completion):


@patch("rich.console.Console.print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_show_chat_use_markdown(completion, console_print):
completion.return_value = mock_comp("ok")
chat_name = "_test"
Expand All @@ -57,7 +57,7 @@ def test_show_chat_use_markdown(completion, console_print):


@patch("rich.console.Console.print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_show_chat_no_use_markdown(completion, console_print):
completion.return_value = mock_comp("ok")
chat_name = "_test"
Expand All @@ -75,7 +75,7 @@ def test_show_chat_no_use_markdown(completion, console_print):
console_print.assert_not_called()


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_default_chat(completion):
completion.side_effect = [mock_comp("ok"), mock_comp("4")]
chat_name = "_test"
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_default_chat(completion):
chat_path.unlink()


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_default_repl(completion):
completion.side_effect = [mock_comp("ok"), mock_comp("8")]
chat_name = "_test"
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_default_repl(completion):
assert "8" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_default_repl_stdin(completion):
completion.side_effect = [mock_comp("ok init"), mock_comp("ok another")]
chat_name = "_test"
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_default_repl_stdin(completion):
assert "ok another" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_llm_options(completion):
completion.return_value = mock_comp("Berlin")

Expand All @@ -216,7 +216,7 @@ def test_llm_options(completion):
assert "Berlin" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_version(completion):
args = {"--version": True}
result = runner.invoke(app, cmd_args(**args))
Expand All @@ -227,7 +227,7 @@ def test_version(completion):

@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("pong")

Expand All @@ -240,7 +240,7 @@ def test_markdown(completion, markdown_printer, text_printer):

@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_no_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("pong")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .utils import app, cmd_args, comp_args, mock_comp, runner


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_role(completion):
completion.return_value = mock_comp('{"foo": "bar"}')
path = Path(cfg.get("ROLE_STORAGE_PATH")) / "json_gen_test.json"
Expand Down
20 changes: 10 additions & 10 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .utils import app, cmd_args, comp_args, mock_comp, runner


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell(completion):
role = SystemRole.get(DefaultRoles.SHELL.value)
completion.return_value = mock_comp("git commit -m test")
Expand All @@ -24,7 +24,7 @@ def test_shell(completion):

@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_no_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("git commit -m test")

Expand All @@ -37,7 +37,7 @@ def test_shell_no_markdown(completion, markdown_printer, text_printer):
text_printer.assert_called()


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_stdin(completion):
completion.return_value = mock_comp("ls -l | sort")
role = SystemRole.get(DefaultRoles.SHELL.value)
Expand All @@ -53,7 +53,7 @@ def test_shell_stdin(completion):
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_describe_shell(completion):
completion.return_value = mock_comp("lists the contents of a folder")
role = SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value)
Expand All @@ -66,7 +66,7 @@ def test_describe_shell(completion):
assert "lists" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_describe_shell_stdin(completion):
completion.return_value = mock_comp("lists the contents of a folder")
role = SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value)
Expand All @@ -82,7 +82,7 @@ def test_describe_shell_stdin(completion):


@patch("os.system")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_run_description(completion, system):
completion.side_effect = [mock_comp("echo hello"), mock_comp("prints hello")]
args = {"prompt": "echo hello", "--shell": True}
Expand All @@ -95,7 +95,7 @@ def test_shell_run_description(completion, system):
assert "prints hello" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_chat(completion):
completion.side_effect = [mock_comp("ls"), mock_comp("ls | sort")]
role = SystemRole.get(DefaultRoles.SHELL.value)
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_shell_chat(completion):


@patch("os.system")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_repl(completion, mock_system):
completion.side_effect = [mock_comp("ls"), mock_comp("ls | sort")]
role = SystemRole.get(DefaultRoles.SHELL.value)
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_shell_repl(completion, mock_system):
assert "ls | sort" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_and_describe_shell(completion):
args = {"prompt": "ls", "--describe-shell": True, "--shell": True}
result = runner.invoke(app, cmd_args(**args))
Expand All @@ -176,7 +176,7 @@ def test_shell_and_describe_shell(completion):
assert "Error" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_no_interaction(completion):
completion.return_value = mock_comp("git commit -m test")
role = SystemRole.get(DefaultRoles.SHELL.value)
Expand Down
Loading
Loading