You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/torch/fx/graph.py:1801: UserWarning: Node prediction_src_pe_lifted_tensor_0 target Prediction.src_pe.lifted_tensor_0 lifted_tensor_0 of Prediction.src_pe does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(
[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/torch/fx/graph.py:1801: UserWarning: Node prediction_src_pe_lifted_tensor_1 target Prediction.src_pe.lifted_tensor_1 lifted_tensor_1 of Prediction.src_pe does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(
[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/torch/fx/graph.py:1801: UserWarning: Node prediction_trg_pe_lifted_tensor_2 target Prediction.trg_pe.lifted_tensor_2 lifted_tensor_2 of Prediction.trg_pe does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(
[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/torch/fx/graph.py:1801: UserWarning: Node prediction_trg_pe_lifted_tensor_3 target Prediction.trg_pe.lifted_tensor_3 lifted_tensor_3 of Prediction.trg_pe does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(
[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/torch/fx/graph.py:1801: UserWarning: Node prediction_trg_pe_lifted_tensor_4 target Prediction.trg_pe.lifted_tensor_4 lifted_tensor_4 of Prediction.trg_pe does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(
[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/torch/fx/graph.py:1810: UserWarning: Additional 22 warnings suppressed about get_attr references
warnings.warn(
Traceback (most recent call last):
File "convert_torch_tensorrt.py", line 146, in <module>
convert_tensorrt(opt)
File "convert_torch_tensorrt.py", line 50, in convert_tensorrt
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/_compile.py", line 289, in compile
trt_graph_module = dynamo_compile(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 670, in compile
exported_program = exported_program.run_decompositions(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 1310, in run_decompositions
return _decompose_exported_program(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 784, in _decompose_exported_program
) = _decompose_and_get_gm_with_new_signature_constants(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/exported_program.py", line 472, in _decompose_and_get_gm_with_new_signature_constants
aten_export_artifact = _export_to_aten_ir(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/export/_trace.py", line 743, in _export_to_aten_ir
gm, graph_signature = transform(aot_export_module)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1357, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1596, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 582, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 832, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in aot_dispatch_export
graph, _, _ = aot_dispatch_base_graph(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 153, in aot_dispatch_base_graph
fw_module = _create_graph(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 55, in _create_graph
fx_g = make_fx(
^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2200, in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2138, in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2109, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 755, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1142, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1698, in trace
res = super().trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 755, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/_symbolic_trace.py", line 843, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1197, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "<string>", line 1, in <lambda>
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 693, in inner_fn
outs = fn(*args)
^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 413, in _functionalized_f_helper
f_outs = fn(*f_args)
^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 78, in inner_fn
outs = fn(*args)
^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
out = PropagateUnbackedSymInts(mod).run(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 167, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6826, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 230, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 310, in call_function
return target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 758, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1245, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 758, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/functional_tensor.py", line 527, in __torch_dispatch__
outs_unwrapped = func._op_dk(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 26, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1347, in __torch_dispatch__
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 793, in proxy_call
r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2268, in maybe_handle_decomp
out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/lowering/_decompositions.py", line 208, in slice_scatter_decomposition
assert isinstance(end, int), "end must be an integer"
^^^^^^^^^^^^^^^^^^^^
AssertionError: end must be an integer
While executing %copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_11, %clone_1), kwargs = {})
Original traceback:
File "/model.py", line 60, in forward
outs, probs = self.Prediction(features)
File "prediction.py", line 85, in forward
probs[:, step, :] = prob.clone().detach()
My code for model(prediction.py) is like below,
def forward(self, src):
memory = self.encode(src)
b = memory.size(0)
# filling with [BOS](index=1)
outs = torch.ones(b, 1).fill_(1).long().to(self.device)
probs = torch.zeros(b, self.max_len, self.num_class).to(self.device)
for step in range(self.max_len - 1):
probs = probs.clone().detach()
# [B, step+1, d_model]
out = self.decode(memory, outs, subsequent_mask(outs.size(1)).long().to(self.device))
prob = self.generator(out[:, -1]) # [B, num_class]
_, next_word = torch.max(prob, dim=1) # [B]
outs = torch.cat([outs, next_word.unsqueeze(1)], dim=1) # [B, step+2]
probs[:, step, :] = prob.clone().detach() # <--- Error occur at this line
return outs, probs
Expected behavior
"rt_model.ep" model file must be created and saved.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
Torch-TensorRT Version (e.g. 1.0.0): 2.6.0a0
PyTorch Version (e.g. 1.0): 2.7.0a0+ecf3bae40a.nv25.2
CPU Architecture: x86_64
OS (e.g., Linux): Ubuntu 24.04.1 LTS
How you installed PyTorch (conda, pip, libtorch, source): Docker image "PyTorch Release 25.02" at link
Python version: 3.12.3
CUDA version: 12.8
GPU models and configuration: RTX4000
Additional context
Even if I do with adding or removing functions such as clone() and detach() to the 'probs' variable in the example code where the error occurred, the same error occurs.
!!! important !!!
Model transformation just succeeds when static batch is set as inputs like below.
Setting dynamic batch as inputs result the error as suggested.
Is there a way to take dynamic batch as inputs like the sample code provided as an example?
The text was updated successfully, but these errors were encountered:
Thanks for the issue. Trying to repro the above.
Couple of questions- what are the values of min_batch_size, max_batch_size, channel, height and width you are using. Also the subsequent_mask? Do you have a simple repro?
There are a couple of other things missing for the repro. The opt in model = MyModel(opt) is missing. Also I do not have the model.pth which is used in model.load_state_dict(torch.load("model.pth"))
@apbose
MyModel is a model composed of a transformer encoder and a decoder. However, it is difficult to share the entire structure of the model. Is it impossible to debug with only the given example?
opt is nothing special. It is an arguments with the following values defined:
--img_w, --img_h, --transformer_encoder_layer_num, ...
Bug Description
To Reproduce
Steps to reproduce the behavior:
Expected behavior
"rt_model.ep" model file must be created and saved.
Environment
conda
,pip
,libtorch
, source): Docker image "PyTorch Release 25.02" at linkAdditional context
Even if I do with adding or removing functions such as clone() and detach() to the 'probs' variable in the example code where the error occurred, the same error occurs.
!!! important !!!
Model transformation just succeeds when static batch is set as inputs like below.
Setting dynamic batch as inputs result the error as suggested.
Is there a way to take dynamic batch as inputs like the sample code provided as an example?
The text was updated successfully, but these errors were encountered: