Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized DeepSeek V2/V3 implementation (MLA) #11446

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

fairydreaming
Copy link
Collaborator

@fairydreaming fairydreaming commented Jan 27, 2025

This PR introduces various optimizations for DeepSeek V2/V3 implementation:

Note that you need to reconvert the model to use this implementation.

Performance compared to the previous "naive" implementation:

deepseek-mla

deepseek-lite-mla-pp

deepseek-r1-mla

deepseek-mla-pp

CUDA performance is worse for short context lengths, but the curve is flatter:

deepseek-lite-mla

deepseek-lite-cuda-mla-pp

TODO:

  • remove unused kv_b tensor from the model
  • maybe add support for old model files (compute k_b and v_b during inference with reduced performance)
  • wait for completion of: llama : refactor llama_kv_cache, llama_context and llm_build_context #11213
  • implement MLA KV cache
  • address regressions in prompt processing performance (different permutations of tensors?) - I don't think it's possible, as this implementation is more compute-intensive compared to regular attention implementation

@fairydreaming fairydreaming marked this pull request as draft January 28, 2025 11:23
@wronkiew
Copy link

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@fairydreaming
Copy link
Collaborator Author

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@wronkiew What model would you like to test?

@wronkiew
Copy link

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@wronkiew What model would you like to test?

V3/R1, Q4_K_S.

@fairydreaming
Copy link
Collaborator Author

@fairydreaming do you have a converted model available or instructions for replicating your setup? I would like to run some benchmarks on these changes.

@wronkiew What model would you like to test?

V3/R1, Q4_K_S.

@wronkiew I don't have the model uploaded (my upload bandwidth is too low), you have to download, convert to bf16, convert to gguf and quantize the original model by yourself (or download one that is already converted to bf16, this will save you one step).

@fairydreaming
Copy link
Collaborator Author

I spent some time investigating this hint from the DeepSeek V2 paper:

Fortunately, due to the associative law of matrix multiplication, we can absorb $𝑊^{𝑈𝐾}$ into $𝑊^{𝑈𝑄}$ , and $𝑊^{𝑈𝑉}$ into $𝑊^𝑂$

At first glance it looks reasonable, each absorbed matrix allows to replace two matrix multiplications with a single multiplication, thus reducing the number of operations.

However when we take a look into dimensions of these matrices, this stops being reasonable. For example in DeepSeek V2 lite:

  • $𝑊^{𝑈𝑄}$ tensor has shape [2048, 2048], that is [16, 2048, 128] after reshaping to 3d and permutation
  • $𝑊^{𝑈𝐾}$ tensor has shape [128, 8192], that is [16, 512, 128] after reshaping to 3d and permutation
  • combined "absorbed" tensor has shape [16, 512, 2048]

So (let's ignore the head dimension) this allows to replace two multiplications: with [2048, 128] matrix and [512, 128] matrix with a single multiplication with a [512, 2048]. The combined matrix has over 3x elements compared to individual matrices, so it will take more memory and it will be actually slower to multiply compared to two multiplications with smaller matrices.

With $𝑊^{𝑈𝑉}$ and $𝑊^𝑂$ it's the same story:

  • $𝑊^{𝑈𝑉}$ tensor has shape [2048, 512], that is [16, 512, 128] after reshaping to 3d and permutation
  • $𝑊^𝑂$ tensor has shape [2048, 2048], that is [16, 2048, 128] after reshaping to 3d and permutation
  • combined "absorbed" tensor has shape [16, 512, 2048]

I also found this blog post: https://github.com/xjdr-alt/mla_blog_translation where they mention:

Compared to performing projection with these particularly large low-rank matrices, it is obviously more advantageous to multiply them successively according to the low-rank decomposition form. Therefore, we believe that this optimization step is not very necessary.

So it looks like a dead end, it won't give us any speed gains.

@divine-taco
Copy link

I ran into an issue with DeepSeek-R1-UD-Q2_K_XL from unsloth/DeepSeek-R1-GGUF

llama_model_load: error loading model: missing tensor 'blk.0.attn_k_b.weight'                                                        llama_model_load_from_file_impl: failed to load model

@fairydreaming
Copy link
Collaborator Author

fairydreaming commented Jan 31, 2025

I ran into an issue with DeepSeek-R1-UD-Q2_K_XL from unsloth/DeepSeek-R1-GGUF

llama_model_load: error loading model: missing tensor 'blk.0.attn_k_b.weight'                                                        llama_model_load_from_file_impl: failed to load model

As I wrote in the PR:

Note that you need to reconvert the model to use this implementation.

Existing GGUFs won't work, you have to convert and quantize one with the code from this PR.

@danielhanchen
Copy link
Contributor

Ohh hmm should I re-quantize the ones in https://huggingface.co/unsloth/DeepSeek-R1-GGUF?

@fairydreaming
Copy link
Collaborator Author

Ohh hmm should I re-quantize the ones in https://huggingface.co/unsloth/DeepSeek-R1-GGUF?

I think it's best to wait a bit until this is stable and merged, it's possible that there will be some changes that would cause them to stop working and you'd have to repeat the conversion again.

@fairydreaming
Copy link
Collaborator Author

I updated the token generation performance plots in the PR post, also added some new showing the prompt processing performance. The optimized implementation generally performs WORSE in prompt processing - DeepSeek R1 671B Q4_K_S running on CPU performs only a little worse (~10% with 4k prompt), but DeepSeek V2 Lite Q8_0 running on RTX 4090 performs MUCH WORSE (~30% with 16k prompt) and in both cases the gap widens as the prompt length increases. So it's not all sunshine and rainbows.

Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture (LLM_ARCH_DEEPSEEK2_MLA or something like this). This will prevent issues with existing GGUFs - they would keep working with existing architecture. I guess in this case the convert script would have to allow selection of the target model architecture with some option, but that shouldn't be difficult to add. @ggerganov what do you think?

Comment on lines +6406 to +6409
// whether to use n_tokens as the matrix dimension during multiplication or n_head
// n_tokens is higher during prompt processing, this allows to optimize for this case
bool pp_opt = n_tokens > n_head;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure this is the right approach. Haven't followed through the logic yet, but it seems strange to involve so many permutes and conts.

I would first look into improving the FA kernels to support DeepSeek head sizes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure this is the right approach. Haven't followed through the logic yet, but it seems strange to involve so many permutes and conts.

Hmm? I'm quite sure there's only one ggml_cont() call (excluding the ones for CUDA compatibility that already existed in the previous implementation).

As for the permutes the idea is to multiply by a matrix with a second dimension equal to the number of heads instead of the number of tokens (which is 1) during a single sequence token generation, that increased the performance on a CPU a bit.

So during prompt processing we have 2 permutes and 1 cont. During token generation we have 5 permutes (yeah, that may be a lot) and 0 conts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the correction - I did imagine the extra conts when I saw the permutes.

@ggerganov
Copy link
Member

Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture (LLM_ARCH_DEEPSEEK2_MLA or something like this). This will prevent issues with existing GGUFs - they would keep working with existing architecture. I guess in this case the convert script would have to allow selection of the target model architecture with some option, but that shouldn't be difficult to add. @ggerganov what do you think?

While this is possible to do, I think it has a lot of cons. It will make it difficult for everyone to know which model variation on which hardware to use for better performance. Ideally, we want to have a single implementation that is optimal in all use cases, which can be deprecated at some point for a better alternative. But having 2 alternatives neither of which is optimal is not great.

Also, I'm not sure how this implementation fits with multiple parallel sequences and it introduces extra KV cache logic, specific to this type of arch.

I know there is a lot of interest in the DeepSeek arch right now and such optimizations are really important for people. But I think that we have to keep this work in a PR for a while. It is much more important to fix the software architecture in libllama after which such changes should become easier.

@fairydreaming
Copy link
Collaborator Author

Considering all these performance regressions I think the best course of action would be to put the optimized implementation into separate model architecture (LLM_ARCH_DEEPSEEK2_MLA or something like this). This will prevent issues with existing GGUFs - they would keep working with existing architecture. I guess in this case the convert script would have to allow selection of the target model architecture with some option, but that shouldn't be difficult to add. @ggerganov what do you think?

While this is possible to do, I think it has a lot of cons. It will make it difficult for everyone to know which model variation on which hardware to use for better performance. Ideally, we want to have a single implementation that is optimal in all use cases, which can be deprecated at some point for a better alternative. But having 2 alternatives neither of which is optimal is not great.

That may not be possible - IMHO MLA attention implementation that caches "compressed" latent kv representations introduces unavoidable computational overhead due to the need to "decompress" these representations in order to calculate attention scores and attention output. So "naive" attention implementation that caches full K/V vectors will always use less compute but more memory bandwidth, while caching latent representations results in using more compute, but less memory bandwidth. So there can't be a single implementation optimal in all use cases. I'd be happy to be proven wrong about this, though.

Also, I'm not sure how this implementation fits with multiple parallel sequences and it introduces extra KV cache logic, specific to this type of arch.

I think there shouldn't be any problems with this, as there is a straightforward direct mapping between the cached representations and full K/V vectors.

I know there is a lot of interest in the DeepSeek arch right now and such optimizations are really important for people. But I think that we have to keep this work in a PR for a while. It is much more important to fix the software architecture in libllama after which such changes should become easier.

That's fine with me. I'm taking a break from this anyway, got bored with tensor shuffling looking for 0.1 t/s more performance. 😉

@saood06
Copy link

saood06 commented Feb 2, 2025

@fairydreaming
Is there any reason this should cause issues with RPC.
Encountered:

ggml_cuda_compute_forward: cannot compute kqv-31: src0->ne[3] = 1, src1->ne[3] = 2 - fallback to CPU
evaluate_and_capture_cuda_graph: op not supported kqv-31 (MUL_MAT)
[...]\llama.cpp\ggml\src\ggml-cuda\ggml-cuda.cu:2660: GGML_ASSERT(ok) failed

I don't have a quant on hand that I can test without this branch, but this branch does give me a nice performance boost for TG at longer contexts, but RPC to CUDA does not work.

@JohannesGaessler
Copy link
Collaborator

What ggml_cuda_op_mul_mat is doing is beyond me

The function makes input tensors contiguous and presents them as single-batch matrix multiplications to other kernels. The conversion to q8_1 is only done for kernels that use quantized data.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 16, 2025

What ggml_cuda_op_mul_mat is doing is beyond me

The function makes input tensors contiguous and presents them as single-batch matrix multiplications to other kernels. The conversion to q8_1 is only done for kernels that use quantized data.

I don't know if it's worth looking at yet as this is still a draft PR, but it should be quite easy to replicate the slowdown I saw using Q8_0 via the smaller deepseek-v2-lite model.

I just looked at my logs for running the first 16 chunks of llama-perplexity and it seems to affect prompt processing too:

BF16 for attn_k_b:

llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type q8_0:  551 tensors
llama_model_loader: - type q5_K:  116 tensors
llama_model_loader: - type q6_K:   58 tensors
llama_model_loader: - type bf16:   61 tensors

perplexity: tokenization took 1190.46 ms
perplexity: calculating perplexity over 561 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 147.01 seconds per pass - ETA 5 hours 43.63 minutes
[1]2.5077,[2]3.2865,[3]2.3722,[4]1.9792,[5]1.7845,[6]1.6442,[7]1.5532,[8]1.4876,[9]1.4382,[10]1.3990,[11]1.3828,[12]1.4122,[13]1.4242,[14]1.5514,[15]1.6815,[16]1.7411,^C

F16 for attn_k_b:

llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type  f16:   61 tensors
llama_model_loader: - type q8_0:  551 tensors
llama_model_loader: - type q5_K:  116 tensors
llama_model_loader: - type q6_K:   58 tensors

perplexity: calculating perplexity over 561 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 146.17 seconds per pass - ETA 5 hours 41.67 minutes
[1]nan,[2]nan,[3]nan,[4]nan,^C

Q8_0 for attn_k_b:

llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type q8_0:  612 tensors
llama_model_loader: - type q5_K:  116 tensors
llama_model_loader: - type q6_K:   58 tensors

perplexity: calculating perplexity over 561 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 330.34 seconds per pass - ETA 12 hours 52.15 minutes
[1]2.4944,[2]3.2788,[3]2.3639,[4]1.9761,[5]1.7833,[6]1.6414,[7]1.5508,[8]1.4850,[9]1.4362,[10]1.3974,[11]1.3819,[12]1.4160,[13]1.4275,[14]1.5542,[15]1.6846,[16]1.7440,^C

Which is nearly 2.5x longer for Q8_0.

This isn't quite the same custom quant, but uses F32 for attn_k_b and looks to have similar timing to F16 and BF16:

llama_model_loader: - type  f32:  544 tensors
llama_model_loader: - type q8_0:  429 tensors
llama_model_loader: - type q5_K:   84 tensors
llama_model_loader: - type q6_K:   90 tensors

perplexity: calculating perplexity over 561 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 146.66 seconds per pass - ETA 5 hours 42.80 minutes
[1]2.4987,[2]3.2797,[3]2.3680,[4]1.9800,[5]1.7890,[6]1.6474,[7]1.5545,[8]1.4885

I can't find the logs for the token generation but the speed difference wasn't as bad as this: something like 2.8 tokens per second for Q8_0 and 3.6-3.8 tokens per second for BF16.

The F16 overflow might be specific to deepseek-v3 or deepseek-r1 only, but as well as causing nan in llama-perplexity (as the example above shows), it just writes <think> and then the same word over and over.

I went as far as testing the magnitude of all the attn_k_b tensors and IIRC, none even had a magnitude over 256. It's also multiplying from the compressed KV-cache which has been passed through a layer_norm before being stored, so I'm at a loss to see what else could be overflowing. I did try adding a ggml_mul_mat_set_prec() call right before all the mul_mat() calls before I narrowed the overflow down to attn_k_b but from my brief skimming of the code this afternoon; I don't think that looks to be used here?

@JohannesGaessler
Copy link
Collaborator

Generally speaking, the KQ matrix is susceptible to overflow. So it is preferable to use BF16 or FP32 accumulators for its calculation. However, I was never able to get FP16 matrix multiplication with FP32 accumulation to work with cuBLAS. The documentation says it should be possible but the kernel fails to launch when I try it. Currently the precision argument for KQ is not used for cuBLAS GEMM. For a FP16 K matrix FP16 accumulation is used unconditionally.

@JohannesGaessler
Copy link
Collaborator

I think I misremembered. After looking at the documentation again I think the problem was that FP16, FP16 -> FP32 GEMM is supported but the performance was so much worse that there was basically no point in using it.

@fairydreaming
Copy link
Collaborator Author

I investigated possible reasons for poor scaling of token generation when using DeepSeek V3/R1 on dual CPU systems.

My current working hypothesis is that the DeepSeek V3/R1 expert FFN matrices are so small (7168 x 2048) that overhead of using two CPUs when doing matrix vector multiplication during token generation negates almost all performance gains.

I suppose this is the reason why ktransformers folks in their v3.0-preview have two copies of experts in memory, one for each CPU.

I'm going to create a NUMA-aware matrix vector multiplication benchmark to verify this hypothesis.

I thought about possible solutions. One would be to assign the experts in each layer into N sets where N is equal to the number of CPUs and then use top n_expert_used/N experts from each set during inference. In this solution each CPU would handle only its assigned local experts and there would be no communication overhead. But it can result in non-optimal expert choices, not sure how it would affect the model performance.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 19, 2025

    // ######
    if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS        // IQ4_XXS : 4.25 BPW for experts | 345 GiB (4.41 BPW) (-28.4%) | PPL = 3.3850 +/- 0.01877 (+1.51%) | 15.05 tokens per second ( +8.0%)
        || ftype == LLAMA_FTYPE_MOSTLY_Q4_0       // Q4_0_XS : 4.5 BPW for experts  | 365 GiB (4.66 BPW) (-24.4%) | PPL = 3.3944 +/- 0.01885 (+1.95%) | 14.17 tokens per second ( +1.6%)
        || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S     // Q4_K_XS : 4.5 BPW for experts  | 365 GiB (4.66 BPW) (-24.4%) | PPL = 3.3724 +/- 0.01866 (+0.66%) | 18.81 tokens per second (+34.9%)
        || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S     // Q5_K_XS : 5.5 BPW for experts  | 441 GiB (5.63 BPW) ( -8.6%) | PPL = 3.3546 +/- 0.01852 (+0.16%) | 13.84 tokens per second ( -0.7%)
                                                  // -----------------------------------------------------------------------------------------------------------------------------------
        || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M     // Q4_K_XM : ~5.0 BPW for experts | 404 GiB (5.16 BPW) (-16.2%) | PPL = 3.3666 +/- 0.01863 (+0.48%) | 15.82 tokens per second (+13.5%)
        || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL     // Q4_K_XL : ~5.5 BPW for experts | 446 GiB (5.69 BPW) ( -7.6%) | PPL = 3.3614 +/- 0.01858 (+0.33%) | 16.03 tokens per second (+15.0%)
        || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M     // Q5_K_XM : ~6.0 BPW for experts | 483 GiB (6.16 BPW)          | PPL = 3.3504 +/- 0.01849          | 13.94 tokens per second
                                                  // -----------------------------------------------------------------------------------------------------------------------------------
        || ftype == LLAMA_FTYPE_MOSTLY_Q5_1       // Q5_K_XH : 5.0 BPW for experts  | 403 GiB (5.15 BPW)          | PPL = 3.3695 +/- 0.01864 (+0.57%) | 15.90 tokens per second (+14.1%)
        || ftype == LLAMA_FTYPE_MOSTLY_Q6_K       // Q6_K_XH : 6.0 BPW for experts  | 481 GiB (6.15 BPW) (-16.2%) | PPL = 3.3548 +/- 0.01853 (+0.13%) | 13.87 tokens per second ( -0.5%)
                                                  // -----------------------------------------------------------------------------------------------------------------------------------
        ) {                                       // iQ4_K_XS (Q4_K_XS using Bartowski imatrix for experts only)  : PPL = 3.3734 +/- 0.01866 (+0.69%) | 18.76 tokens per second (+34.6%)
        if (name.find("_exps") != std::string::npos) {
            int i_layer;
            if (sscanf(name.c_str(), "blk.%d.", &i_layer) != 1) {
                throw std::runtime_error(format("Failed to determine layer for tensor %s", name.c_str()));
            }
            if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
                new_type = GGML_TYPE_IQ4_XS;    // IQ4_XXS
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_0) {
                new_type = GGML_TYPE_Q4_0;      // Q4_0_XS
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S) {
                new_type = GGML_TYPE_Q4_K;      // Q4_K_XS
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) {
                new_type = GGML_TYPE_Q5_K;      // Q5_K_XS
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_1) {
                new_type = (i_layer <= 31 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K);  // Q5_K_XH first and last 29 experts
            }
            else if (ftype == LLAMA_FTYPE_MOSTLY_Q6_K) {
                new_type = (i_layer <= 31 ? GGML_TYPE_Q6_K : GGML_TYPE_Q5_K);  // Q6_K_XH first and last 29 experts
            }
            else if (name.find("ffn_down") != std::string::npos || i_layer <= 10 || i_layer >= 53) {  // First 8 and last 8 experts (ie: 16/58 experts)
                if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
                    new_type = GGML_TYPE_Q5_K;  // Q4_K_XM
                }
                else { 
                    new_type = GGML_TYPE_Q6_K;  // Q4_K_XL & Q5_K_XM
                }
            }
            else {
                if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) {
                    new_type = GGML_TYPE_Q5_K;  // Q5_K_XM
                }
                else {
                    new_type = GGML_TYPE_Q4_K;  // Q4_K_XM & Q4_K_XL
                }
            }
        }
        else if (name.find("attn_kv_a_mqa") != std::string::npos || name.find("attn_k_b") != std::string::npos || name.find("attn_v_b") != std::string::npos) {
            new_type = GGML_TYPE_F32;  // Also used: type_kr = type_kv = GGML_TYPE_F32
        }
        else {
            new_type = GGML_TYPE_Q8_0;
        }
    }
    else
    // ######

I've finished the testing of the custom quants:

  • Using Q4_K for all expert tensors seems a clear winner (the 365GB model would likely fit on 2 x M2 Ultra 192GB too).
  • Using Q4_0 and IQ4_XS gave particularly bad performance in comparison ( Q4_0 surprisingly in terms of tokens/s too).
  • Using Bartowski's imatrix for experts made no measurable difference for Q4_K.
  • Mixing different quants for the first/last tensors and bumping up_proj had very little gain (ie: might as well just use Q5_K ).
  • Not included, but found that any mixtures involving Q3_K really start to hurt performance badly.

Just running one last test on pure Q4_K to see if using type_kr = type_kv = GGML_TYPE_F16 vs type_kr = type_kv = GGML_TYPE_F32 makes any difference.

EDIT:

// Q4_K_XS : 4.5 BPW for experts  | 365 GiB (4.66 BPW) (-24.4%) | PPL = 3.3724 +/- 0.01866 (+0.66%)
// Q4_K_XS using type_kr = type_kv = GGML_TYPE_F16 & 44 threads : PPL = 3.3728 +/- 0.01866 (+0.67%)

No difference,

@jukofyork
Copy link
Contributor

Generally speaking, the KQ matrix is susceptible to overflow. So it is preferable to use BF16 or FP32 accumulators for its calculation. However, I was never able to get FP16 matrix multiplication with FP32 accumulation to work with cuBLAS. The documentation says it should be possible but the kernel fails to launch when I try it. Currently the precision argument for KQ is not used for cuBLAS GEMM. For a FP16 K matrix FP16 accumulation is used unconditionally.

I think the safest option then is probably to use F32 as an override for now.

@Thomas-MMJ
Copy link

In Daniel's 1.58 quantization he kept the shared expert at a higher resolution than the routed experts.

@jukofyork
Copy link
Contributor

        // whether to use n_tokens as the matrix dimension during multiplication or n_head
        // n_tokens is higher during prompt processing, this allows to optimize for this case
        bool pp_opt = n_tokens > n_head;

I think this might be causing some weird problem in the CUDA back-end where a different code-path is taken.

If I leave it as default and use this 127-token test prompt:

> Varis adjusted the noose, its hemp fibers grinding beneath his calluses. “Last chance,” he said, voice like gravel dragged through mud. “Confess, and your soul stays your own.”
>
> Jurl laughed—a wet, gurgling sound. “You’re knee-deep in it, Coldwater. ” The thing inside him twisted the boy’s lips into a grin too wide for his face. “The Great Wolf’s howlin’ again. The Dead’s Gate’s rusted through… ”

Turn this into the opening chapter of a Grimdark trilogy.

The model wont' say the actual phrases and it feels "off" - like there is something wrong with the attention mechanism (it sometimes "sort of" says the phrases, but not quite and often not at all).

If I fix the flag to always be true, eg:

    bool pp_opt = true;

Then all of a sudden the model starts to says those phrases and seems way better at writing in general (I suspect this triggers a different code-path - possibly something to do with the matrix-vector vs matrix-matrix stuff I remember seeing the other day?)

If I fix the flag to always be false eg:

    bool pp_opt = false;

Then run llama-perplexity, I get almost (but not quite) the same PPL to the default (ie: where n_tokens > n_head --> 512 > 128 --> pp_opt = true always), so I think the code-path is testing for a batch size of 1 exactly and not related to the actual series of ggml_permute and ggml_cont this triggers in llama.cpp::build_deepseek2().

So I thought I'd try running with bool pp_opt = false and llama-perplexity with llama-perplexity to test this idea, and weirdly:

perplexity: 607.12 seconds per pass - ETA 23 hours 39.13 minutes
[1]2.4873,[2]3.2776,[3]2.3693,[4]1.9780

It actually seems to get better PPL for these first few values (sorry no way I can run the 24h to completion) and the difference is almost the size of the error bar from the full PPL calculated over the default setting.

I don't know how else to help diagnose what's going on 😕

Could it be that the 127-token test prompt is not a multiple of 32 and when it gets permuted it's causing some problem there?

@slaren
Copy link
Member

slaren commented Feb 19, 2025

@jukofyork If you think that some operation is producing wrong results with CUDA, an easy way to test that would be to add a test case to test-backend-ops. It should be fairly straightforward, you would need to add a test case for the relevant operations and with the same shapes and types as are used with this model, in make_test_cases_eval.

@jukofyork
Copy link
Contributor

@slaren @JohannesGaessler @fairydreaming

I've got a little further now and think it's the same overflow problem that affected float16 tensors - just with pp_opt set it must cause more severe problems and/or some kind of catastrophic-cancellation due to the rows/columns being swapped.

Both the existing attention implementations use set_prec(cur, GGML_PREC_F32) here:

    if (cparams.flash_attn) {
        GGML_UNUSED(model);
        GGML_UNUSED(n_ctx);

        // split cached v into n_head heads (not transposed)
        struct ggml_tensor * v =
            ggml_view_3d(ctx, kv.v_l[il],
                    n_embd_head_v, n_kv, n_head_kv,
                    ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
                    ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
                    0);
        cb(v, "v", il);

        cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);

        ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);

        cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
    } else {
        struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
        cb(kq, "kq", il);

        // note: this op tends to require high floating point range
        //       while for some models F16 is enough, for others it is not, so we default to F32 here
        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

and by trial and error with --temp 0.0, I've found that these 3 also need to be upped for the MLA implementation:

                struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
                ggml_mul_mat_set_prec(kq_nope, GGML_PREC_F32); // ***
                cb(kq_nope, "kq_nope", il);
                struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
                ggml_mul_mat_set_prec(kq_pe, GGML_PREC_F32); // ***
                cb(kq_pe, "kq_pe", il);
                struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
                ggml_mul_mat_set_prec(kqv_compressed, GGML_PREC_F32); // ***
                cb(kqv_compressed, "kqv_compressed", il);

but this needs to be compiled with -DGGML_CUDA_FORCE_CUBLAS=1 to allow those to be set.

This was with bool pp_opt = true; fixed, and it still gives different output with bool pp_opt = n_tokens > n_head, but it's not as obviously broken as it was before.

When it comes times to merge the official MLA implementation, then I think this needs to be tested more thoroughly than I can do.

@JohannesGaessler
Copy link
Collaborator

but this needs to be compiled with -DGGML_CUDA_FORCE_CUBLAS=1 to allow those to be set.

The MMQ kernels always use FP32 for the accumulators, if there are numerical issues they must be due to extremal values in the inputs since FP16 is used for the scales of the quantized data.

@slaren
Copy link
Member

slaren commented Feb 19, 2025

ggml_cuda_mul_mat_batched_cublas always converts src1 to F16 regardless of the value of ggml_mul_mat_set_prec, and that may be a problem. This function is used in most KV operations.

We should avoid these conversions regardless, because the memory required for the intermediate copies is too high with big contexts. However, that would require a f16 x f32 -> f32 matrix multiplication, and I am not sure that cuBLAS can do that.

@jukofyork
Copy link
Contributor

I think just changing bool pp_opt = n_tokens > n_head; to bool pp_opt = true; has fixed whatever I was getting. It possibly runs a tiny bit slower though (~0.25 tokens/second).

Here's my full script that merges the PRs and applies all the hacks (including the commented out ones I'm not using):

#!/bin/bash

function safe_sed() {
    local file=$1
    local pattern=$2
    local replacement=$3

    # Check if pattern exists
    if ! sed -n "s/${pattern}/${replacement}/p" "$file" | grep -q .; then
        echo "Error: Pattern not found in $file: $pattern"
        return 1
    fi

    # Create backup
    cp "$file" "$file.bak"

    # Perform the replacement
    sed -i "s/${pattern}/${replacement}/g" "$file"

    # Show diff
    echo "Changes in $file:"
    diff "$file.bak" "$file"

    # Clean up
    rm "$file.bak"

    echo "Successfully replaced in $file"
    echo "-------------------"
}

rm -rf llama.cpp

git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
git remote add fairydreaming https://github.com/fairydreaming/llama.cpp.git
git remote add sl https://github.com/ggerganov/llama.cpp.git
git fetch fairydreaming
git fetch sl
git checkout -b merged_features

# For MLA compressed KV-cache
git merge --no-edit fairydreaming/deepseek2-mla-exp

# To save having to wait ages for the warmup (~2.5x less wait)
git merge --no-edit fairydreaming/experts-warmup

# To allow the use of --override-tensor exps=CPU (and --override-tensor attn_kv_b=CPU)
git merge --no-edit sl/sl/custom-tensor-offload

# Allocate the minimum possible for the unused KV-cache.
safe_sed "src/llama-kv-cache.cpp" "ggml_tensor \* k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa\*kv_size);" "ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, 1);"
safe_sed "src/llama-kv-cache.cpp" "ggml_tensor \* v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa\*kv_size);" "ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, 1);"

# Don't offload to GPU.
safe_sed "ggml/src/ggml-cuda/ggml-cuda.cu" "const int min_batch_size = 32" "const int min_batch_size = 9999999"

safe_sed "src/llama.cpp" "bool pp_opt = n_tokens > n_head;" "bool pp_opt = true;"

#safe_sed "src/llama.cpp" "kv_cache, q_nope2);" "kv_cache, q_nope2);\n                ggml_mul_mat_set_prec(kq_nope, GGML_PREC_F32);"
#safe_sed "src/llama.cpp" "kr_cache, q_pe);" "kr_cache, q_pe);\n                ggml_mul_mat_set_prec(kq_pe, GGML_PREC_F32);"
#safe_sed "src/llama.cpp" "kv_cache_trans, kq);" "kv_cache_trans, kq);\n                ggml_mul_mat_set_prec(kqv_compressed, GGML_PREC_F32);"

# Use float32 for the compressed KV-cache.
#safe_sed "src/llama-kv-cache.h" "ggml_type type_kr = GGML_TYPE_F16" "ggml_type type_kr = GGML_TYPE_F32"
#safe_sed "src/llama-kv-cache.h" "ggml_type type_kv = GGML_TYPE_F16" "ggml_type type_kv = GGML_TYPE_F32"

# Hack llama_tensor_get_type() to use our chosen custom quant.
safe_sed "src/llama-quant.cpp" \
  "llama_tensor_get_type(qs, new_type, tensor, ftype);" \
  "name.find(\"_exps\") != std::string::npos ? name.find(\"ffn_down\") != std::string::npos ? GGML_TYPE_Q6_K : GGML_TYPE_Q5_K : GGML_TYPE_BF16;"

# Must set GGML_SCHED_MAX_COPIES=1 for use with --override-tensor exps=CPU
#cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1 -DGGML_RPC=ON
#cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=9999999
cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1
#cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON -DGGML_SCHED_MAX_COPIES=1 -DGGML_CUDA_FORCE_CUBLAS=1

cmake --build build --config Release -- -j 44

Which gets run using:

numactl --interleave=all ./llama.cpp/build/bin/llama-server --host 192.168.1.111 --port 8080 \
  --model ./DeepSeek-R1-mla-Q5_K_XL.gguf --chat-template deepseek3 --alias "DeepSeek-R1-mla-Q5_K_XL" --ctx_size 32768 \
  --n-gpu-layers 99 --override-tensor exps=CPU --override-tensor attn_kv_b=CPU --numa distribute \
  --temp 0.6 --min-p 0.0 --top-p 1.0 --top-k 0 --threads 30 --threads-batch 44

The quant is in the script on 1 line:

    // ######
    if (name.find("_exps") != std::string::npos) {
        if (name.find("ffn_down") != std::string::npos) {
            new_type = GGML_TYPE_Q6_K;
        }
        else {
                new_type = GGML_TYPE_Q5_K;
        }
    }
    else {
        new_type = GGML_TYPE_BF16;
    }
    else
    // ######

and gave this on wiki.test.raw:

Q5_K_XL : 479.64 GiB (6.13 BPW) | PPL = 3.3499 +/- 0.01849 | 19.72 tokens per second

I can't see the thought tags on openrouter, but this custom BF16/Q6_K /Q5_K appears to be working as good as any they are hosting now (the official deepseek openrouter just seems to always be down so can't text against them), and gives similar responses.

@JohannesGaessler
Copy link
Collaborator

We should avoid these conversions regardless, because the memory required for the intermediate copies is too high with big contexts. However, that would require a f16 x f32 -> f32 matrix multiplication, and I am not sure that cuBLAS can do that.

The PTX documentation has a table with the data types that are supported by tensor cores. In all cases the input matrices must have the same data type. So if the KV cache stays FP16 the activations must be converted to FP16. Alternative approaches would be to use BF16 which has the same numerical range as FP32 or to convert the FP16 data to TF32 in SRAM (this is to my knowledge not supported by cuBLAS, I did not check CUTLASS). Both BF16 and TF32 need Ampere or newer. In terms of speed FP16, FP16 -> FP16 > FP16, FP16 -> FP32 > TF32, TF32 -> FP32.

@jukofyork
Copy link
Contributor

A quick update on the F16 overflow issue:

I've found that fixing bool pp_opt = true; (which essentially removes all the extra perms), and keeping only the _a and _b tensors set as F16, eg:

"static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {\n\
    const std::string name = ggml_get_name(tensor);\n\
    if (name.find(\"_exps\") != std::string::npos) {\n\
        return name.find(\"ffn_down\") != std::string::npos ? GGML_TYPE_Q6_K : GGML_TYPE_Q5_K;\n\
    } else if (name.find(\"attn_\") != std::string::npos && name.find(\"_output\") == std::string::npos) {\n\
        return GGML_TYPE_BF16;\n\
    }\n\
    return GGML_TYPE_Q8_0;\n\

It does somewhat work, and no longer gives nan for perplexity or repeat the same word over and over for token generation.

I works quite a bit faster (3.6 tokens/s vs 3.1-3.2 tokens/s) compared to using the same custom quant with BF16 or F32 ( probably by not having to do lost of up/down casting), but it still isn't working 100% correctly as the perplexity run shows:

[1]8.0332,[2]10.0018,[3]8.4663,[4]7.7059,[5]6.9553,[6]6.6773,[7]6.4792,[8]6.8003,[9]6.8766,[10]6.7664,[11]6.7516,[12]7.0069,^C

These should be [1]2.5 and so on.

I've tried using -DGGML_CUDA_FORCE_CUBLAS=1 and then using ggml_mul_mat_set_prec(XXX, GGML_PREC_F32) and this didn't seem to help (one actually made the perplexity go up into the 80s!).

Hopefully after the attention refactoring is over and MLA gets looked at again, some of these problems can be ironed out.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 22, 2025

We should avoid these conversions regardless, because the memory required for the intermediate copies is too high with big contexts. However, that would require a f16 x f32 -> f32 matrix multiplication, and I am not sure that cuBLAS can do that.

The PTX documentation has a table with the data types that are supported by tensor cores. In all cases the input matrices must have the same data type. So if the KV cache stays FP16 the activations must be converted to FP16. Alternative approaches would be to use BF16 which has the same numerical range as FP32 or to convert the FP16 data to TF32 in SRAM (this is to my knowledge not supported by cuBLAS, I did not check CUTLASS). Both BF16 and TF32 need Ampere or newer. In terms of speed FP16, FP16 -> FP16 > FP16, FP16 -> FP32 > TF32, TF32 -> FP32.

I don't think BF16 will save us here. I've tried both these variants in ggml_cuda_op_mul_mat_cublas:

1. BF16 x BF16 --> F32 using F32 compute type

    if (src0->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_BF16) {

        ggml_cuda_pool_alloc<nv_bfloat16> src0_as_bf16(ctx.pool(id));
        if (src0->type != GGML_TYPE_BF16) {
            const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src0->type);
            GGML_ASSERT(to_bf16_cuda != nullptr);
            src0_as_bf16.alloc(row_diff*ne00);
            to_bf16_cuda(src0_dd_i, src0_as_bf16.get(), row_diff*ne00, stream);
        }
        const nv_bfloat16 * src0_ptr = src0->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src0_dd_i : src0_as_bf16.get();

        ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
        if (src1->type != GGML_TYPE_BF16) {
            const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
            GGML_ASSERT(to_bf16_cuda != nullptr);
            src1_as_bf16.alloc(src1_ncols*ne10);
            to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), src1_ncols*ne10, stream);
        }
        const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();

        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));

        const float alpha = 1.0f;
        const float beta = 0.0f;
        CUBLAS_CHECK(
            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
                    row_diff, src1_ncols, ne10,
                    &alpha, src0_ptr,  CUDA_R_16BF, ne00,
                            src1_ptr,  CUDA_R_16BF, ne10,
                    &beta,   dst_dd_i, CUDA_R_32F, ldc,
                    CUBLAS_COMPUTE_32F,
                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
    }

which has to have this conversion type added:

to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
    switch (type) {
        case GGML_TYPE_F32:
            return convert_unary_cuda<float>;
        default:
            return nullptr;
    }
}

but as @JohannesGaessler found:

I think I misremembered. After looking at the documentation again I think the problem was that FP16, FP16 -> FP32 GEMM is supported but the performance was so much worse that there was basically no point in using it.

It seems to be a complete waste of time and appears to just upcast the BF16 to F32 internally.

2. F32 x F32 --> F32 using 32F_FAST_16BF compute type:

    if (src0->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_BF16) {

        ggml_cuda_pool_alloc<float> src0_as_f32(ctx.pool(id));
        if (src0->type != GGML_TYPE_F32) {
            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
            GGML_ASSERT(to_fp32_cuda != nullptr);
            src0_as_f32.alloc(row_diff*ne00);
            to_fp32_cuda(src0_dd_i, src0_as_f32.get(), row_diff*ne00, stream);
        }
        const float * src0_ptr = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_as_f32.get();

        ggml_cuda_pool_alloc<float> src1_as_f32(ctx.pool(id));
        if (src1->type != GGML_TYPE_F32) {
            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
            GGML_ASSERT(to_fp32_cuda != nullptr);
            src1_as_f32.alloc(src1_ncols*ne10);
            to_fp32_cuda(src1_ddf_i, src1_as_f32.get(), src1_ncols*ne10, stream);
        }
        const float * src1_ptr = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_as_f32.get();

        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));

        const float alpha = 1.0f;
        const float beta = 0.0f;
        CUBLAS_CHECK(
            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
                    row_diff, src1_ncols, ne10,
                    &alpha, src0_ptr,  CUDA_R_32F, ne00,
                            src1_ptr,  CUDA_R_32F, ne10,
                    &beta,   dst_dd_i, CUDA_R_32F, ldc,
                    CUBLAS_COMPUTE_32F_FAST_16BF,
                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
    }

This probably is a little faster, but most of the gain is lost to the stupid BF16 --> F32 --> BF16 conversions.

But more importantly: using 16BF as the compute type now trades the old overflow errors we had for F16 for low-order bit summation errors, and to solve this we would need to implement a custom GEMM function using Kahan summation or similar, which would likely negate the gains of using tensor cores anyway.


These two look to be the only options for using BF16:

https://docs.nvidia.com/cuda/cublas/#cublasgemmex

(I assume the lack of BF16 x BF16 --> BF16 using BF16 compute type is because they know it would suffer badly from low-order bit summation errors and be pretty much useless?)


I also found this:

https://docs.nvidia.com/cuda/cublas/#gemm-algorithms-numerical-behavior

For the routines cublasgemmEx() and cublasGemmEx(), when the compute type is greater than the output type, the sum of the split chunks can potentially lead to some intermediate overflows thus producing a final resulting matrix with some overflows. Those overflows might not have occurred if all the dot products had been accumulated in the compute type before being converted at the end in the output type. This computation side-effect can be easily exposed when the computeType is CUDA_R_32F and Atype, Btype and Ctype are in CUDA_R_16F.

which explains why the existing ggml_cuda_op_mul_mat_cublas code that uses F16 x F16 --> F32:

        if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
            const float alpha = 1.0f;
            const float beta = 0.0f;
            CUBLAS_CHECK(
                cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
                        row_diff, src1_ncols, ne10,
                        &alpha, src0_ptr,  CUDA_R_16F, ne00,
                                src1_ptr,  CUDA_R_16F, ne10,
                        &beta,   dst_dd_i, CUDA_R_32F, ldc,
                        CUBLAS_COMPUTE_32F,
                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));
        } else {
            ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);

            const half alpha_f16 = 1.0f;
            const half beta_f16 = 0.0f;

            CUBLAS_CHECK(
                cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
                        row_diff, src1_ncols, ne10,
                        &alpha_f16, src0_ptr,      CUDA_R_16F, ne00,
                                    src1_ptr,      CUDA_R_16F, ne10,
                        &beta_f16,  dst_f16.get(), CUDA_R_16F, ldc,
                        CUBLAS_COMPUTE_16F,
                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));

            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
            to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
        }

can cause weird/unpredictable stuff happening like I saw with the code in this PR related to the pp_opt flag.

@slaren
Copy link
Member

slaren commented Feb 22, 2025

I suppose we could use BF16 for all intermediate results and the KV, but that would require significant changes to ggml. Is the performance of BF16, BF16 -> BF16 good, or does the F32 accumulator cause it to be significantly slower than F16?

@JohannesGaessler
Copy link
Collaborator

I suppose we could use BF16 for all intermediate results and the KV, but that would require significant changes to ggml

There are issues with using BF16 outside of GEMM though, particularly when it comes to support for BF16 instructions. BF16 softmax would for example require Hopper or newer if it is to be done without conversions to FP32.

An alternative approach that we could investigate is normalizing the activations when converting them to FP16 and applying the inverse scale when converting the output back to FP32.

Is the performance of BF16, BF16 -> BF16 good, or does the F32 accumulator cause it to be significantly slower than F16?

I don't know the ratios for 16 vs 32 bit accumulation but FP16 and BF16 tensor cores are I think the same speed.

@jukofyork
Copy link
Contributor

jukofyork commented Feb 22, 2025

An alternative approach that we could investigate is normalizing the activations when converting them to FP16 and applying the inverse scale when converting the output back to FP32.

Yeah, I'm going to try this tomorrow if I get chance. I also think it might be possible to scale down the tensors and use the product of the inverse scales when converting back to FP32:

I noticed in the past that the Cohere models look like they've had all their tensors scaled to fit within the numerical precision of FP16, as changing the logit_scale parameter even by a miniscule amount causes numerical problems (I was using this as a "training time temperature scaling" to trick the model into thinking its output distribution was sharper than it really was during fine-tuning to increase the Entropy of the outputs of the new model).

The name:

  "_name_or_path": "/home/ahmet_cohere_com/HF_Final_weight_tie"

and the different logit_scale values for the different models also hints at this.

It's possible they took a sample of activations to do this though.

@jukofyork
Copy link
Contributor

Just saw this linked on Reddit:

https://github.com/deepseek-ai/FlashMLA

I'm still unsure what this thread is trying to say:

deepseek-ai/FlashMLA#16

For MLA the q absorb and o absorb steps can be done separately from the attention. e.g. q: [bs, num_q_heads, 128 (head dim)] -> q: [bs, num_q_heads, 512 (latent dim)] concat q_nope: [bs, num_q_heads, 64)] the output of MLA will be [bs, num_q_heads, 512)], which can then be down_projed independently.

It also links this paper:

https://arxiv.org/abs/2502.14837

@jukofyork
Copy link
Contributor

jukofyork commented Mar 2, 2025

I've found a problem with the compute buffer calculation in this PR:

It looks like it calls the compute graph with bs=1 and (twice) with bs=max-context to calculate the compute buffer, but seems to assume that the bs=1 case will always be the smaller of the two.

For this particular PR this assumption is wrong (due to the extra permutes for bs=1) and it will underestimate the compute buffer requirements and crash exactly 2/3rds of the way into the context:

llama_init_from_model: pipeline parallelism enabled (n_copies=1)
llama_init_from_model:      CUDA0 compute buffer size = 16518.88 MiB
llama_init_from_model:      CUDA1 compute buffer size = 16518.88 MiB
llama_init_from_model:  CUDA_Host compute buffer size =   103.63 MiB
llama_init_from_model: graph nodes  = 5208 (with bs=512), 5330 (with bs=1)
llama_init_from_model: graph splits = 119
.
.
.
llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:73: CUDA error: an internal operation failed
  current device: 0, in function ggml_cuda_op_mul_mat_cublas at llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:1268
  cublasSgemm_v2(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, &alpha, src0_ddf_i, ne00, src1_ddf1_i, ne10, &beta, dst_dd_i, ldc)
CUDA error

Not sure if this is the best solution, but by detecting the case where it's doing these runs (ie: n_kv == kv_self.size) you can set pp_opt to false to make it see the larger requirements:

bool pp_opt = (n_tokens > n_head && n_kv < static_cast<int32_t>(kv_self.size));

and it seems to work:

llama_init_from_model:  CUDA_Host  output buffer size =     0.49 MiB
llama_init_from_model: pipeline parallelism enabled (n_copies=1)
llama_init_from_model:      CUDA0 compute buffer size = 24781.75 MiB
llama_init_from_model:      CUDA1 compute buffer size = 24781.75 MiB
llama_init_from_model:  CUDA_Host compute buffer size =   143.25 MiB
llama_init_from_model: graph nodes  = 5330
llama_init_from_model: graph splits = 119

although it doesn't print the separate 5208 (with bs=512), 5330 (with bs=1) line.

@jukofyork
Copy link
Contributor

jukofyork commented Mar 2, 2025

No, it still crashes.

Fixing pp_opt=true or pp_opt=false work but it doesn't seem to expect anyone to try doing this sort of dynamic permutation.

I think ultimately, it should probably not be the compute graph that detects/converts batches of vectors into matrices like this, and either the backends themselves or something that hands off the work to the backends?

@jukofyork
Copy link
Contributor

jukofyork commented Mar 3, 2025

It's not to do with the bs=1 stuff anyway and looks to just be a bug.

I've completely rewritten all the code which used the permutes to swap the batch dimension in/out, and instead used views since these weren't really batch matrix multiples:

struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
cb(wk_b, "wk_b", il);

struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
cb(wv_b, "wv_b", il);

q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
cb(q_nope, "q_nope_perm", il);

q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
cb(q_pe, "q_pe_perm", il);

// cont so that kq_view will have last 2 dims in correct order
if (n_tokens > 1) {
    q_nope = ggml_cont(ctx0, q_nope);
    cb(q_nope, "q_nope_cont", il);

    q_pe = ggml_cont(ctx0, q_pe);
    cb(q_pe, "q_pe_cont", il);
}

struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
cb(q_nope2, "q_nope2", il);

struct ggml_tensor * q_nope2_view = ggml_view_2d(ctx0, q_nope2, kv_lora_rank, n_tokens * n_head, ggml_row_size(q_nope2->type, kv_lora_rank), 0);
cb(q_nope2_view, "q_nope2_view", il);

struct ggml_tensor * q_pe_view = ggml_view_2d(ctx0, q_pe, n_embd_head_qk_rope, n_tokens * n_head, ggml_row_size(q_pe->type, n_embd_head_qk_rope), 0);
cb(q_pe_view, "q_pe_view", il);

struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_view);
//ggml_mul_mat_set_prec(kq_nope, GGML_PREC_F32);
cb(kq_nope, "kq_nope", il);

struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe_view);
//ggml_mul_mat_set_prec(kq_pe, GGML_PREC_F32);
cb(kq_pe, "kq_pe", il);

struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
cb(kq, "kq", il);

struct ggml_tensor * kq_view = ggml_view_3d(ctx0, kq, n_kv, n_tokens, n_head, ggml_row_size(kq->type, n_kv), ggml_row_size(kq->type, n_kv * n_tokens), 0);
cb(kq_view, "kq_view", il);

struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq_view, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq_soft_max, "kq_soft_max", il);

struct ggml_tensor * kq_soft_max_view = ggml_view_2d(ctx0, kq_soft_max, n_kv, n_tokens * n_head, ggml_row_size(kq_soft_max->type, n_kv), 0);
cb(kq_soft_max_view, "kq_soft_max_view", il);

struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_soft_max_view);
//ggml_mul_mat_set_prec(kqv_compressed, GGML_PREC_F32);
cb(kqv_compressed, "kqv_compressed,", il);

struct ggml_tensor * kqv_compressed_view = ggml_view_3d(ctx0, kqv_compressed, kv_lora_rank, n_tokens, n_head, ggml_row_size(kqv_compressed->type, kv_lora_rank), ggml_row_size(kqv_compressed->type, kv_lora_rank * n_tokens), 0);
cb(kqv_compressed_view, "kqv_compressed_view", il);

struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed_view);
cb(kqv, "kqv", il);

if (n_tokens > 1) {
    kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
    cb(kqv, "kqv_perm", il);
}

cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
cb(cur, "kqv_2d", il);

ggml_build_forward_expand(gf, cur);

cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
cb(cur, "kqv_out", il);

With or without the if (n_tokens > 1) cases in the new code.

The new code runs about the same on CUDA (likely because the old code ultimately hands off the 2D x 3M multiples to cublasGemmBatchedEx() in the backend), but it does now seem to use split-mode row properly (which doesn't work with 3D tensors AFAIK) and may be quite a bit faster for other backends due to not having to run 128 batches for the critical compressed KV cache multiplies (it just treats them as huge 2D x 2D multiplies instead).

I think I can see how to use the flash attention code now too: by storing the 576 "NOPE+ROPE" vectors in the KV-cache, then setting it up as standard Multi-Query Attention (MQA) and then either chopping off the last 64 elements of the output or adding 64 zeroed vectors to wv_b (ie: pad the 512 dimension to 576 with zeros so when it gets multiplied back out the extra 64 elements from the flash attention output get nullified). It will waste a tiny bit of compute taking the linear weighted sums of these last 64 elements in the flash attention code, but not much.

@jukofyork
Copy link
Contributor

jukofyork commented Mar 6, 2025

So I've got the MLA stuff working with flash attention now:

    if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v && model->arch != LLM_ARCH_DEEPSEEK2) {
        LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
        params.flash_attn = false;
    }
.
.
.
                if (cparams.flash_attn) {

                    struct ggml_tensor * k_cache = ggml_view_3d(ctx0, kv_self.k_l[il], kv_lora_rank + n_embd_head_qk_rope, n_kv, 1, ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), ggml_row_size(kv_self.k_l[il]->type, (kv_lora_rank + n_embd_head_qk_rope) * n_kv), 0);
                    cb(k_cache, "k_cache", il);

                    struct ggml_tensor * v_cache = ggml_view_3d(ctx0, kv_self.k_l[il], kv_lora_rank + n_embd_head_qk_rope, n_kv, 1, ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), ggml_row_size(kv_self.k_l[il]->type, (kv_lora_rank + n_embd_head_qk_rope) * n_kv), 0);
                    cb(k_cache, "k_cache", il);

                    struct ggml_tensor * kqv_compressed = ggml_flash_attn_ext(ctx0, q_compressed_concat, k_cache, v_cache, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.0f);
                    cb(kqv_compressed, "kqv_compressed_flash", il);

                    //ggml_flash_attn_ext_set_prec(kqv_compressed, GGML_PREC_F32);

                    kqv_compressed = ggml_cont(ctx0, ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3));
                    cb(kqv_compressed, "kqv_compressed_perm", il);

                    kqv_compressed_view = ggml_view_3d(ctx0, kqv_compressed, kv_lora_rank, n_tokens, n_head, ggml_row_size(kqv_compressed->type, kv_lora_rank+ n_embd_head_qk_rope), ggml_row_size(kqv_compressed->type, (kv_lora_rank + n_embd_head_qk_rope) * n_tokens), 0);
                    cb(kqv_compressed_view, "kqv_compressed_view", il);

                } else {

                    struct ggml_tensor * k_cache = ggml_view_2d(ctx0, kv_self.k_l[il], kv_lora_rank + n_embd_head_qk_rope, n_kv, ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
                    cb(k_cache, "k_cache", il);

                    struct ggml_tensor * v_cache_trans = ggml_view_2d(ctx0, kv_self.v_l[il], n_kv, kv_lora_rank, ggml_row_size(kv_self.v_l[il]->type, kv_self.size), 0);
                    cb(v_cache_trans, "v_cache_trans", il);

                    struct ggml_tensor * q_view = ggml_view_2d(ctx0, q_compressed_concat, kv_lora_rank + n_embd_head_qk_rope, n_tokens * n_head, ggml_row_size(q_nope_absorbed->type, kv_lora_rank + n_embd_head_qk_rope), 0);
                    cb(q_view, "q_view", il);

                    struct ggml_tensor * kq = ggml_mul_mat(ctx0, k_cache, q_view);
                    //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
                    cb(kq, "kq", il);

                    struct ggml_tensor * kq_view = ggml_view_3d(ctx0, kq, n_kv, n_tokens, n_head, ggml_row_size(kq->type, n_kv), ggml_row_size(kq->type, n_kv * n_tokens), 0);
                    cb(kq_view, "kq_view", il);

                    struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq_view, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
                    cb(kq_soft_max, "kq_soft_max", il);

                    struct ggml_tensor * kq_soft_max_view = ggml_view_2d(ctx0, kq_soft_max, n_kv, n_tokens * n_head, ggml_row_size(kq_soft_max->type, n_kv), 0);
                    cb(kq_soft_max_view, "kq_soft_max_view", il);

                    struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, v_cache_trans, kq_soft_max_view);
                    //ggml_mul_mat_set_prec(kqv_compressed, GGML_PREC_F32);
                    cb(kqv_compressed, "kqv_compressed,", il);

                    kqv_compressed_view = ggml_view_3d(ctx0, kqv_compressed, kv_lora_rank, n_tokens, n_head, ggml_row_size(kqv_compressed->type, kv_lora_rank), ggml_row_size(kqv_compressed->type, kv_lora_rank * n_tokens), 0);
                    cb(kqv_compressed_view, "kqv_compressed_view", il);
                }
llama_init_from_model:      CUDA0 compute buffer size =   318.00 MiB
llama_init_from_model:      CUDA1 compute buffer size =   317.75 MiB

These were 16GB each before for the same context and batch size.

I haven't added it yet, but the other advantage is that there is no need to store the transposed version of c (the code above is just using v_cache = ggml_view_3d(ctx0, kv_self.k_l[il],..., so this also halves the storage for context!

It looks to be around 15% slower for prompt processing and 2-5% slower for token generation (I'm just running perplexity to compare now, but looks almost identical PPL so far). These slowdowns might get fixed after I go through all the code again to see if I can avoid the kqv_compressed = ggml_cont(ctx0, ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3) by permuting the copy of wv_b, etc.

It also turned out that the "move elision" stuff was pointless in C++ so with that removed it became a lot clearer how to use for flash attention.

It's all a bit of a mess ATM and I will make a proper PR after some more testing, but I need some input from @JohannesGaessler first:

I've only managed to get the vector kernal working, via:

void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
    const ggml_tensor * KQV  = dst;
    const ggml_tensor * Q    = dst->src[0];
    const ggml_tensor * K    = dst->src[1];
    const ggml_tensor * V    = dst->src[2];
    const ggml_tensor * mask = dst->src[3];

    ggml_cuda_set_device(ctx.device);
    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);

    // On AMD the tile kernels perform poorly, use the vec kernel instead:
    if (true || cc >= GGML_CUDA_CC_OFFSET_AMD) {
        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
        } else {
            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
        }
        return;
    }
.
.
.

and:

static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
.
.
.
FATTN_VEC_F16_CASE(576, GGML_TYPE_F16, GGML_TYPE_F16)
static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
.
.
.
FATTN_VEC_F32_CASE(576, GGML_TYPE_F16, GGML_TYPE_F16)

I did try to get the tile kernels working via the 576 overrides but ran into more and more problems, and wanrings I had to bypass and ultimately couldn't get it to run...

BUT: Does it make any sense to use the tiles if they are 576x576 or is almost certain that the vector kernels will work as well or better? AFAIK, the whole point of the tiles is to try to use the HBM memory cache and huge tiles like this aren't likely to work well anyway?

If so then I could turn if (true || cc >= GGML_CUDA_CC_OFFSET_AMD) { into a test for D > 512 and force the vector kernal for this?

@jukofyork
Copy link
Contributor

I've tidied it up a bit and made a proper PR: #12227

@fairydreaming I'd be interested in seeing if the removal of the "Copy Elision" and converting the ggml_permute to 2D views speeds up the CPU backend, as in the CUDA backend it calls the CUBLAS batched-MM (and thus didn't make much difference), but in the CPU backend it might save splitting into batches and just power through a huge 2D x 2D multiply instead.

@jukofyork
Copy link
Contributor

jukofyork commented Mar 10, 2025

See #12313 for my current attempt to refactor this to use the proper lm_build_kv() calls and (vastly) simplify build_deepseek2().

NOTE: It's currently slightly worse performance as a result and requires re-quantising all models, but I hope to regain most of the performance over the next couple of days and breaking up the tensors like I have should help to get to the bottom of the numerical problems. It is probably best to wait until I bring it out of draft to start using seriously...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.