Skip to content

llama : remove ggml_cont where possible #14568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 7, 2025
Merged

Conversation

CISC
Copy link
Collaborator

@CISC CISC commented Jul 7, 2025

AFAIK all backends support non-contiguous norm/rope, therefore we can remove a lot of unnecessary ggml_cont()s by directly creating a 3D view before these ops.

Gives a quite hefty PP boost for the affected models.

@CISC CISC requested a review from ggerganov July 7, 2025 15:21
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sample M2 Ultra numbers:

Model FlashAttention Test t/s master t/s cisc/remove-unnecessary-conts Speedup
falcon 7B Q4_0 No pp512 1367.69 1382.38 1.01
falcon 7B Q4_0 No tg32 99.63 104.74 1.05
falcon 7B Q4_0 Yes pp512 1384.67 1396.75 1.01
falcon 7B Q4_0 Yes tg32 102.58 108.62 1.06

@jeffbolznv
Copy link
Collaborator

Nice boost for tg, not seeing a gain for pp:

before:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           pp512 |    15172.83 ± 221.32 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        282.68 ± 1.78 |

after:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           pp512 |    15177.85 ± 494.64 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        293.20 ± 1.74 |

@CISC
Copy link
Collaborator Author

CISC commented Jul 7, 2025

Some CUDA numbers before:

model size params backend ngl threads test t/s
phi3 3B IQ4_XS - 4.25 bpw 1.92 GiB 3.82 B CUDA 99 7 pp512 7276.01 ± 54.48
phi3 3B IQ4_XS - 4.25 bpw 1.92 GiB 3.82 B CUDA 99 7 tg128 193.72 ± 1.09

after:

model size params backend ngl threads test t/s
phi3 3B IQ4_XS - 4.25 bpw 1.92 GiB 3.82 B CUDA 99 7 pp512 7506.65 ± 22.00
phi3 3B IQ4_XS - 4.25 bpw 1.92 GiB 3.82 B CUDA 99 7 tg128 196.35 ± 1.01

@CISC
Copy link
Collaborator Author

CISC commented Jul 7, 2025

Looks like MiniCPM broke, investigating...

@CISC
Copy link
Collaborator Author

CISC commented Jul 7, 2025

Ah, hm, nope, it was already broken.

@CISC
Copy link
Collaborator Author

CISC commented Jul 7, 2025

@ggerganov ggml_reshape_2d in llama_kv_cache_unified::cpy_v fails with GGML_ASSERT(ggml_nelements(a) == ne0*ne1), wonder how long this has been broken for MiniCPM3?

@ggerganov
Copy link
Member

Likely since #12449

@CISC
Copy link
Collaborator Author

CISC commented Jul 7, 2025

Likely since #12449

The problem is simply that v_states is turned into a 2D view, removing that fixes it. :P

@ggerganov
Copy link
Member

Is the ggml_cont right before that still needed?

@CISC
Copy link
Collaborator Author

CISC commented Jul 7, 2025

Is the ggml_cont right before that still needed?

Yes, due to the ggml_reshape_2d.

@CISC
Copy link
Collaborator Author

CISC commented Jul 7, 2025

I'll make a separate PR for MiniCPM3 for visibility.

@CISC CISC merged commit 12f55c3 into master Jul 7, 2025
48 checks passed
@CISC CISC deleted the cisc/remove-unnecessary-conts branch July 7, 2025 19:35
@CISC
Copy link
Collaborator Author

CISC commented Jul 7, 2025

gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jul 8, 2025
* origin/master:
model : fix hunyuan moe chat template (ggml-org#14584)
model : add SmolLM3 (ggml-org#14581)
memory : fix broken batch splits for recurrent cache (ggml-org#14575)
vulkan : fix rope with partial rotation and non-cont src (ggml-org#14582)
server: Add ability to mount server at prefix (ggml-org#14544)
model : add hunyuan moe (ggml-org#14425)
vulkan: increase timeout for CI (ggml-org#14574)
cuda : fix rope with partial rotation and non-cont src (ggml-org#14580)
CUDA: add bilinear interpolation for upscale (ggml-org#14563)
musa: fix build warnings (unused variable) (ggml-org#14561)
llama : fix incorrect minicpm3 v_states shape (ggml-org#14571)
llama : remove ggml_cont where possible (ggml-org#14568)
qnixsynapse pushed a commit to menloresearch/llama.cpp that referenced this pull request Jul 10, 2025
qnixsynapse pushed a commit to menloresearch/llama.cpp that referenced this pull request Jul 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants