-
Notifications
You must be signed in to change notification settings - Fork 12.4k
llama : remove ggml_cont where possible #14568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sample M2 Ultra numbers:
Model | FlashAttention | Test | t/s master | t/s cisc/remove-unnecessary-conts | Speedup |
---|---|---|---|---|---|
falcon 7B Q4_0 | No | pp512 | 1367.69 | 1382.38 | 1.01 |
falcon 7B Q4_0 | No | tg32 | 99.63 | 104.74 | 1.05 |
falcon 7B Q4_0 | Yes | pp512 | 1384.67 | 1396.75 | 1.01 |
falcon 7B Q4_0 | Yes | tg32 | 102.58 | 108.62 | 1.06 |
Nice boost for tg, not seeing a gain for pp:
|
Some CUDA numbers before:
after:
|
Looks like MiniCPM broke, investigating... |
Ah, hm, nope, it was already broken. |
@ggerganov |
Likely since #12449 |
The problem is simply that |
Is the |
Yes, due to the |
I'll make a separate PR for MiniCPM3 for visibility. |
Huh, for some weird reason this seems to have broken GPTNeoX: |
* origin/master: model : fix hunyuan moe chat template (ggml-org#14584) model : add SmolLM3 (ggml-org#14581) memory : fix broken batch splits for recurrent cache (ggml-org#14575) vulkan : fix rope with partial rotation and non-cont src (ggml-org#14582) server: Add ability to mount server at prefix (ggml-org#14544) model : add hunyuan moe (ggml-org#14425) vulkan: increase timeout for CI (ggml-org#14574) cuda : fix rope with partial rotation and non-cont src (ggml-org#14580) CUDA: add bilinear interpolation for upscale (ggml-org#14563) musa: fix build warnings (unused variable) (ggml-org#14561) llama : fix incorrect minicpm3 v_states shape (ggml-org#14571) llama : remove ggml_cont where possible (ggml-org#14568)
AFAIK all backends support non-contiguous norm/rope, therefore we can remove a lot of unnecessary ggml_cont()s by directly creating a 3D view before these ops.
Gives a quite hefty PP boost for the affected models.