Skip to content

CUDA: mul_mat_v support for batch sizes > 1 #14262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

JohannesGaessler
Copy link
Collaborator

This PR extends the mul_mat_vec kernels for batch sizes > 1, they seem to be viable up to a batch size of 8. The primary purpose is to help with speculative decoding and batched inference.

Performance changes
GPU Model Microbatch size Test t/s master t/s PR Speedup
RX 6800 llama 1B all F32 1 pp512 93.97 94.58 1.01
RX 6800 llama 1B all F32 2 pp512 44.37 169.10 3.81
RX 6800 llama 1B all F32 3 pp512 65.45 211.74 3.24
RX 6800 llama 1B all F32 4 pp512 86.84 260.85 3.00
RX 6800 llama 1B all F32 5 pp512 101.39 248.44 2.45
RX 6800 llama 1B all F32 6 pp512 120.59 275.06 2.28
RX 6800 llama 1B all F32 7 pp512 141.30 295.12 2.09
RX 6800 llama 1B all F32 8 pp512 162.56 305.50 1.88
RX 6800 llama 8B BF16 1 pp512 21.20 21.36 1.01
RX 6800 llama 8B BF16 2 pp512 1.89 38.77 20.55
RX 6800 llama 8B BF16 3 pp512 2.81 48.75 17.34
RX 6800 llama 8B BF16 4 pp512 3.76 56.86 15.12
RX 6800 llama 8B BF16 5 pp512 4.61 57.07 12.39
RX 6800 llama 8B BF16 6 pp512 5.52 59.26 10.74
RX 6800 llama 8B BF16 7 pp512 6.55 60.94 9.31
RX 6800 llama 8B BF16 8 pp512 7.47 61.88 8.28
RX 6800 llama 8B F16 1 pp512 21.38 21.31 1.00
RX 6800 llama 8B F16 2 pp512 9.60 39.14 4.08
RX 6800 llama 8B F16 3 pp512 14.16 48.64 3.43
RX 6800 llama 8B F16 4 pp512 18.90 54.80 2.90
RX 6800 llama 8B F16 5 pp512 22.84 56.92 2.49
RX 6800 llama 8B F16 6 pp512 27.33 60.27 2.21
RX 6800 llama 8B F16 7 pp512 32.07 61.76 1.93
RX 6800 llama 8B F16 8 pp512 36.72 63.06 1.72
P40 llama 1B all F32 1 pp512 75.35 75.65 1.00
P40 llama 1B all F32 2 pp512 140.43 143.76 1.02
P40 llama 1B all F32 3 pp512 186.86 212.35 1.14
P40 llama 1B all F32 4 pp512 259.12 260.10 1.00
P40 llama 1B all F32 5 pp512 304.59 304.61 1.00
P40 llama 1B all F32 6 pp512 357.97 358.68 1.00
P40 llama 1B all F32 7 pp512 414.78 415.16 1.00
P40 llama 1B all F32 8 pp512 475.44 476.04 1.00
P40 llama 8B BF16 1 pp512 21.15 21.21 1.00
P40 llama 8B BF16 2 pp512 8.60 35.31 4.10
P40 llama 8B BF16 3 pp512 12.83 39.42 3.07
P40 llama 8B BF16 4 pp512 17.09 45.63 2.67
P40 llama 8B BF16 5 pp512 21.14 43.44 2.06
P40 llama 8B BF16 6 pp512 25.26 53.78 2.13
P40 llama 8B BF16 7 pp512 29.71 47.35 1.59
P40 llama 8B BF16 8 pp512 33.90 46.15 1.36
P40 llama 8B F16 1 pp512 20.95 21.15 1.01
P40 llama 8B F16 2 pp512 6.96 35.44 5.09
P40 llama 8B F16 3 pp512 10.20 39.67 3.89
P40 llama 8B F16 4 pp512 13.70 46.57 3.40
P40 llama 8B F16 5 pp512 16.54 48.39 2.93
P40 llama 8B F16 6 pp512 19.77 53.76 2.72
P40 llama 8B F16 7 pp512 22.95 47.02 2.05
P40 llama 8B F16 8 pp512 26.10 46.37 1.78
RTX 3090 llama 1B all F32 1 pp512 201.17 200.97 1.00
RTX 3090 llama 1B all F32 2 pp512 325.44 379.24 1.17
RTX 3090 llama 1B all F32 3 pp512 464.06 538.10 1.16
RTX 3090 llama 1B all F32 4 pp512 601.38 683.79 1.14
RTX 3090 llama 1B all F32 5 pp512 743.95 740.42 1.00
RTX 3090 llama 1B all F32 6 pp512 885.69 887.77 1.00
RTX 3090 llama 1B all F32 7 pp512 1025.44 1024.07 1.00
RTX 3090 llama 1B all F32 8 pp512 1178.03 1178.45 1.00
RTX 3090 llama 8B BF16 1 pp512 58.06 58.27 1.00
RTX 3090 llama 8B BF16 2 pp512 98.48 109.70 1.11
RTX 3090 llama 8B BF16 3 pp512 146.26 148.08 1.01
RTX 3090 llama 8B BF16 4 pp512 195.15 194.32 1.00
RTX 3090 llama 8B BF16 5 pp512 239.12 238.88 1.00
RTX 3090 llama 8B BF16 6 pp512 285.15 284.49 1.00
RTX 3090 llama 8B BF16 7 pp512 330.18 329.39 1.00
RTX 3090 llama 8B BF16 8 pp512 380.56 378.83 1.00
RTX 3090 llama 8B F16 1 pp512 58.27 58.39 1.00
RTX 3090 llama 8B F16 2 pp512 101.39 108.35 1.07
RTX 3090 llama 8B F16 3 pp512 149.68 150.05 1.00
RTX 3090 llama 8B F16 4 pp512 198.52 198.50 1.00
RTX 3090 llama 8B F16 5 pp512 243.57 244.09 1.00
RTX 3090 llama 8B F16 6 pp512 290.06 290.72 1.00
RTX 3090 llama 8B F16 7 pp512 340.58 340.60 1.00
RTX 3090 llama 8B F16 8 pp512 391.75 392.27 1.00
RTX 4090 llama 1B all F32 1 pp512 231.53 232.40 1.00
RTX 4090 llama 1B all F32 2 pp512 371.68 435.37 1.17
RTX 4090 llama 1B all F32 3 pp512 550.96 642.04 1.17
RTX 4090 llama 1B all F32 4 pp512 733.60 851.59 1.16
RTX 4090 llama 1B all F32 5 pp512 908.50 1031.69 1.14
RTX 4090 llama 1B all F32 6 pp512 1102.94 1205.03 1.09
RTX 4090 llama 1B all F32 7 pp512 1278.15 1375.06 1.08
RTX 4090 llama 1B all F32 8 pp512 1478.59 1560.42 1.06
RTX 4090 llama 8B BF16 1 pp512 66.49 66.67 1.00
RTX 4090 llama 8B BF16 2 pp512 119.44 127.02 1.06
RTX 4090 llama 8B BF16 3 pp512 177.66 187.72 1.06
RTX 4090 llama 8B BF16 4 pp512 236.78 247.97 1.05
RTX 4090 llama 8B BF16 5 pp512 291.99 291.87 1.00
RTX 4090 llama 8B BF16 6 pp512 348.79 348.97 1.00
RTX 4090 llama 8B BF16 7 pp512 404.26 403.96 1.00
RTX 4090 llama 8B BF16 8 pp512 466.13 465.14 1.00
RTX 4090 llama 8B F16 1 pp512 66.56 66.66 1.00
RTX 4090 llama 8B F16 2 pp512 117.49 126.75 1.08
RTX 4090 llama 8B F16 3 pp512 177.31 188.96 1.07
RTX 4090 llama 8B F16 4 pp512 235.92 247.04 1.05
RTX 4090 llama 8B F16 5 pp512 290.63 289.39 1.00
RTX 4090 llama 8B F16 6 pp512 346.52 345.44 1.00
RTX 4090 llama 8B F16 7 pp512 401.26 399.99 1.00
RTX 4090 llama 8B F16 8 pp512 462.38 461.38 1.00

