-
Notifications
You must be signed in to change notification settings - Fork 321
Description
- Kernel name: _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel
- inputs: column-major input tensor of shape (E,K,N) and scales tensor of shape (E,K)
- outputs: column-major output tensor transposed casted to fp8 rowwise - shape (E,N,K)
- Repro:
- Checkout this PR: [moe fp8 training] use transpose method when quantizing to avoid uncoalesced gmem accesses #2864
- Option 1: Run test for this kernel:
pytest test/prototype/moe_training/test_kernels.py -k test_fp8_rowwise_3d_transpose_rhs_atomic
- Option 2: Run bench comparing this kernel to other implementations:
python benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py
- This will run 3 implementations to compare perf, but I am only concerned about the
_triton_fp8_rowwise_3d_transpose_scales_rhs_kernel
- This will run 3 implementations to compare perf, but I am only concerned about the
- TTIR: https://www.internalfb.com/phabricator/paste/view/P1918713582
- TTGIR: https://www.internalfb.com/phabricator/paste/view/P1918714134
- Warnings/remarks: https://www.internalfb.com/phabricator/paste/view/P1918025647
Lines with bank conflicts



Additional context
With #2864 we have modest speedups for MoE fp8 rowwise training, when experts per device is <= 16.
I did some perf analysis to determine why perf regressed as number of experts grew, and found 2 specific kernels are the culprits - and both are for quantizing the 3d expert weights tensor (1st time in forward for out = input @ weight.t()
and 2nd time the non-transposed tensor for grad_input = grad_output_t @ weight
).
When scaling up from 4 experts to 16 experts, kernels quantizing the input activations have the same runtime as expected, since inputs are the same size. However, the 2 kernels quantizing the weight described above take 6x as long (for weights that are only 4x as big).
- The kernel used in forward pass to quantize weight^T is codegen by inductor.
- The kernel used in backward pass to quantize weight (non transpose) is handwritten (this one) since inductor was too slow. This kernel is ~20% faster than inductor but still scales poorly, as described above.
NCU flags shows 71% memory bandwidth utilization, but flags (1) scoreboard stalls / bank conflicts and (2) low occupancy due to register requirements:
