Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSeek V2/V3 implementation refactored to allow non-MLA and MLA #12313

Closed
wants to merge 11 commits into from
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.flash_attn = true;
}
).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg(
{"-mla", "--mla-attn"},
string_format("enable Multi-head Latent Attention (default: %s)", params.mla_attn ? "enabled" : "disabled"),
[](common_params & params) {
params.mla_attn = true;
}
).set_env("LLAMA_ARG_MLA_ATTN"));
add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys",
Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.no_perf = params.no_perf;

if (params.reranking) {
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ struct common_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool mla_attn = false; // MLA attention for deepseek2
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation

Expand Down
72 changes: 72 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4141,6 +4141,78 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
else:
return []

n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
n_head_kv = self.hparams["num_key_value_heads"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
qk_rope_head_dim = self.hparams["qk_rope_head_dim"]
v_head_dim = self.hparams["v_head_dim"]
kv_lora_rank = self.hparams["kv_lora_rank"]

# (v2-lite) split q_proj into: q_proj and q_mqa_proj
if name.endswith("q_proj.weight"):
assert data_torch.shape[0] == n_head_kv * (qk_nope_head_dim + qk_rope_head_dim)
assert data_torch.shape[1] == n_embed

q_proj_with_mqa = data_torch.view(n_head_kv, qk_nope_head_dim + qk_rope_head_dim, n_embed)
q_proj, q_mqa_proj = torch.split(q_proj_with_mqa, [qk_nope_head_dim, qk_rope_head_dim], dim = 1)

q_proj = q_proj.reshape(n_head_kv * qk_nope_head_dim, n_embed)
q_mqa_proj = q_mqa_proj.reshape(n_head_kv * qk_rope_head_dim, n_embed)

return [
(self.map_tensor_name(name), q_proj),
(self.map_tensor_name(name.replace("q_proj", "q_mqa_proj")), q_mqa_proj)
]

# (v2/v3/r1) split q_b_proj into: q_b_proj and q_b_mqa_proj
if name.endswith("q_b_proj.weight"):
q_lora_rank = self.hparams["q_lora_rank"]

assert data_torch.shape[0] == n_head_kv * (qk_nope_head_dim + qk_rope_head_dim)
assert data_torch.shape[1] == q_lora_rank

q_b_proj_with_mqa = data_torch.view(n_head_kv, qk_nope_head_dim + qk_rope_head_dim, q_lora_rank)
q_b_proj, q_b_mqa_proj = torch.split(q_b_proj_with_mqa, [qk_nope_head_dim, qk_rope_head_dim], dim = 1)

q_b_proj = q_b_proj.reshape(n_head_kv * qk_nope_head_dim, q_lora_rank)
q_b_mqa_proj = q_b_mqa_proj.reshape(n_head_kv * qk_rope_head_dim, q_lora_rank)

return [
(self.map_tensor_name(name), q_b_proj),
(self.map_tensor_name(name.replace("q_b_proj", "q_b_mqa_proj")), q_b_mqa_proj)
]

# split kv_a_proj_with_mqa into: kv_a_proj and k_mqa_proj
if name.endswith("kv_a_proj_with_mqa.weight"):
assert data_torch.shape[0] == kv_lora_rank + qk_rope_head_dim
assert data_torch.shape[1] == n_embed

kv_a_proj_with_mqa = data_torch.view(kv_lora_rank + qk_rope_head_dim, n_embed)
kv_a_proj, k_mqa_proj = torch.split(kv_a_proj_with_mqa, [kv_lora_rank, qk_rope_head_dim], dim = 0)

return [
(self.map_tensor_name(name.replace("kv_a_proj_with_mqa", "kv_a_proj")), kv_a_proj),
(self.map_tensor_name(name.replace("kv_a_proj_with_mqa", "k_mqa_proj")), k_mqa_proj)
]

# split kv_b_proj into: k_b_proj, v_b_proj, and k_b_trans_proj (for deepseek-mla)
if name.endswith("kv_b_proj.weight"):
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
assert data_torch.shape[1] == kv_lora_rank

kv_b_proj = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, kv_lora_rank)
k_b_proj, v_b_proj = torch.split(kv_b_proj, [qk_nope_head_dim, v_head_dim], dim = 1)

k_b_trans_proj = k_b_proj.transpose(1, 2).reshape(n_head_kv * kv_lora_rank, qk_nope_head_dim)
k_b_proj = k_b_proj.reshape(n_head_kv * qk_nope_head_dim, kv_lora_rank)
v_b_proj = v_b_proj.reshape(n_head_kv * v_head_dim, kv_lora_rank)

return [
(self.map_tensor_name(name.replace("kv_b_proj", "k_b_trans_proj")), k_b_trans_proj),
(self.map_tensor_name(name.replace("kv_b_proj", "k_b_proj")), k_b_proj),
(self.map_tensor_name(name.replace("kv_b_proj", "v_b_proj")), v_b_proj)
]

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
Expand Down
1 change: 1 addition & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ The project is under active development, and we are [looking for feedback and co
| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
| `-fa, --flash-attn` | enable Flash Attention (default: disabled)<br/>(env: LLAMA_ARG_FLASH_ATTN) |
| `-mla, --mla-attn` | enable Multi-head Latent Attention (default: disabled)<br/>(env: LLAMA_ARG_MLA_ATTN) |
| `--no-perf` | disable internal libllama performance timings (default: false)<br/>(env: LLAMA_ARG_NO_PERF) |
| `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
| `--no-escape` | do not process escape sequences |
Expand Down
21 changes: 21 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,13 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
ATTN_Q_MQA = auto()
ATTN_Q_B_MQA = auto()
ATTN_KV_A = auto()
ATTN_K_MQA = auto()
ATTN_K_B_TRANS = auto()
ATTN_K_B = auto()
ATTN_V_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
Expand Down Expand Up @@ -543,6 +550,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
MODEL_TENSOR.ATTN_Q_MQA: "blk.{bid}.attn_q_mqa",
MODEL_TENSOR.ATTN_Q_B_MQA: "blk.{bid}.attn_q_b_mqa",
MODEL_TENSOR.ATTN_KV_A: "blk.{bid}.attn_kv_a",
MODEL_TENSOR.ATTN_K_MQA: "blk.{bid}.attn_k_mqa",
MODEL_TENSOR.ATTN_K_B_TRANS: "blk.{bid}.attn_k_b_trans",
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
Expand Down Expand Up @@ -1041,6 +1055,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_Q_MQA,
MODEL_TENSOR.ATTN_Q_B_MQA,
MODEL_TENSOR.ATTN_KV_A,
MODEL_TENSOR.ATTN_K_MQA,
MODEL_TENSOR.ATTN_K_B_TRANS,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
Expand Down
28 changes: 28 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,34 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_MQA: (
"model.layers.{bid}.self_attn.q_mqa_proj", # deepseek2 (v2-lite)
),

MODEL_TENSOR.ATTN_Q_B_MQA: (
"model.layers.{bid}.self_attn.q_b_mqa_proj", # deepseek2 (v2/v3/r1)
),

MODEL_TENSOR.ATTN_KV_A: (
"model.layers.{bid}.self_attn.kv_a_proj", # deepseek2
),

MODEL_TENSOR.ATTN_K_MQA: (
"model.layers.{bid}.self_attn.k_mqa_proj", # deepseek2
),

MODEL_TENSOR.ATTN_K_B_TRANS: (
"model.layers.{bid}.self_attn.k_b_trans_proj", # deepseek2 (mla only)
),

MODEL_TENSOR.ATTN_K_B: (
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_V_B: (
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ extern "C" {
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool mla_attn; // MLA attention for deepseek2
bool no_perf; // whether to measure performance timings

// Abort callback
Expand Down
31 changes: 14 additions & 17 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,13 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_Q_MQA, "blk.%d.attn_q_mqa" },
{ LLM_TENSOR_ATTN_Q_B_MQA, "blk.%d.attn_q_b_mqa" },
{ LLM_TENSOR_ATTN_KV_A, "blk.%d.attn_kv_a" },
{ LLM_TENSOR_ATTN_K_MQA, "blk.%d.attn_k_mqa" },
{ LLM_TENSOR_ATTN_K_B_TRANS, "blk.%d.attn_k_b_trans" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
Expand Down Expand Up @@ -1333,23 +1340,13 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B_TRANS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down
7 changes: 7 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,13 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_Q_MQA,
LLM_TENSOR_ATTN_Q_B_MQA,
LLM_TENSOR_ATTN_KV_A,
LLM_TENSOR_ATTN_K_MQA,
LLM_TENSOR_ATTN_K_B_TRANS,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
bool mla_attn;
bool no_perf;

enum llama_pooling_type pooling_type;
Expand Down
19 changes: 16 additions & 3 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ bool llama_kv_cache_init(

cache.recurrent = llama_model_is_recurrent(&model);
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported yet

LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift);
Expand Down Expand Up @@ -91,8 +91,21 @@ bool llama_kv_cache_init(
return false;
}

ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
int64_t n_embd_k;
int64_t n_embd_v;

// note: deepseek-mla stores the compressed versions
if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) {
n_embd_k = hparams.n_lora_kv + hparams.n_rot;
n_embd_v = hparams.n_lora_kv;
} else {
n_embd_k = hparams.n_embd_k_gqa(i);
n_embd_v = hparams.n_embd_v_gqa(i);
}

ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v*kv_size);

ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
Expand Down
16 changes: 11 additions & 5 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2890,14 +2890,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) {

if (!is_lite) {
layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_qk_nope}, 0);
layer.wq_b_mqa = create_tensor(tn(LLM_TENSOR_ATTN_Q_B_MQA, "weight", i), {q_lora_rank, n_head * n_embd_head_qk_rope}, 0);
} else {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_qk_nope}, 0);
layer.wq_mqa = create_tensor(tn(LLM_TENSOR_ATTN_Q_MQA, "weight", i), {n_embd, n_head * n_embd_head_qk_rope}, 0);
}

layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
layer.wkv_a = create_tensor(tn(LLM_TENSOR_ATTN_KV_A, "weight", i), {n_embd, kv_lora_rank}, 0);
layer.wk_mqa = create_tensor(tn(LLM_TENSOR_ATTN_K_MQA, "weight", i), {n_embd, n_embd_head_qk_rope}, 0);
layer.wk_b_trans = create_tensor(tn(LLM_TENSOR_ATTN_K_B_TRANS, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0);
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_qk_nope}, 0);
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0);

layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v, n_embd}, 0);

layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);

Expand Down
Loading
Loading