Skip to content

Commit 2cdfdfe

Browse files
[moe fp8 training] fused reduction kernel along dim1 for 3d expert weights in backward
1 parent 3848c56 commit 2cdfdfe

File tree

3 files changed

+293
-21
lines changed

3 files changed

+293
-21
lines changed

benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from torchao.prototype.moe_training.kernels.float8_rowwise import (
1818
triton_fp8_rowwise_3d_transpose_rhs,
19+
triton_fp8_rowwise_3d_transpose_rhs_fused_reduction,
1920
)
2021
from torchao.prototype.moe_training.utils import (
2122
torch_to_3d_rowwise_float8_transpose_rhs,
@@ -37,9 +38,11 @@ class ExperimentConfig:
3738
@dataclass(frozen=True)
3839
class ExperimentResult:
3940
torch_time_us: float
40-
triton_time_us: float
41+
triton_atomic_time_us: float
42+
triton_reduction_time_us: float
4143
torch_mem_bw_gbps: float
42-
triton_mem_bw_gbps: float
44+
triton_atomic_mem_bw_gbps: float
45+
triton_reduction_mem_bw_gbps: float
4346

4447

4548
@dataclass(frozen=True)
@@ -59,7 +62,7 @@ def get_configs() -> List[ExperimentConfig]:
5962
(128, 5120, 8192), # w2
6063
]
6164
high_precision_dtypes = [torch.bfloat16]
62-
power_of_2_scales = [True, False]
65+
power_of_2_scales = [True]
6366
configs = []
6467
for input_shape, high_precision_dtype, power_of_2_scale in itertools.product(
6568
input_shapes, high_precision_dtypes, power_of_2_scales
@@ -94,14 +97,22 @@ def run_torch(input_tensor: torch.Tensor):
9497
)
9598
return out
9699

97-
def run_triton(input_tensor: torch.Tensor):
100+
def run_triton_atomic(input_tensor: torch.Tensor):
98101
out = triton_fp8_rowwise_3d_transpose_rhs(
99102
input_tensor,
100103
output_dtype=torch.float8_e4m3fn,
101104
round_scales_to_power_of_2=config.power_of_2_scales,
102105
)
103106
return out
104107

108+
def run_triton_reduction(input_tensor: torch.Tensor):
109+
out = triton_fp8_rowwise_3d_transpose_rhs_fused_reduction(
110+
input_tensor,
111+
output_dtype=torch.float8_e4m3fn,
112+
round_scales_to_power_of_2=config.power_of_2_scales,
113+
)
114+
return out
115+
105116
# bench torch
106117
compiled_run_torch = torch.compile(run_torch)
107118
warmup(run_torch, input_tensor)
@@ -110,10 +121,19 @@ def run_triton(input_tensor: torch.Tensor):
110121
input_tensor,
111122
)
112123

113-
# bench triton
114-
warmup(run_triton, input_tensor)
115-
triton_time_us = benchmark_cuda_function_in_microseconds(
116-
run_triton,
124+
# bench triton atomic method
125+
run_triton_atomic_c = torch.compile(run_triton_atomic)
126+
warmup(run_triton_atomic_c, input_tensor)
127+
triton_atomic_time_us = benchmark_cuda_function_in_microseconds(
128+
run_triton_atomic_c,
129+
input_tensor,
130+
)
131+
132+
# bench triton reduction method
133+
run_triton_reduction_c = torch.compile(run_triton_reduction)
134+
warmup(run_triton_reduction_c, input_tensor)
135+
triton_reduction_time_us = benchmark_cuda_function_in_microseconds(
136+
run_triton_reduction_c,
117137
input_tensor,
118138
)
119139

@@ -129,13 +149,20 @@ def run_triton(input_tensor: torch.Tensor):
129149
# Both torch.compile codegen and the triton kernel read the input tensor twice
130150
# (once for scale calculations, once for scaling + casting).
131151
torch_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (torch_time_us / 1e6)
132-
triton_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (triton_time_us / 1e6)
152+
triton_atomic_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (
153+
triton_atomic_time_us / 1e6
154+
)
155+
triton_reduction_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (
156+
triton_reduction_time_us / 1e6
157+
)
133158

134159
return ExperimentResult(
135160
torch_time_us=torch_time_us,
136-
triton_time_us=triton_time_us,
161+
triton_atomic_time_us=triton_atomic_time_us,
162+
triton_reduction_time_us=triton_reduction_time_us,
137163
torch_mem_bw_gbps=torch_mem_bw_gbps,
138-
triton_mem_bw_gbps=triton_mem_bw_gbps,
164+
triton_atomic_mem_bw_gbps=triton_atomic_mem_bw_gbps,
165+
triton_reduction_mem_bw_gbps=triton_reduction_mem_bw_gbps,
139166
)
140167

141168

@@ -144,10 +171,13 @@ def print_results(experiments: List[Experiment]):
144171
"input_shape",
145172
"power_of_2_scales",
146173
"torch_time_us",
147-
"triton_time_us",
174+
"triton_atomic_time_us",
175+
"triton_reduction_time_us",
148176
"torch_mem_bw_gbps",
149-
"triton_mem_bw_gbps",
150-
"triton_speedup",
177+
"triton_atomic_mem_bw_gbps",
178+
"triton_reduction_mem_bw_gbps",
179+
"triton_atomic_speedup",
180+
"triton_reduction_speedup",
151181
]
152182
rows = []
153183
for experiment in experiments:
@@ -157,10 +187,13 @@ def print_results(experiments: List[Experiment]):
157187
input_shape,
158188
experiment.config.power_of_2_scales,
159189
experiment.result.torch_time_us,
160-
experiment.result.triton_time_us,
190+
experiment.result.triton_atomic_time_us,
191+
experiment.result.triton_reduction_time_us,
161192
round(experiment.result.torch_mem_bw_gbps, 3),
162-
round(experiment.result.triton_mem_bw_gbps, 3),
163-
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
193+
round(experiment.result.triton_atomic_mem_bw_gbps, 3),
194+
round(experiment.result.triton_reduction_mem_bw_gbps, 3),
195+
f"{experiment.result.torch_time_us / experiment.result.triton_atomic_time_us:.2f}x",
196+
f"{experiment.result.torch_time_us / experiment.result.triton_reduction_time_us:.2f}x",
164197
]
165198
)
166199
print(tabulate(rows, headers=headers))

test/prototype/moe_training/test_kernels.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from torchao.prototype.moe_training.kernels.float8_rowwise import (
1717
triton_fp8_rowwise_3d_transpose_rhs,
18+
triton_fp8_rowwise_3d_transpose_rhs_fused_reduction,
1819
)
1920
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
2021
triton_fp8_per_group_colwise_scales,
@@ -128,7 +129,7 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo
128129

129130
@skip_if_rocm("ROCm not supported")
130131
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
131-
def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool):
132+
def test_fp8_rowwise_3d_transpose_rhs_atomic(round_scales_to_power_of_2: bool):
132133
device = "cuda"
133134
experts, n, k = 8, 4 * 5120, 5120
134135

@@ -159,3 +160,38 @@ def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool):
159160
assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal"
160161
assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal"
161162
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"
163+
164+
165+
@skip_if_rocm("ROCm not supported")
166+
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
167+
def test_fp8_rowwise_3d_transpose_rhs_reduction(round_scales_to_power_of_2: bool):
168+
device = "cuda"
169+
experts, n, k = 8, 4 * 5120, 5120
170+
171+
# Example expert weights as it comes into forward transposed
172+
torch.manual_seed(0)
173+
x = torch.randn((experts, n, k), dtype=torch.bfloat16, device=device).transpose(
174+
-2, -1
175+
)
176+
177+
# Compute reference with torch impl
178+
ref_fp8, ref_scales = torch_to_3d_rowwise_float8_transpose_rhs(
179+
x,
180+
target_dtype=torch.float8_e4m3fn,
181+
round_scales_to_power_of_2=round_scales_to_power_of_2,
182+
)
183+
# Torch impl keeps empty scaled dim, so we squeeze it out to be consistent with triton impl
184+
ref_scales = ref_scales.squeeze(1)
185+
186+
triton_fp8, triton_scales = triton_fp8_rowwise_3d_transpose_rhs_fused_reduction(
187+
x,
188+
output_dtype=torch.float8_e4m3fn,
189+
round_scales_to_power_of_2=round_scales_to_power_of_2,
190+
)
191+
assert ref_scales.shape == triton_scales.shape, "scale shapes not equal"
192+
assert ref_scales.stride() == triton_scales.stride(), "scale strides not equal"
193+
assert torch.allclose(ref_scales, triton_scales, rtol=0, atol=0), "scales not equal"
194+
195+
assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal"
196+
assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal"
197+
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"

0 commit comments

Comments
 (0)