Skip to content

Commit 083d0c3

Browse files
[mxfp8 moe training] use dim1 cast cuda kernel in bwd (#2897)
1 parent fbe3df9 commit 083d0c3

File tree

3 files changed

+92
-79
lines changed

3 files changed

+92
-79
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@
2020
from torchao.prototype.moe_training.utils import (
2121
_is_column_major,
2222
)
23+
from torchao.prototype.mx_formats.config import (
24+
MXFP8Dim1CastKernelChoice,
25+
MXGemmKernelChoice,
26+
ScaleCalculationMode,
27+
)
2328
from torchao.prototype.mx_formats.mx_tensor import to_mx
29+
from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper
2430

2531
logger: logging.Logger = logging.getLogger(__name__)
2632

@@ -376,17 +382,18 @@ def backward(ctx, grad_out: torch.Tensor):
376382
# Transpose A so we can scale along the M dimension, then un-transpose.
377383
# A_t_data shape: (K, M)
378384
# A_t_scales shape: (K, M//block_size)
379-
A_t_scales, A_t_data = to_mx(
380-
A.transpose(-2, -1).contiguous(),
385+
A_t_mx = _to_mxfp8_dim1_kernel_wrapper(
386+
A,
387+
block_size,
381388
elem_dtype=torch.float8_e4m3fn,
382-
block_size=block_size,
383-
)
384-
385-
# A_data shape = (M, K)
386-
A_data = A_t_data.transpose(-2, -1)
387-
388-
# A_scales shape = (M//block_size, K)
389-
A_scales = A_t_scales.transpose(-2, -1)
389+
hp_dtype=A.dtype,
390+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
391+
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
392+
scale_calculation_mode=ScaleCalculationMode.FLOOR,
393+
)
394+
A_mx = A_t_mx.t()
395+
A_data = A_mx.qdata
396+
A_scales = A_mx._scale_e8m0.t()
390397

391398
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
392399
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11,86 +11,20 @@
1111
from typing import Any, Optional
1212

1313
import torch
14-
from torch.distributed._tensor import DTensor
1514

1615
from torchao.prototype.mx_formats.config import (
1716
MXFP8Dim1CastKernelChoice,
1817
MXGemmKernelChoice,
1918
MXLinearConfig,
2019
ScaleCalculationMode,
2120
)
22-
from torchao.prototype.mx_formats.kernels import (
23-
mxfp8_quantize_cuda,
24-
triton_to_mxfp8_dim1,
25-
)
2621
from torchao.prototype.mx_formats.mx_tensor import MXTensor
22+
from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper
2723
from torchao.quantization.transform_module import (
2824
register_quantize_module_handler,
2925
)
3026

3127

32-
def _to_mxfp8_dim1_kernel_wrapper(
33-
a,
34-
block_size,
35-
elem_dtype,
36-
hp_dtype,
37-
gemm_kernel_choice,
38-
cast_kernel_choice,
39-
scale_calculation_mode: ScaleCalculationMode,
40-
):
41-
if cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
42-
assert scale_calculation_mode == ScaleCalculationMode.FLOOR
43-
a_data, a_scale = triton_to_mxfp8_dim1(a, block_size)
44-
elif cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
45-
assert scale_calculation_mode in (
46-
ScaleCalculationMode.FLOOR,
47-
ScaleCalculationMode.RCEIL,
48-
)
49-
_, a_data, _, a_scale = mxfp8_quantize_cuda(
50-
a,
51-
rowwise=False,
52-
colwise=True,
53-
scaling_mode=scale_calculation_mode.value,
54-
)
55-
else:
56-
raise ValueError(f"must be one of [CUDA, TRITON], got {cast_kernel_choice}")
57-
58-
if isinstance(a_data, DTensor):
59-
assert isinstance(a_scale, DTensor)
60-
a_data_local = a_data.to_local()
61-
a_scale_local = a_scale.to_local()
62-
inner = MXTensor(
63-
a_data_local.t(),
64-
a_scale_local,
65-
elem_dtype,
66-
block_size,
67-
hp_dtype,
68-
gemm_kernel_choice,
69-
False,
70-
None,
71-
)
72-
mx_tensor = DTensor.from_local(
73-
inner,
74-
a_data.device_mesh,
75-
a_data.placements,
76-
run_check=False,
77-
shape=a_data.t().size(),
78-
stride=a_data.t().stride(),
79-
)
80-
else:
81-
mx_tensor = MXTensor(
82-
a_data.t(),
83-
a_scale,
84-
elem_dtype,
85-
block_size,
86-
hp_dtype,
87-
gemm_kernel_choice,
88-
False,
89-
None,
90-
)
91-
return mx_tensor
92-
93-
9428
@torch._dynamo.allow_in_graph
9529
class mx_mm(torch.autograd.Function):
9630
# There are three gemms in a forward + backward of a Linear layer:

torchao/prototype/mx_formats/utils.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,18 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
9-
from torchao.prototype.mx_formats.kernels import triton_mx_block_rearrange
8+
from torch.distributed._tensor import DTensor
9+
10+
from torchao.prototype.mx_formats.config import (
11+
MXFP8Dim1CastKernelChoice,
12+
ScaleCalculationMode,
13+
)
14+
from torchao.prototype.mx_formats.kernels import (
15+
mxfp8_quantize_cuda,
16+
triton_mx_block_rearrange,
17+
triton_to_mxfp8_dim1,
18+
)
19+
from torchao.prototype.mx_formats.mx_tensor import MXTensor
1020

1121
Tensor = torch.Tensor
1222

@@ -99,3 +109,65 @@ def _to_blocked_single(scales: Tensor) -> Tensor:
99109
assert scales.shape == (128, 4)
100110
scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles
101111
return scales_tiled.transpose(0, 1).reshape(32, 16) # Interleave tiles
112+
113+
114+
def _to_mxfp8_dim1_kernel_wrapper(
115+
a,
116+
block_size,
117+
elem_dtype,
118+
hp_dtype,
119+
gemm_kernel_choice,
120+
cast_kernel_choice,
121+
scale_calculation_mode: ScaleCalculationMode,
122+
):
123+
if cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
124+
assert scale_calculation_mode == ScaleCalculationMode.FLOOR
125+
a_data, a_scale = triton_to_mxfp8_dim1(a, block_size)
126+
elif cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
127+
assert scale_calculation_mode in (
128+
ScaleCalculationMode.FLOOR,
129+
ScaleCalculationMode.RCEIL,
130+
)
131+
_, a_data, _, a_scale = mxfp8_quantize_cuda(
132+
a,
133+
rowwise=False,
134+
colwise=True,
135+
scaling_mode=scale_calculation_mode.value,
136+
)
137+
else:
138+
raise ValueError(f"must be one of [CUDA, TRITON], got {cast_kernel_choice}")
139+
140+
if isinstance(a_data, DTensor):
141+
assert isinstance(a_scale, DTensor)
142+
a_data_local = a_data.to_local()
143+
a_scale_local = a_scale.to_local()
144+
inner = MXTensor(
145+
a_data_local.t(),
146+
a_scale_local,
147+
elem_dtype,
148+
block_size,
149+
hp_dtype,
150+
gemm_kernel_choice,
151+
False,
152+
None,
153+
)
154+
mx_tensor = DTensor.from_local(
155+
inner,
156+
a_data.device_mesh,
157+
a_data.placements,
158+
run_check=False,
159+
shape=a_data.t().size(),
160+
stride=a_data.t().stride(),
161+
)
162+
else:
163+
mx_tensor = MXTensor(
164+
a_data.t(),
165+
a_scale,
166+
elem_dtype,
167+
block_size,
168+
hp_dtype,
169+
gemm_kernel_choice,
170+
False,
171+
None,
172+
)
173+
return mx_tensor

0 commit comments

Comments
 (0)