Skip to content

Commit

Permalink
Markdown option and config variable
Browse files Browse the repository at this point in the history
  • Loading branch information
TheR1D committed Feb 17, 2024
1 parent ecb7b26 commit b278c30
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 17 deletions.
12 changes: 8 additions & 4 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def main(
max=1.0,
help="Limits highest probable tokens (words).",
),
md: bool = typer.Option(
cfg.get("PRETTIFY_MARKDOWN") == "true",
help="Prettify markdown output.",
),
shell: bool = typer.Option(
False,
"--shell",
Expand Down Expand Up @@ -203,7 +207,7 @@ def main(

if repl:
# Will be in infinite loop here until user exits with Ctrl+C.
ReplHandler(repl, role_class).handle(
ReplHandler(repl, role_class, md).handle(
init_prompt=prompt,
model=model,
temperature=temperature,
Expand All @@ -213,7 +217,7 @@ def main(
)

if chat:
full_completion = ChatHandler(chat, role_class).handle(
full_completion = ChatHandler(chat, role_class, md).handle(
prompt=prompt,
model=model,
temperature=temperature,
Expand All @@ -222,7 +226,7 @@ def main(
functions=function_schemas,
)
else:
full_completion = DefaultHandler(role_class).handle(
full_completion = DefaultHandler(role_class, md).handle(
prompt=prompt,
model=model,
temperature=temperature,
Expand All @@ -243,7 +247,7 @@ def main(
# "y" option is for keeping compatibility with old version.
run_command(full_completion)
elif option == "d":
DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role()).handle(
DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role(), md).handle(
full_completion,
model=model,
temperature=temperature,
Expand Down
1 change: 1 addition & 0 deletions sgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"OPENAI_USE_FUNCTIONS": os.getenv("OPENAI_USE_FUNCTIONS", "true"),
"SHOW_FUNCTIONS_OUTPUT": os.getenv("SHOW_FUNCTIONS_OUTPUT", "false"),
"API_BASE_URL": os.getenv("API_BASE_URL", "default"),
"PRETTIFY_MARKDOWN": os.getenv("PRETTIFY_MARKDOWN", "true"),
# New features might add their own config variables here.
}

Expand Down
4 changes: 2 additions & 2 deletions sgpt/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def list(self) -> List[Path]:
class ChatHandler(Handler):
chat_session = ChatSession(CHAT_CACHE_LENGTH, CHAT_CACHE_PATH)

def __init__(self, chat_id: str, role: SystemRole) -> None:
super().__init__(role)
def __init__(self, chat_id: str, role: SystemRole, markdown: bool) -> None:
super().__init__(role, markdown)
self.chat_id = chat_id
self.role = role

Expand Down
4 changes: 2 additions & 2 deletions sgpt/handlers/default_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


class DefaultHandler(Handler):
def __init__(self, role: SystemRole) -> None:
super().__init__(role)
def __init__(self, role: SystemRole, markdown: bool) -> None:
super().__init__(role, markdown)
self.role = role

def make_messages(self, prompt: str) -> List[Dict[str, str]]:
Expand Down
15 changes: 11 additions & 4 deletions sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,25 @@
class Handler:
cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH")))

def __init__(self, role: SystemRole) -> None:
def __init__(self, role: SystemRole, markdown: bool) -> None:
self.role = role

api_base_url = cfg.get("API_BASE_URL")
self.base_url = None if api_base_url == "default" else api_base_url
self.timeout = int(cfg.get("REQUEST_TIMEOUT"))

config_markdown = cfg.get("PRETTIFY_MARKDOWN") == "true"
role_markdown = "APPLY MARKDOWN" in self.role.role
self.markdown = markdown and config_markdown and role_markdown
self.code_theme, self.color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR")

@property
def printer(self) -> Printer:
use_markdown = "APPLY MARKDOWN" in self.role.role
code_theme, color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR")
return MarkdownPrinter(code_theme) if use_markdown else TextPrinter(color)
return (
MarkdownPrinter(self.code_theme)
if self.markdown
else TextPrinter(self.color)
)

def make_messages(self, prompt: str) -> List[Dict[str, str]]:
raise NotImplementedError
Expand Down
10 changes: 5 additions & 5 deletions sgpt/handlers/repl_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


class ReplHandler(ChatHandler):
def __init__(self, chat_id: str, role: SystemRole) -> None:
super().__init__(chat_id, role)
def __init__(self, chat_id: str, role: SystemRole, markdown: bool) -> None:
super().__init__(chat_id, role, markdown)

@classmethod
def _get_multiline_input(cls) -> str:
Expand Down Expand Up @@ -59,8 +59,8 @@ def handle(self, init_prompt: str, **kwargs: Any) -> None: # type: ignore
typer.echo()
rich_print(Rule(style="bold magenta"))
elif self.role.name == DefaultRoles.SHELL.value and prompt == "d":
DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role()).handle(
prompt=full_completion, **kwargs
)
DefaultHandler(
DefaultRoles.DESCRIBE_SHELL.get_role(), self.markdown
).handle(prompt=full_completion, **kwargs)
else:
full_completion = super().handle(prompt=prompt, **kwargs)

0 comments on commit b278c30

Please sign in to comment.