From 063083da81d64ce5ad9259b4e411bac350f89d2d Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 1 Aug 2025 00:53:18 -0500 Subject: [PATCH 1/2] persistent_counter is also needed when SWA is requested --- .../native/transformers/hip/flash_attn/aot/mha_all_aot.hip | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 05523f75caa4..8dae54e91979 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -275,7 +275,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); - auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); + auto persistent_counter = mk_atomictensor(is_causal || uses_swa ? atomic_counter.data_ptr() : nullptr); if (uses_swa) { #if V3_API using aotriton::v3::flash::CausalType; @@ -476,7 +476,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto nullscalar = mk_philoxtensor(nullptr); auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; - auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; + auto persistent_counter = is_causal || uses_swa ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; if (uses_swa) { #if V3_API using aotriton::v3::flash::CausalType; From 759a982562651b662218ff45281e8cb29520b26d Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 1 Aug 2025 18:50:27 -0500 Subject: [PATCH 2/2] Fix atomic_counter creation --- .../hip/flash_attn/aot/mha_all_aot.hip | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 8dae54e91979..1d4926c02274 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -243,12 +243,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x } else { softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts); } - - at::Tensor atomic_counter; - if (is_causal) { - atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); - } - auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, window_size_right, seqlen_q, @@ -262,6 +256,14 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x constexpr bool uses_swa = false; #endif + // SWA in AOTriton Kernels is treated as "Generalized Causal masks" + is_causal = is_causal || uses_swa; + + at::Tensor atomic_counter; + if (is_causal) { + atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); + } + hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; using sdp::aotriton_adapter::mk_aotensor; @@ -275,7 +277,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); - auto persistent_counter = mk_atomictensor(is_causal || uses_swa ? atomic_counter.data_ptr() : nullptr); + auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); if (uses_swa) { #if V3_API using aotriton::v3::flash::CausalType; @@ -455,6 +457,9 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot constexpr bool uses_swa = false; #endif + // SWA in AOTriton Kernels is treated as "Generalized Causal masks" + is_causal = is_causal || needs_swa; + auto [seed_t, offset_t, philox_state, use_philox_state] = prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); @@ -476,7 +481,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto nullscalar = mk_philoxtensor(nullptr); auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; - auto persistent_counter = is_causal || uses_swa ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; + auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; if (uses_swa) { #if V3_API using aotriton::v3::flash::CausalType;