-
Notifications
You must be signed in to change notification settings - Fork 379
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ab1304d
commit 50c5080
Showing
19 changed files
with
304 additions
and
258 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import global_vars | ||
|
||
import gradio as gr | ||
|
||
from gens.batch_gen import get_output_batch | ||
from miscs.strings import SPECIAL_STRS | ||
from miscs.constants import num_of_characters_to_keep | ||
from miscs.utils import generate_prompt | ||
from miscs.utils import common_post_process, post_processes_batch, post_process_stream | ||
|
||
def chat_stream( | ||
context, | ||
instruction, | ||
state_chatbot, | ||
): | ||
if len(context) > 1000 or len(instruction) > 300: | ||
raise gr.Error("context or prompt is too long!") | ||
|
||
bot_summarized_response = '' | ||
# user input should be appropriately formatted (don't be confused by the function name) | ||
instruction_display = common_post_process(instruction) | ||
instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context) | ||
|
||
if conv_length > num_of_characters_to_keep: | ||
instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context)[0] | ||
|
||
state_chatbot = state_chatbot + [ | ||
( | ||
None, | ||
" too long conversations, so let's summarize..." | ||
) | ||
] | ||
yield (state_chatbot, state_chatbot, context) | ||
|
||
bot_summarized_response = get_output_batch( | ||
global_vars.model, global_vars.tokenizer, [instruction_prompt], global_vars.generation_config | ||
)[0] | ||
bot_summarized_response = bot_summarized_response.split("### Response:")[-1].strip() | ||
|
||
state_chatbot[-1] = ( | ||
None, | ||
"✅ summarization is done and set as context" | ||
) | ||
print(f"bot_summarized_response: {bot_summarized_response}") | ||
yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip()) | ||
|
||
instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0] | ||
|
||
bot_response = global_vars.stream_model( | ||
instruction_prompt, | ||
max_tokens=256, | ||
temperature=1, | ||
top_p=0.9 | ||
) | ||
|
||
instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display | ||
state_chatbot = state_chatbot + [(instruction_display, None)] | ||
yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip()) | ||
|
||
prev_index = 0 | ||
agg_tokens = "" | ||
cutoff_idx = 0 | ||
for tokens in bot_response: | ||
tokens = tokens.strip() | ||
cur_token = tokens[prev_index:] | ||
|
||
if "#" in cur_token and agg_tokens == "": | ||
cutoff_idx = tokens.find("#") | ||
agg_tokens = tokens[cutoff_idx:] | ||
|
||
if agg_tokens != "": | ||
if len(agg_tokens) < len("### Instruction:") : | ||
agg_tokens = agg_tokens + cur_token | ||
elif len(agg_tokens) >= len("### Instruction:"): | ||
if tokens.find("### Instruction:") > -1: | ||
processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip()) | ||
|
||
state_chatbot[-1] = ( | ||
instruction_display, | ||
processed_response | ||
) | ||
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip()) | ||
break | ||
else: | ||
agg_tokens = "" | ||
cutoff_idx = 0 | ||
|
||
if agg_tokens == "": | ||
processed_response, to_exit = post_process_stream(tokens) | ||
state_chatbot[-1] = (instruction_display, processed_response) | ||
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip()) | ||
|
||
if to_exit: | ||
break | ||
|
||
prev_index = len(tokens) | ||
|
||
yield ( | ||
state_chatbot, | ||
state_chatbot, | ||
f"{context} {bot_summarized_response}".strip() | ||
) | ||
|
||
|
||
def chat_batch( | ||
contexts, | ||
instructions, | ||
state_chatbots, | ||
): | ||
state_results = [] | ||
ctx_results = [] | ||
|
||
instruct_prompts = [ | ||
generate_prompt(instruct, histories, ctx)[0] | ||
for ctx, instruct, histories in zip(contexts, instructions, state_chatbots) | ||
] | ||
|
||
bot_responses = get_output_batch( | ||
global_vars.model, global_vars.tokenizer, instruct_prompts, global_vars.generation_config | ||
) | ||
bot_responses = post_processes_batch(bot_responses) | ||
|
||
for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots): | ||
new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)] | ||
ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx) | ||
state_results.append(new_state_chatbot) | ||
|
||
return (state_results, state_results, ctx_results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
|
||
def get_output_batch( | ||
model, tokenizer, prompts, generation_config | ||
): | ||
if len(prompts) == 1: | ||
encoding = tokenizer(prompts, return_tensors="pt") | ||
input_ids = encoding["input_ids"].cuda() | ||
generated_id = model.generate( | ||
input_ids=input_ids, | ||
generation_config=generation_config, | ||
max_new_tokens=256 | ||
) | ||
|
||
decoded = tokenizer.batch_decode(generated_id) | ||
del input_ids, generated_id | ||
torch.cuda.empty_cache() | ||
return decoded | ||
else: | ||
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda') | ||
generated_ids = model.generate( | ||
**encodings, | ||
generation_config=generation_config, | ||
max_new_tokens=256 | ||
) | ||
|
||
decoded = tokenizer.batch_decode(generated_ids) | ||
del encodings, generated_ids | ||
torch.cuda.empty_cache() | ||
return decoded |
Oops, something went wrong.