diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_cuda/requirements.txt b/test/test_cuda/requirements.txt index 9168419c..cd1f13dc 100644 --- a/test/test_cuda/requirements.txt +++ b/test/test_cuda/requirements.txt @@ -17,3 +17,5 @@ torch torchvision tqdm transformers==4.51.3 +vllm>=0.8.5.post1 + diff --git a/test/test_cuda/test_transformers.py b/test/test_cuda/test_transformers.py new file mode 100644 index 00000000..527f5aec --- /dev/null +++ b/test/test_cuda/test_transformers.py @@ -0,0 +1,209 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import tempfile +import unittest + +from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer +from transformers.testing_utils import ( + require_accelerate, + require_intel_extension_for_pytorch, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) +from transformers.utils import is_torch_available + + +if is_torch_available(): + import torch + + +# @slow +@require_torch_gpu +@require_accelerate +class AutoRoundTest(unittest.TestCase): + model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc" + input_text = "There is a girl who likes adventure," + EXPECTED_OUTPUTS = set() + ## Different backends may produce slight variations in output + EXPECTED_OUTPUTS.add( + "There is a girl who likes adventure, and she has been exploring the world " + "for many years. She travels to different countries and cultures, trying new " + "things every day. One of her favorite places to visit is a small village in " + "the mountains where" + ) + EXPECTED_OUTPUTS.add( + "There is a girl who likes adventure, and she has been exploring the world for many years. She has visited every country in Europe and has even traveled to some of the most remote parts of Africa. She enjoys hiking through the mountains and discovering" + ) + + device_map = "cuda" + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + torch.cuda.synchronize() + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, device_map=cls.device_map, torch_dtype=torch.float16 + ) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_quantized_model(self): + """ + Simple test that checks if the quantized model is working properly + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + output = self.quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + def test_raise_if_non_quantized(self): + model_id = "facebook/opt-125m" + quantization_config = AutoRoundConfig(bits=4) + with self.assertRaises(ValueError): + _ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) + + def test_quantized_model_bf16(self): + """ + Simple test that checks if the quantized model is working properly with bf16 + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + quantization_config = AutoRoundConfig(backend="triton") + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16, + device_map=self.device_map, + quantization_config=quantization_config, + ) + + output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + @require_intel_extension_for_pytorch + def test_quantized_model_on_cpu(self): + """ + Simple test that checks if the quantized model is working properly + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt") + + quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto") + output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False) + + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + def test_save_pretrained(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + + ## some backends like marlin/ipex will repack the weight that caused the weight shape changed + with tempfile.TemporaryDirectory() as tmpdirname: + quantization_config = AutoRoundConfig(backend="triton") + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=self.device_map, + torch_dtype=torch.float16, + quantization_config=quantization_config, + ) + + quantized_model.save_pretrained(tmpdirname) + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="cuda") + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=40, do_sample=False) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + @require_torch_multi_gpu + def test_quantized_model_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly with multiple GPUs + """ + quantization_config = AutoRoundConfig(backend="triton") + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, device_map="auto", quantization_config=quantization_config, torch_dtype="auto" + ) + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(quantized_model.device) + output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + def test_convert_from_gptq(self): + """ + Simple test that checks if auto-round work properly with gptq format + """ + model_name = "ybelkada/opt-125m-gptq-4bit" + + quantization_config = AutoRoundConfig() + + model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="cuda", quantization_config=quantization_config, torch_dtype="auto" + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) + + @require_intel_extension_for_pytorch + def test_convert_from_awq_cpu(self): + """ + Simple test that checks if auto-round work properly with awq format + """ + model_name = "casperhansen/opt-125m-awq" + + quantization_config = AutoRoundConfig() + + model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="cpu", quantization_config=quantization_config, torch_dtype="auto" + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) + + def test_mixed_bits(self): + """ + Simple test that checks if auto-round work properly with mixed bits + """ + model_name = "facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto") + tokenizer = AutoTokenizer.from_pretrained(model_name) + layer_config = { + "model.decoder.layers.0.self_attn.k_proj": {"bits": 8}, + "model.decoder.layers.6.self_attn.out_proj": {"bits": 2, "group_size": 32}, + } + + bits, group_size, sym = 4, 128, True + from auto_round import AutoRound + + autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, layer_config=layer_config) + with tempfile.TemporaryDirectory() as tmpdirname: + autoround.quantize_and_save(output_dir=tmpdirname) + model = AutoModelForCausalLM.from_pretrained(tmpdirname, torch_dtype=torch.float16, device_map="cuda") + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(model.device) + tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0]) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/test_cuda/test_vllm.py b/test/test_cuda/test_vllm.py new file mode 100644 index 00000000..88a252d2 --- /dev/null +++ b/test/test_cuda/test_vllm.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test model set-up and inference for quantized HF models supported + on the AutoRound. + + Validating the configuration and printing results for manual checking. + + Run `pytest tests/quantization/test_auto_round.py`. +""" + +import pytest +from vllm import LLM, SamplingParams +from vllm.platforms import current_platform + +MODELS = [ + "OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", ##auto_round:auto_gptq + "Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound" ##auto_round:auto_awq +] + + +@pytest.mark.skipif(not current_platform.is_cpu() + and not current_platform.is_xpu() + and not current_platform.is_cuda(), + reason="only supports CPU/XPU/CUDA backend.") +@pytest.mark.parametrize("model", MODELS) +def test_auto_round(model): + # Sample prompts. + prompts = [ + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + # Create an LLM. + QUANTIZATION = "auto-round" + llm = LLM(model=model, quantization=QUANTIZATION, trust_remote_code=True, tensor_parallel_size=1) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + if "France" in prompt: + assert "Paris" in generated_text