Skip to content

[MoE] Cleanup MoE examples #1576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

Big updates have landed in LLM Compressor! Check out these exciting new features:

* **DeepSeekV3 and Sequential Onloading Support** As of llm-compressor>=0.6.0, you can now quantize DeepSeekV3 and other large models on a single GPU. Models are broken into disjoint layers which are then onloaded to the GPU one layer at a time. For more information on sequential onloading, see [Big Modeling with Sequential Onloading](examples/big_models_with_sequential_onloading/README.md) as well as the [DeepSeekV3 Example](examples/quantizing_moe/deepseekv3_example.py).
* **Preliminary FP4 Quantization Support:** Quantize weights and activations to FP4 and seamlessly run the compressed model in vLLM. Model weights and activations are quantized following the NVFP4 [configuration](https://github.com/neuralmagic/compressed-tensors/blob/f5dbfc336b9c9c361b9fe7ae085d5cb0673e56eb/src/compressed_tensors/quantization/quant_scheme.py#L104). See examples of [weight-only quantization](examples/quantization_w4a16_fp4/llama3_example.py) and [fp4 activation support](examples/quantization_w4a4_fp4/llama3_example.py). Support is currently preliminary and additional support will be added for MoEs.
* **Axolotl Sparse Finetuning Integration:** Seamlessly finetune sparse LLMs with our Axolotl integration. Learn how to create [fast sparse open-source models with Axolotl and LLM Compressor](https://developers.redhat.com/articles/2025/06/17/axolotl-meets-llm-compressor-fast-sparse-open). See also the [Axolotl integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#llmcompressor).
* **AutoAWQ Integration:** Perform low-bit weight-only quantization efficiently using AutoAWQ, now part of LLM Compressor. *Note: This integration should be considered experimental for now. Enhanced support, including for MoE models and improved handling of larger models via layer sequential pipelining, is planned for upcoming releases.* [See the details](https://github.com/vllm-project/llm-compressor/pull/1177).
Expand Down
44 changes: 22 additions & 22 deletions examples/quantizing_moe/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Quantizing Mixtral-8x7B-Instruct-v0.1 Model with FP8
# Quantizing Mixtral-8x7B-Instruct-v0.1 Model with W4A16

This directory contains an example script for quantizing the `Mixtral-8x7B-Instruct-v0.1` model using the static per-tensor FP8 quantization scheme.
This directory contains an example script for quantizing the `Mixtral-8x7B-Instruct-v0.1` model using the static per-tensor W4A16 quantization scheme.

## Installation

Expand All @@ -17,17 +17,17 @@ pip install -e .
The provided example script demonstrates an end-to-end process for applying the quantization algorithm:

```bash
python3 mixtral_moe_w8a8_fp8.py
python3 mixtral_example.py
```

## Creating a Quantized MoE Model

This example leverages `llm-compressor` and `compressed-tensors` to create an FP8-quantized `Mixtral-8x7B-Instruct-v0.1` model. The model is calibrated and trained using the `open_platypus` dataset.
This example leverages `llm-compressor` and `compressed-tensors` to create an W4A16-quantized `Mixtral-8x7B-Instruct-v0.1` model. The model is calibrated and trained using the `ultrachat_200k` dataset.

You can follow the detailed steps below or simply run the example script with:

```bash
python mixtral_moe_w8a8_fp8.py
python mixtral_example.py
```

### Step 1: Select a Model, Dataset, and Recipe
Expand All @@ -36,24 +36,24 @@ In this step, you'll choose a baseline model for quantization, a dataset for cal

- **Models**: Can be referenced from a local directory or retrieved from the Hugging Face Hub.
- **Datasets**: Can also be from a local directory or the Hugging Face Hub.
- **Recipes**: These are YAML files or Python modifier objects that describe how a model should be optimized during or after training. In this example, we use a `QuantizationModifier` object with the scheme set to `FP8`.
- **Recipes**: These are YAML files or Python modifier objects that describe how a model should be optimized during or after training. In this example, we use a `QuantizationModifier` object with the scheme set to `W4A16`.

```python
from llmcompressor.modifiers.quantization import QuantizationModifier

recipe = QuantizationModifier(scheme="FP8", targets="Linear", ignore=["lm_head", "re:.*block_sparse_moe.gate"])
recipe = QuantizationModifier(scheme="W4A16", targets="Linear", ignore=["lm_head", "re:.*block_sparse_moe.gate"])
```

NOTE: `.*block_sparse_moe.gate` layers do not quantize well, hence they are ignored!

### Step 2: Run Quantization Using Oneshot

The `oneshot` method applies the selected recipe to your model and dataset without requiring any fine-tuning. The model will be sparsified and saved to `Mixtral-8x7B-Instruct-v0.1-FP8`.
The `oneshot` method applies the selected recipe to your model and dataset without requiring any fine-tuning. The model will be sparsified and saved to `Mixtral-8x7B-Instruct-v0.1-W4A16-G128`.

```python
from llmcompressor import oneshot

output_dir = "Mixtral-8x7B-Instruct-v0.1-FP8"
output_dir = "Mixtral-8x7B-Instruct-v0.1-W4A16-G128"

oneshot(
model=model,
Expand All @@ -74,7 +74,7 @@ NOTE: Only per-tensor quantization is supported in vLLM as of now (`vllm==0.6.1`

The repository supports multiple quantization techniques configured via a recipe. Supported strategies include `tensor`, `group`, and `channel` quantization.

In the above example, FP8 per-tensor quantization is used as specified by the `FP8` scheme. For other preset schemes, refer to the [quantization schemes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py) in the `compressed-tensors` library.
In the above example, quantization is specified by the `W4A18` scheme. For other preset schemes, refer to the [quantization schemes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py) in the `compressed-tensors` library.

A custom scheme can also be specified using `config_groups`:

Expand All @@ -84,18 +84,18 @@ A custom scheme can also be specified using `config_groups`:
from llmcompressor.modifiers.quantization.gptq import GPTQModifier

config_groups = {
"group_0": {
"targets": ["Linear"],
"input_activations": None,
"output_activations": None,
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": true,
"strategy": "group",
"group_size": 128,
}
}
"group_0": {
"targets": ["Linear"],
"input_activations": None,
"output_activations": None,
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": true,
"strategy": "group",
"group_size": 128,
}
}
}

recipe = GPTQModifier(config_groups=config_groups)
Expand Down
125 changes: 0 additions & 125 deletions examples/quantizing_moe/deepseek_moe_w4a16.py

This file was deleted.

8 changes: 0 additions & 8 deletions examples/quantizing_moe/deepseek_recipe_w4a16.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
# previous version or upgrading to a version where this bug is fixed

# select a Mixture of Experts model for quantization
MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
MODEL_ID = "deepseek-ai/DeepSeek-V2.5"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
# its recommended to use more calibration samples for MoE models so each expert is hit
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 2048
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048


Expand Down Expand Up @@ -57,16 +56,12 @@ def tokenize(sample):

ds = ds.map(tokenize, remove_columns=ds.column_names)

# define a llmcompressor recipe for INT8 W8A8 quantization
# Configure the quantization algorithm to run.
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = [
GPTQModifier(
targets="Linear",
scheme="W8A8",
ignore=["lm_head", "re:.*mlp.gate$"],
),
]
recipe = GPTQModifier(
targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"]
)

oneshot(
model=model,
Expand All @@ -82,12 +77,10 @@ def tokenize(sample):
if Version(__version__) < Version("4.48"):
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
SAMPLE_INPUT = ["I love quantization because"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device)
output = model.generate(**inputs, max_length=50)
text_output = tokenizer.batch_decode(output)
print(text_output)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to("cuda") for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================")
else:
print(
Expand All @@ -96,6 +89,6 @@ def tokenize(sample):
)

# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8"
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
88 changes: 88 additions & 0 deletions examples/quantizing_moe/deepseekv3_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modeling import prepare_for_quantization
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
# For DeepSeekv3, we require a full precision model in order to properly calibrate
# `DeepSeek-V3-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16
model_id = "RedHatAI/DeepSeek-V3-BF16"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = prepare_for_quantization(model)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = GPTQModifier(
targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"]
)

# Apply algorithms.
# due to the large size of DeepSeekV3, we specify sequential targets such that
# only one MLP is loaded into GPU memory at a time
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"],
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to("cuda") for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
Loading