Skip to content

enable 28 GGUF test cases on XPU #11404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 28, 2025
7 changes: 4 additions & 3 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,19 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
)

weight_on_cpu = False
if not module.weight.is_cuda:
if module.weight.device.type == "cpu":
weight_on_cpu = True

device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
if is_bnb_4bit_quantized:
module_weight = dequantize_bnb_weight(
module.weight.cuda() if weight_on_cpu else module.weight,
module.weight.to(device) if weight_on_cpu else module.weight,
state=module.weight.quant_state,
dtype=model.dtype,
).data
elif is_gguf_quantized:
module_weight = dequantize_gguf_tensor(
module.weight.cuda() if weight_on_cpu else module.weight,
module.weight.to(device) if weight_on_cpu else module.weight,
)
module_weight = module_weight.to(model.dtype)
else:
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/quantizers/gguf/gguf_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,14 @@ def _dequantize(self, model):
is_model_on_cpu = model.device.type == "cpu"
if is_model_on_cpu:
logger.info(
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to accelerator. After dequantization, will move the model back to CPU again to preserve the previous device."
)
model.to(torch.cuda.current_device())
device = (
torch.accelerator.current_accelerator()
if hasattr(torch, "accelerator")
else torch.cuda.current_device()
)
model.to(device)

model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
if is_model_on_cpu:
Expand Down
145 changes: 96 additions & 49 deletions tests/quantization/gguf/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
enable_full_determinism,
is_gguf_available,
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_big_gpu_with_torch_cuda,
require_big_accelerator,
require_gguf_version_greater_or_equal,
require_peft_backend,
torch_device,
Expand All @@ -31,9 +36,11 @@
if is_gguf_available():
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter

enable_full_determinism()


@nightly
@require_big_gpu_with_torch_cuda
@require_big_accelerator
@require_accelerate
@require_gguf_version_greater_or_equal("0.10.0")
class GGUFSingleFileTesterMixin:
Expand Down Expand Up @@ -68,15 +75,15 @@ def test_gguf_memory_usage(self):
model = self.model_cls.from_single_file(
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
)
model.to("cuda")
model.to(torch_device)
assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb
inputs = self.get_dummy_inputs()

torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
backend_reset_peak_memory_stats(torch_device)
backend_empty_cache(torch_device)
with torch.no_grad():
model(**inputs)
max_memory = torch.cuda.max_memory_allocated()
max_memory = backend_max_memory_allocated(torch_device)
assert (max_memory / 1024**3) < self.expected_memory_use_in_gb

def test_keep_modules_in_fp32(self):
Expand Down Expand Up @@ -106,7 +113,8 @@ def test_dtype_assignment(self):

with self.assertRaises(ValueError):
# Tries with a `device` and `dtype`
model.to(device="cuda:0", dtype=torch.float16)
device_0 = f"{torch_device}:0"
model.to(device=device_0, dtype=torch.float16)

with self.assertRaises(ValueError):
# Tries with a cast
Expand All @@ -117,7 +125,7 @@ def test_dtype_assignment(self):
model.half()

# This should work
model.to("cuda")
model.to(torch_device)

def test_dequantize_model(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
Expand Down Expand Up @@ -146,11 +154,11 @@ class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):

def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def get_dummy_inputs(self):
return {
Expand Down Expand Up @@ -233,11 +241,11 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase)

def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def get_dummy_inputs(self):
return {
Expand Down Expand Up @@ -267,40 +275,79 @@ def test_pipeline_inference(self):

prompt = "a cat holding a sign that says hello"
output = pipe(
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
prompt=prompt,
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images[0]
output_slice = output[:3, :3, :].flatten()
expected_slice = np.array(
[
0.17578125,
0.27539062,
0.27734375,
0.11914062,
0.26953125,
0.25390625,
0.109375,
0.25390625,
0.25,
0.15039062,
0.26171875,
0.28515625,
0.13671875,
0.27734375,
0.28515625,
0.12109375,
0.26757812,
0.265625,
0.16210938,
0.29882812,
0.28515625,
0.15625,
0.30664062,
0.27734375,
0.14648438,
0.29296875,
0.26953125,
]
expected_slices = Expectations(
{
("xpu", 3): np.array(
[
0.19335938,
0.3125,
0.3203125,
0.1328125,
0.3046875,
0.296875,
0.11914062,
0.2890625,
0.2890625,
0.16796875,
0.30273438,
0.33203125,
0.14648438,
0.31640625,
0.33007812,
0.12890625,
0.3046875,
0.30859375,
0.17773438,
0.33789062,
0.33203125,
0.16796875,
0.34570312,
0.32421875,
0.15625,
0.33203125,
0.31445312,
]
),
("cuda", 7): np.array(
[
0.17578125,
0.27539062,
0.27734375,
0.11914062,
0.26953125,
0.25390625,
0.109375,
0.25390625,
0.25,
0.15039062,
0.26171875,
0.28515625,
0.13671875,
0.27734375,
0.28515625,
0.12109375,
0.26757812,
0.265625,
0.16210938,
0.29882812,
0.28515625,
0.15625,
0.30664062,
0.27734375,
0.14648438,
0.29296875,
0.26953125,
]
),
}
)
expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4

Expand All @@ -313,11 +360,11 @@ class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase

def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def get_dummy_inputs(self):
return {
Expand Down Expand Up @@ -393,11 +440,11 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):

def setUp(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def get_dummy_inputs(self):
return {
Expand Down Expand Up @@ -463,7 +510,7 @@ def test_pipeline_inference(self):

@require_peft_backend
@nightly
@require_big_gpu_with_torch_cuda
@require_big_accelerator
@require_accelerate
@require_gguf_version_greater_or_equal("0.10.0")
class FluxControlLoRAGGUFTests(unittest.TestCase):
Expand All @@ -478,7 +525,7 @@ def test_lora_loading(self):
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
).to(torch_device)
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
Expand Down
Loading