Skip to content

Commit

Permalink
Added optimised version of MQA for MLA to llm_build_kqv()
Browse files Browse the repository at this point in the history
  • Loading branch information
jukofyork committed Mar 10, 2025
1 parent 3649714 commit c684535
Showing 1 changed file with 133 additions and 82 deletions.
215 changes: 133 additions & 82 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ static void llm_build_kv_store(
int64_t n_embd_k;
int64_t n_embd_v;

// note: deepseek-mla stores the compressed versions
// note: deepseek-mla converts MLA to MQA so n_embd_k/n_embd_v change too
if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) {
n_embd_k = hparams.n_lora_kv + hparams.n_rot;
n_embd_v = hparams.n_lora_kv;
Expand Down Expand Up @@ -568,59 +568,36 @@ static struct ggml_tensor * llm_build_kqv(
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;

const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head(il);
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head(il);
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
const int64_t n_embd_head_v = hparams.n_embd_head_v;

int64_t n_head_kv;
int64_t n_embd_k;
int64_t n_embd_head_k;
int64_t n_embd_v;
int64_t n_embd_head_v;
int64_t n_embd_head_v_final;

// note: MLA caches compressed KV and acts as MQA until the final wv_b expansion
if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) {
GGML_ASSERT(wv_b);
n_head_kv = 1;
n_embd_head_k = hparams.n_lora_kv + hparams.n_rot;
n_embd_k = n_embd_head_k;
n_embd_head_v = hparams.n_lora_kv;
n_embd_v = n_embd_head_v;
n_embd_head_v_final = hparams.n_embd_head_v; // after multiplying by wv_b
} else {
n_head_kv = hparams.n_head_kv(il);
n_embd_head_k = hparams.n_embd_head_k;
n_embd_k = hparams.n_embd_k_gqa(il);
n_embd_head_v = hparams.n_embd_head_v;
n_embd_v = hparams.n_embd_v_gqa(il);
n_embd_head_v_final = n_embd_head_v;
}
struct ggml_tensor * cur;

struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
cb(q, "q", il);

struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv.k_l[il]->type, n_embd_k),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
0);
cb(k, "k", il);

struct ggml_tensor * cur;

if (cparams.flash_attn) {
GGML_UNUSED(model);
GGML_UNUSED(n_ctx);

// note: MLA creates emebddings too large for FA, see: https://github.com/ggml-org/llama.cpp/pull/12227
GGML_ASSERT(!wv_b);
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
0);
cb(k, "k", il);

// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv.v_l[il]->type, n_embd_v),
ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
0);
cb(v, "v", il);
Expand All @@ -630,65 +607,139 @@ static struct ggml_tensor * llm_build_kqv(

ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);

cur = ggml_reshape_2d(ctx, cur, n_embd_head_v_final*n_head, n_tokens);
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);

// note: this op tends to require high floating point range
// while for some models F16 is enough, for others it is not, so we default to F32 here
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
// MLA converetd to MQA optimised to use non-batched matrix multiplies
if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) {
const int64_t n_embd_head_k_mqa = hparams.n_lora_kv + hparams.n_rot;
const int64_t n_embd_head_v_mqa = hparams.n_lora_kv;

if (model.arch == LLM_ARCH_GROK) {
// need to do the following:
// multiply by attn_output_multiplyer of 0.08838834764831845
// and then :
// kq = 30 * tanh(kq / 30)
// before the softmax below
// must cont for the 2D view or else kq with have n_tokens <-> n_head swapped...
q = ggml_cont(ctx, q);
cb(q, "q_cont", il);

kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
kq = ggml_scale(ctx, kq, 30);
}
q = ggml_view_2d(ctx, q,
n_embd_head_k_mqa, n_head * n_tokens,
ggml_row_size(q->type, n_embd_head_k_mqa),
0);
cb(q, "q_view", il);

if (hparams.attn_soft_cap) {
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
kq = ggml_tanh(ctx, kq);
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
}
struct ggml_tensor * k =
ggml_view_2d(ctx, kv.k_l[il],
n_embd_head_k_mqa, n_kv,
ggml_row_size(kv.k_l[il]->type, n_embd_head_k_mqa),
0);
cb(k, "k", il);

kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);

