Skip to content

CUDA: fix FTZ in FA for Gemma 3 #13991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2025

Conversation

JohannesGaessler
Copy link
Collaborator

Fixes #12433 (comment) .

What I think is happening is that there is an underflow in the FlashAttention code when rescaling the FP16 VKQ accumulators. This PR flushes the scale to 0 if it's < 2.06e-9. I don't have multimodal Gemma 3 set up, I did not reproduce the issue on my machine.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 3, 2025
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good solution, though I have some small remaining concerns that there might be something else going on. I tried the same approach with the Metal implementation (i.e. keep accumulating the output in F16 and FTZ the scores like in the CUDA code) and Gemma 3 27B keeps outputting garbage for large prompts. Hard to say what is the root cause as the Metal implementation does not provide many tools for debugging.

Anyway, this should be OK to merge since @mostlygeek confirmed to be running, but we should keep an eye out for any remaining issues.

I don't have multimodal Gemma 3 set up

Btw, you don't need multi-modal Gemma to reproduce the issue. Just load the text-only model and ask it to summarize something about ~100k tokens (for example, server.cpp + llama-context.cpp).

@JohannesGaessler JohannesGaessler merged commit 0b4be4c into ggml-org:master Jun 4, 2025
42 checks passed
@JohannesGaessler
Copy link
Collaborator Author

This seems like a good solution, though I have some small remaining concerns that there might be something else going on.

Well, I hope not. If the CUDA code had to use FP32 for the accumulation of VKQ that would be a pretty big headache for me due to register pressure. BF16 could partially solve the issue but then the new issue is that not all instructions are available on all GPUs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Eval bug: Gemma3 <unused32> spam
2 participants