Skip to content

add GGML_USE_NUMA_MIGRATE feature to optimize cross NUMA op computation #13649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from

Conversation

wenlujon
Copy link

@wenlujon wenlujon commented May 20, 2025

This PR adds GGML_USE_NUMA_MIGRATE feature to optimize cross NUMA op computation as the cross-NUMA memory access would be the bottleneck if spawning threads across NUMA nodes:

  1. optimize the ggml_barrier() for cross NUMA case by adding ggml_barrier_numa_aware(), currently enabled for aarch64 forward_mul_mat()
  2. add build option GGML_USE_NUMA_MIGRATE to enable the feature
  3. add option --numa migrate if GGML_USE_NUMA_MIGRATE is enabled, which would migrate pages across NUMA nodes so that mul_mat op would only do the computation in local numa part of tensor data src0 and dst according to the ith, cores would be set affinity with the option.

Test Results

With the feature, tested on NeoVerse N2 with multiple NUMA nodes (with 64 cores per NUMA node), see performance improvements with running llama3 model.

Base data

running command:

$ numactl -m 0 -N 0 $LLAMA/build/bin/llama-batched-bench -m ./models/Meta-Llama-3-8B-Instruct.Q4_0.gguf -c 4096 -b 2048 -ub 512 -npp 128 -ntg 128 -npl 1,4,8,16 -t 64

results:

|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|   128 |    128 |    1 |    256 |    0.286 |   447.56 |    5.119 |    25.01 |    5.405 |    47.37 |
|   128 |    128 |    4 |   1024 |    1.391 |   368.03 |    6.025 |    84.98 |    7.416 |   138.08 |
|   128 |    128 |    8 |   2048 |    2.925 |   350.03 |    8.641 |   118.51 |   11.566 |   177.07 |
|   128 |    128 |   16 |   4096 |    6.853 |   298.83 |   11.240 |   182.21 |   18.093 |   226.38 |

Result with the feature

running command:

$ ./llama-batched-bench -m $MODEL_DIR/$MODEL -c 4096 -b 2048 -ub 512 -npp 128 -ntg 128 -npl 1,2,4,8,16 -t 128 --cache-type-k q8_0 --cache-type-v q8_0 -fa --numa migrate

results:

|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|   128 |    128 |    1 |    256 |    0.291 |   439.45 |    4.186 |    30.58 |    4.477 |    57.18 |
|   128 |    128 |    4 |   1024 |    1.458 |   351.08 |    5.909 |    86.65 |    7.367 |   139.00 |
|   128 |    128 |    8 |   2048 |    2.943 |   347.95 |    7.499 |   136.55 |   10.442 |   196.13 |
|   128 |    128 |   16 |   4096 |    7.738 |   264.68 |   10.933 |   187.33 |   18.670 |   219.39 |

For example, there's S_TG t/s 22% performance improvement with B=1, slightly S_PP t/s drop (1.8%), overall 20.7% S T/s improvement.

Tested perplexity, there's no regression with the feature:

$ numactl -m 0,1 -N 0,1 $LLAMA/build/bin/llama-perplexity -m ./models/Meta-Llama-3-8B-Instruct.Q4_0.gguf -f $LLAMA/wikitext-2-raw/wiki.test.raw -t 128 --numa migrate
Final estimate: PPL = 8.7538 +/- 0.06481

Enablement

The feature is disabled by default, this is the guidances to enable the feature:

  1. enable DGGML_NUMA_MIGRATE during cmake
$ cmake -B build -DGGML_NUMA_MIGRATE=ON
  1. add --numa migrate option when running, and set thread numbers according to the number of numa nodes (which is 2 nodes by default which is the best trade-off number during the NeoVerse N2 platform in test, could be changed by setting GGML_NUMA_MIGRATE_NODES build option, it's also better to bind llama to numa node 0 and 1 through numactl -m 0,1 -N 0,1).

@github-actions github-actions bot added examples ggml changes relating to the ggml tensor library for machine learning labels May 20, 2025
@Warshmao
Copy link

enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) {
    size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
    size = GGML_PAD(size, talloc->alignment);

    if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {
        GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
                __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
        GGML_ABORT("not enough space in the buffer");
    }

    void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset;
    talloc->offset += size;

    assert(((uintptr_t)addr % talloc->alignment) == 0);

    return ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
}

