Skip to content

[#2452] Fix Rank Mismatch in Quantization for Conv3d Layers #2454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 96 additions & 94 deletions coremltools/converters/mil/frontend/torch/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,110 +911,112 @@ def _interleave_repeat_scale_zp(
return scale_repeated, zero_point_repeated

def _construct_quantization_op(
self,
weight: np.ndarray,
compression_info: CompressionInfo,
name: str,
compressed_var: Optional[Var] = None,
) -> Var:
"""
The weight is constructed by `weight = scale * (quantized_data - zero_point)`.
We need to restore the quantized_data to construct the quantization op.
self,
weight: np.ndarray,
compression_info: CompressionInfo,
name: str,
compressed_var: Optional[Var] = None,
) -> Var:
"""
The weight is constructed by `weight = scale * (quantized_data - zero_point)`.
We need to restore the quantized_data to construct the quantization op.

If compressed_var is not None, it's the var constructed by a previous compression function,
which means this is a joint compression. For example, if the compression_info.compression_type
is [CompressionType.PRUNING, CompressionType.QUANTIZATION], the compressed_var is the var
produced by the pruning.
"""
if compression_info.quantization_n_bits is None:
raise ValueError("quantization_n_bits must be specified in quantization.")
if compression_info.quantization_scale is None:
raise ValueError("quantization_scale must be specified in quantization.")

scale = compression_info.quantization_scale.detach().numpy()
zero_point: Optional[np.ndarray] = None
if compression_info.zero_point is not None:
zero_point = compression_info.zero_point.detach().numpy()
# For conv/conv_transpose, the weight has rank=4, so we auto-expand scale and zero-point if
# it only has two elements.
if len(weight.shape) == 4 and len(scale.shape) == 2:
scale = np.expand_dims(np.expand_dims(scale, axis=-1), axis=-1)
if zero_point is not None:
zero_point = np.expand_dims(np.expand_dims(zero_point, axis=-1), axis=-1)
If compressed_var is not None, it's the var constructed by a previous compression function,
which means this is a joint compression. For example, if the compression_info.compression_type
is [CompressionType.PRUNING, CompressionType.QUANTIZATION], the compressed_var is the var
produced by the pruning.
"""
if compression_info.quantization_n_bits is None:
raise ValueError("quantization_n_bits must be specified in quantization.")
if compression_info.quantization_scale is None:
raise ValueError("quantization_scale must be specified in quantization.")

scale = compression_info.quantization_scale.detach().numpy()
zero_point: Optional[np.ndarray] = None
if compression_info.zero_point is not None:
zero_point = compression_info.zero_point.detach().numpy()

# Expand the scale tensor to match the weight tensor's rank
if len(scale.shape) < len(weight.shape):
for _ in range(len(weight.shape) - len(scale.shape)):
scale = np.expand_dims(scale, axis=-1)
if zero_point is not None:
for _ in range(len(weight.shape) - len(zero_point.shape)):
zero_point = np.expand_dims(zero_point, axis=-1)

if compressed_var is not None and compressed_var.op.op_type == "constexpr_lut_to_dense":
# The quantization on lut could lead to extra two dims at the end.
if len(scale.shape) == len(weight.shape) + 2 and scale.shape[-2:] == (1, 1):
scale = np.squeeze(np.squeeze(scale, axis=-1), axis=-1)
if zero_point is not None:
zero_point = np.squeeze(np.squeeze(zero_point, axis=-1), axis=-1)
if compressed_var is not None and compressed_var.op.op_type == "constexpr_lut_to_dense":
# The quantization on lut could lead to extra two dims at the end.
if len(scale.shape) == len(weight.shape) + 2 and scale.shape[-2:] == (1, 1):
scale = np.squeeze(np.squeeze(scale, axis=-1), axis=-1)
if zero_point is not None:
zero_point = np.squeeze(np.squeeze(zero_point, axis=-1), axis=-1)

if len(weight.shape) != len(scale.shape):
if len(weight.shape) != len(scale.shape):
raise ValueError(
f"In {name}, the `weight` should have same rank as `scale`, but got {weight.shape} vs {scale.shape}"
)
if zero_point is not None:
if len(weight.shape) != len(zero_point.shape):
raise ValueError(
f"In {name}, the `weight` should have same rank as `scale`, but got {weight.shape} vs {scale.shape}"
f"In {name}, the `weight` should have same rank as `zero_point`, but got {weight.shape} vs {zero_point.shape}"
)
if zero_point is not None:
if len(weight.shape) != len(zero_point.shape):
raise ValueError(
f"In {name}, the `weight` should have same rank as `zero_point`, but got {weight.shape} vs {zero_point.shape}"
)

scale_repeated, zero_point_repeated = self._interleave_repeat_scale_zp(
weight, scale, zero_point
)
quantized_data = np.round(weight / scale_repeated)
if zero_point_repeated is not None:
quantized_data += zero_point_repeated

# Adjust dtype based on nbits.
dtype_str_prefix = "int"
if quantized_data.min() >= 0 and (zero_point is None or zero_point.min() >= 0):
dtype_str_prefix = "uint"
dtype_str = dtype_str_prefix + str(compression_info.quantization_n_bits)
builtin_dtype = types.string_to_builtin(dtype_str)
np_dtype = types.nptype_from_builtin(builtin_dtype)

builtin_range = types.type_mapping.builtin_to_range(builtin_dtype)
quantized_data = np.clip(quantized_data, builtin_range.low, builtin_range.high).astype(
np_dtype
)
if zero_point is not None:
zero_point = zero_point.astype(np_dtype)
scale_repeated, zero_point_repeated = self._interleave_repeat_scale_zp(
weight, scale, zero_point
)
quantized_data = np.round(weight / scale_repeated)
if zero_point_repeated is not None:
quantized_data += zero_point_repeated

# Adjust dtype based on nbits.
dtype_str_prefix = "int"
if quantized_data.min() >= 0 and (zero_point is None or zero_point.min() >= 0):
dtype_str_prefix = "uint"
dtype_str = dtype_str_prefix + str(compression_info.quantization_n_bits)
builtin_dtype = types.string_to_builtin(dtype_str)
np_dtype = types.nptype_from_builtin(builtin_dtype)

builtin_range = types.type_mapping.builtin_to_range(builtin_dtype)
quantized_data = np.clip(quantized_data, builtin_range.low, builtin_range.high).astype(
np_dtype
)
if zero_point is not None:
zero_point = zero_point.astype(np_dtype)

if compressed_var is None:
return frontend_utils._construct_constexpr_dequant_op(
quantized_data, zero_point, scale, name=name
if compressed_var is None:
return frontend_utils._construct_constexpr_dequant_op(
quantized_data, zero_point, scale, name=name
)
else:
# Specially handles joint compression, such as using sparse op if joint with pruning.
if compressed_var.op.op_type == "constexpr_sparse_to_dense":
mask, nonzero_data = mb.constexpr_sparse_blockwise_shift_scale(
data_mask=compressed_var.op.mask,
nonzero_data=quantized_data[compressed_var.op.mask.val != 0].flatten(),
scale=scale,
offset=zero_point,
before_op=compressed_var.op,
name=compressed_var.op.name + "_quantized",
)
else:
# Specially handles joint compression, such as using sparse op if joint with pruning.
if compressed_var.op.op_type == "constexpr_sparse_to_dense":
mask, nonzero_data = mb.constexpr_sparse_blockwise_shift_scale(
data_mask=compressed_var.op.mask,
nonzero_data=quantized_data[compressed_var.op.mask.val != 0].flatten(),
scale=scale,
offset=zero_point,
before_op=compressed_var.op,
name=compressed_var.op.name + "_quantized",
)
return mb.constexpr_sparse_to_dense(nonzero_data=nonzero_data, mask=mask, name=name)
elif compressed_var.op.op_type == "constexpr_lut_to_dense":
if not types.is_int(compressed_var.dtype):
raise ValueError(
"The joint palettization+quantization only supports lut with "
f"int entries, but got {types.builtin_to_string(compressed_var.dtype)}"
)
return mb.constexpr_blockwise_shift_scale(
data=compressed_var,
scale=scale,
offset=zero_point,
name=name,
)
else:
return mb.constexpr_sparse_to_dense(nonzero_data=nonzero_data, mask=mask, name=name)
elif compressed_var.op.op_type == "constexpr_lut_to_dense":
if not types.is_int(compressed_var.dtype):
raise ValueError(
"Unsupported joint compression combination. The quantization can only be joint "
f"with pruning or palettization, but got {compressed_var.op.op_type}. Please check the value of "
"'compression_type' in your registered buffers."
"The joint palettization+quantization only supports lut with "
f"int entries, but got {types.builtin_to_string(compressed_var.dtype)}"
)
return mb.constexpr_blockwise_shift_scale(
data=compressed_var,
scale=scale,
offset=zero_point,
name=name,
)
else:
raise ValueError(
"Unsupported joint compression combination. The quantization can only be joint "
f"with pruning or palettization, but got {compressed_var.op.op_type}. Please check the value of "
"'compression_type' in your registered buffers."
)

def _construct_palettization_op(
self,
Expand Down