From 29bb6421ee76d4e117041cccdf935be7bcff539c Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Tue, 22 Apr 2025 23:58:46 -0700 Subject: [PATCH 1/5] enable gguf test cases on XPU Signed-off-by: YAO Matrix --- src/diffusers/models/modeling_utils.py | 1 + src/diffusers/quantizers/gguf/gguf_quantizer.py | 2 +- tests/quantization/gguf/test_gguf.py | 6 +++++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 2a22bc09ad7a..3e7cbdc031ff 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -760,6 +760,7 @@ def dequantize(self): dequantization. """ hf_quantizer = getattr(self, "hf_quantizer", None) + print(hf_quantizer) if hf_quantizer is None: raise ValueError("You need to first quantize your model in order to dequantize it") diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py index 6da69c7bd60c..ff68a8aa5ab2 100644 --- a/src/diffusers/quantizers/gguf/gguf_quantizer.py +++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py @@ -152,7 +152,7 @@ def _dequantize(self, model): logger.info( "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." ) - model.to(torch.cuda.current_device()) + model.to(torch.accelerator.current_accelerator()) model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert) if is_model_on_cpu: diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index e4cf1dfee1e5..0b1d80b0858c 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -16,11 +16,13 @@ StableDiffusion3Pipeline, ) from diffusers.utils import load_image +from diffusers.utils.torch_utils import get_device from diffusers.utils.testing_utils import ( is_gguf_available, nightly, numpy_cosine_similarity_distance, require_accelerate, + require_big_accelerator, require_big_gpu_with_torch_cuda, require_gguf_version_greater_or_equal, require_peft_backend, @@ -33,7 +35,7 @@ @nightly -@require_big_gpu_with_torch_cuda +@require_big_accelerator @require_accelerate @require_gguf_version_greater_or_equal("0.10.0") class GGUFSingleFileTesterMixin: @@ -68,6 +70,8 @@ def test_gguf_memory_usage(self): model = self.model_cls.from_single_file( self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype ) + device_type = get_device() + print(f"device_type: {device_type}") model.to("cuda") assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb inputs = self.get_dummy_inputs() From e24dd9b8c1308ebe4bf5f458b4534422b8dee170 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 23 Apr 2025 17:14:20 -0700 Subject: [PATCH 2/5] make SD35LargeGGUFSingleFileTests::test_pipeline_inference pas Signed-off-by: root --- src/diffusers/models/modeling_utils.py | 1 - .../quantizers/gguf/gguf_quantizer.py | 5 +- tests/quantization/gguf/test_gguf.py | 124 ++++++++++++------ 3 files changed, 88 insertions(+), 42 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 3e7cbdc031ff..2a22bc09ad7a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -760,7 +760,6 @@ def dequantize(self): dequantization. """ hf_quantizer = getattr(self, "hf_quantizer", None) - print(hf_quantizer) if hf_quantizer is None: raise ValueError("You need to first quantize your model in order to dequantize it") diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py index ff68a8aa5ab2..554b536974da 100644 --- a/src/diffusers/quantizers/gguf/gguf_quantizer.py +++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py @@ -150,9 +150,10 @@ def _dequantize(self, model): is_model_on_cpu = model.device.type == "cpu" if is_model_on_cpu: logger.info( - "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." + "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to accelerator. After dequantization, will move the model back to CPU again to preserve the previous device." ) - model.to(torch.accelerator.current_accelerator()) + device = torch.accelerator.current_accelerator() if hasattr(torch, "accelerator") else torch.cuda.current_device() + model.to(device) model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert) if is_model_on_cpu: diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 0b1d80b0858c..a27e2399447b 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -18,6 +18,8 @@ from diffusers.utils import load_image from diffusers.utils.torch_utils import get_device from diffusers.utils.testing_utils import ( + Expectations, + enable_full_determinism, is_gguf_available, nightly, numpy_cosine_similarity_distance, @@ -33,6 +35,8 @@ if is_gguf_available(): from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter +enable_full_determinism() + @nightly @require_big_accelerator @@ -70,17 +74,17 @@ def test_gguf_memory_usage(self): model = self.model_cls.from_single_file( self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype ) - device_type = get_device() - print(f"device_type: {device_type}") - model.to("cuda") + device = get_device() + model.to(device) assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb inputs = self.get_dummy_inputs() - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() + torch_accelerator_module = getattr(torch, device) + torch_accelerator_module.reset_peak_memory_stats() + torch_accelerator_module.empty_cache() with torch.no_grad(): model(**inputs) - max_memory = torch.cuda.max_memory_allocated() + max_memory = torch_accelerator_module.max_memory_allocated() assert (max_memory / 1024**3) < self.expected_memory_use_in_gb def test_keep_modules_in_fp32(self): @@ -104,13 +108,16 @@ def test_dtype_assignment(self): quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + device = get_device() + with self.assertRaises(ValueError): # Tries with a `dtype` model.to(torch.float16) with self.assertRaises(ValueError): # Tries with a `device` and `dtype` - model.to(device="cuda:0", dtype=torch.float16) + device_0 = f"{device}:0" + model.to(device=device_0, dtype=torch.float16) with self.assertRaises(ValueError): # Tries with a cast @@ -121,7 +128,7 @@ def test_dtype_assignment(self): model.half() # This should work - model.to("cuda") + model.to(device) def test_dequantize_model(self): quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) @@ -271,40 +278,79 @@ def test_pipeline_inference(self): prompt = "a cat holding a sign that says hello" output = pipe( - prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np" + prompt=prompt, + num_inference_steps=2, + generator=torch.Generator("cpu").manual_seed(0), + output_type="np", ).images[0] output_slice = output[:3, :3, :].flatten() - expected_slice = np.array( - [ - 0.17578125, - 0.27539062, - 0.27734375, - 0.11914062, - 0.26953125, - 0.25390625, - 0.109375, - 0.25390625, - 0.25, - 0.15039062, - 0.26171875, - 0.28515625, - 0.13671875, - 0.27734375, - 0.28515625, - 0.12109375, - 0.26757812, - 0.265625, - 0.16210938, - 0.29882812, - 0.28515625, - 0.15625, - 0.30664062, - 0.27734375, - 0.14648438, - 0.29296875, - 0.26953125, - ] + expected_slices = Expectations( + { + ("xpu", 3): np.array( + [ + 0.19335938, + 0.3125, + 0.3203125, + 0.1328125, + 0.3046875, + 0.296875, + 0.11914062, + 0.2890625, + 0.2890625, + 0.16796875, + 0.30273438, + 0.33203125, + 0.14648438, + 0.31640625, + 0.33007812, + 0.12890625, + 0.3046875, + 0.30859375, + 0.17773438, + 0.33789062, + 0.33203125, + 0.16796875, + 0.34570312, + 0.32421875, + 0.15625, + 0.33203125, + 0.31445312, + ] + ), + ("cuda", 7): np.array( + [ + 0.17578125, + 0.27539062, + 0.27734375, + 0.11914062, + 0.26953125, + 0.25390625, + 0.109375, + 0.25390625, + 0.25, + 0.15039062, + 0.26171875, + 0.28515625, + 0.13671875, + 0.27734375, + 0.28515625, + 0.12109375, + 0.26757812, + 0.265625, + 0.16210938, + 0.29882812, + 0.28515625, + 0.15625, + 0.30664062, + 0.27734375, + 0.14648438, + 0.29296875, + 0.26953125, + ] + ), + } ) + expected_slice = expected_slices.get_expectation() max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) assert max_diff < 1e-4 From 68a020f4b002a9169c1c72040e5fdfa67e1638bf Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Thu, 24 Apr 2025 06:08:36 +0000 Subject: [PATCH 3/5] make FluxControlLoRAGGUFTests::test_lora_loading pass Signed-off-by: Yao Matrix --- src/diffusers/loaders/lora_pipeline.py | 7 ++++--- tests/quantization/gguf/test_gguf.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index fb2cdf6ce304..50a99cee1d23 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -91,18 +91,19 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module): ) weight_on_cpu = False - if not module.weight.is_cuda: + if module.weight.device.type == "cpu": weight_on_cpu = True + device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" if is_bnb_4bit_quantized: module_weight = dequantize_bnb_weight( - module.weight.cuda() if weight_on_cpu else module.weight, + module.weight.to(device) if weight_on_cpu else module.weight, state=module.weight.quant_state, dtype=model.dtype, ).data elif is_gguf_quantized: module_weight = dequantize_gguf_tensor( - module.weight.cuda() if weight_on_cpu else module.weight, + module.weight.to(device) if weight_on_cpu else module.weight, ) module_weight = module_weight.to(model.dtype) else: diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index a27e2399447b..5edff276e317 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -19,6 +19,9 @@ from diffusers.utils.torch_utils import get_device from diffusers.utils.testing_utils import ( Expectations, + backend_empty_cache, + backend_max_memory_allocated, + backend_reset_peak_memory_stats, enable_full_determinism, is_gguf_available, nightly, @@ -79,12 +82,11 @@ def test_gguf_memory_usage(self): assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb inputs = self.get_dummy_inputs() - torch_accelerator_module = getattr(torch, device) - torch_accelerator_module.reset_peak_memory_stats() - torch_accelerator_module.empty_cache() + backend_reset_peak_memory_stats(device) + backend_empty_cache(device) with torch.no_grad(): model(**inputs) - max_memory = torch_accelerator_module.max_memory_allocated() + max_memory = backend_max_memory_allocated(device) assert (max_memory / 1024**3) < self.expected_memory_use_in_gb def test_keep_modules_in_fp32(self): @@ -513,7 +515,7 @@ def test_pipeline_inference(self): @require_peft_backend @nightly -@require_big_gpu_with_torch_cuda +@require_big_accelerator @require_accelerate @require_gguf_version_greater_or_equal("0.10.0") class FluxControlLoRAGGUFTests(unittest.TestCase): @@ -528,7 +530,7 @@ def test_lora_loading(self): "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, - ).to("cuda") + ).to(torch_device) pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." From 9b2cda6f433e43adb65da87302dfa4555456ea4e Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Thu, 24 Apr 2025 06:55:13 +0000 Subject: [PATCH 4/5] polish code Signed-off-by: Yao Matrix --- tests/quantization/gguf/test_gguf.py | 33 ++++++++++++---------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 5edff276e317..9f54ecf6c67c 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -16,7 +16,6 @@ StableDiffusion3Pipeline, ) from diffusers.utils import load_image -from diffusers.utils.torch_utils import get_device from diffusers.utils.testing_utils import ( Expectations, backend_empty_cache, @@ -28,7 +27,6 @@ numpy_cosine_similarity_distance, require_accelerate, require_big_accelerator, - require_big_gpu_with_torch_cuda, require_gguf_version_greater_or_equal, require_peft_backend, torch_device, @@ -77,16 +75,15 @@ def test_gguf_memory_usage(self): model = self.model_cls.from_single_file( self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype ) - device = get_device() - model.to(device) + model.to(torch_device) assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb inputs = self.get_dummy_inputs() - backend_reset_peak_memory_stats(device) - backend_empty_cache(device) + backend_reset_peak_memory_stats(torch_device) + backend_empty_cache(torch_device) with torch.no_grad(): model(**inputs) - max_memory = backend_max_memory_allocated(device) + max_memory = backend_max_memory_allocated(torch_device) assert (max_memory / 1024**3) < self.expected_memory_use_in_gb def test_keep_modules_in_fp32(self): @@ -110,15 +107,13 @@ def test_dtype_assignment(self): quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) - device = get_device() - with self.assertRaises(ValueError): # Tries with a `dtype` model.to(torch.float16) with self.assertRaises(ValueError): # Tries with a `device` and `dtype` - device_0 = f"{device}:0" + device_0 = f"{torch_device}:0" model.to(device=device_0, dtype=torch.float16) with self.assertRaises(ValueError): @@ -130,7 +125,7 @@ def test_dtype_assignment(self): model.half() # This should work - model.to(device) + model.to(torch_device) def test_dequantize_model(self): quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) @@ -159,11 +154,11 @@ class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): def setUp(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_inputs(self): return { @@ -246,11 +241,11 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase) def setUp(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_inputs(self): return { @@ -365,11 +360,11 @@ class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase def setUp(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_inputs(self): return { @@ -445,11 +440,11 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): def setUp(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_inputs(self): return { From 4be067b1db0af4bd130d124a11acff43cd186fa5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 28 Apr 2025 00:50:56 +0000 Subject: [PATCH 5/5] Apply style fixes --- src/diffusers/quantizers/gguf/gguf_quantizer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py index 554b536974da..97f03b07a345 100644 --- a/src/diffusers/quantizers/gguf/gguf_quantizer.py +++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py @@ -152,7 +152,11 @@ def _dequantize(self, model): logger.info( "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to accelerator. After dequantization, will move the model back to CPU again to preserve the previous device." ) - device = torch.accelerator.current_accelerator() if hasattr(torch, "accelerator") else torch.cuda.current_device() + device = ( + torch.accelerator.current_accelerator() + if hasattr(torch, "accelerator") + else torch.cuda.current_device() + ) model.to(device) model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)