diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b0e3f63597a76..b9d0952698c08 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1432,21 +1432,24 @@ ggml_tensor * llm_graph_context::build_attn( v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens); - ggml_tensor * v_cache_view = nullptr; + // note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache + if (!v_mla || v_trans) { + ggml_tensor * v_cache_view = nullptr; - if (!v_trans) { - v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head); - } else { - // note: the V cache is transposed when not using flash attention - v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv_self->v_l[il]), - (kv_head)*ggml_element_size(kv_self->v_l[il])); + if (!v_trans) { + v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head); + } else { + // note: the V cache is transposed when not using flash attention + v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(kv_self->v_l[il]), + (kv_head)*ggml_element_size(kv_self->v_l[il])); - v_cur = ggml_transpose(ctx0, v_cur); - } - //cb(v_cache_view, "v_cache_view", il); + v_cur = ggml_transpose(ctx0, v_cur); + } + //cb(v_cache_view, "v_cache_view", il); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); + } } const bool is_swa = hparams.is_swa(il); @@ -1471,17 +1474,28 @@ ggml_tensor * llm_graph_context::build_attn( 0); //cb(k, "k", il); - ggml_tensor * v = !v_trans ? - ggml_view_3d(ctx0, kv_self->v_l[il], - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v), - 0) : - ggml_view_3d(ctx0, kv_self->v_l[il], + ggml_tensor * v = nullptr; + + if (v_trans) { + v = ggml_view_3d(ctx0, kv_self->v_l[il], n_kv, n_embd_head_v, n_head_kv, ggml_element_size(kv_self->v_l[il])*n_ctx, ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, 0); + } else if (!v_mla) { + v = ggml_view_3d(ctx0, kv_self->v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v), + 0); + } else { + // note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache + v = ggml_view_3d(ctx0, kv_self->k_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), + n_embd_head_k-n_embd_head_v); // offset by n_rot elements + } ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale); cb(cur, "kqv_out", il); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3dcad65bb6a85..ea401c2ab5b48 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -32,8 +32,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) { const int32_t n_layer = hparams.n_layer; + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + has_shift = false; - can_shift = true; + can_shift = !is_mla || v_trans; // TODO: allow context shifting for MLA with flash attention LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n", __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding); @@ -100,8 +102,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( throw std::runtime_error("failed to create ggml context for kv cache"); } + // note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, !is_mla || v_trans ? n_embd_v_gqa*kv_size : 0); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); k_l.push_back(k);