From e0af0e28896df27d8db26bd80e12ccb4ee56c62f Mon Sep 17 00:00:00 2001 From: Farkhod Sadykov Date: Sat, 17 Feb 2024 03:47:28 +0100 Subject: [PATCH] Markdown option and config variable --- sgpt/app.py | 12 ++++++++---- sgpt/config.py | 1 + sgpt/handlers/chat_handler.py | 4 ++-- sgpt/handlers/default_handler.py | 4 ++-- sgpt/handlers/handler.py | 13 +++++++++---- sgpt/handlers/repl_handler.py | 10 +++++----- tests/test_code.py | 21 ++++++++++++++++++--- tests/test_default.py | 26 ++++++++++++++++++++++++++ tests/test_shell.py | 15 +++++++++++++++ 9 files changed, 86 insertions(+), 20 deletions(-) diff --git a/sgpt/app.py b/sgpt/app.py index 1851bae6..bd092165 100644 --- a/sgpt/app.py +++ b/sgpt/app.py @@ -45,6 +45,10 @@ def main( max=1.0, help="Limits highest probable tokens (words).", ), + md: bool = typer.Option( + cfg.get("PRETTIFY_MARKDOWN") == "true", + help="Prettify markdown output.", + ), shell: bool = typer.Option( False, "--shell", @@ -203,7 +207,7 @@ def main( if repl: # Will be in infinite loop here until user exits with Ctrl+C. - ReplHandler(repl, role_class).handle( + ReplHandler(repl, role_class, md).handle( init_prompt=prompt, model=model, temperature=temperature, @@ -213,7 +217,7 @@ def main( ) if chat: - full_completion = ChatHandler(chat, role_class).handle( + full_completion = ChatHandler(chat, role_class, md).handle( prompt=prompt, model=model, temperature=temperature, @@ -222,7 +226,7 @@ def main( functions=function_schemas, ) else: - full_completion = DefaultHandler(role_class).handle( + full_completion = DefaultHandler(role_class, md).handle( prompt=prompt, model=model, temperature=temperature, @@ -243,7 +247,7 @@ def main( # "y" option is for keeping compatibility with old version. run_command(full_completion) elif option == "d": - DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role()).handle( + DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role(), md).handle( full_completion, model=model, temperature=temperature, diff --git a/sgpt/config.py b/sgpt/config.py index 20c71048..4c603470 100644 --- a/sgpt/config.py +++ b/sgpt/config.py @@ -32,6 +32,7 @@ "OPENAI_USE_FUNCTIONS": os.getenv("OPENAI_USE_FUNCTIONS", "true"), "SHOW_FUNCTIONS_OUTPUT": os.getenv("SHOW_FUNCTIONS_OUTPUT", "false"), "API_BASE_URL": os.getenv("API_BASE_URL", "default"), + "PRETTIFY_MARKDOWN": os.getenv("PRETTIFY_MARKDOWN", "true"), # New features might add their own config variables here. } diff --git a/sgpt/handlers/chat_handler.py b/sgpt/handlers/chat_handler.py index 8704507d..67190d16 100644 --- a/sgpt/handlers/chat_handler.py +++ b/sgpt/handlers/chat_handler.py @@ -94,8 +94,8 @@ def list(self) -> List[Path]: class ChatHandler(Handler): chat_session = ChatSession(CHAT_CACHE_LENGTH, CHAT_CACHE_PATH) - def __init__(self, chat_id: str, role: SystemRole) -> None: - super().__init__(role) + def __init__(self, chat_id: str, role: SystemRole, markdown: bool) -> None: + super().__init__(role, markdown) self.chat_id = chat_id self.role = role diff --git a/sgpt/handlers/default_handler.py b/sgpt/handlers/default_handler.py index 3a56f723..e0fdad13 100644 --- a/sgpt/handlers/default_handler.py +++ b/sgpt/handlers/default_handler.py @@ -10,8 +10,8 @@ class DefaultHandler(Handler): - def __init__(self, role: SystemRole) -> None: - super().__init__(role) + def __init__(self, role: SystemRole, markdown: bool) -> None: + super().__init__(role, markdown) self.role = role def make_messages(self, prompt: str) -> List[Dict[str, str]]: diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index 0cc7dc40..d4d05c2a 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -16,18 +16,23 @@ class Handler: cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH"))) - def __init__(self, role: SystemRole) -> None: + def __init__(self, role: SystemRole, markdown: bool) -> None: self.role = role api_base_url = cfg.get("API_BASE_URL") self.base_url = None if api_base_url == "default" else api_base_url self.timeout = int(cfg.get("REQUEST_TIMEOUT")) + self.markdown = "APPLY MARKDOWN" in self.role.role and markdown + self.code_theme, self.color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR") + @property def printer(self) -> Printer: - use_markdown = "APPLY MARKDOWN" in self.role.role - code_theme, color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR") - return MarkdownPrinter(code_theme) if use_markdown else TextPrinter(color) + return ( + MarkdownPrinter(self.code_theme) + if self.markdown + else TextPrinter(self.color) + ) def make_messages(self, prompt: str) -> List[Dict[str, str]]: raise NotImplementedError diff --git a/sgpt/handlers/repl_handler.py b/sgpt/handlers/repl_handler.py index 0014bb61..df03f7df 100644 --- a/sgpt/handlers/repl_handler.py +++ b/sgpt/handlers/repl_handler.py @@ -11,8 +11,8 @@ class ReplHandler(ChatHandler): - def __init__(self, chat_id: str, role: SystemRole) -> None: - super().__init__(chat_id, role) + def __init__(self, chat_id: str, role: SystemRole, markdown: bool) -> None: + super().__init__(chat_id, role, markdown) @classmethod def _get_multiline_input(cls) -> str: @@ -59,8 +59,8 @@ def handle(self, init_prompt: str, **kwargs: Any) -> None: # type: ignore typer.echo() rich_print(Rule(style="bold magenta")) elif self.role.name == DefaultRoles.SHELL.value and prompt == "d": - DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role()).handle( - prompt=full_completion, **kwargs - ) + DefaultHandler( + DefaultRoles.DESCRIBE_SHELL.get_role(), self.markdown + ).handle(prompt=full_completion, **kwargs) else: full_completion = super().handle(prompt=prompt, **kwargs) diff --git a/tests/test_code.py b/tests/test_code.py index 1b49f8da..e53166db 100644 --- a/tests/test_code.py +++ b/tests/test_code.py @@ -10,17 +10,32 @@ @patch("litellm.completion") -def test_code_generation(mock): - mock.return_value = mock_comp("print('Hello World')") +def test_code_generation(completion): + completion.return_value = mock_comp("print('Hello World')") args = {"prompt": "hello world python", "--code": True} result = runner.invoke(app, cmd_args(**args)) - mock.assert_called_once_with(**comp_args(role, args["prompt"])) + completion.assert_called_once_with(**comp_args(role, args["prompt"])) assert result.exit_code == 0 assert "print('Hello World')" in result.stdout +@patch("sgpt.printer.TextPrinter.live_print") +@patch("sgpt.printer.MarkdownPrinter.live_print") +@patch("litellm.completion") +def test_code_generation_no_markdown(completion, markdown_printer, text_printer): + completion.return_value = mock_comp("print('Hello World')") + + args = {"prompt": "make a commit using git", "--code": True, "--md": True} + result = runner.invoke(app, cmd_args(**args)) + + assert result.exit_code == 0 + # Should ignore --md for --code option and output code without markdown. + markdown_printer.assert_not_called() + text_printer.assert_called() + + @patch("litellm.completion") def test_code_generation_stdin(completion): completion.return_value = mock_comp("# Hello\nprint('Hello')") diff --git a/tests/test_default.py b/tests/test_default.py index 36be294f..b1c95aa5 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -223,3 +223,29 @@ def test_version(completion): completion.assert_not_called() assert __version__ in result.stdout + + +@patch("sgpt.printer.TextPrinter.live_print") +@patch("sgpt.printer.MarkdownPrinter.live_print") +@patch("litellm.completion") +def test_markdown(completion, markdown_printer, text_printer): + completion.return_value = mock_comp("pong") + + args = {"prompt": "ping", "--md": True} + result = runner.invoke(app, cmd_args(**args)) + assert result.exit_code == 0 + markdown_printer.assert_called() + text_printer.assert_not_called() + + +@patch("sgpt.printer.TextPrinter.live_print") +@patch("sgpt.printer.MarkdownPrinter.live_print") +@patch("litellm.completion") +def test_no_markdown(completion, markdown_printer, text_printer): + completion.return_value = mock_comp("pong") + + args = {"prompt": "ping", "--no-md": True} + result = runner.invoke(app, cmd_args(**args)) + assert result.exit_code == 0 + markdown_printer.assert_not_called() + text_printer.assert_called() diff --git a/tests/test_shell.py b/tests/test_shell.py index a008625a..f99f4936 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -22,6 +22,21 @@ def test_shell(completion): assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout +@patch("sgpt.printer.TextPrinter.live_print") +@patch("sgpt.printer.MarkdownPrinter.live_print") +@patch("litellm.completion") +def test_shell_no_markdown(completion, markdown_printer, text_printer): + completion.return_value = mock_comp("git commit -m test") + + args = {"prompt": "make a commit using git", "--shell": True, "--md": True} + result = runner.invoke(app, cmd_args(**args)) + + assert result.exit_code == 0 + # Should ignore --md for --shell option and output text without markdown. + markdown_printer.assert_not_called() + text_printer.assert_called() + + @patch("litellm.completion") def test_shell_stdin(completion): completion.return_value = mock_comp("ls -l | sort")