Skip to content

Commit

Permalink
add internet search cap
Browse files Browse the repository at this point in the history
  • Loading branch information
deep-diver committed Jul 14, 2023
1 parent 8261ba7 commit 944a989
Show file tree
Hide file tree
Showing 26 changed files with 507 additions and 1,328 deletions.
21 changes: 17 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
import json
import copy
Expand Down Expand Up @@ -828,6 +829,15 @@ def gradio_main(args):
elem_id="global-context"
)

gr.Markdown("#### Internet search")
with gr.Row():
internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode")
serper_api_key = gr.Textbox(
value= "" if args.serper_api_key is None else args.serper_api_key,
placeholder="Get one by visiting serper.dev",
label="Serper api key"
)

gr.Markdown("#### GenConfig for **response** text generation")
with gr.Row():
res_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True)
Expand Down Expand Up @@ -998,7 +1008,8 @@ def gradio_main(args):
[idx, local_data, instruction_txtbox, chat_state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid],
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key],
[instruction_txtbox, chatbot, context_inspector, local_data],
)

Expand All @@ -1016,8 +1027,9 @@ def gradio_main(args):
[idx, local_data, instruction_txtbox, chat_state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid],
[instruction_txtbox, chatbot, context_inspector, local_data],
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key],
[instruction_txtbox, chatbot, context_inspector, local_data],
).then(
lambda: gr.update(interactive=True),
None,
Expand Down Expand Up @@ -1065,9 +1077,10 @@ def gradio_main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--root-path', default="")
parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--local-files-only', default=local_files_only, action=argparse.BooleanOptionalAction)
parser.add_argument('--share', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--debug', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--serper-api-key', default=None, type=str)
args = parser.parse_args()

gradio_main(args)
83 changes: 13 additions & 70 deletions chats/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,64 +5,14 @@
from pingpong import PingPong
from gens.batch_gen import get_output_batch

from pingpong.context import CtxLastWindowStrategy

def build_prompts(ppmanager, user_message, global_context, win_size=3):
dummy_ppm = copy.deepcopy(ppmanager)

dummy_ppm.ctx = global_context
for pingpong in dummy_ppm.pingpongs:
pong = pingpong.pong
first_sentence = pong.split("\n")[0]
if first_sentence != "" and \
pre.contains_image_markdown(first_sentence):
pong = ' '.join(pong.split("\n")[1:]).strip()
pingpong.pong = pong

lws = CtxLastWindowStrategy(win_size)

prompt = lws(dummy_ppm)
return prompt

def text_stream(ppmanager, streamer):
count = 0

for new_text in streamer:
if count == 0:
ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n")
count = count + 1

ppmanager.append_pong(new_text)
yield ppmanager, ppmanager.build_uis()

yield ppmanager, ppmanager.build_uis()

def summarize(
ppmanager, prompt_to_summarize, win_size,
temperature, top_p, top_k, repetition_penalty, max_new_tokens,
num_beams, use_cache, do_sample, eos_token_id, pad_token_id
):
ctx = ppmanager.ctx
last_pong = ppmanager.pingpongs[-1].pong
ppmanager.add_pingpong(PingPong(prompt_to_summarize, ""))
prompt = ppmanager.build_prompts(from_idx=-win_size)

_, gen_config_summarization = pre.build_gen_config(
temperature, top_p, top_k, repetition_penalty, max_new_tokens,
num_beams, use_cache, do_sample, eos_token_id, pad_token_id
)
summarize_output = get_output_batch(
global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization
)[0].split("### Response:")[-1].strip()
ppmanager.ctx = summarize_output
ppmanager.pop_pingpong()
return ppmanager
from chats.utils import build_prompts, text_stream, internet_search

def chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
):
res = [
state["ppmanager_type"].from_json(json.dumps(ppm))
Expand All @@ -75,11 +25,18 @@ def chat_stream(
ppm.add_pingpong(
PingPong(user_message, "")
)
prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv)

prompt = build_prompts(ppm, global_context, ctx_num_lconv)

#######
if internet_option:
search_prompt = None
for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
search_prompt = tmp_prompt
yield "", uis, prompt, str(res)

# prepare text generating streamer & start generating
gen_kwargs, streamer = pre.build(
prompt,
search_prompt if internet_option else prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts,
res_beams, res_cache, res_sample, res_eosid, res_padid,
return_token_type_ids=False
Expand All @@ -91,18 +48,4 @@ def chat_stream(
yield "", uis, prompt, str(res)

ppm = post.strip_pong(ppm)
yield "", ppm.build_uis(), prompt, str(res)

# summarization
# ppm.add_pingpong(
# PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)")
# )
# yield "", ppm.build_uis(), prompt, state
# ppm.pop_pingpong()

# ppm = summarize(
# ppm, ctx_sum_prompt, ctx_num_lconv,
# sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts,
# sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid
# )
yield "", ppm.build_uis(), prompt, str(res)
83 changes: 13 additions & 70 deletions chats/alpaca_gpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,64 +5,14 @@
from pingpong import PingPong
from gens.batch_gen import get_output_batch

from pingpong.context import CtxLastWindowStrategy

def build_prompts(ppmanager, user_message, global_context, win_size=3):
dummy_ppm = copy.deepcopy(ppmanager)

dummy_ppm.ctx = global_context
for pingpong in dummy_ppm.pingpongs:
pong = pingpong.pong
first_sentence = pong.split("\n")[0]
if first_sentence != "" and \
pre.contains_image_markdown(first_sentence):
pong = ' '.join(pong.split("\n")[1:]).strip()
pingpong.pong = pong

lws = CtxLastWindowStrategy(win_size)

prompt = lws(dummy_ppm)
return prompt

def text_stream(ppmanager, streamer):
count = 0

for new_text in streamer:
if count == 0:
ppmanager.append_pong(f"![]({global_vars.model_thumbnail_tiny})***[{global_vars.model_type}]***\n")
count = count + 1

ppmanager.append_pong(new_text)
yield ppmanager, ppmanager.build_uis()

yield ppmanager, ppmanager.build_uis()

def summarize(
ppmanager, prompt_to_summarize, win_size,
temperature, top_p, top_k, repetition_penalty, max_new_tokens,
num_beams, use_cache, do_sample, eos_token_id, pad_token_id
):
ctx = ppmanager.ctx
last_pong = ppmanager.pingpongs[-1].pong
ppmanager.add_pingpong(PingPong(prompt_to_summarize, ""))
prompt = ppmanager.build_prompts(from_idx=-win_size)

_, gen_config_summarization = pre.build_gen_config(
temperature, top_p, top_k, repetition_penalty, max_new_tokens,
num_beams, use_cache, do_sample, eos_token_id, pad_token_id
)
summarize_output = get_output_batch(
global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization
)[0].split("### Response:")[-1].strip()
ppmanager.ctx = summarize_output
ppmanager.pop_pingpong()
return ppmanager
from chats.utils import build_prompts, text_stream, internet_search

def chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
):
res = [
state["ppmanager_type"].from_json(json.dumps(ppm))
Expand All @@ -75,11 +25,18 @@ def chat_stream(
ppm.add_pingpong(
PingPong(user_message, "")
)
prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv)

prompt = build_prompts(ppm, global_context, ctx_num_lconv)

#######
if internet_option:
search_prompt = None
for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
search_prompt = tmp_prompt
yield "", uis, prompt, str(res)

# prepare text generating streamer & start generating
gen_kwargs, streamer = pre.build(
prompt,
search_prompt if internet_option else prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts,
res_beams, res_cache, res_sample, res_eosid, res_padid,
return_token_type_ids=False
Expand All @@ -91,18 +48,4 @@ def chat_stream(
yield "", uis, prompt, str(res)

ppm = post.strip_pong(ppm)
yield "", ppm.build_uis(), prompt, str(res)

# summarization
# ppm.add_pingpong(
# PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)")
# )
# yield "", ppm.build_uis(), prompt, state
# ppm.pop_pingpong()

# ppm = summarize(
# ppm, ctx_sum_prompt, ctx_num_lconv,
# sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts,
# sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid
# )
yield "", ppm.build_uis(), prompt, str(res)
77 changes: 13 additions & 64 deletions chats/alpacoom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,58 +5,14 @@
from pingpong import PingPong
from gens.batch_gen import get_output_batch

from pingpong.context import CtxLastWindowStrategy

def build_prompts(ppmanager, user_message, global_context, win_size=3):
dummy_ppm = copy.deepcopy(ppmanager)

dummy_ppm.ctx = global_context
for pingpong in dummy_ppm.pingpongs:
pong = pingpong.pong
first_sentence = pong.split(" ")[0]
if first_sentence != "" and \
pre.contains_image_markdown(first_sentence):
pong = ' '.join(pong.split(" ")[1:])
pingpong.pong = pong

lws = CtxLastWindowStrategy(win_size)

prompt = lws(dummy_ppm)
return prompt

def text_stream(ppmanager, streamer):
for new_text in streamer:
ppmanager.append_pong(new_text)
yield ppmanager, ppmanager.build_uis()

yield ppmanager, ppmanager.build_uis()

def summarize(
ppmanager, prompt_to_summarize, win_size,
temperature, top_p, top_k, repetition_penalty, max_new_tokens,
num_beams, use_cache, do_sample, eos_token_id, pad_token_id
):
ctx = ppmanager.ctx
last_pong = ppmanager.pingpongs[-1].pong
ppmanager.add_pingpong(PingPong(prompt_to_summarize, ""))
prompt = ppmanager.build_prompts(from_idx=-win_size)

_, gen_config_summarization = pre.build_gen_config(
temperature, top_p, top_k, repetition_penalty, max_new_tokens,
num_beams, use_cache, do_sample, eos_token_id, pad_token_id
)
summarize_output = get_output_batch(
global_vars.model, global_vars.tokenizer, [prompt], gen_config_summarization
)[0].split("### Response:")[-1].strip()
ppmanager.ctx = summarize_output
ppmanager.pop_pingpong()
return ppmanager
from chats.utils import build_prompts, text_stream, internet_search

def chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
):
res = [
state["ppmanager_type"].from_json(json.dumps(ppm))
Expand All @@ -69,11 +25,18 @@ def chat_stream(
ppm.add_pingpong(
PingPong(user_message, "")
)
prompt = build_prompts(ppm, user_message, global_context, ctx_num_lconv)

prompt = build_prompts(ppm, global_context, ctx_num_lconv)

#######
if internet_option:
search_prompt = None
for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
search_prompt = tmp_prompt
yield "", uis, prompt, str(res)

# prepare text generating streamer & start generating
gen_kwargs, streamer = pre.build(
prompt,
search_prompt if internet_option else prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts,
res_beams, res_cache, res_sample, res_eosid, res_padid,
return_token_type_ids=False
Expand All @@ -85,18 +48,4 @@ def chat_stream(
yield "", uis, prompt, str(res)

ppm = post.strip_pong(ppm)
yield "", ppm.build_uis(), prompt, str(res)

# summarization
# ppm.add_pingpong(
# PingPong(None, "![](https://i.postimg.cc/ZKNKDPBd/Vanilla-1s-209px.gif)")
# )
# yield "", ppm.build_uis(), prompt, state
# ppm.pop_pingpong()

# ppm = summarize(
# ppm, ctx_sum_prompt, ctx_num_lconv,
# sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts,
# sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid
# )
yield "", ppm.build_uis(), prompt, str(res)
Loading

0 comments on commit 944a989

Please sign in to comment.