Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.quantization import (
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
torch_version_at_least,
)


def get_config(group_size):
return Int4WeightOnlyConfig(
group_size=group_size,
packing_format="plain_int32",
version=2,
)


@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
class Int4PlainInt32Tensor(TestCase):
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 12),
],
)
@parametrize("dtype", [torch.bfloat16, torch.half])
@parametrize("group_size", [32, 64, 128])
def test_linear(self, sizes, dtype, group_size):
device = "xpu"
M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(group_size))
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)

@parametrize("dtype", [torch.bfloat16, torch.half])
def test_module_path(self, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu")
quantize_(linear, get_config(group_size=128))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)


instantiate_parametrized_tests(Int4PlainInt32Tensor)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
Float8Tensor,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Int4PlainInt32Tensor,
Int4PreshuffledTensor,
Int4Tensor,
IntxOpaqueTensor,
Expand Down Expand Up @@ -162,6 +163,7 @@
"FbgemmConfig",
# tensor subclasses
"Int4Tensor",
"Int4PlainInt32Tensor",
"Int4PreshuffledTensor",
"Int4MarlinSparseTensor",
"IntxOpaqueTensor",
Expand Down
8 changes: 7 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
Float8Tensor,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Int4PlainInt32Tensor,
Int4PreshuffledTensor,
Int4Tensor,
IntxOpaqueTensor,
Expand Down Expand Up @@ -521,7 +522,6 @@ def quantize_(
torch._C._log_api_usage_once("torchao.quantization.quantize_")

filter_fn = _is_linear if filter_fn is None else filter_fn

if isinstance(config, ModuleFqnToConfig):
_replace_with_custom_fn_if_matches_filter_with_name(
model,
Expand Down Expand Up @@ -1130,6 +1130,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
block_size,
)
return new_weight
elif packing_format == PackingFormat.PLAIN_INT32:
new_weight = Int4PlainInt32Tensor.from_hp(
weight,
block_size,
)
return new_weight
elif packing_format == PackingFormat.MARLIN_SPARSE:
new_weight = Int4MarlinSparseTensor.from_hp(
weight,
Expand Down
6 changes: 6 additions & 0 deletions torchao/quantization/quantize_/common/packing_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class PackingFormat(str, Enum):
"""
UNPACKED_TO_INT8 = "unpacked_to_int8"

"""
plain_int32 is referring to the format used by int4 weight-only quantization.
which is a groupwise quantization format 2*int4 is store in a byte and 4*(int4*2) is stored in a int32.
"""
PLAIN_INT32 = "plain_int32"

"""
Opaque packing format that's used for tensors that does not have a predefined packing format
(that may be decided on hardware, tensor shape, library availability etc.) and it's not
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from .int4.int4_opaque_tensor import (
Int4OpaqueTensor,
)
from .int4.int4_plain_int32_tensor import (
Int4PlainInt32Tensor,
)
from .int4.int4_preshuffled_tensor import (
Int4PreshuffledTensor,
)
Expand All @@ -25,6 +28,7 @@
"Int4Tensor",
"Int4PreshuffledTensor",
"Int4MarlinSparseTensor",
"Int4PlainInt32Tensor",
"Float8Tensor",
"QuantizeTensorToFloat8Kwargs",
"IntxOpaqueTensor",
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/quantize_/workflows/int4/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .int4_plain_int32_tensor import Int4PlainInt32Tensor
from .int4_preshuffled_tensor import Int4PreshuffledTensor
from .int4_tensor import Int4Tensor

__all__ = [
"Int4PreshuffledTensor",
"Int4Tensor",
"Int4PlainInt32Tensor",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.


from typing import List

import torch

from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
quantize_affine,
)
from torchao.utils import (
TorchAOBaseTensor,
)

__all__ = [
"Int4PlainInt32Tensor",
]

aten = torch.ops.aten


class Int4PlainInt32Tensor(TorchAOBaseTensor):
"""
int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only)

Tensor Attributes:
qdata: (N, K/8), packed int4 weight, the data type is int32 here with 4*(int4*2)
scale: (K/group_size, N), dtype is the same as the original Tensor dtype
zero_point: (K/group_size, N), dtype is int8

Non-Tensor Attributes:
block_size: the block size for quantization, representing the granularity.
shape: shape of the original Tensor

"""

tensor_data_names = ["qdata", "scale", "zero_point"]
tensor_attribute_names = ["block_size", "shape"]

def __new__(
cls,
qdata,
scale,
zero_point,
block_size,
shape,
):
kwargs = {}
kwargs["device"] = qdata.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, qdata, scale, zero_point, block_size, shape):
self.qdata = qdata
self.scale = scale
self.zero_point = zero_point
self.block_size = block_size

def _quantization_type(self):
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"

@classmethod
def from_hp(
cls,
w: torch.Tensor,
block_size: List[int],
):
assert w.ndim == 2 and w.device.type == "xpu", (
f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}"
)
assert len(block_size) == w.ndim

original_shape = w.shape
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
scale_dtype = None
zero_point_dtype = torch.int32
scale, zero_point = choose_qparams_affine(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)
int_data = quantize_affine(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype"
)
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8
)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
return Int4PlainInt32Tensor(
packed_weight,
scale.transpose(0, 1).contiguous(),
zero_point.transpose(0, 1).contiguous().to(torch.int8),
block_size,
original_shape,
)


implements = Int4PlainInt32Tensor.implements


@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
assert input_tensor.device.type == "xpu", (
f"For XPU device only but got: {input_tensor.device}"
)
assert isinstance(weight_tensor, Int4PlainInt32Tensor), (
f"Expected weight_tensor to be Int4PlainInt32Tensor, got: {type(weight_tensor)}"
)
assert weight_tensor.block_size[0] == 1, (
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
)
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
)

act_mat = input_tensor
packed_weight = weight_tensor.qdata
scale = weight_tensor.scale
zero_point = weight_tensor.zero_point

orig_act_size = act_mat.size()
orig_dtype = act_mat.dtype

# reshape to 2D
act_mat = act_mat.reshape(-1, act_mat.shape[-1])

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]
y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
act_mat, packed_weight, groupsize, scale, zero_point
)

# remove out_feature padding
assert weight_tensor.ndim == 2
orig_out_features = weight_tensor.shape[-2]
y = y[:, :orig_out_features]
y = y.reshape(*orig_act_size[:-1], orig_out_features)

if bias is not None:
y += bias
return y.to(orig_dtype)


Int4PlainInt32Tensor.__module__ = "torchao.quantization"

# Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Int4PlainInt32Tensor])
Loading