Skip to content

Commit

Permalink
bugfix: various AOT issues (#752)
Browse files Browse the repository at this point in the history
1. Add missing instantiation for batch prefill and single prefill.
2. Skip FP8 in sm90 prefill dispatch.
3. Fix the incorrect prefill pybind declaration.
4. Fix mismatched uri for batch prefill
5. Add a DISPATCH_head_dim_sm90 since SM90 only supports 64, 128, 256.
6. Remove `csrc/aot_default_additional_params.h` and add to gitignore.
  • Loading branch information
abcdabcd987 authored Jan 25, 2025
1 parent e840db1 commit 6e6f38d
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 127 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ flashinfer/_build_meta.py
flashinfer/data/
flashinfer/jit/aot_config.py
src/generated/
csrc/aot_default_additional_params.h

# DS_Store files
.DS_store
Expand Down
14 changes: 0 additions & 14 deletions aot_build_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
generate_batch_paged_decode_inst,
generate_batch_paged_prefill_inst,
generate_batch_ragged_prefill_inst,
generate_dispatch_inc,
generate_single_decode_inst,
generate_single_prefill_inst,
)
Expand All @@ -48,19 +47,6 @@ def write_if_different(path: Path, content: str) -> None:

path.mkdir(parents=True, exist_ok=True)

# dispatch.inc
write_if_different(
path / "dispatch.inc",
generate_dispatch_inc.get_dispatch_inc_str(
argparse.Namespace(
head_dims=head_dims,
pos_encoding_modes=pos_encoding_modes,
use_fp16_qk_reductions=use_fp16_qk_reductions,
mask_modes=mask_modes,
)
),
)

write_if_different(
path / "aot_default_additional_params.h",
generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(),
Expand Down
12 changes: 12 additions & 0 deletions aot_build_utils/generate_batch_paged_decode_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,24 @@ def get_cu_file_str(
using Params = BatchDecodeParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>;
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
cudaStream_t stream);
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
cudaStream_t stream);
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp_v, float* tmp_s,
cudaStream_t stream);
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>(
Params params,
Expand Down
12 changes: 12 additions & 0 deletions aot_build_utils/generate_dispatch_inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(case_var, ...) \\
{dispatch_head_dims_entries}
// EOL
"""
# head dims for sm90
dispatch_head_dims_sm90_entries = "\n".join(
[
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
for _ in args.head_dims_sm90
]
)
dispatch_head_dims_sm90_str = f"""#define _DISPATCH_CASES_head_dim_sm90(case_var, ...) \\
{dispatch_head_dims_sm90_entries}
// EOL
"""
# positional encoding modes
dispatch_pos_encoding_modes_entries = "\n".join(
Expand Down Expand Up @@ -73,6 +84,7 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
return "\n".join(
[
dispatch_head_dims_str,
dispatch_head_dims_sm90_str,
dispatch_pos_encoding_modes_str,
dispatch_use_fp16_qk_reductions_str,
dispatch_mask_mode_str,
Expand Down
13 changes: 13 additions & 0 deletions aot_build_utils/generate_single_decode_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def get_cu_file_str(
using Params = SingleDecodeParams<{dtype_q}, {dtype_kv}, {dtype_out}>;
template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp,
cudaStream_t stream);
template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>(
Params params,
{dtype_out}* tmp,
cudaStream_t stream);
template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
/*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
Params params,
Expand All @@ -45,6 +57,7 @@ def get_cu_file_str(
Params params,
{dtype_out}* tmp,
cudaStream_t stream);
}}
""".format(
head_dim=head_dim,
Expand Down
66 changes: 0 additions & 66 deletions csrc/aot_default_additional_params.h

This file was deleted.

3 changes: 3 additions & 0 deletions csrc/aot_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#define DISPATCH_head_dim(expr, const_expr, ...) \
_DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__))

#define DISPATCH_head_dim_sm90(expr, const_expr, ...) \
_DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim_sm90(const_expr, __VA_ARGS__))

#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \
_DISPATCH_SWITCH("positional encoding mode", expr, \
_DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__))
Expand Down
32 changes: 17 additions & 15 deletions csrc/batch_prefill_sm90_config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,24 @@ using IdType = int32_t;
USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, ...) \
{ \
DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \
return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \
q_scalar_type, kv_scalar_type, dtype_q, dtype_kv, [&] { \
using DTypeQ = cutlass_dtype_t<dtype_q>; \
using DTypeKV = cutlass_dtype_t<dtype_kv>; \
using DTypeO = DTypeQ; \
using RaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
using PagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
__VA_ARGS__(); \
return true; \
}); \
}); \
if (q_scalar_type != kv_scalar_type) { \
return false; \
} \
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, dtype_q, [&] { \
using DTypeQ = cutlass_dtype_t<dtype_q>; \
using DTypeKV = DTypeQ; \
using DTypeO = DTypeQ; \
using RaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
using PagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
return DISPATCH_head_dim_sm90(head_dim, HEAD_DIM, [&] { \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
__VA_ARGS__(); \
return true; \
}); \
}); \
}); \
}); \
}); \
}
6 changes: 3 additions & 3 deletions csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::
std::vector<int64_t> BatchPrefillWithKVCachePlan(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph,
unsigned int head_dim, bool causal, int64_t cuda_stream);
at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream);

void BatchPrefillWithRaggedKVCacheRun(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
Expand Down
14 changes: 7 additions & 7 deletions csrc/flashinfer_ops_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
int64_t cuda_stream);

void single_prefill_with_kv_cache_sm90(
at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, std::optional<at::Tensor> maybe_lse,
unsigned int mask_mode_code, unsigned int layout,
at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o,
std::optional<at::Tensor> maybe_lse, unsigned int mask_mode_code, unsigned int layout,
int32_t window_left SINGLE_PREFILL_SM90_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream);

std::vector<int64_t> BatchPrefillWithKVCacheSM90Plan(
unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream);
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream);

void BatchPrefillWithRaggedKVCacheSM90Run(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
Expand Down
30 changes: 16 additions & 14 deletions csrc/single_prefill_sm90_config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,23 @@ using IdType = int32_t;
USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \
{ \
DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \
return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \
q_scalar_type, kv_scalar_type, dtype_q, dtype_kv, [&] { \
using DTypeQ = cutlass_dtype_t<dtype_q>; \
using DTypeKV = cutlass_dtype_t<dtype_kv>; \
using DTypeO = DTypeQ; \
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>; \
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
__VA_ARGS__(); \
return true; \
}); \
}); \
if (q_scalar_type != kv_scalar_type) { \
return false; \
} \
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, dtype_q, [&] { \
using DTypeQ = cutlass_dtype_t<dtype_q>; \
using DTypeKV = DTypeQ; \
using DTypeO = DTypeQ; \
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>; \
return DISPATCH_head_dim_sm90(head_dim, HEAD_DIM, [&] { \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
__VA_ARGS__(); \
return true; \
}); \
}); \
}); \
}); \
}); \
}
2 changes: 1 addition & 1 deletion flashinfer/jit/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def get_batch_prefill_uri(
use_fp16_qk_reduction: bool,
) -> str:
return (
f"batch_prefill_{backend}_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
f"batch_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
Expand Down
Loading

0 comments on commit 6e6f38d

Please sign in to comment.