GGML_ASSERT(kv.size == n_ctx);
// note: this doesn't seem necessary
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv.v_l[il])*n_ctx,
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
kq = ggml_view_3d(ctx, kq,
n_kv, n_tokens, n_head,
ggml_row_size(kq->type, n_kv),
ggml_row_size(kq->type, n_kv * n_tokens),
0);
cb(v, "v", il);
cb(kq, "kq_view", il);

struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
cb(kqv, "kqv", il);
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);

// note: MLA needs to expand KQV from MQA into MHA
if (wv_b) {
struct ggml_tensor * wv_b_view = ggml_view_3d(ctx, wv_b, n_embd_head_v, n_embd_head_v_final, n_head,
ggml_row_size(model.layers[il].wv_b->type, n_embd_head_v),
ggml_row_size(model.layers[il].wv_b->type, n_embd_head_v * n_embd_head_v_final),
kq = ggml_view_2d(ctx, kq,
n_kv, n_tokens * n_head,
ggml_row_size(kq->type, n_kv),
0);
cb(kq, "kq_soft_max_view", il);

GGML_ASSERT(kv.size == n_ctx);

struct ggml_tensor * v =
ggml_view_2d(ctx, kv.v_l[il],
n_kv, n_embd_head_v_mqa,
ggml_element_size(kv.v_l[il])*n_ctx,
0);
cb(v, "v", il);

struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
cb(kqv, "kqv_compressed", il);

kqv = ggml_view_3d(ctx, kqv,
n_embd_head_v_mqa, n_tokens, n_head,
ggml_row_size(kqv->type, n_embd_head_v_mqa),
ggml_row_size(kqv->type, n_embd_head_v_mqa * n_tokens),
0);
cb(kqv, "kqv_view", il);

struct ggml_tensor * wv_b_view =
ggml_view_3d(ctx, wv_b, n_embd_head_v_mqa, n_embd_head_v, n_head,
ggml_row_size(wv_b->type, n_embd_head_v_mqa),
ggml_row_size(wv_b->type, n_embd_head_v * n_embd_head_v_mqa),
0);
cb(wv_b_view, "wv_b_view", il);

kqv = ggml_mul_mat(ctx, wv_b_view, kqv);
cb(kqv, "kqv_wv_b", il);
// dsecompress the MQA to MHA
cur = ggml_mul_mat(ctx, wv_b_view, kqv);
cb(cur, "kqv", il);

// standard MHA/GQA non-flash-attension case
} else {
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
0);
cb(k, "k", il);

struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);

// note: this op tends to require high floating point range
// while for some models F16 is enough, for others it is not, so we default to F32 here
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

if (model.arch == LLM_ARCH_GROK) {
// need to do the following:
// multiply by attn_output_multiplyer of 0.08838834764831845
// and then :
// kq = 30 * tanh(kq / 30)
// before the softmax below

kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
kq = ggml_scale(ctx, kq, 30);
}

if (hparams.attn_soft_cap) {
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
kq = ggml_tanh(ctx, kq);
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
}

kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);

GGML_ASSERT(kv.size == n_ctx);

// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv.v_l[il])*n_ctx,
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
0);
cb(v, "v", il);

cur = ggml_mul_mat(ctx, v, kq);
cb(cur, "kqv", il);
}

struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
cur = ggml_permute(ctx, cur, 0, 2, 1, 3);
cb(cur, "kqv_merged", il);

cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v_final*n_head, n_tokens);
cur = ggml_cont_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);
}

Expand Down

0 comments on commit c684535

Please sign in to comment.