Skip to content

Commit f259aaa

Browse files
jerryzh168liangel-02
authored andcommitted
Refactor TorchAOBaseTensor for better BC support (#2793)
Summary: After this PR, tensors inheriting from TorchAOBaseTensor will have better support BC, that is if they add some optional tensor data attribute or optional non-tensor attribute, we will still have BC without any additional changes. More Details: The BC story we are looking at is that, after we land some tensor, e.g. Int4Tensor, Float8Tensor, future changes should only add optional Tensor data attributes and optional non-Tensor attributes to the Tensor (other bigger changes will require a version bump, we need to add that too). The current TorchAOBaseTensor doesn’t support this very well. also see #2840 for a real test that adds both an optional tensor and optional non-tensor attribute to Float8Tensor, and the BC test in https://github.com/pytorch/ao/blob/main/test/integration/test_load_and_run_checkpoint.py that tests Float8Tensor does not fail. Docs for current TorchAOBaseTensor: https://github.com/pytorch/ao/blob/e6b38bb0e1477ae6aaca0a3d30de70598be43290/torchao/utils.py#L726-L731 `tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match the `__init__` list of tensor subclass `optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional `tensor_attribute_names` (List[str]): list of names of non-Tensor attributes, order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names` Problems: current optional_tensor_data_names is not truly optional, since it is followed by tensor_attribute_names which contains both required and optional attributes. So if we add a tensor data attribute to Tensor, it will break BC. Here are a few options: ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] optional_tensor_data_names = ["act_scale"] tensor_attribute_names = ["block_size", "shape", "_demo_only_optional_attr"] def __init__(self, qdata, scale, zero_point, act_scale=None, block_size=None, shape=None, _demo_only_optional_attr=None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] optional_tensor_data_names = ["act_scale"] required_tensor_attribute_names = ["block_size", "shape"] optional_tensor_attribute_names = ["_demo_only_optional_attr"] def __init__(self, qdata, scale, zero_point, block_size, shape, act_scale=None, _demo_only_optional_attr = None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] tensor_attribute_names = ["block_size", "shape", "_demo_only_optional_attr"] optional_tensor_data_names = ["act_scale"] def __init__(self, qdata, scale, zero_point, block_size, shape, _demo_only_optional_attr = None, act_scale = None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` Test Plan: python test/integration/test_load_and_run_checkpoint.py Reviewers: Subscribers: Tasks: Tags:
1 parent f46fa25 commit f259aaa

File tree

4 files changed

+215
-72
lines changed

4 files changed

+215
-72
lines changed

test/test_utils.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -186,60 +186,103 @@ class MyTensor(TorchAOBaseTensor):
186186
tensor_data_names = ["qdata"]
187187
tensor_attribute_names = ["attr", "device"]
188188

189-
def __new__(cls, qdata, attr, device=None):
189+
def __new__(cls, qdata, attr, device):
190190
shape = qdata.shape
191191
if device is None:
192192
device = qdata.device
193193
kwargs = {"device": device}
194194
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
195195

196-
def __init__(self, qdata, attr, device=None):
196+
def __init__(self, qdata, attr, device):
197197
self.qdata = qdata
198198
self.attr = attr
199199

200200
l = torch.nn.Linear(2, 3)
201-
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
201+
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr", None))
202202
lp_tensor = l.weight
203203

204204
another_tensor = torch.nn.Linear(2, 3).weight
205205
# attribute has to be the same
206-
lp_tensor_for_copy = MyTensor(another_tensor, "attr")
206+
lp_tensor_for_copy = MyTensor(another_tensor, "attr", None)
207207
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
208208

209209
@skip_if_no_cuda()
210210
def test_default_impls_with_optional_data(self):
211211
class MyTensorWithOptionalData(TorchAOBaseTensor):
212212
tensor_data_names = ["qdata"]
213-
optional_tensor_data_names = ["zero_point"]
214213
tensor_attribute_names = ["attr", "device"]
214+
optional_tensor_data_names = ["zero_point"]
215215

216-
def __new__(cls, qdata, zero_point=None, attr=1.0, device=None):
216+
def __new__(cls, qdata, attr, device, zero_point=None):
217217
shape = qdata.shape
218218
if device is None:
219219
device = qdata.device
220220
kwargs = {"device": device}
221221
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
222222

223-
def __init__(self, qdata, zero_point=None, attr=1.0, device=None):
223+
def __init__(self, qdata, attr, device, zero_point=None):
224224
self.qdata = qdata
225+
self.attr = attr
225226
self.zero_point = zero_point
227+
228+
# test both the optional Tensor is None
229+
# and not None
230+
l = torch.nn.Linear(2, 3)
231+
lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, None)
232+
l = torch.nn.Linear(2, 3)
233+
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, "attr", None, None)
234+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
235+
236+
l = torch.nn.Linear(2, 3)
237+
lp_tensor = MyTensorWithOptionalData(
238+
l.weight, "attr", None, torch.zeros_like(l.weight)
239+
)
240+
l = torch.nn.Linear(2, 3)
241+
lp_tensor_for_copy = MyTensorWithOptionalData(
242+
l.weight, "attr", None, torch.zeros_like(l.weight)
243+
)
244+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
245+
246+
@skip_if_no_cuda()
247+
def test_default_impls_with_optional_attr(self):
248+
class MyTensorWithOptionalData(TorchAOBaseTensor):
249+
tensor_data_names = ["qdata"]
250+
tensor_attribute_names = ["attr", "device"]
251+
optional_tensor_data_names = ["zero_point"]
252+
optional_tensor_attribute_names = ["optional_attr"]
253+
254+
def __new__(cls, qdata, attr, device, zero_point=None, optional_attr=None):
255+
shape = qdata.shape
256+
if device is None:
257+
device = qdata.device
258+
kwargs = {"device": device}
259+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
260+
261+
def __init__(
262+
self, qdata, attr, device, zero_point=None, optional_attr=None
263+
):
264+
self.qdata = qdata
226265
self.attr = attr
266+
self.zero_point = zero_point
267+
self.optional_attr = optional_attr
227268