load model error: "not enough space in the buffer"

@wenlujon
Copy link
Author

wenlujon commented May 21, 2025

enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) {
    size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
    size = GGML_PAD(size, talloc->alignment);

    if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {
        GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
                __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
        GGML_ABORT("not enough space in the buffer");
    }

    void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset;
    talloc->offset += size;

    assert(((uintptr_t)addr % talloc->alignment) == 0);

    return ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
}

load model error: "not enough space in the buffer"

The issue should be caused by the GGML_PAD with page size alignment hence the overall size exceeds the original allocated size in ggml_backend_cpu_buffer_type_alloc_buffer().
The issue relates with test environment as I don't see the issue with the models (llama3, DeepSeek, and Qwen3) and platforms I tested (NeoVerse and Rome), so could you please share the environment you're using: the model, test platform information? So that I could have a fix and then verify it.

@Warshmao
Copy link

enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) {
    size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
    size = GGML_PAD(size, talloc->alignment);

    if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {
        GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
                __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
        GGML_ABORT("not enough space in the buffer");
    }

    void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset;
    talloc->offset += size;

    assert(((uintptr_t)addr % talloc->alignment) == 0);

    return ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
}

load model error: "not enough space in the buffer"

The issue should be caused by the GGML_PAD with page size alignment hence the overall size exceeds the original allocated size in ggml_backend_cpu_buffer_type_alloc_buffer(). The issue relates with test environment as I don't see the issue with the models (llama3, DeepSeek, and Qwen3) and platforms I tested (NeoVerse and Rome), so could you please share the environment you're using: the model, test platform information? So that I could have a fix and then verify it.

image
image
image
image
image

model: Qwen3-30B-A3B-Q4_K_M
system: ubuntu 22.04
cpu: intel Xeon 8480C * 2
mem: 1T

@wenlujon
Copy link
Author

Qwen3-30B-A3B-Q4_K_M

OK, Seems an issue also related with platform, could you please retry the latest commit? I added the size with num of tensors * page_size, which should address the issue here. it should not have big impact on memory consumption, for example, for Qwen3-235B-A22B Q8_0 which has 1131 tensors, the increased memory would be 1131*4k=4.4M.

@Warshmao
Copy link

Qwen3-30B-A3B-Q4_K_M

OK, Seems an issue also related with platform, could you please retry the latest commit? I added the size with num of tensors * page_size, which should address the issue here. it should not have big impact on memory consumption, for example, for Qwen3-235B-A22B Q8_0 which has 1131 tensors, the increased memory would be 1131*4k=4.4M.

not work.
Segmentation fault.
image

@wenlujon
Copy link
Author

Qwen3-30B-A3B-Q4_K_M

OK, Seems an issue also related with platform, could you please retry the latest commit? I added the size with num of tensors * page_size, which should address the issue here. it should not have big impact on memory consumption, for example, for Qwen3-235B-A22B Q8_0 which has 1131 tensors, the increased memory would be 1131*4k=4.4M.

not work. Segmentation fault. image

Good thing is it seems it passed the buffer allocation. what's the backtrace for the seg fault?

@Warshmao
Copy link

Qwen3-30B-A3B-Q4_K_M

OK, Seems an issue also related with platform, could you please retry the latest commit? I added the size with num of tensors * page_size, which should address the issue here. it should not have big impact on memory consumption, for example, for Qwen3-235B-A22B Q8_0 which has 1131 tensors, the increased memory would be 1131*4k=4.4M.

not work. Segmentation fault. image

Good thing is it seems it passed the buffer allocation. what's the backtrace for the seg fault?

image

debug log info

@wenlujon
Copy link
Author

Qwen3-30B-A3B-Q4_K_M

OK, Seems an issue also related with platform, could you please retry the latest commit? I added the size with num of tensors * page_size, which should address the issue here. it should not have big impact on memory consumption, for example, for Qwen3-235B-A22B Q8_0 which has 1131 tensors, the increased memory would be 1131*4k=4.4M.

not work. Segmentation fault. image

Good thing is it seems it passed the buffer allocation. what's the backtrace for the seg fault?

image

debug log info

Do you have a core dump generated? if so, you might use gdb to open it to see the backtrace.

$ gdb [your program] [core file]
$ bt

it would be better if you could build a debug build, and then use gdb attach to it.

gdb [your program]
(gdb) run [with your arguments]

BTW, what's the running command of your test?

@Warshmao
Copy link

Qwen3-30B-A3B-Q4_K_M

OK, Seems an issue also related with platform, could you please retry the latest commit? I added the size with num of tensors * page_size, which should address the issue here. it should not have big impact on memory consumption, for example, for Qwen3-235B-A22B Q8_0 which has 1131 tensors, the increased memory would be 1131*4k=4.4M.

not work. Segmentation fault. image

Good thing is it seems it passed the buffer allocation. what's the backtrace for the seg fault?

image
debug log info

Do you have a core dump generated? if so, you might use gdb to open it to see the backtrace.

$ gdb [your program] [core file]
$ bt

it would be better if you could build a debug build, and then use gdb attach to it.

gdb [your program]
(gdb) run [with your arguments]

BTW, what's the running command of your test?

llama-bench -m Qwen3-30B-A3B-Q4_K_M.gguf -t 50 -p 512 -n 128 --numa migrate --verbnose

image

@wenlujon
Copy link
Author

Qwen3-30B-A3B-Q4_K_M

OK, Seems an issue also related with platform, could you please retry the latest commit? I added the size with num of tensors * page_size, which should address the issue here. it should not have big impact on memory consumption, for example, for Qwen3-235B-A22B Q8_0 which has 1131 tensors, the increased memory would be 1131*4k=4.4M.

not work. Segmentation fault. image

Good thing is it seems it passed the buffer allocation. what's the backtrace for the seg fault?

image
debug log info

Do you have a core dump generated? if so, you might use gdb to open it to see the backtrace.

$ gdb [your program] [core file]
$ bt

it would be better if you could build a debug build, and then use gdb attach to it.

gdb [your program]
(gdb) run [with your arguments]

BTW, what's the running command of your test?

llama-bench -m Qwen3-30B-A3B-Q4_K_M.gguf -t 50 -p 512 -n 128 --numa migrate --verbnose

image

so the cgraph of running llama-bench is null, I've fixed it in the latest commit, could you please give it a try?

@Warshmao
Copy link

Qwen3-30B-A3B-Q4_K_M

OK, Seems an issue also related with platform, could you please retry the latest commit? I added the size with num of tensors * page_size, which should address the issue here. it should not have big impact on memory consumption, for example, for Qwen3-235B-A22B Q8_0 which has 1131 tensors, the increased memory would be 1131*4k=4.4M.

not work. Segmentation fault. image

Good thing is it seems it passed the buffer allocation. what's the backtrace for the seg fault?

image
debug log info

Do you have a core dump generated? if so, you might use gdb to open it to see the backtrace.

$ gdb [your program] [core file]
$ bt

it would be better if you could build a debug build, and then use gdb attach to it.

gdb [your program]
(gdb) run [with your arguments]

BTW, what's the running command of your test?

llama-bench -m Qwen3-30B-A3B-Q4_K_M.gguf -t 50 -p 512 -n 128 --numa migrate --verbnose
image

so the cgraph of running llama-bench is null, I've fixed it in the latest commit, could you please give it a try?

It worked.

@Warshmao
Copy link

but the performance is not up to ~20%

numa_migrate
image

b5439
image

@wenlujon
Copy link
Author

but the performance is not up to ~20%

numa_migrate image

b5439 image

I notice you have 112 cores, could you run with -t 112, only in this case, the feature would leverage more cores to do OP computation.

@Warshmao
Copy link

but the performance is not up to ~20%
numa_migrate image
b5439 image

I notice you have 112 cores, could you run with -t 112, only in this case, the feature would leverage more cores to do OP computation.

numa_migrate -t 112
image

@wenlujon
Copy link
Author

wenlujon commented May 21, 2025

but the performance is not up to ~20%
numa_migrate image
b5439 image

I notice you have 112 cores, could you run with -t 112, only in this case, the feature would leverage more cores to do OP computation.

numa_migrate -t 112 image

OK, so we could see there's obvious performance uplift for pp512 test, though there's no improvement for tg128 (guess the cross numa traffic isn't the bottleneck in the case).
BTW, could you run base data with also 112 threads? so we could have a better comparison.

…rier as the barrier only happens in few cores
@Warshmao
Copy link

but the performance is not up to ~20%
numa_migrate image
b5439 image

I notice you have 112 cores, could you run with -t 112, only in this case, the feature would leverage more cores to do OP computation.

numa_migrate -t 112 image

OK, so we could see there's obvious performance uplift for pp512 test, though there's no improvement for tg128 (guess the cross numa traffic isn't the bottleneck in the case). BTW, could you run base data with also 112 threads? so we could have a better comparison.

b5439 -t 112
image

@wenlujon
Copy link
Author

but the performance is not up to ~20%
numa_migrate image
b5439 image

I notice you have 112 cores, could you run with -t 112, only in this case, the feature would leverage more cores to do OP computation.

numa_migrate -t 112 image

OK, so we could see there's obvious performance uplift for pp512 test, though there's no improvement for tg128 (guess the cross numa traffic isn't the bottleneck in the case). BTW, could you run base data with also 112 threads? so we could have a better comparison.

b5439 -t 112 image

I saw there're 2 runs with the feature, the first is at 29.51 t/s and second run is 34.63 t/s, the base data is stable at 29.72 t/s, could you rerun the pp512 with the feature to see if it's stable at 34 t/s? if that's the case, then we could observe 17% uplift. anyway, x86 has better cross NUMA traffic latency than aarch64.

@wenlujon
Copy link
Author

wenlujon commented May 22, 2025

Hi @Warshmao , will you do more testing with pp512? Or do you have an aarch64 system to verify it?
BTW, I tested the llama3 with IceLake systems, I do see improvements with pp512 tests (8% ~ 16%) and slight drop with tg128 (5%):

$ lscpu
On-line CPU(s) list:                0-93
Off-line CPU(s) list:               94,95
Thread(s) per core:                 1
Core(s) per socket:                 24
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              106
Model name:                         Intel(R) Xeon(R) Gold 6336Y CPU @ 2.40GHz


Base:
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           pp512 |       225.78 ± 17.27 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           tg128 |         33.76 ± 0.19 |
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           pp512 |       241.63 ± 33.17 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           tg128 |         34.70 ± 0.66 |

| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           pp512 |        237.33 ± 9.02 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           tg128 |         33.01 ± 1.42 |

NUMA opt:

| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           pp512 |        262.09 ± 3.29 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           tg128 |         31.90 ± 0.10 |
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           pp512 |        262.37 ± 2.27 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           tg128 |         31.90 ± 0.12 |
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           pp512 |        260.17 ± 4.15 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      94 |           tg128 |         32.18 ± 0.26 |

And for aarch64, there's 16.8% uplift for pp512 and 25.9% uplift for tg128 with 64 threads as base data (64 cores per numa node in the platform, will get best performance with 64 threads in aarch64 as base data), if comparing with 128 threads as base data, that's 27.9% and 140% uplift for pp512 and tg128 respectively:

Base:
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      64 |           pp512 |        498.54 ± 1.69 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      64 |           tg128 |         25.06 ± 0.01 |
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      64 |           pp512 |        499.53 ± 0.45 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |      64 |           tg128 |         25.17 ± 0.13 |

Base with 128 threads (see performance drop compared to 64 threads):
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |     128 |           pp512 |        451.64 ± 9.04 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |     128 |           tg128 |         11.06 ± 0.92 |
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |     128 |           pp512 |        471.37 ± 5.17 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |     128 |           tg128 |         13.03 ± 0.48 |


NUMA OPT:
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |     128 |           pp512 |        582.72 ± 0.44 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |     128 |           tg128 |         31.30 ± 0.07 |
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |     128 |           pp512 |        577.65 ± 0.46 |
| llama 8B Q4_0                  |   4.33 GiB |     8.03 B | CPU        |     128 |           tg128 |         31.71 ± 0.06 |

So the feature would improve performance significantly for aarch64 platform.

@Nor7th
Copy link

Nor7th commented Jun 17, 2025

Hi, that's a very nice work you've done.
I personally have some confusion about the patch and hope you could help to address the doubts.


Below 2 figures (base and opt) show my overall understanding after going through all the code:

base
image
For base case, tensor data may be randomly distributed over the memory space of 2 numa nodes, and all the compute threads are dynamically scheduled over all the cores of 2 numa nodes. This may cause massive uncontrolled cross-numa remote memory access.

opt
image
This opt patch aims to get all things under control as much as possible: let threads of node0 cores process data on node0 memory, and similarly for node1.

But I'm still not quite sure of the root cause of the performance gain. Especially cannot fully understand this statement:

" mul_mat op would only do the computation in local numa part of tensor data src0 and dst according to the ith "

My doubt is, is this 100% achieved?

Let me firstly explain my doubt from 2 aspects, the src0(k×n) and src1(k×m) of the forward_mul_mat kernel:

  1. For src0 tensor, the task of ith thread is restricted by:
        int64_t      src0_start      = (ith * ne01) / nth;
        int64_t      src0_end        = ((ith + 1) * ne01) / nth;

This is a partition for dimension "ne01" (the "n" of k×n, the 2nd dimension).
Considering that this patch can physically evenly distribute the data of src0 over 2 numa nodes. So it looks like we can successfully get first half threads to process first half data in node0, and the other half in node1.
However, at least in my view, the tensor data cannot be fully evenly distributed (I mean exactly 50% for each node is nearly impossible), because the tensor data area is aligned to system pagesize, which means the effective weight data is concentrated in one side of the aligned area, with some padding on the other side. Therefore, perhaps in most cases, node0 still has slightly more data than node1:
image
This means node0 may achieve 100% goal that threads within this node only need to access local data. But node1 cannot, as some few threads in node1 may still need to access data on node0.
That's why I doubt it's not 100% true that each node only needs to access local data of itself. But the overall quantitiy of cross-numa memory access is massively reduced, that really contributes to the performance gain.

  1. For src1 tensor, I noticed there is a specific commit in this patch for src1. It allows each numa node have their own wdata space to participate in the src1 data online quantize task. I think the overall case in src1 is similar with what I explain for src0, it's not a 100% goal achieved optimization, but largely reduced the uncontrolled remote memory access.
    BTW, I wonder what's the performance gain without the particular src1 opt commit ( only with the barrier updating modification )? Does the src1 opt helped a lot for the whole forward_mul_mat performance?

The last doubt I have is for kernel forward_mul_mat_id.
Your patch doesn't modify much of this kernel, except for updating barrier statements. Will it also be helpful to do the src1 opt like in forward_mul_mat ?
I think forward_mul_mat_id has an obvious difference with forward_mul_mat in that the src0 tensor is 3-dimensional. Although the computation task of each thread is still restricted by ith on 2nd dimension like in forward_mul_mat, it happens not in the main path, but in the loop of 3rd dimension.
This means that the page migration of src0 tensor in forward_mul_mat_id may not achieve same effect as in forward_mul_mat:
image
Becase of the existence of 3rd dimension, the tensor data for 2nd dimension need to be partitioned multiple times, not for only once as in forward_mul_mat. The page migration cannot naturally achieve the effect that 2 halves of 2nd dimension data distributed into 2 numa nodes. They are more like in an interleaved format.
So I guess the performance gain for forward_mul_mat_id is not that high as in forward_mul_mat? After all it still has many cross numa memory access. But it still have some improvement over the base case, right? Although I cannot clarify the specific root cause why is it.
Any way, I doubt that the page migration mechanism seems to be more friendly for 2-dimension tensor data. Higher dimension data cannot benefit much of it. Is this conclusion reasonable?


It'll be really appreciate to get your comments.

@wenlujon
Copy link
Author

wenlujon commented Jun 17, 2025

@Nor7th thanks for your review.

My doubt is, is this 100% achieved?

no, it's not 100% achieved, depends on the tensor data, in some models like llama 3, the tensor data of src0 and dst are basically the 2x page sizes, and so each numa nodes could handle data locally (for example node 0 handles page 0 while node 1 handles page 1, with node 0 has 64 cores, each core handles 4096/64=64 bytes), but for tensor data which is not 2x page sizes, then it's not handled in the case (pages won't be moved).

I think the overall case in src1 is similar with what I explain for src0, it's not a 100% goal achieved optimization, but largely reduced the uncontrolled remote memory access.

src1 has to be read during gemv/gemm with an entire buffer, but the quantization is not the bottleneck (it's much more faster than gemm/gemv), so by introducing a numa-local wdata will speed up the gemm/gemv.
the performance uplift with src1 opt is obvious, especially for pp512:
without src1 opt:
pp512: +16.3%
tg128: +23.6%

with src1 opt:
pp512: +53.5%
tg128: +33.1%

The last doubt I have is for kernel forward_mul_mat_id.

forward_mul_mat_id is not the hotspot so it's not touched in the case (it won't have obvious improvement with the change).

BTW, I'll create a PR for the change and rebase the changes into a single commit also merge with the latest llama.cpp, so i'll close the PR.

@wenlujon
Copy link
Author

close the PR, filed a new one with all commits rebased as one and merged with latest commit in upstream:
#14232

@wenlujon wenlujon closed this Jun 17, 2025
@Nor7th
Copy link

Nor7th commented Jun 17, 2025

@wenlujon Thank you for clarifying most of my doubts. Really appreciate that!

The only point you might missed is the "fake" even distribution of tensor data pages:
image
The 2 nodes both have 2 pages, but, that does not mean each node process only the data of its own, right?
Because when you do the tensor data memory allocation, the allocated size is actually aligned to system pagesize, which will bring in possible padding bytes in the last page. So the last page may not be filled with effective tensor data, the real effective tensor data is not evenly distributed between the 2 nodes. Threads in node1 may still need to access data in node0.

However the above cases should have a small proportion. The overall improvement is still good enough.

@wenlujon
Copy link
Author

Because when you do the tensor data memory allocation, the allocated size is actually aligned to system pagesize, which will bring in possible padding bytes in the last page.

please note not only the size of tensor data are numbers of page sizes, but the start address of the tensor data is guaranteed to allocated page aligned (see the get alignment change), so if there're 4 pages in the case, each node would process 2 pages, why there's padding bytes in the last page?

@Nor7th
Copy link

Nor7th commented Jun 17, 2025

Because when you do the tensor data memory allocation, the allocated size is actually aligned to system pagesize, which will bring in possible padding bytes in the last page.

please note not only the size of tensor data are numbers of page sizes, but the start address of the tensor data is guaranteed to allocated page aligned (see the get alignment change), so if there're 4 pages in the case, each node would process 2 pages, why there's padding bytes in the last page?

Perhaps I was wrong, as I reconfirm the following code:
image

You directly use the result from ggml_backend_buffer_get_alloc_size, not with extra padding action like in:

image

The this_size here is the real size allocated for a tensor->data, it includes the result of alloc_size and also a padding action. That will cause padding area for tensor->data space.

But your code doesn't bring in such issue. I was a bit careless here.

@wenlujon
Copy link
Author

@Nor7th that's a good catch, from the model i'm using (llama3), the size for mal_mut op from ggml_backend_buft_get_alloc_size() are page aligned, so migrate_pages_with_cache() would migrate pages for them, other OPs are not intended to be moved, so migrate_pages_with_cache() would just bail out:
if (num_pages && ((size % ggml_backend_page_size) == 0)) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples ggml changes relating to the ggml tensor library for machine learning
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants