Skip to content

[Intel GPU] Enable backward for SDPA XPU [WIP] #156272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

LuFinch
Copy link
Contributor

@LuFinch LuFinch commented Jun 18, 2025

Copy link

pytorch-bot bot commented Jun 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156272

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures

As of commit d6ddf7a with merge base 908c5cc (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration labels Jun 18, 2025
@LuFinch LuFinch changed the title [Intel GPU] Enable training for SDPA XPU [Intel GPU] Enable training for SDPA XPU [WIP] Jun 18, 2025
@LuFinch
Copy link
Contributor Author

LuFinch commented Jun 18, 2025

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jun 18, 2025
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@LuFinch LuFinch force-pushed the lfq/sdpa_traning branch from 7014f89 to 232387d Compare June 18, 2025 11:03
@guangyey guangyey moved this to In Progress in PyTorch Intel Jun 19, 2025
@LuFinch LuFinch force-pushed the lfq/sdpa_traning branch 7 times, most recently from 40bafeb to 435e14a Compare June 24, 2025 06:27
@LuFinch LuFinch changed the title [Intel GPU] Enable training for SDPA XPU [WIP] [Intel GPU] Enable backward for SDPA XPU [WIP] Jun 25, 2025
@LuFinch LuFinch force-pushed the lfq/sdpa_traning branch from 672a28f to 4d05632 Compare July 16, 2025 06:32
@LuFinch LuFinch force-pushed the lfq/sdpa_traning branch from 4d05632 to 6de83d4 Compare July 24, 2025 01:56
@LuFinch LuFinch force-pushed the lfq/sdpa_traning branch from be64cfc to c8d7c3b Compare July 30, 2025 03:33
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Jul 31, 2025
Copy link

pytorch-bot bot commented Jul 31, 2025

To add the ciflow label ciflow/xpu please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Jul 31, 2025
@guangyey guangyey added the module: xpu Intel XPU related issues label Jul 31, 2025
@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 5, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Aug 5, 2025
@albanD albanD removed their request for review August 5, 2025 15:24
@LuFinch
Copy link
Contributor Author

LuFinch commented Aug 8, 2025

@EikanWang Could you also help review this PR?

@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 8, 2025
Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need more time to go through this PR...

@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Aug 11, 2025
@LuFinch
Copy link
Contributor Author

LuFinch commented Aug 11, 2025

@guangyey Sure. Thank you for your time to help review.

For more detail of OneDNN, you could refer to https://uxlfoundation.github.io/oneDNN/dev_guide_graph_sdpa.html#floating-point-sdpa-for-training-backpropagation.

I need more time to go through this PR...

@guangyey guangyey added the ciflow/xpu Run XPU CI tasks label Aug 11, 2025
@EikanWang EikanWang requested a review from Copilot August 11, 2025 22:44
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enables backward pass support for Scaled Dot Product Attention (SDPA) on Intel XPU devices by implementing the backward kernel for the OVERRIDEABLE backend. The changes focus on extending the existing forward-only implementation to support gradient computation.

  • Adds compute_log_sumexp parameter to the forward pass to conditionally compute logsumexp values needed for backward pass
  • Implements the complete backward pass kernel _scaled_dot_product_fused_attention_overrideable_backward_xpu using OneDNN graph operations
  • Updates test coverage to include gradient computation validation against the math reference implementation

Reviewed Changes

Copilot reviewed 9 out of 12 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
torch/_meta_registrations.py Adds meta registration for backward function and updates forward signature with compute_log_sumexp parameter
tools/autograd/derivatives.yaml Updates function signature to include compute_log_sumexp parameter
test/test_transformers.py Enables training mode tests and updates test logic to validate gradient computation
cmake/Modules/FindMKLDNN.cmake Updates OneDNN dependency to version 3.9 which supports required backward operations
aten/src/ATen/native/transformers/attention.cpp Adds logic to determine when logsumexp computation is needed for backward pass
aten/src/ATen/native/native_functions.yaml Updates function signatures and adds XPU dispatch for backward function
aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h Adds function declarations for forward and backward SDPA operations
aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp Implements complete OneDNN graph-based backward pass logic
aten/src/ATen/native/mkldnn/xpu/Attention.cpp Implements XPU-specific backward function wrapper and updates gradient checking logic
Comments suppressed due to low confidence (2)

aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp:963

  • Typo in comment: 'explict' should be 'explicit'.
  // and the reference implementation is worse than aten math + explict causal

@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Aug 12, 2025
Copy link
Collaborator

@EikanWang EikanWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LuFinch , the newly added parameter of sdpa backward overridable is not backward compatible.

@@ -14965,7 +14965,7 @@
CPU: _scaled_dot_product_flash_attention_cpu
tags: nondeterministic_seeded

- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, bool compute_log_sumexp=False, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LuFinch , why do we need to add comput_log_sumexp? It breaks the ABI backward compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we can check input tensors' attr like compute_logsumexp = query.requires_grad() || key.requires_grad() || value.requires_grad() to decide whether compute logsumexp. And this checking works in eager mode.

However, in torch.compile mode, the input tensors query/key/value require_grad()==True at the beginning would become require_grad()==False in the op after aot_autograd in some models. Hence, it needs a bool flag to indicate this op should compute logsumexp. I am not an expert of aot_autograd and not sure why it acts like this. But cudnn and efficient attention also has this parameter. I guess they meet the same issue, otherwise they should be able to move this check into op.

case SDPBackend::cudnn_attention: {
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention(
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale);

case SDPBackend::efficient_attention: {
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
if (attn_mask.has_value()) {
attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);;
}
auto out_and_lse = at::_scaled_dot_product_efficient_attention(
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, scale);
return std::get<0>(out_and_lse);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LuFinch , why do we need to add comput_log_sumexp? It breaks the ABI backward compatibility.

By default comput_log_sumexp=False, this should not break API-level BC, right?

(attn_mask.has_value() && attn_mask.value().requires_grad()));
}

bool check_grad(sdp::sdp_params const& params, bool debug) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the implementation details, it returns True when

  • Grad mode is not enabled
  • All input tensors do not require a gradient
  • Not Group Query Attention and the attention mask do not require a gradient

@LuFinch , is my understanding correct? If so, I would suggest refining the name of check_grad a little bit. Something could be like is_onednn_attention_backward_supported to illustrate your idea clearly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be used to check the grad requirements of inputs to determine whether they are suitable for supporting overrideable SDPA on XPU in the future.

auto k_num_heads = params.key.sym_size(-3);
auto v_num_heads = params.value.sym_size(-3);
bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads;
if (debug && is_gqa)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it has been gqa, why does this function return false?


bool attn_mask_needs_grad =
params.attn_mask.has_value() && params.attn_mask.value().requires_grad();
if (debug && attn_mask_needs_grad) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

auto grad_attn_bias = attn_bias_opt.has_value()
? at::empty_like(attn_bias_opt.value())
: at::Tensor();
at::native::onednn::gpu_float_sdpa_backward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu_float_sdpa_backward has been defined? Does it mean the backward function only supports float?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question.

grad_out.dim() == 4 && out.dim() == 4 &&
grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) &&
grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {(B), H, T, K}");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of (B)?

is_causal, logical_params);
auto i = logical_params.get_input();
auto o = logical_params.get_output();
auto compiled_partition = partition_.compile(i, o, eng);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable shadows a similar declaration at line 972. It's fine but not good. I recommend renaming this to avoid name shadowing.

if (is_causal) {
neg_inf = at::full(
{},
-INFINITY,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at::numeric_limits<>::lower_bound is better.

Comment on lines +1024 to +1036
inputs.reserve(l_inputs.size());
inputs.emplace_back(l_inputs[i++], eng, grad_out.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, query.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, key.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, value.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, out.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, logsumexp.data_ptr());
inputs.emplace_back(l_inputs[i++], eng, softmax_scale.data_ptr());
if (neg_inf.has_value()) {
inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr());
}
if (attn_mask.has_value()) {
inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a macro to reduce the duplicated code. such as

#define ADD_INPUT(variable) \
inputs.emplace_back(l_inputs[i++], eng, variable.data_ptr())
ADD_INPUT(grad_out);
ADD_INPUT(query);
...
#undef ADD_INPUT

partition& find_or_create_backward_graph_partition(
bool is_causal,
const SDPABackwardLogicalParams& params) {
thread_local static PartitionCache cache;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
thread_local static PartitionCache cache;
thread_local PartitionCache cache;

std::bitset<32> patternID;
if (dtype == data_type::f32) {
// bit 3 corresponds to float32 dtype
patternID.set(3, 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine. But I recommend using a name like kBitFloat32 instead of hardcoded number 3. Another one is that kBitFloat32 could be shared with find_or_create_graph_partition

at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor());
if (at::native::onednn::is_broadcast(reshaped_query)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this code change, sdpa will not support broadcast anymore. Is this a BC breaking, any impact for the old script.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpu CPU specific problem (e.g., perf, algorithm) module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration module: xpu Intel XPU related issues open source release notes: inductor (aoti) topic: not user facing topic category
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

5 participants