Skip to content

feat: TensorRT AOT Plugin #3504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions examples/dynamo/aot_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import argparse
from typing import Tuple, Union


import tensorrt as trt
import tensorrt.plugin as trtp
import torch
import torch_tensorrt
import triton
import triton.language as tl


trt_logger = trt.Logger(trt.Logger.VERBOSE)


@triton.jit
def add_one_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
output = x + 1
tl.store(y_ptr + offsets, output, mask=mask)


@torch.library.custom_op("my::add_one", mutates_args=()) # type: ignore[misc]
def add_one(
X: torch.Tensor
) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda

# Create output tensor
Y = torch.empty_like(X)

# Define block size
BLOCK_SIZE = 256

# Grid of programs
grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)

# Launch the kernel
add_one_kernel[grid](X, X.numel(), Y, BLOCK_SIZE=BLOCK_SIZE)

return Y


@torch.library.register_fake("my::add_one")
def _(X: torch.Tensor) -> torch.Tensor:
return X


# torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
# "my::add_one"
# )

@trtp.register("my::add_one")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not use torch_tensorrt.dynamo.conversion.custom_op here?

def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
return X.like()

@trtp.aot_impl("my::add_one")
def add_plugin_aot_impl(
X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:


type_str = "fp32" if X.dtype == trt.float32 else "fp16"

block_size = 256
src = triton.compiler.ASTSource(
fn=add_one_kernel,
signature={
"x_ptr": f"*{type_str}",
"n_elements": "i32",
"y_ptr": f"*{type_str}",
"BLOCK_SIZE": "constexpr",
},
constants={
"BLOCK_SIZE": block_size,
},
)

compiled_kernel = triton.compile(src)

N = X.shape_expr.numel()
launch_params = trtp.KernelLaunchParams()

# grid dims
launch_params.grid_x = trtp.cdiv(N, block_size)
# block dims
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
# shared memory
launch_params.shared_mem = compiled_kernel.metadata.shared

extra_args = trtp.SymIntExprs(1)
extra_args[0] = trtp.SymInt32(N)

return (
compiled_kernel.metadata.name,
compiled_kernel.asm["ptx"],
launch_params,
extra_args,
)

torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
"my::add_one",
supports_dynamic_shapes=False,
requires_output_allocator=False,
aot=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think that we need 2 things. 1. there should be a flag something like use_aot_if_available and then in generate_plugin_converter a function that checks on the aot_impl registration

)


class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, X: torch.Tensor) -> torch.Tensor:
res = torch.ops.my.add_one.default(X)

return res


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument(
"--aot", action="store_true", help="Try to use AOT compilation", default=False
)
args = parser.parse_args()



my_model = MyModel().to("cuda")
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)

# This works!
assert my_model(X=m)[0][0] == 3.0


with torch_tensorrt.logging.debug():
trt_inputs = [m]
model_trt = torch_tensorrt.compile(
my_model,
inputs=trt_inputs,
debug=True,
min_block_size=1,
)
print("Model compiled successfully!")
print("Running inference with compiled model...")
for i in range(10):
res = model_trt(m)
assert torch.allclose(res, my_model(m)), "Results do not match!"

print("Inference successful!")
print(res)
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _generate_plugin_converter(
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
requires_output_allocator: bool = False,
aot: bool = False,
) -> DynamoConverterImplSignature:
torch_target = getattr(getattr(torch.ops, namespace), op_name)
overload_str = overload if overload else ""
Expand Down Expand Up @@ -80,7 +81,7 @@ def custom_kernel_converter(
if isinstance(v, torch.fx.immutable_collections.immutable_list):
kwargs[k] = np.array(v)

layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs))
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs), aot=aot)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be a utility function that checks on aot_impl registrations

assert layer, f"{namespace}::{name} plugin layer was not able to be created"
_LOGGER.debug(
f"Adding generated plugin for {namespace}::{name} to tensorrt network"
Expand All @@ -107,6 +108,7 @@ def generate_plugin_converter(
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
requires_output_allocator: bool = False,
aot: bool = False,
) -> DynamoConverterImplSignature:
plugin_ns, plugin_name = plugin_id.split("::")
return _generate_plugin_converter(
Expand All @@ -116,4 +118,5 @@ def generate_plugin_converter(
priority=priority,
supports_dynamic_shapes=supports_dynamic_shapes,
requires_output_allocator=requires_output_allocator,
aot=aot,
)
Loading