
🧠 any4: Learns optimal 4-bit representation • Outperforms INT4, FP4, NF4
🚀 tinygemm: Fast Small Batch CUDA kernels for BF16/FP16 • INT4/NF4/MX4/any4 quantization
✨ Effortless Scripts: Evaluate LLMs on NLP, Code, & Perplexity • Analyze & Visualize Weights/Activations
🤗 Hugging Face Compatible
This code release is meant to accompany our paper any4: Learned 4-bit Numeric Representation for LLMs, ICML 2025, by Mostafa Elhoushi and Jeff Johnson.
The technique and code for learning any4 representations and quantizing a model was authored by Mostafa Elhoushi (previously Meta FAIR SysML research). The Nvidia GPU tinygemm library was authored by Jeff Johnson (currently Meta FAIR SysML research). An extremely early version of the tinygemm kernels without any4/MX4 support were upstreamed to PyTorch core in Q4 2023 for use by the torch
compiler.
What is any4?
There is a wide variety of 4-bit numerical formats implemented on CPU/GPU for ML inference, such as uniform int4 quantization, "fp4", NF4, AF4 and the like, all of which have the dequantization values fixed a priori. any4 substitutes a lookup table (LUT) to translate the 16 possible 4-bit quantization codes to any arbitrary bfloat16 or float16 floating-point value, and this GPU in-register LUT is used at dequantization time. Each row of a weight matrix can use a different 16 x bfloat16/float16 LUT, so the quantization codes can be tailored to each row of a matrix. k-means or neural network based clustering is used to learn the any4 LUTs based off the weight matrix data distribution. Effectively, any4 is 4-bit grouped quantization like typical int4 quantization, just that instead of the code dequantization values prior to scale and offset being integers in the range [-8, +7] or [0, 15], the dequantization values are here arbitrary floating point values from the LUT. any4 is thus a very efficient means of implementing NormalFloat4 (NF4) or AbnormalFloat4 (AF4), whose initial implementations used GPU unfriendly deeply-nested if/else blocks or switch statements.What is tinygemm?
The tinygemm low-latency GPU GEMM library implements any4 quantization. Learning the any4 quantization codes is not part of tinygemm itself. While tinygemm supports most any arbitrary GEMM size (assuming the reduction/k dimension is a multiple of 16 or 32), it is primarily meant for matrix multiplication problems where one of the `m` or `n` problem dimensions (for a `(m x k) x (n x k)^t` matrix multiplication) is *smaller* than a GPU tensor core tile size (e.g., 1 <= m <= 16 or 1 <= n <= 8), usually applied to the "activation" vector in neural networks.tinygemm has two different modes, one that computes Y = X W^t
and the other that computes Y = (W X^t)^t
(both produce the same result, just whether the "weight" matrix is the "A" or "B" matrix for tensor core usage). All needed transpositions are performed on the fly as needed by tinygemm. For the m16n8k16 A100+ bf16/fp16 tensor core tile, the "A" matrix tile size is 16 x 16 and "B" is 8 x 16 (or 16 x 8 as desired). Putting activations (e.g., a 1 x k
matrix) on the right and weight on the left (so that the 1 x k
matrix will occupy the "B" tile) ensures that we will be running the tensor core unit at 1/8th throughput rather than 1/16th throughput. We have found that using the tensor core in this fashion for e.g., GEMV is pretty fast. tinygemm does not use larger tensor core multiplication primitives (again, because a typical use case is something like a (1 x k) x (n x k)
GEMM. All matrices presented to tinygemm must be row-major with the reduction dimension k
being innermost.
To further reduce latency, it is best to lay out weight matrices in "tensor core" format, so no shared memory transposition is needed. Because there is also no reuse of the weight matrix in usual circumstances, we avoid shared memory entirely for buffering or transposition and the kernels load data directly from gmem into registers (though with some degree of multi-buffering into registers, but nvcc/ptxas' register usage heuristics are at odds with this; loads from gmem into a register are still asynchronous until the point of first use).
Please defer to the paper for additional details.
- Clone Repo
git clone [email protected]:fairinternal/any4.git
cd any4
- Setup Environment
conda create --name any4 python=3.10
conda activate any4
pip install -r requirements.txt
- Access Models
Some models (e.g., Llama) require permission. Follow these steps to access them:
a. Submit a request to access a Llama checkpoint, e.g., https://huggingface.co/meta-llama/Llama-3.2-1B.
b. Setup Hugging Face token access by following the steps described here.
c. Then you will be able to login to Hugging Face by running the cell below and entering the token you obtain from Step b. above:
huggingface-cli login
- Install tinygemm kernels
cd tinygemm_lib
python setup.py install
cd ..
Most of the scripts below will run baseline fp16 model by default. To quantize add the following arguments:
--model-args
: pass in any args that are passed to Hugging Face's [from_pretrained()
].--quantize
: implements different (fake) quantization algorithms implemented in this codebase. It can take:intq
(integer quantization),fp4
(4-bit float quantization),nf4
(4-bit normal float quantization),any4
(proposed lookup table quantization).--quantize-args
: comma-separated arguments to pass to a quantization algorithm, e.g.,--quantize-args n_bit=4,group_size=128
will perform 4-bit quantization with group size 128.
You should now be able to run a code snippet like this where can you just quantize a model by calling int4(..)
, int8(...)
, nf4(...)
, fp4(...)
, or any4(...)
.
from transformers import AutoModelForCausalLM, AutoTokenizer
from quantize import int4, any4, int8, nf4, fp4
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").cuda().bfloat16()
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
model = any4(model)
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, do_sample=True, max_new_tokens=256)
print(tokenizer.batch_decode(outputs)[0])
Feel also free to edit example.py
and run it:
python example.py
and we encourage you to play around with our Jupyter Notebook tutorial.
Evaluate a model (with or without quantization) on downstream tasks.
- Baseline fp16 model:
python eval.py --model-name facebook/opt-125m --tasks piqa
- Quantized int4 model:
python eval.py --model-name facebook/opt-125m --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --tasks piqa
- Quantized any4 model:
python eval.py --model-name facebook/opt-125m --quantize anyq --quantize-args n_bit=4,sample_weight=calibrate,scale_sample_weight=True,skip_modules=lm_head --tasks piqa
Arguments:
--tasks
: by default it runs a large number of natural language, coding, and perplexity evaluation tasks:- You can specify a space separate list of tasks, e.g.,
--tasks piqa mbpp
. - You can pass in any task supported by Eleuther LM Eval Harness, BigCode Eval Harness, and any Hugging Face dataset to measure its perplexity.
- You can specify a space separate list of tasks, e.g.,
To benchmark the performance time a single linear layer with tinygemm's kernels, you can run:
python microbenchmark.py --input-dim 4096 --output-dim 4096 --batch-size 1 --quantize anyq
To benchmark a model end-to-end with tinygemm's kernels:
python benchmark.py --batch-size 1 --seqlen 1 --model-name meta-llama/Llama-3.2-1B --quantize anyq --quantize-args skip_modules=lm_head
To analyze weights and mean square errors on weights and activations between baseline model and quantized model at each layer:
python analyze.py --model-name meta-llama/Llama-3.2-1B --quantize nf4 --quantize-args skip_modules=lm_head
To pass a dataset or pompt over a model and store output activations of each layer:
python calibrate.py --model-name meta-llama/Llama-3.2-1B --dataset cerebras/SlimPajama-627B --num-batches 10
To pass a prompt to both a baseline model and quantized model and measure the mean square error along each layer:
python analyze.py --model-name meta-llama/Llama-3.2-1B --quantize anyq
To run all unit test cases:
python -m pytest .
In this section we provide the results in the paper and the command to reproduce each result.
Please note: you need to expand "Commands to reproduce results" block below each table, in order for the links to commands in each row to work.
Llama3.2 1B
WikiText-2↓ | C4↓ | PTB↓ | CodeParrot↓ | HumanEval↑ | MBPP↑ | MMLU↑ | HellaSwag↑ | GSM8K↑ | BBH↑ | |
---|---|---|---|---|---|---|---|---|---|---|
FP16 [1] | 9.76 | 12.77 | 16.56 | 3.49 | 16.46% | 21.4% | 36.1% | 47.7% | 6.60% | 31.1% |
INT4 [2] | 11.89 | 15.74 | 20.32 | 4.08 | 9.76% | 11.4% | 30.1% | 44.7% | 3.18% | 26.2% |
FP4 [3] | 13.01 | 17.11 | 21.89 | 4.28 | 8.54% | 5.8% | 29.3% | 43.6% | 2.27% | 23.3% |
NF4 [4] | 10.99 | 14.63 | 18.78 | 3.82 | 13.4% | 13.8% | 33.3% | 45.8% | 3.65% | 26.8% |
ANY4 [5] | 10.63 | 13.95 | 17.94 | 3.71 | 11.0% | 18.6% | 32.9% | 46.7% | 3.71% | 29.0% |
Commands to reproduce results:
-
python eval.py --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log--dir ./logs/llama3.2-1b/bf16
-
python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3.2-1b/int4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3.2-1b/fp4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3.2-1b/nf4
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3.2-1b/any4
Llama3 8B
WikiText-2↓ | C4↓ | PTB↓ | CodeParrot↓ | HumanEval↑ | MBPP↑ | MMLU↑ | HellaSwag↑ | GSM8K↑ | BBH↑ | |
---|---|---|---|---|---|---|---|---|---|---|
FP16 [6] | 6.14 | 8.93 | 10.59 | 2.54 | 29.3% | 41.4% | 62.0% | 60.1% | 50.7% | 62.8% |
INT4 [7] | 6.87 | 9.89 | 11.37 | 2.83 | 23.2% | 35.4% | 59.6% | 58.6% | 40.6% | 58.5% |
FP4 [8] | 7.10 | 10.22 | 11.81 | 2.89 | 22.0% | 36.8% | 57.1% | 58.5% | 35.0% | 53.2% |
NF4 [9] | 6.63 | 9.52 | 11.14 | 2.72 | 23.2% | 39.2% | 60.7% | 59.1% | 41.1% | 59.0% |
ANY4 [10] | 6.51 | 9.40 | 11.07 | 2.68 | 21.3% | 39.2% | 61.0% | 59.5% | 41.7% | 59.2% |
Commands to reproduce results:
-
python eval.py --model-name meta-llama/Meta-Llama-3-8B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log--dir ./logs/llama3-8b/bf16
-
python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-8B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-8b/int4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-8B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-8b/fp4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-8B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-8b/nf4
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Meta-Llama-3-8B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-8b/any4
Llama3 70B
WikiText-2↓ | C4↓ | PTB↓ | CodeParrot↓ | HumanEval↑ | MBPP↑ | MMLU↑ | HellaSwag↑ | GSM8K↑ | BBH↑ | |
---|---|---|---|---|---|---|---|---|---|---|
FP16 [11] | 2.86 | 6.77 | 8.16 | 1.91 | 17.7% | 60.8% | 75.4% | 66.3% | 80.6% | 82.4% |
INT4 [12] | 3.63 | 7.97 | 8.86 | 2.21 | 18.3% | 45.0% | 73.0% | 66.2% | 73.9% | 78.4% |
FP4 [13] | 3.94 | 7.76 | 8.99 | 2.17 | 22.0% | 50.8% | 71.9% | 65.6% | 75.3% | 77.9% |
NF4 [14] | 3.43 | 7.67 | 8.84 | 2.15 | 18.9% | 39.6% | 73.7% | 66.1% | 75.9% | 79.3% |
ANY4 [15] | 3.20 | 7.01 | 8.33 | 1.99 | 17.1% | 57.4% | 75.1% | 66.1% | 78.5% | 81.8% |
Commands to reproduce results:
-
python eval.py --model-name meta-llama/Meta-Llama-3-70B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-70b/bf16
-
python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-70B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-70b/int4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-70B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-70b/fp4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-70B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-70b/nf4
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Meta-Llama-3-70B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-70b/any4
Llama2 7B
WikiText-2↓ | C4↓ | PTB↓ | CodeParrot↓ | HumanEval↑ | MBPP↑ | MMLU↑ | HellaSwag↑ | GSM8K↑ | BBH↑ | |
---|---|---|---|---|---|---|---|---|---|---|
FP16 [16] | 5.47 | 6.97 | 20.83 | 2.54 | 17.1% | 20.0% | 41.3% | 57.2% | 13.6% | 39.8% |
INT4 [17] | 5.74 | 7.30 | 24.00 | 2.63 | 14.0% | 18.2% | 38.1% | 56.4% | 10.6% | 36.5% |
FP4 [18] | 5.83 | 7.37 | 22.57 | 2.65 | 11.0% | 16.8% | 36.5% | 56.6% | 11.2% | 35.5% |
NF4 [19] | 5.66 | 7.19 | 22.82 | 2.60 | 11.6% | 19.2% | 37.4% | 56.8% | 10.2% | 36.8% |
ANY4 [20] | 5.59 | 7.10 | 21.23 | 2.57 | 14.0% | 18.4% | 40.3% | 56.7% | 12.7% | 36.9% |
Commands to reproduce results:
-
python eval.py --model-name meta-llama/Llama-2-7b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-7b/fp16
-
python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-7b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-7b/int4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-7b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-7b/fp4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-7b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-7b/nf4
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-2-7b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-7b/any4
Llama2 13B
WikiText-2↓ | C4↓ | PTB↓ | CodeParrot↓ | HumanEval↑ | MBPP↑ | MMLU↑ | HellaSwag↑ | GSM8K↑ | BBH↑ | |
---|---|---|---|---|---|---|---|---|---|---|
FP16 [21] | 4.88 | 6.47 | 28.93 | 2.40 | 19.5% | 18.4% | 50.5% | 60.0% | 23.2% | 47.4% |
INT4 [22] | 5.05 | 6.65 | 30.79 | 2.45 | 15.2% | 16.4% | 48.8% | 59.3% | 20.8% | 44.2% |
FP4 [23] | 5.07 | 6.67 | 30.96 | 2.46 | 15.6% | 16.6% | 49.0% | 59.7% | 21.2% | 44.1% |
NF4 [24] | 4.99 | 6.58 | 31.17 | 2.43 | 15.9% | 16.6% | 49.9% | 59.9% | 22.1% | 44.6% |
ANY4 [25] | 4.97 | 6.55 | 28.83 | 2.42 | 15.2% | 18.0% | 49.3% | 59.5% | 21.6% | 44.6% |
Commands to reproduce results:
-
python eval.py --model-name meta-llama/Llama-2-13b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-13b/fp16
-
python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-13b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-13b/int4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-13b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-13b/fp4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-13b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-13b/nf4
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-2-13b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-13b/any4
Llama2 70B
WikiText-2↓ | C4↓ | PTB↓ | CodeParrot↓ | HumanEval↑ | MBPP↑ | MMLU↑ | HellaSwag↑ | GSM8K↑ | BBH↑ | |
---|---|---|---|---|---|---|---|---|---|---|
FP16 [26] | 3.32 | 5.52 | 14.44 | 2.11 | 31.7% | 37.4% | 65.2% | 64.8% | 53.3% | 67.1% |
INT4 [27] | 3.46 | 5.61 | 14.61 | 2.14 | 26.8% | 37.8% | 64.4% | 64.6% | 51.4% | 65.9% |
FP4 [28] | 3.53 | 5.67 | 14.34 | 2.16 | 29.0% | 36.8% | 63.6% | 63.9% | 51.2% | 65.5% |
NF4 [29] | 3.44 | 5.61 | 14.36 | 2.13 | 29.9% | 37.2% | 64.4% | 63.9% | 51.9% | 66.5% |
ANY4 [30] | 3.40 | 5.58 | 14.64 | 2.13 | 26.8% | 38.5% | 64.8% | 63.6% | 51.6% | 66.6% |
Commands to reproduce results:
-
python eval.py --model-name meta-llama/Llama-2-70b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-70b/fp16
-
python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-70b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-70b/int4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-70b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-70b/fp4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-70b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-70b/nf4
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-2-70b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-70b/any4
Mistral-7B Instruct v0.2
WikiText-2↓ | C4↓ | PTB↓ | CodeParrot↓ | MMLU↑ | HellaSwag↑ | GSM8K↑ | BigBench↑ | |
---|---|---|---|---|---|---|---|---|
FP16 [31] | 5.95 | 8.82 | 21.77 | 2.63 | 58.7% | 66.1% | 41.7% | 51.7% |
INT4 [32] | 6.14 | 9.03 | 22.02 | 2.70 | 57.1% | 65.1% | 39.7% | 50.4% |
FP4 [33] | 6.19 | 9.10 | 21.62 | 2.70 | 56.6% | 64.7% | 38.2% | 47.7% |
NF4 [34] | 6.06 | 8.93 | 24.72 | 2.66 | 58.0% | 65.5% | 38.5% | 51.8% |
ANY4 [35] | 6.00 | 8.85 | 23.24 | 2.64 | 58.6% | 65.4% | 41.1% | 51.7% |
Commands to reproduce results:
-
python eval.py --model-name mistralai/Mistral-7B-Instruct-v0.2 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mistral-7b-instruct-v0.2/fp16
-
python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mistral-7B-Instruct-v0.2 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mistral-7b-instruct-v0.2/int4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mistral-7B-Instruct-v0.2 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mistral-7b-instruct-v0.2/fp4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mistral-7B-Instruct-v0.2 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mistral-7b-instruct-v0.2/nf4
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name mistralai/Mistral-7B-Instruct-v0.2 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mistral-7b-instruct-v0.2/any4
Mixtral-8x7B Instruct v0.1
WikiText-2↓ | C4↓ | PTB↓ | CodeParrot↓ | MMLU↑ | HellaSwag↑ | GSM8K↑ | BigBench↑ | |
---|---|---|---|---|---|---|---|---|
FP16 [36] | 4.14 | 7.18 | 16.47 | 2.20 | 68.2% | 67.6% | 64.8% | 68.1% |
INT4 [37] | 4.35 | 7.45 | 16.84 | 2.26 | 66.5% | 66.3% | 57.8% | 61.8% |
FP4 [38] | 4.46 | 7.48 | 18.42 | 2.27 | 66.8% | 66.5% | 59.4% | 62.8% |
NF4 [39] | 4.30 | 7.32 | 15.00 | 2.24 | 67.6% | 67.2% | 61.0% | 66.5% |
ANY4 [40] | 4.27 | 7.27 | 16.14 | 2.22 | 67.7% | 67.1% | 62.8% | 65.8% |
Commands to reproduce results:
-
python eval.py --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mixtral-8x7b-instruct-v0.1/fp16
-
python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mixtral-8x7b-instruct-v0.1/int4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mixtral-8x7b-instruct-v0.1/fp4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mixtral-8x7b-instruct-v0.1/nf4
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mixtral-8x7b-instruct-v0.1/any4
Group Size Referencing the paper, Table 4: C4 perplexity after quantizing with different group sizes.
64 | 128 | 256 | 512 | 1024 | |
---|---|---|---|---|---|
FP4 | 16.19 [41] | 17.11 [42] | 18.12 [43] | 20.43 [44] | 2.3E6 [45] |
NF4 | 14.27 [46] | 14.63 [47] | 14.98 [48] | 15.38 [49] | 7.8E5 [50] |
ANY4 | 13.75 [51] | 13.95 [52] | 14.09 [53] | 14.24 [54] | 14.34 [55] |
Commands to reproduce results:
-
python eval.py --quantize fp4 --quantize-args n_bit=4,group_size=64,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,group_size=128,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,group_size=256,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,group_size=512,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize fp4 --quantize-args n_bit=4,group_size=1024,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,group_size=64,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,group_size=128,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,group_size=256,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,group_size=512,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize nf4 --quantize-args n_bit=4,group_size=1024,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize anyq --quantize-args n_bit=4,group_size=64,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize anyq --quantize-args n_bit=4,group_size=128,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize anyq --quantize-args n_bit=4,group_size=256,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize anyq --quantize-args n_bit=4,group_size=512,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks c4
-
python eval.py --quantize anyq --quantize-args n_bit=4,group_size=1024,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks c4
Calibration Data Referencing the paper, Table 3: any4 quantization with different calibration data.
Calibration Data | Number of Samples | Sequence Length per Sample | WikiText-2↓ | C4↓ | PTB↓ | CodeParrot↓ | |
---|---|---|---|---|---|---|---|
ANY4 [57] | WikiText-2 | 128 | 2048 | 10.70 | 14.08 | 18.02 | 3.74 |
ANY4 [58] | Pile | 128 | 2048 | 10.70 | 13.99 | 18.26 | 3.74 |
ANY4 [59] | C4 | 128 | 2048 | 10.67 | 14.05 | 17.97 | 3.74 |
ANY4 [60] | C4 | 128 | 4096 | 10.74 | 14.14 | 18.10 | 3.75 |
ANY4 [61] | C4 | 128 | 512 | 10.62 | 13.96 | 18.03 | 3.72 |
ANY4 [62] | Handwritten Prompt | 1 | - | 10.63 | 13.95 | 17.94 | 3.71 |
Commands to reproduce results:
-
a. Offline Calibration:
python calibrate.py --dataset wikitext-2 --dataloader-type gptq --num-samples 128 --max-seq-len 2048 --model-name meta-llama/Llama-3.2-1B --log-dir ./calibrations/llama3.2-1b/wikitext2_128_2048/
followed bypython eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=./calibrations/llama3.2-1b/wikitext2_128_2048/wikitext-2.pt,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
b. Online Calibration [will lead to different results]:
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --calibrate-args dataset=wikitext-2,dataloader_type=gptq,num_samples=128,max_seq_len=2048 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
-
a. Offline Calibration:
python calibrate.py --dataset monology/pile-uncopyrighted --num-samples 128 --max-seq-len 2048 --model-name meta-llama/Llama-3.2-1B --log-dir ./calibrations/llama3.2-1b/pile_128_2048/
followed bypython eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=./calibrations/llama3.2-1b/pile_128_2048/pile-uncopyrighted.pt,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
b. Online Calibration [will lead to different results]:
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --calibrate-args dataset=monology/pile-uncopyrighted,num_samples=128,max_seq_len=2048 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
-
a. Offline Calibration:
python calibrate.py --dataset c4 --dataloader-type gptq --num-samples 128 --max-seq-len 2048 --model-name meta-llama/Llama-3.2-1B --log-dir ./calibrations/llama3.2-1b/c4_128_2048/
followed bypython eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=./calibrations/llama3.2-1b/c4_128_2048/c4.pt,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
b. Online Calibration [will lead to different results]:
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --calibrate-args dataset=c4,dataloader_type=gptq,num_samples=128,max_seq_len=2048 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
-
a. Offline Calibration:
python calibrate.py --dataset c4 --dataloader-type gptq --num-samples 128 --max-seq-len 4096 --model-name meta-llama/Llama-3.2-1B --log-dir ./calibrations/llama3.2-1b/c4_128_4096/
followed bypython eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=./calibrations/llama3.2-1b/c4_128_4096/c4.pt,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
b. Online Calibration [will lead to different results]:
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --calibrate-args dataset=c4,dataloader_type=gptq,num_samples=128,max_seq_len=4096 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
-
a. Offline Calibration:
python calibrate.py --dataset c4 --dataloader-type gptq --num-samples 128 --max-seq-len 4096 --model-name meta-llama/Llama-3.2-1B --log-dir ./calibrations/llama3.2-1b/c4_128_512/
followed bypython eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=./calibrations/llama3.2-1b/c4_128_512/c4.pt,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
b. Online Calibration [will lead to different results]:
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --calibrate-args dataset=c4,dataloader_type=gptq,num_samples=128,max_seq_len=512 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
Term to Minimize Perplexity after quantizing Llama3.2 1B with LUTs created by minimizing different terms.
Term to Minimize | WikiText-2↓ | C4↓ | PTB↓ | CodeParrot↓ | |
---|---|---|---|---|---|
Weights Only [63] | 11.143 | 14.740 | 18.715 | 3.858 | |
Weights × Activations [64] | 10.636 | 13.954 | 18.031 | 3.719 | |
Weights × Activations × Group Scales [65] [Ours] | 10.603 | 13.949 | 18.085 | 3.710 |
Commands to reproduce results:
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
K-Means Initialization Referencing the paper, Table A4: any4 quantization with K-means clustering initialzied with different algorithms and values.
K-Means Initialization | WikiText-2↓ | C4↓ | PTB↓ | |
---|---|---|---|---|
FP16 [66] | - | 9.76 | 12.77 | 16.56 |
ANY4 [67] | k-means++ | 10.63 | 13.95 | 17.94 |
ANY4 [68] | random | 10.66 | 13.97 | 18.17 |
ANY4 [69] | int4 | 10.83 | 14.21 | 18.69 |
ANY4 [70] | nf4 | 10.65 | 13.96 | 18.21 |
Commands to reproduce results:
-
python eval.py --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True,init=k-means++ --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True,init=random --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True,init=int --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb
-
python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True,init=nf4 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb
Referencing the paper, Figure 3: Speedup of our tinygemm CUDA kernels on matrix multiplication of 1 × K input by K × K weight, w.r.t PyTorch’s bfloat16 implementation. Please note the results below are on Nvidia A5000, while the paper's figure was based on Nvidia A100.
Dimension ($DIM ) |
INT4 [71] | NF4 [73] | ANY4 [72] |
---|---|---|---|
1024 | 1.45x | 1.37x | 1.36x |
2048 | 2.75x | 2.17x | 2.32x |
3072 | 2.60x | 2.07x | 2.15x |
4096 | 3.26x | 2.23x | 2.29x |
5120 | 3.19x | 2.26x | 2.27x |
6144 | 3.40x | 2.27x | 2.23x |
7168 | 3.26x | 2.19x | 2.24x |
8192 | 3.52x | 2.24x | 2.25x |
Commands to reproduce results:
-
python microbenchmark.py --input-dim $DIM --output-dim $DIM --quantize intq
-
python microbenchmark.py --input-dim $DIM --output-dim $DIM --quantize anyq
-
python microbenchmark.py --input-dim $DIM --output-dim $DIM --quantize anyq --quantize-args per_row=False
(Note we have not yet implemented NF4 end-to-end modules. See #16).
We encourage contributions from the community. Please feel free to check our Issues for any task to contribute with, especially our TODOs issue, as well as our Contributing Guidelines.
tinygemm and any4 quantization code are CC-BY-NC 4.0 licensed, as found in the LICENSE file.
If you use any4 quantization algorithm and/or tinygemm quantization library, please use the following BibTex entry:
@inproceedings{any4,
title={any4: Learned 4-bit Numeric Representation for {LLM}s},
author={Mostafa Elhoushi and Jeff Johnson},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=tJmhOPkWCj}
}