Skip to content

Commit

Permalink
Use OpenAI by default, make LiteLLM optional
Browse files Browse the repository at this point in the history
  • Loading branch information
TheR1D committed Feb 21, 2024
1 parent 779749f commit 3aa49d0
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 46 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ ignore = [
"E501", # line too long, handled by black.
"C901", # too complex.
"B008", # do not perform function calls in argument defaults.
"E731", # do not assign a lambda expression, use a def.
]

[tool.codespell]
Expand Down
1 change: 1 addition & 0 deletions sgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"SHOW_FUNCTIONS_OUTPUT": os.getenv("SHOW_FUNCTIONS_OUTPUT", "false"),
"API_BASE_URL": os.getenv("API_BASE_URL", "default"),
"PRETTIFY_MARKDOWN": os.getenv("PRETTIFY_MARKDOWN", "true"),
"USE_LITELLM": os.getenv("USE_LITELLM", "false"),
# New features might add their own config variables here.
}

Expand Down
38 changes: 28 additions & 10 deletions sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
import json
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional

import litellm # type: ignore
from typing import Any, Callable, Dict, Generator, List, Optional

from ..cache import Cache
from ..config import cfg
from ..function import get_function
from ..printer import MarkdownPrinter, Printer, TextPrinter
from ..role import DefaultRoles, SystemRole

litellm.suppress_debug_info = True
completion: Callable[..., Any] = lambda *args, **kwargs: None
base_url = cfg.get("API_BASE_URL")
use_litellm = cfg.get("USE_LITELLM") == "true"
additional_kwargs = {
"timeout": int(cfg.get("REQUEST_TIMEOUT")),
"api_key": cfg.get("OPENAI_API_KEY"),
"base_url": None if base_url == "default" else base_url,
}

if use_litellm:
import litellm # type: ignore

completion = litellm.completion
litellm.suppress_debug_info = True
else:
from openai import OpenAI

client = OpenAI(**additional_kwargs) # type: ignore
completion = client.chat.completions.create
additional_kwargs = {}


