Skip to content

[draft] [feat] Multi-block mode for Hopper spec dec XQA kernel #4416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

jhaotingc
Copy link
Collaborator

@jhaotingc jhaotingc commented May 18, 2025

[feat] Multi-block mode for Hopper spec dec XQA kernel

Description

Following PR 3269, it's observed that at low batch size and low draft len, the XQA hopper spec dec kernel has low CTA.

The reason is root caused to not having multi-block mode when Spec-dec is turned on.

The XQA Hopper Spec-dec kernel is launched with gridDim = dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size}, where

  • specDecBlocks = divUp(specDecParams.qSeqLen, 64 / num_q_heads_over_kv), num_q_heads_over_kv for Llama is 8.
  • nbKVHeads is num_kv_head per TP rank, for Llama 3, it's 8 when TP=1, 2 when TP=4, 1 when TP=8.

Before the fix, multi_block = 1.

At a very common use case, where eagle draft length set to < 8, running TP=8, using Llama 3 70b / Llama 3.1 8b. The number of blocks launched could be only batch_size blocks. At BS=1, only 1 block will be launched.

gridDim = dim3{ divUp(7 / 8), 1, 1 * xqaParams.batch_size} # = batch_size

Therefore, multi-block mode is crucial for low BS, low draft length case.

Heuristic design:

A series of sweeps was done with xqa. The experiment showed that when original gridDim is less than a wave of SM, there is benefit for multi-block mode. Furthermore, when original block count is <= 8, 64k > ISL >= 1k, populating all SMs is not always good. The experiments are shown in Appendix.

Speedup:

Kernel Speedup for ISL=32k, draft length 7, batch size 2.

By increasing gridDim from (1,1,2) to (1,32,2) yield a 7.8x speedup.

before: 
kernel_mha
Begins: 39.1859s
Ends: 39.1863s (+418.620 μs)
grid:  <<<1, 1, 2>>>
block: <<<128, 1, 3>>>

after: 
kernel_mha
Begins: 33.2649s
Ends: 33.265s (+52.991 μs)
grid:  <<<1, 32, 2>>>
block: <<<128, 1, 3>>>

Kernel Speedup for ISL=32k, draft length 7, batch size 8.

By increasing gridDim from (1,1,8) to (1,8,8) yield a ** 5.4x speedup.**

before: 
kernel_mha
Begins: 62.8263s
Ends: 62.8268s (+420.063 μs)
grid:  <<<1, 1, 8>>>
block: <<<128, 1, 3>>>

after: 
kernel_mha
Begins: 42.189s
Ends: 42.1891s (+77.663 μs)
grid:  <<<1, 8, 8>>>
block: <<<128, 1, 3>>>

Kernel Speedup for ISL=10k, draft length 7, batch size 8.

By increasing gridDim from (1,1,8) to (1,4,8) yield a ** 2.5x speedup.**

before: 
kernel_mha
Begins: 39.3397s
Ends: 39.3399s (+137.759 μs)
grid:  <<<1, 1, 8>>>
block: <<<128, 1, 3>>>

after: 
kernel_mha
Begins: 38.6683s
Ends: 38.6683s (+55.776 μs)
grid:  <<<1, 4, 8>>>
block: <<<128, 1, 3>>>

Generation Step Speedup for ISL=32k, ISL=10k, ISL=1k, running TP8PP1 Llama 3 70B Eagle, Linear Tree (depth 6, max_draft_len = 7).

image (47)
image (48)
image (49)

Accuracy verification:

# add speculative_config in lm_eval_tensorrt_llm.py
python lm_eval_tensorrt_llm.py --model trt-llm \
    --model_args tokenizer=$HF_DIR,model=$ENGINE_DIR \
    --tasks gsm8k

# (Before:)
trt-llm (tokenizer=/scratch_1/tmp/hf_models/Meta-Llama-3-70B-Instruct,model=/scratch_1/tmp/trt_engines/Meta-Llama-3-70B-Instruct_eagle_fp8/tp8_pp1), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9075|±  | 0.008|
|     |       |strict-match    |     5|exact_match|↑  |0.9067|±  | 0.008|


# (After:)
trt-llm (tokenizer=/scratch_1/tmp/hf_models/Meta-Llama-3-70B-Instruct,model=/scratch_1/tmp/trt_engines/Meta-Llama-3-70B-Instruct_eagle_fp8/tp8_pp1), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9121|±  |0.0078|
|     |       |strict-match    |     5|exact_match|↑  |0.9121|±  |0.0078|

Test Coverage

Appendix

xqa sweep result

image (46)
image (43)
image (44)
image (45)
image (42)
image (41)
image (40)
image (39)
image (38)
image (37)

BS=16, ISL=1024 slight regression:

before:
kernel_mha
Begins: 36.3784s
Ends: 36.3785s (+21.600 μs)
grid:  <<<1, 1, 16>>>
block: <<<128, 1, 3>>>

after:
kernel_mha
Begins: 39.7469s
Ends: 39.7469s (+27.584 μs)
grid:  <<<1, 8, 16>>>
block: <<<128, 1, 3>>>

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]

Launch build/test pipelines. All previously running jobs will be killed.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-[Post-Merge]-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@jhaotingc
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5603 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5603 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #4089 completed with status: 'SUCCESS'

@jhaotingc jhaotingc force-pushed the xqa_hopper_cta_perf branch 3 times, most recently from c317adc to e635efe Compare May 18, 2025 20:27
@jhaotingc jhaotingc force-pushed the xqa_hopper_cta_perf branch from e635efe to d20ab7d Compare May 18, 2025 20:30
@jhaotingc
Copy link
Collaborator Author

/bot run --disable-fail-fast

@jhaotingc jhaotingc changed the title initial heuristic for xqa hopper spec dec Multi-block mode for Hopper spec dec XQA kernel May 18, 2025
@jhaotingc jhaotingc changed the title Multi-block mode for Hopper spec dec XQA kernel [feat] Multi-block mode for Hopper spec dec XQA kernel May 18, 2025
@jhaotingc
Copy link
Collaborator Author

/bot run --disable-fail-fast --post-merge

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5631 [ run ] triggered by Bot

Signed-off-by: Jhao-Ting Chen <[email protected]>
@jhaotingc
Copy link
Collaborator Author

/bot kill

@jhaotingc
Copy link
Collaborator Author

/bot kill --post-merge

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5636 [ kill ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5631 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5636 [ kill ] completed with state SUCCESS
Successfully killed previous jobs for commit 043e0dc

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5637 Bot args parsing error: usage: /bot [-h]
{run,kill,skip,submit,reviewers,reuse-pipeline,reuse-review} ...
/bot: error: unrecognized arguments: --post-merge

@jhaotingc
Copy link
Collaborator Author

/bot run --disable-fail-fast --post-merge

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5654 [ run ] triggered by Bot

@jhaotingc
Copy link
Collaborator Author

/bot kill

Signed-off-by: Jhao-Ting Chen <[email protected]>
@jhaotingc jhaotingc requested review from lowsfer and symphonylyh May 19, 2025 03:37
@jhaotingc
Copy link
Collaborator Author

/bot kill --post-merge

@jhaotingc
Copy link
Collaborator Author

/bot run --disable-fail-fast --post-merge

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5667 Bot args parsing error: usage: /bot [-h]
{run,kill,skip,submit,reviewers,reuse-pipeline,reuse-review} ...
/bot: error: unrecognized arguments: --post-merge

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5668 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5654 [ run ] completed with state ABORTED

@lowsfer
Copy link
Member

lowsfer commented May 19, 2025

Good work! gitlab/ftp/xqa content is already opensource in trtllm repo, so please update the link.

TLLM_CHECK_WITH_INFO(batch_size <= 32, "Multiblock tuning should be for only batch size <= 32");

int num_kv_heads = xqaParams.num_kv_heads;
int history_length = xqaParams.max_past_kv_length;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One concern about this is that the kv cache length is non-uniform inside one batch and this is the max. That's why the previous heuristic did not use this value. Did you try this PR on real e2e workload?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I only tested on same ISL case. Might need to do more testings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can try a few options: 1. max; 2. max * factor (factor < 1); 3. mean seqlen and see which works better

Signed-off-by: Jhao-Ting Chen <[email protected]>
@jhaotingc
Copy link
Collaborator Author

/bot run --disable-fail-fast --post-merge

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5674 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5668 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5674 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #4145 completed with status: 'FAILURE'

@jhaotingc
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5751 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #5751 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #4206 completed with status: 'SUCCESS'

@jhaotingc jhaotingc changed the title [feat] Multi-block mode for Hopper spec dec XQA kernel [draft] [feat] Multi-block mode for Hopper spec dec XQA kernel May 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants