-
Notifications
You must be signed in to change notification settings - Fork 365
feat: TensorRT AOT Plugin #3504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/aot_plugin.py 2025-05-05 05:52:23.878918+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/aot_plugin.py 2025-05-05 05:52:44.176344+00:00
@@ -23,13 +23,11 @@
output = x + 1
tl.store(y_ptr + offsets, output, mask=mask)
@torch.library.custom_op("my::add_one", mutates_args=()) # type: ignore[misc]
-def add_one(
- X: torch.Tensor
-) -> torch.Tensor:
+def add_one(X: torch.Tensor) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda
# Create output tensor
Y = torch.empty_like(X)
@@ -53,19 +51,22 @@
# torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
# "my::add_one"
# )
+
@trtp.register("my::add_one")
def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
return X.like()
+
@trtp.aot_impl("my::add_one")
def add_plugin_aot_impl(
X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
-) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
-
+) -> Tuple[
+ Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
+]:
type_str = "fp32" if X.dtype == trt.float32 else "fp16"
block_size = 256
src = triton.compiler.ASTSource(
@@ -101,10 +102,11 @@
compiled_kernel.asm["ptx"],
launch_params,
extra_args,
)
+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
"my::add_one",
supports_dynamic_shapes=False,
requires_output_allocator=False,
aot=True,
@@ -127,18 +129,15 @@
parser.add_argument(
"--aot", action="store_true", help="Try to use AOT compilation", default=False
)
args = parser.parse_args()
-
-
my_model = MyModel().to("cuda")
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
# This works!
assert my_model(X=m)[0][0] == 3.0
-
with torch_tensorrt.logging.debug():
trt_inputs = [m]
model_trt = torch_tensorrt.compile(
my_model,
@@ -151,6 +150,6 @@
for i in range(10):
res = model_trt(m)
assert torch.allclose(res, my_model(m)), "Results do not match!"
print("Inference successful!")
- print(res)
\ No newline at end of file
+ print(res)
py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py
Outdated
Show resolved
Hide resolved
@@ -31,7 +31,7 @@ def _generate_plugin_converter( | |||
priority: ConverterPriority = ConverterPriority.STANDARD, | |||
supports_dynamic_shapes: bool = False, | |||
requires_output_allocator: bool = False, | |||
aot: bool = False, | |||
use_aot_if_available: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default to true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
This PR demonstrates how to use AOT plugin in Torch-TensorRT
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: