diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 05523f75caa4..1d4926c02274 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -243,12 +243,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x } else { softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts); } - - at::Tensor atomic_counter; - if (is_causal) { - atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); - } - auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, window_size_right, seqlen_q, @@ -262,6 +256,14 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x constexpr bool uses_swa = false; #endif + // SWA in AOTriton Kernels is treated as "Generalized Causal masks" + is_causal = is_causal || uses_swa; + + at::Tensor atomic_counter; + if (is_causal) { + atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); + } + hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; using sdp::aotriton_adapter::mk_aotensor; @@ -455,6 +457,9 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot constexpr bool uses_swa = false; #endif + // SWA in AOTriton Kernels is treated as "Generalized Causal masks" + is_causal = is_causal || needs_swa; + auto [seed_t, offset_t, philox_state, use_philox_state] = prepare_philox_arguments(p_dropout, batch_size * num_heads * 32);