Skip to content

Commit a832bf0

Browse files
committed
[reland] Refactor TorchAOBaseTensor for better BC (#2793)
Summary: After this PR, tensors inheriting from TorchAOBaseTensor will have better support BC, that is if they add some optional tensor data attribute or optional non-tensor attribute, we will still have BC without any additional changes. More Details: The BC story we are looking at is that, after we land some tensor, e.g. Int4Tensor, Float8Tensor, future changes should only add optional Tensor data attributes and optional non-Tensor attributes to the Tensor (other bigger changes will require a version bump, we need to add that too). The current TorchAOBaseTensor doesn’t support this very well. also see #2840 for a real test that adds both an optional tensor and optional non-tensor attribute to Float8Tensor, and the BC test in https://github.com/pytorch/ao/blob/main/test/integration/test_load_and_run_checkpoint.py that tests Float8Tensor does not fail. Docs for current TorchAOBaseTensor: https://github.com/pytorch/ao/blob/e6b38bb0e1477ae6aaca0a3d30de70598be43290/torchao/utils.py#L726-L731 `tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match the `__init__` list of tensor subclass `optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional `tensor_attribute_names` (List[str]): list of names of non-Tensor attributes, order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names` Problems: current optional_tensor_data_names is not truly optional, since it is followed by tensor_attribute_names which contains both required and optional attributes. So if we add a tensor data attribute to Tensor, it will break BC. Here are a few options: ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] optional_tensor_data_names = ["act_scale"] tensor_attribute_names = ["block_size", "shape", "_demo_only_optional_attr"] def __init__(self, qdata, scale, zero_point, act_scale=None, block_size=None, shape=None, _demo_only_optional_attr=None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] optional_tensor_data_names = ["act_scale"] required_tensor_attribute_names = ["block_size", "shape"] optional_tensor_attribute_names = ["_demo_only_optional_attr"] def __init__(self, qdata, scale, zero_point, block_size, shape, act_scale=None, _demo_only_optional_attr = None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] tensor_attribute_names = ["block_size", "shape", "_demo_only_optional_attr"] optional_tensor_data_names = ["act_scale"] def __init__(self, qdata, scale, zero_point, block_size, shape, _demo_only_optional_attr = None, act_scale = None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` Test Plan: python test/integration/test_load_and_run_checkpoint.py Reviewers: Subscribers: Tasks: Tags:
1 parent 27f4d75 commit a832bf0

File tree

6 files changed

+235
-90
lines changed

6 files changed

+235
-90
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def test_nvfp4_swizzled_scales_serialization():
307307
tensor_list, ctx = original_tensor.__tensor_flatten__()
308308

309309
# Verify swizzled flag is preserved in context
310-
assert NVFP4Tensor.tensor_attribute_names[2] == "_is_swizzled_scales"
310+
assert NVFP4Tensor.optional_tensor_attribute_names[0] == "_is_swizzled_scales"
311311
assert ctx[2] == True
312312

313313
# Test deserialization

test/test_utils.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -186,60 +186,103 @@ class MyTensor(TorchAOBaseTensor):
186186
tensor_data_names = ["qdata"]
187187
tensor_attribute_names = ["attr", "device"]
188188

189-
def __new__(cls, qdata, attr, device=None):
189+
def __new__(cls, qdata, attr, device):
190190
shape = qdata.shape
191191
if device is None:
192192
device = qdata.device
193193
kwargs = {"device": device}
194194
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
195195

196-
def __init__(self, qdata, attr, device=None):
196+
def __init__(self, qdata, attr, device):
197197
self.qdata = qdata
198198
self.attr = attr
199199

200200
l = torch.nn.Linear(2, 3)
201-
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
201+
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr", None))
202202
lp_tensor = l.weight
203203

204204
another_tensor = torch.nn.Linear(2, 3).weight
205205
# attribute has to be the same
206-
lp_tensor_for_copy = MyTensor(another_tensor, "attr")
206+
lp_tensor_for_copy = MyTensor(another_tensor, "attr", None)
207207
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
208208

209209
@skip_if_no_cuda()
210210
def test_default_impls_with_optional_data(self):
211211
class MyTensorWithOptionalData(TorchAOBaseTensor):
212212
tensor_data_names = ["qdata"]
213-
optional_tensor_data_names = ["zero_point"]
214213
tensor_attribute_names = ["attr", "device"]
214+
optional_tensor_data_names = ["zero_point"]
215215

216-
def __new__(cls, qdata, zero_point=None, attr=1.0, device=None):
216+
def __new__(cls, qdata, attr, device, zero_point=None):
217217
shape = qdata.shape
218218
if device is None:
219219
device = qdata.device
220220
kwargs = {"device": device}
221221
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
222222

223-
def __init__(self, qdata, zero_point=None, attr=1.0, device=None):
223+
def __init__(self, qdata, attr, device, zero_point=None):
224224
self.qdata = qdata
225+
self.attr = attr
225226
self.zero_point = zero_point
227+
228+
# test both the optional Tensor is None
229+
# and not None
230+
l = torch.nn.Linear(2, 3)
231+
lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, None)
232+
l = torch.nn.Linear(2, 3)
233+
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, "attr", None, None)
234+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
235+
236+
l = torch.nn.Linear(2, 3)
237+
lp_tensor = MyTensorWithOptionalData(
238+
l.weight, "attr", None, torch.zeros_like(l.weight)
239+
)
240+
l = torch.nn.Linear(2, 3)
241+
lp_tensor_for_copy = MyTensorWithOptionalData(
242+
l.weight, "attr", None, torch.zeros_like(l.weight)
243+
)
244+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
245+
246+
@skip_if_no_cuda()
247+
def test_default_impls_with_optional_attr(self):
248+
class MyTensorWithOptionalData(TorchAOBaseTensor):
249+
tensor_data_names = ["qdata"]
250+
tensor_attribute_names = ["attr", "device"]
251+
optional_tensor_data_names = ["zero_point"]
252+
optional_tensor_attribute_names = ["optional_attr"]
253+
254+
def __new__(cls, qdata, attr, device, zero_point=None, optional_attr=None):
255+
shape = qdata.shape
256+
if device is None:
257+
device = qdata.device
258+
kwargs = {"device": device}
259+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
260+
261+
def __init__(
262+
self, qdata, attr, device, zero_point=None, optional_attr=None
263+
):
264+
self.qdata = qdata
226265
self.attr = attr
266+
self.zero_point = zero_point
267+
self.optional_attr = optional_attr
227268