On modern NVIDIA GPUs the speedup vs. cuBLAS for FP16 and BF16 is relatively small though the speedup for FP32 is larger than I expected. Conversely, the FP32 speedup for Pascal is much smaller, if there is any. What I think happened is that the NVIDIA engineers simply put less work into optimizing FP32 GEMM on more modern GPUs. The cuBLAS performance for old NVIDIA GPUs and the hipBLAS performance seem to be very bad for FP16/BF16 so this PR achieves a ridiculous 20x speedup for some use cases; maybe we are running the BLAS libraries in a suboptimal way.

@IMbackK @yeahdongcn it may be worth checking whether the logic I implemented in ggml_cuda_should_use_mmv can be improved for non-NVIDIA hardware.

@IMbackK
Copy link
Collaborator

IMbackK commented Jun 18, 2025

(hip)rocblas performing very poorly on rdna is a known issue and not down to the exact calls we are useing, its pretty bad for rdna2 but it gets worse for rdna3 and for rnda4 it might as well be broken performance wise.

On mi hardware the performance is much better so possibly we will not want to do this there, but needs bench marking.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 18, 2025
@JohannesGaessler
Copy link
Collaborator Author

I forgot: I changed the integer size in the kernel from 64 bit to 32 bit due to issues with register pressure.

