diff --git a/pyproject.toml b/pyproject.toml index fe2d5848..473aa9f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "click >= 7.1.1, < 9.0.0", "rich >= 13.1.0, < 14.0.0", "distro >= 1.8.0, < 2.0.0", + "openai >= 1.6.1, < 2.0.0", 'pyreadline3 >= 3.4.1, < 4.0.0; sys_platform == "win32"', ] diff --git a/sgpt/app.py b/sgpt/app.py index 83f48e92..65231d34 100644 --- a/sgpt/app.py +++ b/sgpt/app.py @@ -160,7 +160,7 @@ def main( prompt, model=model, temperature=temperature, - top_probability=top_probability, + top_p=top_probability, chat_id=repl, caching=cache, ) @@ -170,7 +170,7 @@ def main( prompt, model=model, temperature=temperature, - top_probability=top_probability, + top_p=top_probability, chat_id=chat, caching=cache, ) @@ -179,7 +179,7 @@ def main( prompt, model=model, temperature=temperature, - top_probability=top_probability, + top_p=top_probability, caching=cache, ) @@ -199,7 +199,7 @@ def main( full_completion, model=model, temperature=temperature, - top_probability=top_probability, + top_p=top_probability, caching=cache, ) continue @@ -207,7 +207,6 @@ def main( def entry_point() -> None: - # Python package entry point defined in setup.py typer.run(main) diff --git a/sgpt/client.py b/sgpt/client.py deleted file mode 100644 index 355b255e..00000000 --- a/sgpt/client.py +++ /dev/null @@ -1,111 +0,0 @@ -import json -from pathlib import Path -from typing import Dict, Generator, List - -import requests -import typer - -from .cache import Cache -from .config import SHELL_GPT_CONFIG_PATH, cfg - -CACHE_LENGTH = int(cfg.get("CACHE_LENGTH")) -CACHE_PATH = Path(cfg.get("CACHE_PATH")) -REQUEST_TIMEOUT = int(cfg.get("REQUEST_TIMEOUT")) -DISABLE_STREAMING = str(cfg.get("DISABLE_STREAMING")) - - -class OpenAIClient: - cache = Cache(CACHE_LENGTH, CACHE_PATH) - - def __init__(self, api_host: str, api_key: str) -> None: - self.__api_key = api_key - self.api_host = api_host - - @cache - def _request( - self, - messages: List[Dict[str, str]], - model: str = "gpt-3.5-turbo", - temperature: float = 1, - top_probability: float = 1, - ) -> Generator[str, None, None]: - """ - Make request to OpenAI API, read more: - https://platform.openai.com/docs/api-reference/chat - - :param messages: List of messages {"role": user or assistant, "content": message_string} - :param model: String gpt-3.5-turbo or gpt-3.5-turbo-0301 - :param temperature: Float in 0.0 - 2.0 range. - :param top_probability: Float in 0.0 - 1.0 range. - :return: Response body JSON. - """ - stream = DISABLE_STREAMING == "false" - data = { - "messages": messages, - "model": model, - "temperature": temperature, - "top_p": top_probability, - "stream": stream, - } - endpoint = f"{self.api_host}/v1/chat/completions" - response = requests.post( - endpoint, - # Hide API key from Rich traceback. - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {self.__api_key}", - }, - json=data, - timeout=REQUEST_TIMEOUT, - stream=stream, - ) - # Check if OPENAI_API_KEY is valid - if response.status_code == 401 or response.status_code == 403: - typer.secho( - f"Invalid OpenAI API key, update your config file: {SHELL_GPT_CONFIG_PATH}", - fg="red", - ) - response.raise_for_status() - # TODO: Optimise. - # https://github.com/openai/openai-python/blob/237448dc072a2c062698da3f9f512fae38300c1c/openai/api_requestor.py#L98 - if not stream: - data = response.json() - yield data["choices"][0]["message"]["content"] # type: ignore - return - for line in response.iter_lines(): - data = line.lstrip(b"data: ").decode("utf-8") - if data == "[DONE]": # type: ignore - break - if not data: - continue - data = json.loads(data) # type: ignore - delta = data["choices"][0]["delta"] # type: ignore - if "content" not in delta: - continue - yield delta["content"] - - def get_completion( - self, - messages: List[Dict[str, str]], - model: str = "gpt-3.5-turbo", - temperature: float = 1, - top_probability: float = 1, - caching: bool = True, - ) -> Generator[str, None, None]: - """ - Generates single completion for prompt (message). - - :param messages: List of dict with messages and roles. - :param model: String gpt-3.5-turbo or gpt-3.5-turbo-0301. - :param temperature: Float in 0.0 - 1.0 range. - :param top_probability: Float in 0.0 - 1.0 range. - :param caching: Boolean value to enable/disable caching. - :return: String generated completion. - """ - yield from self._request( - messages, - model, - temperature, - top_probability, - caching=caching, - ) diff --git a/sgpt/config.py b/sgpt/config.py index af128add..34072cdb 100644 --- a/sgpt/config.py +++ b/sgpt/config.py @@ -22,7 +22,7 @@ "CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")), "REQUEST_TIMEOUT": int(os.getenv("REQUEST_TIMEOUT", "60")), "DEFAULT_MODEL": os.getenv("DEFAULT_MODEL", "gpt-4-1106-preview"), - "OPENAI_API_HOST": os.getenv("OPENAI_API_HOST", "https://api.openai.com"), + "OPENAI_BASE_URL": os.getenv("OPENAI_API_HOST", "https://api.openai.com/v1"), "DEFAULT_COLOR": os.getenv("DEFAULT_COLOR", "magenta"), "ROLE_STORAGE_PATH": os.getenv("ROLE_STORAGE_PATH", str(ROLE_STORAGE_PATH)), "DEFAULT_EXECUTE_SHELL_CMD": os.getenv("DEFAULT_EXECUTE_SHELL_CMD", "false"), diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index 0034214e..6d6e62f7 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -1,22 +1,28 @@ +from pathlib import Path from typing import Any, Dict, Generator, List import typer +from openai import OpenAI from rich.console import Console from rich.live import Live from rich.markdown import Markdown -from ..client import OpenAIClient +from ..cache import Cache from ..config import cfg from ..role import SystemRole +cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH"))) + class Handler: def __init__(self, role: SystemRole) -> None: - self.client = OpenAIClient( - cfg.get("OPENAI_API_HOST"), cfg.get("OPENAI_API_KEY") + self.client = OpenAI( + base_url=cfg.get("OPENAI_BASE_URL"), + api_key=cfg.get("OPENAI_API_KEY"), + timeout=int(cfg.get("REQUEST_TIMEOUT")), ) self.role = role - self.disable_stream = cfg.get("DISABLE_STREAMING") == "false" + self.disable_stream = cfg.get("DISABLE_STREAMING") == "true" self.color = cfg.get("DEFAULT_COLOR") self.theme_name = cfg.get("CODE_THEME") @@ -28,7 +34,7 @@ def _handle_with_markdown(self, prompt: str, **kwargs: Any) -> str: console=Console(), refresh_per_second=8, ) as live: - if not self.disable_stream: + if self.disable_stream: live.update( Markdown(markup="Loading...\r", code_theme=self.theme_name), refresh=True, @@ -44,7 +50,7 @@ def _handle_with_markdown(self, prompt: str, **kwargs: Any) -> str: def _handle_with_plain_text(self, prompt: str, **kwargs: Any) -> str: messages = self.make_messages(prompt.strip()) full_completion = "" - if not self.disable_stream: + if self.disable_stream: typer.echo("Loading...\r", nl=False) for word in self.get_completion(messages=messages, **kwargs): typer.secho(word, fg=self.color, bold=True, nl=False) @@ -56,8 +62,15 @@ def _handle_with_plain_text(self, prompt: str, **kwargs: Any) -> str: def make_messages(self, prompt: str) -> List[Dict[str, str]]: raise NotImplementedError + @cache def get_completion(self, **kwargs: Any) -> Generator[str, None, None]: - yield from self.client.get_completion(**kwargs) + if self.disable_stream: + completion = self.client.chat.completions.create(**kwargs) + yield completion.choices[0].message.content + return + + for chunk in self.client.chat.completions.create(**kwargs, stream=True): + yield from chunk.choices[0].delta.content or "" def handle(self, prompt: str, **kwargs: Any) -> str: if self.role.name == "ShellGPT" or self.role.name == "Shell Command Descriptor": diff --git a/tests/test_integration.py b/tests/test_integration.py index 3d290674..cd348b3f 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -378,7 +378,7 @@ def test_zsh_command(self): # assert "command not found" not in result.stdout # assert "hello world" in stdout.split("\n")[-1] - @patch("sgpt.client.OpenAIClient.get_completion") + @patch("sgpt.handlers.handler.Handler.get_completion") def test_model_option(self, mocked_get_completion): dict_arguments = { "prompt": "What is the capital of the Czech Republic?", @@ -389,7 +389,7 @@ def test_model_option(self, mocked_get_completion): messages=ANY, model="gpt-4", temperature=0.0, - top_probability=1.0, + top_p=1.0, caching=False, ) assert result.exit_code == 0 diff --git a/tests/test_unit.py b/tests/test_unit.py index 888305c0..1aacdc6e 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -1,72 +1,11 @@ -import os import unittest -import requests -import requests_mock - -from sgpt.client import OpenAIClient - class TestMain(unittest.TestCase): - API_HOST = os.getenv("OPENAI_HOST", "https://api.openai.com") - API_URL = f"{API_HOST}/v1/chat/completions" - # TODO: Fix tests. - - def setUp(self): - self.api_key = os.environ["OPENAI_API_KEY"] = "test key" - self.prompt = "What is the capital of France?" - self.shell = False - self.execute = False - self.code = False - self.animation = True - self.spinner = True - self.temperature = 1.0 - self.top_p = 1.0 - self.response_text = "Paris" - self.model = "gpt-3.5-turbo" - self.client = OpenAIClient(self.API_HOST, self.api_key) - - @requests_mock.Mocker() - def test_openai_request(self, mock): - # TODO: Fix tests. - mocked_json = {"choices": [{"message": {"content": self.response_text}}]} - mock.post(self.API_URL, json=mocked_json, status_code=200) - result = yield from self.client.get_completion( - messages=[{"role": "user", "content": self.prompt}], - model=self.model, - temperature=self.temperature, - top_probability=self.top_p, - caching=False, - ) - # TODO: Fix tests with generators. - self.assertEqual(result, self.response_text) - expected_json = { - "messages": [{"role": "user", "content": self.prompt}], - "model": "gpt-3.5-turbo", - "temperature": self.temperature, - "top_p": self.top_p, - } - expected_headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - request = mock.request_history[0] - self.assertEqual(request.json(), expected_json) - for key, value in expected_headers.items(): - self.assertEqual(request.headers[key], value) + # TODO: Write these tests. - @requests_mock.Mocker() - def test_openai_request_fail(self, mock): - # TODO: Fix tests. - mock.post(self.API_URL, status_code=400) - with self.assertRaises(requests.exceptions.HTTPError): - yield from self.client.get_completion( - messages=[{"role": "user", "content": self.prompt}], - model=self.model, - temperature=self.temperature, - top_probability=self.top_p, - caching=False, - ) + def test_main(self): + self.assertEqual(1, 1) if __name__ == "__main__":