@@ -694,7 +694,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
694
694
else if (mFmhaDispatcher ->isSeparateQAndKvInput ())
695
695
{
696
696
// Paged context fmha
697
- q_buf_2_size = (mFP8ContextFMHA ? 1 : size) * max_num_tokens * local_hidden_units_qo;
697
+ q_buf_2_size = (isFP8ContextFmhaInput () ? 1 : size) * max_num_tokens * local_hidden_units_qo;
698
698
}
699
699
700
700
size_t const k_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * kv_seq_length * local_hidden_units_kv;
@@ -705,16 +705,16 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
705
705
size_t const qk_buf_float_size
706
706
= mEnableContextFMHA ? 0 : sizeof (float ) * batch_size * mNumHeads * input_seq_length * kv_seq_length;
707
707
size_t const fp8_qkv_buffer_size
708
- = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
708
+ = isFP8ContextFmhaInput () && mEnableContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
709
709
? max_num_tokens * size_t (local_hidden_units_qo + 2 * local_hidden_units_kv)
710
710
: 0 ;
711
711
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof (int ) * max_num_tokens;
712
712
size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof (int ) * max_num_tokens;
713
713
// Each token holds (batch_idx, token_idx_in_seq) int2.
714
714
size_t const tokens_info_size = sizeof (int2) * max_num_tokens;
715
715
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof (uint32_t ) : 0 ;
716
- size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof (float ) * 2 : 0 ;
717
- size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof (float ) : 0 ;
716
+ size_t const fmha_bmm1_scale_size = isFP8ContextFmhaInput () ? sizeof (float ) * 2 : 0 ;
717
+ size_t const fmha_bmm2_scale_size = isFP8ContextFmhaInput () ? sizeof (float ) : 0 ;
718
718
719
719
// cp workspace size upper bound
720
720
size_t const cpMaxPaddedSequenceLength = max_num_tokens + batch_size * (mCpSize - 1 );
@@ -1254,6 +1254,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1254
1254
= mEnableContextFMHA ? 0 : sizeof (T) * params.batch_size * params.input_seq_length * kv_seq_length;
1255
1255
size_t const cu_seqlens_size = sizeof (int ) * (params.batch_size + 1 );
1256
1256
size_t const rotary_inv_freq_size = sizeof (float ) * params.batch_size * mRotaryEmbeddingDim / 2 ;
1257
+ TLLM_CHECK_WITH_INFO (isFP8ContextFmhaInput () == mFmhaDispatcher ->isFP8InputQ (), " internal error" );
1257
1258
size_t q_buf_2_size = 0 ;
1258
1259
if (!mEnableContextFMHA )
1259
1260
{
@@ -1263,7 +1264,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1263
1264
else if (mFmhaDispatcher ->isSeparateQAndKvInput ())
1264
1265
{
1265
1266
// Paged context fmha
1266
- q_buf_2_size = (mFP8ContextFMHA ? 1 : sizeof (T)) * params.num_tokens * local_hidden_units_qo;
1267
+ q_buf_2_size = (isFP8ContextFmhaInput () ? 1 : sizeof (T)) * params.num_tokens * local_hidden_units_qo;
1267
1268
}
1268
1269
1269
1270
size_t const k_buf_2_size
@@ -1278,7 +1279,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1278
1279
? 0
1279
1280
: sizeof (float ) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length;
1280
1281
size_t const fp8_qkv_buffer_size
1281
- = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
1282
+ = mEnableContextFMHA && isFP8ContextFmhaInput () && !mFmhaDispatcher ->isSeparateQAndKvInput ()
1282
1283
? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
1283
1284
: 0 ;
1284
1285
size_t const padding_offset_size
@@ -1288,8 +1289,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1288
1289
// Each token holds (batch_idx, token_idx_in_seq) int2.
1289
1290
size_t const tokens_info_size = sizeof (int2) * params.num_tokens ;
1290
1291
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof (uint32_t ) : 0 ;
1291
- size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof (float ) * 2 : 0 ;
1292
- size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof (float ) : 0 ;
1292
+ size_t const fmha_bmm1_scale_size = isFP8ContextFmhaInput () ? sizeof (float ) * 2 : 0 ;
1293
+ size_t const fmha_bmm2_scale_size = isFP8ContextFmhaInput () ? sizeof (float ) : 0 ;
1293
1294
1294
1295
// cp workspace size upper bound
1295
1296
size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1 );
@@ -1514,7 +1515,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1514
1515
preprocessingParams.position_shift_enabled = mPosShiftEnabled ;
1515
1516
preprocessingParams.cache_type = cache_type;
1516
1517
preprocessingParams.separate_q_kv_output = enablePagedKVContextFMHA || isCrossAttention ();
1517
- preprocessingParams.quantized_fp8_output = mFP8ContextFMHA ;
1518
+ preprocessingParams.quantized_fp8_output = isFP8ContextFmhaInput () ;
1518
1519
preprocessingParams.generation_phase = false ;
1519
1520
preprocessingParams.multi_processor_count = mMultiProcessorCount ;
1520
1521
@@ -1614,8 +1615,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
1614
1615
// TODO: set it correctly for contiguous kv buffer (cross-attention).
1615
1616
fmhaParams.totalKvSeqLen = isCrossAttention () ? params.num_encoder_tokens : params.num_tokens ;
1616
1617
// Device buffer pointers.
1617
- fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast <void const *>(fp8_qkv_buffer)
1618
- : reinterpret_cast <void const *>(attention_input);
1618
+ fmhaParams.qkvPtr = isFP8ContextFmhaInput () ? reinterpret_cast <void const *>(fp8_qkv_buffer)
1619
+ : reinterpret_cast <void const *>(attention_input);
1619
1620
fmhaParams.qPtr = reinterpret_cast <void const *>(q_buf_2_);
1620
1621
// TODO: add contiguous kv buffer (cross-attention).
1621
1622
fmhaParams.kvPtr = nullptr ;
@@ -2423,6 +2424,12 @@ int AttentionOp::initialize() noexcept
2423
2424
if (mKVCacheQuantMode .hasFp8KvCache ())
2424
2425
{
2425
2426
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
2427
+ // Trtllm-gen kernels forces using FP8 FMHA kernels if fp8 kv cache is used.
2428
+ // FP8 Q/KV input and FP8/BF16/FP16 output are supported.
2429
+ if (mUseTllmGen )
2430
+ {
2431
+ fmhaParams.dataType = DATA_TYPE_E4M3;
2432
+ }
2426
2433
}
2427
2434
// TODO: add FP4 KV cache support.
2428
2435
}
0 commit comments