Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSeek V2/V3 implementation refactored to allow non-MLA and MLA #12313

Closed
wants to merge 11 commits into from

Conversation

jukofyork
Copy link
Contributor

@jukofyork jukofyork commented Mar 10, 2025

IMPORTANT: This will require re-quantising all models that use this PR!!!


This is a vastly tided up continuation of #11446 and #12227 which allows the use of the -mla (--mla-attn) option:

  • By default it won't use MLA and essentially converts MLA into MHA (with very large KV-cache overhead).
  • With the -mla option it essentially converts MLA into MQA (with very low KV-cache overhead, but at the cost of more compute).
  • The build_deepseek2() code now uses the proper llm_build_kv() calls for both the non-MLA and MLA branches.
  • There will likely be some performance regression because of this due to: the forced F32 upcast, no 2D x 2D optimisations, and the splitting of the q_b and kv_b tensors to extract the MQA (ie: RoPE part) separately (see below).

NOTE: This will require re-quantising all models that use this, but this won't change and I intend to run some experiments over the next few days to find better quant rules for the newly split-up tensors (to hopefully avoid so many of the numerical problems that seem to plague this model).

I also plan to see if I can get back some of the lost performance my previous PR gave (but at the cost of a vastly more complex/unmaintainable build_deepseek2() due to all the 2D/3D views it used). DONE

I have left context shifting disabled for now, but I have been careful to move the RoPE parts to the first n_rot parameters so it should be possible eventually to get working with build_k_shift() and build_defrag(), etc. I can't cleanly add this currently though and if I try it will likely end up a confusing mess of overriding the GGUF file parameters for n_embd_k_gqa, n_embd_v_gqa. I've tried to do this as cleanly as the current code allows in: llama-kv-cache.cpp::llama_kv_cache_init(), llama.cpp::llm_build_kv_store() and llama.cpp::llm_build_kqv(). I'm also not 100% clear on the ins-and-outs of the YaRN implementation and how it works for context shifting, etc.


Things in llama.cpp and ggml I'm still a bit unsure of:

  1. There seems to be so many different places to add and copy the -mla option, and I'm not entirely confident I have them all (I looked at how the -fa option was used and tried to copy that as best I could).
  2. When taking views of possibly (likely) quantised tensors:
// {n_embd_head_qk_nope, kv_lora_rank, n_head}
struct ggml_tensor * wk_b_trans_view = ggml_view_3d(ctx0, model.layers[il].wk_b_trans,
        n_embd_head_qk_nope, kv_lora_rank, n_head,
        ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope),
        ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope),
        0);
cb(wk_b_trans_view, "wk_b_trans_view", il);

Should I be using the nb[] values? I'm currently just quantising everything to Bf16 (for the attention tensors anyway), so it's possible some of my views are not going to work when quantised... 😕

  1. What if any tests should I add for this? The only tests for the KV-cache creation look very outdated.

@fairydreaming
Copy link
Collaborator

@jukofyork I wanted to try this, but there seems to be a problem with DeepSeek R1 model conversion in your branch:

(llama.cpp) phm@epyc:~/projects/llama.cpp-mla-final-refactor$ git branch
* mla-final-refactor
(llama.cpp) phm@epyc:~/projects/llama.cpp-mla-final-refactor$ python3 convert_hf_to_gguf.py /mnt/md0/huggingface/hub/models--deepseek-ai--DeepSeek-R1-bf16/ --outfile models/deepseek-r1-PR12313.gguf --outtype "f16"
INFO:hf-to-gguf:Loading model: models--deepseek-ai--DeepSeek-R1-bf16
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Exporting model...
INFO:hf-to-gguf:gguf: loading model weight map from 'model.safetensors.index.json'
INFO:hf-to-gguf:gguf: loading model part 'model-00001-of-000163.safetensors'
INFO:hf-to-gguf:token_embd.weight,            torch.bfloat16 --> F16, shape = {7168, 129280}
INFO:hf-to-gguf:blk.0.attn_norm.weight,       torch.bfloat16 --> F32, shape = {7168}
INFO:hf-to-gguf:blk.0.ffn_down.weight,        torch.bfloat16 --> F16, shape = {18432, 7168}
INFO:hf-to-gguf:blk.0.ffn_gate.weight,        torch.bfloat16 --> F16, shape = {7168, 18432}
INFO:hf-to-gguf:blk.0.ffn_up.weight,          torch.bfloat16 --> F16, shape = {7168, 18432}
INFO:hf-to-gguf:blk.0.ffn_norm.weight,        torch.bfloat16 --> F32, shape = {7168}
INFO:hf-to-gguf:blk.0.attn_kv_a_norm.weight,  torch.bfloat16 --> F32, shape = {512}
Traceback (most recent call last):
  File "/home/phm/projects/llama.cpp-mla-final-refactor/convert_hf_to_gguf.py", line 5189, in <module>
    main()
  File "/home/phm/projects/llama.cpp-mla-final-refactor/convert_hf_to_gguf.py", line 5183, in main
    model_instance.write()
  File "/home/phm/projects/llama.cpp-mla-final-refactor/convert_hf_to_gguf.py", line 439, in write
    self.prepare_tensors()
  File "/home/phm/projects/llama.cpp-mla-final-refactor/convert_hf_to_gguf.py", line 4219, in prepare_tensors
    super().prepare_tensors()
  File "/home/phm/projects/llama.cpp-mla-final-refactor/convert_hf_to_gguf.py", line 298, in prepare_tensors
    for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/phm/projects/llama.cpp-mla-final-refactor/convert_hf_to_gguf.py", line 4194, in modify_tensors
    (self.map_tensor_name(name.replace("kv_a_proj_with_mqa", "kv_a_proj")), kv_a_proj),
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/phm/projects/llama.cpp-mla-final-refactor/convert_hf_to_gguf.py", line 214, in map_tensor_name
    raise ValueError(f"Can not map tensor {name!r}")
ValueError: Can not map tensor 'model.layers.0.self_attn.kv_a_proj.weight'

@jukofyork
Copy link
Contributor Author

@fairydreaming I'm actually just reverting this as I realised it was going to be really hard to maintain llm_build_kqv() as every time somebody adds a model it will require lots of changes to get it mergeable again. It also seems like splitting off the two _mqa tensors loses all the gains I made in the last PR and even adding back most of the code into llm_build_kqv() it was still much worse...

I'm now just merging the older "with flash attention" PR with the "-mla" options, but trying to use at least llm_build_kv_store():

                    struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
                    cb(q_states, "q_states", il);

                    struct ggml_tensor * k_states = ggml_concat(ctx0, kv_compressed, k_pe_view, 0);
                    cb(k_states, "k_states", il);

                    struct ggml_tensor * v_states = kv_compressed;
                    cb(v_states, "v_states", il);

                    // these nodes are added to the graph together so that they are not reordered
                    // by doing so, the number of splits in the graph is reduced
                    ggml_build_forward_expand(gf, q_states);
                    ggml_build_forward_expand(gf, k_states);
                    ggml_build_forward_expand(gf, v_states);

                    llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, k_states, v_states, n_tokens, kv_head, cb, il);

I'll have it done in a couple of hours and there won't be any need to requant then too (closing this for now).

@jukofyork jukofyork closed this Mar 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants