-
Notifications
You must be signed in to change notification settings - Fork 322
Closed
Description
Documenting issues with mxfp8 grouped gemm and repro commands:
Prerequisite: install torch and fbgem-gpu-genai nightlies (CUDA 12.8 on B200): pip3 install --pre torch fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu128
-
Using uniform group sizes, test cases pass for M=2048 but CUDA illegal memory access for other values like M=1024 or 16640.
- Repro:
- Checkout torchao PR: [mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm #2848
- Verify current unit tests pass, enabling logging so we can log input shapes/strides:
pytest test/prototype/moe_training/test_scaled_grouped_mm.py -k test_mxfp8_grouped_gemm_with_dq_fwd_bwd -s --log-cli-level=INFO
- In test
test_mxfp8_grouped_gemm_with_dq_fwd_bwd
change "M" to 1024. Rerun pytest command above, and observe CUDA errors.
- Repro:
-
Using non-uniform group sizes results in CUDA illegal memory access errors
- Repro:
- Checkout torchao PR: [mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm #2848
- Verify current unit tests pass (same as above), enabling logging so we can log input shapes/strides:
pytest test/prototype/moe_training/test_scaled_grouped_mm.py -k test_mxfp8_grouped_gemm_with_dq_fwd_bwd -s --log-cli-level=INFO
- In same test as above
test_mxfp8_grouped_gemm_with_dq_fwd_bwd
, change group offsets to randomly generated ones (using multiple of 32) by commenting out the lineoffs = torch.arange(...)
and then un-commenting the lineoffs = generate_jagged_offs(...)
- Rerun unit test, observe CUDA mem access errors.
- For this one, I suspect the issue may be in my
to_blocked_per_group_2d
orto_blocked_per_group_3d
functions, which convert MXFP8 e8m0 scales to a blocked format on a per-token-group basis. I implemented these functions by using the FBGEMM unit test as a reference, but the unit test only exercises uniform group sizes, so there could be a gap.
- Repro:
Metadata
Metadata
Assignees
Labels
No labels