Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernels generated with use_fp16_qk_reductions=true break the LogitsTransform implementation used by prefill kernels #936

Open
diptorupd opened this issue Mar 12, 2025 · 0 comments · May be fixed by #962

Comments

@diptorupd
Copy link
Contributor

I want to clarify the semantics of the LogitsTransform function declared on :

REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {

The logits parameter is templated and presumably can support __half. However, the computation of the value of logits on

logits = float(math::tanh(logits * soft_cap_pre_tanh_scale));
causes a compilation failure as the operator * cannot be resolved when the types are __half and float. Also, the assignment will also likely break because of the implicit float to __half conversion.

The LogitsTransform template is invoked inside the prefill kernel on line

variant.LogitsTransform(params, s_frag[mma_q][mma_kv][reg_id], batch_idx, q_idx, kv_idx,

I discovered this issue when working on fixing #806 and compiling kernels that were generated with the use_fp16_qk_reductions=true flag passed to aot_build_utils.generate.

I can apply a fix using a constexpr cast from fp16 to fp32 and vice-versa either at call-site or inside LogitsTransform . But, before I do any mechanical changes wanted to clarify the intent of the implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant