Skip to content

Commit

Permalink
modularization done
Browse files Browse the repository at this point in the history
  • Loading branch information
deep-diver committed Apr 4, 2023
1 parent ab1304d commit 50c5080
Show file tree
Hide file tree
Showing 19 changed files with 304 additions and 258 deletions.
Empty file added __init__.py
Empty file.
128 changes: 128 additions & 0 deletions alpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import global_vars

import gradio as gr

from gens.batch_gen import get_output_batch
from miscs.strings import SPECIAL_STRS
from miscs.constants import num_of_characters_to_keep
from miscs.utils import generate_prompt
from miscs.utils import common_post_process, post_processes_batch, post_process_stream

def chat_stream(
context,
instruction,
state_chatbot,
):
if len(context) > 1000 or len(instruction) > 300:
raise gr.Error("context or prompt is too long!")

bot_summarized_response = ''
# user input should be appropriately formatted (don't be confused by the function name)
instruction_display = common_post_process(instruction)
instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context)

if conv_length > num_of_characters_to_keep:
instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context)[0]

state_chatbot = state_chatbot + [
(
None,
"![](https://s2.gifyu.com/images/icons8-loading-circle.gif) too long conversations, so let's summarize..."
)
]
yield (state_chatbot, state_chatbot, context)

bot_summarized_response = get_output_batch(
global_vars.model, global_vars.tokenizer, [instruction_prompt], global_vars.generation_config
)[0]
bot_summarized_response = bot_summarized_response.split("### Response:")[-1].strip()

state_chatbot[-1] = (
None,
"✅ summarization is done and set as context"
)
print(f"bot_summarized_response: {bot_summarized_response}")
yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())

instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0]

bot_response = global_vars.stream_model(
instruction_prompt,
max_tokens=256,
temperature=1,
top_p=0.9
)

instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display
state_chatbot = state_chatbot + [(instruction_display, None)]
yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())

prev_index = 0
agg_tokens = ""
cutoff_idx = 0
for tokens in bot_response:
tokens = tokens.strip()
cur_token = tokens[prev_index:]

if "#" in cur_token and agg_tokens == "":
cutoff_idx = tokens.find("#")
agg_tokens = tokens[cutoff_idx:]

if agg_tokens != "":
if len(agg_tokens) < len("### Instruction:") :
agg_tokens = agg_tokens + cur_token
elif len(agg_tokens) >= len("### Instruction:"):
if tokens.find("### Instruction:") > -1:
processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip())

state_chatbot[-1] = (
instruction_display,
processed_response
)
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
break
else:
agg_tokens = ""
cutoff_idx = 0

if agg_tokens == "":
processed_response, to_exit = post_process_stream(tokens)
state_chatbot[-1] = (instruction_display, processed_response)
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())

if to_exit:
break

prev_index = len(tokens)

yield (
state_chatbot,
state_chatbot,
f"{context} {bot_summarized_response}".strip()
)


def chat_batch(
contexts,
instructions,
state_chatbots,
):
state_results = []
ctx_results = []

instruct_prompts = [
generate_prompt(instruct, histories, ctx)[0]
for ctx, instruct, histories in zip(contexts, instructions, state_chatbots)
]

bot_responses = get_output_batch(
global_vars.model, global_vars.tokenizer, instruct_prompts, global_vars.generation_config
)
bot_responses = post_processes_batch(bot_responses)

for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots):
new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)]
ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx)
state_results.append(new_state_chatbot)

return (state_results, state_results, ctx_results)
162 changes: 14 additions & 148 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,134 +1,16 @@
from strings import TITLE, ABSTRACT, BOTTOM_LINE
from strings import DEFAULT_EXAMPLES
from strings import SPECIAL_STRS
from styles import PARENT_BLOCK_CSS

from constants import num_of_characters_to_keep

import time
import gradio as gr

from args import parse_args
from model import load_model
from gen import get_output_batch, StreamModel
from utils import generate_prompt, post_processes_batch, post_process_stream, get_generation_config, common_post_process

def chat_stream(
context,
instruction,
state_chatbot,
):
if len(context) > 500 or len(instruction) > 150:
raise gr.Error("context or prompt is too long!")

bot_summarized_response = ''
# user input should be appropriately formatted (don't be confused by the function name)
instruction_display = common_post_process(instruction)
instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context)

if conv_length > num_of_characters_to_keep:
instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context)[0]

state_chatbot = state_chatbot + [
(
None,
"![](https://s2.gifyu.com/images/icons8-loading-circle.gif) too long conversations, so let's summarize..."
)
]
yield (state_chatbot, state_chatbot, context)

bot_summarized_response = get_output_batch(
model, tokenizer, [instruction_prompt], gen_config_summarization
)[0]
bot_summarized_response = bot_summarized_response.split("### Response:")[-1].strip()

state_chatbot[-1] = (
None,
"✅ summarization is done and set as context"
)
print(f"bot_summarized_response: {bot_summarized_response}")
yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}")

instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0]

bot_response = stream_model(
instruction_prompt,
max_tokens=256,
temperature=1,
top_p=0.9
)

instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display
state_chatbot = state_chatbot + [(instruction_display, None)]

prev_index = 0
agg_tokens = ""
cutoff_idx = 0
for tokens in bot_response:
tokens = tokens.strip()
cur_token = tokens[prev_index:]

if "#" in cur_token and agg_tokens == "":
cutoff_idx = tokens.find("#")
agg_tokens = tokens[cutoff_idx:]

if agg_tokens != "":
if len(agg_tokens) < len("### Instruction:") :
agg_tokens = agg_tokens + cur_token
elif len(agg_tokens) >= len("### Instruction:"):
if tokens.find("### Instruction:") > -1:
processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip())

state_chatbot[-1] = (
instruction_display,
processed_response
)
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}")
break
else:
agg_tokens = ""
cutoff_idx = 0
import global_vars
import alpaca

if agg_tokens == "":
processed_response, to_exit = post_process_stream(tokens)
state_chatbot[-1] = (instruction_display, processed_response)
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}")

if to_exit:
break

prev_index = len(tokens)

yield (
state_chatbot,
state_chatbot,
f"{context} {bot_summarized_response}"
)

def chat_batch(
contexts,
instructions,
state_chatbots,
):
state_results = []
ctx_results = []

instruct_prompts = [
generate_prompt(instruct, histories, ctx)[0]
for ctx, instruct, histories in zip(contexts, instructions, state_chatbots)
]

bot_responses = get_output_batch(
model, tokenizer, instruct_prompts, generation_config
)
bot_responses = post_processes_batch(bot_responses)

for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots):
new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)]
ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx)
state_results.append(new_state_chatbot)
from args import parse_args
from miscs.strings import TITLE, ABSTRACT, BOTTOM_LINE
from miscs.strings import DEFAULT_EXAMPLES
from miscs.styles import PARENT_BLOCK_CSS
from miscs.strings import SPECIAL_STRS

return (state_results, state_results, ctx_results)
from utils import get_chat_interface

def reset_textbox():
return gr.Textbox.update(value='')
Expand All @@ -148,25 +30,9 @@ def reset_everything(
)

def run(args):
global model, stream_model, tokenizer, generation_config, gen_config_summarization, batch_enabled

batch_enabled = True if args.batch_size > 1 else False

model, tokenizer = load_model(
base=args.base_url,
finetuned=args.ft_ckpt_url,
multi_gpu=args.multi_gpu
)

generation_config = get_generation_config(
args.gen_config_path
)
gen_config_summarization = get_generation_config(
"gen_config_summarization.yaml"
)

if not batch_enabled:
stream_model = StreamModel(model, tokenizer)
global_vars.initialize_globals(args)
batch_enabled = global_vars.batch_enabled
chat_interface = get_chat_interface(global_vars.model_type, batch_enabled)

with gr.Blocks(css=PARENT_BLOCK_CSS) as demo:
state_chatbot = gr.State([])
Expand Down Expand Up @@ -219,7 +85,7 @@ def run(args):
gr.Markdown(f"{BOTTOM_LINE}")

send_event = instruction_txtbox.submit(
chat_batch if batch_enabled else chat_stream,
chat_interface,
[context_txtbox, instruction_txtbox, state_chatbot],
[state_chatbot, chatbot, context_txtbox],
batch=batch_enabled,
Expand All @@ -232,7 +98,7 @@ def run(args):
)

continue_event = continue_btn.click(
chat_batch if batch_enabled else chat_stream,
chat_interface,
[context_txtbox, continue_txtbox, state_chatbot],
[state_chatbot, chatbot, context_txtbox],
batch=batch_enabled,
Expand All @@ -245,7 +111,7 @@ def run(args):
)

summarize_event = summarize_btn.click(
chat_batch if batch_enabled else chat_stream,
chat_interface,
[context_txtbox, summrize_txtbox, state_chatbot],
[state_chatbot, chatbot, context_txtbox],
batch=batch_enabled,
Expand Down
8 changes: 7 additions & 1 deletion args.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@ def parse_args():
parser.add_argument(
"--gen_config_path",
help="path to GenerationConfig file used in batch mode",
default="generation_config_default.yaml",
default="configs/generation_config_default.yaml",
type=str
)
parser.add_argument(
"--gen_config_summarization_path",
help="path to GenerationConfig file used in context summarization",
default="configs/gen_config_summarization.yaml",
type=str
)
parser.add_argument(
"--multi_gpu",
help="Enable multi gpu mode. This will force not to use Int8 but float16, so you need to check if your system has enough GPU memory",
Expand Down
Empty file added chats/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
Empty file added gens/__init__.py
Empty file.
30 changes: 30 additions & 0 deletions gens/batch_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch

def get_output_batch(
model, tokenizer, prompts, generation_config
):
if len(prompts) == 1:
encoding = tokenizer(prompts, return_tensors="pt")
input_ids = encoding["input_ids"].cuda()
generated_id = model.generate(
input_ids=input_ids,
generation_config=generation_config,
max_new_tokens=256
)

decoded = tokenizer.batch_decode(generated_id)
del input_ids, generated_id
torch.cuda.empty_cache()
return decoded
else:
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda')
generated_ids = model.generate(
**encodings,
generation_config=generation_config,
max_new_tokens=256
)

decoded = tokenizer.batch_decode(generated_ids)
del encodings, generated_ids
torch.cuda.empty_cache()
return decoded
Loading

0 comments on commit 50c5080

Please sign in to comment.