Skip to content

[MoE fp8 rowwise training] Runtime of quantizing 3d expert weights scales worse than linearly #2880

@danielvegamyhre

Description

@danielvegamyhre

Lines with bank conflicts

Image Image Image

Additional context

With #2864 we have modest speedups for MoE fp8 rowwise training, when experts per device is <= 16.

I did some perf analysis to determine why perf regressed as number of experts grew, and found 2 specific kernels are the culprits - and both are for quantizing the 3d expert weights tensor (1st time in forward for out = input @ weight.t() and 2nd time the non-transposed tensor for grad_input = grad_output_t @ weight).

When scaling up from 4 experts to 16 experts, kernels quantizing the input activations have the same runtime as expected, since inputs are the same size. However, the 2 kernels quantizing the weight described above take 6x as long (for weights that are only 4x as big).

  1. The kernel used in forward pass to quantize weight^T is codegen by inductor.
  2. The kernel used in backward pass to quantize weight (non transpose) is handwritten (this one) since inductor was too slow. This kernel is ~20% faster than inductor but still scales poorly, as described above.

NCU flags shows 71% memory bandwidth utilization, but flags (1) scoreboard stalls / bank conflicts and (2) low occupancy due to register requirements:

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions