Skip to content

Commit

Permalink
Added OpenAI SDK dependency (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheR1D authored Dec 23, 2023
1 parent 482ec9d commit 4b670cf
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 190 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"click >= 7.1.1, < 9.0.0",
"rich >= 13.1.0, < 14.0.0",
"distro >= 1.8.0, < 2.0.0",
"openai >= 1.6.1, < 2.0.0",
'pyreadline3 >= 3.4.1, < 4.0.0; sys_platform == "win32"',
]

Expand Down
9 changes: 4 additions & 5 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(
prompt,
model=model,
temperature=temperature,
top_probability=top_probability,
top_p=top_probability,
chat_id=repl,
caching=cache,
)
Expand All @@ -170,7 +170,7 @@ def main(
prompt,
model=model,
temperature=temperature,
top_probability=top_probability,
top_p=top_probability,
chat_id=chat,
caching=cache,
)
Expand All @@ -179,7 +179,7 @@ def main(
prompt,
model=model,
temperature=temperature,
top_probability=top_probability,
top_p=top_probability,
caching=cache,
)

Expand All @@ -199,15 +199,14 @@ def main(
full_completion,
model=model,
temperature=temperature,
top_probability=top_probability,
top_p=top_probability,
caching=cache,
)
continue
break


def entry_point() -> None:
# Python package entry point defined in setup.py
typer.run(main)


Expand Down
111 changes: 0 additions & 111 deletions sgpt/client.py

This file was deleted.

2 changes: 1 addition & 1 deletion sgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")),
"REQUEST_TIMEOUT": int(os.getenv("REQUEST_TIMEOUT", "60")),
"DEFAULT_MODEL": os.getenv("DEFAULT_MODEL", "gpt-4-1106-preview"),
"OPENAI_API_HOST": os.getenv("OPENAI_API_HOST", "https://api.openai.com"),
"OPENAI_BASE_URL": os.getenv("OPENAI_API_HOST", "https://api.openai.com/v1"),
"DEFAULT_COLOR": os.getenv("DEFAULT_COLOR", "magenta"),
"ROLE_STORAGE_PATH": os.getenv("ROLE_STORAGE_PATH", str(ROLE_STORAGE_PATH)),
"DEFAULT_EXECUTE_SHELL_CMD": os.getenv("DEFAULT_EXECUTE_SHELL_CMD", "false"),
Expand Down
27 changes: 20 additions & 7 deletions sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
from pathlib import Path
from typing import Any, Dict, Generator, List

import typer
from openai import OpenAI
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown

from ..client import OpenAIClient
from ..cache import Cache
from ..config import cfg
from ..role import SystemRole

cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH")))


class Handler:
def __init__(self, role: SystemRole) -> None:
self.client = OpenAIClient(
cfg.get("OPENAI_API_HOST"), cfg.get("OPENAI_API_KEY")
self.client = OpenAI(
base_url=cfg.get("OPENAI_BASE_URL"),
api_key=cfg.get("OPENAI_API_KEY"),
timeout=int(cfg.get("REQUEST_TIMEOUT")),
)
self.role = role
self.disable_stream = cfg.get("DISABLE_STREAMING") == "false"
self.disable_stream = cfg.get("DISABLE_STREAMING") == "true"
self.color = cfg.get("DEFAULT_COLOR")
self.theme_name = cfg.get("CODE_THEME")

Expand All @@ -28,7 +34,7 @@ def _handle_with_markdown(self, prompt: str, **kwargs: Any) -> str:
console=Console(),
refresh_per_second=8,
) as live:
if not self.disable_stream:
if self.disable_stream:
live.update(
Markdown(markup="Loading...\r", code_theme=self.theme_name),
refresh=True,
Expand All @@ -44,7 +50,7 @@ def _handle_with_markdown(self, prompt: str, **kwargs: Any) -> str:
def _handle_with_plain_text(self, prompt: str, **kwargs: Any) -> str:
messages = self.make_messages(prompt.strip())
full_completion = ""
if not self.disable_stream:
if self.disable_stream:
typer.echo("Loading...\r", nl=False)
for word in self.get_completion(messages=messages, **kwargs):
typer.secho(word, fg=self.color, bold=True, nl=False)
Expand All @@ -56,8 +62,15 @@ def _handle_with_plain_text(self, prompt: str, **kwargs: Any) -> str:
def make_messages(self, prompt: str) -> List[Dict[str, str]]:
raise NotImplementedError

@cache
def get_completion(self, **kwargs: Any) -> Generator[str, None, None]:
yield from self.client.get_completion(**kwargs)
if self.disable_stream:
completion = self.client.chat.completions.create(**kwargs)
yield completion.choices[0].message.content
return

for chunk in self.client.chat.completions.create(**kwargs, stream=True):
yield from chunk.choices[0].delta.content or ""

def handle(self, prompt: str, **kwargs: Any) -> str:
if self.role.name == "ShellGPT" or self.role.name == "Shell Command Descriptor":
Expand Down
4 changes: 2 additions & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def test_zsh_command(self):
# assert "command not found" not in result.stdout
# assert "hello world" in stdout.split("\n")[-1]

@patch("sgpt.client.OpenAIClient.get_completion")
@patch("sgpt.handlers.handler.Handler.get_completion")
def test_model_option(self, mocked_get_completion):
dict_arguments = {
"prompt": "What is the capital of the Czech Republic?",
Expand All @@ -389,7 +389,7 @@ def test_model_option(self, mocked_get_completion):
messages=ANY,
model="gpt-4",
temperature=0.0,
top_probability=1.0,
top_p=1.0,
caching=False,
)
assert result.exit_code == 0
Expand Down
67 changes: 3 additions & 64 deletions tests/test_unit.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,11 @@
import os
import unittest

import requests
import requests_mock

from sgpt.client import OpenAIClient


class TestMain(unittest.TestCase):
API_HOST = os.getenv("OPENAI_HOST", "https://api.openai.com")
API_URL = f"{API_HOST}/v1/chat/completions"
# TODO: Fix tests.

def setUp(self):
self.api_key = os.environ["OPENAI_API_KEY"] = "test key"
self.prompt = "What is the capital of France?"
self.shell = False
self.execute = False
self.code = False
self.animation = True
self.spinner = True
self.temperature = 1.0
self.top_p = 1.0
self.response_text = "Paris"
self.model = "gpt-3.5-turbo"
self.client = OpenAIClient(self.API_HOST, self.api_key)

@requests_mock.Mocker()
def test_openai_request(self, mock):
# TODO: Fix tests.
mocked_json = {"choices": [{"message": {"content": self.response_text}}]}
mock.post(self.API_URL, json=mocked_json, status_code=200)
result = yield from self.client.get_completion(
messages=[{"role": "user", "content": self.prompt}],
model=self.model,
temperature=self.temperature,
top_probability=self.top_p,
caching=False,
)
# TODO: Fix tests with generators.
self.assertEqual(result, self.response_text)
expected_json = {
"messages": [{"role": "user", "content": self.prompt}],
"model": "gpt-3.5-turbo",
"temperature": self.temperature,
"top_p": self.top_p,
}
expected_headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
request = mock.request_history[0]
self.assertEqual(request.json(), expected_json)
for key, value in expected_headers.items():
self.assertEqual(request.headers[key], value)
# TODO: Write these tests.

@requests_mock.Mocker()
def test_openai_request_fail(self, mock):
# TODO: Fix tests.
mock.post(self.API_URL, status_code=400)
with self.assertRaises(requests.exceptions.HTTPError):
yield from self.client.get_completion(
messages=[{"role": "user", "content": self.prompt}],
model=self.model,
temperature=self.temperature,
top_probability=self.top_p,
caching=False,
)
def test_main(self):
self.assertEqual(1, 1)


if __name__ == "__main__":
Expand Down

0 comments on commit 4b670cf

Please sign in to comment.