@@ -685,6 +685,9 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
685
685
size_t const cu_seqlens_size = sizeof (int ) * (batch_size + 1 );
686
686
size_t const rotary_inv_freq_size = sizeof (float ) * batch_size * mRotaryEmbeddingDim / 2 ;
687
687
688
+ // Does the fmha kernel need FP8 inputQ ?
689
+ bool isFP8InputQ = mFmhaDispatcher ->isFP8InputQ ();
690
+
688
691
size_t q_buf_2_size = 0 ;
689
692
if (!mEnableContextFMHA )
690
693
{
@@ -694,7 +697,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
694
697
else if (mFmhaDispatcher ->isSeparateQAndKvInput ())
695
698
{
696
699
// Paged context fmha
697
- q_buf_2_size = (mFP8ContextFMHA ? 1 : size) * max_num_tokens * local_hidden_units_qo;
700
+ q_buf_2_size = (isFP8InputQ ? 1 : size) * max_num_tokens * local_hidden_units_qo;
698
701
}
699
702
700
703
size_t const k_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * kv_seq_length * local_hidden_units_kv;
@@ -704,17 +707,16 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
704
707
size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo;
705
708
size_t const qk_buf_float_size
706
709
= mEnableContextFMHA ? 0 : sizeof (float ) * batch_size * mNumHeads * input_seq_length * kv_seq_length;
707
- size_t const fp8_qkv_buffer_size
708
- = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
710
+ size_t const fp8_qkv_buffer_size = isFP8InputQ && mEnableContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
709
711
? max_num_tokens * size_t (local_hidden_units_qo + 2 * local_hidden_units_kv)
710
712
: 0 ;
711
713
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof (int ) * max_num_tokens;
712
714
size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof (int ) * max_num_tokens;
713
715
// Each token holds (batch_idx, token_idx_in_seq) int2.
714
716
size_t const tokens_info_size = sizeof (int2) * max_num_tokens;
715
717
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof (uint32_t ) : 0 ;
716
- size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof (float ) * 2 : 0 ;
717
- size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof (float ) : 0 ;
718
+ size_t const fmha_bmm1_scale_size = isFP8InputQ ? sizeof (float ) * 2 : 0 ;
719
+ size_t const fmha_bmm2_scale_size = isFP8InputQ ? sizeof (float ) : 0 ;
718
720
719
721
// cp workspace size upper bound
720
722
size_t const cpMaxPaddedSequenceLength = max_num_tokens + batch_size * (mCpSize - 1 );
@@ -1254,6 +1256,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1254
1256
= mEnableContextFMHA ? 0 : sizeof (T) * params.batch_size * params.input_seq_length * kv_seq_length;
1255
1257
size_t const cu_seqlens_size = sizeof (int ) * (params.batch_size + 1 );
1256
1258
size_t const rotary_inv_freq_size = sizeof (float ) * params.batch_size * mRotaryEmbeddingDim / 2 ;
1259
+ bool const isFP8InputQ = mFmhaDispatcher ->isFP8InputQ ();
1257
1260
size_t q_buf_2_size = 0 ;
1258
1261
if (!mEnableContextFMHA )
1259
1262
{
@@ -1263,7 +1266,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1263
1266
else if (mFmhaDispatcher ->isSeparateQAndKvInput ())
1264
1267
{
1265
1268
// Paged context fmha
1266
- q_buf_2_size = (mFP8ContextFMHA ? 1 : sizeof (T)) * params.num_tokens * local_hidden_units_qo;
1269
+ q_buf_2_size = (isFP8InputQ ? 1 : sizeof (T)) * params.num_tokens * local_hidden_units_qo;
1267
1270
}
1268
1271
1269
1272
size_t const k_buf_2_size
@@ -1277,8 +1280,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1277
1280
size_t const qk_buf_float_size = mEnableContextFMHA
1278
1281
? 0
1279
1282
: sizeof (float ) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length;
1280
- size_t const fp8_qkv_buffer_size
1281
- = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
1283
+ size_t const fp8_qkv_buffer_size = mEnableContextFMHA && isFP8InputQ && !mFmhaDispatcher ->isSeparateQAndKvInput ()
1282
1284
? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
1283
1285
: 0 ;
1284
1286
size_t const padding_offset_size
@@ -1288,8 +1290,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1288
1290
// Each token holds (batch_idx, token_idx_in_seq) int2.
1289
1291
size_t const tokens_info_size = sizeof (int2) * params.num_tokens ;
1290
1292
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof (uint32_t ) : 0 ;
1291
- size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof (float ) * 2 : 0 ;
1292
- size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof (float ) : 0 ;
1293
+ size_t const fmha_bmm1_scale_size = isFP8InputQ ? sizeof (float ) * 2 : 0 ;
1294
+ size_t const fmha_bmm2_scale_size = isFP8InputQ ? sizeof (float ) : 0 ;
1293
1295
1294
1296
// cp workspace size upper bound
1295
1297
size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1 );
@@ -1514,7 +1516,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1514
1516
preprocessingParams.position_shift_enabled = mPosShiftEnabled ;
1515
1517
preprocessingParams.cache_type = cache_type;
1516
1518
preprocessingParams.separate_q_kv_output = enablePagedKVContextFMHA || isCrossAttention ();
1517
- preprocessingParams.quantized_fp8_output = mFP8ContextFMHA ;
1519
+ preprocessingParams.quantized_fp8_output = isFP8InputQ ;
1518
1520
preprocessingParams.generation_phase = false ;
1519
1521
preprocessingParams.multi_processor_count = mMultiProcessorCount ;
1520
1522
@@ -1614,8 +1616,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1614
1616
// TODO: set it correctly for contiguous kv buffer (cross-attention).
1615
1617
fmhaParams.totalKvSeqLen = isCrossAttention () ? params.num_encoder_tokens : params.num_tokens ;
1616
1618
// Device buffer pointers.
1617
- fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast <void const *>(fp8_qkv_buffer)
1618
- : reinterpret_cast <void const *>(attention_input);
1619
+ fmhaParams.qkvPtr = isFP8InputQ ? reinterpret_cast <void const *>(fp8_qkv_buffer)
1620
+ : reinterpret_cast <void const *>(attention_input);
1619
1621
fmhaParams.qPtr = reinterpret_cast <void const *>(q_buf_2_);
1620
1622
// TODO: add contiguous kv buffer (cross-attention).
1621
1623
fmhaParams.kvPtr = nullptr ;
@@ -2423,6 +2425,12 @@ int AttentionOp::initialize() noexcept
2423
2425
if (mKVCacheQuantMode .hasFp8KvCache ())
2424
2426
{
2425
2427
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
2428
+ // Trtllm-gen kernels forces using FP8 FMHA kernels if fp8 kv cache is used.
2429
+ // FP8 Q/KV input and FP8/BF16/FP16 output are supported.
2430
+ if (mUseTllmGen )
2431
+ {
2432
+ fmhaParams.dataType = DATA_TYPE_E4M3;
2433
+ }
2426
2434
}
2427
2435
// TODO: add FP4 KV cache support.
2428
2436
}
0 commit comments