Skip to content

Commit 7afa9b7

Browse files
committed
fix: flash attention
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent c951b89 commit 7afa9b7

File tree

2 files changed

+0
-10
lines changed

2 files changed

+0
-10
lines changed

CMakeLists.txt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ option(SD_METAL "sd: metal backend" OFF)
3434
option(SD_VULKAN "sd: vulkan backend" OFF)
3535
option(SD_CANN "sd: cann backend" OFF)
3636
option(SD_SYCL "sd: sycl backend" OFF)
37-
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
3837
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
3938
#option(SD_BUILD_SERVER "sd: build server example" ON)
4039

@@ -68,11 +67,6 @@ if (SD_HIPBLAS)
6867
add_definitions(-DSD_USE_CUBLAS)
6968
endif ()
7069

71-
if(SD_FLASH_ATTN)
72-
message("-- Use Flash Attention for memory optimization")
73-
add_definitions(-DSD_USE_FLASH_ATTENTION)
74-
endif()
75-
7670
set(SD_LIB stable-diffusion)
7771

7872
file(GLOB SD_LIB_SOURCES

ggml_extend.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -654,9 +654,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
654654
struct ggml_tensor* k,
655655
struct ggml_tensor* v,
656656
bool mask = false) {
657-
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_CANN) && !defined(SD_USE_SYCL)
658-
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
659-
#else
660657
float d_head = (float)q->ne[0];
661658

662659
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
@@ -667,7 +664,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
667664
kq = ggml_soft_max_inplace(ctx, kq);
668665

669666
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
670-
#endif
671667
return kqv;
672668
}
673669

0 commit comments

Comments
 (0)