Skip to content

Commit d20ab7d

Browse files
committed
add heuristic for xqa hopper spec-dec kernel, add test for fp8 llama3 ckpt
Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent 039f7e3 commit d20ab7d

File tree

5 files changed

+110
-5
lines changed

5 files changed

+110
-5
lines changed

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,5 +368,84 @@ inline int computeMultiBlockCount(XQAParams const& xqaParams, int batch_size, in
368368
return multi_block_count;
369369
}
370370

371+
inline int computeMultiBlockCountSpecDecGMMA(
372+
XQAParams const& xqaParams, int batch_size, int multiprocessor_count, int specDecBlocks)
373+
{
374+
auto const userSpecified = tensorrt_llm::common::getEnvXqaBlocksPerSequence();
375+
if (userSpecified.has_value())
376+
{
377+
return userSpecified.value();
378+
}
379+
int multi_block_count = 1;
380+
381+
// skip large batch size
382+
TLLM_CHECK_WITH_INFO(batch_size <= 32, "Multiblock tuning should be for only batch size <= 32");
383+
384+
int num_kv_heads = xqaParams.num_kv_heads;
385+
int history_length = xqaParams.max_past_kv_length;
386+
387+
// gridDim = dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size}
388+
int single_block_count = specDecBlocks * num_kv_heads * batch_size;
389+
double wave_count = (double) single_block_count / (double) multiprocessor_count;
390+
391+
// Multi block tuning for low CTA: populating CTAs to at most 1 wave of SMs
392+
if (wave_count < 1)
393+
{
394+
auto highestPowerof2 = [](int x)
395+
{
396+
x |= x >> 1;
397+
x |= x >> 2;
398+
x |= x >> 4;
399+
x |= x >> 8;
400+
x |= x >> 16;
401+
return x ^ (x >> 1);
402+
};
403+
404+
// calculate the maximum blocks to be populated at most 1 wave
405+
multi_block_count = floor(multiprocessor_count / single_block_count);
406+
// make multi_block_count a power of 2 for tuning convenience.
407+
multi_block_count = highestPowerof2(multi_block_count);
408+
// make multi_block_count at most 64 and at least 1.
409+
multi_block_count = std::min(multi_block_count, 64);
410+
multi_block_count = std::max(multi_block_count, 1);
411+
412+
// tune only when original CTA is too small, multi_block_count is too big, and history length < 2^16
413+
// For Hopper, most cases there are 114, 132, 144 SMs. For H20 about 78.
414+
// single_block_count = [1..8]
415+
// multi_block_count = [16,32,64,128]
416+
// history_length = [1024..65536]
417+
if (single_block_count <= 8 && multi_block_count >= 16 && history_length < 65536)
418+
{
419+
if (history_length <= 1024)
420+
{
421+
// for history length <= 1024 and low CTA, scaling is not effective, so we set a hard limit to
422+
// multi_block_count = 4
423+
multi_block_count = std::min(multi_block_count, 4);
424+
}
425+
else if (history_length < 65536)
426+
{
427+
// at single_block == 8, multi_block_count can only be 16. (SM / 8 ~= 16)
428+
// tune only 1024 < kvlen < 8192
429+
if (single_block_count == 8 && history_length <= 8192)
430+
{
431+
multi_block_count >>= 1;
432+
}
433+
else
434+
{
435+
auto getLog2 = [](int x) { return x ? 31 - __builtin_clz(x) : -1; };
436+
auto history_length_log2 = getLog2(history_length);
437+
multi_block_count >>= 3 - (history_length_log2 - 10) / 2;
438+
// 2^15 (< 65536) -> shift 1
439+
// 2^13, 2^14 -> shift 2
440+
// 2^11, 2^12 (> 1024) -> shift 3
441+
}
442+
}
443+
}
444+
TLLM_CHECK_WITH_INFO((multi_block_count * single_block_count) <= multiprocessor_count,
445+
"The adjusted MultiBlock exceed number of SMs, adding additional wave may result to perf drop.");
446+
}
447+
return multi_block_count;
448+
}
449+
371450
} // namespace kernels
372451
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,12 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
380380
{
381381
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
382382
}
383+
// A WAR to enable Hopper XQA multi-token multi_block mode for low batch size
384+
if (isSpecDec && isGMMAKernel && xqaParams.batch_size <= 32)
385+
{
386+
multi_block
387+
= computeMultiBlockCountSpecDecGMMA(xqaParams, xqaParams.batch_size, multiprocessor_count, specDecBlocks);
388+
}
383389
uint32_t const nbKVHeads = xqaParams.num_kv_heads;
384390
auto const gridDim = (isGMMAKernel ? dim3{specDecBlocks, multi_block, nbKVHeads * xqaParams.batch_size}
385391
: dim3{multi_block, nbKVHeads, xqaParams.batch_size});

cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
705705
mUseSpecDecoding = useSpecDecoding;
706706
// change mMultiBlockMode to default
707707
mMultiBlockMode = mUseSpecDecoding ? false : true;
708+
// if Hopper XQA kernel is enabled, multi block mode will be true in decoderXQAImplJIT::runImpl
708709
}
709710

710711
[[maybe_unused]] MlaParams<T> mla_params;

tests/integration/defs/examples/test_eagle.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,16 @@ def test_llm_eagle_1gpu(batch_size, data_type, use_dynamic_tree,
9696
# TODO: remove skip_post_blackwell after Speculative decoding is supported.
9797
@skip_post_blackwell
9898
@skip_pre_ada
99+
@pytest.mark.parametrize("use_dynamic_tree", [False, True],
100+
ids=['eagle1', 'eagle2'])
99101
@pytest.mark.parametrize("batch_size", [8], ids=['bs8'])
100102
@pytest.mark.parametrize("data_type", ['float16'])
101103
@pytest.mark.parametrize("eagle_model_roots", ["llama3.1-eagle-8b-hf_v0.5"],
102104
indirect=True)
103-
def test_llm_eagle_1gpu_modelopt_ckpt(batch_size, data_type, eagle_model_roots,
104-
eagle_example_root, llm_datasets_root,
105-
llm_rouge_root, llm_venv, cmodel_dir,
106-
engine_dir):
105+
def test_llm_eagle_1gpu_modelopt_ckpt(batch_size, data_type, use_dynamic_tree,
106+
eagle_model_roots, eagle_example_root,
107+
llm_datasets_root, llm_rouge_root,
108+
llm_venv, cmodel_dir, engine_dir):
107109
print("Build engines...")
108110
model_name = "eagle"
109111

@@ -141,6 +143,22 @@ def test_llm_eagle_1gpu_modelopt_ckpt(batch_size, data_type, eagle_model_roots,
141143

142144
venv_check_call(llm_venv, run_cmd)
143145

146+
print("Run summarize...")
147+
summary_cmd = [
148+
f"{eagle_example_root}/../summarize.py", "--test_trt_llm",
149+
"--hf_model_dir", f"{eagle_model_roots}", "--tokenizer_dir",
150+
f"{eagle_model_roots}", f"--engine_dir={engine_dir}",
151+
"--check_accuracy", "--tensorrt_llm_rouge1_threshold=24",
152+
"--eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]",
153+
f"--max_ite=40", f"--batch_size={batch_size}",
154+
f"--dataset_dir={llm_datasets_root}", f"--rouge_dir={llm_rouge_root}"
155+
]
156+
if use_dynamic_tree:
157+
summary_cmd.extend(
158+
[f"--eagle_dynamic_tree_max_top_k={3}", "--eagle_use_dynamic_tree"])
159+
160+
venv_check_call(llm_venv, summary_cmd)
161+
144162

145163
def test_with_dummy_eagle(hf_model_root,
146164
use_dynamic_tree,

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ l0_h100:
220220
backend: tensorrt
221221
tests:
222222
# ------------- TRT tests ---------------
223-
- examples/test_eagle.py::test_llm_eagle_1gpu_modelopt_ckpt[llama3.1-eagle-8b-hf_v0.5-float16-bs8] # 9 mins
223+
- examples/test_eagle.py::test_llm_eagle_1gpu_modelopt_ckpt[llama3.1-eagle-8b-hf_v0.5-float16-bs8-eagle1] # 9 mins
224+
- examples/test_eagle.py::test_llm_eagle_1gpu_modelopt_ckpt[llama3.1-eagle-8b-hf_v0.5-float16-bs8-eagle2] # 9 mins
224225
- examples/test_eagle.py::test_llm_eagle_1gpu[EAGLE-Vicuna-7B-v1.3-float16-bs1-eagle1]
225226
- examples/test_eagle.py::test_llm_eagle_1gpu[EAGLE-Vicuna-7B-v1.3-float16-bs1-eagle2] # 5 mins
226227
- accuracy/test_llm_api.py::TestMistral_NeMo_Minitron_8B_Instruct::test_fp8

0 commit comments

Comments
 (0)