@slaren
Copy link
Member

slaren commented Jun 18, 2025

I forgot: I changed the integer size in the kernel from 64 bit to 32 bit due to issues with register pressure.

I think this is ok as long as the pointers or indexes to the weight matrix are still computed with 64-bit math, otherwise it will result in overflows with large matrices. E.g. Command-R output matrix is 256000*8192 elements, which is very close to the limit of a 32-bit int.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jun 18, 2025

I changed specifically the calculation of the initial offsets to 64 bit math. That is the only part of the kernel where the pointer offsets scale with the product of 2 tensor dimensions. The pointer offsets scaling with 1 tensor dimension are at least 1024x lower.

@slaren
Copy link
Member

slaren commented Jun 18, 2025

test-backend-ops crashes:

MUL_MAT_ID(type_a=f32,type_b=f32,n_mats=4,n_used=2,b=0,m=512,n=1,k=256): ggml/src/ggml-cuda/mmv.cu:352: GGML_ASSERT(!ids || ne1 == 1) failed

@JohannesGaessler
Copy link
Collaborator Author

Thank you, I forgot to check MUL_MAT_ID for the final version.

@yeahdongcn
Copy link
Collaborator

Merged your changes along with #13842 and tested on MTT S80 and S4000. All test-backend-ops tests passed.

However, I noticed a slight performance drop on the S4000 when running llama-bench. I’ll investigate further to understand the cause.

@IMbackK
Copy link
Collaborator

IMbackK commented Jun 19, 2025

On cdna i am seeing a large (2x +) slow down starting at batch 4 in all datatypes.

I will try to take a look soon, maybe sunday

@yeahdongcn
Copy link
Collaborator

On cdna i am seeing a large (2x +) slow down starting at batch 4 in all datatypes.

I will try to take a look soon, maybe sunday

I'm currently applying for a trial cloud instance with an AMD GPU. @IMbackK Could you please share the steps you used to run the above tests? I’d like to reproduce them for comparison as well.

@yeahdongcn
Copy link
Collaborator

@JohannesGaessler I saw a small performance gain (on MTT S80) after switching to __fmaf_rn (see: makllama@e8bbcaa).

@JohannesGaessler
Copy link
Collaborator Author

How much of a speedup and for which data types? Are you sure this is due to __fmaf_rn and not due to you moving the type conversion out of the loop?

@yeahdongcn
Copy link
Collaborator

How much of a speedup and for which data types? Are you sure this is due to __fmaf_rn and not due to you moving the type conversion out of the loop?

How to perform tests on each data type? I just ran E2E tests using llama-bench on several models, and the performance gain is about 3%~4% on MTT S80.

@JohannesGaessler
Copy link
Collaborator Author

For the table in the OP I used commands like:

export model_name=llama_3.2-1b && export quantization=f32
build/bin/llama-bench --model models/opt/${model_name}-${quantization}.gguf -r 1 -fa 1 -n 0 -ub 1,2,3,4,5,6,7,8 -o sql|sqlite3 llama-bench.sqlite

@IMbackK
Copy link
Collaborator

IMbackK commented Jun 22, 2025

@JohannesGaessler added #14324 with path selection optimized for cdna, we can merge this immediately after this one lands.

I would like to spend some time optimizing the kernel itself, as its performance for bs 2 and 3 is not great, but i wont have the time in the near future so lets leave it here for now.

Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM

@IMbackK
Copy link
Collaborator

IMbackK commented Jun 22, 2025

btw expiramenting with rocblas-bench shows we would be better off using rocblas_hshgemv_batched and friends.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants