From c3cd651bc2b644708b7193a4f5a79cdcaa88abe2 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Mon, 5 May 2025 05:50:38 +0000 Subject: [PATCH] feat: enable AOT tensorrt plugin example --- examples/dynamo/aot_plugin.py | 156 ++++++++++++++++++ .../plugins/_generate_plugin_converter.py | 5 +- 2 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 examples/dynamo/aot_plugin.py diff --git a/examples/dynamo/aot_plugin.py b/examples/dynamo/aot_plugin.py new file mode 100644 index 0000000000..5c3bc2def4 --- /dev/null +++ b/examples/dynamo/aot_plugin.py @@ -0,0 +1,156 @@ +import argparse +from typing import Tuple, Union + + +import tensorrt as trt +import tensorrt.plugin as trtp +import torch +import torch_tensorrt +import triton +import triton.language as tl + + +trt_logger = trt.Logger(trt.Logger.VERBOSE) + + +@triton.jit +def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + output = x + 1 + tl.store(y_ptr + offsets, output, mask=mask) + + +@torch.library.custom_op("my::add_one", mutates_args=()) # type: ignore[misc] +def add_one( + X: torch.Tensor +) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda + + # Create output tensor + Y = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 256 + + # Grid of programs + grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),) + + # Launch the kernel + add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE) + + return Y + + +@torch.library.register_fake("my::add_one") +def _(X: torch.Tensor) -> torch.Tensor: + return X + + +# torch_tensorrt.dynamo.conversion.plugins.generate_plugin( +# "my::add_one" +# ) + +@trtp.register("my::add_one") +def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]: + return X.like() + +@trtp.aot_impl("my::add_one") +def add_plugin_aot_impl( + X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int +) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]: + + + type_str = "fp32" if X.dtype == trt.float32 else "fp16" + + block_size = 256 + src = triton.compiler.ASTSource( + fn=add_one_kernel, + signature={ + "x_ptr": f"*{type_str}", + "n_elements": "i32", + "y_ptr": f"*{type_str}", + "BLOCK_SIZE": "constexpr", + }, + constants={ + "BLOCK_SIZE": block_size, + }, + ) + + compiled_kernel = triton.compile(src) + + N = X.shape_expr.numel() + launch_params = trtp.KernelLaunchParams() + + # grid dims + launch_params.grid_x = trtp.cdiv(N, block_size) + # block dims + launch_params.block_x = compiled_kernel.metadata.num_warps * 32 + # shared memory + launch_params.shared_mem = compiled_kernel.metadata.shared + + extra_args = trtp.SymIntExprs(1) + extra_args[0] = trtp.SymInt32(N) + + return ( + compiled_kernel.metadata.name, + compiled_kernel.asm["ptx"], + launch_params, + extra_args, + ) + +torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( + "my::add_one", + supports_dynamic_shapes=False, + requires_output_allocator=False, + aot=True, +) + + +class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, X: torch.Tensor) -> torch.Tensor: + res = torch.ops.my.add_one.default(X) + + return res + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--aot", action="store_true", help="Try to use AOT compilation", default=False + ) + args = parser.parse_args() + + + + my_model = MyModel().to("cuda") + m = torch.full((64, 64), 2, device="cuda", dtype=torch.float) + + # This works! + assert my_model(X=m)[0][0] == 3.0 + + + with torch_tensorrt.logging.debug(): + trt_inputs = [m] + model_trt = torch_tensorrt.compile( + my_model, + inputs=trt_inputs, + debug=True, + min_block_size=1, + ) + print("Model compiled successfully!") + print("Running inference with compiled model...") + for i in range(10): + res = model_trt(m) + assert torch.allclose(res, my_model(m)), "Results do not match!" + + print("Inference successful!") + print(res) \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index 99ea3bc356..fffc988336 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -31,6 +31,7 @@ def _generate_plugin_converter( priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, requires_output_allocator: bool = False, + aot: bool = False, ) -> DynamoConverterImplSignature: torch_target = getattr(getattr(torch.ops, namespace), op_name) overload_str = overload if overload else "" @@ -80,7 +81,7 @@ def custom_kernel_converter( if isinstance(v, torch.fx.immutable_collections.immutable_list): kwargs[k] = np.array(v) - layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs)) + layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=aot) assert layer, f"{namespace}::{name} plugin layer was not able to be created" _LOGGER.debug( f"Adding generated plugin for {namespace}::{name} to tensorrt network" @@ -107,6 +108,7 @@ def generate_plugin_converter( priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, requires_output_allocator: bool = False, + aot: bool = False, ) -> DynamoConverterImplSignature: plugin_ns, plugin_name = plugin_id.split("::") return _generate_plugin_converter( @@ -116,4 +118,5 @@ def generate_plugin_converter( priority=priority, supports_dynamic_shapes=supports_dynamic_shapes, requires_output_allocator=requires_output_allocator, + aot=aot, )