diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alpaca.py b/alpaca.py new file mode 100644 index 0000000..c76c3ee --- /dev/null +++ b/alpaca.py @@ -0,0 +1,128 @@ +import global_vars + +import gradio as gr + +from gens.batch_gen import get_output_batch +from miscs.strings import SPECIAL_STRS +from miscs.constants import num_of_characters_to_keep +from miscs.utils import generate_prompt +from miscs.utils import common_post_process, post_processes_batch, post_process_stream + +def chat_stream( + context, + instruction, + state_chatbot, +): + if len(context) > 1000 or len(instruction) > 300: + raise gr.Error("context or prompt is too long!") + + bot_summarized_response = '' + # user input should be appropriately formatted (don't be confused by the function name) + instruction_display = common_post_process(instruction) + instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context) + + if conv_length > num_of_characters_to_keep: + instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context)[0] + + state_chatbot = state_chatbot + [ + ( + None, + "![](https://s2.gifyu.com/images/icons8-loading-circle.gif) too long conversations, so let's summarize..." + ) + ] + yield (state_chatbot, state_chatbot, context) + + bot_summarized_response = get_output_batch( + global_vars.model, global_vars.tokenizer, [instruction_prompt], global_vars.generation_config + )[0] + bot_summarized_response = bot_summarized_response.split("### Response:")[-1].strip() + + state_chatbot[-1] = ( + None, + "✅ summarization is done and set as context" + ) + print(f"bot_summarized_response: {bot_summarized_response}") + yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip()) + + instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0] + + bot_response = global_vars.stream_model( + instruction_prompt, + max_tokens=256, + temperature=1, + top_p=0.9 + ) + + instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display + state_chatbot = state_chatbot + [(instruction_display, None)] + yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip()) + + prev_index = 0 + agg_tokens = "" + cutoff_idx = 0 + for tokens in bot_response: + tokens = tokens.strip() + cur_token = tokens[prev_index:] + + if "#" in cur_token and agg_tokens == "": + cutoff_idx = tokens.find("#") + agg_tokens = tokens[cutoff_idx:] + + if agg_tokens != "": + if len(agg_tokens) < len("### Instruction:") : + agg_tokens = agg_tokens + cur_token + elif len(agg_tokens) >= len("### Instruction:"): + if tokens.find("### Instruction:") > -1: + processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip()) + + state_chatbot[-1] = ( + instruction_display, + processed_response + ) + yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip()) + break + else: + agg_tokens = "" + cutoff_idx = 0 + + if agg_tokens == "": + processed_response, to_exit = post_process_stream(tokens) + state_chatbot[-1] = (instruction_display, processed_response) + yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip()) + + if to_exit: + break + + prev_index = len(tokens) + + yield ( + state_chatbot, + state_chatbot, + f"{context} {bot_summarized_response}".strip() + ) + + +def chat_batch( + contexts, + instructions, + state_chatbots, +): + state_results = [] + ctx_results = [] + + instruct_prompts = [ + generate_prompt(instruct, histories, ctx)[0] + for ctx, instruct, histories in zip(contexts, instructions, state_chatbots) + ] + + bot_responses = get_output_batch( + global_vars.model, global_vars.tokenizer, instruct_prompts, global_vars.generation_config + ) + bot_responses = post_processes_batch(bot_responses) + + for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots): + new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)] + ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx) + state_results.append(new_state_chatbot) + + return (state_results, state_results, ctx_results) \ No newline at end of file diff --git a/app.py b/app.py index b2827e8..666031d 100644 --- a/app.py +++ b/app.py @@ -1,134 +1,16 @@ -from strings import TITLE, ABSTRACT, BOTTOM_LINE -from strings import DEFAULT_EXAMPLES -from strings import SPECIAL_STRS -from styles import PARENT_BLOCK_CSS - -from constants import num_of_characters_to_keep - import time import gradio as gr -from args import parse_args -from model import load_model -from gen import get_output_batch, StreamModel -from utils import generate_prompt, post_processes_batch, post_process_stream, get_generation_config, common_post_process - -def chat_stream( - context, - instruction, - state_chatbot, -): - if len(context) > 500 or len(instruction) > 150: - raise gr.Error("context or prompt is too long!") - - bot_summarized_response = '' - # user input should be appropriately formatted (don't be confused by the function name) - instruction_display = common_post_process(instruction) - instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context) - - if conv_length > num_of_characters_to_keep: - instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context)[0] - - state_chatbot = state_chatbot + [ - ( - None, - "![](https://s2.gifyu.com/images/icons8-loading-circle.gif) too long conversations, so let's summarize..." - ) - ] - yield (state_chatbot, state_chatbot, context) - - bot_summarized_response = get_output_batch( - model, tokenizer, [instruction_prompt], gen_config_summarization - )[0] - bot_summarized_response = bot_summarized_response.split("### Response:")[-1].strip() - - state_chatbot[-1] = ( - None, - "✅ summarization is done and set as context" - ) - print(f"bot_summarized_response: {bot_summarized_response}") - yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}") - - instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0] - - bot_response = stream_model( - instruction_prompt, - max_tokens=256, - temperature=1, - top_p=0.9 - ) - - instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display - state_chatbot = state_chatbot + [(instruction_display, None)] - - prev_index = 0 - agg_tokens = "" - cutoff_idx = 0 - for tokens in bot_response: - tokens = tokens.strip() - cur_token = tokens[prev_index:] - - if "#" in cur_token and agg_tokens == "": - cutoff_idx = tokens.find("#") - agg_tokens = tokens[cutoff_idx:] - - if agg_tokens != "": - if len(agg_tokens) < len("### Instruction:") : - agg_tokens = agg_tokens + cur_token - elif len(agg_tokens) >= len("### Instruction:"): - if tokens.find("### Instruction:") > -1: - processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip()) - - state_chatbot[-1] = ( - instruction_display, - processed_response - ) - yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}") - break - else: - agg_tokens = "" - cutoff_idx = 0 +import global_vars +import alpaca - if agg_tokens == "": - processed_response, to_exit = post_process_stream(tokens) - state_chatbot[-1] = (instruction_display, processed_response) - yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}") - - if to_exit: - break - - prev_index = len(tokens) - - yield ( - state_chatbot, - state_chatbot, - f"{context} {bot_summarized_response}" - ) - -def chat_batch( - contexts, - instructions, - state_chatbots, -): - state_results = [] - ctx_results = [] - - instruct_prompts = [ - generate_prompt(instruct, histories, ctx)[0] - for ctx, instruct, histories in zip(contexts, instructions, state_chatbots) - ] - - bot_responses = get_output_batch( - model, tokenizer, instruct_prompts, generation_config - ) - bot_responses = post_processes_batch(bot_responses) - - for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots): - new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)] - ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx) - state_results.append(new_state_chatbot) +from args import parse_args +from miscs.strings import TITLE, ABSTRACT, BOTTOM_LINE +from miscs.strings import DEFAULT_EXAMPLES +from miscs.styles import PARENT_BLOCK_CSS +from miscs.strings import SPECIAL_STRS - return (state_results, state_results, ctx_results) +from utils import get_chat_interface def reset_textbox(): return gr.Textbox.update(value='') @@ -148,25 +30,9 @@ def reset_everything( ) def run(args): - global model, stream_model, tokenizer, generation_config, gen_config_summarization, batch_enabled - - batch_enabled = True if args.batch_size > 1 else False - - model, tokenizer = load_model( - base=args.base_url, - finetuned=args.ft_ckpt_url, - multi_gpu=args.multi_gpu - ) - - generation_config = get_generation_config( - args.gen_config_path - ) - gen_config_summarization = get_generation_config( - "gen_config_summarization.yaml" - ) - - if not batch_enabled: - stream_model = StreamModel(model, tokenizer) + global_vars.initialize_globals(args) + batch_enabled = global_vars.batch_enabled + chat_interface = get_chat_interface(global_vars.model_type, batch_enabled) with gr.Blocks(css=PARENT_BLOCK_CSS) as demo: state_chatbot = gr.State([]) @@ -219,7 +85,7 @@ def run(args): gr.Markdown(f"{BOTTOM_LINE}") send_event = instruction_txtbox.submit( - chat_batch if batch_enabled else chat_stream, + chat_interface, [context_txtbox, instruction_txtbox, state_chatbot], [state_chatbot, chatbot, context_txtbox], batch=batch_enabled, @@ -232,7 +98,7 @@ def run(args): ) continue_event = continue_btn.click( - chat_batch if batch_enabled else chat_stream, + chat_interface, [context_txtbox, continue_txtbox, state_chatbot], [state_chatbot, chatbot, context_txtbox], batch=batch_enabled, @@ -245,7 +111,7 @@ def run(args): ) summarize_event = summarize_btn.click( - chat_batch if batch_enabled else chat_stream, + chat_interface, [context_txtbox, summrize_txtbox, state_chatbot], [state_chatbot, chatbot, context_txtbox], batch=batch_enabled, diff --git a/args.py b/args.py index f3f63ab..37f14a8 100644 --- a/args.py +++ b/args.py @@ -42,9 +42,15 @@ def parse_args(): parser.add_argument( "--gen_config_path", help="path to GenerationConfig file used in batch mode", - default="generation_config_default.yaml", + default="configs/generation_config_default.yaml", type=str ) + parser.add_argument( + "--gen_config_summarization_path", + help="path to GenerationConfig file used in context summarization", + default="configs/gen_config_summarization.yaml", + type=str + ) parser.add_argument( "--multi_gpu", help="Enable multi gpu mode. This will force not to use Int8 but float16, so you need to check if your system has enough GPU memory", diff --git a/chats/__init__.py b/chats/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gen_config_summarization.yaml b/configs/gen_config_summarization.yaml similarity index 100% rename from gen_config_summarization.yaml rename to configs/gen_config_summarization.yaml diff --git a/generation_config_default.yaml b/configs/generation_config_default.yaml similarity index 100% rename from generation_config_default.yaml rename to configs/generation_config_default.yaml diff --git a/gens/__init__.py b/gens/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gens/batch_gen.py b/gens/batch_gen.py new file mode 100644 index 0000000..33c351c --- /dev/null +++ b/gens/batch_gen.py @@ -0,0 +1,30 @@ +import torch + +def get_output_batch( + model, tokenizer, prompts, generation_config +): + if len(prompts) == 1: + encoding = tokenizer(prompts, return_tensors="pt") + input_ids = encoding["input_ids"].cuda() + generated_id = model.generate( + input_ids=input_ids, + generation_config=generation_config, + max_new_tokens=256 + ) + + decoded = tokenizer.batch_decode(generated_id) + del input_ids, generated_id + torch.cuda.empty_cache() + return decoded + else: + encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda') + generated_ids = model.generate( + **encodings, + generation_config=generation_config, + max_new_tokens=256 + ) + + decoded = tokenizer.batch_decode(generated_ids) + del encodings, generated_ids + torch.cuda.empty_cache() + return decoded diff --git a/gen.py b/gens/stream_gen.py similarity index 90% rename from gen.py rename to gens/stream_gen.py index 81b4e89..71d901c 100644 --- a/gen.py +++ b/gens/stream_gen.py @@ -15,34 +15,6 @@ TopPLogitsWarper, ) -def get_output_batch( - model, tokenizer, prompts, generation_config -): - if len(prompts) == 1: - encoding = tokenizer(prompts, return_tensors="pt") - input_ids = encoding["input_ids"].cuda() - generated_id = model.generate( - input_ids=input_ids, - generation_config=generation_config, - max_new_tokens=256 - ) - - decoded = tokenizer.batch_decode(generated_id) - del input_ids, generated_id - torch.cuda.empty_cache() - return decoded - else: - encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda') - generated_ids = model.generate( - **encodings, - generation_config=generation_config, - max_new_tokens=256 - ) - - decoded = tokenizer.batch_decode(generated_ids) - del encodings, generated_ids - torch.cuda.empty_cache() - return decoded # StreamModel is borrowed from basaran project diff --git a/global_vars.py b/global_vars.py new file mode 100644 index 0000000..0f7aad6 --- /dev/null +++ b/global_vars.py @@ -0,0 +1,36 @@ +from models.alpaca_model import load_model +from gens.stream_gen import StreamModel + +from miscs.utils import get_generation_config + +def initialize_globals(args): + global model, stream_model, tokenizer + global generation_config, gen_config_summarization + global model_type, batch_enabled + + model_type = "alpaca" + batch_enabled = True if args.batch_size > 1 else False + + model, tokenizer = load_model( + base=args.base_url, + finetuned=args.ft_ckpt_url, + multi_gpu=args.multi_gpu + ) + + if "alpaca" in args.ft_ckpt_url: + model_type = "alpaca" + elif "baize" in args.ft_ckpt_url: + model_type = "baize" + else: + print("unsupported model type. only alpaca and baize are supported") + quit() + + generation_config = get_generation_config( + args.gen_config_path + ) + gen_config_summarization = get_generation_config( + args.gen_config_summarization_path + ) + + if not batch_enabled: + stream_model = StreamModel(model, tokenizer) \ No newline at end of file diff --git a/miscs/__init__.py b/miscs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/constants.py b/miscs/constants.py similarity index 100% rename from constants.py rename to miscs/constants.py diff --git a/strings.py b/miscs/strings.py similarity index 100% rename from strings.py rename to miscs/strings.py diff --git a/styles.py b/miscs/styles.py similarity index 100% rename from styles.py rename to miscs/styles.py diff --git a/miscs/utils.py b/miscs/utils.py new file mode 100644 index 0000000..188eb97 --- /dev/null +++ b/miscs/utils.py @@ -0,0 +1,81 @@ +import re +import yaml + +from transformers import GenerationConfig + +from miscs.strings import SPECIAL_STRS +from miscs.constants import html_tag_pattern, multi_line_pattern, multi_space_pattern +from miscs.constants import repl_empty_str, repl_br_tag, repl_span_tag_multispace, repl_linebreak + +def get_generation_config(path): + with open(path, 'rb') as f: + generation_config = yaml.safe_load(f.read()) + + return GenerationConfig(**generation_config["generation_config"]) + +def generate_prompt(prompt, histories, ctx=None): + convs = f"""Below is a history of instructions that describe tasks, paired with an input that provides further context. Write a response that appropriately completes the request by remembering the conversation history. + +""" + if ctx is not None: + convs = f"""{ctx} + +""" + sub_convs = "" + start_idx = 0 + + for idx, history in enumerate(histories): + history_prompt = history[0] + history_response = history[1] + if history_response == "✅ summarization is done and set as context" or history_prompt == SPECIAL_STRS["summarize"]: + start_idx = idx + + # drop the previous conversations if user has summarized + for history in histories[start_idx if start_idx == 0 else start_idx+1:]: + history_prompt = history[0] + history_response = history[1] + + history_response = history_response.replace("
", "\n") + history_response = re.sub( + html_tag_pattern, repl_empty_str, history_response + ) + + sub_convs = sub_convs + f"""### Instruction:{history_prompt} + +### Response:{history_response} + +""" + + sub_convs = sub_convs + f"""### Instruction:{prompt} + +### Response:""" + + convs = convs + sub_convs + return convs, len(sub_convs) + +# applicable to instruction to be displayed as well +def common_post_process(original_str): + original_str = re.sub( + multi_line_pattern, repl_br_tag, original_str + ) + original_str = re.sub( + multi_space_pattern, repl_span_tag_multispace, original_str + ) + + return original_str + +def post_process_stream(bot_response): + # sometimes model spits out text containing + # "### Response:" and "### Instruction: -> in this case, we want to stop generating + if "### Response:" in bot_response or "### Input:" in bot_response: + bot_response = bot_response.replace("### Response:", '').replace("### Input:", '').strip() + return bot_response, True + + return common_post_process(bot_response), False + +def post_process_batch(bot_response): + bot_response = bot_response.split("### Response:")[-1].strip() + return common_post_process(bot_response) + +def post_processes_batch(bot_responses): + return [post_process_batch(r) for r in bot_responses] \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model.py b/models/alpaca_model.py similarity index 100% rename from model.py rename to models/alpaca_model.py diff --git a/utils.py b/utils.py index 023a6d3..2d974cb 100644 --- a/utils.py +++ b/utils.py @@ -1,81 +1,8 @@ -import re -import yaml - -from transformers import GenerationConfig - -from strings import SPECIAL_STRS -from constants import html_tag_pattern, multi_line_pattern, multi_space_pattern -from constants import repl_empty_str, repl_br_tag, repl_span_tag_multispace, repl_linebreak - -def get_generation_config(path): - with open(path, 'rb') as f: - generation_config = yaml.safe_load(f.read()) - - return GenerationConfig(**generation_config["generation_config"]) - -def generate_prompt(prompt, histories, ctx=None): - convs = f"""Below is a history of instructions that describe tasks, paired with an input that provides further context. Write a response that appropriately completes the request by remembering the conversation history. - -""" - if ctx is not None: - convs = f"""{ctx} - -""" - sub_convs = "" - start_idx = 0 - - for idx, history in enumerate(histories): - history_prompt = history[0] - history_response = history[1] - if history_response == "✅ summarization is done and set as context" or history_prompt == SPECIAL_STRS["summarize"]: - start_idx = idx - - # drop the previous conversations if user has summarized - for history in histories[start_idx if start_idx == 0 else start_idx+1:]: - history_prompt = history[0] - history_response = history[1] - - history_response = history_response.replace("
", "\n") - history_response = re.sub( - html_tag_pattern, repl_empty_str, history_response - ) - - sub_convs = sub_convs + f"""### Instruction:{history_prompt} - -### Response:{history_response} - -""" - - sub_convs = sub_convs + f"""### Instruction:{prompt} - -### Response:""" - - convs = convs + sub_convs - return convs, len(sub_convs) - -# applicable to instruction to be displayed as well -def common_post_process(original_str): - original_str = re.sub( - multi_line_pattern, repl_br_tag, original_str - ) - original_str = re.sub( - multi_space_pattern, repl_span_tag_multispace, original_str - ) - - return original_str - -def post_process_stream(bot_response): - # sometimes model spits out text containing - # "### Response:" and "### Instruction: -> in this case, we want to stop generating - if "### Response:" in bot_response or "### Input:" in bot_response: - bot_response = bot_response.replace("### Response:", '').replace("### Input:", '').strip() - return bot_response, True - - return common_post_process(bot_response), False - -def post_process_batch(bot_response): - bot_response = bot_response.split("### Response:")[-1].strip() - return common_post_process(bot_response) - -def post_processes_batch(bot_responses): - return [post_process_batch(r) for r in bot_responses] +import alpaca + +def get_chat_interface(model_type, batch_enabled): + match model_type: + case 'alpaca': + return alpaca.chat_batch if batch_enabled else alpaca.chat_stream + case other: + return None