Skip to content

add auto-round related vllm and transformers UT #613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added test/__init__.py
Empty file.
209 changes: 209 additions & 0 deletions test/test_cuda/test_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest

from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer
from transformers.testing_utils import (
require_accelerate,
require_intel_extension_for_pytorch,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils import is_torch_available


if is_torch_available():
import torch


# @slow
@require_torch_gpu
@require_accelerate
class AutoRoundTest(unittest.TestCase):
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
input_text = "There is a girl who likes adventure,"
EXPECTED_OUTPUTS = set()
## Different backends may produce slight variations in output
EXPECTED_OUTPUTS.add(
"There is a girl who likes adventure, and she has been exploring the world "
"for many years. She travels to different countries and cultures, trying new "
"things every day. One of her favorite places to visit is a small village in "
"the mountains where"
)
EXPECTED_OUTPUTS.add(
"There is a girl who likes adventure, and she has been exploring the world for many years. She has visited every country in Europe and has even traveled to some of the most remote parts of Africa. She enjoys hiking through the mountains and discovering"
)

device_map = "cuda"

# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
torch.cuda.synchronize()
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, torch_dtype=torch.float16
)

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()

def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_raise_if_non_quantized(self):
model_id = "facebook/opt-125m"
quantization_config = AutoRoundConfig(bits=4)
with self.assertRaises(ValueError):
_ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)

def test_quantized_model_bf16(self):
"""
Simple test that checks if the quantized model is working properly with bf16
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = AutoRoundConfig(backend="triton")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=self.device_map,
quantization_config=quantization_config,
)

output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

@require_intel_extension_for_pytorch
def test_quantized_model_on_cpu(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt")

quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto")
output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)

self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""

## some backends like marlin/ipex will repack the weight that caused the weight shape changed
with tempfile.TemporaryDirectory() as tmpdirname:
quantization_config = AutoRoundConfig(backend="triton")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device_map,
torch_dtype=torch.float16,
quantization_config=quantization_config,
)

quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="cuda")

input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

output = model.generate(**input_ids, max_new_tokens=40, do_sample=False)
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
"""
quantization_config = AutoRoundConfig(backend="triton")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=quantization_config, torch_dtype="auto"
)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(quantized_model.device)
output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_convert_from_gptq(self):
"""
Simple test that checks if auto-round work properly with gptq format
"""
model_name = "ybelkada/opt-125m-gptq-4bit"

quantization_config = AutoRoundConfig()

model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="cuda", quantization_config=quantization_config, torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])

@require_intel_extension_for_pytorch
def test_convert_from_awq_cpu(self):
"""
Simple test that checks if auto-round work properly with awq format
"""
model_name = "casperhansen/opt-125m-awq"

quantization_config = AutoRoundConfig()

model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="cpu", quantization_config=quantization_config, torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])

def test_mixed_bits(self):
"""
Simple test that checks if auto-round work properly with mixed bits
"""
model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
layer_config = {
"model.decoder.layers.0.self_attn.k_proj": {"bits": 8},
"model.decoder.layers.6.self_attn.out_proj": {"bits": 2, "group_size": 32},
}

bits, group_size, sym = 4, 128, True
from auto_round import AutoRound

autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, layer_config=layer_config)
with tempfile.TemporaryDirectory() as tmpdirname:
autoround.quantize_and_save(output_dir=tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, torch_dtype=torch.float16, device_map="cuda")
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])


if __name__ == "__main__":
unittest.main()
45 changes: 45 additions & 0 deletions test/test_cuda/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
"""Test model set-up and inference for quantized HF models supported
on the AutoRound.

Validating the configuration and printing results for manual checking.

Run `pytest tests/quantization/test_auto_round.py`.
"""

import pytest
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform

MODELS = [
"OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", ##auto_round:auto_gptq
"Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound" ##auto_round:auto_awq
]


@pytest.mark.skipif(not current_platform.is_cpu()
and not current_platform.is_xpu()
and not current_platform.is_cuda(),
reason="only supports CPU/XPU/CUDA backend.")
@pytest.mark.parametrize("model", MODELS)
def test_auto_round(model):
# Sample prompts.
prompts = [
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
QUANTIZATION = "auto-round"
llm = LLM(model=model, quantization=QUANTIZATION, trust_remote_code=True, tensor_parallel_size=1)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
if "France" in prompt:
assert "Paris" in generated_text