Skip to content

TorchAO compile + offloading tests #11697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,15 +866,17 @@ def test_fp4_double_safe(self):

@require_torch_version_greater("2.7.1")
class Bnb4BitCompileTests(QuantCompileTests):
quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer", "text_encoder_2"],
)
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer", "text_encoder_2"],
)

def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
Expand All @@ -883,5 +885,7 @@ def test_torch_compile(self):
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)

def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, use_stream=True
)
18 changes: 10 additions & 8 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,11 +831,13 @@ def test_serialization_sharded(self):

@require_torch_version_greater_equal("2.6.0")
class Bnb8BitCompileTests(QuantCompileTests):
quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=["transformer", "text_encoder_2"],
)
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=["transformer", "text_encoder_2"],
)

def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
Expand All @@ -847,7 +849,7 @@ def test_torch_compile_with_cpu_offload(self):
)

@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
)
13 changes: 9 additions & 4 deletions tests/quantization/test_torch_compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
@require_torch_gpu
@slow
class QuantCompileTests(unittest.TestCase):
quantization_config = None
@property
def quantization_config(self):
raise NotImplementedError(
"This property should be implemented in the subclass to return the appropriate quantization config."
)
Comment on lines +27 to +31
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌


def setUp(self):
super().setUp()
Expand Down Expand Up @@ -64,16 +68,17 @@ def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)

def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
def _test_torch_compile_with_group_offload_leaf(
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
):
torch._dynamo.config.cache_size_limit = 10000

pipe = self._init_pipeline(quantization_config, torch_dtype)
group_offload_kwargs = {
"onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": True,
"non_blocking": True,
"use_stream": use_stream,
}
pipe.transformer.enable_group_offload(**group_offload_kwargs)
pipe.transformer.compile()
Expand Down
51 changes: 51 additions & 0 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import List

import numpy as np
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel

from diffusers import (
Expand All @@ -29,6 +30,7 @@
TorchAoConfig,
)
from diffusers.models.attention_processor import Attention
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_synchronize,
Expand All @@ -44,6 +46,8 @@
torch_device,
)

from ..test_torch_compile_utils import QuantCompileTests


enable_full_determinism()

Expand Down Expand Up @@ -625,6 +629,53 @@ def test_int_a16w8_cpu(self):
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)


@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
},
)

def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)

@unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling."
)
def test_torch_compile_with_cpu_offload(self):
# RuntimeError: _apply(): Couldn't swap Linear.weight
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)

@unittest.skip(
"""
For `use_stream=False`:
- Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
For `use_stream=True`:
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
"""
)
@parameterized.expand([False, True])
def test_torch_compile_with_group_offload_leaf(self):
# For use_stream=False:
# If we run group offloading without compilation, we will see:
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
# When running with compilation, the error ends up being different:
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
# Looks like something that will have to be looked into upstream.
# for linear layers, weight.tensor_impl shows cuda... but:
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu

# For use_stream=True:
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)


# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
Expand Down