@@ -52,47 +52,50 @@ def add_plugin_desc(X: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
52
52
return X .like ()
53
53
54
54
55
- # @trtp.aot_impl("my::add_one")
56
- # def add_plugin_aot_impl(
57
- # X: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc], tactic: int
58
- # ) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
59
- # type_str = "fp32" if X.dtype == trt.float32 else "fp16"
60
-
61
- # block_size = 256
62
- # src = triton.compiler.ASTSource(
63
- # fn=add_one_kernel,
64
- # signature={
65
- # "x_ptr": f"*{type_str}",
66
- # "n_elements": "i32",
67
- # "y_ptr": f"*{type_str}",
68
- # "BLOCK_SIZE": "constexpr",
69
- # },
70
- # constants={
71
- # "BLOCK_SIZE": block_size,
72
- # },
73
- # )
74
-
75
- # compiled_kernel = triton.compile(src)
76
-
77
- # N = X.shape_expr.numel()
78
- # launch_params = trtp.KernelLaunchParams()
79
-
80
- # # grid dims
81
- # launch_params.grid_x = trtp.cdiv(N, block_size)
82
- # # block dims
83
- # launch_params.block_x = compiled_kernel.metadata.num_warps * 32
84
- # # shared memory
85
- # launch_params.shared_mem = compiled_kernel.metadata.shared
86
-
87
- # extra_args = trtp.SymIntExprs(1)
88
- # extra_args[0] = trtp.SymInt32(N)
89
-
90
- # return (
91
- # compiled_kernel.metadata.name,
92
- # compiled_kernel.asm["ptx"],
93
- # launch_params,
94
- # extra_args,
95
- # )
55
+ @trtp .aot_impl ("my::add_one" )
56
+ def add_plugin_aot_impl (
57
+ X : trtp .TensorDesc , outputs : Tuple [trtp .TensorDesc ], tactic : int
58
+ ) -> Tuple [
59
+ Union [str , bytes ], Union [str , bytes ], trtp .KernelLaunchParams , trtp .SymExprs
60
+ ]:
61
+ type_str = "fp32" if X .dtype == trt .float32 else "fp16"
62
+
63
+ block_size = 256
64
+ src = triton .compiler .ASTSource (
65
+ fn = add_one_kernel ,
66
+ signature = {
67
+ "x_ptr" : f"*{ type_str } " ,
68
+ "n_elements" : "i32" ,
69
+ "y_ptr" : f"*{ type_str } " ,
70
+ "BLOCK_SIZE" : "constexpr" ,
71
+ },
72
+ constants = {
73
+ "BLOCK_SIZE" : block_size ,
74
+ },
75
+ )
76
+
77
+ compiled_kernel = triton .compile (src )
78
+
79
+ N = X .shape_expr .numel ()
80
+ launch_params = trtp .KernelLaunchParams ()
81
+
82
+ # grid dims
83
+ launch_params .grid_x = trtp .cdiv (N , block_size )
84
+ # block dims
85
+ launch_params .block_x = compiled_kernel .metadata .num_warps * 32
86
+ # shared memory
87
+ launch_params .shared_mem = compiled_kernel .metadata .shared
88
+
89
+ extra_args = trtp .SymIntExprs (1 )
90
+ extra_args [0 ] = trtp .SymInt32 (N )
91
+
92
+ return (
93
+ compiled_kernel .metadata .name ,
94
+ compiled_kernel .asm ["ptx" ],
95
+ launch_params ,
96
+ extra_args ,
97
+ )
98
+
96
99
97
100
torch_tensorrt .dynamo .conversion .plugins .generate_plugin_converter (
98
101
"my::add_one" ,
@@ -113,7 +116,6 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
113
116
114
117
115
118
if __name__ == "__main__" :
116
-
117
119
parser = argparse .ArgumentParser ()
118
120
parser .add_argument (
119
121
"--aot" , action = "store_true" , help = "Try to use AOT compilation" , default = False
@@ -123,7 +125,6 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
123
125
my_model = MyModel ().to ("cuda" )
124
126
m = torch .full ((64 , 64 ), 2 , device = "cuda" , dtype = torch .float )
125
127
126
- # This works!
127
128
assert my_model (X = m )[0 ][0 ] == 3.0
128
129
129
130
with torch_tensorrt .logging .debug ():
@@ -141,4 +142,3 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
141
142
assert torch .allclose (res , my_model (m )), "Results do not match!"
142
143
143
144
print ("Inference successful!" )
144
- print (res )
0 commit comments