228269
# test both the optional Tensor is None
229270
# and not None
230271
l = torch.nn.Linear(2, 3)
231-
lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr")
272+
lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, zero_point=None)
232273
l = torch.nn.Linear(2, 3)
233-
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr")
274+
lp_tensor_for_copy = MyTensorWithOptionalData(
275+
l.weight, "attr", None, zero_point=None
276+
)
234277
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
235278

236279
l = torch.nn.Linear(2, 3)
237280
lp_tensor = MyTensorWithOptionalData(
238-
l.weight, torch.zeros_like(l.weight), "attr"
281+
l.weight, "attr", None, zero_point=None, optional_attr="value"
239282
)
240283
l = torch.nn.Linear(2, 3)
241284
lp_tensor_for_copy = MyTensorWithOptionalData(
242-
l.weight, torch.zeros_like(l.weight), "attr"
285+
l.weight, "attr", None, zero_point=None, optional_attr="value"
243286
)
244287
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
245288

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,12 @@ class NVFP4Tensor(TorchAOBaseTensor):
7979
"""
8080

8181
tensor_data_names = ["qdata", "_scale_e4m3"]
82-
optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"]
8382
tensor_attribute_names = [
8483
"_block_size",
8584
"_orig_dtype",
85+
]
86+
optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"]
87+
optional_tensor_attribute_names = [
8688
"_is_swizzled_scales",
8789
"use_triton_kernel",
8890
"act_quant_kwargs",
@@ -92,10 +94,10 @@ def __new__(
9294
cls,
9395
qdata,
9496
blockwise_scales,
95-
per_tensor_scale,
96-
act_per_tensor_scale,
9797
block_size,
9898
orig_dtype,
99+
per_tensor_scale,
100+
act_per_tensor_scale,
99101
is_swizzled_scales=False,
100102
use_triton_kernel=False,
101103
act_quant_kwargs=None,
@@ -116,13 +118,13 @@ def __new__(
116118
requires_grad=False,
117119
)
118120

119-
self._scale_e4m3 = blockwise_scales
120-
self._is_swizzled_scales = is_swizzled_scales
121-
self._per_tensor_scale = per_tensor_scale
122-
self._act_per_tensor_scale = act_per_tensor_scale
123121
self.qdata = qdata
122+
self._scale_e4m3 = blockwise_scales
124123
self._block_size = block_size
125124
self._orig_dtype = orig_dtype
125+
self._per_tensor_scale = per_tensor_scale
126+
self._act_per_tensor_scale = act_per_tensor_scale
127+
self._is_swizzled_scales = is_swizzled_scales
126128
self.use_triton_kernel = use_triton_kernel
127129
self.act_quant_kwargs = act_quant_kwargs
128130
return self
@@ -184,10 +186,10 @@ def to_nvfp4(
184186
return NVFP4Tensor(
185187
data_lp,
186188
blockwise_scales,
187-
per_tensor_scale,
188-
act_per_tensor_scale,
189189
block_size,
190190
data_hp.dtype,
191+
per_tensor_scale,
192+
act_per_tensor_scale,
191193
is_swizzled_scales,
192194
use_triton_kernel,
193195
act_quant_kwargs,
@@ -312,10 +314,10 @@ def nvfp4_to_copy(func, types, args, kwargs):
312314
res = NVFP4Tensor(
313315
tensor.qdata,
314316
tensor._scale_e4m3,
315-
tensor._per_tensor_scale,
316-
tensor._act_per_tensor_scale,
317317
tensor._block_size,
318318
dtype,
319+
tensor._per_tensor_scale,
320+
tensor._act_per_tensor_scale,
319321
tensor._is_swizzled_scales,
320322
tensor.use_triton_kernel,
321323
tensor.act_quant_kwargs,
@@ -513,10 +515,10 @@ def nvfp4_slice(func, types, args, kwargs):
513515
result = NVFP4Tensor(
514516
sliced_data,
515517
sliced_scale,
516-
x._per_tensor_scale,
517-
x._act_per_tensor_scale,
518518
x._block_size,
519519
x._orig_dtype,
520+
x._per_tensor_scale,
521+
x._act_per_tensor_scale,
520522
x._is_swizzled_scales,
521523
x.use_triton_kernel,
522524
x.act_quant_kwargs,
@@ -532,10 +534,10 @@ def nvfp4_t(func, types, args, kwargs):
532534
new = NVFP4Tensor(
533535
old.qdata.t(),
534536
old._scale_e4m3,
535-
old._per_tensor_scale,
536-
old._act_per_tensor_scale,
537537
old._block_size,
538538
old._orig_dtype,
539+
old._per_tensor_scale,
540+
old._act_per_tensor_scale,
539541
old._is_swizzled_scales,
540542
old.use_triton_kernel,
541543
old.act_quant_kwargs,
@@ -552,10 +554,10 @@ def nvfp4_view_op(func, types, args, kwargs):
552554
return NVFP4Tensor(
553555
new_data,
554556
args[0]._scale_e4m3,
555-
args[0]._per_tensor_scale,
556-
args[0]._act_per_tensor_scale,
557557
args[0]._block_size,
558558
args[0]._orig_dtype,
559+
args[0]._per_tensor_scale,
560+
args[0]._act_per_tensor_scale,
559561
args[0]._is_swizzled_scales,
560562
args[0].use_triton_kernel,
561563
args[0].act_quant_kwargs,

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class Float8Tensor(TorchAOBaseTensor):
9494
"""
9595

