Skip to content

cuda : fix rope with partial rotation and non-cont src #14580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 8, 2025

Conversation

ggerganov
Copy link
Member

The problem was revealed by #14573 and #14568

The problem occurs for ropes with n_dim < ne00 + non-cont src.

We never really had proper tests for this case because of the all = false; logic in the test-backend-ops.

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jul 8, 2025
@ggerganov
Copy link
Member Author

@qnixsynapse Could you push a SYCL fix directly in this branch, as I don't have an environment to test with?

@ggerganov
Copy link
Member Author

ggerganov commented Jul 8, 2025

Qwen2VL models might have also been affected. Probably not

@github-actions github-actions bot added the SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language label Jul 8, 2025
Comment on lines 62 to 63
dst[i + 0] = x[ix + 0];
dst[i + 1] = x[ix + 1];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dst[i + 0] = x[ix + 0];
dst[i + 1] = x[ix + 1];
dst[idst + 0] = x[ix + 0];
dst[idst + 1] = x[ix + 1];

From what I can tell i and idst are the same.

@ggerganov ggerganov merged commit 4d0dcd4 into master Jul 8, 2025
55 of 56 checks passed
@ggerganov ggerganov deleted the gg/cuda-fix-rope-non-cont branch July 8, 2025 07:15
@qnixsynapse
Copy link
Collaborator

@qnixsynapse Could you push a SYCL fix directly in this branch, as I don't have an environment to test with?

Oh you already seems to fix it... Thank you!!

@CISC
Copy link
Collaborator

CISC commented Jul 8, 2025

gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jul 8, 2025
* origin/master:
model : fix hunyuan moe chat template (ggml-org#14584)
model : add SmolLM3 (ggml-org#14581)
memory : fix broken batch splits for recurrent cache (ggml-org#14575)
vulkan : fix rope with partial rotation and non-cont src (ggml-org#14582)
server: Add ability to mount server at prefix (ggml-org#14544)
model : add hunyuan moe (ggml-org#14425)
vulkan: increase timeout for CI (ggml-org#14574)
cuda : fix rope with partial rotation and non-cont src (ggml-org#14580)
CUDA: add bilinear interpolation for upscale (ggml-org#14563)
musa: fix build warnings (unused variable) (ggml-org#14561)
llama : fix incorrect minicpm3 v_states shape (ggml-org#14571)
llama : remove ggml_cont where possible (ggml-org#14568)
qnixsynapse pushed a commit to menloresearch/llama.cpp that referenced this pull request Jul 10, 2025
* cuda : fix rope non-cont

ggml-ci

* cont : fix multi-rope + add test

ggml-ci

* sycl : try fix

ggml-ci

* cont : fix sycl + clean-up cuda

ggml-ci
qnixsynapse pushed a commit to menloresearch/llama.cpp that referenced this pull request Jul 10, 2025
* cuda : fix rope non-cont

ggml-ci

* cont : fix multi-rope + add test

ggml-ci

* sycl : try fix

ggml-ci

* cont : fix sycl + clean-up cuda

ggml-ci
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants