Skip to content

Commit afd9cb6

Browse files
[moe fp8 training] test and bench new faster method for per group rowwise scaling
1 parent 253d65a commit afd9cb6

File tree

3 files changed

+340
-68
lines changed

3 files changed

+340
-68
lines changed

benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py renamed to benchmarks/prototype/moe_training/benchmark_per_group_colwise_scaling_kernels.py

Lines changed: 59 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616

1717
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
1818
triton_fp8_per_group_colwise_scales,
19-
triton_fp8_per_group_rowwise_scales,
2019
)
2120
from torchao.prototype.moe_training.utils import (
2221
generate_jagged_offs,
2322
torch_to_float8_per_group_colwise,
24-
torch_to_float8_per_group_rowwise,
2523
)
2624

2725
device = torch.device("cuda")
@@ -39,7 +37,7 @@ class ExperimentConfig:
3937

4038
@dataclass(frozen=True)
4139
class ExperimentResult:
42-
torch_time_us: float
40+
torch_loop_time_us: float
4341
triton_time_us: float
4442
torch_mem_bw_gbps: float
4543
triton_mem_bw_gbps: float
@@ -53,7 +51,7 @@ class Experiment:
5351

5452
def get_configs() -> List[ExperimentConfig]:
5553
input_shapes = [(16640, 5120)] # (Mg, K)
56-
n_groups_list = [1, 16, 128]
54+
n_groups_list = [1, 16, 64]
5755
high_precision_dtypes = [torch.bfloat16]
5856
configs = []
5957
for input_shape, n_groups, high_precision_dtype in itertools.product(
@@ -70,85 +68,82 @@ def get_configs() -> List[ExperimentConfig]:
7068

7169

7270
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
73-
# define test inputs
74-
input_tensor = torch.randn(
75-
*config.input_shape,
76-
dtype=config.high_precision_dtype,
77-
device=device,
71+
# Define test inputs
72+
Mg, K = config.input_shape
73+
74+
# Column major input tensor.
75+
# Right operand in grad_weight = grad_output_t @ input
76+
input_tensor = (
77+
torch.randn(
78+
Mg,
79+
K,
80+
dtype=config.high_precision_dtype,
81+
device=device,
82+
)
83+
.transpose(-2, -1)
84+
.contiguous()
85+
.transpose(-2, -1)
7886
)
79-
input_row_major = input_tensor.clone().detach()
80-
input_col_major = input_tensor.clone().detach().t()
8187

8288
# - configure input to be row-major with groups divided along the column dimension,
8389
# representing the left operand of grad_weight = grad_output_t @ input
8490
# that occurs in the backward pass of the differentiable scaled grouped mm.
8591
# - the transposed tensor in col-major format with groups along the row dimension,
8692
# which represents the right operand.
8793
n_groups = config.n_groups
88-
Mg = input_row_major.shape[0]
8994
offs = generate_jagged_offs(n_groups, Mg, multiple_of=16)
9095

9196
def warmup(func, *args, **kwargs):
9297
for _ in range(10):
9398
func(*args, **kwargs)
9499

95-
def run_torch(
96-
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
97-
):
98-
_ = torch_to_float8_per_group_rowwise(
99-
input_row_major,
100-
offs,
101-
target_dtype=torch.float8_e4m3fn,
102-
round_scales_to_power_of_2=True,
103-
)
104-
_ = torch_to_float8_per_group_colwise(
105-
input_col_major,
106-
offs,
107-
target_dtype=torch.float8_e4m3fn,
108-
round_scales_to_power_of_2=True,
109-
)
110-
111-
def run_triton(
112-
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
113-
):
114-
_ = triton_fp8_per_group_rowwise_scales(
115-
input_row_major,
116-
offs,
117-
output_dtype=torch.float8_e4m3fn,
118-
round_scales_to_power_of_2=True,
119-
)
120-
_ = triton_fp8_per_group_colwise_scales(
121-
input_col_major,
122-
offs,
123-
output_dtype=torch.float8_e4m3fn,
124-
round_scales_to_power_of_2=True,
125-
)
126-
127-
# bench torch
128-
compiled_run_torch = torch.compile(run_torch)
129-
warmup(compiled_run_torch, input_row_major, input_col_major, offs)
130-
torch_time_us = benchmark_cuda_function_in_microseconds(
131-
compiled_run_torch, input_row_major, input_col_major, offs
100+
# Bench torch per group colwise
101+
torch_to_float8_per_group_colwise_c = torch.compile(
102+
torch_to_float8_per_group_colwise
103+
)
104+
warmup(
105+
torch_to_float8_per_group_colwise_c,
106+
input_tensor,
107+
offs,
108+
target_dtype=torch.float8_e4m3fn,
109+
)
110+
torch_loop_time_us = benchmark_cuda_function_in_microseconds(
111+
torch_to_float8_per_group_colwise_c,
112+
input_tensor,
113+
offs,
114+
target_dtype=torch.float8_e4m3fn,
132115
)
133116

134-
# bench triton
135-
warmup(run_triton, input_row_major, input_col_major, offs)
117+
# Bench triton per group colwise
118+
warmup(
119+
triton_fp8_per_group_colwise_scales,
120+
input_tensor,
121+
offs,
122+
output_dtype=torch.float8_e4m3fn,
123+
round_scales_to_power_of_2=True,
124+
)
136125
triton_time_us = benchmark_cuda_function_in_microseconds(
137-
run_triton, input_row_major, input_col_major, offs
126+
triton_fp8_per_group_colwise_scales,
127+
input_tensor,
128+
offs,
129+
output_dtype=torch.float8_e4m3fn,
130+
round_scales_to_power_of_2=True,
138131
)
139132

140-
# mem bw calculations - excluding scales to simplify calculation
141-
# but still get an accurate estimate.
133+
# Mem bw calculations
142134
bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8
143135
num_elements = input_tensor.numel()
144-
read_bytes = num_elements * bytes_per_input_el
145-
write_bytes = num_elements # 1 byte per element in float8_e4m3fn
136+
# 2x read_bytes because we are reading the input tensor twice (once to compute scales, once to apply them)
137+
read_bytes = 2 * num_elements * bytes_per_input_el
138+
write_bytes = num_elements + 4 * (
139+
n_groups * K
140+
) # 1 byte per element in float8_e4m3fn + 4 bytes per fp32 scale
146141
read_write_bytes = read_bytes + write_bytes
147-
torch_mem_bw_gbps = (read_write_bytes) / (torch_time_us / 1e6) / 1e9
142+
torch_mem_bw_gbps = (read_write_bytes) / (torch_loop_time_us / 1e6) / 1e9
148143
triton_mem_bw_gbps = (read_write_bytes) / (triton_time_us / 1e6) / 1e9
149144

150145
return ExperimentResult(
151-
torch_time_us=torch_time_us,
146+
torch_loop_time_us=torch_loop_time_us,
152147
triton_time_us=triton_time_us,
153148
torch_mem_bw_gbps=torch_mem_bw_gbps,
154149
triton_mem_bw_gbps=triton_mem_bw_gbps,
@@ -157,10 +152,10 @@ def run_triton(
157152

158153
def print_results(experiments: List[Experiment]):
159154
headers = [
160-
"input_shape",
155+
"Mg,K",
161156
"n_groups",
162157
"high_precision_dtype",
163-
"torch_time_us",
158+
"torch_loop_time_us",
164159
"triton_time_us",
165160
"torch_mem_bw_gbps",
166161
"triton_mem_bw_gbps",
@@ -176,18 +171,18 @@ def print_results(experiments: List[Experiment]):
176171
input_shape,
177172
experiment.config.n_groups,
178173
experiment.config.high_precision_dtype,
179-
experiment.result.torch_time_us,
174+
experiment.result.torch_loop_time_us,
180175
experiment.result.triton_time_us,
181176
round(experiment.result.torch_mem_bw_gbps, 3),
182177
round(experiment.result.triton_mem_bw_gbps, 3),
183-
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
178+
f"{experiment.result.torch_loop_time_us / experiment.result.triton_time_us:.2f}x",
184179
]
185180
)
186181
print(tabulate(rows, headers=headers))
187182

188183

189-
def benchmark_cuda_function_in_microseconds(f, *args):
190-
return do_bench(lambda: f(*args), return_mode="median") * 1e3
184+
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
185+
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3
191186

192187

193188
def main():

0 commit comments

Comments
 (0)