Skip to content

Commit

Permalink
add more models + update tgi interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
deep-diver committed Aug 25, 2023
1 parent 60d510c commit e36ab3c
Show file tree
Hide file tree
Showing 7 changed files with 388 additions and 56 deletions.
110 changes: 76 additions & 34 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def move_to_second_view_from_tb(tb, evt: gr.SelectData):
info["thumb"],
f"## {selected_model}",
f"**Parameters**\n: Approx. {info['parameters']}",
f"**🤗 Hub(base)**\n: {info['hub(base)']}",
f"**🤗 Hub(LoRA)**\n: {info['hub(ckpt)']}",
f"**🤗 Hub(GPTQ)**\n: {info['hub(gptq)']}",
f"**🤗 Hub(GPTQ_BASE)**\n: {info['hub(gptq_base)']}",
f"**Hugging Face Hub(base)**\n: {info['hub(base)']}",
f"**Hugging Face Hub(LoRA)**\n: {info['hub(ckpt)']}",
f"**Hugging Face Hub(GPTQ)**\n: {info['hub(gptq)']}",
f"**Hugging Face Hub(GPTQ_BASE)**\n: {info['hub(gptq_base)']}",
info['desc'],
f"""**Min VRAM requirements** :
| half precision | load_in_8bit | load_in_4bit |
Expand Down Expand Up @@ -415,10 +415,10 @@ def move_to_second_view(btn):
info["thumb"],
f"## {btn}",
f"**Parameters**\n: Approx. {info['parameters']}",
f"**🤗 Hub(base)**\n: {info['hub(base)']}",
f"**🤗 Hub(LoRA)**\n: {info['hub(ckpt)']}",
f"**🤗 Hub(GPTQ)**\n: {info['hub(gptq)']}",
f"**🤗 Hub(GPTQ_BASE)**\n: {info['hub(gptq_base)']}",
f"**Hugging Face Hub(base)**\n: {info['hub(base)']}",
f"**Hugging Face Hub(LoRA)**\n: {info['hub(ckpt)']}",
f"**Hugging Face Hub(GPTQ)**\n: {info['hub(gptq)']}",
f"**Hugging Face Hub(GPTQ_BASE)**\n: {info['hub(gptq_base)']}",
info['desc'],
f"""**Min VRAM requirements** :
| half precision | load_in_8bit | load_in_4bit |
Expand Down Expand Up @@ -459,14 +459,15 @@ def download_completed(
):
global local_files_only

print(model_gptq_base)
print(f"model_name: {model_name}")
print(f"model_base: {model_base}")

tmp_args = types.SimpleNamespace()
tmp_args.model_name = model_name[4:-6]
tmp_args.base_url = model_base.split(":")[-1].split("</p")[0].strip()
tmp_args.ft_ckpt_url = model_ckpt.split(":")[-1].split("</p")[0].strip()
tmp_args.gptq_url = model_gptq.split(":")[-1].split("</p")[0].strip()
tmp_args.gptq_base_url = model_gptq_base.split(":")[-1].split("</p")[0].strip().replace(' ', '')
tmp_args.model_name = model_name[3:]
tmp_args.base_url = model_base.split(":")[-1].strip()
tmp_args.ft_ckpt_url = model_ckpt.split(":")[-1].strip()
tmp_args.gptq_url = model_gptq.split(":")[-1].strip()
tmp_args.gptq_base_url = model_gptq_base.split(":")[-1].strip().replace(' ', '')
tmp_args.gen_config_path = gen_config_path
tmp_args.gen_config_summarization_path = gen_config_sum_path
tmp_args.force_download_ckpt = force_download
Expand Down Expand Up @@ -640,9 +641,17 @@ def gradio_main(args):
gr.Markdown("## Recent Releases")
with gr.Row(elem_classes=["sub-container"]):
with gr.Column(min_width=20):
stable_beluga2_70b_rr = gr.Button("stable-beluga2-70b", elem_id="stable-beluga2-70b", elem_classes=["square"])
gr.Markdown("Stable Beluga 2 (70B)", elem_classes=["center"])

