0️⃣1️⃣🤗 BitNet-Transformers: Huggingface Transformers Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch with Mistral Architecture
- Paper Link: https://arxiv.org/pdf/2310.11453.pdf
# Clone this repo
git clone https://github.com/DewEfresh/bitnet-transformers
cd bitnet-transformers
# Install requirements
pip install -r clm_requirements.txt
# Clone transformers repo
git clone https://github.com/huggingface/transformers
pip install -e transformers
# Update Llama(2) model
rm ./transformers/src/transformers/models/llama/modeling_llama.py
ln -s $(pwd)/bitnet_mistral/modeling_llama.py ./transformers/src/transformers/models/llama/modeling_llama.py
# Update Llama(2) model
rm ./transformers/src/transformers/models/mistral/modeling_mistral.py
ln -s $(pwd)/bitnet_mistral/modeling_mistral.py ./transformers/src/transformers/models/mistral/modeling_mistral.py
We'll overwrite bitnet_llama/modeling_llama.py
into transformers
. Since the file is linked, any changes made to the file will be reflected in the transformers
repo.
You can track metrics via wandb
./train_wikitext.sh
Train Config
- Batch size: 1
- Gradient accumulation: 1
- Seq length: 2048
- Model:
LLamaForCausalLM
withBitLinear
layer - Model size: 47,452,672 (47.5M)
Original LLAMA - 16bit
- Uses 250MB GPU memory for Model weights
BitLLAMA - Mixed 16bit
- Uses 200MB GPU memory for Model weights
- Use bf16(or fp16) to store model weights
- Use int8 to store
-1
/1
1-bit weights - Use more memory when training than original LLAMA: It saves 1-bit weight and 16bit weight together
BitLLAMA - 8bit
- Uses 100MB GPU memory for Model weights
- Use bf16(or fp16) on-the-fly when needed
- Use 8bit to save 1-bit BitLinear weight & other weights
BitLLAMA - 1bit
- Use bf16(or fp16) on-the-fly when needed
- Use 1bit to save 1-bit weight
TBD
- Add
BitLinear
layer - Add
LLamaForCausalLM
model withBitLinear
layer- Update
.save_pretrained
method (for 1-bit weight saving)
- Update
- Add sample code for LM training
- Update
BitLinear
layer to use 1-bit weight- Use uint8 instead of bfloat16
- Use custom cuda kernel for 1-bit weight