-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[Intel GPU] Support SDPA backend selection and priority setting on XPU #159464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159464
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3d6a1cd with merge base b602ea9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@chunhuanMeng @guangyey Could you help take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if we support with sdpa_kernel(order, set_priority=True):
with set_priority=True
sdp::SDPBackend::flash_attention, | ||
sdp::SDPBackend::math, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we currently don't support set_priority=True since our priority_order
is a constant value.
For PyTorch, pytorch/aten/src/ATen/SDPBackend.h Lines 7 to 14 in 1465757
Do we really want to change this behavior for XPU? |
@guangyey To support this feature, XPU need to use
As you know, the overrideable backend has lowest priority in the default pytorch/aten/src/ATen/Context.h Lines 432 to 437 in d7a5ec9
Now we have added flashattention entry and fallback to overrideable. We can use pytorch/aten/src/ATen/native/mkldnn/xpu/Attention.cpp Lines 110 to 116 in d7a5ec9
|
However, in near future, we will add cutlass-sycl version SDPA to FlashAttention backend. pytorch/aten/src/ATen/Context.h Lines 432 to 437 in d7a5ec9
|
Let's have a discussion offline. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a ut to ensure overrideable
by default has high priority than math
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your update!
@drisspg May I know if this PR is reasonable to you. This PR doesn't change CUDA behavior, and ensures all non-cuda backends align to the cuda sdpa priority behavior. |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
e5490c7
to
6506ea3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this change the global order for all privateuse1 backends?
@drisspg I think privateuser1 will also benefit from this PR. It doesn't change the ordering between "flash_attention", "efficient_attention", and "math". |
aten/src/ATen/Context.h
Outdated
@@ -432,9 +432,9 @@ class TORCH_API Context { | |||
std::array<at::SDPBackend, at::num_sdp_backends> sdp_priority_order = { | |||
at::SDPBackend::flash_attention, | |||
at::SDPBackend::efficient_attention, | |||
at::SDPBackend::overrideable, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you undo this change, this code
pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Lines 60 to 100 in 731ee31
// tracks whether we've set the default priority order once, to avoid setting | |
// it redundantly or overwriting a user-specified priority order | |
// when the priority order context manager is used before the default priority | |
// order is initialized the following happens: | |
// (1) the current priority order is queried | |
// (2) priority_order() is called, which initializes it to the default as init_ is false | |
// (3) the user-specified priority order is set | |
// (3.1) we are in the priority context... | |
// (3.2) we exit the priority context... | |
// (4) the previous priority order (default) is restored | |
bool priority_order_init_ = false; | |
// TODO(eqy): more benchmarking to determine whether this should include sm86/89 | |
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py | |
bool check_prefer_cudnn_attention() { | |
static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") == true; | |
if (!prefer_cudnn) { | |
return false; | |
} | |
#if (defined(CUDNN_VERSION) && (CUDNN_VERSION > 90000)) | |
auto dprops = at::cuda::getCurrentDeviceProperties(); | |
return dprops->major >= 9 && !dprops->minor; | |
#else | |
return false; | |
#endif | |
} | |
// flash_attention V2 is universally faster than efficient_attention and Math | |
std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) { | |
if (!priority_order_init_) { | |
priority_order_init_ = true; | |
if (check_prefer_cudnn_attention()) { | |
const std::vector<int64_t> cudnn_order = {static_cast<int64_t>(at::SDPBackend::cudnn_attention), | |
static_cast<int64_t>(at::SDPBackend::flash_attention), | |
static_cast<int64_t>(at::SDPBackend::efficient_attention), | |
static_cast<int64_t>(at::SDPBackend::math)}; | |
at::globalContext().setSDPPriorityOrder(cudnn_order); | |
} | |
} | |
return at::globalContext().sDPPriorityOrder(); | |
} |
shows how to set a order per a specific backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it looks good. We will follow this code to set XPU priority order. cc @LuFinch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg May I know if we have addressed your comments.
6506ea3
to
f4cde98
Compare
@guangyey I undo the changes on CUDA and update XPU priority order setting code. Please help review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks better.
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Successfully rebased |
f4cde98
to
3d6a1cd
Compare
Currentlly SPDA XPU use own
priority_order
instead of the one from global context. Hence it does not supportwith sdpa_kernel(order, set_priority=True)
with set_priority=True.This PR enables this feature. To make default
priority_order
from global context works for XPU, I also move MATH backend to lowest priority, otherwisecudnn attention
andoverrideable attention
will never be selected.cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168