Skip to content

Commit 0167831

Browse files
committed
update
1 parent 88c304e commit 0167831

File tree

1 file changed

+44
-44
lines changed

1 file changed

+44
-44
lines changed

examples/dynamo/aot_plugin.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -52,47 +52,50 @@ def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
5252
return X.like()
5353

5454

55-
# @trtp.aot_impl("my::add_one")
56-
# def add_plugin_aot_impl(
57-
# X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
58-
# ) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
59-
# type_str = "fp32" if X.dtype == trt.float32 else "fp16"
60-
61-
# block_size = 256
62-
# src = triton.compiler.ASTSource(
63-
# fn=add_one_kernel,
64-
# signature={
65-
# "x_ptr": f"*{type_str}",
66-
# "n_elements": "i32",
67-
# "y_ptr": f"*{type_str}",
68-
# "BLOCK_SIZE": "constexpr",
69-
# },
70-
# constants={
71-
# "BLOCK_SIZE": block_size,
72-
# },
73-
# )
74-
75-
# compiled_kernel = triton.compile(src)
76-
77-
# N = X.shape_expr.numel()
78-
# launch_params = trtp.KernelLaunchParams()
79-
80-
# # grid dims
81-
# launch_params.grid_x = trtp.cdiv(N, block_size)
82-
# # block dims
83-
# launch_params.block_x = compiled_kernel.metadata.num_warps * 32
84-
# # shared memory
85-
# launch_params.shared_mem = compiled_kernel.metadata.shared
86-
87-
# extra_args = trtp.SymIntExprs(1)
88-
# extra_args[0] = trtp.SymInt32(N)
89-
90-
# return (
91-
# compiled_kernel.metadata.name,
92-
# compiled_kernel.asm["ptx"],
93-
# launch_params,
94-
# extra_args,
95-
# )
55+
@trtp.aot_impl("my::add_one")
56+
def add_plugin_aot_impl(
57+
X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
58+
) -> Tuple[
59+
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
60+
]:
61+
type_str = "fp32" if X.dtype == trt.float32 else "fp16"
62+
63+
block_size = 256
64+
src = triton.compiler.ASTSource(
65+
fn=add_one_kernel,
66+
signature={
67+
"x_ptr": f"*{type_str}",
68+
"n_elements": "i32",
69+
"y_ptr": f"*{type_str}",
70+
"BLOCK_SIZE": "constexpr",
71+
},
72+
constants={
73+
"BLOCK_SIZE": block_size,
74+
},
75+
)
76+
77+
compiled_kernel = triton.compile(src)
78+
79+
N = X.shape_expr.numel()
80+
launch_params = trtp.KernelLaunchParams()
81+
82+
# grid dims
83+
launch_params.grid_x = trtp.cdiv(N, block_size)
84+
# block dims
85+
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
86+
# shared memory
87+
launch_params.shared_mem = compiled_kernel.metadata.shared
88+
89+
extra_args = trtp.SymIntExprs(1)
90+
extra_args[0] = trtp.SymInt32(N)
91+
92+
return (
93+
compiled_kernel.metadata.name,
94+
compiled_kernel.asm["ptx"],
95+
launch_params,
96+
extra_args,
97+
)
98+
9699

97100
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
98101
"my::add_one",
@@ -113,7 +116,6 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
113116

114117

115118
if __name__ == "__main__":
116-
117119
parser = argparse.ArgumentParser()
118120
parser.add_argument(
119121
"--aot", action="store_true", help="Try to use AOT compilation", default=False
@@ -123,7 +125,6 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
123125
my_model = MyModel().to("cuda")
124126
m = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
125127

126-
# This works!
127128
assert my_model(X=m)[0][0] == 3.0
128129

129130
with torch_tensorrt.logging.debug():
@@ -141,4 +142,3 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
141142
assert torch.allclose(res, my_model(m)), "Results do not match!"
142143

143144
print("Inference successful!")
144-
print(res)

0 commit comments

Comments
 (0)