-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[draft] [feat] Multi-block mode for Hopper spec dec XQA kernel #4416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
/bot run --disable-fail-fast |
PR_Github #5603 [ run ] triggered by Bot |
PR_Github #5603 [ run ] completed with state |
c317adc
to
e635efe
Compare
… ckpt Signed-off-by: Jhao-Ting Chen <[email protected]>
e635efe
to
d20ab7d
Compare
/bot run --disable-fail-fast |
/bot run --disable-fail-fast --post-merge |
PR_Github #5631 [ run ] triggered by Bot |
Signed-off-by: Jhao-Ting Chen <[email protected]>
/bot kill |
/bot kill --post-merge |
PR_Github #5636 [ kill ] triggered by Bot |
PR_Github #5631 [ run ] completed with state |
PR_Github #5636 [ kill ] completed with state |
PR_Github #5637 Bot args parsing error: usage: /bot [-h] |
/bot run --disable-fail-fast --post-merge |
PR_Github #5654 [ run ] triggered by Bot |
/bot kill |
Signed-off-by: Jhao-Ting Chen <[email protected]>
/bot kill --post-merge |
/bot run --disable-fail-fast --post-merge |
PR_Github #5667 Bot args parsing error: usage: /bot [-h] |
PR_Github #5668 [ run ] triggered by Bot |
PR_Github #5654 [ run ] completed with state |
Good work! gitlab/ftp/xqa content is already opensource in trtllm repo, so please update the link. |
TLLM_CHECK_WITH_INFO(batch_size <= 32, "Multiblock tuning should be for only batch size <= 32"); | ||
|
||
int num_kv_heads = xqaParams.num_kv_heads; | ||
int history_length = xqaParams.max_past_kv_length; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One concern about this is that the kv cache length is non-uniform inside one batch and this is the max. That's why the previous heuristic did not use this value. Did you try this PR on real e2e workload?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, I only tested on same ISL case. Might need to do more testings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can try a few options: 1. max; 2. max * factor (factor < 1); 3. mean seqlen and see which works better
Signed-off-by: Jhao-Ting Chen <[email protected]>
/bot run --disable-fail-fast --post-merge |
PR_Github #5674 [ run ] triggered by Bot |
PR_Github #5668 [ run ] completed with state |
PR_Github #5674 [ run ] completed with state |
/bot run --disable-fail-fast |
PR_Github #5751 [ run ] triggered by Bot |
PR_Github #5751 [ run ] completed with state |
[feat] Multi-block mode for Hopper spec dec XQA kernel
Description
Following PR 3269, it's observed that at low batch size and low draft len, the XQA hopper spec dec kernel has low CTA.
The reason is root caused to not having multi-block mode when Spec-dec is turned on.
The XQA Hopper Spec-dec kernel is launched with
gridDim = dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size}
, wherespecDecBlocks = divUp(specDecParams.qSeqLen, 64 / num_q_heads_over_kv)
, num_q_heads_over_kv for Llama is 8.nbKVHeads
is num_kv_head per TP rank, for Llama 3, it's 8 when TP=1, 2 when TP=4, 1 when TP=8.Before the fix, multi_block = 1.
At a very common use case, where eagle draft length set to < 8, running TP=8, using Llama 3 70b / Llama 3.1 8b. The number of blocks launched could be only
batch_size
blocks. At BS=1, only 1 block will be launched.Therefore, multi-block mode is crucial for low BS, low draft length case.
Heuristic design:
A series of sweeps was done with xqa. The experiment showed that when original gridDim is less than a wave of SM, there is benefit for multi-block mode. Furthermore, when original block count is <= 8, 64k > ISL >= 1k, populating all SMs is not always good. The experiments are shown in Appendix.
Speedup:
Kernel Speedup for ISL=32k, draft length 7, batch size 2.
By increasing gridDim from (1,1,2) to (1,32,2) yield a 7.8x speedup.
Kernel Speedup for ISL=32k, draft length 7, batch size 8.
By increasing gridDim from (1,1,8) to (1,8,8) yield a ** 5.4x speedup.**
Kernel Speedup for ISL=10k, draft length 7, batch size 8.
By increasing gridDim from (1,1,8) to (1,4,8) yield a ** 2.5x speedup.**
Generation Step Speedup for ISL=32k, ISL=10k, ISL=1k, running TP8PP1 Llama 3 70B Eagle, Linear Tree (depth 6, max_draft_len = 7).
Accuracy verification:
Test Coverage
Appendix
xqa sweep result
BS=16, ISL=1024 slight regression:
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]
Launch build/test pipelines. All previously running jobs will be killed.
--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-[Post-Merge]-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.