From 383061fa90affbb08d6fde1d41e9e96207d8401b Mon Sep 17 00:00:00 2001 From: luis Date: Wed, 24 May 2023 23:26:01 -0400 Subject: [PATCH] inital bnb --- requirements.txt | 11 ++++++--- training/hf_trainer.py | 52 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9d96232..8e6b927 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,10 @@ -accelerate +torch parallel_pandas -peft +git+https://github.com/huggingface/peft pyarrow -transformers +git+https://github.com/huggingface/transformers +wandb +deepspeed +xformers +scipy +git+https://github.com/huggingface/accelerate \ No newline at end of file diff --git a/training/hf_trainer.py b/training/hf_trainer.py index 48b2b5e..c52d0a6 100644 --- a/training/hf_trainer.py +++ b/training/hf_trainer.py @@ -5,6 +5,7 @@ import torch import transformers +import typing from dataset import DataCollatorForMmapedDataset, MmappedArrowDataset from profiling import ProfilerCallback, build_profiler_configuration @@ -48,16 +49,33 @@ class LoraArguments: lora_target_modules: t.Optional[str] = field(metadata={"help": "Target modules, comma-separated."}, default=None) + +@dataclass +class BitsAndBytesArguments: + load_in_4bit: t.Optional[bool] = field(metadata={ + "help": "Weather to use 4-Bit quantization."}, + default=False) + load_in_8bit: t.Optional[bool] = field(metadata={ + "help": "Weather to use 8-Bit quantization."}, + default=False) + quant_type: t.Optional[typing.Literal['nf4', 'fp4']] = field(metadata={ + "help": "Weather to use NF4 quant type. Only used for 4bit."}, + default="fp4") + use_double_quant: t.Optional[bool] = field(metadata={ + "help": "Weather to use double quant. Only used for 4bit."}, + default=False) + def main() -> None: parser = transformers.HfArgumentParser(( ModelArguments, DataArguments, LoraArguments, + BitsAndBytesArguments, OtherArguments, transformers.TrainingArguments, )) - model_args, data_args, lora_args, \ + model_args, data_args, lora_args, bnb_args, \ other_args, training_args = parser.parse_args_into_dataclasses() tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -88,11 +106,26 @@ def main() -> None: elif training_args.fp16: model_load_dtype = torch.float16 - model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, - low_cpu_mem_usage=True, - torch_dtype=model_load_dtype, - ).cuda() + if bnb_args.load_in_8bit or bnb_args.load_in_4bit: + quantization_config = transformers.BitsAndBytesConfig( + load_in_8bit=bnb_args.load_in_8bit, + load_in_4bit=bnb_args.load_in_4bit, + bnb_4bit_compute_dtype=model_load_dtype, + bnb_4bit_quant_type=bnb_args.quant_type, + bnb_4bit_use_double_quant=bnb_args.use_double_quant, + ) + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=model_load_dtype, + quantization_config=quantization_config, + ).cuda() + else: + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=model_load_dtype, + ).cuda() if other_args.add_special_tokens is not None: # MAINTENANCE(11b): Big fat warning: the snippet below is copy-pasted @@ -185,7 +218,7 @@ def main() -> None: class SavePeftModelCallback(transformers.TrainerCallback): - ''' + """ At some point, PEFT stopped saving just the adapter and instead started storing full model weights. Extracting the adapter from the weights is doable, but seems to result in subpar results for some unknown reason, so @@ -194,7 +227,7 @@ class SavePeftModelCallback(transformers.TrainerCallback): https://github.com/huggingface/peft/issues/286#issuecomment-1512611968 https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb - ''' + """ def on_save( self, @@ -254,8 +287,9 @@ def _add_special_tokens_to_tokenizer_and_resize_model_embeddings( def _nearest_divisible(num: int, divisor: int) -> int: - '''Returns the nearest number to `num` that is divisible by `divisor`.''' + """Returns the nearest number to `num` that is divisible by `divisor`.""" return (num + divisor - 1) // divisor * divisor + if __name__ == "__main__": main()