Skip to content

Commit f4cde98

Browse files
committed
update default priority order setting
1 parent e5490c7 commit f4cde98

File tree

4 files changed

+19
-17
lines changed

4 files changed

+19
-17
lines changed

aten/src/ATen/Context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,9 @@ class TORCH_API Context {
432432
std::array<at::SDPBackend, at::num_sdp_backends> sdp_priority_order = {
433433
at::SDPBackend::flash_attention,
434434
at::SDPBackend::efficient_attention,
435-
at::SDPBackend::overrideable,
436435
at::SDPBackend::math,
437-
at::SDPBackend::cudnn_attention};
436+
at::SDPBackend::cudnn_attention,
437+
at::SDPBackend::overrideable};
438438
bool enabled_flashSDP = true;
439439
bool enabled_mem_efficientSDP = true;
440440
bool enabled_mathSDP = true;

aten/src/ATen/native/mkldnn/xpu/Attention.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,20 @@ bool can_use_mem_efficien_attention(sdp::sdp_params const& params, bool debug) {
9393
return false;
9494
}
9595

96+
bool priority_order_init = false;
97+
9698
std::array<sdp::SDPBackend, sdp::num_backends> priority_order(
9799
sdp::sdp_params const& params) {
100+
if (!priority_order_init) {
101+
priority_order_init = true;
102+
const std::vector<int64_t> priority_order = {
103+
static_cast<int64_t>(at::SDPBackend::overrideable),
104+
static_cast<int64_t>(at::SDPBackend::math),
105+
static_cast<int64_t>(at::SDPBackend::flash_attention),
106+
static_cast<int64_t>(at::SDPBackend::efficient_attention),
107+
static_cast<int64_t>(at::SDPBackend::cudnn_attention)};
108+
at::globalContext().setSDPPriorityOrder(priority_order);
109+
}
98110
return at::globalContext().sDPPriorityOrder();
99111
}
100112

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -805,13 +805,6 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
805805
return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
806806
}
807807

808-
inline bool can_use_overrideable_attention(sdp_params const& params, bool debug) {
809-
if (debug) {
810-
TORCH_WARN("CUDA don't support SDPA overrideable attention backend.");
811-
}
812-
return false;
813-
}
814-
815808
SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
816809
// This function defines the priority order of the different sdp backends
817810
// 1. Flash Attention
@@ -851,8 +844,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
851844
}
852845
break;
853846
case SDPBackend::overrideable:
854-
if (ctx.userEnabledOverrideableSDP() &&
855-
sdp::can_use_overrideable_attention(kernel_params, print_debug)) {
847+
if (ctx.userEnabledOverrideableSDP()) {
856848
TORCH_CHECK(false, "Invalid backend");
857849
}
858850
break;
@@ -874,8 +866,6 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
874866
sdp::can_use_flash_attention(kernel_params, print_debug);
875867
TORCH_WARN("CuDNN attention kernel not used because:");
876868
sdp::can_use_cudnn_attention(kernel_params, print_debug);
877-
TORCH_WARN("Overrideable attention kernel not used because:");
878-
sdp::can_use_overrideable_attention(kernel_params, print_debug);
879869
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
880870
return SDPBackend::error;
881871
}

test/test_transformers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4259,15 +4259,15 @@ def test_backends_set_to_math(self, device):
42594259
_ = F.scaled_dot_product_attention(q, k, v)
42604260

42614261
def test_default_priority_order(self, device):
4262-
# The default priority order is flash, efficient, overrideable, math, cudnn
4263-
# For non-cuda backend, we need to make sure that flash > overrideable > math
4262+
# The default priority order of xpu is overrideable, math, flash, efficient, cudnn
4263+
# For xpu backend, we need to make sure that overrideable > math > flash
42644264
from torch.nn.attention import _cur_sdpa_kernel_backends
42654265
default_priority = _cur_sdpa_kernel_backends(with_priority=True)
42664266
flash_index = default_priority.index(SDPBackend.FLASH_ATTENTION)
42674267
overrideable_index = default_priority.index(SDPBackend.OVERRIDEABLE)
42684268
math_index = default_priority.index(SDPBackend.MATH)
4269-
self.assertTrue(flash_index < overrideable_index < math_index,
4270-
f"Expected flash < overrideable < math, got {flash_index}, {overrideable_index}, {math_index}")
4269+
self.assertTrue(overrideable_index < math_index < flash_index,
4270+
f"Expected overrideable < math < flash, got {overrideable_index}, {math_index}, {flash_index}")
42714271

42724272
def test_scaled_dot_product_attention_fused_kernels_safe_softmax(self, device):
42734273
dtype = torch.bfloat16

0 commit comments

Comments
 (0)