Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
class ExperimentConfig:
high_precision_dtype: torch.dtype
input_shape: tuple[int]
power_of_2_scales: bool


@dataclass(frozen=True)
Expand All @@ -48,7 +49,7 @@ class Experiment:


def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes
# Llama4 shapes (E, N, K)
input_shapes = [
(1, 8192, 5120), # w1, w3
(1, 5120, 8192), # w2
Expand All @@ -58,14 +59,16 @@ def get_configs() -> List[ExperimentConfig]:
(128, 5120, 8192), # w2
]
high_precision_dtypes = [torch.bfloat16]
power_of_2_scales = [True, False]
configs = []
for input_shape, high_precision_dtype in itertools.product(
input_shapes, high_precision_dtypes
for input_shape, high_precision_dtype, power_of_2_scale in itertools.product(
input_shapes, high_precision_dtypes, power_of_2_scales
):
configs.append(
ExperimentConfig(
input_shape=input_shape,
high_precision_dtype=high_precision_dtype,
power_of_2_scales=power_of_2_scale,
)
)
return configs
Expand All @@ -87,18 +90,16 @@ def run_torch(input_tensor: torch.Tensor):
out = torch_to_3d_rowwise_float8_transpose_rhs(
input_tensor,
target_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
round_scales_to_power_of_2=config.power_of_2_scales,
)
torch.cuda.synchronize()
return out

def run_triton(input_tensor: torch.Tensor):
out = triton_fp8_rowwise_3d_transpose_rhs(
input_tensor,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
round_scales_to_power_of_2=config.power_of_2_scales,
)
torch.cuda.synchronize()
return out

# bench torch
Expand Down Expand Up @@ -141,6 +142,7 @@ def run_triton(input_tensor: torch.Tensor):
def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"power_of_2_scales",
"torch_time_us",
"triton_time_us",
"torch_mem_bw_gbps",
Expand All @@ -153,6 +155,7 @@ def print_results(experiments: List[Experiment]):
rows.append(
[
input_shape,
experiment.config.power_of_2_scales,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
round(experiment.result.torch_mem_bw_gbps, 3),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes
A_shapes = [(16640, 5120)]
B_shapes = [(1, 8192, 5120), (4, 8192, 5120), (16, 8192, 5120), (64, 8192, 5120)]
recipes = [MoEScalingType.FP8_ROWWISE]
recipes = [MoEScalingType.FP8_ROWWISE, MoEScalingType.MXFP8]
high_precision_dtypes = [torch.bfloat16]
configs = []
for A_shape, B_shape, recipe, high_precision_dtype in itertools.product(
Expand Down
27 changes: 10 additions & 17 deletions torchao/prototype/moe_training/kernels/float8_rowwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
torch.float64: tl.float64,
}

block_sizes_n = [32, 128, 256] # large dim (output_features)
block_sizes_k = [32, 128, 256] # small dim (input_features)
num_warps = [2, 4]
num_stages = [2, 3, 4, 5, 6]
block_sizes_n = [128] # large dim (output_features)
block_sizes_k = [128] # small dim (input_features)
num_warps = [4]
num_stages = [4]
kernel_configs_2D = [
triton.Config(
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
Expand Down Expand Up @@ -172,9 +172,7 @@ def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
+ (n_offs[None, :] * stride_input_dim2)
)
input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N)
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to(
input_dtype
)
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0)

# In a normal torch implementation, we should transpose the tensor then compute the amax
# along the dim1 (N), to compute colwise scales for a RHS operand of a scaled grouped gemm:
Expand Down Expand Up @@ -243,25 +241,20 @@ def _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel(
+ (n_offs[None, :] * stride_input_dim2)
)
input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N)
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to(
input_dtype
)
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0)
input_data = input_data.trans(1, 0) # (K, N) -> (N, K)

# load global scales for this block of the given expert - shape (1, K)
scales_offs = (
expert_idx[:, None] * stride_scales_dim0 + k_offs[None, :] * stride_scales_dim1
)
scales_mask = k_offs[None, :] < K
scales = tl.load(scales_ptr + scales_offs, mask=scales_mask, other=0.0).to(
tl.float32
)
scales = tl.load(scales_ptr + scales_offs, mask=scales_mask, other=0.0)

# transpose data and apply scales - shape (N,K) * (1,K) = (N,K)
scaled_data = input_data * scales
output_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to(
output_dtype
)
output_data = tl.clamp(
input_data * scales, min=fp8_dtype_min, max=fp8_dtype_max
).to(output_dtype)

# store transpose and store output data - shape (N, K)
output_offs = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
torch.float64: tl.float64,
}

block_sizes = [1, 16, 32, 64]
block_sizes_iter = [64, 128, 256]
block_sizes = [32] # [16, 32, 64]
block_sizes_iter = [128] # [64, 128, 256]
num_warps = [4]
num_stages = [3]
kernel_configs_2D = [
Expand Down
26 changes: 15 additions & 11 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
from torchao.prototype.moe_training.kernels import (
triton_fp8_per_group_colwise_scales,
triton_fp8_per_group_rowwise_scales,
triton_fp8_rowwise_3d_transpose_rhs,
)
from torchao.prototype.moe_training.utils import (
Expand Down Expand Up @@ -174,8 +173,8 @@ def backward(ctx, grad_output: torch.Tensor):
# Convert grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_A: grad_output @ B
#
# grad_output shape: (M, N)
# grad_output_scale shape: (M, 1)
# grad_output shape: (Mg, N)
# grad_output_scale shape: (Mg, 1)
grad_output_scales = tensor_to_scale(
grad_output,
torch.float8_e4m3fn,
Expand Down Expand Up @@ -226,17 +225,22 @@ def backward(ctx, grad_output: torch.Tensor):

# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_B: grad_output_t @ A
grad_output_t_fp8_row_major, grad_output_t_scales = (
triton_fp8_per_group_rowwise_scales(
grad_output.transpose(-2, -1),
offs,
torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
# Use transpose method to avoid uncoalesced memory accesses.
grad_out_fp8_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales(
grad_output.t()
.contiguous()
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
offs,
torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
grad_output_t_fp8_row_major = grad_out_fp8_colwise.t()
grad_output_t_scales = grad_out_scales.t()

A_fp8_col_major, A_scales = triton_fp8_per_group_colwise_scales(
A,
A.t()
.contiguous()
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
offs,
torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
Expand Down
12 changes: 6 additions & 6 deletions torchao/prototype/moe_training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,21 @@ def torch_to_3d_rowwise_float8_transpose_rhs(
Scales shape: (E, 1, K
"""
assert _is_column_major(input_hp_t), "input tensor must be column-major"
input_hp = input_hp_t.transpose(-2, -1) # (E, N, K)
scales = tensor_to_scale(
input_hp,
input_hp_t,
target_dtype,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-2,
axiswise_dim=-1,
round_scales_to_power_of_2=round_scales_to_power_of_2,
) # (E, 1, K)
) # (E, K, 1)

# Apply scales to tensor and convert to float8.
tensor_scaled = input_hp.to(torch.float32) * scales
tensor_scaled = input_hp_t.to(torch.float32) * scales
float8_tensor = to_fp8_saturated(tensor_scaled, target_dtype)

# To column major
float8_tensor = float8_tensor.transpose(-2, -1).contiguous().transpose(-2, -1)
float8_tensor = float8_tensor.contiguous().transpose(-2, -1)
scales = scales.transpose(-2, -1)
return float8_tensor, scales


Expand Down
Loading