diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/alpaca.py b/alpaca.py
new file mode 100644
index 0000000..c76c3ee
--- /dev/null
+++ b/alpaca.py
@@ -0,0 +1,128 @@
+import global_vars
+
+import gradio as gr
+
+from gens.batch_gen import get_output_batch
+from miscs.strings import SPECIAL_STRS
+from miscs.constants import num_of_characters_to_keep
+from miscs.utils import generate_prompt
+from miscs.utils import common_post_process, post_processes_batch, post_process_stream
+
+def chat_stream(
+ context,
+ instruction,
+ state_chatbot,
+):
+ if len(context) > 1000 or len(instruction) > 300:
+ raise gr.Error("context or prompt is too long!")
+
+ bot_summarized_response = ''
+ # user input should be appropriately formatted (don't be confused by the function name)
+ instruction_display = common_post_process(instruction)
+ instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context)
+
+ if conv_length > num_of_characters_to_keep:
+ instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context)[0]
+
+ state_chatbot = state_chatbot + [
+ (
+ None,
+ " too long conversations, so let's summarize..."
+ )
+ ]
+ yield (state_chatbot, state_chatbot, context)
+
+ bot_summarized_response = get_output_batch(
+ global_vars.model, global_vars.tokenizer, [instruction_prompt], global_vars.generation_config
+ )[0]
+ bot_summarized_response = bot_summarized_response.split("### Response:")[-1].strip()
+
+ state_chatbot[-1] = (
+ None,
+ "✅ summarization is done and set as context"
+ )
+ print(f"bot_summarized_response: {bot_summarized_response}")
+ yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())
+
+ instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0]
+
+ bot_response = global_vars.stream_model(
+ instruction_prompt,
+ max_tokens=256,
+ temperature=1,
+ top_p=0.9
+ )
+
+ instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display
+ state_chatbot = state_chatbot + [(instruction_display, None)]
+ yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip())
+
+ prev_index = 0
+ agg_tokens = ""
+ cutoff_idx = 0
+ for tokens in bot_response:
+ tokens = tokens.strip()
+ cur_token = tokens[prev_index:]
+
+ if "#" in cur_token and agg_tokens == "":
+ cutoff_idx = tokens.find("#")
+ agg_tokens = tokens[cutoff_idx:]
+
+ if agg_tokens != "":
+ if len(agg_tokens) < len("### Instruction:") :
+ agg_tokens = agg_tokens + cur_token
+ elif len(agg_tokens) >= len("### Instruction:"):
+ if tokens.find("### Instruction:") > -1:
+ processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip())
+
+ state_chatbot[-1] = (
+ instruction_display,
+ processed_response
+ )
+ yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
+ break
+ else:
+ agg_tokens = ""
+ cutoff_idx = 0
+
+ if agg_tokens == "":
+ processed_response, to_exit = post_process_stream(tokens)
+ state_chatbot[-1] = (instruction_display, processed_response)
+ yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip())
+
+ if to_exit:
+ break
+
+ prev_index = len(tokens)
+
+ yield (
+ state_chatbot,
+ state_chatbot,
+ f"{context} {bot_summarized_response}".strip()
+ )
+
+
+def chat_batch(
+ contexts,
+ instructions,
+ state_chatbots,
+):
+ state_results = []
+ ctx_results = []
+
+ instruct_prompts = [
+ generate_prompt(instruct, histories, ctx)[0]
+ for ctx, instruct, histories in zip(contexts, instructions, state_chatbots)
+ ]
+
+ bot_responses = get_output_batch(
+ global_vars.model, global_vars.tokenizer, instruct_prompts, global_vars.generation_config
+ )
+ bot_responses = post_processes_batch(bot_responses)
+
+ for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots):
+ new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)]
+ ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx)
+ state_results.append(new_state_chatbot)
+
+ return (state_results, state_results, ctx_results)
\ No newline at end of file
diff --git a/app.py b/app.py
index b2827e8..666031d 100644
--- a/app.py
+++ b/app.py
@@ -1,134 +1,16 @@
-from strings import TITLE, ABSTRACT, BOTTOM_LINE
-from strings import DEFAULT_EXAMPLES
-from strings import SPECIAL_STRS
-from styles import PARENT_BLOCK_CSS
-
-from constants import num_of_characters_to_keep
-
import time
import gradio as gr
-from args import parse_args
-from model import load_model
-from gen import get_output_batch, StreamModel
-from utils import generate_prompt, post_processes_batch, post_process_stream, get_generation_config, common_post_process
-
-def chat_stream(
- context,
- instruction,
- state_chatbot,
-):
- if len(context) > 500 or len(instruction) > 150:
- raise gr.Error("context or prompt is too long!")
-
- bot_summarized_response = ''
- # user input should be appropriately formatted (don't be confused by the function name)
- instruction_display = common_post_process(instruction)
- instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context)
-
- if conv_length > num_of_characters_to_keep:
- instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context)[0]
-
- state_chatbot = state_chatbot + [
- (
- None,
- " too long conversations, so let's summarize..."
- )
- ]
- yield (state_chatbot, state_chatbot, context)
-
- bot_summarized_response = get_output_batch(
- model, tokenizer, [instruction_prompt], gen_config_summarization
- )[0]
- bot_summarized_response = bot_summarized_response.split("### Response:")[-1].strip()
-
- state_chatbot[-1] = (
- None,
- "✅ summarization is done and set as context"
- )
- print(f"bot_summarized_response: {bot_summarized_response}")
- yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}")
-
- instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0]
-
- bot_response = stream_model(
- instruction_prompt,
- max_tokens=256,
- temperature=1,
- top_p=0.9
- )
-
- instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display
- state_chatbot = state_chatbot + [(instruction_display, None)]
-
- prev_index = 0
- agg_tokens = ""
- cutoff_idx = 0
- for tokens in bot_response:
- tokens = tokens.strip()
- cur_token = tokens[prev_index:]
-
- if "#" in cur_token and agg_tokens == "":
- cutoff_idx = tokens.find("#")
- agg_tokens = tokens[cutoff_idx:]
-
- if agg_tokens != "":
- if len(agg_tokens) < len("### Instruction:") :
- agg_tokens = agg_tokens + cur_token
- elif len(agg_tokens) >= len("### Instruction:"):
- if tokens.find("### Instruction:") > -1:
- processed_response, _ = post_process_stream(tokens[:tokens.find("### Instruction:")].strip())
-
- state_chatbot[-1] = (
- instruction_display,
- processed_response
- )
- yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}")
- break
- else:
- agg_tokens = ""
- cutoff_idx = 0
+import global_vars
+import alpaca
- if agg_tokens == "":
- processed_response, to_exit = post_process_stream(tokens)
- state_chatbot[-1] = (instruction_display, processed_response)
- yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}")
-
- if to_exit:
- break
-
- prev_index = len(tokens)
-
- yield (
- state_chatbot,
- state_chatbot,
- f"{context} {bot_summarized_response}"
- )
-
-def chat_batch(
- contexts,
- instructions,
- state_chatbots,
-):
- state_results = []
- ctx_results = []
-
- instruct_prompts = [
- generate_prompt(instruct, histories, ctx)[0]
- for ctx, instruct, histories in zip(contexts, instructions, state_chatbots)
- ]
-
- bot_responses = get_output_batch(
- model, tokenizer, instruct_prompts, generation_config
- )
- bot_responses = post_processes_batch(bot_responses)
-
- for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots):
- new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)]
- ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx)
- state_results.append(new_state_chatbot)
+from args import parse_args
+from miscs.strings import TITLE, ABSTRACT, BOTTOM_LINE
+from miscs.strings import DEFAULT_EXAMPLES
+from miscs.styles import PARENT_BLOCK_CSS
+from miscs.strings import SPECIAL_STRS
- return (state_results, state_results, ctx_results)
+from utils import get_chat_interface
def reset_textbox():
return gr.Textbox.update(value='')
@@ -148,25 +30,9 @@ def reset_everything(
)
def run(args):
- global model, stream_model, tokenizer, generation_config, gen_config_summarization, batch_enabled
-
- batch_enabled = True if args.batch_size > 1 else False
-
- model, tokenizer = load_model(
- base=args.base_url,
- finetuned=args.ft_ckpt_url,
- multi_gpu=args.multi_gpu
- )
-
- generation_config = get_generation_config(
- args.gen_config_path
- )
- gen_config_summarization = get_generation_config(
- "gen_config_summarization.yaml"
- )
-
- if not batch_enabled:
- stream_model = StreamModel(model, tokenizer)
+ global_vars.initialize_globals(args)
+ batch_enabled = global_vars.batch_enabled
+ chat_interface = get_chat_interface(global_vars.model_type, batch_enabled)
with gr.Blocks(css=PARENT_BLOCK_CSS) as demo:
state_chatbot = gr.State([])
@@ -219,7 +85,7 @@ def run(args):
gr.Markdown(f"{BOTTOM_LINE}")
send_event = instruction_txtbox.submit(
- chat_batch if batch_enabled else chat_stream,
+ chat_interface,
[context_txtbox, instruction_txtbox, state_chatbot],
[state_chatbot, chatbot, context_txtbox],
batch=batch_enabled,
@@ -232,7 +98,7 @@ def run(args):
)
continue_event = continue_btn.click(
- chat_batch if batch_enabled else chat_stream,
+ chat_interface,
[context_txtbox, continue_txtbox, state_chatbot],
[state_chatbot, chatbot, context_txtbox],
batch=batch_enabled,
@@ -245,7 +111,7 @@ def run(args):
)
summarize_event = summarize_btn.click(
- chat_batch if batch_enabled else chat_stream,
+ chat_interface,
[context_txtbox, summrize_txtbox, state_chatbot],
[state_chatbot, chatbot, context_txtbox],
batch=batch_enabled,
diff --git a/args.py b/args.py
index f3f63ab..37f14a8 100644
--- a/args.py
+++ b/args.py
@@ -42,9 +42,15 @@ def parse_args():
parser.add_argument(
"--gen_config_path",
help="path to GenerationConfig file used in batch mode",
- default="generation_config_default.yaml",
+ default="configs/generation_config_default.yaml",
type=str
)
+ parser.add_argument(
+ "--gen_config_summarization_path",
+ help="path to GenerationConfig file used in context summarization",
+ default="configs/gen_config_summarization.yaml",
+ type=str
+ )
parser.add_argument(
"--multi_gpu",
help="Enable multi gpu mode. This will force not to use Int8 but float16, so you need to check if your system has enough GPU memory",
diff --git a/chats/__init__.py b/chats/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gen_config_summarization.yaml b/configs/gen_config_summarization.yaml
similarity index 100%
rename from gen_config_summarization.yaml
rename to configs/gen_config_summarization.yaml
diff --git a/generation_config_default.yaml b/configs/generation_config_default.yaml
similarity index 100%
rename from generation_config_default.yaml
rename to configs/generation_config_default.yaml
diff --git a/gens/__init__.py b/gens/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gens/batch_gen.py b/gens/batch_gen.py
new file mode 100644
index 0000000..33c351c
--- /dev/null
+++ b/gens/batch_gen.py
@@ -0,0 +1,30 @@
+import torch
+
+def get_output_batch(
+ model, tokenizer, prompts, generation_config
+):
+ if len(prompts) == 1:
+ encoding = tokenizer(prompts, return_tensors="pt")
+ input_ids = encoding["input_ids"].cuda()
+ generated_id = model.generate(
+ input_ids=input_ids,
+ generation_config=generation_config,
+ max_new_tokens=256
+ )
+
+ decoded = tokenizer.batch_decode(generated_id)
+ del input_ids, generated_id
+ torch.cuda.empty_cache()
+ return decoded
+ else:
+ encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda')
+ generated_ids = model.generate(
+ **encodings,
+ generation_config=generation_config,
+ max_new_tokens=256
+ )
+
+ decoded = tokenizer.batch_decode(generated_ids)
+ del encodings, generated_ids
+ torch.cuda.empty_cache()
+ return decoded
diff --git a/gen.py b/gens/stream_gen.py
similarity index 90%
rename from gen.py
rename to gens/stream_gen.py
index 81b4e89..71d901c 100644
--- a/gen.py
+++ b/gens/stream_gen.py
@@ -15,34 +15,6 @@
TopPLogitsWarper,
)
-def get_output_batch(
- model, tokenizer, prompts, generation_config
-):
- if len(prompts) == 1:
- encoding = tokenizer(prompts, return_tensors="pt")
- input_ids = encoding["input_ids"].cuda()
- generated_id = model.generate(
- input_ids=input_ids,
- generation_config=generation_config,
- max_new_tokens=256
- )
-
- decoded = tokenizer.batch_decode(generated_id)
- del input_ids, generated_id
- torch.cuda.empty_cache()
- return decoded
- else:
- encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda')
- generated_ids = model.generate(
- **encodings,
- generation_config=generation_config,
- max_new_tokens=256
- )
-
- decoded = tokenizer.batch_decode(generated_ids)
- del encodings, generated_ids
- torch.cuda.empty_cache()
- return decoded
# StreamModel is borrowed from basaran project
diff --git a/global_vars.py b/global_vars.py
new file mode 100644
index 0000000..0f7aad6
--- /dev/null
+++ b/global_vars.py
@@ -0,0 +1,36 @@
+from models.alpaca_model import load_model
+from gens.stream_gen import StreamModel
+
+from miscs.utils import get_generation_config
+
+def initialize_globals(args):
+ global model, stream_model, tokenizer
+ global generation_config, gen_config_summarization
+ global model_type, batch_enabled
+
+ model_type = "alpaca"
+ batch_enabled = True if args.batch_size > 1 else False
+
+ model, tokenizer = load_model(
+ base=args.base_url,
+ finetuned=args.ft_ckpt_url,
+ multi_gpu=args.multi_gpu
+ )
+
+ if "alpaca" in args.ft_ckpt_url:
+ model_type = "alpaca"
+ elif "baize" in args.ft_ckpt_url:
+ model_type = "baize"
+ else:
+ print("unsupported model type. only alpaca and baize are supported")
+ quit()
+
+ generation_config = get_generation_config(
+ args.gen_config_path
+ )
+ gen_config_summarization = get_generation_config(
+ args.gen_config_summarization_path
+ )
+
+ if not batch_enabled:
+ stream_model = StreamModel(model, tokenizer)
\ No newline at end of file
diff --git a/miscs/__init__.py b/miscs/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/constants.py b/miscs/constants.py
similarity index 100%
rename from constants.py
rename to miscs/constants.py
diff --git a/strings.py b/miscs/strings.py
similarity index 100%
rename from strings.py
rename to miscs/strings.py
diff --git a/styles.py b/miscs/styles.py
similarity index 100%
rename from styles.py
rename to miscs/styles.py
diff --git a/miscs/utils.py b/miscs/utils.py
new file mode 100644
index 0000000..188eb97
--- /dev/null
+++ b/miscs/utils.py
@@ -0,0 +1,81 @@
+import re
+import yaml
+
+from transformers import GenerationConfig
+
+from miscs.strings import SPECIAL_STRS
+from miscs.constants import html_tag_pattern, multi_line_pattern, multi_space_pattern
+from miscs.constants import repl_empty_str, repl_br_tag, repl_span_tag_multispace, repl_linebreak
+
+def get_generation_config(path):
+ with open(path, 'rb') as f:
+ generation_config = yaml.safe_load(f.read())
+
+ return GenerationConfig(**generation_config["generation_config"])
+
+def generate_prompt(prompt, histories, ctx=None):
+ convs = f"""Below is a history of instructions that describe tasks, paired with an input that provides further context. Write a response that appropriately completes the request by remembering the conversation history.
+
+"""
+ if ctx is not None:
+ convs = f"""{ctx}
+
+"""
+ sub_convs = ""
+ start_idx = 0
+
+ for idx, history in enumerate(histories):
+ history_prompt = history[0]
+ history_response = history[1]
+ if history_response == "✅ summarization is done and set as context" or history_prompt == SPECIAL_STRS["summarize"]:
+ start_idx = idx
+
+ # drop the previous conversations if user has summarized
+ for history in histories[start_idx if start_idx == 0 else start_idx+1:]:
+ history_prompt = history[0]
+ history_response = history[1]
+
+ history_response = history_response.replace("
", "\n")
+ history_response = re.sub(
+ html_tag_pattern, repl_empty_str, history_response
+ )
+
+ sub_convs = sub_convs + f"""### Instruction:{history_prompt}
+
+### Response:{history_response}
+
+"""
+
+ sub_convs = sub_convs + f"""### Instruction:{prompt}
+
+### Response:"""
+
+ convs = convs + sub_convs
+ return convs, len(sub_convs)
+
+# applicable to instruction to be displayed as well
+def common_post_process(original_str):
+ original_str = re.sub(
+ multi_line_pattern, repl_br_tag, original_str
+ )
+ original_str = re.sub(
+ multi_space_pattern, repl_span_tag_multispace, original_str
+ )
+
+ return original_str
+
+def post_process_stream(bot_response):
+ # sometimes model spits out text containing
+ # "### Response:" and "### Instruction: -> in this case, we want to stop generating
+ if "### Response:" in bot_response or "### Input:" in bot_response:
+ bot_response = bot_response.replace("### Response:", '').replace("### Input:", '').strip()
+ return bot_response, True
+
+ return common_post_process(bot_response), False
+
+def post_process_batch(bot_response):
+ bot_response = bot_response.split("### Response:")[-1].strip()
+ return common_post_process(bot_response)
+
+def post_processes_batch(bot_responses):
+ return [post_process_batch(r) for r in bot_responses]
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/model.py b/models/alpaca_model.py
similarity index 100%
rename from model.py
rename to models/alpaca_model.py
diff --git a/utils.py b/utils.py
index 023a6d3..2d974cb 100644
--- a/utils.py
+++ b/utils.py
@@ -1,81 +1,8 @@
-import re
-import yaml
-
-from transformers import GenerationConfig
-
-from strings import SPECIAL_STRS
-from constants import html_tag_pattern, multi_line_pattern, multi_space_pattern
-from constants import repl_empty_str, repl_br_tag, repl_span_tag_multispace, repl_linebreak
-
-def get_generation_config(path):
- with open(path, 'rb') as f:
- generation_config = yaml.safe_load(f.read())
-
- return GenerationConfig(**generation_config["generation_config"])
-
-def generate_prompt(prompt, histories, ctx=None):
- convs = f"""Below is a history of instructions that describe tasks, paired with an input that provides further context. Write a response that appropriately completes the request by remembering the conversation history.
-
-"""
- if ctx is not None:
- convs = f"""{ctx}
-
-"""
- sub_convs = ""
- start_idx = 0
-
- for idx, history in enumerate(histories):
- history_prompt = history[0]
- history_response = history[1]
- if history_response == "✅ summarization is done and set as context" or history_prompt == SPECIAL_STRS["summarize"]:
- start_idx = idx
-
- # drop the previous conversations if user has summarized
- for history in histories[start_idx if start_idx == 0 else start_idx+1:]:
- history_prompt = history[0]
- history_response = history[1]
-
- history_response = history_response.replace("
", "\n")
- history_response = re.sub(
- html_tag_pattern, repl_empty_str, history_response
- )
-
- sub_convs = sub_convs + f"""### Instruction:{history_prompt}
-
-### Response:{history_response}
-
-"""
-
- sub_convs = sub_convs + f"""### Instruction:{prompt}
-
-### Response:"""
-
- convs = convs + sub_convs
- return convs, len(sub_convs)
-
-# applicable to instruction to be displayed as well
-def common_post_process(original_str):
- original_str = re.sub(
- multi_line_pattern, repl_br_tag, original_str
- )
- original_str = re.sub(
- multi_space_pattern, repl_span_tag_multispace, original_str
- )
-
- return original_str
-
-def post_process_stream(bot_response):
- # sometimes model spits out text containing
- # "### Response:" and "### Instruction: -> in this case, we want to stop generating
- if "### Response:" in bot_response or "### Input:" in bot_response:
- bot_response = bot_response.replace("### Response:", '').replace("### Input:", '').strip()
- return bot_response, True
-
- return common_post_process(bot_response), False
-
-def post_process_batch(bot_response):
- bot_response = bot_response.split("### Response:")[-1].strip()
- return common_post_process(bot_response)
-
-def post_processes_batch(bot_responses):
- return [post_process_batch(r) for r in bot_responses]
+import alpaca
+
+def get_chat_interface(model_type, batch_enabled):
+ match model_type:
+ case 'alpaca':
+ return alpaca.chat_batch if batch_enabled else alpaca.chat_stream
+ case other:
+ return None