16
16
17
17
from torchao .prototype .moe_training .kernels .float8_rowwise import (
18
18
triton_fp8_rowwise_3d_transpose_rhs ,
19
+ triton_fp8_rowwise_3d_transpose_rhs_fused_reduction ,
19
20
)
20
21
from torchao .prototype .moe_training .utils import (
21
22
torch_to_3d_rowwise_float8_transpose_rhs ,
@@ -37,9 +38,11 @@ class ExperimentConfig:
37
38
@dataclass (frozen = True )
38
39
class ExperimentResult :
39
40
torch_time_us : float
40
- triton_time_us : float
41
+ triton_atomic_time_us : float
42
+ triton_reduction_time_us : float
41
43
torch_mem_bw_gbps : float
42
- triton_mem_bw_gbps : float
44
+ triton_atomic_mem_bw_gbps : float
45
+ triton_reduction_mem_bw_gbps : float
43
46
44
47
45
48
@dataclass (frozen = True )
@@ -59,7 +62,7 @@ def get_configs() -> List[ExperimentConfig]:
59
62
(128 , 5120 , 8192 ), # w2
60
63
]
61
64
high_precision_dtypes = [torch .bfloat16 ]
62
- power_of_2_scales = [True , False ]
65
+ power_of_2_scales = [True ]
63
66
configs = []
64
67
for input_shape , high_precision_dtype , power_of_2_scale in itertools .product (
65
68
input_shapes , high_precision_dtypes , power_of_2_scales
@@ -94,14 +97,22 @@ def run_torch(input_tensor: torch.Tensor):
94
97
)
95
98
return out
96
99
97
- def run_triton (input_tensor : torch .Tensor ):
100
+ def run_triton_atomic (input_tensor : torch .Tensor ):
98
101
out = triton_fp8_rowwise_3d_transpose_rhs (
99
102
input_tensor ,
100
103
output_dtype = torch .float8_e4m3fn ,
101
104
round_scales_to_power_of_2 = config .power_of_2_scales ,
102
105
)
103
106
return out
104
107
108
+ def run_triton_reduction (input_tensor : torch .Tensor ):
109
+ out = triton_fp8_rowwise_3d_transpose_rhs_fused_reduction (
110
+ input_tensor ,
111
+ output_dtype = torch .float8_e4m3fn ,
112
+ round_scales_to_power_of_2 = config .power_of_2_scales ,
113
+ )
114
+ return out
115
+
105
116
# bench torch
106
117
compiled_run_torch = torch .compile (run_torch )
107
118
warmup (run_torch , input_tensor )
@@ -110,10 +121,19 @@ def run_triton(input_tensor: torch.Tensor):
110
121
input_tensor ,
111
122
)
112
123
113
- # bench triton
114
- warmup (run_triton , input_tensor )
115
- triton_time_us = benchmark_cuda_function_in_microseconds (
116
- run_triton ,
124
+ # bench triton atomic method
125
+ run_triton_atomic_c = torch .compile (run_triton_atomic )
126
+ warmup (run_triton_atomic_c , input_tensor )
127
+ triton_atomic_time_us = benchmark_cuda_function_in_microseconds (
128
+ run_triton_atomic_c ,
129
+ input_tensor ,
130
+ )
131
+
132
+ # bench triton reduction method
133
+ run_triton_reduction_c = torch .compile (run_triton_reduction )
134
+ warmup (run_triton_reduction_c , input_tensor )
135
+ triton_reduction_time_us = benchmark_cuda_function_in_microseconds (
136
+ run_triton_reduction_c ,
117
137
input_tensor ,
118
138
)
119
139
@@ -129,13 +149,20 @@ def run_triton(input_tensor: torch.Tensor):
129
149
# Both torch.compile codegen and the triton kernel read the input tensor twice
130
150
# (once for scale calculations, once for scaling + casting).
131
151
torch_mem_bw_gbps = ((read_bytes * 2 + write_bytes ) / 1e9 ) / (torch_time_us / 1e6 )
132
- triton_mem_bw_gbps = ((read_bytes * 2 + write_bytes ) / 1e9 ) / (triton_time_us / 1e6 )
152
+ triton_atomic_mem_bw_gbps = ((read_bytes * 2 + write_bytes ) / 1e9 ) / (
153
+ triton_atomic_time_us / 1e6
154
+ )
155
+ triton_reduction_mem_bw_gbps = ((read_bytes * 2 + write_bytes ) / 1e9 ) / (
156
+ triton_reduction_time_us / 1e6
157
+ )
133
158
134
159
return ExperimentResult (
135
160
torch_time_us = torch_time_us ,
136
- triton_time_us = triton_time_us ,
161
+ triton_atomic_time_us = triton_atomic_time_us ,
162
+ triton_reduction_time_us = triton_reduction_time_us ,
137
163
torch_mem_bw_gbps = torch_mem_bw_gbps ,
138
- triton_mem_bw_gbps = triton_mem_bw_gbps ,
164
+ triton_atomic_mem_bw_gbps = triton_atomic_mem_bw_gbps ,
165
+ triton_reduction_mem_bw_gbps = triton_reduction_mem_bw_gbps ,
139
166
)
140
167
141
168
@@ -144,10 +171,13 @@ def print_results(experiments: List[Experiment]):
144
171
"input_shape" ,
145
172
"power_of_2_scales" ,
146
173
"torch_time_us" ,
147
- "triton_time_us" ,
174
+ "triton_atomic_time_us" ,
175
+ "triton_reduction_time_us" ,
148
176
"torch_mem_bw_gbps" ,
149
- "triton_mem_bw_gbps" ,
150
- "triton_speedup" ,
177
+ "triton_atomic_mem_bw_gbps" ,
178
+ "triton_reduction_mem_bw_gbps" ,
179
+ "triton_atomic_speedup" ,
180
+ "triton_reduction_speedup" ,
151
181
]
152
182
rows = []
153
183
for experiment in experiments :
@@ -157,10 +187,13 @@ def print_results(experiments: List[Experiment]):
157
187
input_shape ,
158
188
experiment .config .power_of_2_scales ,
159
189
experiment .result .torch_time_us ,
160
- experiment .result .triton_time_us ,
190
+ experiment .result .triton_atomic_time_us ,
191
+ experiment .result .triton_reduction_time_us ,
161
192
round (experiment .result .torch_mem_bw_gbps , 3 ),
162
- round (experiment .result .triton_mem_bw_gbps , 3 ),
163
- f"{ experiment .result .torch_time_us / experiment .result .triton_time_us :.2f} x" ,
193
+ round (experiment .result .triton_atomic_mem_bw_gbps , 3 ),
194
+ round (experiment .result .triton_reduction_mem_bw_gbps , 3 ),
195
+ f"{ experiment .result .torch_time_us / experiment .result .triton_atomic_time_us :.2f} x" ,
196
+ f"{ experiment .result .torch_time_us / experiment .result .triton_reduction_time_us :.2f} x" ,
164
197
]
165
198
)
166
199
print (tabulate (rows , headers = headers ))
0 commit comments