9696
tensor_data_names = ["qdata", "scale"]
97-
tensor_attribute_names = [
97+
tensor_attribute_names = []
98+
optional_tensor_attribute_names = [
9899
"block_size",
99100
"mm_config",
100101
"hp_value_lb",
@@ -106,15 +107,15 @@ class Float8Tensor(TorchAOBaseTensor):
106107

107108
def __new__(
108109
cls,
109-
qdata,
110-
scale,
111-
block_size,
112-
mm_config,
113-
hp_value_lb,
114-
hp_value_ub,
115-
act_quant_kwargs,
116-
kernel_preference,
117-
dtype,
110+
qdata: torch.Tensor,
111+
scale: torch.Tensor,
112+
block_size: Optional[List[int]] = None,
113+
mm_config: Optional[Float8MMConfig] = None,
114+
hp_value_lb: Optional[float] = None,
115+
hp_value_ub: Optional[float] = None,
116+
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
117+
kernel_preference: KernelPreference = KernelPreference.AUTO,
118+
dtype: Optional[torch.dtype] = None,
118119
):
119120
shape = qdata.shape
120121
kwargs = {}

torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
7575
"""
7676

7777
tensor_data_names = ["qdata", "group_scale"]
78-
optional_tensor_data_names = ["group_zero", "row_scale"]
7978
tensor_attribute_names = ["block_size", "shape"]
79+
optional_tensor_data_names = ["group_zero", "row_scale"]
8080

8181
def __new__(
8282
cls,
83-
qdata,
84-
group_scale,
85-
group_zero,
86-
row_scale,
87-
block_size,
88-
shape,
83+
qdata: torch.Tensor,
84+
group_scale: torch.Tensor,
85+
block_size: List[int],
86+
shape: List[int],
87+
group_zero: Optional[torch.Tensor] = None,
88+
row_scale: Optional[torch.Tensor] = None,
8989
):
9090
kwargs = {}
9191
kwargs["device"] = qdata.device
@@ -97,19 +97,19 @@ def __init__(
9797
self,
9898
qdata: torch.Tensor,
9999
group_scale: torch.Tensor,
100-
group_zero: Optional[torch.Tensor],
101-
row_scale: Optional[torch.Tensor],
102100
block_size: List[int],
103101
shape: List[int],
102+
group_zero: Optional[torch.Tensor] = None,
103+
row_scale: Optional[torch.Tensor] = None,
104104
):
105105
# one and only one of group_scale and group_zero should be None
106106
assert group_zero is None or row_scale is None
107107
assert not (group_zero is not None and row_scale is not None)
108108
self.qdata = qdata
109-
self.group_scale = group_scale
110-
self.group_zero = group_zero
111109
self.row_scale = row_scale
112110
self.block_size = block_size
111+
self.group_scale = group_scale
112+
self.group_zero = group_zero
113113

114114
def _quantization_type(self):
115115
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
@@ -178,10 +178,10 @@ def from_hp(
178178
return Int4PreshuffledTensor(
179179
qdata=wq,
180180
group_scale=group_scale,
181-
group_zero=group_zero,
182-
row_scale=row_scale,
183181
block_size=block_size,
184182
shape=original_shape,
183+
group_zero=group_zero,
184+
row_scale=row_scale,
185185
)
186186

187187

0 commit comments

Comments
 (0)