Skip to content

Add fused_transpose_split_quant kernel #10657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

lshpku
Copy link

@lshpku lshpku commented May 26, 2025

PR types

Performance optimization

PR changes

APIs

Description

新增fused_transpose_split_quant算子,定义如下:

/**
 * Quantize on dim[0] of X, transpose dim[0] and dim[1] of X, then
 * split the result into out and scale.
 *
 * Inputs:
 *   X     : [SUM(M_1...M_N), K], bfloat16
 *
 * Outputs:
 *   out   : {[K, M_1], [K, M_2], ..., [K, M_N]}, float8_e4m3fn
 *   scale : {[M_1/128, K], [M_2/128, K], ..., [M_N/128, K]}, float
 *
 * Attrs:
 *   pow_2_scales
 *         : bool that indicates whether to use power-of-2 scaling
 *
 * Requirements:
 *   1) M_i % 128 == 0 for M_i in [M_1, M_2, ..., M_N]
 *   2) K <= 65535 * 128
 */
void fused_transpose_split_quant(const paddle::Tensor& X,
                                 std::vector<paddle::Tensor>& outs,
                                 std::vector<paddle::Tensor>& scales,
                                 bool pow_2_scales)

注:该算子使用较为复杂,由于custom op不支持变长输出,所以我采用将输出从参数传入的方式,需要用户自己为输出分配空间。可以在Python端做一层包装,这样看起来更像一个函数式API,下面举了一个例子:

import paddle
import FusedQuantOps as FQO

def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales):
    out, scale = [], []
    for tokens in tokens_per_expert:
        out.append(paddle.empty([x.shape[1], tokens], dtype='float8_e4m3fn'))
        scale.append(paddle.empty([tokens//128, x.shape[1]], dtype='float32'))
    FQO.fused_transpose_split_quant(x, out, scale, pow_2_scales)
    return out, scale

tokens_per_expert = [24*128, 0, 128*128, 1*128, 13*128]
seq_len = 7168

x = paddle.randn([sum(tokens_per_expert), seq_len], dtype='bfloat16')

out, scale = fused_transpose_split_quant(x, tokens_per_expert, False)

# out[0].shape=[7168,  3072]  scale[0].shape=[ 24, 7168]
# out[1].shape=[7168,     0]  scale[1].shape=[  0, 7168]
# out[2].shape=[7168, 16384]  scale[2].shape=[128, 7168]
# out[3].shape=[7168,   128]  scale[3].shape=[  1, 7168]
# out[4].shape=[7168,  1664]  scale[4].shape=[ 13, 7168]

在H卡上与标准库验证过正确性,scale精度达到1e-7,out精度达到rtol=0.01 atol=0.2

性能测试

以上面的shape为例,在A100上达到81.8%的带宽利用率

Pcard-85711

Copy link

paddle-bot bot commented May 26, 2025

Thanks for your contribution!

@lshpku lshpku force-pushed the add-transpose-split-quant-kernel branch 4 times, most recently from 3ff52f0 to ace25e1 Compare May 29, 2025 09:18
@lshpku lshpku force-pushed the add-transpose-split-quant-kernel branch from ace25e1 to 0095350 Compare May 29, 2025 11:37
@phlrain phlrain merged commit f2712b7 into PaddlePaddle:dsv3_dev Jun 3, 2025
2 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants