Skip to content

CUDA: skip masked out KQ slices in mma FA kernel #14735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

JohannesGaessler
Copy link
Collaborator

This PR extends the mma-based CUDA FlashAttention kernel with logic for skipping fully masked-out KQ slices. However, this kernel makes use of asynchronous data loading to preload KV data so this is not straightforward: the mask and K data are being preloaded at the same time, if it turns out that the KQ slice should be skipped the K I/O is wasted, and also the GPU compute pipes are idling until the mask and K data for the next potential KQ slice can be fetched. It turns out that for the new batched inference setup it's faster not to preload any data at all but this is going to hamper the overall performance. With LLaMA 3 8b q4_0 and the command

./batched-bench -t 8 -tb 8 -m models/opt/${model_name}-${quantization}.gguf -c 65536 -b 1024 -ub 1024 -npp 1024 -ntg 128 -npl 1,2,4,8,16,32 -fa -ngl 999

I'm seeing net speedups of ~3% for my RTX 3090/4090, with

LLAMA_ARG_MODEL=/opt/models/llama_3.2_instruct-1b-q4_k_m.gguf py server-bench.py --path_server /home/johannesg/Projects/llama.cpp/build/bin/llama-server --prompt_source rng-1024-40000

I'm seeing net speedups of ~10% (LLAMA_SET_ROWS=1 set for both cases).

I think that the approach of loading a tile of the mask and checking whether it's all == -inf is bad. I think I could write a much better kernel if I instead had a list of those 256x64 mask slices that are not all == -inf. Then I could simply iterate over those indices and preload data without potentially wasting any I/O. It would also allow me to solve the problem where I cannot distribute work to streaming multiprocessors in an optimal way because I don't know ahead of time how the KQ slices needing compute are distributed; if one of the sequences is very long vs. the rest I would ideally assign more SMs to that sequence. For non-batched inference with a diagonal mask I could also skip a few KQ slices, it would reduce the amount of KV data to iterate over per token by half the physical batch size (probably not worthwhile on its own). Ideally one bit of the indices should be reserved to indicate whether all elements in the mask slice are == 0, the mask for those slices will not need to be loaded (as of yet unclear whether that would make a meaningful difference). @ggerganov thoughts?

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels Jul 17, 2025
@ggerganov
Copy link
Member

ggerganov commented Jul 17, 2025

I think it is not very difficult to provide the indices list (or some other meta information about the mask) without making breaking changes. What we can do is during the construction of the KQ mask, we can create a second tensor that contains non-masked block indices. This tensor could be attached to the flash attention op as an optional src:

cur = ggml_flash_attn_ext(ctx, ...);
ggml_flash_attn_ext_add_idxs(cur, kq_idxs);

The backend can decide wether to use these indices or not. We could also pass a bitmask if that would be more appropriate. Or even both. So from the PoV of ggml and llama.cpp APIs I think we can relatively easy create and provide such additional information to the operator.

On the backend side - if you are seeing significant benefits, I think it would be worth it. I think the existing FA Metal kernel can benefit from a bitmask. Not sure about the list of indices, but once it is provided, we can attempt to write an alternative kernel that uses it.

It turns out that for the new batched inference setup it's faster not to preload any data at all but this is going to hamper the overall performance.

The first test that you did with llama-batched-bench does not benefit at all from the -INF masking optimization because the sequences are always of equal length. But it is performing faster without preloading the mask - that's unexpected. Could this pre-loading not have a significant effect? Can you post the resulting numbers from running the llama-batched-bench on master and this PR to see how the respective numbers change?

@JohannesGaessler
Copy link
Collaborator Author

What we can do is during the construction of the KQ mask, we can create a second tensor that contains non-masked block indices.

Yes, that is how I would have implemented it as well.

We could also pass a bitmask if that would be more appropriate.

For CUDA a bitmask would not work well because I don't know where in the bit mask I would have to set the start and end indices for evenly distributing data across streaming multiprocessors.

The first test that you did with llama-batched-bench does not benefit at all from the -INF masking optimization because the sequences are always of equal length. But it is performing faster without preloading the mask - that's unexpected. Could this pre-loading not have a significant effect?

For a single sequence the preloading definitely does provide a benefit, the difference is something like 10% end-to-end speedup. Also I just noticed that I forgot that master and the PR are running different kernels for the RTX 4090, master is running the vector kernel, the PR is running the mma kernel. There is also a difference for the RTX 3090, what could be happening there is that due to the lower shared memory use for a single stage pipeline there is some benefit from higher occupancy.

Can you post the resulting numbers from running the llama-batched-bench on master and this PR to see how the respective numbers change?

master, RTX 4090:

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.083 12306.07 0.823 155.61 0.906 1271.84
1024 128 2 2304 0.153 13425.63 0.922 277.69 1.074 2144.39
1024 128 4 4608 0.307 13352.46 0.990 517.06 1.297 3552.90
1024 128 8 9216 0.613 13371.35 1.274 803.66 1.887 4884.38
1024 128 16 18432 1.228 13344.11 1.668 1227.51 2.896 6364.14
1024 128 32 36864 2.463 13306.32 2.304 1777.46 4.767 7733.16

PR, RTX 4090:

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.084 12165.57 0.823 155.51 0.907 1269.75
1024 128 2 2304 0.153 13403.58 0.936 273.58 1.089 2116.59
1024 128 4 4608 0.307 13328.60 1.002 511.02 1.309 3519.66
1024 128 8 9216 0.614 13333.18 1.282 798.63 1.897 4859.21
1024 128 16 18432 1.231 13313.54 1.675 1222.85 2.905 6344.04
1024 128 32 36864 2.469 13272.45 2.286 1791.48 4.755 7752.27

master, RTX 3090:

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.192 5336.50 0.990 129.33 1.182 974.96
1024 128 2 2304 0.374 5482.68 1.108 231.05 1.482 1555.14
1024 128 4 4608 0.751 5450.97 1.345 380.63 2.097 2197.87
1024 128 8 9216 1.522 5380.98 2.224 460.37 3.747 2459.77
1024 128 16 18432 3.058 5358.17 2.308 887.22 5.366 3434.89
1024 128 32 36864 6.143 5334.01 3.322 1232.96 9.465 3894.64

PR, RTX 3090:

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 128 1 1152 0.194 5270.15 0.993 128.89 1.187 970.17
1024 128 2 2304 0.381 5373.56 1.133 225.92 1.514 1521.55
1024 128 4 4608 0.761 5382.02 1.350 379.32 2.111 2183.00
1024 128 8 9216 1.534 5339.30 2.240 457.08 3.775 2441.60
1024 128 16 18432 3.089 5304.12 2.310 886.48 5.399 3413.85
1024 128 32 36864 6.201 5284.08 3.209 1276.23 9.411 3917.24

I'll make a prototype for a kernel using a list of mask slices.

@JohannesGaessler
Copy link
Collaborator Author

I'm seeing net speedups of ~3%

I think I did the math in my head wrong, sorry.

@ggerganov
Copy link
Member

Try to make the list of indices to be compatible both with unified and split KV cache buffers.

@slaren You mentioned recently that such modification might not be practical (#14363 (comment)). What do you think about the discussed approach above to pass the indices as an optional tensor that the backends can choose to ignore?

@slaren
Copy link
Member

slaren commented Jul 17, 2025

Could the backend itself generate the list of indices from the KQ mask in a pre-processing step?

@JohannesGaessler
Copy link
Collaborator Author

Could the backend itself generate the list of indices from the KQ mask in a pre-processing step?

I considered this but the problem is that I think that that will be inefficient to parallelize. With CUDA I can efficiently generate e.g. a bitmask for which mask slices are all -inf. But a compact list of indices of the active KQ slices needs significantly more communication between CUDA threads.

I think I'll just make a prototype with an extended interface for ggml_flash_attn_ext first. If it turns out that the performance difference is negligible anyways the approach can be scrapped. Otherwise I can implement the generation of the list in the backend and compare vs. generation in llama.cpp.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants