Skip to content

Commit e4f79be

Browse files
committed
add chunked-attention kernels for blackwell
Signed-off-by: Perkz Zheng <[email protected]>
1 parent 414eca0 commit e4f79be

File tree

6 files changed

+57
-22
lines changed

6 files changed

+57
-22
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,9 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
685685
size_t const cu_seqlens_size = sizeof(int) * (batch_size + 1);
686686
size_t const rotary_inv_freq_size = sizeof(float) * batch_size * mRotaryEmbeddingDim / 2;
687687

688+
// Does the fmha kernel need FP8 inputQ ?
689+
bool isFP8InputQ = mFmhaDispatcher->isFP8InputQ();
690+
688691
size_t q_buf_2_size = 0;
689692
if (!mEnableContextFMHA)
690693
{
@@ -694,7 +697,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
694697
else if (mFmhaDispatcher->isSeparateQAndKvInput())
695698
{
696699
// Paged context fmha
697-
q_buf_2_size = (mFP8ContextFMHA ? 1 : size) * max_num_tokens * local_hidden_units_qo;
700+
q_buf_2_size = (isFP8InputQ ? 1 : size) * max_num_tokens * local_hidden_units_qo;
698701
}
699702

700703
size_t const k_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * kv_seq_length * local_hidden_units_kv;
@@ -704,17 +707,16 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
704707
size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo;
705708
size_t const qk_buf_float_size
706709
= mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_length * kv_seq_length;
707-
size_t const fp8_qkv_buffer_size
708-
= mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
710+
size_t const fp8_qkv_buffer_size = isFP8InputQ && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
709711
? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv)
710712
: 0;
711713
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
712714
size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
713715
// Each token holds (batch_idx, token_idx_in_seq) int2.
714716
size_t const tokens_info_size = sizeof(int2) * max_num_tokens;
715717
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0;
716-
size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0;
717-
size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0;
718+
size_t const fmha_bmm1_scale_size = isFP8InputQ ? sizeof(float) * 2 : 0;
719+
size_t const fmha_bmm2_scale_size = isFP8InputQ ? sizeof(float) : 0;
718720

719721
// cp workspace size upper bound
720722
size_t const cpMaxPaddedSequenceLength = max_num_tokens + batch_size * (mCpSize - 1);
@@ -1254,6 +1256,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
12541256
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * kv_seq_length;
12551257
size_t const cu_seqlens_size = sizeof(int) * (params.batch_size + 1);
12561258
size_t const rotary_inv_freq_size = sizeof(float) * params.batch_size * mRotaryEmbeddingDim / 2;
1259+
bool const isFP8InputQ = mFmhaDispatcher->isFP8InputQ();
12571260
size_t q_buf_2_size = 0;
12581261
if (!mEnableContextFMHA)
12591262
{
@@ -1263,7 +1266,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
12631266
else if (mFmhaDispatcher->isSeparateQAndKvInput())
12641267
{
12651268
// Paged context fmha
1266-
q_buf_2_size = (mFP8ContextFMHA ? 1 : sizeof(T)) * params.num_tokens * local_hidden_units_qo;
1269+
q_buf_2_size = (isFP8InputQ ? 1 : sizeof(T)) * params.num_tokens * local_hidden_units_qo;
12671270
}
12681271

12691272
size_t const k_buf_2_size
@@ -1277,8 +1280,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
12771280
size_t const qk_buf_float_size = mEnableContextFMHA
12781281
? 0
12791282
: sizeof(float) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length;
1280-
size_t const fp8_qkv_buffer_size
1281-
= mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
1283+
size_t const fp8_qkv_buffer_size = mEnableContextFMHA && isFP8InputQ && !mFmhaDispatcher->isSeparateQAndKvInput()
12821284
? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
12831285
: 0;
12841286
size_t const padding_offset_size
@@ -1288,8 +1290,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
12881290
// Each token holds (batch_idx, token_idx_in_seq) int2.
12891291
size_t const tokens_info_size = sizeof(int2) * params.num_tokens;
12901292
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0;
1291-
size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0;
1292-
size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0;
1293+
size_t const fmha_bmm1_scale_size = isFP8InputQ ? sizeof(float) * 2 : 0;
1294+
size_t const fmha_bmm2_scale_size = isFP8InputQ ? sizeof(float) : 0;
12931295

12941296
// cp workspace size upper bound
12951297
size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1);
@@ -1514,7 +1516,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
15141516
preprocessingParams.position_shift_enabled = mPosShiftEnabled;
15151517
preprocessingParams.cache_type = cache_type;
15161518
preprocessingParams.separate_q_kv_output = enablePagedKVContextFMHA || isCrossAttention();
1517-
preprocessingParams.quantized_fp8_output = mFP8ContextFMHA;
1519+
preprocessingParams.quantized_fp8_output = isFP8InputQ;
15181520
preprocessingParams.generation_phase = false;
15191521
preprocessingParams.multi_processor_count = mMultiProcessorCount;
15201522

@@ -1614,8 +1616,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16141616
// TODO: set it correctly for contiguous kv buffer (cross-attention).
16151617
fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens;
16161618
// Device buffer pointers.
1617-
fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
1618-
: reinterpret_cast<void const*>(attention_input);
1619+
fmhaParams.qkvPtr = isFP8InputQ ? reinterpret_cast<void const*>(fp8_qkv_buffer)
1620+
: reinterpret_cast<void const*>(attention_input);
16191621
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
16201622
// TODO: add contiguous kv buffer (cross-attention).
16211623
fmhaParams.kvPtr = nullptr;
@@ -2423,6 +2425,12 @@ int AttentionOp::initialize() noexcept
24232425
if (mKVCacheQuantMode.hasFp8KvCache())
24242426
{
24252427
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
2428+
// Trtllm-gen kernels forces using FP8 FMHA kernels if fp8 kv cache is used.
2429+
// FP8 Q/KV input and FP8/BF16/FP16 output are supported.
2430+
if (mUseTllmGen)
2431+
{
2432+
fmhaParams.dataType = DATA_TYPE_E4M3;
2433+
}
24262434
}
24272435
// TODO: add FP4 KV cache support.
24282436
}

cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
183183
tllmRunnerParams.mMaxSeqLenQ = runnerParams.qSeqLen;
184184
tllmRunnerParams.mMaxSeqLenKv = runnerParams.kvSeqLen;
185185
tllmRunnerParams.mAttentionWindowSize = runnerParams.slidingWindowSize;
186+
// Set chunked attention size to INT_MAX to disable chunked attention for now.
187+
tllmRunnerParams.mChunkedAttentionSize = INT_MAX;
186188
tllmRunnerParams.mSumOfSeqLensQ = runnerParams.totalQSeqLen;
187189
tllmRunnerParams.mSumOfSeqLensKv = runnerParams.totalKvSeqLen;
188190
tllmRunnerParams.mMaxNumPagesPerSeqKv = maxBlocksPerSeq;

cpp/tensorrt_llm/kernels/fmhaDispatcher.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ class FmhaDispatcher
3737
// Deconstructor.
3838
~FmhaDispatcher() = default;
3939

40+
// Does the fmha kernel need FP8 inputQ ?
41+
bool isFP8InputQ() const
42+
{
43+
return mFixedParams.dataType == DATA_TYPE_E4M3;
44+
}
45+
4046
// Check if any fmha kernel meets the requirements.
4147
bool isSupported();
4248

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,13 @@ class TllmGenFmhaKernel
465465

466466
// The mask type.
467467
TrtllmGenAttentionMaskType maskType = params.mMaskType;
468-
// Enable sliding window causal if the max kv sequence length exceeds attention window size.
469-
if (params.mAttentionWindowSize < params.mMaxSeqLenKv && maskType == TrtllmGenAttentionMaskType::Causal)
468+
// Enable sliding window or chunked causal if the max kv sequence length exceeds attention window size or
469+
// chunked attention size.
470+
if (maskType == TrtllmGenAttentionMaskType::Causal
471+
&& (params.mMaxSeqLenKv > params.mAttentionWindowSize
472+
|| params.mMaxSeqLenKv > params.mChunkedAttentionSize))
470473
{
471-
maskType = TrtllmGenAttentionMaskType::SlidingWindowCausal;
474+
maskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
472475
}
473476
// NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
474477
int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage;

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ enum class TrtllmGenAttentionMaskType
3232
Dense = 0,
3333
// Causal mask.
3434
Causal,
35-
// Sliding window causal mask.
36-
SlidingWindowCausal,
35+
// Sliding window or chunked causal mask.
36+
SlidingOrChunkedCausal,
3737
// Custom mask.
3838
Custom
3939
};
@@ -50,7 +50,7 @@ enum class TrtllmGenAttentionMaskType
5050

5151
ATTENTION_MASK_TYPE_FUNCTION(Dense)
5252
ATTENTION_MASK_TYPE_FUNCTION(Causal)
53-
ATTENTION_MASK_TYPE_FUNCTION(SlidingWindowCausal)
53+
ATTENTION_MASK_TYPE_FUNCTION(SlidingOrChunkedCausal)
5454
ATTENTION_MASK_TYPE_FUNCTION(Custom)
5555

5656
#undef ATTENTION_MASK_TYPE_FUNCTION
@@ -246,8 +246,11 @@ struct TllmGenFmhaRunnerParams
246246
int mMaxSeqLenQ;
247247
// The max kv sequence length.
248248
int mMaxSeqLenKv;
249-
// The attention window size for sliding window attention.
249+
// The attention window size for sliding window attention (sliding-window-attention is enabled when seqLenKv >
250+
// mAttentionWindowSize).
250251
int mAttentionWindowSize;
252+
// The chunked attention size (chunked-context is enabled when seqLenKv > mChunkedAttentionSize).
253+
int mChunkedAttentionSize;
251254
// The sum of sequence lengths for Q and K/V. (Only used when mSupportsVarSeqLens = true)
252255
int mSumOfSeqLensQ;
253256
int mSumOfSeqLensKv;
@@ -283,8 +286,8 @@ struct TllmGenFmhaRunnerParams
283286
case 1: // tensorrt_llm::kernels::ContextAttentionMaskType::CAUSAL
284287
mMaskType = TrtllmGenAttentionMaskType::Causal;
285288
break;
286-
case 2: // tensorrt_llm::kernels::ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL
287-
mMaskType = TrtllmGenAttentionMaskType::SlidingWindowCausal;
289+
case 2: // tensorrt_llm::kernels::ContextAttentionMaskType::SLIDING_OR_CHUNKED_CAUSAL
290+
mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
288291
break;
289292
case 3: // tensorrt_llm::kernels::ContextAttentionMaskType::CUSTOM_MASK
290293
mMaskType = TrtllmGenAttentionMaskType::Custom;

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ struct KernelParams
108108
int32_t mAttentionWindowSize;
109109
// The batch size
110110
int32_t mBatchSize;
111+
// The chunked attention size in log2.
112+
int32_t mChunkedAttentionSizeLog2;
111113
// The log of the Sage Attention block size for K.
112114
int32_t mLogNumEltsPerSageAttnBlkK;
113115
// The log of the Sage Attention block size for P.
@@ -741,6 +743,17 @@ struct KernelParams
741743
params.ptrSoftmaxStats = options.softmaxStatsPtr;
742744

743745
params.mAttentionWindowSize = options.mAttentionWindowSize;
746+
if (isSlidingOrChunkedCausalMask(options.mMaskType) && options.mMaxSeqLenKv > options.mChunkedAttentionSize)
747+
{
748+
TLLM_CHECK_WITH_INFO((options.mChunkedAttentionSize & (options.mChunkedAttentionSize - 1)) == 0,
749+
"Chunked attention size must be a power of 2");
750+
params.mChunkedAttentionSizeLog2 = std::log2(options.mChunkedAttentionSize);
751+
}
752+
else
753+
{
754+
// Default 0 means that chunked attention is disabled.
755+
params.mChunkedAttentionSizeLog2 = 0;
756+
}
744757
params.mMaxSeqLenQ = options.mMaxSeqLenQ;
745758
params.mMaxSeqLenKv = options.mMaxSeqLenKv;
746759
params.mMaxNumCtasQ = maxNumCtasQ;

0 commit comments

Comments
 (0)