Skip to content

Commit cadf1f0

Browse files
authored
Merge branch 'main' into loguru
2 parents 700f690 + 341e27c commit cadf1f0

File tree

82 files changed

+2102
-9146
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+2102
-9146
lines changed

examples/awq/qwen3_moe_example.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.awq import AWQModifier
6+
7+
# Select model and load it.
8+
MODEL_ID = "Qwen/Qwen3-30B-A3B"
9+
10+
model = AutoModelForCausalLM.from_pretrained(
11+
MODEL_ID, device_map="auto", torch_dtype="auto"
12+
)
13+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
14+
15+
# Select calibration dataset.
16+
DATASET_ID = "mit-han-lab/pile-val-backup"
17+
DATASET_SPLIT = "validation"
18+
19+
# Select number of samples. 256 samples is a good place to start.
20+
# Increasing the number of samples can improve accuracy.
21+
NUM_CALIBRATION_SAMPLES = 256
22+
MAX_SEQUENCE_LENGTH = 512
23+
24+
# Load dataset and preprocess.
25+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
26+
ds = ds.shuffle(seed=42)
27+
28+
29+
def preprocess(example):
30+
return {
31+
"text": tokenizer.apply_chat_template(
32+
[{"role": "user", "content": example["text"]}],
33+
tokenize=False,
34+
)
35+
}
36+
37+
38+
ds = ds.map(preprocess)
39+
40+
41+
# Tokenize inputs.
42+
def tokenize(sample):
43+
return tokenizer(
44+
sample["text"],
45+
padding=False,
46+
max_length=MAX_SEQUENCE_LENGTH,
47+
truncation=True,
48+
add_special_tokens=False,
49+
)
50+
51+
52+
# Configure the quantization algorithm to run.
53+
# NOTE: vllm currently does not support asym MoE, using symmetric here
54+
recipe = [
55+
AWQModifier(
56+
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
57+
scheme="W4A16",
58+
targets=["Linear"],
59+
),
60+
]
61+
62+
# Apply algorithms.
63+
oneshot(
64+
model=model,
65+
dataset=ds,
66+
recipe=recipe,
67+
max_seq_length=MAX_SEQUENCE_LENGTH,
68+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
69+
)
70+
71+
# Confirm generations of the quantized model look sane.
72+
print("\n\n")
73+
print("========== SAMPLE GENERATION ==============")
74+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
75+
output = model.generate(input_ids, max_new_tokens=100)
76+
print(tokenizer.decode(output[0]))
77+
print("==========================================\n\n")
78+
79+
# Save to disk compressed.
80+
SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-sym"
81+
model.save_pretrained(SAVE_DIR, save_compressed=True)
82+
tokenizer.save_pretrained(SAVE_DIR)

examples/multimodal_audio/README.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ Sequential targets are the modules which determine the granularity of error prop
4747

4848
Choosing sequential targets with higher granularity (for example "Linear" instead of "LlamaDecoderLayer") will result in fewer hessians being allocated at the same time, decreasing the memory requirements for compression. This may also increase the recovered accuracy of the model, as compression error is propagated at a higher granularity. However, using higher granularity sequential targets may also increase compression time, as more time is spent offloading and onloading activations.
4949

50-
### Ignore ###
51-
If your model is not traceable for your desired dataset, first consider adding any problematic modules to the ignore list. Doing this prevents the model tracer from tracing the internals of those modules, thereby avoid the untraceable operations.
52-
53-
## Tracing Errors ##
54-
Because the architectures of audio-language models is often times more complex than those of typical decoder-only text models, you may encounter `torch.fx.TraceError`s when attempting to quantize your model. For more information on `torch.fx.TraceError`s, why they occur, and how to resolve them, please see the [Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md).
55-
5650
## Adding Your Own Smoothquant Mappings ##
5751
For a guide on adding smoothquant mappings for your dataset, see the [SmoothQuant Guide](/src/llmcompressor/modifiers/smoothquant/README.md).
5852

examples/multimodal_vision/README.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ Sequential targets are the modules which determine the granularity of error prop
5151

5252
Choosing sequential targets with higher granularity (for example "Linear" instead of "LlamaDecoderLayer") will result in fewer hessians being allocated at the same time, decreasing the memory requirements for compression. This may also increase the recovered accuracy of the model, as compression error is propagated at a higher granularity. However, using higher granularity sequential targets may also increase compression time, as more time is spent offloading and onloading activations.
5353

54-
### Ignore ###
55-
If your model is not traceable for your desired dataset, first consider adding any problematic modules to the ignore list. Doing this prevents the model tracer from tracing the internals of those modules, thereby avoid the untraceable operations.
56-
57-
## Tracing Errors ##
58-
Because the architectures of vision-language models is often times more complex than those of typical decoder-only text models, you may encounter `torch.fx.TraceError`s when attempting to quantize your model. For more information on `torch.fx.TraceError`s, why they occur, and how to resolve them, please see the [Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md).
59-
6054
## Adding Your Own Smoothquant Mappings ##
6155
For a guide on adding smoothquant mappings for your dataset, see the [SmoothQuant Guide](/src/llmcompressor/modifiers/smoothquant/README.md).
6256

examples/multimodal_vision/gemma3_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import requests
22
import torch
33
from PIL import Image
4-
from transformers import AutoProcessor
4+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
55

66
from llmcompressor import oneshot
77
from llmcompressor.modifiers.quantization import GPTQModifier
8-
from llmcompressor.transformers.tracing import TraceableGemma3ForConditionalGeneration
98

109
# Load model.
1110
model_id = "google/gemma-3-4b-it"
12-
model = TraceableGemma3ForConditionalGeneration.from_pretrained(
11+
model = Gemma3ForConditionalGeneration.from_pretrained(
1312
model_id, device_map="auto", torch_dtype="auto"
1413
)
1514
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
@@ -64,8 +63,9 @@ def data_collator(batch):
6463
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
6564
raw_image = Image.open(requests.get(image_url, stream=True).raw)
6665

66+
# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
6767
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
68-
output = model.generate(**inputs, max_new_tokens=100)
68+
output = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
6969
print(processor.decode(output[0], skip_special_tokens=True))
7070
print("==========================================")
7171

examples/multimodal_vision/idefics3_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
import torch
33
from datasets import load_dataset
44
from PIL import Image
5-
from transformers import AutoProcessor
5+
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
66

77
from llmcompressor import oneshot
88
from llmcompressor.modifiers.quantization import GPTQModifier
9-
from llmcompressor.transformers.tracing import TraceableIdefics3ForConditionalGeneration
109

1110
# Load model.
1211
model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct"
13-
model = TraceableIdefics3ForConditionalGeneration.from_pretrained(
12+
model = Idefics3ForConditionalGeneration.from_pretrained(
1413
model_id, device_map="auto", torch_dtype="auto"
1514
)
1615
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

examples/multimodal_vision/llava_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import requests
22
import torch
33
from PIL import Image
4-
from transformers import AutoProcessor
4+
from transformers import AutoProcessor, LlavaForConditionalGeneration
55

66
from llmcompressor import oneshot
77
from llmcompressor.modifiers.quantization import GPTQModifier
8-
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration
98

109
# Load model.
1110
model_id = "llava-hf/llava-1.5-7b-hf"
12-
model = TraceableLlavaForConditionalGeneration.from_pretrained(
11+
model = LlavaForConditionalGeneration.from_pretrained(
1312
model_id, device_map="auto", torch_dtype="auto"
1413
)
1514
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"chat_template": "{%- set default_system_message = \"You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI\" %}\n\n{{- bos_token }}\n\n{%- if messages[0]['role'] == 'system' %}\n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content'] %}\n {%- else %}\n {%- set system_message = messages[0]['content'][0]['text'] %}\n {%- endif %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set system_message = default_system_message %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}\n\n{%- for message in loop_messages %}\n {%- if message['role'] == 'user' %}\n {%- if message['content'] is string %}\n {{- '[INST]' + message['content'] + '[/INST]' }}\n {%- else %}\n {{- '[INST]' }}\n {%- for block in message['content'] %}\n {%- if block['type'] == 'text' %}\n {{- block['text'] }}\n {%- elif block['type'] in ['image', 'image_url'] %}\n {{- '[IMG]' }}\n {%- else %}\n {{- raise_exception('Only text and image blocks are supported in message content!') }}\n {%- endif %}\n {%- endfor %}\n {{- '[/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'system' %}\n {%- if message['content'] is string %}\n {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}\n {%- else %}\n {{- '[SYSTEM_PROMPT]' + message['content'][0]['text'] + '[/SYSTEM_PROMPT]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {%- if message['content'] is string %}\n {{- message['content'] + eos_token }}\n {%- else %}\n {{- message['content'][0]['text'] + eos_token }}\n {%- endif %}\n {%- else %}\n {{- raise_exception('Only user, system and assistant roles are supported!') }}\n {%- endif %}\n{%- endfor %}"
3+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import json
2+
import os
3+
4+
import requests
5+
import torch
6+
from PIL import Image
7+
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
8+
9+
from llmcompressor import oneshot
10+
from llmcompressor.modifiers.quantization import GPTQModifier
11+
12+
# Load model.
13+
model_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
14+
model = Mistral3ForConditionalGeneration.from_pretrained(
15+
model_id, device_map="auto", torch_dtype="auto"
16+
)
17+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
18+
19+
# Use a custom calibration chat template, rather than the overly-verbose default
20+
file_path = os.path.join(os.path.dirname(__file__), "mistral3_chat_template.json")
21+
with open(file_path, "r") as file:
22+
processor.chat_template = json.load(file)["chat_template"]
23+
24+
# Oneshot arguments
25+
DATASET_ID = "flickr30k"
26+
DATASET_SPLIT = "test"
27+
NUM_CALIBRATION_SAMPLES = 512
28+
MAX_SEQUENCE_LENGTH = 2048
29+
30+
31+
# Define a oneshot data collator for multimodal inputs.
32+
def data_collator(batch):
33+
assert len(batch) == 1
34+
return {
35+
key: torch.tensor(value)
36+
if key != "pixel_values"
37+
else torch.tensor(value, dtype=model.dtype)
38+
for key, value in batch[0].items()
39+
}
40+
41+
42+
# Recipe
43+
recipe = [
44+
GPTQModifier(
45+
targets="Linear",
46+
scheme="W4A16",
47+
sequential_targets=["MistralDecoderLayer"],
48+
ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
49+
),
50+
]
51+
52+
# Perform oneshot
53+
oneshot(
54+
model=model,
55+
tokenizer=model_id,
56+
dataset=DATASET_ID,
57+
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
58+
recipe=recipe,
59+
max_seq_length=MAX_SEQUENCE_LENGTH,
60+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
61+
trust_remote_code_model=True,
62+
data_collator=data_collator,
63+
)
64+
65+
# Confirm generations of the quantized model look sane.
66+
print("========== SAMPLE GENERATION ==============")
67+
messages = [
68+
{
69+
"role": "user",
70+
"content": [
71+
{"type": "text", "text": "Please describe the animal in this image\n"},
72+
{"type": "image"},
73+
],
74+
},
75+
]
76+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
77+
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
78+
raw_image = Image.open(requests.get(image_url, stream=True).raw)
79+
80+
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
81+
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) # fix dtype
82+
output = model.generate(**inputs, max_new_tokens=100)
83+
print(processor.decode(output[0], skip_special_tokens=True))
84+
print("==========================================")
85+
86+
# Save to disk compressed.
87+
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
88+
model.save_pretrained(SAVE_DIR, save_compressed=True)
89+
processor.save_pretrained(SAVE_DIR)

examples/multimodal_vision/mllama_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import requests
22
import torch
33
from PIL import Image
4-
from transformers import AutoProcessor
4+
from transformers import AutoProcessor, MllamaForConditionalGeneration
55

66
from llmcompressor import oneshot
77
from llmcompressor.modifiers.quantization import GPTQModifier
8-
from llmcompressor.transformers.tracing import TraceableMllamaForConditionalGeneration
98

109
# Load model.
1110
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
12-
model = TraceableMllamaForConditionalGeneration.from_pretrained(
11+
model = MllamaForConditionalGeneration.from_pretrained(
1312
model_id, device_map="auto", torch_dtype="auto"
1413
)
1514
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

examples/multimodal_vision/pixtral_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import requests
22
import torch
33
from PIL import Image
4-
from transformers import AutoProcessor
4+
from transformers import AutoProcessor, LlavaForConditionalGeneration
55

66
from llmcompressor import oneshot
77
from llmcompressor.modifiers.quantization import GPTQModifier
8-
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration
98

109
# Load model.
1110
model_id = "mgoin/pixtral-12b"
12-
model = TraceableLlavaForConditionalGeneration.from_pretrained(
11+
model = LlavaForConditionalGeneration.from_pretrained(
1312
model_id, device_map="auto", torch_dtype="auto"
1413
)
1514
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

examples/multimodal_vision/qwen2_vl_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import torch
55
from datasets import load_dataset
66
from qwen_vl_utils import process_vision_info
7-
from transformers import AutoProcessor
7+
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
88

99
from llmcompressor import oneshot
1010
from llmcompressor.modifiers.quantization import GPTQModifier
11-
from llmcompressor.transformers.tracing import TraceableQwen2VLForConditionalGeneration
1211

1312
# Load model.
1413
model_id = "Qwen/Qwen2-VL-2B-Instruct"
15-
model = TraceableQwen2VLForConditionalGeneration.from_pretrained(
14+
model = Qwen2VLForConditionalGeneration.from_pretrained(
1615
model_id,
1716
device_map="auto",
1817
torch_dtype="auto",

examples/multimodal_vision/qwen_2_5_vl_example.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,14 @@
44
import torch
55
from datasets import load_dataset
66
from qwen_vl_utils import process_vision_info
7-
from transformers import AutoProcessor
7+
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
88

99
from llmcompressor.modifiers.quantization import GPTQModifier
1010
from llmcompressor.transformers import oneshot
11-
from llmcompressor.transformers.tracing import (
12-
TraceableQwen2_5_VLForConditionalGeneration,
13-
)
1411

1512
# Load model.
1613
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
17-
model = TraceableQwen2_5_VLForConditionalGeneration.from_pretrained(
14+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
1815
model_id,
1916
device_map="auto",
2017
torch_dtype="auto",

examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,13 @@
6868
model=model,
6969
**oneshot_kwargs,
7070
stage="sparsity_stage",
71-
output_dir=output_dir,
7271
)
7372

7473
# Sparse finetune
7574
finetune_applied_model = train(
7675
model=oneshot_applied_model,
7776
**oneshot_kwargs,
7877
**training_kwargs,
79-
output_dir=output_dir,
8078
stage="finetuning_stage",
8179
)
8280

examples/quantization_w4a16_fp4/llama3_example.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
# Apply quantization.
2020
oneshot(model=model, recipe=recipe)
2121

22+
print("\n\n")
23+
print("========== SAMPLE GENERATION ==============")
24+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
25+
output = model.generate(input_ids, max_new_tokens=100)
26+
print(tokenizer.decode(output[0]))
27+
print("==========================================\n\n")
28+
29+
2230
# Save to disk in compressed-tensors format.
2331
SAVE_DIR = MODEL_ID.split("/")[1] + "-NVFP4A16"
2432
model.save_pretrained(SAVE_DIR, save_compressed=True)

0 commit comments

Comments
 (0)