Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
torch_to_blocked_per_group_3d,
triton_mx_block_rearrange_per_group_3d,
)

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
input_shape: tuple[int]


@dataclass(frozen=True)
class ExperimentResult:
torch_time_us: float
triton_time_us: float
torch_mem_bw_gbps: float
triton_mem_bw_gbps: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes. Input activations are scaled along K dim.
block_size = 32
input_shapes = [
# w1, w3 scaled along K (fwd)
(1, 8192, 5120 // block_size),
(2, 8192, 5120 // block_size),
(4, 8192, 5120 // block_size),
(8, 8192, 5120 // block_size),
(16, 8192, 5120 // block_size),
# w2 scaled along K (fwd)
(1, 5120, 8192 // block_size),
(2, 5120, 8192 // block_size),
(4, 5120, 8192 // block_size),
(8, 5120, 8192 // block_size),
(16, 5120, 8192 // block_size),
]
configs = []
for shape in input_shapes:
configs.append(
ExperimentConfig(
input_shape=shape,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
input_tensor = torch.randint(
low=0,
high=256,
size=config.input_shape,
dtype=torch.uint8,
device=device,
)

def warmup(fn, *args, **kwargs):
for _ in range(5):
fn(*args, **kwargs)

E, N, K = config.input_shape

# bench torch
compiled_run_torch = torch.compile(torch_to_blocked_per_group_3d)
warmup(compiled_run_torch, input_tensor)
torch_time_us = benchmark_cuda_function_in_microseconds(
compiled_run_torch,
input_tensor,
)

# bench triton
triton_out_scales = triton_mx_block_rearrange_per_group_3d(input_tensor)
warmup(triton_mx_block_rearrange_per_group_3d, input_tensor)
triton_time_us = benchmark_cuda_function_in_microseconds(
triton_mx_block_rearrange_per_group_3d,
input_tensor,
)

# mem bw calculations
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8

read_bytes = input_tensor.numel() * bytes_per_input_el
write_bytes = triton_out_scales.numel() * bytes_per_output_el

torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)

return ExperimentResult(
torch_time_us=torch_time_us,
triton_time_us=triton_time_us,
torch_mem_bw_gbps=torch_mem_bw_gbps,
triton_mem_bw_gbps=triton_mem_bw_gbps,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"torch_time_us",
"triton_time_us",
"torch_mem_bw_gbps",
"triton_mem_bw_gbps",
"triton_speedup",
]
rows = []
for experiment in experiments:
input_shape = f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]}, {experiment.config.input_shape[2]})"
rows.append(
[
input_shape,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
round(experiment.result.torch_mem_bw_gbps, 3),
round(experiment.result.triton_mem_bw_gbps, 3),
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
]
)
print(tabulate(rows, headers=headers))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
26 changes: 26 additions & 0 deletions test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
compute_per_group_blocked_scale_offsets,
torch_to_blocked_per_group_2d,
torch_to_blocked_per_group_3d,
triton_mx_block_rearrange_per_group_2d,
triton_mx_block_rearrange_per_group_3d,
)
from torchao.prototype.moe_training.utils import (
_is_column_major,
Expand Down Expand Up @@ -240,3 +242,27 @@ def test_mxfp8_per_group_blocked_scales_2d(
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
"blocked scales not equal"
)


@skip_if_rocm("ROCm enablement in progress")
@pytest.mark.parametrize("e,n,k", [(1, 8192, 5120), (2, 8192, 5120), (8, 5120, 8192)])
def test_mxfp8_per_group_blocked_scales_3d(
e: int,
n: int,
k: int,
):
device = "cuda"
block_size = 32
weights = torch.randn(e, n, k // block_size, device=device)
weight_scales, _ = to_mx(
weights, elem_dtype=torch.float8_e4m3fn, block_size=block_size
)

# torch reference
ref_out_scales = torch_to_blocked_per_group_3d(weight_scales)

# triton kernel
triton_out_scales = triton_mx_block_rearrange_per_group_3d(weight_scales)
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
"blocked scales not equal"
)
120 changes: 120 additions & 0 deletions torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def triton_scale_swizzle_per_group_2d(
# We track how many row blocks we have iterated through.
block_row_id = 0
current_start_row = input_group_start_row

# TODO: Investigate if it is possible and beneficial to parallelize along
# row blocks as well, and get rid of this loop.
while current_start_row < input_group_end_row:
Expand All @@ -237,3 +238,122 @@ def triton_scale_swizzle_per_group_2d(
# Update row block id to next block
block_row_id += 1
current_start_row += BLOCK_ROWS


def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.Tensor:
"""
Rearranges an E8M0 tensor scale to block-scaled swizzle format.

This format is suitable for Tmem as described in NVIDIA documentation:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout

Args:
scale_tensor: Input tensor in row-major format with 8-bit elements

Returns:
Rearranged tensor in block-scaled swizzle format
"""
assert scale_tensor.ndim == 3, "scales tensor must be 3d"
assert scale_tensor.element_size() == 1, (
"Expected element size to be 1 byte (8 bits)"
)

num_groups, rows, cols = scale_tensor.shape
input_stride_dim0 = scale_tensor.stride(0)
input_stride_dim1 = scale_tensor.stride(1)
input_stride_dim2 = scale_tensor.stride(2)

# Calculate blocks needed and allocate output tensor
num_row_blocks = triton.cdiv(rows, 128)
num_col_blocks = triton.cdiv(cols, 4)
padded_rows = num_row_blocks * 128
padded_cols = num_col_blocks * 4
output = scale_tensor.new_empty((num_groups, padded_rows * padded_cols))
output_stride_dim0 = output.stride(0)

# We probably want handle multiple blocks per tile but for now keep it simple
BLOCK_ROWS, BLOCK_COLS = 128, 4

# Output block stride for the rearranged format
output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)

grid = lambda META: (
num_groups,
num_row_blocks,
num_col_blocks,
)

triton_scale_swizzle_per_group_3d[grid](
scale_tensor.view(torch.uint8),
input_stride_dim0,
input_stride_dim1,
input_stride_dim2,
output.view(torch.uint8),
output_stride_dim0,
output_block_stride,
rows,
cols,
BLOCK_ROWS=BLOCK_ROWS,
BLOCK_COLS=BLOCK_COLS,
)

return output


@triton.jit
def triton_scale_swizzle_per_group_3d(
input_ptr,
input_stride_dim0,
input_stride_dim1,
input_stride_dim2,
output_ptr,
output_stride_dim0,
output_block_stride,
scale_rows,
scale_cols,
BLOCK_ROWS: tl.constexpr,
BLOCK_COLS: tl.constexpr,
):
pid_group = tl.program_id(0)
pid_row = tl.program_id(1)
pid_col = tl.program_id(2)

# Update base pointers based on this group id
input_ptr += pid_group * input_stride_dim0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this one we can just call the other triton kernl impl though right?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could potentially modify the other impl but the result would be rather ugly IMO, see below (from PR description, sorry it was a bit long):

"I could have potentially expanded the existing dense model kernel to do this since there's no variable group sizes, but i didn't because (1) kernel needs different launch grid and additional params passed into it, so we'd have a weird if-else launching different kernels with different grids and different params. And (2) I would like to keep all the MoE training code in the prototype/moe_training to make the code base as self-contained as possible for now."

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe... you could also just factor out the main kernel into a series of triton jit functions and then the kernel builds them up

non blocking

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that could work. Sounds good

output_ptr += pid_group * output_stride_dim0

rows = tl.arange(0, BLOCK_ROWS)[:, None]
cols = tl.arange(0, BLOCK_COLS)[None, :]

# Calculate starting row and column for this tile
start_row = pid_row * BLOCK_ROWS
start_col = pid_col * BLOCK_COLS
global_rows = start_row + rows
global_cols = start_col + cols

mask = (global_rows < scale_rows) & (global_cols < scale_cols)

input_scales = tl.load(
input_ptr + global_rows * input_stride_dim1 + global_cols * input_stride_dim2,
mask=mask,
other=0.0,
)

r_div_32 = rows // 32
r_mod_32 = rows % 32

# 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols

# Flatten
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))

# Calculate block offset using provided output block stride
LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS
block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride)

tl.store(
output_ptr + block_offset + dest_indices_flat,
scales_flat,
)
2 changes: 1 addition & 1 deletion torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ def triton_scale_swizzle(
@torch.library.custom_op("torchao::triton_mx_block_rearrange", mutates_args=())
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
"""
Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format.
Rearranges an E8M0 tensor scale to block-scaled swizzle format.

This format is suitable for Tmem as described in NVIDIA documentation:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Expand Down