Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] API Usage Error when using bfloat16 and use_explicit_typing #3439

Open
HolyWu opened this issue Mar 12, 2025 · 1 comment · May be fixed by #3445
Open

🐛 [Bug] API Usage Error when using bfloat16 and use_explicit_typing #3439

HolyWu opened this issue Mar 12, 2025 · 1 comment · May be fixed by #3445
Assignees
Labels
bug Something isn't working

Comments

@HolyWu
Copy link
Contributor

HolyWu commented Mar 12, 2025

To Reproduce

from __future__ import annotations

import os

import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"

dtype = torch.bfloat16
device = torch.device("cuda", 0)


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


with torch.inference_mode():
    model = MyModule().eval().to(device, dtype)
    inputs = (torch.randn(1, 3, 224, 224, dtype=dtype, device=device),)
    exported_program = torch.export.export(model, inputs)

    trt_model = torch_tensorrt.dynamo.compile(
        exported_program,
        inputs,
        device=device,
        enabled_precisions={torch.float32},
        debug=True,
        min_block_size=1,
        use_explicit_typing=True,
    )
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    return (conv2d,)
WARNING:py.warnings:/home/holywu/.local/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)

DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %conv_weight : [num_users=1] = get_attr[target=conv.weight]
    %conv_bias : [num_users=1] = get_attr[target=conv.bias]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv_weight, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    return (convolution,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %conv_weight : [num_users=1] = get_attr[target=conv.weight]
    %conv_bias : [num_users=1] = get_attr[target=conv.bias]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv_weight, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    return (convolution,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_nodes:Removed 0 assert_scalar nodes:
graph():
    %conv_weight : [num_users=1] = get_attr[target=conv.weight]
    %conv_bias : [num_users=1] = get_attr[target=conv.bias]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv_weight, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    return (convolution,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %conv_weight : [num_users=1] = get_attr[target=conv.weight]
    %conv_bias : [num_users=1] = get_attr[target=conv.bias]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv_weight, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    return (convolution,)
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.convolution.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.convolution.default
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.convolution.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.convolution.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.convolution.default
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.convolution.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(1, 3, 224, 224)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv_weight : [num_users=1] = get_attr[target=conv.weight]
    %conv_bias : [num_users=1] = get_attr[target=conv.bias]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv_weight, %conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    return convolution
WARNING:py.warnings:/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/utils.py:423: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return torch.tensor(tensor).dtype

DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.convolution.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.convolution.default
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 3, 224, 224], dtype=DataType.BF16]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.bfloat16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node conv_weight (kind: conv.weight, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node conv_weight [conv.weight] (Inputs: () | Outputs: (conv_weight: (3, 3, 3, 3)@float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node conv_bias (kind: conv.bias, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node conv_bias [conv.bias] (Inputs: () | Outputs: (conv_bias: (3,)@float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node conv/convolution (kind: aten.convolution.default, args: ('x <Node>', 'conv_weight <Node>', 'conv_bias <Node>', ['1 <int>', '1 <int>'], ['0 <int>', '0 <int>'], ['1 <int>', '1 <int>'], 'False <bool>', ['0 <int>', '0 <int>'], '1 <int>'))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.convolution.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.convolution.default
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node conv/convolution [aten.convolution.default] (Inputs: (x: (1, 3, 224, 224)@torch.bfloat16, conv_weight: (3, 3, 3, 3)@float32, conv_bias: (3,)@float32, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) | Outputs: (convolution: (1, 3, 222, 222)@torch.bfloat16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('convolution <Node>',))
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: API Usage Error ([CONVOLUTION]-[aten_ops.convolution.default]-[conv/convolution]: IConvolutionLayer `input` and `kernel` must be of same type. `input` type is BFloat16 but `kernel` is of type Float.)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(81), dtype=DataType.BF16]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (convolution: (1, 3, 222, 222)@torch.bfloat16) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.017973
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [CONVOLUTION]-[aten_ops.convolution.default]-[conv/convolution].)
ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 4: API Usage Error ([CONVOLUTION]-[aten_ops.convolution.default]-[conv/convolution]: IConvolutionLayer `input` and `kernel` must be of same type. `input` type is BFloat16 but `kernel` is of type Float.)
Traceback (most recent call last):
  File "/home/holywu/test.py", line 28, in <module>
    trt_model = torch_tensorrt.dynamo.compile(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 681, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 885, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 90, in convert_module
    interpreter_result = interpret_module_to_result(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 69, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 717, in run
    assert serialized_engine
           ^^^^^^^^^^^^^^^^^
AssertionError

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.7.0.dev20250311+cu128
  • PyTorch Version (e.g. 1.0): 2.7.0.dev20250311+cu128
  • CPU Architecture: x64
  • OS (e.g., Linux): Ubuntu 24.04.1 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.12.3
  • CUDA version: 12.8
  • GPU models and configuration: RTX 4060 Ti
  • Any other relevant information:
@HolyWu HolyWu added the bug Something isn't working label Mar 12, 2025
@peri044
Copy link
Collaborator

peri044 commented Mar 14, 2025

@HolyWu Thanks for reporting this. This issue occurs because the bfloat16 weights in the model are converted into FP32 during the conversion phase and hence inputs and weights now have different data types. One approach that seemed to work is casting the weights to BF16 within the converter

if isinstance(weight, torch.Tensor) and weight.dtype == torch.bfloat16:
   weight_fp32 = weight.to(torch.float32)
   weight_trt_tensor = get_trt_tensor(ctx, weight_fp32, f"{name}_weight_fp32")
   weight = cast_trt_tensor(ctx, weight_trt_tensor, trt.bfloat16, f"{name}_weight_bf16")

The above change ensures the engine builds successfully. I'm investigating other alternatives and will open a PR once I have the fix ready.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
2 participants