diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 6251da3faa..38eefbff07 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -907,8 +907,8 @@ def test_nvfp4_swizzled_scales_serialization(): tensor_list, ctx = original_tensor.__tensor_flatten__() # Verify swizzled flag is preserved in context - assert NVFP4Tensor.optional_tensor_attribute_names[0] == "_is_swizzled_scales" - assert ctx[2] == True + assert "_is_swizzled_scales" in ctx + assert ctx["_is_swizzled_scales"] == True # Test deserialization inner_tensors = {} diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 443a5f2ec8..1eaa335c1e 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -307,8 +307,8 @@ def test_nvfp4_swizzled_scales_serialization(): tensor_list, ctx = original_tensor.__tensor_flatten__() # Verify swizzled flag is preserved in context - assert NVFP4Tensor.optional_tensor_attribute_names[0] == "_is_swizzled_scales" - assert ctx[2] == True + assert "_is_swizzled_scales" in ctx + assert ctx["_is_swizzled_scales"] == True # Test deserialization inner_tensors = {} diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 97fcbea25b..3f2e8eeef3 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -96,9 +96,9 @@ def __new__( blockwise_scales, block_size, orig_dtype, - per_tensor_scale, - act_per_tensor_scale, - is_swizzled_scales=False, + _per_tensor_scale=None, + _act_per_tensor_scale=None, + _is_swizzled_scales=False, use_triton_kernel=False, act_quant_kwargs=None, ): @@ -122,9 +122,9 @@ def __new__( self._scale_e4m3 = blockwise_scales self._block_size = block_size self._orig_dtype = orig_dtype - self._per_tensor_scale = per_tensor_scale - self._act_per_tensor_scale = act_per_tensor_scale - self._is_swizzled_scales = is_swizzled_scales + self._per_tensor_scale = _per_tensor_scale + self._act_per_tensor_scale = _act_per_tensor_scale + self._is_swizzled_scales = _is_swizzled_scales self.use_triton_kernel = use_triton_kernel self.act_quant_kwargs = act_quant_kwargs return self diff --git a/torchao/utils.py b/torchao/utils.py index 5c84bca8ff..4c401d40cd 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -778,10 +778,10 @@ class variables to define to simplify implmentation of tensor subclasses: Note: Argument order in __init__ and __new__ should match exaclty with tensor_data_names + tensor_attribute_names + optional_tensor_data_names (if present) + optional_tensor_attribute_names (if present) - If `tensor_data_names` and `tensor_attribute_names` are defined, there are some additional + If `tensor_data_names` (torch.Tensor data attribute names) and `tensor_attribute_names` (non-torch.Tensor attribute names) are defined, there are some additional functions that will be added, this includes: `__tensor_flatten__`: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data, - second element is a list of non-Tensor attributes + second element is a dict from attribute_name to non-Tensor attributes `__tensor_unflatten__`: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-tensor attributes, returns a new instance of the subclassed tensor `_apply_fn_to_data`: takes a function (Tensor -> Tensor), applies function to all tensor data and recreate a new subclassed Tensor with the transformed tensor data @@ -871,15 +871,17 @@ def __tensor_flatten__(self): if maybe_tensor is not None: tensor_data_names.append(tensor_data_name) - attrs = [getattr(self, attr) for attr in self.tensor_attribute_names] + attr_dict = { + attr: getattr(self, attr) for attr in self.tensor_attribute_names + } if hasattr(self, "optional_tensor_attribute_names"): - attrs += [ - getattr(self, attr) for attr in self.optional_tensor_attribute_names - ] + attr_dict = attr_dict | { + attr: getattr(self, attr) + for attr in self.optional_tensor_attribute_names + } + + return tensor_data_names, attr_dict - # TODO(future PR): also return names of tensor attributes for easier - # debugging - return tensor_data_names, attrs raise NotImplementedError( "Subclasses should implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class before using it" ) @@ -892,27 +894,30 @@ def __tensor_unflatten__( required_tensors = [ tensor_data_dict[name] for name in cls.tensor_data_names ] - optional_tensors = [] + optional_tensor_dict = {} if hasattr(cls, "optional_tensor_data_names"): - for tensor_data_name in cls.optional_tensor_data_names: - if tensor_data_name in tensor_data_dict: - optional_tensors.append(tensor_data_dict[tensor_data_name]) - else: - optional_tensors.append(None) + optional_tensor_dict = { + tensor_data_name: tensor_data_dict.get(tensor_data_name, None) + for tensor_data_name in cls.optional_tensor_data_names + } - required_attributes = tensor_attributes[: len(cls.tensor_attribute_names)] - optional_attributes = [] + required_attributes = [ + tensor_attributes[name] for name in cls.tensor_attribute_names + ] + optional_attribute_dict = {} if hasattr(cls, "optional_tensor_attribute_names"): - optional_attributes = tensor_attributes[ - len(cls.tensor_attribute_names) : - ] + optional_attribute_dict = { + name: tensor_attributes[name] + for name in cls.optional_tensor_attribute_names + } return cls( *required_tensors, *required_attributes, - *optional_tensors, - *optional_attributes, + **optional_tensor_dict, + **optional_attribute_dict, ) + raise NotImplementedError( "Subclasses should implement __tensor_unflatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class before using it" ) @@ -924,29 +929,30 @@ def _apply_fn_to_data(self, fn): required_tensors = [ fn(getattr(self, attr)) for attr in self.tensor_data_names ] - optional_tensors = [] + optional_tensor_dict = {} if hasattr(self, "optional_tensor_data_names"): for tensor_data_name in self.optional_tensor_data_names: maybe_tensor = getattr(self, tensor_data_name) if maybe_tensor is not None: - optional_tensors.append(fn(maybe_tensor)) + optional_tensor_dict[tensor_data_name] = fn(maybe_tensor) else: - optional_tensors.append(None) + optional_tensor_dict[tensor_data_name] = None required_attributes = [ getattr(self, attr) for attr in self.tensor_attribute_names ] - optional_attributes = [] + optional_attribute_dict = {} if hasattr(self, "optional_tensor_attribute_names"): - optional_attributes = [ - getattr(self, attr) for attr in self.optional_tensor_attribute_names - ] + optional_attribute_dict = { + attr_name: getattr(self, attr_name) + for attr_name in self.optional_tensor_attribute_names + } return self.__class__( *required_tensors, *required_attributes, - *optional_tensors, - *optional_attributes, + **optional_tensor_dict, + **optional_attribute_dict, ) raise NotImplementedError(