codellama_7b_rr = gr.Button("codellama-7b", elem_id="codellama-7b", elem_classes=["square"])
gr.Markdown("Code LLaMA (7B)", elem_classes=["center"])

with gr.Column(min_width=20):
codellama_13b_rr = gr.Button("codellama-13b", elem_id="codellama-13b", elem_classes=["square"])
gr.Markdown("Code LLaMA (13B)", elem_classes=["center"])

with gr.Column(min_width=20):
codellama_34b_rr = gr.Button("codellama-34b", elem_id="codellama-34b", elem_classes=["square"])
gr.Markdown("Code LLaMA (34B)", elem_classes=["center"])

with gr.Column(min_width=20):
upstage_llama2_70b_2_rr = gr.Button("upstage-llama2-70b-2", elem_id="upstage-llama2-70b-2", elem_classes=["square"])
gr.Markdown("Upstage2 v2 (70B)", elem_classes=["center"])
Expand All @@ -656,16 +665,20 @@ def gradio_main(args):
gr.Markdown("WizardLM (70B)", elem_classes=["center"])

with gr.Column(min_width=20):
nous_hermes_13b_v2_rr = gr.Button("nous-hermes-13b-llama2", elem_id="nous-hermes-13b-llama2", elem_classes=["square"])
gr.Markdown("Nous Hermes 2 (13B)", elem_classes=["center"])

orcamini_70b_rr = gr.Button("orcamini-70b", elem_id="orcamini-70b", elem_classes=["square"])
gr.Markdown("Orca Mini (70B)", elem_classes=["center"])
with gr.Column(min_width=20):
nous_puffin_13b_v2_rr = gr.Button("nous-puffin-13b-llama2", elem_id="nous-puffin-13b-llama2", elem_classes=["square"])
gr.Markdown("Nous Puffin 2 (13B)", elem_classes=["center"])

samantha_70b_rr = gr.Button("samantha-70b", elem_id="samantha-70b", elem_classes=["square"])
gr.Markdown("Samantha (70B)", elem_classes=["center"])

with gr.Column(min_width=20):
godzilla_70b_rr = gr.Button("godzilla-70b", elem_id="godzilla-70b", elem_classes=["square"])
gr.Markdown("GadziLLa (70B)", elem_classes=["center"])

with gr.Column(min_width=20):
wizardlm_13b_1_2_rr = gr.Button("wizardlm-13b-1-2", elem_id="wizardlm-13b-1-2", elem_classes=["square"])
gr.Markdown("WizardLM 1.2 (13B)", elem_classes=["center"])
nous_hermes_70b_rr = gr.Button("nous-hermes-70b", elem_id="nous-hermes-70b", elem_classes=["square"])
gr.Markdown("Nous Hermes 2 (70B)", elem_classes=["center"])

with gr.Column(visible=False) as full_section:
gr.Markdown("## ~ 10B Parameters")
Expand All @@ -678,10 +691,6 @@ def gradio_main(args):
flan3b = gr.Button("flan-3b", elem_id="flan-3b", elem_classes=["square"])
gr.Markdown("Flan-XL", elem_classes=["center"])

# with gr.Column(min_width=20):
# replit_3b = gr.Button("replit-3b", elem_id="replit-3b", elem_classes=["square"])
# gr.Markdown("Replit Instruct", elem_classes=["center"])

with gr.Column(min_width=20):
camel5b = gr.Button("camel-5b", elem_id="camel-5b", elem_classes=["square"])
gr.Markdown("Camel", elem_classes=["center"])
Expand Down Expand Up @@ -778,6 +787,10 @@ def gradio_main(args):
nous_hermes_7b_v2 = gr.Button("nous-hermes-7b-llama2", elem_id="nous-hermes-7b-llama2", elem_classes=["square"])
gr.Markdown("Nous Hermes 2", elem_classes=["center"])

with gr.Column(min_width=20):
codellama_7b = gr.Button("codellama-7b", elem_id="codellama-7b", elem_classes=["square"])
gr.Markdown("Code LLaMA", elem_classes=["center"])

gr.Markdown("## ~ 20B Parameters")
with gr.Row(elem_classes=["sub-container"]):
with gr.Column(min_width=20, visible=False):
Expand Down Expand Up @@ -883,6 +896,10 @@ def gradio_main(args):
with gr.Column(min_width=20):
wizardlm_13b_1_2 = gr.Button("wizardlm-13b-1-2", elem_id="wizardlm-13b-1-2", elem_classes=["square"])
gr.Markdown("WizardLM 1.2", elem_classes=["center"])

with gr.Column(min_width=20):
codellama_13b = gr.Button("codellama-13b", elem_id="codellama-13b", elem_classes=["square"])
gr.Markdown("Code LLaMA", elem_classes=["center"])

gr.Markdown("## ~ 30B Parameters", visible=False)
with gr.Row(elem_classes=["sub-container"], visible=False):
Expand Down Expand Up @@ -935,6 +952,10 @@ def gradio_main(args):
with gr.Column(min_width=20):
upstage_llama_30b = gr.Button("upstage-llama-30b", elem_id="upstage-llama-30b", elem_classes=["square"])
gr.Markdown("Upstage LLaMA", elem_classes=["center"])

with gr.Column(min_width=20):
codellama_34b = gr.Button("codellama-34b", elem_id="codellama-34b", elem_classes=["square"])
gr.Markdown("Code LLaMA", elem_classes=["center"])

gr.Markdown("## ~ 70B Parameters")
with gr.Row(elem_classes=["sub-container"]):
Expand All @@ -957,6 +978,22 @@ def gradio_main(args):
with gr.Column(min_width=20):
wizardlm_70b = gr.Button("wizardlm-70b", elem_id="wizardlm-70b", elem_classes=["square"])
gr.Markdown("WizardLM", elem_classes=["center"])

with gr.Column(min_width=20):
orcamini_70b = gr.Button("orcamini-70b", elem_id="orcamini-70b", elem_classes=["square"])
gr.Markdown("Orca Mini", elem_classes=["center"])

with gr.Column(min_width=20):
samantha_70b = gr.Button("samantha-70b", elem_id="samantha-70b", elem_classes=["square"])
gr.Markdown("Samantha", elem_classes=["center"])

with gr.Column(min_width=20):
godzilla_70b = gr.Button("godzilla-70b", elem_id="godzilla-70b", elem_classes=["square"])
gr.Markdown("GadziLLa", elem_classes=["center"])

with gr.Column(min_width=20):
nous_hermes_70b = gr.Button("nous-hermes-70b", elem_id="nous-hermes-70b", elem_classes=["square"])
gr.Markdown("Nous Hermes 2", elem_classes=["center"])

progress_view = gr.Textbox(label="Progress", elem_classes=["progress-view"])

Expand Down Expand Up @@ -1248,17 +1285,22 @@ def gradio_main(args):
gpt4_alpaca_7b, os_stablelm7b, mpt_7b, redpajama_7b, redpajama_instruct_7b, llama_deus_7b,
evolinstruct_vicuna_7b, alpacoom_7b, baize_7b, guanaco_7b, vicuna_7b_1_3,
falcon_7b, wizard_falcon_7b, airoboros_7b, samantha_7b, openllama_7b, orcamini_7b,
xgen_7b, llama2_7b, nous_hermes_7b_v2,
xgen_7b, llama2_7b, nous_hermes_7b_v2, codellama_7b,

flan11b, koalpaca, kullm, alpaca_lora13b, gpt4_alpaca_13b, stable_vicuna_13b,
starchat_15b, starchat_beta_15b, vicuna_7b, vicuna_13b, evolinstruct_vicuna_13b,
baize_13b, guanaco_13b, nous_hermes_13b, airoboros_13b, samantha_13b, chronos_13b,
wizardlm_13b, wizard_vicuna_13b, wizard_coder_15b, vicuna_13b_1_3, openllama_13b, orcamini_13b,
llama2_13b, nous_hermes_13b_v2, nous_puffin_13b_v2, wizardlm_13b_1_2, camel20b,
llama2_13b, nous_hermes_13b_v2, nous_puffin_13b_v2, wizardlm_13b_1_2, codellama_13b, camel20b,

guanaco_33b, falcon_40b, wizard_falcon_40b, samantha_33b, lazarus_30b, chronos_33b,
wizardlm_30b, wizard_vicuna_30b, vicuna_33b_1_3, mpt_30b, upstage_llama_30b,
stable_beluga2_70b, upstage_llama2_70b, upstage_llama2_70b_2, platypus2_70b, wizardlm_70b,
wizardlm_30b, wizard_vicuna_30b, vicuna_33b_1_3, mpt_30b, upstage_llama_30b, codellama_34b,

stable_beluga2_70b, upstage_llama2_70b, upstage_llama2_70b_2, platypus2_70b, wizardlm_70b, orcamini_70b,
samantha_70b, godzilla_70b, nous_hermes_70b,

stable_beluga2_70b_rr, upstage_llama2_70b_2_rr, platypus2_70b_rr, wizardlm_70b_rr, nous_hermes_13b_v2_rr, nous_puffin_13b_v2_rr, wizardlm_13b_1_2_rr
codellama_7b_rr, codellama_13b_rr, codellama_34b_rr, upstage_llama2_70b_2_rr, platypus2_70b_rr,
wizardlm_70b_rr, orcamini_70b_rr, samantha_70b_rr, godzilla_70b_rr, nous_hermes_70b_rr
]
for btn in btns:
btn.click(
Expand Down
47 changes: 43 additions & 4 deletions chats/central.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def chat_stream(
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
):
if global_vars.remote_addr != "":
if global_vars.remote_addr is not None and global_vars.remote_addr != "":
if internet_option == "on" and serper_api_key.strip() != "":
internet_option = True
else:
Expand Down Expand Up @@ -161,15 +161,18 @@ def sync_chat_stream(
internet_option, serper_api_key
)

elif model_type == "llama2":
elif model_type == "llama2" or \
model_type == "codellama" or \
model_type == "llama2-70b" or \
model_type == "codellama2-70b":
cs = llama2.chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
)

)
elif model_type == "xgen":
cs = xgen.chat_stream(
idx, local_data, user_message, state,
Expand Down Expand Up @@ -224,6 +227,15 @@ def sync_chat_stream(
internet_option, serper_api_key
)

elif model_type == "godzilla2":
cs = alpaca.chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
)

elif model_type == "openllama":
cs = alpaca.chat_stream(
idx, local_data, user_message, state,
Expand All @@ -234,6 +246,15 @@ def sync_chat_stream(
)

elif model_type == "orcamini":
cs = alpaca.chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
)

elif model_type == "orcamini2":
cs = alpaca.chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
Expand All @@ -259,6 +280,15 @@ def sync_chat_stream(
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
)

elif model_type == "nous-hermes2":
cs = alpaca.chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
)

elif model_type == "replit-instruct":
cs = alpaca.chat_stream(
Expand Down Expand Up @@ -387,6 +417,15 @@ def sync_chat_stream(
)

elif model_type == "samantha-vicuna":
cs = vicuna.chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key
)

elif model_type == "samantha2":
cs = vicuna.chat_stream(
idx, local_data, user_message, state,
global_context, ctx_num_lconv, ctx_sum_prompt,
Expand Down
2 changes: 1 addition & 1 deletion chats/remote_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def gen_text(
headers={
'Content-type': 'application/json'
}
if remote_token is not None:
if remote_token is not None and remote_token != "":
headers["Authorization"] = f'Bearer {remote_token}'

data = {
Expand Down
Loading

0 comments on commit e36ab3c

Please sign in to comment.