Skip to content

Commit

Permalink
add llama2 7, 13B
Browse files Browse the repository at this point in the history
  • Loading branch information
deep-diver committed Jul 19, 2023
1 parent 0c5528e commit e42be48
Show file tree
Hide file tree
Showing 15 changed files with 695 additions and 136 deletions.
13 changes: 10 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,10 @@ def gradio_main(args):
with gr.Column(min_width=20):
xgen_7b = gr.Button("xgen-7b", elem_id="xgen-7b", elem_classes=["square"])
gr.Markdown("XGen", elem_classes=["center"])

with gr.Column(min_width=20):
llama2_7b = gr.Button("llama2-7b", elem_id="llama2-7b", elem_classes=["square"])
gr.Markdown("LLaMA 2", elem_classes=["center"])

gr.Markdown("## ~ 20B Parameters")
with gr.Row(elem_classes=["sub-container"]):
Expand Down Expand Up @@ -592,6 +596,9 @@ def gradio_main(args):
with gr.Column(min_width=20):
orcamini_13b = gr.Button("orcamini-13b", elem_id="orcamini-13b", elem_classes=["square"])
gr.Markdown("Orca Mini", elem_classes=["center"])
with gr.Column(min_width=20):
llama2_13b = gr.Button("llama2-13b", elem_id="llama2-13b", elem_classes=["square"])
gr.Markdown("LLaMA 2", elem_classes=["center"])

gr.Markdown("## ~ 30B Parameters", visible=False)
with gr.Row(elem_classes=["sub-container"], visible=False):
Expand Down Expand Up @@ -880,12 +887,12 @@ def gradio_main(args):
gpt4_alpaca_7b, os_stablelm7b, mpt_7b, redpajama_7b, redpajama_instruct_7b, llama_deus_7b,
evolinstruct_vicuna_7b, alpacoom_7b, baize_7b, guanaco_7b, vicuna_7b_1_3,
falcon_7b, wizard_falcon_7b, airoboros_7b, samantha_7b, openllama_7b, orcamini_7b,
xgen_7b,
xgen_7b,llama2_7b,
flan11b, koalpaca, kullm, alpaca_lora13b, gpt4_alpaca_13b, stable_vicuna_13b,
starchat_15b, starchat_beta_15b, vicuna_7b, vicuna_13b, evolinstruct_vicuna_13b,
baize_13b, guanaco_13b, nous_hermes_13b, airoboros_13b, samantha_13b, chronos_13b,
wizardlm_13b, wizard_vicuna_13b, wizard_coder_15b, vicuna_13b_1_3, openllama_13b, orcamini_13b,
camel20b,
llama2_13b, camel20b,
guanaco_33b, falcon_40b, wizard_falcon_40b, samantha_33b, lazarus_30b, chronos_33b,
wizardlm_30b, wizard_vicuna_30b, vicuna_33b_1_3, mpt_30b
]
Expand Down Expand Up @@ -1077,7 +1084,7 @@ def gradio_main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--root-path', default="")
parser.add_argument('--local-files-only', default=local_files_only, action=argparse.BooleanOptionalAction)
parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--share', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--debug', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--serper-api-key', default=None, type=str)
Expand Down
12 changes: 11 additions & 1 deletion chats/central.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from chats import falcon
from chats import wizard_falcon
from chats import xgen
from chats import llama2
from chats import custom

def chat_stream(
Expand All @@ -40,14 +41,23 @@ def chat_stream(
internet_option, serper_api_key
)

elif model_type == "llama2":
cs = llama2.chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
)

elif model_type == "xgen":
cs = xgen.chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
)
)

elif model_type == "stablelm":
cs = stablelm.chat_stream(
Expand Down
51 changes: 51 additions & 0 deletions chats/llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import copy
import json
import global_vars
from chats import pre, post
from pingpong import PingPong
from gens.batch_gen import get_output_batch

from chats.utils import build_prompts, text_stream, internet_search

def chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
):
res = [
state["ppmanager_type"].from_json(json.dumps(ppm))
for ppm in local_data
]

ppm = res[idx]

# add_ping returns a prompt structured in Alpaca form
ppm.add_pingpong(
PingPong(user_message, "")
)
prompt = build_prompts(ppm, global_context, ctx_num_lconv)

#######
if internet_option:
search_prompt = None
for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
search_prompt = tmp_prompt
yield "", uis, prompt, str(res)