228269
# test both the optional Tensor is None
229270
# and not None
230271
l = torch.nn.Linear(2, 3)
231-
lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr")
272+
lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, zero_point=None)
232273
l = torch.nn.Linear(2, 3)
233-
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr")
274+
lp_tensor_for_copy = MyTensorWithOptionalData(
275+
l.weight, "attr", None, zero_point=None
276+
)
234277
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
235278

236279
l = torch.nn.Linear(2, 3)
237280
lp_tensor = MyTensorWithOptionalData(
238-
l.weight, torch.zeros_like(l.weight), "attr"
281+
l.weight, "attr", None, zero_point=None, optional_attr="value"
239282
)
240283
l = torch.nn.Linear(2, 3)
241284
lp_tensor_for_copy = MyTensorWithOptionalData(
242-
l.weight, torch.zeros_like(l.weight), "attr"
285+
l.weight, "attr", None, zero_point=None, optional_attr="value"
243286
)
244287
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
245288

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class Float8Tensor(TorchAOBaseTensor):
9494
"""
9595

9696
tensor_data_names = ["qdata", "scale"]
97-
tensor_attribute_names = [
97+
tensor_attribute_names = []
98+
optional_tensor_attribute_names = [
9899
"block_size",
99100
"mm_config",
100101
"hp_value_lb",
@@ -106,15 +107,15 @@ class Float8Tensor(TorchAOBaseTensor):
106107

107108
def __new__(
108109
cls,
109-
qdata,
110-
scale,
111-
block_size,
112-
mm_config,
113-
hp_value_lb,
114-
hp_value_ub,
115-
act_quant_kwargs,
116-
kernel_preference,
117-
dtype,
110+
qdata: torch.Tensor,
111+
scale: torch.Tensor,
112+
block_size: Optional[List[int]] = None,
113+
mm_config: Optional[Float8MMConfig] = None,
114+
hp_value_lb: Optional[float] = None,
115+
hp_value_ub: Optional[float] = None,
116+
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
117+
kernel_preference: KernelPreference = KernelPreference.AUTO,
118+
dtype: Optional[torch.dtype] = None,
118119
):
119120
shape = qdata.shape
120121
kwargs = {}

torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
7575
"""
7676

7777
tensor_data_names = ["qdata", "group_scale"]
78-
optional_tensor_data_names = ["group_zero", "row_scale"]
7978
tensor_attribute_names = ["block_size", "shape"]
79+
optional_tensor_data_names = ["group_zero", "row_scale"]
8080

8181
def __new__(
8282
cls,
83-
qdata,
84-
group_scale,
85-
group_zero,
86-
row_scale,
87-
block_size,
88-
shape,
83+
qdata: torch.Tensor,
84+
group_scale: torch.Tensor,
85+
block_size: List[int],
86+
shape: List[int],
87+
group_zero: Optional[torch.Tensor] = None,
88+
row_scale: Optional[torch.Tensor] = None,
8989
):
9090
kwargs = {}
9191
kwargs["device"] = qdata.device
@@ -97,19 +97,19 @@ def __init__(
9797
self,
9898
qdata: torch.Tensor,
9999
group_scale: torch.Tensor,
100-
group_zero: Optional[torch.Tensor],
101-
row_scale: Optional[torch.Tensor],
102100
block_size: List[int],
103101
shape: List[int],
102+
group_zero: Optional[torch.Tensor] = None,
103+
row_scale: Optional[torch.Tensor] = None,
104104
):
105105
# one and only one of group_scale and group_zero should be None
106106
assert group_zero is None or row_scale is None
107107
assert not (group_zero is not None and row_scale is not None)
108108
self.qdata = qdata
109-
self.group_scale = group_scale
110-
self.group_zero = group_zero
111109
self.row_scale = row_scale
112110
self.block_size = block_size
111+
self.group_scale = group_scale
112+
self.group_zero = group_zero
113113

114114
def _quantization_type(self):
115115
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
@@ -178,10 +178,10 @@ def from_hp(
178178
return Int4PreshuffledTensor(
179179
qdata=wq,
180180
group_scale=group_scale,
181-
group_zero=group_zero,
182-
row_scale=row_scale,
183181
block_size=block_size,
184182
shape=original_shape,
183+
group_zero=group_zero,
184+
row_scale=row_scale,
185185
)
186186

187187

0 commit comments

Comments
 (0)