Skip to content

Commit

Permalink
addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Jan 7, 2025
1 parent ee5d49d commit 984c8d2
Showing 1 changed file with 80 additions and 137 deletions.
217 changes: 80 additions & 137 deletions docs/quantization.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
# Quantization in StableHLO
# StableHLO Quantization

StableHLO quantization follows the [LiteRT quantization
specification](https://ai.google.dev/edge/litert/models/quantization_spec),
using a uniform quantization scheme with support for both per-tensor and
per-axis quantization. It inherits its type expression from MLIR's [Quant
dialect](https://mlir.llvm.org/docs/Dialects/QuantDialect/), providing a
standardized way to represent quantized data types.

## Uniform Quantization
## Quantization Types in StableHLO

Quantization is a technique to optimize machine learning models by
converting floating-point numbers (like those used in original models)
into lower-precision integers. This reduces memory usage and speeds up
computations, making models more efficient for deployment on devices with
limited resources.

One common approach is uniform quantization, where we map the
floating-point values to integers that are evenly spaced. To do this, we
use two key quantization parameters:
StableHLO quantization follows the [LiteRT quantization
specification](https://ai.google.dev/edge/litert/models/quantization_spec),
using a uniform quantization scheme with support for both per-tensor and
per-axis quantization. It inherits its type expression from MLIR's [Quant
dialect](https://mlir.llvm.org/docs/Dialects/QuantDialect/), providing a
standardized way to represent quantized data types.

Uniform quantization maps floating-point values to integers using a uniform step
size, resulting in evenly spaced quantized values. This is achieved through an
affine ralationship using two key quantization parameters.

- **Scale:** This determines the step size between consecutive quantized values.
A smaller scale means the quantized values are closer together, allowing for
finer-grained representation.
- **Zero Point:** This integer value represents the real value 0 in the
quantized space.
Uniform quantization simplifies the representation of floating-point numbers by
mapping them to integers that are evenly spaced. This mapping is achieved
through an affine transformation that uses two key parameters: **scale** and
**zero point**. The scale determines determines the step size between
consecutive quantized values. A smaller scale means the quantized values are
closer together. The zero point defines the integer value that represents zero
in the original floating-point space.

The relationship between the original floating-point value (`real_value`) and
the quantized integer value (`quantized_value`) in uniform quantization is:

```python
real_value = scale * (quantized_value + zero_point)
real_value = scale * (quantized_value - zero_point)
```

### Per-tensor Quantization
Expand All @@ -56,27 +58,27 @@ along a specific dimension `quantized_dimension` of the tensor. This is
particularly useful when values vary significantly across different dimensions,
allowing for better preservation of information and accuracy.

Consider a tensor t with dimensions sizes `[4, 3, 2, 1]`. We choose to quantize
Consider a tensor t with dimensions sizes `[4, 3, 2]`. We choose to quantize
this tensor along the second dimension (`quantized_dimension = 1`). This means
we'll have three slices (since the second dimension has a size of 3), each with
its own scale and zero point:

```python
t[:, 0, :, :]: This slice gets scale[0] and zero_point[0].
t[:, 1, :, :]: This slice gets scale[1] and zero_point[1].
t[:, 2, :, :]: This slice gets scale[2] and zero_point[2].
t[:, 0, :]: This slice gets scale[0] and zero_point[0].
t[:, 1, :]: This slice gets scale[1] and zero_point[1].
t[:, 2, :]: This slice gets scale[2] and zero_point[2].
```

In StableHLO, per-axis quantized type is expressed as:

```mlir
!quant.uniform<storage_type:expressed_type:quantized_dimension, {scale:zero_point, scale:zero_point, ...}>
!quant.uniform<storage_type:expressed_type:quantized_dimension, {scale0:zero_point0, scale1:zero_point1, ...}>
```

where the length of the `scale:zero_point` matches the number of slices along
the `quantized_dimension` of the containing tensor.

**Example**: `!quant.uniform<i8:f32:1, {0.2:20, 0.1:10, 0.3:30}>`
**Example**: `tensor<4x3x2x!quant.uniform<i8:f32:1, {0.2:20, 0.1:10, 0.3:30}>>`

**Note**: StableHLO will soon support _sub-channel quantization_, which allows
for quantization along a subset of dimensions. This feature is currently in
Expand All @@ -89,103 +91,29 @@ StableHLO provides several compiler passes which allow for different
transformations and optimizations related to quantization, giving you
flexibility in how you handle quantized models. These passes are:

### `stablehlo-legalize-qdq-to-quantized-op` [code](https://github.com/openxla/stablehlo/blob/main/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp)
### `stablehlo-legalize-qdq-to-quantized-op`

This pass fuses a common pattern in quantized models, a dequantize operation
followed by a floating-point operation, and finally a quantize operation, into
a single quantized operation.

**Example:**
a single quantized operation. [details](https://openxla.org/stablehlo/generated/stablehlo_passes#-stablehlo-legalize-qdq-to-quantized-op)

```mlir
// Before the pass
func.func @add(%arg0: tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>) -> tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>> {
%0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>) -> tensor<16x16xf32>
%1 = stablehlo.abs %0 : tensor<16x16xf32>
%2 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>
func.return %2 : tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>
}
// After the pass
func.func @add(%arg0: tensor<16x16x!quant.uniform<u8:f32, 3.400000e+01:16>>) -> tensor<16x16x!quant.uniform<u8:f32, 3.400000e+01:16>> {
%0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform<u8:f32, 3.400000e+01:16>>) -> tensor<16x16xf32>
%1 = stablehlo.abs %0 : tensor<16x16xf32>
%2 = stablehlo.abs %arg0 : tensor<16x16x!quant.uniform<u8:f32, 3.400000e+01:16>>
%3 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform<u8:f32, 3.400000e+01:16>>
return %2 : tensor<16x16x!quant.uniform<u8:f32, 3.400000e+01:16>>
}
```

### stablehlo-legalize-quantized-op-to-qdq [code](https://github.com/openxla/stablehlo/blob/main/stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp)
### stablehlo-legalize-quantized-op-to-qdq

This pass does the opposite of the previous pass. It decomposes a quantized
operation into its equivalent sequence of dequantize, floating-point operation,
and quantize operations.
and quantize operations.
[details](https://openxla.org/stablehlo/generated/stablehlo_passes#-stablehlo-legalize-quantized-op-to-qdq)

**Example:**

```mlir
// Before the pass
func.func @add(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}
// After the pass
func.func @add(%arg0: tensor<!quant.uniform<i8:f32, 1.000000e+00>>, %arg1: tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>> {
%0 = stablehlo.uniform_dequantize %arg0 : (tensor<!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<f32>
%1 = stablehlo.uniform_dequantize %arg1 : (tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.uniform_quantize %2 : (tensor<f32>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
return %3 : tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
}
```

### stablehlo-legalize-quant-to-math [code](https://github.com/openxla/stablehlo/blob/main/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp)
### stablehlo-legalize-quant-to-math

This pass converts StableHLO operations on quantized types into equivalent
operations on integer types. It essentially implements the quantization
arithmetic using standard mathematical operations. This decompsition is useful
for systems that do not support quantization natively, but can still use the
quantization arithmetic to express the semantics of quantized models.
[details](https://openxla.org/stablehlo/generated/stablehlo_passes#-stablehlo-legalize-quant-to-math)

**Example:**

```mlir
// Before the pass
func.func @add(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}
// After the pass
func.func @add(%arg0: tensor<i8>, %arg1: tensor<i8>) -> tensor<i8> {
%0 = stablehlo.convert %arg0 : (tensor<i8>) -> tensor<f32>
%cst = stablehlo.constant dense<0.333333343> : tensor<f32>
%1 = chlo.broadcast_multiply %0, %cst : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
%2 = chlo.broadcast_add %1, %cst_0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%3 = stablehlo.round_nearest_even %2 : tensor<f32>
%4 = stablehlo.convert %3 : (tensor<f32>) -> tensor<i32>
%5 = stablehlo.convert %arg1 : (tensor<i8>) -> tensor<f32>
%cst_1 = stablehlo.constant dense<0.666666686> : tensor<f32>
%6 = chlo.broadcast_multiply %5, %cst_1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_2 = stablehlo.constant dense<1.33333337> : tensor<f32>
%7 = chlo.broadcast_add %6, %cst_2 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%8 = stablehlo.round_nearest_even %7 : tensor<f32>
%9 = stablehlo.convert %8 : (tensor<f32>) -> tensor<i32>
%c = stablehlo.constant dense<2> : tensor<i32>
%10 = chlo.broadcast_add %4, %9 : (tensor<i32>, tensor<i32>) -> tensor<i32>
%11 = chlo.broadcast_subtract %10, %c : (tensor<i32>, tensor<i32>) -> tensor<i32>
%c_3 = stablehlo.constant dense<-128> : tensor<i32>
%c_4 = stablehlo.constant dense<127> : tensor<i32>
%12 = stablehlo.clamp %c_3, %11, %c_4 : tensor<i32>
%13 = stablehlo.convert %12 : (tensor<i32>) -> tensor<i8>
return %13 : tensor<i8>
}
```

## stablehlo-quant-legalize-to-tosa-rescale [code](https://github.com/openxla/stablehlo/blob/main/stablehlo/conversions/tosa/transforms/StablehloQuantLegalizeToTosaRescale.cpp)
## stablehlo-quant-legalize-to-tosa-rescale

StableHLO offers the capability to legalize quantized operations to their
corresponding representations in the [TOSA
Expand All @@ -196,42 +124,14 @@ StableHLO and TOSA operations, with the TOSA dialect primarily employed for the
`rescale` operation. The `tosa.rescale` op plays a crucial role in adjusting the
scale and zero point of quantized values, enabling accurate representation of
quantized data within the TOSA framework.
[details](https://openxla.org/stablehlo/generated/stablehlo_tosa_passes#-stablehlo-quant-legalize-to-tosa-rescale)

**Example:**

Consider the following StableHLO code snippet:

```mlir
// Before the pass
func.func @add(%arg0 : tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>,
%arg1 : tensor<2x2x!quant.uniform<i8:f32, 0.075:-1>>) -> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>, tensor<2x2x!quant.uniform<i8:f32, 0.075:-1>>) -> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>>
}
// After the pass
func.func @add(%arg0: tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>,
%arg1 : tensor<2x2x!quant.uniform<i8:f32, 0.075:-1>>) -> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>> {
%r0 = "tosa.rescale"(%0) {double_round = true, input_zp = -1 : i32, multiplier = array<i32: 1431655765>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 13>} : (tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>) -> tensor<2x2xi32>
%r1 = "tosa.rescale"(%1) {double_round = true, input_zp = -1 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 11>} : (tensor<2x2x!quant.uniform<i8:f32, 0.075:-1>>) -> tensor<2x2xi32>
%add = "stablehlo.add"(%r0, %r1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
%r2 = "tosa.rescale"(%add) {double_round = true, input_zp = -1 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 50>} : (tensor<2x2xi32>) -> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>>
return %r2 : tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>>
}
```

The `stablehlo.add` operation now works on integer types (`i32`), and the
`tosa.rescale` operations are strategically placed to manage the necessary
scaling and zero-point adjustments for accurate quantization.

## tosa-rescale-legalize-to-stablehlo [code](https://github.com/openxla/stablehlo/blob/main/stablehlo/conversions/tosa/transforms/TosaRescaleLegalizeToStablehlo.cpp)
## tosa-rescale-legalize-to-stablehlo

This pass rewrites TOSA rescale operations to StableHLO primitive math
operations. One of the main use cases for this pass is to allow the StableHLO
interpreter to evaluate programs containing TOSA rescale operations.
[details](https://openxla.org/stablehlo/generated/stablehlo_tosa_passes#-tosa-rescale-legalize-to-stablehlo)

## Evaluating Quantized Programs

Expand Down Expand Up @@ -264,3 +164,46 @@ stablehlo-canonicalize-dynamism
**Note:** There is an ongoing effort to improve the efficiency of this lowering
process. You can track the progress in this [open
issue](https://github.com/openxla/stablehlo/issues/2390).

## Quantized Test Cases

StableHLO provides a comprehensive suite of quantized test cases to validate the
correctness and behavior of quantized operations. These test cases serve as unit
tests, covering various StableHLO operations in quantized scenarios.

A typical example of a quantized test case looks like

```mlir
func.func @main() -> tensor<11xf32> {
%operand_0 = stablehlo.constant dense<...> : tensor<11xf32>
%operand_1 = stablehlo.constant dense<...> : tensor<11xf32>
%golden = stablehlo.constant dense<...> : tensor<11xf32>
%0 = stablehlo.uniform_quantize %operand_0 : (tensor<11xf32>) -> tensor<11x!quant.uniform<i8:f32, 0.3>>
%1 = stablehlo.uniform_quantize %operand_1 : (tensor<11xf32>) -> tensor<11x!quant.uniform<i8:f32, 0.3>>
%2 = stablehlo.add %1, %0 : tensor<11x!quant.uniform<i8:f32, 0.3>>
%result = stablehlo.uniform_dequantize %2 : (tensor<11x!quant.uniform<i8:f32, 0.3>>) -> tensor<11xf32>
%4 = stablehlo.custom_call @check.eq(%golden, %result) : (tensor<11xf32>, tensor<11xf32>) -> tensor<i1>
return %3 : tensor<11xf32>
}
```

and includes:

- **Input data:** Representative input values for the operation.
- **Golden output:** The expected output of the operation when applied to the
input data, complying with the StableHLO reference interpreter and the HLO
evaluator.

These test cases are valuable for:

- **Validating StableHLO quantization:** Ensuring that the quantization behavior
of StableHLO operations aligns with the expected results.
- **Cross-validation:** Comparing the behavior of StableHLO quantization with
other implementations or frameworks.
- **Debugging and development:** Aiding in the development and debugging of new
quantization features or optimizations.

0 comments on commit 984c8d2

Please sign in to comment.