Skip to content

facebookresearch/any4

any4 & tinygemm

image

License: CC BY-NC Tutorial Slides arXiv alphaXiv

🧠 any4: Learns optimal 4-bit representation • Outperforms INT4, FP4, NF4
🚀 tinygemm: Fast Small Batch CUDA kernels for BF16/FP16 • INT4/NF4/MX4/any4 quantization
Effortless Scripts: Evaluate LLMs on NLP, Code, & Perplexity • Analyze & Visualize Weights/Activations
🤗 Hugging Face Compatible

This code release is meant to accompany our paper any4: Learned 4-bit Numeric Representation for LLMs, ICML 2025, by Mostafa Elhoushi and Jeff Johnson.

The technique and code for learning any4 representations and quantizing a model was authored by Mostafa Elhoushi (previously Meta FAIR SysML research). The Nvidia GPU tinygemm library was authored by Jeff Johnson (currently Meta FAIR SysML research). An extremely early version of the tinygemm kernels without any4/MX4 support were upstreamed to PyTorch core in Q4 2023 for use by the torch compiler.

What is any4? There is a wide variety of 4-bit numerical formats implemented on CPU/GPU for ML inference, such as uniform int4 quantization, "fp4", NF4, AF4 and the like, all of which have the dequantization values fixed a priori. any4 substitutes a lookup table (LUT) to translate the 16 possible 4-bit quantization codes to any arbitrary bfloat16 or float16 floating-point value, and this GPU in-register LUT is used at dequantization time. Each row of a weight matrix can use a different 16 x bfloat16/float16 LUT, so the quantization codes can be tailored to each row of a matrix. k-means or neural network based clustering is used to learn the any4 LUTs based off the weight matrix data distribution. Effectively, any4 is 4-bit grouped quantization like typical int4 quantization, just that instead of the code dequantization values prior to scale and offset being integers in the range [-8, +7] or [0, 15], the dequantization values are here arbitrary floating point values from the LUT. any4 is thus a very efficient means of implementing NormalFloat4 (NF4) or AbnormalFloat4 (AF4), whose initial implementations used GPU unfriendly deeply-nested if/else blocks or switch statements.
What is tinygemm? The tinygemm low-latency GPU GEMM library implements any4 quantization. Learning the any4 quantization codes is not part of tinygemm itself. While tinygemm supports most any arbitrary GEMM size (assuming the reduction/k dimension is a multiple of 16 or 32), it is primarily meant for matrix multiplication problems where one of the `m` or `n` problem dimensions (for a `(m x k) x (n x k)^t` matrix multiplication) is *smaller* than a GPU tensor core tile size (e.g., 1 <= m <= 16 or 1 <= n <= 8), usually applied to the "activation" vector in neural networks.

tinygemm has two different modes, one that computes Y = X W^t and the other that computes Y = (W X^t)^t (both produce the same result, just whether the "weight" matrix is the "A" or "B" matrix for tensor core usage). All needed transpositions are performed on the fly as needed by tinygemm. For the m16n8k16 A100+ bf16/fp16 tensor core tile, the "A" matrix tile size is 16 x 16 and "B" is 8 x 16 (or 16 x 8 as desired). Putting activations (e.g., a 1 x k matrix) on the right and weight on the left (so that the 1 x k matrix will occupy the "B" tile) ensures that we will be running the tensor core unit at 1/8th throughput rather than 1/16th throughput. We have found that using the tensor core in this fashion for e.g., GEMV is pretty fast. tinygemm does not use larger tensor core multiplication primitives (again, because a typical use case is something like a (1 x k) x (n x k) GEMM. All matrices presented to tinygemm must be row-major with the reduction dimension k being innermost.

To further reduce latency, it is best to lay out weight matrices in "tensor core" format, so no shared memory transposition is needed. Because there is also no reuse of the weight matrix in usual circumstances, we avoid shared memory entirely for buffering or transposition and the kernels load data directly from gmem into registers (though with some degree of multi-buffering into registers, but nvcc/ptxas' register usage heuristics are at odds with this; loads from gmem into a register are still asynchronous until the point of first use).

Please defer to the paper for additional details.

Getting Started

  1. Clone Repo
git clone [email protected]:fairinternal/any4.git

cd any4
  1. Setup Environment
conda create --name any4 python=3.10
conda activate any4

pip install -r requirements.txt
  1. Access Models

Some models (e.g., Llama) require permission. Follow these steps to access them:

a. Submit a request to access a Llama checkpoint, e.g., https://huggingface.co/meta-llama/Llama-3.2-1B.

b. Setup Hugging Face token access by following the steps described here.

c. Then you will be able to login to Hugging Face by running the cell below and entering the token you obtain from Step b. above:

huggingface-cli login
  1. Install tinygemm kernels
cd tinygemm_lib
python setup.py install
cd ..

Run

Most of the scripts below will run baseline fp16 model by default. To quantize add the following arguments:

  • --model-args: pass in any args that are passed to Hugging Face's [from_pretrained()].
  • --quantize: implements different (fake) quantization algorithms implemented in this codebase. It can take: intq (integer quantization), fp4 (4-bit float quantization), nf4 (4-bit normal float quantization), any4 (proposed lookup table quantization).
    • --quantize-args: comma-separated arguments to pass to a quantization algorithm, e.g., --quantize-args n_bit=4,group_size=128 will perform 4-bit quantization with group size 128.

Quick Example

You should now be able to run a code snippet like this where can you just quantize a model by calling int4(..), int8(...), nf4(...), fp4(...), or any4(...).

from transformers import AutoModelForCausalLM, AutoTokenizer
from quantize import int4, any4, int8, nf4, fp4

model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").cuda().bfloat16()
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")

model = any4(model)

inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, do_sample=True, max_new_tokens=256)
print(tokenizer.batch_decode(outputs)[0])

Feel also free to edit example.py and run it:

python example.py

and we encourage you to play around with our Jupyter Notebook tutorial.

Evaluation

Evaluate a model (with or without quantization) on downstream tasks.

  • Baseline fp16 model:
python eval.py --model-name facebook/opt-125m --tasks piqa
  • Quantized int4 model:
python eval.py --model-name facebook/opt-125m --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --tasks piqa
  • Quantized any4 model:
python eval.py --model-name facebook/opt-125m --quantize anyq --quantize-args n_bit=4,sample_weight=calibrate,scale_sample_weight=True,skip_modules=lm_head --tasks piqa

Arguments:

  • --tasks: by default it runs a large number of natural language, coding, and perplexity evaluation tasks:

Benchmark

To benchmark the performance time a single linear layer with tinygemm's kernels, you can run:

python microbenchmark.py --input-dim 4096 --output-dim 4096 --batch-size 1 --quantize anyq

To benchmark a model end-to-end with tinygemm's kernels:

python benchmark.py --batch-size 1 --seqlen 1 --model-name meta-llama/Llama-3.2-1B --quantize anyq --quantize-args skip_modules=lm_head

Analyze

To analyze weights and mean square errors on weights and activations between baseline model and quantized model at each layer:

python analyze.py --model-name meta-llama/Llama-3.2-1B --quantize nf4 --quantize-args skip_modules=lm_head

Calibrate

To pass a dataset or pompt over a model and store output activations of each layer:

python calibrate.py --model-name meta-llama/Llama-3.2-1B --dataset cerebras/SlimPajama-627B --num-batches 10

Diff

To pass a prompt to both a baseline model and quantized model and measure the mean square error along each layer:

python analyze.py --model-name meta-llama/Llama-3.2-1B --quantize anyq

Test

To run all unit test cases:

python -m pytest .

Experiments

In this section we provide the results in the paper and the command to reproduce each result.

Please note: you need to expand "Commands to reproduce results" block below each table, in order for the links to commands in each row to work.

Accuracy Results

Llama3.2 1B

WikiText-2↓ C4↓ PTB↓ CodeParrot↓ HumanEval↑ MBPP↑ MMLU↑ HellaSwag↑ GSM8K↑ BBH↑
FP16 [1] 9.76 12.77 16.56 3.49 16.46% 21.4% 36.1% 47.7% 6.60% 31.1%
INT4 [2] 11.89 15.74 20.32 4.08 9.76% 11.4% 30.1% 44.7% 3.18% 26.2%
FP4 [3] 13.01 17.11 21.89 4.28 8.54% 5.8% 29.3% 43.6% 2.27% 23.3%
NF4 [4] 10.99 14.63 18.78 3.82 13.4% 13.8% 33.3% 45.8% 3.65% 26.8%
ANY4 [5] 10.63 13.95 17.94 3.71 11.0% 18.6% 32.9% 46.7% 3.71% 29.0%
Commands to reproduce results:
  1. python eval.py --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log--dir ./logs/llama3.2-1b/bf16
  2. python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3.2-1b/int4
  3. python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3.2-1b/fp4
  4. python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3.2-1b/nf4
  5. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3.2-1b/any4

Llama3 8B

WikiText-2↓ C4↓ PTB↓ CodeParrot↓ HumanEval↑ MBPP↑ MMLU↑ HellaSwag↑ GSM8K↑ BBH↑
FP16 [6] 6.14 8.93 10.59 2.54 29.3% 41.4% 62.0% 60.1% 50.7% 62.8%
INT4 [7] 6.87 9.89 11.37 2.83 23.2% 35.4% 59.6% 58.6% 40.6% 58.5%
FP4 [8] 7.10 10.22 11.81 2.89 22.0% 36.8% 57.1% 58.5% 35.0% 53.2%
NF4 [9] 6.63 9.52 11.14 2.72 23.2% 39.2% 60.7% 59.1% 41.1% 59.0%
ANY4 [10] 6.51 9.40 11.07 2.68 21.3% 39.2% 61.0% 59.5% 41.7% 59.2%
Commands to reproduce results:
  1. python eval.py --model-name meta-llama/Meta-Llama-3-8B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log--dir ./logs/llama3-8b/bf16
  2. python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-8B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-8b/int4
  3. python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-8B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-8b/fp4
  4. python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-8B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-8b/nf4
  5. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Meta-Llama-3-8B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-8b/any4

Llama3 70B

WikiText-2↓ C4↓ PTB↓ CodeParrot↓ HumanEval↑ MBPP↑ MMLU↑ HellaSwag↑ GSM8K↑ BBH↑
FP16 [11] 2.86 6.77 8.16 1.91 17.7% 60.8% 75.4% 66.3% 80.6% 82.4%
INT4 [12] 3.63 7.97 8.86 2.21 18.3% 45.0% 73.0% 66.2% 73.9% 78.4%
FP4 [13] 3.94 7.76 8.99 2.17 22.0% 50.8% 71.9% 65.6% 75.3% 77.9%
NF4 [14] 3.43 7.67 8.84 2.15 18.9% 39.6% 73.7% 66.1% 75.9% 79.3%
ANY4 [15] 3.20 7.01 8.33 1.99 17.1% 57.4% 75.1% 66.1% 78.5% 81.8%
Commands to reproduce results:
  1. python eval.py --model-name meta-llama/Meta-Llama-3-70B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-70b/bf16
  2. python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-70B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-70b/int4
  3. python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-70B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-70b/fp4
  4. python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Meta-Llama-3-70B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-70b/nf4
  5. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Meta-Llama-3-70B --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama3-70b/any4

Llama2 7B

WikiText-2↓ C4↓ PTB↓ CodeParrot↓ HumanEval↑ MBPP↑ MMLU↑ HellaSwag↑ GSM8K↑ BBH↑
FP16 [16] 5.47 6.97 20.83 2.54 17.1% 20.0% 41.3% 57.2% 13.6% 39.8%
INT4 [17] 5.74 7.30 24.00 2.63 14.0% 18.2% 38.1% 56.4% 10.6% 36.5%
FP4 [18] 5.83 7.37 22.57 2.65 11.0% 16.8% 36.5% 56.6% 11.2% 35.5%
NF4 [19] 5.66 7.19 22.82 2.60 11.6% 19.2% 37.4% 56.8% 10.2% 36.8%
ANY4 [20] 5.59 7.10 21.23 2.57 14.0% 18.4% 40.3% 56.7% 12.7% 36.9%
Commands to reproduce results:
  1. python eval.py --model-name meta-llama/Llama-2-7b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-7b/fp16
  2. python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-7b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-7b/int4
  3. python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-7b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-7b/fp4
  4. python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-7b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-7b/nf4
  5. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-2-7b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-7b/any4

Llama2 13B

WikiText-2↓ C4↓ PTB↓ CodeParrot↓ HumanEval↑ MBPP↑ MMLU↑ HellaSwag↑ GSM8K↑ BBH↑
FP16 [21] 4.88 6.47 28.93 2.40 19.5% 18.4% 50.5% 60.0% 23.2% 47.4%
INT4 [22] 5.05 6.65 30.79 2.45 15.2% 16.4% 48.8% 59.3% 20.8% 44.2%
FP4 [23] 5.07 6.67 30.96 2.46 15.6% 16.6% 49.0% 59.7% 21.2% 44.1%
NF4 [24] 4.99 6.58 31.17 2.43 15.9% 16.6% 49.9% 59.9% 22.1% 44.6%
ANY4 [25] 4.97 6.55 28.83 2.42 15.2% 18.0% 49.3% 59.5% 21.6% 44.6%
Commands to reproduce results:
  1. python eval.py --model-name meta-llama/Llama-2-13b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-13b/fp16
  2. python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-13b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-13b/int4
  3. python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-13b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-13b/fp4
  4. python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-13b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-13b/nf4
  5. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-2-13b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-13b/any4

Llama2 70B

WikiText-2↓ C4↓ PTB↓ CodeParrot↓ HumanEval↑ MBPP↑ MMLU↑ HellaSwag↑ GSM8K↑ BBH↑
FP16 [26] 3.32 5.52 14.44 2.11 31.7% 37.4% 65.2% 64.8% 53.3% 67.1%
INT4 [27] 3.46 5.61 14.61 2.14 26.8% 37.8% 64.4% 64.6% 51.4% 65.9%
FP4 [28] 3.53 5.67 14.34 2.16 29.0% 36.8% 63.6% 63.9% 51.2% 65.5%
NF4 [29] 3.44 5.61 14.36 2.13 29.9% 37.2% 64.4% 63.9% 51.9% 66.5%
ANY4 [30] 3.40 5.58 14.64 2.13 26.8% 38.5% 64.8% 63.6% 51.6% 66.6%
Commands to reproduce results:
  1. python eval.py --model-name meta-llama/Llama-2-70b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-70b/fp16
  2. python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-70b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-70b/int4
  3. python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-70b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-70b/fp4
  4. python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-2-70b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-70b/nf4
  5. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-2-70b-hf --tasks wikitext-2 c4 ptb codeparrot humaneval mbpp mmlu hellaswag gsm8k bbh --log-dir ./logs/llama2-70b/any4

Mistral-7B Instruct v0.2

WikiText-2↓ C4↓ PTB↓ CodeParrot↓ MMLU↑ HellaSwag↑ GSM8K↑ BigBench↑
FP16 [31] 5.95 8.82 21.77 2.63 58.7% 66.1% 41.7% 51.7%
INT4 [32] 6.14 9.03 22.02 2.70 57.1% 65.1% 39.7% 50.4%
FP4 [33] 6.19 9.10 21.62 2.70 56.6% 64.7% 38.2% 47.7%
NF4 [34] 6.06 8.93 24.72 2.66 58.0% 65.5% 38.5% 51.8%
ANY4 [35] 6.00 8.85 23.24 2.64 58.6% 65.4% 41.1% 51.7%
Commands to reproduce results:
  1. python eval.py --model-name mistralai/Mistral-7B-Instruct-v0.2 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mistral-7b-instruct-v0.2/fp16
  2. python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mistral-7B-Instruct-v0.2 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mistral-7b-instruct-v0.2/int4
  3. python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mistral-7B-Instruct-v0.2 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mistral-7b-instruct-v0.2/fp4
  4. python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mistral-7B-Instruct-v0.2 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mistral-7b-instruct-v0.2/nf4
  5. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name mistralai/Mistral-7B-Instruct-v0.2 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mistral-7b-instruct-v0.2/any4

Mixtral-8x7B Instruct v0.1

WikiText-2↓ C4↓ PTB↓ CodeParrot↓ MMLU↑ HellaSwag↑ GSM8K↑ BigBench↑
FP16 [36] 4.14 7.18 16.47 2.20 68.2% 67.6% 64.8% 68.1%
INT4 [37] 4.35 7.45 16.84 2.26 66.5% 66.3% 57.8% 61.8%
FP4 [38] 4.46 7.48 18.42 2.27 66.8% 66.5% 59.4% 62.8%
NF4 [39] 4.30 7.32 15.00 2.24 67.6% 67.2% 61.0% 66.5%
ANY4 [40] 4.27 7.27 16.14 2.22 67.7% 67.1% 62.8% 65.8%
Commands to reproduce results:
  1. python eval.py --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mixtral-8x7b-instruct-v0.1/fp16
  2. python eval.py --quantize intq --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mixtral-8x7b-instruct-v0.1/int4
  3. python eval.py --quantize fp4 --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mixtral-8x7b-instruct-v0.1/fp4
  4. python eval.py --quantize nf4 --quantize-args n_bit=4,skip_modules=lm_head --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mixtral-8x7b-instruct-v0.1/nf4
  5. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name mistralai/Mixtral-8x7B-Instruct-v0.1 --tasks wikitext-2 c4 ptb codeparrot mmlu hellaswag gsm8k bigbench --log-dir ./logs/mixtral-8x7b-instruct-v0.1/any4

Ablation Studies

Group Size Referencing the paper, Table 4: C4 perplexity after quantizing with different group sizes.

64 128 256 512 1024
FP4 16.19 [41] 17.11 [42] 18.12 [43] 20.43 [44] 2.3E6 [45]
NF4 14.27 [46] 14.63 [47] 14.98 [48] 15.38 [49] 7.8E5 [50]
ANY4 13.75 [51] 13.95 [52] 14.09 [53] 14.24 [54] 14.34 [55]
Commands to reproduce results:
  1. python eval.py --quantize fp4 --quantize-args n_bit=4,group_size=64,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
  2. python eval.py --quantize fp4 --quantize-args n_bit=4,group_size=128,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
  3. python eval.py --quantize fp4 --quantize-args n_bit=4,group_size=256,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
  4. python eval.py --quantize fp4 --quantize-args n_bit=4,group_size=512,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
  5. python eval.py --quantize fp4 --quantize-args n_bit=4,group_size=1024,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
  6. python eval.py --quantize nf4 --quantize-args n_bit=4,group_size=64,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
  7. python eval.py --quantize nf4 --quantize-args n_bit=4,group_size=128,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
  8. python eval.py --quantize nf4 --quantize-args n_bit=4,group_size=256,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
  9. python eval.py --quantize nf4 --quantize-args n_bit=4,group_size=512,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
  10. python eval.py --quantize nf4 --quantize-args n_bit=4,group_size=1024,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks c4
  11. python eval.py --quantize anyq --quantize-args n_bit=4,group_size=64,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks c4
  12. python eval.py --quantize anyq --quantize-args n_bit=4,group_size=128,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks c4
  13. python eval.py --quantize anyq --quantize-args n_bit=4,group_size=256,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks c4
  14. python eval.py --quantize anyq --quantize-args n_bit=4,group_size=512,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks c4
  15. python eval.py --quantize anyq --quantize-args n_bit=4,group_size=1024,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks c4

Calibration Data Referencing the paper, Table 3: any4 quantization with different calibration data.

Calibration Data Number of Samples Sequence Length per Sample WikiText-2↓ C4↓ PTB↓ CodeParrot↓
ANY4 [57] WikiText-2 128 2048 10.70 14.08 18.02 3.74
ANY4 [58] Pile 128 2048 10.70 13.99 18.26 3.74
ANY4 [59] C4 128 2048 10.67 14.05 17.97 3.74
ANY4 [60] C4 128 4096 10.74 14.14 18.10 3.75
ANY4 [61] C4 128 512 10.62 13.96 18.03 3.72
ANY4 [62] Handwritten Prompt 1 - 10.63 13.95 17.94 3.71
Commands to reproduce results:
  1. a. Offline Calibration: python calibrate.py --dataset wikitext-2 --dataloader-type gptq --num-samples 128 --max-seq-len 2048 --model-name meta-llama/Llama-3.2-1B --log-dir ./calibrations/llama3.2-1b/wikitext2_128_2048/ followed by python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=./calibrations/llama3.2-1b/wikitext2_128_2048/wikitext-2.pt,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

    b. Online Calibration [will lead to different results]: python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --calibrate-args dataset=wikitext-2,dataloader_type=gptq,num_samples=128,max_seq_len=2048 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

  2. a. Offline Calibration: python calibrate.py --dataset monology/pile-uncopyrighted --num-samples 128 --max-seq-len 2048 --model-name meta-llama/Llama-3.2-1B --log-dir ./calibrations/llama3.2-1b/pile_128_2048/ followed by python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=./calibrations/llama3.2-1b/pile_128_2048/pile-uncopyrighted.pt,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

    b. Online Calibration [will lead to different results]: python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --calibrate-args dataset=monology/pile-uncopyrighted,num_samples=128,max_seq_len=2048 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

  3. a. Offline Calibration: python calibrate.py --dataset c4 --dataloader-type gptq --num-samples 128 --max-seq-len 2048 --model-name meta-llama/Llama-3.2-1B --log-dir ./calibrations/llama3.2-1b/c4_128_2048/ followed by python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=./calibrations/llama3.2-1b/c4_128_2048/c4.pt,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

    b. Online Calibration [will lead to different results]: python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --calibrate-args dataset=c4,dataloader_type=gptq,num_samples=128,max_seq_len=2048 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

  4. a. Offline Calibration: python calibrate.py --dataset c4 --dataloader-type gptq --num-samples 128 --max-seq-len 4096 --model-name meta-llama/Llama-3.2-1B --log-dir ./calibrations/llama3.2-1b/c4_128_4096/ followed by python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=./calibrations/llama3.2-1b/c4_128_4096/c4.pt,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

    b. Online Calibration [will lead to different results]: python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --calibrate-args dataset=c4,dataloader_type=gptq,num_samples=128,max_seq_len=4096 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

  5. a. Offline Calibration: python calibrate.py --dataset c4 --dataloader-type gptq --num-samples 128 --max-seq-len 4096 --model-name meta-llama/Llama-3.2-1B --log-dir ./calibrations/llama3.2-1b/c4_128_512/ followed by python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=./calibrations/llama3.2-1b/c4_128_512/c4.pt,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

    b. Online Calibration [will lead to different results]: python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --calibrate-args dataset=c4,dataloader_type=gptq,num_samples=128,max_seq_len=512 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

  6. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

Term to Minimize Perplexity after quantizing Llama3.2 1B with LUTs created by minimizing different terms.

Term to Minimize WikiText-2↓ C4↓ PTB↓ CodeParrot↓
Weights Only [63] $(w_{S_{i,j}} - w_{Q_{i,j}})$ 11.143 14.740 18.715 3.858
Weights × Activations [64] $(w_{S_{i,j}}x_j - w_{Q_{i,j}}x_j)$ 10.636 13.954 18.031 3.719
Weights × Activations × Group Scales [65] [Ours] $(\alpha_{i,j}w_{S_{i,j}}x_j - \alpha_{i,j}w_{Q_{i,j}}x_j)$ 10.603 13.949 18.085 3.710
Commands to reproduce results:
  1. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
  2. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot
  3. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb codeparrot

K-Means Initialization Referencing the paper, Table A4: any4 quantization with K-means clustering initialzied with different algorithms and values.

K-Means Initialization WikiText-2↓ C4↓ PTB↓
FP16 [66] - 9.76 12.77 16.56
ANY4 [67] k-means++ 10.63 13.95 17.94
ANY4 [68] random 10.66 13.97 18.17
ANY4 [69] int4 10.83 14.21 18.69
ANY4 [70] nf4 10.65 13.96 18.21
Commands to reproduce results:
  1. python eval.py --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb
  2. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True,init=k-means++ --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb
  3. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True,init=random --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb
  4. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True,init=int --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb
  5. python eval.py --quantize anyq --quantize-args n_bit=4,skip_modules=lm_head,sample_weight=calibrate,scale_sample_weight=True,init=nf4 --model-name meta-llama/Llama-3.2-1B --tasks wikitext-2 c4 ptb

Benchmarking

Referencing the paper, Figure 3: Speedup of our tinygemm CUDA kernels on matrix multiplication of 1 × K input by K × K weight, w.r.t PyTorch’s bfloat16 implementation. Please note the results below are on Nvidia A5000, while the paper's figure was based on Nvidia A100.

Microbenchmark Results

Dimension ($DIM) INT4 [71] NF4 [73] ANY4 [72]
1024 1.45x 1.37x 1.36x
2048 2.75x 2.17x 2.32x
3072 2.60x 2.07x 2.15x
4096 3.26x 2.23x 2.29x
5120 3.19x 2.26x 2.27x
6144 3.40x 2.27x 2.23x
7168 3.26x 2.19x 2.24x
8192 3.52x 2.24x 2.25x
Commands to reproduce results:
  1. python microbenchmark.py --input-dim $DIM --output-dim $DIM --quantize intq
  2. python microbenchmark.py --input-dim $DIM --output-dim $DIM --quantize anyq
  3. python microbenchmark.py --input-dim $DIM --output-dim $DIM --quantize anyq --quantize-args per_row=False (Note we have not yet implemented NF4 end-to-end modules. See #16).

Contribution

We encourage contributions from the community. Please feel free to check our Issues for any task to contribute with, especially our TODOs issue, as well as our Contributing Guidelines.

License

tinygemm and any4 quantization code are CC-BY-NC 4.0 licensed, as found in the LICENSE file.

Citation

If you use any4 quantization algorithm and/or tinygemm quantization library, please use the following BibTex entry:

@inproceedings{any4,
    title={any4: Learned 4-bit Numeric Representation for {LLM}s},
    author={Mostafa Elhoushi and Jeff Johnson},
    booktitle={Forty-second International Conference on Machine Learning},
    year={2025},
    url={https://openreview.net/forum?id=tJmhOPkWCj}
}

About

Quantize transformers to any learned arbitrary 4-bit numeric format

Resources

License

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Packages

No packages published

Contributors 6