Skip to content

Commit 6f035e8

Browse files
authored
[CPU] Introduce Int4OpaqueTensor to replace Int4CPULayout in AQT (#2798)
* [CPU] Introduce Int4WoqCpuTensor to replace Int4CPULayout in AQT * refine code * refine code * Refine code * Update UT * Rename tensor & format to opaque * Rename OpaqueTensor -> Int4OpaqueTensor
1 parent 9056c46 commit 6f035e8

File tree

5 files changed

+294
-0
lines changed

5 files changed

+294
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
15+
run_tests,
16+
)
17+
18+
from torchao.quantization import (
19+
Int4WeightOnlyConfig,
20+
quantize_,
21+
)
22+
from torchao.quantization.utils import compute_error
23+
from torchao.utils import (
24+
torch_version_at_least,
25+
)
26+
27+
28+
def get_config(group_size):
29+
return Int4WeightOnlyConfig(
30+
group_size=group_size,
31+
packing_format="opaque",
32+
version=2,
33+
)
34+
35+
36+
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Need pytorch 2.6+")
37+
class TestInt4OpaqueTensor(TestCase):
38+
@parametrize(
39+
"sizes",
40+
[
41+
((128,), 256, 128),
42+
((32, 128), 512, 128),
43+
((2, 32, 128), 256, 12),
44+
],
45+
)
46+
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
47+
@parametrize("group_size", [32, 64, 128])
48+
def test_linear(self, sizes, dtype, group_size):
49+
device = "cpu"
50+
M, N, K = sizes
51+
input = torch.randn(*M, K, dtype=dtype, device=device)
52+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
53+
original = linear(input)
54+
quantize_(linear, get_config(group_size))
55+
quantized = linear(input)
56+
self.assertTrue(compute_error(original, quantized) > 20)
57+
58+
compiled_linear = torch.compile(linear)
59+
quantized_and_compiled = compiled_linear(input)
60+
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
61+
62+
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
63+
def test_module_path(self, dtype):
64+
linear = torch.nn.Linear(128, 256, dtype=dtype)
65+
quantize_(linear, get_config(group_size=128))
66+
self.assertEqual(
67+
str(type(linear.weight)),
68+
"<class 'torchao.quantization.Int4OpaqueTensor'>",
69+
)
70+
71+
with tempfile.NamedTemporaryFile() as f:
72+
torch.save(linear.state_dict(), f)
73+
f.seek(0)
74+
state_dict = torch.load(f)
75+
self.assertEqual(
76+
str(type(state_dict["weight"])),
77+
"<class 'torchao.quantization.Int4OpaqueTensor'>",
78+
)
79+
80+
81+
instantiate_parametrized_tests(TestInt4OpaqueTensor)
82+
83+
84+
if __name__ == "__main__":
85+
run_tests()

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
from .quantize_.workflows import (
9292
Float8Tensor,
9393
Int4MarlinSparseTensor,
94+
Int4OpaqueTensor,
9495
Int4PreshuffledTensor,
9596
Int4Tensor,
9697
IntxUnpackedToInt8Tensor,
@@ -164,6 +165,7 @@
164165
"Int4MarlinSparseTensor",
165166
"IntxUnpackedToInt8Tensor",
166167
"Float8Tensor",
168+
"Int4OpaqueTensor",
167169
# smooth quant - subject to change
168170
"get_scale",
169171
"SmoothFakeDynQuantMixin",

torchao/quantization/quant_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from torchao.quantization.quantize_.workflows import (
7474
Float8Tensor,
7575
Int4MarlinSparseTensor,
76+
Int4OpaqueTensor,
7677
Int4PreshuffledTensor,
7778
Int4Tensor,
7879
IntxUnpackedToInt8Tensor,
@@ -1120,6 +1121,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
11201121
block_size,
11211122
)
11221123
return new_weight
1124+
elif packing_format == PackingFormat.OPAQUE:
1125+
new_weight = Int4OpaqueTensor.from_hp(
1126+
weight,
1127+
block_size,
1128+
)
1129+
return new_weight
11231130
else:
11241131
raise ValueError(f"Unsupported packing format: {packing_format}")
11251132

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from .int4.int4_marlin_sparse_tensor import (
66
Int4MarlinSparseTensor,
77
)
8+
from .int4.int4_opaque_tensor import (
9+
Int4OpaqueTensor,
10+
)
811
from .int4.int4_preshuffled_tensor import (
912
Int4PreshuffledTensor,
1013
)
@@ -21,5 +24,7 @@
2124
"Int4MarlinSparseTensor",
2225
"Float8Tensor",
2326
"QuantizeTensorToFloat8Kwargs",
27+
"Int4OpaqueTensor",
28+
"IntxUnpackedTensor",
2429
"IntxUnpackedToInt8Tensor",
2530
]
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import List
9+
10+
import torch
11+
12+
from torchao.quantization.quant_primitives import (
13+
MappingType,
14+
_choose_qparams_affine_tinygemm,
15+
_quantize_affine_tinygemm,
16+
)
17+
from torchao.utils import (
18+
TorchAOBaseTensor,
19+
)
20+
21+
__all__ = [
22+
"Int4OpaqueTensor",
23+
]
24+
25+
aten = torch.ops.aten
26+
27+
28+
class Int4OpaqueTensor(TorchAOBaseTensor):
29+
"""
30+
int4 weight-only quantization on CPU with tinygemm (groupwise quantization only). The packing format is determined on ISA and shape.
31+
This is an opaque tensor subclass, the packing format is not exposed to the rest of the system. See the note below for more details.
32+
33+
Tensor Attributes:
34+
qdata: preshuffled and packed int4 weight for CPU tinygemm kernel, always viewed as a 2D (N, K/2) tensor, last dimension is packed
35+
preshuffling is specific to CPU kernels based on ISA and shape, see Note below.
36+
scale_and_zero: (K/group_size, N, 2), dtype is the same as the original Tensor dtype
37+
38+
Non-Tensor Attributes:
39+
block_size: the block size for quantization, representing the granularity, for groupwise quantization, will have block_size (1, group_size).
40+
we only support group_size = 32/64/128.
41+
shape: shape of the original Tensor
42+
43+
Note on Details for data layout for CPU tinygemm kernel:
44+
45+
We use AVX512 to compute TINYGEMM on CPU. We can also leverage AVX512_VNNI and AMX instructions with torch.compile and max-autotune.
46+
For data locality, we preshuffle the data in plain layout (N, K/2) to (N/block_n, K, block_n/2), where block_n = 64/32/16.
47+
See https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 for more details.
48+
"""
49+
50+
tensor_data_names = ["qdata", "scale_and_zero"]
51+
tensor_attribute_names = ["block_size", "shape"]
52+
53+
def __new__(
54+
cls,
55+
qdata,
56+
scale_and_zero,
57+
block_size,
58+
shape,
59+
):
60+
kwargs = {}
61+
kwargs["device"] = qdata.device
62+
kwargs["dtype"] = scale_and_zero.dtype
63+
kwargs["requires_grad"] = False
64+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
65+
66+
def __init__(
67+
self,
68+
qdata: torch.Tensor,
69+
scale_and_zero: torch.Tensor,
70+
block_size: List[int],
71+
shape: torch.Size,
72+
):
73+
self.qdata = qdata
74+
self.scale_and_zero = scale_and_zero
75+
self.block_size = block_size
76+
77+
def _quantization_type(self):
78+
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
79+
80+
@classmethod
81+
def from_hp(
82+
cls,
83+
w: torch.Tensor,
84+
block_size: List[int],
85+
):
86+
assert w.ndim == 2 and w.device.type == "cpu", (
87+
f"Expecting 2D tensor on CPU, but got: {w.shape} on {w.device.type}"
88+
)
89+
assert len(block_size) == w.ndim
90+
assert block_size[0] == 1 and block_size[1] in (32, 64, 128), (
91+
f"Expecting groupwise quantization with group size = 32/64/128, but got block_size: {block_size}"
92+
)
93+
original_shape = w.shape
94+
mapping_type = MappingType.ASYMMETRIC
95+
target_dtype = torch.int32
96+
quant_min = 0
97+
quant_max = 15
98+
eps = 1e-6
99+
scale_dtype = None
100+
zero_point_dtype = w.dtype
101+
scale, zero_point = _choose_qparams_affine_tinygemm(
102+
w,
103+
mapping_type,
104+
block_size,
105+
target_dtype,
106+
quant_min,
107+
quant_max,
108+
eps,
109+
scale_dtype,
110+
zero_point_dtype,
111+
)
112+
int_data = _quantize_affine_tinygemm(
113+
w,
114+
block_size,
115+
scale,
116+
zero_point,
117+
target_dtype,
118+
quant_min,
119+
quant_max,
120+
)
121+
assert int_data.dtype == torch.int32, (
122+
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
123+
)
124+
packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
125+
int_data,
126+
1, # innerKTiles is not needed for CPU
127+
)
128+
129+
scale = scale.reshape(int_data.shape[0], -1)
130+
zero_point = zero_point.reshape(int_data.shape[0], -1)
131+
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
132+
133+
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype)
134+
return Int4OpaqueTensor(
135+
qdata=packed_weight,
136+
scale_and_zero=scale_and_zero,
137+
block_size=block_size,
138+
shape=original_shape,
139+
)
140+
141+
142+
implements = Int4OpaqueTensor.implements
143+
144+
145+
@implements([torch.nn.functional.linear, aten.linear.default])
146+
def _(func, types, args, kwargs):
147+
input_tensor, weight_tensor, bias = (
148+
args[0],
149+
args[1],
150+
args[2] if len(args) > 2 else None,
151+
)
152+
assert input_tensor.device.type == "cpu", (
153+
f"For CPU device only but got: {input_tensor.device}"
154+
)
155+
assert isinstance(weight_tensor, Int4OpaqueTensor), (
156+
f"Expected weight_tensor to be Int4OpaqueTensor, got: {type(weight_tensor)}"
157+
)
158+
assert weight_tensor.block_size[0] == 1, (
159+
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
160+
)
161+
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
162+
f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
163+
)
164+
165+
act_mat = input_tensor
166+
packed_weight = weight_tensor.qdata
167+
scale_and_zero = weight_tensor.scale_and_zero
168+
169+
orig_act_size = act_mat.size()
170+
orig_dtype = act_mat.dtype
171+
172+
# reshape to 2D
173+
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
174+
175+
# groupwise int4 quantization
176+
groupsize = weight_tensor.block_size[1]
177+
y = torch.ops.aten._weight_int4pack_mm_for_cpu(
178+
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
179+
)
180+
181+
# remove out_feature padding
182+
assert weight_tensor.ndim == 2
183+
orig_out_features = weight_tensor.shape[-2]
184+
y = y[:, :orig_out_features]
185+
y = y.reshape(*orig_act_size[:-1], orig_out_features)
186+
187+
if bias is not None:
188+
y += bias
189+
return y.to(orig_dtype)
190+
191+
192+
Int4OpaqueTensor.__module__ = "torchao.quantization"
193+
194+
# Allow a model with Int4OpaqueTensor weights to be loaded with `weights_only=True`
195+
torch.serialization.add_safe_globals([Int4OpaqueTensor])

0 commit comments

Comments
 (0)