Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,8 +907,8 @@ def test_nvfp4_swizzled_scales_serialization():
tensor_list, ctx = original_tensor.__tensor_flatten__()

# Verify swizzled flag is preserved in context
assert NVFP4Tensor.optional_tensor_attribute_names[0] == "_is_swizzled_scales"
assert ctx[2] == True
assert "_is_swizzled_scales" in ctx
assert ctx["_is_swizzled_scales"] == True

# Test deserialization
inner_tensors = {}
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def test_nvfp4_swizzled_scales_serialization():
tensor_list, ctx = original_tensor.__tensor_flatten__()

# Verify swizzled flag is preserved in context
assert NVFP4Tensor.optional_tensor_attribute_names[0] == "_is_swizzled_scales"
assert ctx[2] == True
assert "_is_swizzled_scales" in ctx
assert ctx["_is_swizzled_scales"] == True

# Test deserialization
inner_tensors = {}
Expand Down
12 changes: 6 additions & 6 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def __new__(
blockwise_scales,
block_size,
orig_dtype,
per_tensor_scale,
act_per_tensor_scale,
is_swizzled_scales=False,
_per_tensor_scale=None,
_act_per_tensor_scale=None,
_is_swizzled_scales=False,
use_triton_kernel=False,
act_quant_kwargs=None,
):
Expand All @@ -122,9 +122,9 @@ def __new__(
self._scale_e4m3 = blockwise_scales
self._block_size = block_size
self._orig_dtype = orig_dtype
self._per_tensor_scale = per_tensor_scale
self._act_per_tensor_scale = act_per_tensor_scale
self._is_swizzled_scales = is_swizzled_scales
self._per_tensor_scale = _per_tensor_scale
self._act_per_tensor_scale = _act_per_tensor_scale
self._is_swizzled_scales = _is_swizzled_scales
self.use_triton_kernel = use_triton_kernel
self.act_quant_kwargs = act_quant_kwargs
return self
Expand Down
68 changes: 37 additions & 31 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,10 @@ class variables to define to simplify implmentation of tensor subclasses:
Note: Argument order in __init__ and __new__ should match exaclty with tensor_data_names + tensor_attribute_names + optional_tensor_data_names (if present) + optional_tensor_attribute_names (if present)


If `tensor_data_names` and `tensor_attribute_names` are defined, there are some additional
If `tensor_data_names` (torch.Tensor data attribute names) and `tensor_attribute_names` (non-torch.Tensor attribute names) are defined, there are some additional
functions that will be added, this includes:
`__tensor_flatten__`: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data,
second element is a list of non-Tensor attributes
second element is a dict from attribute_name to non-Tensor attributes
`__tensor_unflatten__`: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-tensor attributes, returns a new instance of the subclassed tensor
`_apply_fn_to_data`: takes a function (Tensor -> Tensor), applies function to all tensor data and
recreate a new subclassed Tensor with the transformed tensor data
Expand Down Expand Up @@ -871,15 +871,17 @@ def __tensor_flatten__(self):
if maybe_tensor is not None:
tensor_data_names.append(tensor_data_name)

attrs = [getattr(self, attr) for attr in self.tensor_attribute_names]
attr_dict = {
attr: getattr(self, attr) for attr in self.tensor_attribute_names
}
if hasattr(self, "optional_tensor_attribute_names"):
attrs += [
getattr(self, attr) for attr in self.optional_tensor_attribute_names
]
attr_dict = attr_dict | {
attr: getattr(self, attr)
for attr in self.optional_tensor_attribute_names
}

return tensor_data_names, attr_dict

# TODO(future PR): also return names of tensor attributes for easier
# debugging
return tensor_data_names, attrs
raise NotImplementedError(
"Subclasses should implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class before using it"
)
Expand All @@ -892,27 +894,30 @@ def __tensor_unflatten__(
required_tensors = [
tensor_data_dict[name] for name in cls.tensor_data_names
]
optional_tensors = []
optional_tensor_dict = {}
if hasattr(cls, "optional_tensor_data_names"):
for tensor_data_name in cls.optional_tensor_data_names:
if tensor_data_name in tensor_data_dict:
optional_tensors.append(tensor_data_dict[tensor_data_name])
else:
optional_tensors.append(None)
optional_tensor_dict = {
tensor_data_name: tensor_data_dict.get(tensor_data_name, None)
for tensor_data_name in cls.optional_tensor_data_names
}

required_attributes = tensor_attributes[: len(cls.tensor_attribute_names)]
optional_attributes = []
required_attributes = [
tensor_attributes[name] for name in cls.tensor_attribute_names
]
optional_attribute_dict = {}
if hasattr(cls, "optional_tensor_attribute_names"):
optional_attributes = tensor_attributes[
len(cls.tensor_attribute_names) :
]
optional_attribute_dict = {
name: tensor_attributes[name]
for name in cls.optional_tensor_attribute_names
}

return cls(
*required_tensors,
*required_attributes,
*optional_tensors,
*optional_attributes,
**optional_tensor_dict,
**optional_attribute_dict,
)

raise NotImplementedError(
"Subclasses should implement __tensor_unflatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class before using it"
)
Expand All @@ -924,29 +929,30 @@ def _apply_fn_to_data(self, fn):
required_tensors = [
fn(getattr(self, attr)) for attr in self.tensor_data_names
]
optional_tensors = []
optional_tensor_dict = {}
if hasattr(self, "optional_tensor_data_names"):
for tensor_data_name in self.optional_tensor_data_names:
maybe_tensor = getattr(self, tensor_data_name)
if maybe_tensor is not None:
optional_tensors.append(fn(maybe_tensor))
optional_tensor_dict[tensor_data_name] = fn(maybe_tensor)
else:
optional_tensors.append(None)
optional_tensor_dict[tensor_data_name] = None

required_attributes = [
getattr(self, attr) for attr in self.tensor_attribute_names
]
optional_attributes = []
optional_attribute_dict = {}
if hasattr(self, "optional_tensor_attribute_names"):
optional_attributes = [
getattr(self, attr) for attr in self.optional_tensor_attribute_names
]
optional_attribute_dict = {
attr_name: getattr(self, attr_name)
for attr_name in self.optional_tensor_attribute_names
}

return self.__class__(
*required_tensors,
*required_attributes,
*optional_tensors,
*optional_attributes,
**optional_tensor_dict,
**optional_attribute_dict,
)

raise NotImplementedError(
Expand Down
Loading