# prepare text generating streamer & start generating
gen_kwargs, streamer = pre.build(
search_prompt if internet_option else prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts,
res_beams, res_cache, res_sample, res_eosid, res_padid,
return_token_type_ids=False
)
pre.start_gen(gen_kwargs)

# handling stream
for ppmanager, uis in text_stream(ppm, streamer):
yield "", uis, prompt, str(res)

ppm = post.strip_pong(ppm)
yield "", ppm.build_uis(), prompt, str(res)
12 changes: 12 additions & 0 deletions configs/response_configs/llama2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
generation_config:
temperature: 0.95
top_p: 0.9
top_k: 50
num_beams: 1
use_cache: True
repetition_penalty: 1.2
max_new_tokens: 4096
do_sample: True
pad_token_id: 32000
bos_token_id: 1
eos_token_id: 2
171 changes: 82 additions & 89 deletions discord_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import types
import asyncio
import argparse
from urlextract import URLExtract
from urllib.request import urlopen
from concurrent.futures import ThreadPoolExecutor

Expand All @@ -17,6 +18,8 @@
sync_task, build_prompt, build_ppm
)
from discordbot.flags import parse_req
from discordbot import helps, post
from dumb_utils import URLSearchStrategy

model_info = json.load(open("model_cards.json"))

Expand All @@ -33,116 +36,106 @@
max_response_length = 2000

async def build_prompt_and_reply(executor, user_name, user_id):
other_job_on_progress = False
loop = asyncio.get_running_loop()

