Skip to content

[Intel GPU] Support SDPA backend selection and priority setting on XPU #159464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

LuFinch
Copy link
Contributor

@LuFinch LuFinch commented Jul 30, 2025

Currentlly SPDA XPU use own priority_order instead of the one from global context. Hence it does not support with sdpa_kernel(order, set_priority=True) with set_priority=True.

This PR enables this feature. To make default priority_order from global context works for XPU, I also move MATH backend to lowest priority, otherwise cudnn attention and overrideable attention will never be selected.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Jul 30, 2025
Copy link

pytorch-bot bot commented Jul 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159464

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3d6a1cd with merge base b602ea9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@LuFinch
Copy link
Contributor Author

LuFinch commented Jul 30, 2025

@chunhuanMeng @guangyey Could you help take a look?

@albanD albanD requested review from drisspg and removed request for albanD July 30, 2025 13:34
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 30, 2025
Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we support with sdpa_kernel(order, set_priority=True): with set_priority=True

@guangyey guangyey moved this to Pre-Review Required in PyTorch Intel Jul 30, 2025
sdp::SDPBackend::flash_attention,
sdp::SDPBackend::math,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we currently don't support set_priority=True since our priority_order is a constant value.

@guangyey
Copy link
Collaborator

For PyTorch, math has high priority than flash_attention,

enum class SDPBackend {
error = -1,
math = 0,
flash_attention = 1,
efficient_attention = 2,
cudnn_attention = 3,
overrideable = 4
};

Do we really want to change this behavior for XPU?

@LuFinch
Copy link
Contributor Author

LuFinch commented Jul 31, 2025

@guangyey
For the above comments, I think we can support set_priority=True.

To support this feature, XPU need to use priority_order like

const auto ordering = priority_order(kernel_params);

As you know, the overrideable backend has lowest priority in the default priority_order and we don't have entry for flashattention backend at the beginning. This means it will run into math instead of overrideable beforce. I guess this is the reason why xpu use hard-code priority_order instead of priority_order from global context.

std::array<at::SDPBackend, at::num_sdp_backends> sdp_priority_order = {
at::SDPBackend::flash_attention,
at::SDPBackend::efficient_attention,
at::SDPBackend::math,
at::SDPBackend::cudnn_attention,
at::SDPBackend::overrideable};

Now we have added flashattention entry and fallback to overrideable. We can use priority_order from global context.

case sdp::SDPBackend::flash_attention:
if (ctx.userEnabledFlashSDP() &&
use_overrideable_xpu(kernel_params, print_debug)) {
TORCH_WARN(
"Flash Attention is not supported on XPU, falling back to overrideable kernel.");
return sdp::SDPBackend::overrideable;
}

@LuFinch
Copy link
Contributor Author

LuFinch commented Jul 31, 2025

However, in near future, we will add cutlass-sycl version SDPA to FlashAttention backend.
If user use default priority_order and input params don't choose FlashAttention backend, then it will run into MATH backend instead of OVERRIDEABLE backend due to default priority_order.

std::array<at::SDPBackend, at::num_sdp_backends> sdp_priority_order = {
at::SDPBackend::flash_attention,
at::SDPBackend::efficient_attention,
at::SDPBackend::math,
at::SDPBackend::cudnn_attention,
at::SDPBackend::overrideable};

@guangyey
Copy link
Collaborator

Let's have a discussion offline.

@LuFinch LuFinch changed the title [Intel GPU] Make SDPA FLASH_ATTENTION backend has higher priority than MATH backend on XPU [Intel GPU] Support SDPA backend selection and priority setting on XPU Jul 31, 2025
@LuFinch LuFinch requested a review from guangyey July 31, 2025 08:31
@guangyey guangyey added the topic: not user facing topic category label Jul 31, 2025
Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a ut to ensure overrideable by default has high priority than math

Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your update!

@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 1, 2025
@guangyey
Copy link
Collaborator

guangyey commented Aug 1, 2025

@drisspg May I know if this PR is reasonable to you. This PR doesn't change CUDA behavior, and ensures all non-cuda backends align to the cuda sdpa priority behavior.

@guangyey
Copy link
Collaborator

guangyey commented Aug 8, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased lfq/change_sdpa_priority onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout lfq/change_sdpa_priority && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the lfq/change_sdpa_priority branch from e5490c7 to 6506ea3 Compare August 8, 2025 03:09
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Aug 8, 2025
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 8, 2025
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change the global order for all privateuse1 backends?

@guangyey
Copy link
Collaborator

guangyey commented Aug 8, 2025

@drisspg I think privateuser1 will also benefit from this PR. It doesn't change the ordering between "flash_attention", "efficient_attention", and "math".
If privateuser1 isn't using "overrideable", this PR has no effect on their case. If they are, then it's reasonable for "overrideable" to take precedence over "math", since "math" is intended as the fallback implementation.
@LuFinch Please correct me if I am wrong.

@mikaylagawarecki mikaylagawarecki removed their request for review August 8, 2025 21:16
@@ -432,9 +432,9 @@ class TORCH_API Context {
std::array<at::SDPBackend, at::num_sdp_backends> sdp_priority_order = {
at::SDPBackend::flash_attention,
at::SDPBackend::efficient_attention,
at::SDPBackend::overrideable,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you undo this change, this code

// tracks whether we've set the default priority order once, to avoid setting
// it redundantly or overwriting a user-specified priority order
// when the priority order context manager is used before the default priority
// order is initialized the following happens:
// (1) the current priority order is queried
// (2) priority_order() is called, which initializes it to the default as init_ is false
// (3) the user-specified priority order is set
// (3.1) we are in the priority context...
// (3.2) we exit the priority context...
// (4) the previous priority order (default) is restored
bool priority_order_init_ = false;
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
bool check_prefer_cudnn_attention() {
static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") == true;
if (!prefer_cudnn) {
return false;
}
#if (defined(CUDNN_VERSION) && (CUDNN_VERSION > 90000))
auto dprops = at::cuda::getCurrentDeviceProperties();
return dprops->major >= 9 && !dprops->minor;
#else
return false;
#endif
}
// flash_attention V2 is universally faster than efficient_attention and Math
std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
if (!priority_order_init_) {
priority_order_init_ = true;
if (check_prefer_cudnn_attention()) {
const std::vector<int64_t> cudnn_order = {static_cast<int64_t>(at::SDPBackend::cudnn_attention),
static_cast<int64_t>(at::SDPBackend::flash_attention),
static_cast<int64_t>(at::SDPBackend::efficient_attention),
static_cast<int64_t>(at::SDPBackend::math)};
at::globalContext().setSDPPriorityOrder(cudnn_order);
}
}
return at::globalContext().sDPPriorityOrder();
}

shows how to set a order per a specific backend

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it looks good. We will follow this code to set XPU priority order. cc @LuFinch

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg May I know if we have addressed your comments.

@LuFinch LuFinch force-pushed the lfq/change_sdpa_priority branch from 6506ea3 to f4cde98 Compare August 11, 2025 03:42
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Aug 11, 2025
@LuFinch
Copy link
Contributor Author

LuFinch commented Aug 11, 2025

@guangyey I undo the changes on CUDA and update XPU priority order setting code. Please help review.

Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks better.

@guangyey guangyey requested a review from drisspg August 11, 2025 05:16
@guangyey
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased lfq/change_sdpa_priority onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout lfq/change_sdpa_priority && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the lfq/change_sdpa_priority branch from f4cde98 to 3d6a1cd Compare August 11, 2025 05:19
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/xpu Run XPU CI tasks module: cpu CPU specific problem (e.g., perf, algorithm) open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Review Required
Development

Successfully merging this pull request may close these issues.

7 participants