Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 81 additions & 75 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def forward(
round_scales_to_power_of_2=True,
)
A_scaled = A.to(torch.float32) * A_scales
A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)
A_data_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)

# Convert B to float8, column-major for right operand of grouped GEMM.
# B_t shape: (E, K, N)
Expand All @@ -136,18 +136,18 @@ def forward(
round_scales_to_power_of_2=True,
)
B_t_scaled = B_t.to(torch.float32) * B_t_scales
B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)
B_t_data_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)

# Store what we need for backward.
ctx.save_for_backward(A, B_t, offs)
ctx.out_dtype = out_dtype

# Perform scaled grouped GEMM and return result.
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
assert not _is_column_major(A_fp8_row_major), (
assert not _is_column_major(A_data_row_major), (
"A must be row-major for output = A @ B"
)
assert _is_column_major(B_t_fp8_col_major), (
assert _is_column_major(B_t_data_col_major), (
"B must be column-major for output = A @ B"
)

Expand All @@ -157,8 +157,8 @@ def forward(
A_scales = A_scales.squeeze(-1)
B_t_scales = B_t_scales.squeeze(1)
return torch._scaled_grouped_mm(
A_fp8_row_major,
B_t_fp8_col_major,
A_data_row_major,
B_t_data_col_major,
A_scales.reciprocal(), # Reciprocals are needed for rescaling the output.
B_t_scales.reciprocal(),
offs,
Expand All @@ -184,13 +184,13 @@ def backward(ctx, grad_output: torch.Tensor):
round_scales_to_power_of_2=True,
)
grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales
grad_output_fp8_row_major = to_fp8_saturated(
grad_output_data_row_major = to_fp8_saturated(
grad_output_scaled, torch.float8_e4m3fn
)

# Compute B fp8 column-major for right operand of grouped GEMM:
# grad_A = grad_output @ B.
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
B_data_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
B_t._data if hasattr(B_t, "_data") else B_t,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
Expand All @@ -199,10 +199,10 @@ def backward(ctx, grad_output: torch.Tensor):
# Compute grad_A.
# grad_A = grad_output @ B
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
assert not _is_column_major(grad_output_fp8_row_major), (
assert not _is_column_major(grad_output_data_row_major), (
"grad_output must be row-major for grad_A = grad_output @ B"
)
assert _is_column_major(B_fp8_col_major), (
assert _is_column_major(B_data_col_major), (
"B must be column-major for grad_A = grad_output @ B"
)

Expand All @@ -212,8 +212,8 @@ def backward(ctx, grad_output: torch.Tensor):
grad_output_scales = grad_output_scales.squeeze(-1)
B_scales = B_scales.squeeze(1)
grad_A = torch._scaled_grouped_mm(
grad_output_fp8_row_major,
B_fp8_col_major,
grad_output_data_row_major,
B_data_col_major,
grad_output_scales.reciprocal(),
B_scales.reciprocal(),
offs,
Expand All @@ -227,18 +227,18 @@ def backward(ctx, grad_output: torch.Tensor):
# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_B: grad_output_t @ A
# Use transpose method to avoid uncoalesced memory accesses.
grad_out_fp8_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales(
grad_out_data_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales(
grad_output.t()
.contiguous()
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
offs,
torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
grad_output_t_fp8_row_major = grad_out_fp8_colwise.t()
grad_output_t_data_row_major = grad_out_data_colwise.t()
grad_output_t_scales = grad_out_scales.t()

A_fp8_col_major, A_scales = triton_fp8_per_group_colwise_scales(
A_data_col_major, A_scales = triton_fp8_per_group_colwise_scales(
A.t()
.contiguous()
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
Expand All @@ -249,19 +249,19 @@ def backward(ctx, grad_output: torch.Tensor):

# Compute grad_B = grad_output_t @ A.
# grad_B = grad_output_t @ A
assert not _is_column_major(grad_output_t_fp8_row_major), (
assert not _is_column_major(grad_output_t_data_row_major), (
"grad_output_t must be row-major for grad_B = grad_output_t @ A"
)
assert _is_column_major(A_fp8_col_major), (
assert _is_column_major(A_data_col_major), (
"A must be column-major for grad_B = grad_output_t @ A"
)

# Per-token group scales computed via triton kernels above do not have
# the empty dim like the scales computed via tensor_to_scale, so we need
# don't need to squeeze here.
grad_B = torch._scaled_grouped_mm(
grad_output_t_fp8_row_major,
A_fp8_col_major,
grad_output_t_data_row_major,
A_data_col_major,
grad_output_t_scales.reciprocal(),
A_scales.reciprocal(),
offs,
Expand Down Expand Up @@ -295,13 +295,15 @@ def forward(
ctx.out_dtype = out_dtype
ctx.emulated = emulated

# A_mx shape: (M, K)
# A_data shape: (M, K)
# A_scale shape: (M, K//block_size)
A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
A_scale, A_data = to_mx(
A, elem_dtype=torch.float8_e4m3fn, block_size=block_size
)

# B_mx shape: (E, N, K)
# B_data shape: (E, N, K)
# B_scale shape: (E, N, K//block_size)
B_scales, B_mx = to_mx(
B_scales, B_data = to_mx(
B_t.transpose(-2, -1),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
Expand All @@ -315,9 +317,9 @@ def forward(
else fbgemm_mxfp8_grouped_mm_2d_3d
)
out = mxfp8_2d_3d_grouped_mm(
A_mx,
A_data,
A_scale,
B_mx,
B_data,
B_scales,
offs=offs,
block_size=block_size,
Expand All @@ -332,15 +334,15 @@ def backward(ctx, grad_out: torch.Tensor):
out_dtype = ctx.out_dtype
emulated = ctx.emulated

# grad_out_mx shape: (M, N)
# grad_out_data shape: (M, N)
# grad_out_scale shape: (M, N//block_size)
grad_out_scale, grad_out_mx = to_mx(
grad_out_scale, grad_out_data = to_mx(
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
)

# B_mx shape: (E, K, N)
# B_data shape: (E, K, N)
# B_scale shape: (E, K, N//block_size)
B_scales, B_mx = to_mx(
B_scales, B_data = to_mx(
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
B_t.contiguous(),
elem_dtype=torch.float8_e4m3fn,
Expand All @@ -354,43 +356,43 @@ def backward(ctx, grad_out: torch.Tensor):
else fbgemm_mxfp8_grouped_mm_2d_3d
)
grad_A = mxfp8_2d_3d_grouped_mm(
grad_out_mx,
grad_out_data,
grad_out_scale,
B_mx,
B_data,
B_scales,
offs=offs,
out_dtype=out_dtype,
)

# grad_out_t_mx shape: (N, M)
# grad_out_t_data shape: (N, M)
# grad_out_t_scales shape: (N, M//block_size)
grad_out_t_scales, grad_out_t_mx = to_mx(
grad_out_t_scales, grad_out_t_data = to_mx(
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
grad_out.transpose(-2, -1).contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)

# Transpose A so we can scale along the M dimension, then un-transpose.
# A_t_mx shape: (K, M)
# A_t_data shape: (K, M)
# A_t_scales shape: (K, M//block_size)
A_t_scales, A_t_mx = to_mx(
A_t_scales, A_t_data = to_mx(
A.transpose(-2, -1).contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)

# A_mx shape = (M, K)
A_mx = A_t_mx.transpose(-2, -1)
# A_data shape = (M, K)
A_data = A_t_data.transpose(-2, -1)

# A_scales shape = (M//block_size, K)
A_scales = A_t_scales.transpose(-2, -1)

# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
grad_out_t_mx,
grad_out_t_data,
grad_out_t_scales,
A_mx,
A_data,
A_scales,
offs=offs,
)
Expand All @@ -402,64 +404,68 @@ def backward(ctx, grad_out: torch.Tensor):


def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
A_mx: torch.Tensor,
A_data: torch.Tensor,
A_scale: torch.Tensor,
B_mx: torch.Tensor,
B_data: torch.Tensor,
B_scale: torch.Tensor,
offs: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
block_size: int = 32,
) -> torch.Tensor:
assert A_mx.ndim == 2, f"A must be 2D, got {A_mx.ndim}"
assert B_mx.ndim == 3, f"B must be 3D, got {B_mx.ndim}"
assert A_scale.shape[0] == A_mx.shape[0], (
f"A_scale must have same M dim as A_mx, got A={A_mx.shape} and A_scale={A_scale.shape}"
assert A_data.ndim == 2, f"A must be 2D, got {A_data.ndim}"
assert B_data.ndim == 3, f"B must be 3D, got {B_data.ndim}"
assert A_scale.shape[0] == A_data.shape[0], (
f"A_scale must have same M dim as A_data, got A={A_data.shape} and A_scale={A_scale.shape}"
)
assert A_scale.shape[1] == A_mx.shape[1] // block_size, (
f"A_scale dim1 should be size K//block_size, got A={A_mx.shape} and A_scale={A_scale.shape}"
assert A_scale.shape[1] == A_data.shape[1] // block_size, (
f"A_scale dim1 should be size K//block_size, got A={A_data.shape} and A_scale={A_scale.shape}"
)
assert B_scale.shape[0] == B_mx.shape[0], (
f"B_scale must have same E dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}"
assert B_scale.shape[0] == B_data.shape[0], (
f"B_scale must have same E dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}"
)
assert B_scale.shape[1] == B_mx.shape[1], (
f"B_scale must have same N dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}"
assert B_scale.shape[1] == B_data.shape[1], (
f"B_scale must have same N dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}"
)
assert B_scale.shape[2] == B_mx.shape[2] // block_size, (
f"B_scale dim2 should be size K//block_size, got B={B_mx.shape} and B_scale={B_scale.shape}"
assert B_scale.shape[2] == B_data.shape[2] // block_size, (
f"B_scale dim2 should be size K//block_size, got B={B_data.shape} and B_scale={B_scale.shape}"
)

# Dequantize input
# A_mx shape: (M, K)
# A_data shape: (M, K)
# A_scale shape: (M, K//block_size)
A_orig_shape = A_mx.shape
A_orig_shape = A_data.shape

# Reshape to be able to do per-scaling group multiplication
# A_mx shape: (M, K//block_size, block_size)
# A_data shape: (M, K//block_size, block_size)
# A_scale shape: (M, K//block_size, 1)
A_mx = A_mx.reshape(*A_mx.shape[:-1], A_mx.shape[-1] // block_size, block_size)
A_data = A_data.reshape(
*A_data.shape[:-1], A_data.shape[-1] // block_size, block_size
)
A_scale = A_scale.unsqueeze(-1)

# Rescale and cast to bfloat16
A = A_mx.to(torch.bfloat16) * A_scale.to(torch.bfloat16)
A = A_data.to(torch.bfloat16) * A_scale.to(torch.bfloat16)

# Reshape back to original shape
# A shape: (M, K)
A = A.reshape(A_orig_shape)

# Dequantize weights
# Tranpose to get block_size on rightmost dim
# B_mx shape: (E, N, K)
# B_data shape: (E, N, K)
# B_scale shape: (E, N, K//block_size)
E, N, K = B_mx.shape
E, N, K = B_data.shape

# Reshape to be able to do per-scaling group multiplication
# B_mx shape: (E, N, K//block_size, block_size)
# B_data shape: (E, N, K//block_size, block_size)
# B_scale shape: (E, N, K//block_size, 1)
B_mx = B_mx.reshape(*B_mx.shape[:-1], B_mx.shape[-1] // block_size, block_size)
B_data = B_data.reshape(
*B_data.shape[:-1], B_data.shape[-1] // block_size, block_size
)
B_scale = B_scale.unsqueeze(-1)

# Rescale and cast to bfloat16
B = B_mx.to(torch.bfloat16) * B_scale.to(torch.bfloat16)
B = B_data.to(torch.bfloat16) * B_scale.to(torch.bfloat16)

# Reshape back to original shape
# B shape: (E, K, N)
Expand All @@ -471,27 +477,27 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(


def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
A_mx: torch.Tensor, # (M, K)
A_data: torch.Tensor, # (M, K)
A_scale: torch.Tensor, # (M, K//block_size)
B_mx: torch.Tensor, # (K, N)
B_data: torch.Tensor, # (K, N)
B_scale: torch.Tensor, # (K//block_size, N)
offs: torch.Tensor,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
block_size: int = 32,
) -> torch.Tensor:
assert A_mx.ndim == 2, "A must be 2D"
assert B_mx.ndim == 2, "B must be 2D"
assert A_data.ndim == 2, "A must be 2D"
assert B_data.ndim == 2, "B must be 2D"
A = torch.zeros(
A_mx.shape,
A_data.shape,
dtype=torch.bfloat16,
device=A_mx.device,
requires_grad=A_mx.requires_grad,
device=A_data.device,
requires_grad=A_data.requires_grad,
)
B = torch.zeros(
B_mx.shape,
B_data.shape,
dtype=torch.bfloat16,
device=B_mx.device,
requires_grad=B_mx.requires_grad,
device=B_data.device,
requires_grad=B_data.requires_grad,
)

# Dequantize input per each scaling group
Expand All @@ -507,7 +513,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
# -- Dequantize A tensor
# A_group shape: (M, group_size)
# A_scale shape: (M, group_size//block_size)
A_group = A_mx[:, group_start_idx:group_end_idx]
A_group = A_data[:, group_start_idx:group_end_idx]
A_group_shape = A_group.shape

# Get scales for this group.
Expand All @@ -532,7 +538,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(

# -- Dequantize B tensor
# B_group shape is (group_size, N)
B_group = B_mx[group_start_idx:group_end_idx, :]
B_group = B_data[group_start_idx:group_end_idx, :]
B_group_shape = B_group.shape

# Scales shape is (group_size//block_size, N)
Expand Down
Loading