print(queue.qsize())
msg = await queue.get()
user_msg, user_args, err_msg = parse_req(
msg.content.replace(f"@{user_name} ", "").replace(f"<@{user_id}> ", ""), None
user_msg, user_args = parse_req(
msg.content.replace(f"@{user_name} ", "").replace(f"<@{user_id}> ", ""), global_vars.gen_config
)

if user_msg == "help":
help_msg = """Type one of the following for more information about this chatbot
- **`help`:** list of supported commands
- **`model-info`:** get currently selected model card
- **`default-params`:** get default parameters of the Generation Config
You can start conversation by metioning the chatbot `@{chatbot name} {your prompt} {options}`, and the following options are supported.
- **`--top-p {float}`**: determins how many tokens to pick from the top tokens based on the sum of their probabilities(<= `top-p`).
- **`--temperature {float}`**: used to modulate the next token probabilities.
- **`--max-new-tokens {integer}`**: maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
- **`--do-sample`**: determines whether or not to use sampling ; use greedy decoding otherwise.
- **`--max-windows {integer}`**: determines how many past conversations to look up as a reference.
- **`--internet`**: determines whether or not to use internet search capabilities.
If you want to continue conversation based on past conversation histories, you can simply `reply` to chatbot's message. At this time, you don't need to metion its name. However, you need to specify options in every turn. For instance, if you want to `reply` based on internet search information, then you shoul specify `--internet` in your message.
"""
await msg.channel.send(help_msg)
await msg.channel.send(helps.get_help())
elif user_msg == "model-info":
selected_model_info = model_info[model_name]
help_msg = f"""## {model_name}
- **Description:** {selected_model_info['desc']}
- **Number of parameters:** {selected_model_info['parameters']}
- **Hugging Face Hub (base):** {selected_model_info['hub(base)']}
- **Hugging Face Hub (ckpt):** {selected_model_info['hub(ckpt)']}
"""
await msg.channel.send(help_msg)
await msg.channel.send(helps.get_model_info(model_name, model_info))
elif user_msg == "default-params":
help_msg = f"""{global_vars.gen_config}, max-windows = {user_args.max_windows}"""
await msg.channel.send(help_msg)
await msg.channel.send(helps.get_default_params(global_vars.gen_config, user_args["max-windows"]))
else:
if err_msg is None:
try:
ppm = await build_ppm(msg, user_msg, user_name, user_id)

if user_args.internet and serper_api_key is not None:
progress_msg = await msg.reply("Progress 🚧", mention_author=False)

internet_search_ppm = copy.deepcopy(ppm)
internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query."
internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
internet_search_prompt = await build_prompt(
internet_search_ppm,
ctx_include=False,
win_size=user_args.max_windows
)
internet_search_prompt_response = await loop.run_in_executor(
executor, sync_task, internet_search_prompt, user_args
)
if internet_search_prompt_response.endswith("</s>"):
internet_search_prompt_response = internet_search_prompt_response[:-len("</s>")]
if internet_search_prompt_response.endswith("<|endoftext|>"):
internet_search_prompt_response = internet_search_prompt_response[:-len("<|endoftext|>")]

ppm.pingpongs[-1].ping = internet_search_prompt_response

await progress_msg.edit(
content=f"• Search query re-organized by LLM: {internet_search_prompt_response}",
suppress=True
)

searcher = SimilaritySearcher.from_pretrained(device="cuda")
try:
ppm = await build_ppm(msg, user_msg, user_name, user_id)

if user_args["internet"] and serper_api_key is not None:
other_job_on_progress = True
progress_msg = await msg.reply("Progress 🚧", mention_author=False)

internet_search_ppm = copy.deepcopy(ppm)
internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query."
internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
internet_search_prompt = await build_prompt(
internet_search_ppm,
ctx_include=False,
win_size=user_args["max-windows"]
)
internet_search_prompt_response = await loop.run_in_executor(
executor, sync_task, internet_search_prompt, user_args
)
internet_search_prompt_response = post.clean(internet_search_prompt_response)

logs = ""
for step_ppm, step_msg in InternetSearchStrategy(
searcher, serper_api_key=serper_api_key
)(ppm, search_query=internet_search_prompt_response, top_k=8):
ppm = step_ppm
logs = logs + step_msg + "\n"
await progress_msg.edit(content=logs, suppress=True)
ppm.pingpongs[-1].ping = internet_search_prompt_response

prompt = await build_prompt(ppm, win_size=user_args.max_windows)
response = await loop.run_in_executor(
executor, sync_task,
prompt, user_args
await progress_msg.edit(
content=f"• Search query re-organized by LLM: {internet_search_prompt_response}",
suppress=True
)
if response.endswith("</s>"):
response = response[:-len("</s>")]

if response.endswith("<|endoftext|>"):
response = response[:-len("<|endoftext|>")]
searcher = SimilaritySearcher.from_pretrained(device=global_vars.device)

logs = ""
for step_ppm, step_msg in InternetSearchStrategy(
searcher, serper_api_key=serper_api_key
)(ppm, search_query=internet_search_prompt_response, top_k=8):
ppm = step_ppm
logs = logs + step_msg + "\n"
await progress_msg.edit(content=logs, suppress=True)
else:
url_extractor = URLExtract()
urls = url_extractor.find_urls(user_msg)
print(f"urls = {urls}")

if len(urls) > 0:
progress_msg = await msg.reply("Progress 🚧", mention_author=False)

response = f"**{model_name}** 💬\n{response.strip()}"
if len(response) >= max_response_length:
response = response[:max_response_length]
other_job_on_progress = True
searcher = SimilaritySearcher.from_pretrained(device=global_vars.device)

if user_args.internet and serper_api_key is not None:
await progress_msg.delete()

await msg.reply(response, mention_author=False)
except IndexError:
err_msg = "Index error"
await msg.channel.send(err_msg)
except HTTPException:
pass
else:
await msg.channel.send(err_msg)
logs = ""
for step_result, step_ppm, step_msg in URLSearchStrategy(searcher)(ppm, urls, top_k=8):
if step_result is True:
ppm = step_ppm
logs = logs + step_msg + "\n"
await progress_msg.edit(content=logs, suppress=True)
else:
ppm = step_ppm
logs = logs + step_msg + "\n"
await progress_msg.edit(content=logs, suppress=True)
await asyncio.sleep(2)
break

prompt = await build_prompt(ppm, win_size=user_args["max-windows"])
response = await loop.run_in_executor(
executor, sync_task,
prompt, user_args
)
response = post.clean(response)

response = f"**{model_name}** 💬\n{response.strip()}"
if len(response) >= max_response_length:
response = response[:max_response_length]

if other_job_on_progress is True:
await progress_msg.delete()

await msg.reply(response, mention_author=False)
except IndexError:
await msg.channel.send("Index error")
except HTTPException:
pass

async def background_task(user_name, user_id, max_workers):
executor = ThreadPoolExecutor(max_workers=max_workers)
print("Task Started. Waiting for inputs.")
while True:
# await asyncio.sleep(5)
await build_prompt_and_reply(executor, user_name, user_id)

@client.event
Expand Down
Loading

0 comments on commit e42be48

Please sign in to comment.