Skip to content

❓ [Question] Manually Annotate Quantization Parameters in FX Graph #3522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
patrick-botco opened this issue May 16, 2025 · 11 comments
Open
Assignees
Labels
question Further information is requested

Comments

@patrick-botco
Copy link

patrick-botco commented May 16, 2025

❓ Question

is there a way to manually annotate quantization parameters that will be respected throughout torch_tensorrt conversion (e.g. manually adding q/dq nodes, or specifying some tensor metadata) via dynamo? thank you!

@patrick-botco patrick-botco added the question Further information is requested label May 16, 2025
@patrick-botco patrick-botco changed the title ❓ [Question] Quantization IR w/ Dynamo ❓ [Question] Manually Annotate Quantization Parameters in FX Graph May 16, 2025
@patrick-botco
Copy link
Author

cc @narendasan @peri044 maybe? 🙏

@narendasan
Copy link
Collaborator

This should be possible as this is what the tensorrt model optimizer toolkit effectively does. @peri044 or @lanluo-nvidia could maybe give more specific guidance.

@peri044
Copy link
Collaborator

peri044 commented May 19, 2025

We currently use NVIDIA Model optimizer toolkit which inserts quantization nodes within the torch model using quantize API

  1. https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/9c54aa1c47871d0541801a20962996461d805162/modelopt/torch/quantization/model_quant.py#L126
  2. https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/9c54aa1c47871d0541801a20962996461d805162/modelopt/torch/quantization/tensor_quant.py#L229-L243 (definition of custom ops which do the quantization). We have converters for these quantization custom ops (which call Q & DQ apis in TensorRT).

You can also manually insert a quantization custom op by implementing a lowering pass which adds these nodes to the torch.fx.GraphModule and implement/register a custom converter for it. You can append custom metadata to this node by updating node.meta["val"]

  1. https://docs.pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html (existing lowering passes)
  2. https://docs.pytorch.org/TensorRT/contributors/dynamo_converters.html
    This can be done outside Torch-TRT codebase using the decorations listed above to register your lowering pass/ converter.

Please let me know if you have any further questions.

@patrick-botco
Copy link
Author

hey @peri044 , thanks for the response. i tried modelopt -> export on a simple model below. am i using this wrong or missing something obvious? im using non-strict export (strict runs into torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)), but hitting ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'torch.Size'>. thanks!

import modelopt.torch.quantization as mtq
import torch
from modelopt.torch.quantization.utils import export_torch_mode


class JustAConv(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, inputs):
        return self.conv(inputs)


if __name__ == "__main__":
    model = JustAConv().to("cuda").eval()
    sample_input = torch.ones(1, 3, 224, 224).to("cuda")
    quant_cfg = mtq.INT8_DEFAULT_CFG
    mtq.quantize(
        model,
        quant_cfg,
        forward_loop=lambda model: model(sample_input),
    )

    with torch.no_grad():
        with export_torch_mode():
            exported_program = torch.export.export(model, (sample_input,), strict=False)

@lanluo-nvidia
Copy link
Collaborator

@patrick-botco I have tried your example with our latest main, when strict=False it is working as expected.
I guess your error might be related to your specific version.
Could you please let me know your version?

@patrick-botco
Copy link
Author

hey @lanluo-nvidia thanks for checking! here are my pytorch and modelopt versions:

nvidia-modelopt           0.29.0
nvidia-modelopt-core      0.29.0
torch                     2.5.1

@lanluo-nvidia
Copy link
Collaborator

@patrick-botco
I also remembered that torch.export.export fails on strict=False somepoint around torch 2.5
If you cannot use higher torch version, then the following workaround might help to bypass the torch.export.export error.

from torch.export._trace import _export
exp_program = _export(model, (input_tensor,))

@patrick-botco
Copy link
Author

patrick-botco commented Jun 1, 2025

thanks @lanluo-nvidia - upgrading to torch 2.6 resolves the issue. compiling the exported program gives me something unexpected though.

for reference, the model (after mtq.quantize()) is:

JustAConv(
  (conv): QuantConv2d(
    3, 3, kernel_size=(3, 3), stride=(1, 1)
    (input_quantizer): TensorQuantizer(8 bit fake per-tensor amax=1.0000 calibrator=MaxCalibrator quant)
    (output_quantizer): TensorQuantizer(disabled)
    (weight_quantizer): TensorQuantizer(8 bit fake axis=0 amax=[0.1883, 0.1920](3) calibrator=MaxCalibrator quant)
  )
)

the issue: compiling the exported program

            # continuing from above
            trt_model = torch_tensorrt.dynamo.compile(
                exported_program,
                inputs=(sample_input,),
                enabled_precisions={torch.int8},
                min_block_size=1,
                debug=True,
            )

the initial lowering passes look good

graph():
    %conv_weight : [num_users=1] = get_attr[target=conv.weight]
    %conv_bias : [num_users=1] = get_attr[target=conv.bias]
    %conv_input_quantizer__amax : [num_users=1] = get_attr[target=conv.input_quantizer._amax]
    %conv_weight_quantizer__amax : [num_users=1] = get_attr[target=conv.weight_quantizer._amax]
    %inputs : [num_users=1] = placeholder[target=inputs]
    %quantize_op : [num_users=1] = call_function[target=torch.ops.tensorrt.quantize_op.default](args = (%inputs, %conv_input_quantizer__amax, 8, 0, False, False), kwargs = {})
    %quantize_op_1 : [num_users=1] = call_function[target=torch.ops.tensorrt.quantize_op.default](args = (%conv_weight, %conv_weight_quantizer__amax, 8, 0, False, False), kwargs = {})
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%quantize_op, %quantize_op_1, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    return (convolution,)

however; after constant folding, %quantize_op_1 is optimized away, resulting in %_frozen_param0

graph():
    %conv_bias : [num_users=1] = get_attr[target=conv.bias]
    %conv_input_quantizer__amax : [num_users=1] = get_attr[target=conv.input_quantizer._amax]
    %inputs : [num_users=1] = placeholder[target=inputs]
    %quantize_op : [num_users=1] = call_function[target=torch.ops.tensorrt.quantize_op.default](args = (%inputs, %conv_input_quantizer__amax, 8, 0, False, False), kwargs = {})
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%quantize_op, %_frozen_param0, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    return (convolution,)

per-channel weight quantization is not respected - it seems like %_frozen_param0 is float32 (_frozen_param0: (3, 3, 3, 3)@float32)

INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node conv/convolution [aten.convolution.default] (Inputs: (quantize_op: (1, 3, 224, 224)@torch.float32, _frozen_param0: (3, 3, 3, 3)@float32, conv_bias: (3,)@float32, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) | Outputs: (convolution: (1, 3, 222, 222)@torch.float32))

more importantly, the gemm kernel itself is f32f32_f32f32_f32 (obtained through torch.profiler). the i8 layout conversion of cuInt8::nchwToNcqhw4 and cuInt8::ncqhw4ToNchw makes it seem like we're doing fake quantization

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
sm80_xmma_fprop_implicit_gemm_f32f32_f32f32_f32_nchw...         0.00%       0.000us         0.00%       0.000us       0.000us       7.968us        60.73%       7.968us       7.968us             1  
cuInt8::nchwToNcqhw4(float const*, unsigned int*, in...         0.00%       0.000us         0.00%       0.000us       0.000us       2.784us        21.22%       2.784us       2.784us             1  
cuInt8::ncqhw4ToNchw(signed char const*, float*, int...         0.00%       0.000us         0.00%       0.000us       0.000us       2.368us        18.05%       2.368us       2.368us             1  
                           cudaStreamCreateWithPriority        90.65%       9.377ms        90.65%       9.377ms      73.255us       0.000us         0.00%       0.000us       0.000us           128  
                                        cudaEventRecord         0.07%       7.380us         0.07%       7.380us       3.690us       0.000us         0.00%       0.000us       0.000us             2  
                                    cudaStreamWaitEvent         0.08%       8.602us         0.08%       8.602us       4.301us       0.000us         0.00%       0.000us       0.000us             2  
                                       cudaLaunchKernel         0.49%      50.595us         8.81%     911.265us     455.632us       0.000us         0.00%       0.000us       0.000us             2  
                                           Unrecognized         8.32%     860.670us         8.32%     860.670us     215.168us       0.000us         0.00%       0.000us       0.000us             4  
                                       cuLaunchKernelEx         0.07%       7.510us         0.07%       7.510us       7.510us       0.000us         0.00%       0.000us       0.000us             1  
                                  cudaDeviceSynchronize         0.31%      32.371us         0.31%      32.371us      32.371us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

do you happen to know what the issue is? am i using this wrong / missing something? thanks! cc @peri044 @narendasan as well 🙏

i am using these versions to test:

torch                     2.6.0
torch_tensorrt            2.6.0
nvidia-modelopt           0.29.0
nvidia-modelopt-core      0.29.0

@lanluo-nvidia
Copy link
Collaborator

@patrick-botco
yes, it should avoid constant folding, the fix is already in another PR(

torch.ops.tensorrt.quantize_op.default,
),
Let me first create a separate bug fixing PR for you, so that it can be merged to main asap.

@lanluo-nvidia
Copy link
Collaborator

lanluo-nvidia commented Jun 1, 2025

Here is the PR raised:
#3543
I have verified with your example that the kernel invoked is: sm80_xmma_fprop_implicit_gemm_interleaved_i8f32_i8i32_f32

Name: conv.weight_quantizer/quantize_op_1 + [QUANTIZE]-[aten_ops.quantize_op.default]-[conv.weight_quantizer/quantize_op_1_quantize] + [CONVOLUTION]-[aten_ops.convolution.default]-[conv/convolution], LayerType: CaskConvolution, Inputs: [ { Name: (Unnamed Layer* 1) [Quantize]_output, Location: Device, Dimensions: [1,3,224,224], Format/Datatype: Int8 }, { Name: (Unnamed Layer* 7) [Constant]_output, Location: Device, Dimensions: [3], Format/Datatype: Float }], Outputs: [ { Name: output0, Location: Device, Dimensions: [1,3,222,222], Format/Datatype: Float }], ParameterType: Convolution, Kernel: [3,3], PaddingMode: kEXPLICIT_ROUND_DOWN, PrePadding: [0,0], PostPadding: [0,0], Stride: [1,1], Dilation: [1,1], OutMaps: 3, Groups: 1, Weights: {"Type": "Int8", "Count": 81}, Bias: {"Type": "Float", "Count": 0}, HasBias: 0, HasReLU: 0, HasSparseWeights: 0, HasDynamicFilter: 0, HasDynamicBias: 1, HasResidual: 0, ConvXAsActInputIdx: -1, BiasAsActInputIdx: -1, ResAsActInputIdx: -1, Activation: NONE, TacticName: sm80_xmma_fprop_implicit_gemm_interleaved_i8f32_i8i32_f32_nchw_vect_c_32kcrs_vect_c_32_nchw_tilesize128x32x64_stage4_warpsize4x1x1_g1_tensor16x8x32_t1r3s3_alignc4, TacticValue: 0xa8b56a226b057463, StreamId: 0, Metadata:

@patrick-botco
Copy link
Author

thanks so much @lanluo-nvidia !

@lanluo-nvidia lanluo-nvidia self-assigned this Jun 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants