-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[Intel GPU] Enable backward for SDPA XPU [WIP] #156272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156272
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit d6ddf7a with merge base 908c5cc ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
7014f89
to
232387d
Compare
40bafeb
to
435e14a
Compare
672a28f
to
4d05632
Compare
4d05632
to
6de83d4
Compare
be64cfc
to
c8d7c3b
Compare
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
@EikanWang Could you also help review this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need more time to go through this PR...
@guangyey Sure. Thank you for your time to help review. For more detail of OneDNN, you could refer to https://uxlfoundation.github.io/oneDNN/dev_guide_graph_sdpa.html#floating-point-sdpa-for-training-backpropagation.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enables backward pass support for Scaled Dot Product Attention (SDPA) on Intel XPU devices by implementing the backward kernel for the OVERRIDEABLE backend. The changes focus on extending the existing forward-only implementation to support gradient computation.
- Adds
compute_log_sumexp
parameter to the forward pass to conditionally compute logsumexp values needed for backward pass - Implements the complete backward pass kernel
_scaled_dot_product_fused_attention_overrideable_backward_xpu
using OneDNN graph operations - Updates test coverage to include gradient computation validation against the math reference implementation
Reviewed Changes
Copilot reviewed 9 out of 12 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
torch/_meta_registrations.py | Adds meta registration for backward function and updates forward signature with compute_log_sumexp parameter |
tools/autograd/derivatives.yaml | Updates function signature to include compute_log_sumexp parameter |
test/test_transformers.py | Enables training mode tests and updates test logic to validate gradient computation |
cmake/Modules/FindMKLDNN.cmake | Updates OneDNN dependency to version 3.9 which supports required backward operations |
aten/src/ATen/native/transformers/attention.cpp | Adds logic to determine when logsumexp computation is needed for backward pass |
aten/src/ATen/native/native_functions.yaml | Updates function signatures and adds XPU dispatch for backward function |
aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h | Adds function declarations for forward and backward SDPA operations |
aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp | Implements complete OneDNN graph-based backward pass logic |
aten/src/ATen/native/mkldnn/xpu/Attention.cpp | Implements XPU-specific backward function wrapper and updates gradient checking logic |
Comments suppressed due to low confidence (2)
aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp:963
- Typo in comment: 'explict' should be 'explicit'.
// and the reference implementation is worse than aten math + explict causal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LuFinch , the newly added parameter of sdpa backward overridable is not backward compatible.
@@ -14965,7 +14965,7 @@ | |||
CPU: _scaled_dot_product_flash_attention_cpu | |||
tags: nondeterministic_seeded | |||
|
|||
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) | |||
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, bool compute_log_sumexp=False, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LuFinch , why do we need to add comput_log_sumexp
? It breaks the ABI backward compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we can check input tensors' attr like compute_logsumexp = query.requires_grad() || key.requires_grad() || value.requires_grad()
to decide whether compute logsumexp. And this checking works in eager mode.
However, in torch.compile mode, the input tensors query/key/value require_grad()==True
at the beginning would become require_grad()==False
in the op after aot_autograd in some models. Hence, it needs a bool flag to indicate this op should compute logsumexp. I am not an expert of aot_autograd and not sure why it acts like this. But cudnn and efficient attention also has this parameter. I guess they meet the same issue, otherwise they should be able to move this check into op.
pytorch/aten/src/ATen/native/transformers/attention.cpp
Lines 739 to 742 in 78d7f0c
case SDPBackend::cudnn_attention: { | |
bool compute_logsumexp = should_compute_logsumexp(query_, key, value); | |
auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention( | |
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale); |
pytorch/aten/src/ATen/native/transformers/attention.cpp
Lines 761 to 769 in 78d7f0c
case SDPBackend::efficient_attention: { | |
bool compute_logsumexp = should_compute_logsumexp(query_, key, value); | |
if (attn_mask.has_value()) { | |
attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);; | |
} | |
auto out_and_lse = at::_scaled_dot_product_efficient_attention( | |
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, scale); | |
return std::get<0>(out_and_lse); | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LuFinch , why do we need to add
comput_log_sumexp
? It breaks the ABI backward compatibility.
By default comput_log_sumexp=False
, this should not break API-level BC, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eikan recommends me to move this argument as the last argument, then it will not break BC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
(attn_mask.has_value() && attn_mask.value().requires_grad())); | ||
} | ||
|
||
bool check_grad(sdp::sdp_params const& params, bool debug) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the implementation details, it returns True when
- Grad mode is not enabled
- All input tensors do not require a gradient
- Not Group Query Attention and the attention mask do not require a gradient
@LuFinch , is my understanding correct? If so, I would suggest refining the name of check_grad
a little bit. Something could be like is_onednn_attention_backward_supported
to illustrate your idea clearly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should be used to check the grad requirements of inputs to determine whether they are suitable for supporting overrideable SDPA on XPU in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As guangye saying, it is use to determine whether use overrideable SDPA. If return True, then it can use overrideable SDPA.
In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't output grad for attn_mask.
Hence this function means:
- If Grad mode is not enabled, we can use overrideable SDPA to run OneDNN SDPA inference forward graph.
- Grad mode is enabled but none of q/k/v needs grad, we can use overrideable SDPA to run OneDNN SDPA inference forward graph.
- If we need to compute grad, it is not GQA and attn_mask don't require gard, then we can use overrideable SDPA to run OneDNN SDPA training forward graph.
- Otherwise, it should fallback to MATH backend.
auto k_num_heads = params.key.sym_size(-3); | ||
auto v_num_heads = params.value.sym_size(-3); | ||
bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads; | ||
if (debug && is_gqa) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it has been gqa
, why does this function return false
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't compute grad for attn_mask.
|
||
bool attn_mask_needs_grad = | ||
params.attn_mask.has_value() && params.attn_mask.value().requires_grad(); | ||
if (debug && attn_mask_needs_grad) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't compute grad for attn_mask.
auto grad_attn_bias = attn_bias_opt.has_value() | ||
? at::empty_like(attn_bias_opt.value()) | ||
: at::Tensor(); | ||
at::native::onednn::gpu_float_sdpa_backward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gpu_float_sdpa_backward
has been defined? Does it mean the backward function only supports float
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the same question.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It supports FP32/FP16/BF16. I directly copy this function name from SDPA inference. We could rename it.
grad_out.dim() == 4 && out.dim() == 4 && | ||
grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) && | ||
grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3), | ||
"scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {(B), H, T, K}"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the meaning of (B)
?
is_causal, logical_params); | ||
auto i = logical_params.get_input(); | ||
auto o = logical_params.get_output(); | ||
auto compiled_partition = partition_.compile(i, o, eng); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This variable shadows a similar declaration at line 972. It's fine but not good. I recommend renaming this to avoid name shadowing.
if (is_causal) { | ||
neg_inf = at::full( | ||
{}, | ||
-INFINITY, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at::numeric_limits<>::lower_bound
is better.
inputs.reserve(l_inputs.size()); | ||
inputs.emplace_back(l_inputs[i++], eng, grad_out.data_ptr()); | ||
inputs.emplace_back(l_inputs[i++], eng, query.data_ptr()); | ||
inputs.emplace_back(l_inputs[i++], eng, key.data_ptr()); | ||
inputs.emplace_back(l_inputs[i++], eng, value.data_ptr()); | ||
inputs.emplace_back(l_inputs[i++], eng, out.data_ptr()); | ||
inputs.emplace_back(l_inputs[i++], eng, logsumexp.data_ptr()); | ||
inputs.emplace_back(l_inputs[i++], eng, softmax_scale.data_ptr()); | ||
if (neg_inf.has_value()) { | ||
inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr()); | ||
} | ||
if (attn_mask.has_value()) { | ||
inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a macro to reduce the duplicated code. such as
#define ADD_INPUT(variable) \
inputs.emplace_back(l_inputs[i++], eng, variable.data_ptr())
ADD_INPUT(grad_out);
ADD_INPUT(query);
...
#undef ADD_INPUT
partition& find_or_create_backward_graph_partition( | ||
bool is_causal, | ||
const SDPABackwardLogicalParams& params) { | ||
thread_local static PartitionCache cache; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thread_local static PartitionCache cache; | |
thread_local PartitionCache cache; |
std::bitset<32> patternID; | ||
if (dtype == data_type::f32) { | ||
// bit 3 corresponds to float32 dtype | ||
patternID.set(3, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine. But I recommend using a name like kBitFloat32
instead of hardcoded number 3
. Another one is that kBitFloat32
could be shared with find_or_create_graph_partition
at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor()); | ||
if (at::native::onednn::is_broadcast(reshaped_query)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this code change, sdpa will not support broadcast anymore. Is this a BC breaking, any impact for the old script.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @EikanWang @fengyuan14 @guangyey