-
Notifications
You must be signed in to change notification settings - Fork 322
Description
Creating this issue as a roadmap/tracker for enabling float8 training for MoEs with token-choice routing. Both core requirements as well as ideas for additional performance optimizations are included.
UPDATE 07/22/2025: revised priorities to reflect shifting focus from fp8 rowwise => fp8 blockwise and mxfp8
This is not an exhaustive list, but highlights some primary milestones / requirements
Compute
- fp8 rowwise
- Add torch._scaled_grouped_mm kernel in core
- Add differentiable scaled grouped mm with dynamic float8 rowwise quant in torchao
- Add custom kernels in torchao for performing per-group scaling on device, to avoid host-device sync
- Faster inductor codegen kernels for dynamic quant of 3d tensors along dim1: Inductor codegen for float8 dynamic quantization ops for scaled_grouped_mm backward pass is slow pytorch#159769
- alternatively, handwritten triton kernel faster than torch.compile for this ([moe training] add fp8 rowwise kernels for expert weights #2696)
- this also needs to be faster [MoE fp8 rowwise training] Runtime of quantizing 3d expert weights scales worse than linearly #2880
- alternatively, handwritten triton kernel faster than torch.compile for this ([moe training] add fp8 rowwise kernels for expert weights #2696)
- fp8 blockwise
- quant primitives
- DeepGEMM integration for fp8 blockwise grouped GEMM
- triton kernels to do scaling per group without d2h sync
- mxpf8
- mxfp8 scaled grouped gemm Add MXFP8 Support to scaled_grouped_gemm pytorch#153502
- 2d-3d gemm for output and dX ([mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm #2848)
- 2d-2d gemm for dW
- torchao differentiable _scaled_grouped_mm support for mxpf8 recipe for dynamic quant before grouped GEMMs
- triton kernels for per token group scale conversion to blocked swizzled format
- for 2d inputs ([mxfp8 moe training] add per group blocked scale kernels #2886)
- for 3d expert weights
- mxfp8 scaled grouped gemm Add MXFP8 Support to scaled_grouped_gemm pytorch#153502
Communication
I looked at traces and validated "all to all dispatch -> grouped gemm -> all to all combine" are all sequentially dependent, so in theory faster/low precision comms should improve performance. There is some overlap with the shared expert computation, but it is not 100% overlap, so there is room for optimization. This will be especially important if/when "all to all" spans multiple nodes, where inter-node network bandwidth is lower than the intra-node NVLink bandwidth.
This is also inspired by the DeepSeekV3 paper where, if I understand correctly, they do a2a dispatch in fp8 but keep a2a combine in bf16 as they found it was more sensitive to low precision during training.
- Add on device all_to_all_v kernels compatible with:
- mxfp8 (P0)
- float8 blockwise (P0)
- float8 rowwise (P1)
- token permutation kernel supports low precision dtypes by permuting scales to be in proper order for permuted tokens (link)
- mxfp8 (P0)
- float8 blockwise (P0)
- float8 rowwise (P1)
Torchao UX
- Add tensor subclass (ScaledGroupedMMTensor) with an op override for
torch.aten._grouped_mm
=> runs differentiable scaled grouped mm - Add one line model conversion API, should recursively swap nn.Parameter data tensors of the expert weights with ScaledGroupedMMTensor.
- support configurable recipe (fp8 blockwise/rowwise, mxpf8)
Compile support
- Compile support for
torch._grouped_mm
- Differentiable _scaled_grouped_mm can compile with
fullgraph=True
- E2E compilation of each TranformerBlock in torchtitan after MoE conversion via tensor subclass approach (fullgraph=False)
- E2E compilation of each TranformerBlock in torchtitan after MoE conversion via tensor subclass approach (fullgraph=True)
Distributed support
- Composability with FSDP2 (will likely need something like this for the new tensor subclass)
- mxfp8 (P0)
- float8 blockwise (P0)
- float8 rowwise (P1) [float8 moe training] FSDP support #2413
- Composability with TP
- mxfp8 (P0)
- float8 blockwise (P0)
- float8 rowwise (P1) [moe training] Add TP support for routed experts #2473
- Composability with FSDP + TP
- mxfp8 (P0)
- float8 blockwise (P0)
- float8 rowwise (P1) [moe training] Add 2D parallel (FSDP2 + TP) tests for routed experts #2475
- Composability with dp2ep as implemented here: dp2ep Expert Parallel torchtitan#1324
- mxfp8 (P0)
- float8 blockwise (P0)
- float8 rowwise (P1) [WIP] [moe training] Add tests for 3D parallel (FDSP + TP + EP) for routed experts #2481