class Handler:
Expand Down Expand Up @@ -79,19 +96,20 @@ def get_completion(
if is_shell_role or is_code_role or is_dsc_shell_role:
functions = None

for chunk in litellm.completion(
for chunk in completion(
model=model,
temperature=temperature,
top_p=top_p,
messages=messages,
functions=functions,
stream=True,
api_key=cfg.get("OPENAI_API_KEY"),
base_url=self.base_url,
timeout=self.timeout,
**additional_kwargs,
):
delta = chunk.choices[0].delta
function_call = delta.get("function_call")
# LiteLLM uses dict instead of Pydantic object like OpenAI does.
function_call = (
delta.get("function_call") if use_litellm else delta.function_call
)
if function_call:
if function_call.name:
name = function_call.name
Expand All @@ -104,7 +122,7 @@ def get_completion(
)
return

yield delta.get("content") or ""
yield delta.content or ""

def handle(
self,
Expand Down
14 changes: 7 additions & 7 deletions tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
role = SystemRole.get(DefaultRoles.CODE.value)


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_generation(completion):
completion.return_value = mock_comp("print('Hello World')")

Expand All @@ -23,7 +23,7 @@ def test_code_generation(completion):

@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_generation_no_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("print('Hello World')")

Expand All @@ -36,7 +36,7 @@ def test_code_generation_no_markdown(completion, markdown_printer, text_printer)
text_printer.assert_called()


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_generation_stdin(completion):
completion.return_value = mock_comp("# Hello\nprint('Hello')")

Expand All @@ -51,7 +51,7 @@ def test_code_generation_stdin(completion):
assert "print('Hello')" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_chat(completion):
completion.side_effect = [
mock_comp("print('hello')"),
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_code_chat(completion):
# TODO: Code chat can be recalled without --code option.


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_repl(completion):
completion.side_effect = [
mock_comp("print('hello')"),
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_code_repl(completion):
assert "print('world')" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_and_shell(completion):
args = {"--code": True, "--shell": True}
result = runner.invoke(app, cmd_args(**args))
Expand All @@ -134,7 +134,7 @@ def test_code_and_shell(completion):
assert "Error" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_code_and_describe_shell(completion):
args = {"--code": True, "--describe-shell": True}
result = runner.invoke(app, cmd_args(**args))
Expand Down
22 changes: 11 additions & 11 deletions tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
cfg = config.cfg


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_default(completion):
completion.return_value = mock_comp("Prague")

Expand All @@ -26,7 +26,7 @@ def test_default(completion):
assert "Prague" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_default_stdin(completion):
completion.return_value = mock_comp("Prague")

Expand All @@ -39,7 +39,7 @@ def test_default_stdin(completion):


@patch("rich.console.Console.print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_show_chat_use_markdown(completion, console_print):
completion.return_value = mock_comp("ok")
chat_name = "_test"
Expand All @@ -57,7 +57,7 @@ def test_show_chat_use_markdown(completion, console_print):


@patch("rich.console.Console.print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_show_chat_no_use_markdown(completion, console_print):
completion.return_value = mock_comp("ok")
chat_name = "_test"
Expand All @@ -75,7 +75,7 @@ def test_show_chat_no_use_markdown(completion, console_print):
console_print.assert_not_called()


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_default_chat(completion):
completion.side_effect = [mock_comp("ok"), mock_comp("4")]
chat_name = "_test"
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_default_chat(completion):
chat_path.unlink()


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_default_repl(completion):
completion.side_effect = [mock_comp("ok"), mock_comp("8")]
chat_name = "_test"
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_default_repl(completion):
assert "8" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_default_repl_stdin(completion):
completion.side_effect = [mock_comp("ok init"), mock_comp("ok another")]
chat_name = "_test"
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_default_repl_stdin(completion):
assert "ok another" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_llm_options(completion):
completion.return_value = mock_comp("Berlin")

Expand All @@ -216,7 +216,7 @@ def test_llm_options(completion):
assert "Berlin" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_version(completion):
args = {"--version": True}
result = runner.invoke(app, cmd_args(**args))
Expand All @@ -227,7 +227,7 @@ def test_version(completion):

@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("pong")

Expand All @@ -240,7 +240,7 @@ def test_markdown(completion, markdown_printer, text_printer):

@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_no_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("pong")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .utils import app, cmd_args, comp_args, mock_comp, runner


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_role(completion):
completion.return_value = mock_comp('{"foo": "bar"}')
path = Path(cfg.get("ROLE_STORAGE_PATH")) / "json_gen_test.json"
Expand Down
20 changes: 10 additions & 10 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .utils import app, cmd_args, comp_args, mock_comp, runner


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell(completion):
role = SystemRole.get(DefaultRoles.SHELL.value)
completion.return_value = mock_comp("git commit -m test")
Expand All @@ -24,7 +24,7 @@ def test_shell(completion):

@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_no_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("git commit -m test")

Expand All @@ -37,7 +37,7 @@ def test_shell_no_markdown(completion, markdown_printer, text_printer):
text_printer.assert_called()


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_stdin(completion):
completion.return_value = mock_comp("ls -l | sort")
role = SystemRole.get(DefaultRoles.SHELL.value)
Expand All @@ -53,7 +53,7 @@ def test_shell_stdin(completion):
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_describe_shell(completion):
completion.return_value = mock_comp("lists the contents of a folder")
role = SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value)
Expand All @@ -66,7 +66,7 @@ def test_describe_shell(completion):
assert "lists" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_describe_shell_stdin(completion):
completion.return_value = mock_comp("lists the contents of a folder")
role = SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value)
Expand All @@ -82,7 +82,7 @@ def test_describe_shell_stdin(completion):


@patch("os.system")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_run_description(completion, system):
completion.side_effect = [mock_comp("echo hello"), mock_comp("prints hello")]
args = {"prompt": "echo hello", "--shell": True}
Expand All @@ -95,7 +95,7 @@ def test_shell_run_description(completion, system):
assert "prints hello" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_chat(completion):
completion.side_effect = [mock_comp("ls"), mock_comp("ls | sort")]
role = SystemRole.get(DefaultRoles.SHELL.value)
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_shell_chat(completion):


@patch("os.system")
@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_repl(completion, mock_system):
completion.side_effect = [mock_comp("ls"), mock_comp("ls | sort")]
role = SystemRole.get(DefaultRoles.SHELL.value)
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_shell_repl(completion, mock_system):
assert "ls | sort" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_and_describe_shell(completion):
args = {"prompt": "ls", "--describe-shell": True, "--shell": True}
result = runner.invoke(app, cmd_args(**args))
Expand All @@ -176,7 +176,7 @@ def test_shell_and_describe_shell(completion):
assert "Error" in result.stdout


@patch("litellm.completion")
@patch("sgpt.handlers.handler.completion")
def test_shell_no_interaction(completion):
completion.return_value = mock_comp("git commit -m test")
role = SystemRole.get(DefaultRoles.SHELL.value)
Expand Down
27 changes: 20 additions & 7 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest.mock import ANY
from datetime import datetime

import typer
from litellm import completion as completion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as StreamChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from typer.testing import CliRunner

from sgpt import main
Expand All @@ -13,8 +15,22 @@


def mock_comp(tokens_string):
model = cfg.get("DEFAULT_MODEL")
return completion(model=model, mock_response=tokens_string, stream=True)
return [
ChatCompletionChunk(
id="foo",
model=cfg.get("DEFAULT_MODEL"),
object="chat.completion.chunk",
choices=[
StreamChoice(
index=0,
finish_reason=None,
delta=ChoiceDelta(content=token, role="assistant"),
),
],
created=int(datetime.now().timestamp()),
)
for token in tokens_string
]


def cmd_args(prompt="", **kwargs):
Expand All @@ -40,8 +56,5 @@ def comp_args(role, prompt, **kwargs):
"top_p": 1.0,
"functions": None,
"stream": True,
"api_key": ANY,
"base_url": ANY,
"timeout": ANY,
**kwargs,
}

0 comments on commit 3aa49d0

Please sign in to comment.