Skip to content

Conversation

Green-Sky
Copy link
Contributor

fixes #756

@leejet
Copy link
Owner

leejet commented Aug 29, 2025

For the WAN model, flash attention is almost essential; otherwise, the speed will be very slow. So in this PR #778, I removed many restrictions of the flash attention. It is basically controlled entirely by parameters. As for the possible side effects, this is for the user to weigh and decide.

@wbruna
Copy link
Contributor

wbruna commented Aug 29, 2025

As for the possible side effects, this is for the user to weigh and decide.

@leejet , so we'll need to add an extra runtime flag and command-line parameter to disable the kv_pad? Fine by me, but isn't that a little bit cluttered for the interface? (note that FA seems to work fine for image generation otherwise).

I could implement that flag, but it'll necessarily conflict with the Wan PR; and since right now it just crashes for image generation, I won't be able to test it. Wouldn't it be better to apply this PR for now, and I'll figure out how to not break it again inside the Wan branch?

@Green-Sky
Copy link
Contributor Author

No, not before. We should think of a solution on top of the wan pr, since it touches that code.

@leejet
Copy link
Owner

leejet commented Aug 29, 2025

I'm not sure if this will be of any help.

diff --git a/ggml_extend.hpp b/ggml_extend.hpp
index 110bbbc..f583869 100644
--- a/ggml_extend.hpp
+++ b/ggml_extend.hpp
@@ -1013,6 +1013,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
     if (flash_attn) {
         // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
         bool can_use_flash_attn = true;
+        if (L_k < 256) {
+            can_use_flash_attn = false;
+        }
         if (can_use_flash_attn && L_k % 256 != 0) {
             kv_pad = GGML_PAD(L_k, 256) - L_k;
         }

@wbruna
Copy link
Contributor

wbruna commented Aug 29, 2025

That would turn flash attention off for the length-77 tensors that seemed to be the problem, yes.

But... If I may get way over my head here, would something like this make sense instead?

diff --git a/ggml_extend.hpp b/ggml_extend.hpp
index 110bbbc..d5721d5 100644
--- a/ggml_extend.hpp
+++ b/ggml_extend.hpp
@@ -1031,12 +1031,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
     ggml_tensor* kqv = nullptr;
     if (flash_attn) {
         // LOG_DEBUG(" uses flash attention");
         if (kv_pad != 0) {
             // LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
             k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
+            scale *= sqrtf((float)(L_k + kv_pad) / (float)L_k);
         }
         k = ggml_cast(ctx, k, GGML_TYPE_F16);
 
         v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));  // [N, n_head, L_k, d_head]
         v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N);     // [N * n_head, L_k, d_head]
         if (kv_pad != 0) {

As I (barely) understand it, this would compensate for the attention mechanism being applied over the larger, padded tensor. It does seem to fix #756 for me, although I can't really say if it could cause other issues.

@leejet
Copy link
Owner

leejet commented Aug 30, 2025

This issue occurred because no mask was applied after kv_pad. I fixed it in this commit aa5566f

@Green-Sky Green-Sky closed this Aug 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

img2img 'weakened' by diffusion-fa
3 participants