Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Markdown option and config variable #481

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def main(
max=1.0,
help="Limits highest probable tokens (words).",
),
md: bool = typer.Option(
cfg.get("PRETTIFY_MARKDOWN") == "true",
help="Prettify markdown output.",
),
shell: bool = typer.Option(
False,
"--shell",
Expand Down Expand Up @@ -203,7 +207,7 @@ def main(

if repl:
# Will be in infinite loop here until user exits with Ctrl+C.
ReplHandler(repl, role_class).handle(
ReplHandler(repl, role_class, md).handle(
init_prompt=prompt,
model=model,
temperature=temperature,
Expand All @@ -213,7 +217,7 @@ def main(
)

if chat:
full_completion = ChatHandler(chat, role_class).handle(
full_completion = ChatHandler(chat, role_class, md).handle(
prompt=prompt,
model=model,
temperature=temperature,
Expand All @@ -222,7 +226,7 @@ def main(
functions=function_schemas,
)
else:
full_completion = DefaultHandler(role_class).handle(
full_completion = DefaultHandler(role_class, md).handle(
prompt=prompt,
model=model,
temperature=temperature,
Expand All @@ -243,7 +247,7 @@ def main(
# "y" option is for keeping compatibility with old version.
run_command(full_completion)
elif option == "d":
DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role()).handle(
DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role(), md).handle(
full_completion,
model=model,
temperature=temperature,
Expand Down
1 change: 1 addition & 0 deletions sgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"OPENAI_USE_FUNCTIONS": os.getenv("OPENAI_USE_FUNCTIONS", "true"),
"SHOW_FUNCTIONS_OUTPUT": os.getenv("SHOW_FUNCTIONS_OUTPUT", "false"),
"API_BASE_URL": os.getenv("API_BASE_URL", "default"),
"PRETTIFY_MARKDOWN": os.getenv("PRETTIFY_MARKDOWN", "true"),
# New features might add their own config variables here.
}

Expand Down
4 changes: 2 additions & 2 deletions sgpt/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def list(self) -> List[Path]:
class ChatHandler(Handler):
chat_session = ChatSession(CHAT_CACHE_LENGTH, CHAT_CACHE_PATH)

def __init__(self, chat_id: str, role: SystemRole) -> None:
super().__init__(role)
def __init__(self, chat_id: str, role: SystemRole, markdown: bool) -> None:
super().__init__(role, markdown)
self.chat_id = chat_id
self.role = role

Expand Down
4 changes: 2 additions & 2 deletions sgpt/handlers/default_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


class DefaultHandler(Handler):
def __init__(self, role: SystemRole) -> None:
super().__init__(role)
def __init__(self, role: SystemRole, markdown: bool) -> None:
super().__init__(role, markdown)
self.role = role

def make_messages(self, prompt: str) -> List[Dict[str, str]]:
Expand Down
13 changes: 9 additions & 4 deletions sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,23 @@
class Handler:
cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH")))

def __init__(self, role: SystemRole) -> None:
def __init__(self, role: SystemRole, markdown: bool) -> None:
self.role = role

api_base_url = cfg.get("API_BASE_URL")
self.base_url = None if api_base_url == "default" else api_base_url
self.timeout = int(cfg.get("REQUEST_TIMEOUT"))

self.markdown = "APPLY MARKDOWN" in self.role.role and markdown
self.code_theme, self.color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR")

@property
def printer(self) -> Printer:
use_markdown = "APPLY MARKDOWN" in self.role.role
code_theme, color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR")
return MarkdownPrinter(code_theme) if use_markdown else TextPrinter(color)
return (
MarkdownPrinter(self.code_theme)
if self.markdown
else TextPrinter(self.color)
)

def make_messages(self, prompt: str) -> List[Dict[str, str]]:
raise NotImplementedError
Expand Down
10 changes: 5 additions & 5 deletions sgpt/handlers/repl_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


class ReplHandler(ChatHandler):
def __init__(self, chat_id: str, role: SystemRole) -> None:
super().__init__(chat_id, role)
def __init__(self, chat_id: str, role: SystemRole, markdown: bool) -> None:
super().__init__(chat_id, role, markdown)

@classmethod
def _get_multiline_input(cls) -> str:
Expand Down Expand Up @@ -59,8 +59,8 @@ def handle(self, init_prompt: str, **kwargs: Any) -> None: # type: ignore
typer.echo()
rich_print(Rule(style="bold magenta"))
elif self.role.name == DefaultRoles.SHELL.value and prompt == "d":
DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role()).handle(
prompt=full_completion, **kwargs
)
DefaultHandler(
DefaultRoles.DESCRIBE_SHELL.get_role(), self.markdown
).handle(prompt=full_completion, **kwargs)
else:
full_completion = super().handle(prompt=prompt, **kwargs)
21 changes: 18 additions & 3 deletions tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,32 @@


@patch("litellm.completion")
def test_code_generation(mock):
mock.return_value = mock_comp("print('Hello World')")
def test_code_generation(completion):
completion.return_value = mock_comp("print('Hello World')")

args = {"prompt": "hello world python", "--code": True}
result = runner.invoke(app, cmd_args(**args))

mock.assert_called_once_with(**comp_args(role, args["prompt"]))
completion.assert_called_once_with(**comp_args(role, args["prompt"]))
assert result.exit_code == 0
assert "print('Hello World')" in result.stdout


@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
def test_code_generation_no_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("print('Hello World')")

args = {"prompt": "make a commit using git", "--code": True, "--md": True}
result = runner.invoke(app, cmd_args(**args))

assert result.exit_code == 0
# Should ignore --md for --code option and output code without markdown.
markdown_printer.assert_not_called()
text_printer.assert_called()


@patch("litellm.completion")
def test_code_generation_stdin(completion):
completion.return_value = mock_comp("# Hello\nprint('Hello')")
Expand Down
26 changes: 26 additions & 0 deletions tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,29 @@ def test_version(completion):

completion.assert_not_called()
assert __version__ in result.stdout


@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
def test_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("pong")

args = {"prompt": "ping", "--md": True}
result = runner.invoke(app, cmd_args(**args))
assert result.exit_code == 0
markdown_printer.assert_called()
text_printer.assert_not_called()


@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
def test_no_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("pong")

args = {"prompt": "ping", "--no-md": True}
result = runner.invoke(app, cmd_args(**args))
assert result.exit_code == 0
markdown_printer.assert_not_called()
text_printer.assert_called()
15 changes: 15 additions & 0 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ def test_shell(completion):
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout


@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("litellm.completion")
def test_shell_no_markdown(completion, markdown_printer, text_printer):
completion.return_value = mock_comp("git commit -m test")

args = {"prompt": "make a commit using git", "--shell": True, "--md": True}
result = runner.invoke(app, cmd_args(**args))

assert result.exit_code == 0
# Should ignore --md for --shell option and output text without markdown.
markdown_printer.assert_not_called()
text_printer.assert_called()


@patch("litellm.completion")
def test_shell_stdin(completion):
completion.return_value = mock_comp("ls -l | sort")
Expand Down
Loading