Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inital bnb #10

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
accelerate
torch
parallel_pandas
peft
git+https://github.com/huggingface/peft
pyarrow
transformers
git+https://github.com/huggingface/transformers
wandb
deepspeed
xformers
scipy
git+https://github.com/huggingface/accelerate
52 changes: 43 additions & 9 deletions training/hf_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import transformers
import typing
from dataset import DataCollatorForMmapedDataset, MmappedArrowDataset
from profiling import ProfilerCallback, build_profiler_configuration

Expand Down Expand Up @@ -48,16 +49,33 @@ class LoraArguments:
lora_target_modules: t.Optional[str] = field(metadata={"help": "Target modules, comma-separated."},
default=None)


@dataclass
class BitsAndBytesArguments:
load_in_4bit: t.Optional[bool] = field(metadata={
"help": "Weather to use 4-Bit quantization."},
default=False)
load_in_8bit: t.Optional[bool] = field(metadata={
"help": "Weather to use 8-Bit quantization."},
default=False)
quant_type: t.Optional[typing.Literal['nf4', 'fp4']] = field(metadata={
"help": "Weather to use NF4 quant type. Only used for 4bit."},
default="fp4")
use_double_quant: t.Optional[bool] = field(metadata={
"help": "Weather to use double quant. Only used for 4bit."},
default=False)


def main() -> None:
parser = transformers.HfArgumentParser((
ModelArguments,
DataArguments,
LoraArguments,
BitsAndBytesArguments,
OtherArguments,
transformers.TrainingArguments,
))
model_args, data_args, lora_args, \
model_args, data_args, lora_args, bnb_args, \
other_args, training_args = parser.parse_args_into_dataclasses()

tokenizer = transformers.AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -88,11 +106,26 @@ def main() -> None:
elif training_args.fp16:
model_load_dtype = torch.float16

model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=model_load_dtype,
).cuda()
if bnb_args.load_in_8bit or bnb_args.load_in_4bit:
quantization_config = transformers.BitsAndBytesConfig(
load_in_8bit=bnb_args.load_in_8bit,
load_in_4bit=bnb_args.load_in_4bit,
bnb_4bit_compute_dtype=model_load_dtype,
bnb_4bit_quant_type=bnb_args.quant_type,
bnb_4bit_use_double_quant=bnb_args.use_double_quant,
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=model_load_dtype,
quantization_config=quantization_config,
).cuda()
else:
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=model_load_dtype,
).cuda()

if other_args.add_special_tokens is not None:
# MAINTENANCE(11b): Big fat warning: the snippet below is copy-pasted
Expand Down Expand Up @@ -185,7 +218,7 @@ def main() -> None:


class SavePeftModelCallback(transformers.TrainerCallback):
'''
"""
At some point, PEFT stopped saving just the adapter and instead started
storing full model weights. Extracting the adapter from the weights is
doable, but seems to result in subpar results for some unknown reason, so
Expand All @@ -194,7 +227,7 @@ class SavePeftModelCallback(transformers.TrainerCallback):

https://github.com/huggingface/peft/issues/286#issuecomment-1512611968
https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb
'''
"""

def on_save(
self,
Expand Down Expand Up @@ -254,8 +287,9 @@ def _add_special_tokens_to_tokenizer_and_resize_model_embeddings(


def _nearest_divisible(num: int, divisor: int) -> int:
'''Returns the nearest number to `num` that is divisible by `divisor`.'''
"""Returns the nearest number to `num` that is divisible by `divisor`."""
return (num + divisor - 1) // divisor * divisor


if __name__ == "__main__":
main()