Skip to content

Commit 22facfa

Browse files
[mxfp8 moe training] use dim1 cast cuda kernel in bwd
stack-info: PR: #2897, branch: danielvegamyhre/stack/64
1 parent 95c3e51 commit 22facfa

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
from torchao.prototype.moe_training.utils import (
2121
_is_column_major,
2222
)
23+
from torchao.prototype.mx_formats.config import (
24+
MXFP8Dim1CastKernelChoice,
25+
MXGemmKernelChoice,
26+
ScaleCalculationMode,
27+
)
28+
from torchao.prototype.mx_formats.mx_linear import _to_mxfp8_dim1_kernel_wrapper
2329
from torchao.prototype.mx_formats.mx_tensor import to_mx
2430

2531
logger: logging.Logger = logging.getLogger(__name__)
@@ -376,17 +382,18 @@ def backward(ctx, grad_out: torch.Tensor):
376382
# Transpose A so we can scale along the M dimension, then un-transpose.
377383
# A_t_data shape: (K, M)
378384
# A_t_scales shape: (K, M//block_size)
379-
A_t_scales, A_t_data = to_mx(
380-
A.transpose(-2, -1).contiguous(),
385+
A_t_mx = _to_mxfp8_dim1_kernel_wrapper(
386+
A,
387+
block_size,
381388
elem_dtype=torch.float8_e4m3fn,
382-
block_size=block_size,
383-
)
384-
385-
# A_data shape = (M, K)
386-
A_data = A_t_data.transpose(-2, -1)
387-
388-
# A_scales shape = (M//block_size, K)
389-
A_scales = A_t_scales.transpose(-2, -1)
389+
hp_dtype=A.dtype,
390+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
391+
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
392+
scale_calculation_mode=ScaleCalculationMode.FLOOR,
393+
)
394+
A_mx = A_t_mx.t()
395+
A_data = A_mx.qdata
396+
A_scales = A_mx._scale_e8m0.t()
390397

391398
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
392399
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(

0 commit comments

Comments
 (0)