diff --git a/pyproject.toml b/pyproject.toml index e97586ad..400b6d1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ ignore = [ "E501", # line too long, handled by black. "C901", # too complex. "B008", # do not perform function calls in argument defaults. + "E731", # do not assign a lambda expression, use a def. ] [tool.codespell] diff --git a/sgpt/config.py b/sgpt/config.py index 4c603470..6dad7161 100644 --- a/sgpt/config.py +++ b/sgpt/config.py @@ -33,6 +33,7 @@ "SHOW_FUNCTIONS_OUTPUT": os.getenv("SHOW_FUNCTIONS_OUTPUT", "false"), "API_BASE_URL": os.getenv("API_BASE_URL", "default"), "PRETTIFY_MARKDOWN": os.getenv("PRETTIFY_MARKDOWN", "true"), + "USE_LITELLM": os.getenv("USE_LITELLM", "false"), # New features might add their own config variables here. } diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index d4d05c2a..130941ee 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -1,8 +1,6 @@ import json from pathlib import Path -from typing import Any, Dict, Generator, List, Optional - -import litellm # type: ignore +from typing import Any, Callable, Dict, Generator, List, Optional from ..cache import Cache from ..config import cfg @@ -10,7 +8,26 @@ from ..printer import MarkdownPrinter, Printer, TextPrinter from ..role import DefaultRoles, SystemRole -litellm.suppress_debug_info = True +completion: Callable[..., Any] = lambda *args, **kwargs: None +base_url = cfg.get("API_BASE_URL") +use_litellm = cfg.get("USE_LITELLM") == "true" +additional_kwargs = { + "timeout": int(cfg.get("REQUEST_TIMEOUT")), + "api_key": cfg.get("OPENAI_API_KEY"), + "base_url": None if base_url == "default" else base_url, +} + +if use_litellm: + import litellm # type: ignore + + completion = litellm.completion + litellm.suppress_debug_info = True +else: + from openai import OpenAI + + client = OpenAI(**additional_kwargs) # type: ignore + completion = client.chat.completions.create + additional_kwargs = {} class Handler: @@ -79,19 +96,20 @@ def get_completion( if is_shell_role or is_code_role or is_dsc_shell_role: functions = None - for chunk in litellm.completion( + for chunk in completion( model=model, temperature=temperature, top_p=top_p, messages=messages, functions=functions, stream=True, - api_key=cfg.get("OPENAI_API_KEY"), - base_url=self.base_url, - timeout=self.timeout, + **additional_kwargs, ): delta = chunk.choices[0].delta - function_call = delta.get("function_call") + # LiteLLM uses dict instead of Pydantic object like OpenAI does. + function_call = ( + delta.get("function_call") if use_litellm else delta.function_call + ) if function_call: if function_call.name: name = function_call.name @@ -104,7 +122,7 @@ def get_completion( ) return - yield delta.get("content") or "" + yield delta.content or "" def handle( self, diff --git a/tests/test_code.py b/tests/test_code.py index e53166db..12a4c532 100644 --- a/tests/test_code.py +++ b/tests/test_code.py @@ -9,7 +9,7 @@ role = SystemRole.get(DefaultRoles.CODE.value) -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_code_generation(completion): completion.return_value = mock_comp("print('Hello World')") @@ -23,7 +23,7 @@ def test_code_generation(completion): @patch("sgpt.printer.TextPrinter.live_print") @patch("sgpt.printer.MarkdownPrinter.live_print") -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_code_generation_no_markdown(completion, markdown_printer, text_printer): completion.return_value = mock_comp("print('Hello World')") @@ -36,7 +36,7 @@ def test_code_generation_no_markdown(completion, markdown_printer, text_printer) text_printer.assert_called() -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_code_generation_stdin(completion): completion.return_value = mock_comp("# Hello\nprint('Hello')") @@ -51,7 +51,7 @@ def test_code_generation_stdin(completion): assert "print('Hello')" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_code_chat(completion): completion.side_effect = [ mock_comp("print('hello')"), @@ -92,7 +92,7 @@ def test_code_chat(completion): # TODO: Code chat can be recalled without --code option. -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_code_repl(completion): completion.side_effect = [ mock_comp("print('hello')"), @@ -124,7 +124,7 @@ def test_code_repl(completion): assert "print('world')" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_code_and_shell(completion): args = {"--code": True, "--shell": True} result = runner.invoke(app, cmd_args(**args)) @@ -134,7 +134,7 @@ def test_code_and_shell(completion): assert "Error" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_code_and_describe_shell(completion): args = {"--code": True, "--describe-shell": True} result = runner.invoke(app, cmd_args(**args)) diff --git a/tests/test_default.py b/tests/test_default.py index b1c95aa5..a9b11a25 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -14,7 +14,7 @@ cfg = config.cfg -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_default(completion): completion.return_value = mock_comp("Prague") @@ -26,7 +26,7 @@ def test_default(completion): assert "Prague" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_default_stdin(completion): completion.return_value = mock_comp("Prague") @@ -39,7 +39,7 @@ def test_default_stdin(completion): @patch("rich.console.Console.print") -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_show_chat_use_markdown(completion, console_print): completion.return_value = mock_comp("ok") chat_name = "_test" @@ -57,7 +57,7 @@ def test_show_chat_use_markdown(completion, console_print): @patch("rich.console.Console.print") -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_show_chat_no_use_markdown(completion, console_print): completion.return_value = mock_comp("ok") chat_name = "_test" @@ -75,7 +75,7 @@ def test_show_chat_no_use_markdown(completion, console_print): console_print.assert_not_called() -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_default_chat(completion): completion.side_effect = [mock_comp("ok"), mock_comp("4")] chat_name = "_test" @@ -127,7 +127,7 @@ def test_default_chat(completion): chat_path.unlink() -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_default_repl(completion): completion.side_effect = [mock_comp("ok"), mock_comp("8")] chat_name = "_test" @@ -156,7 +156,7 @@ def test_default_repl(completion): assert "8" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_default_repl_stdin(completion): completion.side_effect = [mock_comp("ok init"), mock_comp("ok another")] chat_name = "_test" @@ -190,7 +190,7 @@ def test_default_repl_stdin(completion): assert "ok another" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_llm_options(completion): completion.return_value = mock_comp("Berlin") @@ -216,7 +216,7 @@ def test_llm_options(completion): assert "Berlin" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_version(completion): args = {"--version": True} result = runner.invoke(app, cmd_args(**args)) @@ -227,7 +227,7 @@ def test_version(completion): @patch("sgpt.printer.TextPrinter.live_print") @patch("sgpt.printer.MarkdownPrinter.live_print") -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_markdown(completion, markdown_printer, text_printer): completion.return_value = mock_comp("pong") @@ -240,7 +240,7 @@ def test_markdown(completion, markdown_printer, text_printer): @patch("sgpt.printer.TextPrinter.live_print") @patch("sgpt.printer.MarkdownPrinter.live_print") -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_no_markdown(completion, markdown_printer, text_printer): completion.return_value = mock_comp("pong") diff --git a/tests/test_roles.py b/tests/test_roles.py index dcf1b94a..a7cfe33c 100644 --- a/tests/test_roles.py +++ b/tests/test_roles.py @@ -8,7 +8,7 @@ from .utils import app, cmd_args, comp_args, mock_comp, runner -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_role(completion): completion.return_value = mock_comp('{"foo": "bar"}') path = Path(cfg.get("ROLE_STORAGE_PATH")) / "json_gen_test.json" diff --git a/tests/test_shell.py b/tests/test_shell.py index f99f4936..b78e2c96 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -8,7 +8,7 @@ from .utils import app, cmd_args, comp_args, mock_comp, runner -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_shell(completion): role = SystemRole.get(DefaultRoles.SHELL.value) completion.return_value = mock_comp("git commit -m test") @@ -24,7 +24,7 @@ def test_shell(completion): @patch("sgpt.printer.TextPrinter.live_print") @patch("sgpt.printer.MarkdownPrinter.live_print") -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_shell_no_markdown(completion, markdown_printer, text_printer): completion.return_value = mock_comp("git commit -m test") @@ -37,7 +37,7 @@ def test_shell_no_markdown(completion, markdown_printer, text_printer): text_printer.assert_called() -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_shell_stdin(completion): completion.return_value = mock_comp("ls -l | sort") role = SystemRole.get(DefaultRoles.SHELL.value) @@ -53,7 +53,7 @@ def test_shell_stdin(completion): assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_describe_shell(completion): completion.return_value = mock_comp("lists the contents of a folder") role = SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value) @@ -66,7 +66,7 @@ def test_describe_shell(completion): assert "lists" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_describe_shell_stdin(completion): completion.return_value = mock_comp("lists the contents of a folder") role = SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value) @@ -82,7 +82,7 @@ def test_describe_shell_stdin(completion): @patch("os.system") -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_shell_run_description(completion, system): completion.side_effect = [mock_comp("echo hello"), mock_comp("prints hello")] args = {"prompt": "echo hello", "--shell": True} @@ -95,7 +95,7 @@ def test_shell_run_description(completion, system): assert "prints hello" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_shell_chat(completion): completion.side_effect = [mock_comp("ls"), mock_comp("ls | sort")] role = SystemRole.get(DefaultRoles.SHELL.value) @@ -134,7 +134,7 @@ def test_shell_chat(completion): @patch("os.system") -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_shell_repl(completion, mock_system): completion.side_effect = [mock_comp("ls"), mock_comp("ls | sort")] role = SystemRole.get(DefaultRoles.SHELL.value) @@ -166,7 +166,7 @@ def test_shell_repl(completion, mock_system): assert "ls | sort" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_shell_and_describe_shell(completion): args = {"prompt": "ls", "--describe-shell": True, "--shell": True} result = runner.invoke(app, cmd_args(**args)) @@ -176,7 +176,7 @@ def test_shell_and_describe_shell(completion): assert "Error" in result.stdout -@patch("litellm.completion") +@patch("sgpt.handlers.handler.completion") def test_shell_no_interaction(completion): completion.return_value = mock_comp("git commit -m test") role = SystemRole.get(DefaultRoles.SHELL.value) diff --git a/tests/utils.py b/tests/utils.py index f2a022f7..8da7d947 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,9 @@ -from unittest.mock import ANY +from datetime import datetime import typer -from litellm import completion as completion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as StreamChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta from typer.testing import CliRunner from sgpt import main @@ -13,8 +15,22 @@ def mock_comp(tokens_string): - model = cfg.get("DEFAULT_MODEL") - return completion(model=model, mock_response=tokens_string, stream=True) + return [ + ChatCompletionChunk( + id="foo", + model=cfg.get("DEFAULT_MODEL"), + object="chat.completion.chunk", + choices=[ + StreamChoice( + index=0, + finish_reason=None, + delta=ChoiceDelta(content=token, role="assistant"), + ), + ], + created=int(datetime.now().timestamp()), + ) + for token in tokens_string + ] def cmd_args(prompt="", **kwargs): @@ -40,8 +56,5 @@ def comp_args(role, prompt, **kwargs): "top_p": 1.0, "functions": None, "stream": True, - "api_key": ANY, - "base_url": ANY, - "timeout": ANY, **kwargs, }