From 944a98987e6223962fd49b1fd8c405591b8b09ea Mon Sep 17 00:00:00 2001 From: chansung Date: Fri, 14 Jul 2023 03:57:26 +0000 Subject: [PATCH] add internet search cap --- app.py | 21 ++++++-- chats/alpaca.py | 83 +++++------------------------- chats/alpaca_gpt4.py | 83 +++++------------------------- chats/alpacoom.py | 77 +++++----------------------- chats/baize.py | 68 +++++-------------------- chats/central.py | 113 ++++++++++++++++++++++++++++------------- chats/custom.py | 67 +++++------------------- chats/falcon.py | 85 +++++-------------------------- chats/flan_alpaca.py | 83 +++++------------------------- chats/guanaco.py | 87 ++++++------------------------- chats/koalpaca.py | 83 +++++------------------------- chats/mpt.py | 85 +++++-------------------------- chats/os_stablelm.py | 87 ++++++------------------------- chats/redpajama.py | 87 +++++-------------------------- chats/stable_vicuna.py | 83 +++++------------------------- chats/stablelm.py | 87 ++++++------------------------- chats/starchat.py | 85 +++++-------------------------- chats/utils.py | 52 +++++++++++++++++++ chats/vicuna.py | 83 +++++------------------------- chats/wizard_coder.py | 85 +++++-------------------------- chats/wizard_falcon.py | 85 +++++-------------------------- chats/xgen.py | 35 +++++-------- discord_app.py | 99 +++++++++++++++++++++++++++++------- discordbot/flags.py | 9 ++-- discordbot/req.py | 20 ++++++-- entry_point.py | 3 ++ 26 files changed, 507 insertions(+), 1328 deletions(-) create mode 100644 chats/utils.py diff --git a/app.py b/app.py index d30c35dd..388dfed8 100644 --- a/app.py +++ b/app.py @@ -1,3 +1,4 @@ +import os import time import json import copy @@ -828,6 +829,15 @@ def gradio_main(args): elem_id="global-context" ) + gr.Markdown("#### Internet search") + with gr.Row(): + internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode") + serper_api_key = gr.Textbox( + value= "" if args.serper_api_key is None else args.serper_api_key, + placeholder="Get one by visiting serper.dev", + label="Serper api key" + ) + gr.Markdown("#### GenConfig for **response** text generation") with gr.Row(): res_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True) @@ -998,7 +1008,8 @@ def gradio_main(args): [idx, local_data, instruction_txtbox, chat_state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid], + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key], [instruction_txtbox, chatbot, context_inspector, local_data], ) @@ -1016,8 +1027,9 @@ def gradio_main(args): [idx, local_data, instruction_txtbox, chat_state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid], - [instruction_txtbox, chatbot, context_inspector, local_data], + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key], + [instruction_txtbox, chatbot, context_inspector, local_data], ).then( lambda: gr.update(interactive=True), None, @@ -1065,9 +1077,10 @@ def gradio_main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--root-path', default="") - parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction) + parser.add_argument('--local-files-only', default=local_files_only, action=argparse.BooleanOptionalAction) parser.add_argument('--share', default=False, action=argparse.BooleanOptionalAction) parser.add_argument('--debug', default=False, action=argparse.BooleanOptionalAction) + parser.add_argument('--serper-api-key', default=None, type=str) args = parser.parse_args() gradio_main(args) diff --git a/chats/alpaca.py b/chats/alpaca.py index f6f06c0e..20b610d6 100644 --- a/chats/alpaca.py +++ b/chats/alpaca.py @@ -5,64 +5,14 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy - -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split("### Response:")[-1].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager +from chats.utils import build_prompts, text_stream, internet_search def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -75,11 +25,18 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, return_token_type_ids=False @@ -91,18 +48,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/alpaca_gpt4.py b/chats/alpaca_gpt4.py index f6f06c0e..20b610d6 100644 --- a/chats/alpaca_gpt4.py +++ b/chats/alpaca_gpt4.py @@ -5,64 +5,14 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy - -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split("### Response:")[-1].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager +from chats.utils import build_prompts, text_stream, internet_search def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -75,11 +25,18 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, return_token_type_ids=False @@ -91,18 +48,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/alpacoom.py b/chats/alpacoom.py index 51dd51d2..20b610d6 100644 --- a/chats/alpacoom.py +++ b/chats/alpacoom.py @@ -5,58 +5,14 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy - -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split(" ")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split(" ")[1:]) - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - for new_text in streamer: - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split("### Response:")[-1].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager +from chats.utils import build_prompts, text_stream, internet_search def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -69,11 +25,18 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, return_token_type_ids=False @@ -85,18 +48,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/baize.py b/chats/baize.py index b634a893..f72d2a09 100644 --- a/chats/baize.py +++ b/chats/baize.py @@ -5,24 +5,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy - -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt +from chats.utils import build_prompts, internet_search def text_stream(ppmanager, streamer): count = 0 @@ -41,32 +24,12 @@ def text_stream(ppmanager, streamer): yield ppmanager, ppmanager.build_uis() -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split("### Response:")[-1].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -79,11 +42,18 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, return_token_type_ids=False @@ -95,18 +65,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/central.py b/chats/central.py index 0f4ca1fb..446fffa0 100644 --- a/chats/central.py +++ b/chats/central.py @@ -21,16 +21,23 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): model_type = state["model_type"] + + if internet_option == "on" and serper_api_key.strip() != "": + internet_option = True + else: + internet_option = False if model_type == "custom": cs = custom.chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "xgen": @@ -38,7 +45,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "stablelm": @@ -46,7 +54,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "falcon": @@ -54,7 +63,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "wizard-falcon": @@ -62,7 +72,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "baize": @@ -70,7 +81,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "alpaca": @@ -78,7 +90,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "openllama": @@ -86,7 +99,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "orcamini": @@ -94,7 +108,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "alpaca-gpt4": @@ -102,7 +117,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "nous-hermes": @@ -110,7 +126,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "replit-instruct": @@ -118,7 +135,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "alpacoom": @@ -126,7 +144,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "llama-deus": @@ -134,7 +153,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "camel": @@ -142,7 +162,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "koalpaca-polyglot": @@ -150,7 +171,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "kullm-polyglot": @@ -158,7 +180,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "flan-alpaca": @@ -166,7 +189,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "os-stablelm": @@ -174,7 +198,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "t5-vicuna": @@ -182,7 +207,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "stable-vicuna": @@ -190,7 +216,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "vicuna": @@ -198,7 +225,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "wizardlm": @@ -206,7 +234,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "wizard-vicuna": @@ -214,7 +243,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "airoboros": @@ -222,7 +252,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "samantha-vicuna": @@ -230,7 +261,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "evolinstruct-vicuna": @@ -238,7 +270,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "starchat": @@ -246,7 +279,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "wizard-coder": @@ -254,7 +288,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "mpt": @@ -262,7 +297,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "redpajama": @@ -270,7 +306,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "redpajama-instruct": @@ -278,7 +315,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "guanaco": @@ -286,7 +324,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "lazarus": @@ -294,7 +333,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) elif model_type == "chronos": @@ -302,7 +342,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ) for idx, x in enumerate(cs): diff --git a/chats/custom.py b/chats/custom.py index a2d9feaf..510fb2f6 100644 --- a/chats/custom.py +++ b/chats/custom.py @@ -5,24 +5,8 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, internet_search -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt def text_stream(ppmanager, streamer): count = 0 thumbnail_tiny = "https://i.ibb.co/f80BpgR/byom.png" @@ -37,32 +21,12 @@ def text_stream(ppmanager, streamer): yield ppmanager, ppmanager.build_uis() -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split("### Response:")[-1].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -75,11 +39,18 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, return_token_type_ids=False @@ -91,18 +62,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/falcon.py b/chats/falcon.py index 287bbde3..79042bbf 100644 --- a/chats/falcon.py +++ b/chats/falcon.py @@ -8,7 +8,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, text_stream, internet_search class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: @@ -18,62 +18,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return True return False -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split("### Response:")[-1].split("### Input:")[0].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -86,14 +36,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - StoppingCriteriaList([StopOnTokens()]), False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) @@ -102,18 +59,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/flan_alpaca.py b/chats/flan_alpaca.py index f7478583..20b610d6 100644 --- a/chats/flan_alpaca.py +++ b/chats/flan_alpaca.py @@ -5,64 +5,14 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy - -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split("-----")[-1].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager +from chats.utils import build_prompts, text_stream, internet_search def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -75,11 +25,18 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, return_token_type_ids=False @@ -91,18 +48,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/guanaco.py b/chats/guanaco.py index 29c6accf..14334328 100644 --- a/chats/guanaco.py +++ b/chats/guanaco.py @@ -8,7 +8,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, text_stream, internet_search class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: @@ -19,62 +19,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return True return False -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split(prompt_to_summarize)[-1].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -87,14 +37,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - StoppingCriteriaList([StopOnTokens()]), False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) @@ -103,18 +60,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) - yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file + yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/koalpaca.py b/chats/koalpaca.py index a27d8103..20b610d6 100644 --- a/chats/koalpaca.py +++ b/chats/koalpaca.py @@ -5,64 +5,14 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy - -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split("### 응답:")[-1].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager +from chats.utils import build_prompts, text_stream, internet_search def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -75,11 +25,18 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, return_token_type_ids=False @@ -91,18 +48,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/mpt.py b/chats/mpt.py index 6c03bd15..eb69d597 100644 --- a/chats/mpt.py +++ b/chats/mpt.py @@ -8,7 +8,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, text_stream, internet_search class StopOnTokens(StoppingCriteria): def __init__(self, tokenizer): @@ -23,63 +23,13 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa if input_ids[0][-1] == stop_id: return True return False - -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -92,14 +42,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - StoppingCriteriaList([StopOnTokens(global_vars.tokenizer)]), False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) @@ -108,18 +65,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) - yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file + yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/os_stablelm.py b/chats/os_stablelm.py index d9bdfda4..0781f5a3 100644 --- a/chats/os_stablelm.py +++ b/chats/os_stablelm.py @@ -8,7 +8,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, text_stream, internet_search class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: @@ -18,62 +18,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return True return False -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split(prompt_to_summarize)[-1].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -86,14 +36,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - StoppingCriteriaList([StopOnTokens()]), False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) @@ -102,18 +59,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) - yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file + yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/redpajama.py b/chats/redpajama.py index 63d36438..f285f5f1 100644 --- a/chats/redpajama.py +++ b/chats/redpajama.py @@ -8,7 +8,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, text_stream, internet_search class StopOnTokens(StoppingCriteria): # ref: https://github.com/togethercomputer/OpenChatKit/blob/7a931c7d7cf3602c93e00db6e27bdc09d3b5f70f/inference/bot.py @@ -29,62 +29,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return True return False -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -97,18 +47,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - StoppingCriteriaList([StopOnTokens( - global_vars.tokenizer, - [":"], - None, - )]), False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) @@ -120,18 +73,4 @@ def chat_stream( ppm.pingpongs[-1].pong = ppm.pingpongs[-1].pong[:-1] ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/stable_vicuna.py b/chats/stable_vicuna.py index 5ee3ece3..e4a7d0ba 100644 --- a/chats/stable_vicuna.py +++ b/chats/stable_vicuna.py @@ -8,7 +8,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, text_stream, internet_search class StopOnTokens(StoppingCriteria): def __init__(self, tokenizer): @@ -24,62 +24,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return True return False -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split(" ")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split(" ")[1:]) - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]:*** ") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -92,14 +42,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - StoppingCriteriaList([StopOnTokens(global_vars.tokenizer)]), False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) @@ -108,18 +65,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/stablelm.py b/chats/stablelm.py index d9bdfda4..0781f5a3 100644 --- a/chats/stablelm.py +++ b/chats/stablelm.py @@ -8,7 +8,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, text_stream, internet_search class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: @@ -18,62 +18,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return True return False -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split(prompt_to_summarize)[-1].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -86,14 +36,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - StoppingCriteriaList([StopOnTokens()]), False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) @@ -102,18 +59,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) - yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file + yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/starchat.py b/chats/starchat.py index 9910bea4..c205bcb7 100644 --- a/chats/starchat.py +++ b/chats/starchat.py @@ -8,7 +8,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, text_stream, internet_search class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: @@ -18,62 +18,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return True return False -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split(" ")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split(" ")[1:]) - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -86,14 +36,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - StoppingCriteriaList([StopOnTokens()]), False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) @@ -102,18 +59,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/utils.py b/chats/utils.py new file mode 100644 index 00000000..1282ddaf --- /dev/null +++ b/chats/utils.py @@ -0,0 +1,52 @@ +import copy +import global_vars + +from pingpong.context import CtxLastWindowStrategy +from pingpong.context import InternetSearchStrategy, SimilaritySearcher + +from chats import pre, post + +def build_prompts(ppmanager, global_context, win_size=3): + dummy_ppm = copy.deepcopy(ppmanager) + + dummy_ppm.ctx = global_context + for pingpong in dummy_ppm.pingpongs: + pong = pingpong.pong + first_sentence = pong.split("\n")[0] + if first_sentence != "" and \ + pre.contains_image_markdown(first_sentence): + pong = ' '.join(pong.split("\n")[1:]).strip() + pingpong.pong = pong + + lws = CtxLastWindowStrategy(win_size) + + prompt = lws(dummy_ppm) + return prompt + +def text_stream(ppmanager, streamer): + count = 0 + + for new_text in streamer: + if count == 0: + ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") + count = count + 1 + + ppmanager.append_pong(new_text) + yield ppmanager, ppmanager.build_uis() + + yield ppmanager, ppmanager.build_uis() + +def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cpu"): + searcher = SimilaritySearcher.from_pretrained(device=device) + iss = InternetSearchStrategy(searcher, serper_api_key=serper_api_key)(ppmanager) + + step_ppm = None + while True: + try: + step_ppm, _ = next(iss) + yield "", step_ppm.build_uis() + except StopIteration: + break + + search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv) + yield search_prompt, ppmanager.build_uis() \ No newline at end of file diff --git a/chats/vicuna.py b/chats/vicuna.py index 8e6d492f..20b610d6 100644 --- a/chats/vicuna.py +++ b/chats/vicuna.py @@ -5,64 +5,14 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy - -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split(" ")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split(" ")[1:]) - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager +from chats.utils import build_prompts, text_stream, internet_search def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -75,11 +25,18 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, return_token_type_ids=False @@ -91,18 +48,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/wizard_coder.py b/chats/wizard_coder.py index de8621c3..eea145a7 100644 --- a/chats/wizard_coder.py +++ b/chats/wizard_coder.py @@ -8,7 +8,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, text_stream, internet_search class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: @@ -18,62 +18,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return True return False -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split(" ")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split(" ")[1:]) - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -86,14 +36,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - StoppingCriteriaList([StopOnTokens()]), False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) @@ -102,18 +59,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/wizard_falcon.py b/chats/wizard_falcon.py index b95577dc..79042bbf 100644 --- a/chats/wizard_falcon.py +++ b/chats/wizard_falcon.py @@ -8,7 +8,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy +from chats.utils import build_prompts, text_stream, internet_search class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: @@ -18,62 +18,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return True return False -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split(" ")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split(" ")[1:]) - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt - -def text_stream(ppmanager, streamer): - count = 0 - - for new_text in streamer: - if count == 0: - ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n") - count = count + 1 - - ppmanager.append_pong(new_text) - yield ppmanager, ppmanager.build_uis() - - yield ppmanager, ppmanager.build_uis() - -def summarize( - ppmanager, prompt_to_summarize, win_size, - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id -): - ctx = ppmanager.ctx - last_pong = ppmanager.pingpongs[-1].pong - ppmanager.add_pingpong(PingPong(prompt_to_summarize, "")) - prompt = ppmanager.build_prompts(from_idx=-win_size) - - _, gen_config_summarization = pre.build_gen_config( - temperature, top_p, top_k, repetition_penalty, max_new_tokens, - num_beams, use_cache, do_sample, eos_token_id, pad_token_id - ) - summarize_output = get_output_batch( - global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization - )[0].split("### Response:")[-1].split("### Input:")[0].strip() - ppmanager.ctx = summarize_output - ppmanager.pop_pingpong() - return ppmanager - def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -86,14 +36,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) - + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - StoppingCriteriaList([StopOnTokens()]), False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) @@ -102,18 +59,4 @@ def chat_stream( yield "", uis, prompt, str(res) ppm = post.strip_pong(ppm) - yield "", ppm.build_uis(), prompt, str(res) - - # summarization - # ppm.add_pingpong( - # PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)") - # ) - # yield "", ppm.build_uis(), prompt, state - # ppm.pop_pingpong() - - # ppm = summarize( - # ppm, ctx_sum_prompt, ctx_num_lconv, - # sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, - # sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid - # ) yield "", ppm.build_uis(), prompt, str(res) \ No newline at end of file diff --git a/chats/xgen.py b/chats/xgen.py index 09a97ae4..03fca056 100644 --- a/chats/xgen.py +++ b/chats/xgen.py @@ -9,24 +9,7 @@ from pingpong import PingPong from gens.batch_gen import get_output_batch -from pingpong.context import CtxLastWindowStrategy - -def build_prompts(ppmanager, user_message, global_context, win_size=3): - dummy_ppm = copy.deepcopy(ppmanager) - - dummy_ppm.ctx = global_context - for pingpong in dummy_ppm.pingpongs: - pong = pingpong.pong - first_sentence = pong.split("\n")[0] - if first_sentence != "" and \ - pre.contains_image_markdown(first_sentence): - pong = ' '.join(pong.split("\n")[1:]).strip() - pingpong.pong = pong - - lws = CtxLastWindowStrategy(win_size) - - prompt = lws(dummy_ppm) - return prompt +from chats.utils import build_prompts, internet_search def text_stream(ppmanager, streamer): count = 0 @@ -63,7 +46,8 @@ def chat_stream( idx, local_data, user_message, state, global_context, ctx_num_lconv, ctx_sum_prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid + sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, + internet_option, serper_api_key ): res = [ state["ppmanager_type"].from_json(json.dumps(ppm)) @@ -76,14 +60,21 @@ def chat_stream( ppm.add_pingpong( PingPong(user_message, "") ) - prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv) + prompt = build_prompts(ppm, global_context, ctx_num_lconv) + ####### + if internet_option: + search_prompt = None + for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): + search_prompt = tmp_prompt + yield "", uis, prompt, str(res) + # prepare text generating streamer & start generating gen_kwargs, streamer = pre.build( - prompt, + search_prompt if internet_option else prompt, res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, - None, False + return_token_type_ids=False ) pre.start_gen(gen_kwargs) diff --git a/discord_app.py b/discord_app.py index 68fcb91f..1f1246d4 100644 --- a/discord_app.py +++ b/discord_app.py @@ -8,8 +8,10 @@ from concurrent.futures import ThreadPoolExecutor import discord +from discord.errors import HTTPException import global_vars +from pingpong.context import InternetSearchStrategy, SimilaritySearcher from discordbot.req import ( sync_task, build_prompt, build_ppm @@ -44,6 +46,16 @@ async def build_prompt_and_reply(executor, user_name, user_id): - **`help`:** list of supported commands - **`model-info`:** get currently selected model card - **`default-params`:** get default parameters of the Generation Config + +You can start conversation by metioning the chatbot `@{chatbot name} {your prompt} {options}`, and the following options are supported. +- **`--top-p {float}`**: determins how many tokens to pick from the top tokens based on the sum of their probabilities(<= `top-p`). +- **`--temperature {float}`**: used to modulate the next token probabilities. +- **`--max-new-tokens {integer}`**: maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. +- **`--do-sample`**: determines whether or not to use sampling ; use greedy decoding otherwise. +- **`--max-windows {integer}`**: determines how many past conversations to look up as a reference. +- **`--internet`**: determines whether or not to use internet search capabilities. + +If you want to continue conversation based on past conversation histories, you can simply `reply` to chatbot's message. At this time, you don't need to metion its name. However, you need to specify options in every turn. For instance, if you want to `reply` based on internet search information, then you shoul specify `--internet` in your message. """ await msg.channel.send(help_msg) elif user_msg == "model-info": @@ -56,27 +68,73 @@ async def build_prompt_and_reply(executor, user_name, user_id): """ await msg.channel.send(help_msg) elif user_msg == "default-params": - help_msg = f"""{global_vars.gen_config}""" + help_msg = f"""{global_vars.gen_config}, max-windows = {user_args.max_windows}""" await msg.channel.send(help_msg) else: if err_msg is None: - ppm = await build_ppm(msg, user_msg, user_name, user_id) - - prompt = await build_prompt(ppm, user_args.max_windows) - response = await loop.run_in_executor( - executor, sync_task, - prompt, user_args - ) - if response.endswith(""): - response = response[:-len("")] - - if response.endswith("<|endoftext|>"): - response = response[:-len("<|endoftext|>")] + try: + ppm = await build_ppm(msg, user_msg, user_name, user_id) + + if user_args.internet and serper_api_key is not None: + progress_msg = await msg.reply("Progress 🚧", mention_author=False) + + internet_search_ppm = copy.deepcopy(ppm) + internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query." + internet_search_ppm.pingpongs[-1].ping = internet_search_prompt + internet_search_prompt = await build_prompt( + internet_search_ppm, + ctx_include=False, + win_size=user_args.max_windows + ) + internet_search_prompt_response = await loop.run_in_executor( + executor, sync_task, internet_search_prompt, user_args + ) + if internet_search_prompt_response.endswith(""): + internet_search_prompt_response = internet_search_prompt_response[:-len("")] + if internet_search_prompt_response.endswith("<|endoftext|>"): + internet_search_prompt_response = internet_search_prompt_response[:-len("<|endoftext|>")] + + ppm.pingpongs[-1].ping = internet_search_prompt_response + + await progress_msg.edit( + content=f"• Search query re-organized by LLM: {internet_search_prompt_response}", + suppress=True + ) + + searcher = SimilaritySearcher.from_pretrained(device="cuda") + + logs = "" + for step_ppm, step_msg in InternetSearchStrategy( + searcher, serper_api_key=serper_api_key + )(ppm, search_query=internet_search_prompt_response, top_k=8): + ppm = step_ppm + logs = logs + step_msg + "\n" + await progress_msg.edit(content=logs, suppress=True) + + prompt = await build_prompt(ppm, win_size=user_args.max_windows) + response = await loop.run_in_executor( + executor, sync_task, + prompt, user_args + ) + if response.endswith(""): + response = response[:-len("")] + + if response.endswith("<|endoftext|>"): + response = response[:-len("<|endoftext|>")] + + response = f"**{model_name}** 💬\n{response.strip()}" + if len(response) >= max_response_length: + response = response[:max_response_length] + + if user_args.internet and serper_api_key is not None: + await progress_msg.delete() - response = f"**{model_name}** 💬\n{response.strip()}" - if len(response) >= max_response_length: - response = response[:max_response_length] - await msg.reply(response, mention_author=False) + await msg.reply(response, mention_author=False) + except IndexError: + err_msg = "Index error" + await msg.channel.send(err_msg) + except HTTPException: + pass else: await msg.channel.send(err_msg) @@ -148,11 +206,13 @@ def discord_main(args): elif mode == "HALF": off_modes(args) args.mode_full_gpu = True - + global max_workers global model_name + global serper_api_key max_workers = args.max_workers model_name = args.model_name + serper_api_key = args.serper_api_key selected_model_info = model_info[model_name] @@ -179,7 +239,7 @@ def discord_main(args): client.run(args.token) -if __name__ == "__main__": +if __name__ == "__main__": parser = argparse.ArgumentParser() # can be set via environment variable # --token == DISCORD_BOT_TOKEN @@ -193,6 +253,7 @@ def discord_main(args): parser.add_argument('--mode-4bit', default=False, action=argparse.BooleanOptionalAction) parser.add_argument('--mode-full-gpu', default=True, action=argparse.BooleanOptionalAction) parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction) + parser.add_argument('--serper-api-key', default=None, type=str) args = parser.parse_args() discord_main(args) diff --git a/discordbot/flags.py b/discordbot/flags.py index 44e6dcfd..f218600e 100644 --- a/discordbot/flags.py +++ b/discordbot/flags.py @@ -3,11 +3,12 @@ def parse_req(message, gen_config_path): parser = argparse.ArgumentParser() - parser.add_argument('--max-new-tokens', default=256, type=int) - parser.add_argument('--temperature', default=0.8, type=float) + parser.add_argument('--max-new-tokens', default=None, type=int) + parser.add_argument('--temperature', default=None, type=float) parser.add_argument('--max-windows', default=3, type=int) - parser.add_argument('--do-sample', default=True, action=argparse.BooleanOptionalAction) - parser.add_argument('--top-p', default=0.75, type=float) + parser.add_argument('--do-sample', default=None, action=argparse.BooleanOptionalAction) + parser.add_argument('--top-p', default=None, type=float) + parser.add_argument('--internet', default=False, action=argparse.BooleanOptionalAction) msg = message.strip() multiparts_msg = msg.split("--") diff --git a/discordbot/req.py b/discordbot/req.py index 5f19c48d..f1909488 100644 --- a/discordbot/req.py +++ b/discordbot/req.py @@ -13,16 +13,30 @@ def sync_task(prompt, args): input_ids = global_vars.tokenizer(prompt, return_tensors="pt").input_ids.to(global_vars.device) + + gen_config = copy.deepcopy(global_vars.gen_config) + if args.max_new_tokens is not None: + gen_config.max_new_tokens = args.max_new_tokens + if args.temperature is not None: + gen_config.temperature = args.temperature + if args.do_sample is not None: + gen_config.do_sample = args.do_sample + if args.top_p is not None: + gen_config.top_p = args.top_p + generated_ids = global_vars.model.generate( input_ids=input_ids, - generation_config=global_vars.gen_config + generation_config=gen_config ) response = global_vars.tokenizer.decode(generated_ids[0][input_ids.shape[-1]:]) return response -async def build_prompt(ppmanager, win_size=3): +async def build_prompt(ppmanager, ctx_include=True, win_size=3): dummy_ppm = copy.deepcopy(ppmanager) - dummy_ppm.ctx = get_global_context(global_vars.model_type) + if ctx_include: + dummy_ppm.ctx = get_global_context(global_vars.model_type) + else: + dummy_ppm.ctx = "" lws = CtxLastWindowStrategy(win_size) return lws(dummy_ppm) diff --git a/entry_point.py b/entry_point.py index e0572478..64bcd2f4 100644 --- a/entry_point.py +++ b/entry_point.py @@ -9,6 +9,7 @@ app_mode = os.getenv("LLMCHAT_APP_MODE") local_files_only = os.getenv("LLMCHAT_LOCAL_FILES_ONLY") + serper_api_key = os.getenv("LLMCHAT_SERPER_API_KEY") if app_mode is None or \ app_mode not in ["GRADIO", "DISCORD"]: @@ -24,6 +25,7 @@ parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction) parser.add_argument('--share', default=False, action=argparse.BooleanOptionalAction) parser.add_argument('--debug', default=False, action=argparse.BooleanOptionalAction) + parser.add_argument('--serper-api-key', default=serper_api_key, type=str) args = parser.parse_args() gradio_main(args) @@ -37,5 +39,6 @@ parser.add_argument('--mode-4bit', default=False, action=argparse.BooleanOptionalAction) parser.add_argument('--mode-full-gpu', default=True, action=argparse.BooleanOptionalAction) parser.add_argument('--local-files-only', default=local_files_only, action=argparse.BooleanOptionalAction) + parser.add_argument('--serper-api-key', default=serper_api_key, type=str) args = parser.parse_args() discord_main(args)