|
20 | 20 | from torchao.prototype.moe_training.utils import (
|
21 | 21 | _is_column_major,
|
22 | 22 | )
|
| 23 | +from torchao.prototype.mx_formats.config import ( |
| 24 | + MXFP8Dim1CastKernelChoice, |
| 25 | + MXGemmKernelChoice, |
| 26 | + ScaleCalculationMode, |
| 27 | +) |
| 28 | +from torchao.prototype.mx_formats.mx_linear import _to_mxfp8_dim1_kernel_wrapper |
23 | 29 | from torchao.prototype.mx_formats.mx_tensor import to_mx
|
24 | 30 |
|
25 | 31 | logger: logging.Logger = logging.getLogger(__name__)
|
@@ -376,17 +382,18 @@ def backward(ctx, grad_out: torch.Tensor):
|
376 | 382 | # Transpose A so we can scale along the M dimension, then un-transpose.
|
377 | 383 | # A_t_data shape: (K, M)
|
378 | 384 | # A_t_scales shape: (K, M//block_size)
|
379 |
| - A_t_scales, A_t_data = to_mx( |
380 |
| - A.transpose(-2, -1).contiguous(), |
| 385 | + A_t_mx = _to_mxfp8_dim1_kernel_wrapper( |
| 386 | + A, |
| 387 | + block_size, |
381 | 388 | elem_dtype=torch.float8_e4m3fn,
|
382 |
| - block_size=block_size, |
383 |
| - ) |
384 |
| - |
385 |
| - # A_data shape = (M, K) |
386 |
| - A_data = A_t_data.transpose(-2, -1) |
387 |
| - |
388 |
| - # A_scales shape = (M//block_size, K) |
389 |
| - A_scales = A_t_scales.transpose(-2, -1) |
| 389 | + hp_dtype=A.dtype, |
| 390 | + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used |
| 391 | + cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, |
| 392 | + scale_calculation_mode=ScaleCalculationMode.FLOOR, |
| 393 | + ) |
| 394 | + A_mx = A_t_mx.t() |
| 395 | + A_data = A_mx.qdata |
| 396 | + A_scales = A_mx._scale_e8m0.t() |
390 | 397 |
|
391 | 398 | # grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
|
392 | 399 | grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
|
|
0 commit comments