diff --git a/coremltools/converters/mil/frontend/torch/converter.py b/coremltools/converters/mil/frontend/torch/converter.py index a6a4c72da..4dd07c63a 100644 --- a/coremltools/converters/mil/frontend/torch/converter.py +++ b/coremltools/converters/mil/frontend/torch/converter.py @@ -911,110 +911,112 @@ def _interleave_repeat_scale_zp( return scale_repeated, zero_point_repeated def _construct_quantization_op( - self, - weight: np.ndarray, - compression_info: CompressionInfo, - name: str, - compressed_var: Optional[Var] = None, - ) -> Var: - """ - The weight is constructed by `weight = scale * (quantized_data - zero_point)`. - We need to restore the quantized_data to construct the quantization op. + self, + weight: np.ndarray, + compression_info: CompressionInfo, + name: str, + compressed_var: Optional[Var] = None, +) -> Var: + """ + The weight is constructed by `weight = scale * (quantized_data - zero_point)`. + We need to restore the quantized_data to construct the quantization op. - If compressed_var is not None, it's the var constructed by a previous compression function, - which means this is a joint compression. For example, if the compression_info.compression_type - is [CompressionType.PRUNING, CompressionType.QUANTIZATION], the compressed_var is the var - produced by the pruning. - """ - if compression_info.quantization_n_bits is None: - raise ValueError("quantization_n_bits must be specified in quantization.") - if compression_info.quantization_scale is None: - raise ValueError("quantization_scale must be specified in quantization.") - - scale = compression_info.quantization_scale.detach().numpy() - zero_point: Optional[np.ndarray] = None - if compression_info.zero_point is not None: - zero_point = compression_info.zero_point.detach().numpy() - # For conv/conv_transpose, the weight has rank=4, so we auto-expand scale and zero-point if - # it only has two elements. - if len(weight.shape) == 4 and len(scale.shape) == 2: - scale = np.expand_dims(np.expand_dims(scale, axis=-1), axis=-1) - if zero_point is not None: - zero_point = np.expand_dims(np.expand_dims(zero_point, axis=-1), axis=-1) + If compressed_var is not None, it's the var constructed by a previous compression function, + which means this is a joint compression. For example, if the compression_info.compression_type + is [CompressionType.PRUNING, CompressionType.QUANTIZATION], the compressed_var is the var + produced by the pruning. + """ + if compression_info.quantization_n_bits is None: + raise ValueError("quantization_n_bits must be specified in quantization.") + if compression_info.quantization_scale is None: + raise ValueError("quantization_scale must be specified in quantization.") + + scale = compression_info.quantization_scale.detach().numpy() + zero_point: Optional[np.ndarray] = None + if compression_info.zero_point is not None: + zero_point = compression_info.zero_point.detach().numpy() + + # Expand the scale tensor to match the weight tensor's rank + if len(scale.shape) < len(weight.shape): + for _ in range(len(weight.shape) - len(scale.shape)): + scale = np.expand_dims(scale, axis=-1) + if zero_point is not None: + for _ in range(len(weight.shape) - len(zero_point.shape)): + zero_point = np.expand_dims(zero_point, axis=-1) - if compressed_var is not None and compressed_var.op.op_type == "constexpr_lut_to_dense": - # The quantization on lut could lead to extra two dims at the end. - if len(scale.shape) == len(weight.shape) + 2 and scale.shape[-2:] == (1, 1): - scale = np.squeeze(np.squeeze(scale, axis=-1), axis=-1) - if zero_point is not None: - zero_point = np.squeeze(np.squeeze(zero_point, axis=-1), axis=-1) + if compressed_var is not None and compressed_var.op.op_type == "constexpr_lut_to_dense": + # The quantization on lut could lead to extra two dims at the end. + if len(scale.shape) == len(weight.shape) + 2 and scale.shape[-2:] == (1, 1): + scale = np.squeeze(np.squeeze(scale, axis=-1), axis=-1) + if zero_point is not None: + zero_point = np.squeeze(np.squeeze(zero_point, axis=-1), axis=-1) - if len(weight.shape) != len(scale.shape): + if len(weight.shape) != len(scale.shape): + raise ValueError( + f"In {name}, the `weight` should have same rank as `scale`, but got {weight.shape} vs {scale.shape}" + ) + if zero_point is not None: + if len(weight.shape) != len(zero_point.shape): raise ValueError( - f"In {name}, the `weight` should have same rank as `scale`, but got {weight.shape} vs {scale.shape}" + f"In {name}, the `weight` should have same rank as `zero_point`, but got {weight.shape} vs {zero_point.shape}" ) - if zero_point is not None: - if len(weight.shape) != len(zero_point.shape): - raise ValueError( - f"In {name}, the `weight` should have same rank as `zero_point`, but got {weight.shape} vs {zero_point.shape}" - ) - scale_repeated, zero_point_repeated = self._interleave_repeat_scale_zp( - weight, scale, zero_point - ) - quantized_data = np.round(weight / scale_repeated) - if zero_point_repeated is not None: - quantized_data += zero_point_repeated - - # Adjust dtype based on nbits. - dtype_str_prefix = "int" - if quantized_data.min() >= 0 and (zero_point is None or zero_point.min() >= 0): - dtype_str_prefix = "uint" - dtype_str = dtype_str_prefix + str(compression_info.quantization_n_bits) - builtin_dtype = types.string_to_builtin(dtype_str) - np_dtype = types.nptype_from_builtin(builtin_dtype) - - builtin_range = types.type_mapping.builtin_to_range(builtin_dtype) - quantized_data = np.clip(quantized_data, builtin_range.low, builtin_range.high).astype( - np_dtype - ) - if zero_point is not None: - zero_point = zero_point.astype(np_dtype) + scale_repeated, zero_point_repeated = self._interleave_repeat_scale_zp( + weight, scale, zero_point + ) + quantized_data = np.round(weight / scale_repeated) + if zero_point_repeated is not None: + quantized_data += zero_point_repeated + + # Adjust dtype based on nbits. + dtype_str_prefix = "int" + if quantized_data.min() >= 0 and (zero_point is None or zero_point.min() >= 0): + dtype_str_prefix = "uint" + dtype_str = dtype_str_prefix + str(compression_info.quantization_n_bits) + builtin_dtype = types.string_to_builtin(dtype_str) + np_dtype = types.nptype_from_builtin(builtin_dtype) + + builtin_range = types.type_mapping.builtin_to_range(builtin_dtype) + quantized_data = np.clip(quantized_data, builtin_range.low, builtin_range.high).astype( + np_dtype + ) + if zero_point is not None: + zero_point = zero_point.astype(np_dtype) - if compressed_var is None: - return frontend_utils._construct_constexpr_dequant_op( - quantized_data, zero_point, scale, name=name + if compressed_var is None: + return frontend_utils._construct_constexpr_dequant_op( + quantized_data, zero_point, scale, name=name + ) + else: + # Specially handles joint compression, such as using sparse op if joint with pruning. + if compressed_var.op.op_type == "constexpr_sparse_to_dense": + mask, nonzero_data = mb.constexpr_sparse_blockwise_shift_scale( + data_mask=compressed_var.op.mask, + nonzero_data=quantized_data[compressed_var.op.mask.val != 0].flatten(), + scale=scale, + offset=zero_point, + before_op=compressed_var.op, + name=compressed_var.op.name + "_quantized", ) - else: - # Specially handles joint compression, such as using sparse op if joint with pruning. - if compressed_var.op.op_type == "constexpr_sparse_to_dense": - mask, nonzero_data = mb.constexpr_sparse_blockwise_shift_scale( - data_mask=compressed_var.op.mask, - nonzero_data=quantized_data[compressed_var.op.mask.val != 0].flatten(), - scale=scale, - offset=zero_point, - before_op=compressed_var.op, - name=compressed_var.op.name + "_quantized", - ) - return mb.constexpr_sparse_to_dense(nonzero_data=nonzero_data, mask=mask, name=name) - elif compressed_var.op.op_type == "constexpr_lut_to_dense": - if not types.is_int(compressed_var.dtype): - raise ValueError( - "The joint palettization+quantization only supports lut with " - f"int entries, but got {types.builtin_to_string(compressed_var.dtype)}" - ) - return mb.constexpr_blockwise_shift_scale( - data=compressed_var, - scale=scale, - offset=zero_point, - name=name, - ) - else: + return mb.constexpr_sparse_to_dense(nonzero_data=nonzero_data, mask=mask, name=name) + elif compressed_var.op.op_type == "constexpr_lut_to_dense": + if not types.is_int(compressed_var.dtype): raise ValueError( - "Unsupported joint compression combination. The quantization can only be joint " - f"with pruning or palettization, but got {compressed_var.op.op_type}. Please check the value of " - "'compression_type' in your registered buffers." + "The joint palettization+quantization only supports lut with " + f"int entries, but got {types.builtin_to_string(compressed_var.dtype)}" ) + return mb.constexpr_blockwise_shift_scale( + data=compressed_var, + scale=scale, + offset=zero_point, + name=name, + ) + else: + raise ValueError( + "Unsupported joint compression combination. The quantization can only be joint " + f"with pruning or palettization, but got {compressed_var.op.op_type}. Please check the value of " + "'compression_type' in your registered buffers." + ) def _construct_palettization_op( self,