Skip to content

[Intel GPU] Enable backward for SDPA XPU [WIP] #156272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 146 additions & 15 deletions aten/src/ATen/native/mkldnn/xpu/Attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,37 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) {
return true;
}

bool check_no_grad(sdp::sdp_params const& params, bool debug) {
const bool any_inputs_require_grad = params.query.requires_grad() ||
params.key.requires_grad() || params.value.requires_grad();
const bool gradmode_enabled = at::GradMode::is_enabled();
if (debug && any_inputs_require_grad && gradmode_enabled) {
TORCH_WARN("Backward or grad to be supported.");
bool input_require_grad(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_mask) {
return at::GradMode::is_enabled() &&
(query.requires_grad() || key.requires_grad() || value.requires_grad() ||
(attn_mask.has_value() && attn_mask.value().requires_grad()));
}

bool check_grad(sdp::sdp_params const& params, bool debug) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the implementation details, it returns True when

  • Grad mode is not enabled
  • All input tensors do not require a gradient
  • Not Group Query Attention and the attention mask do not require a gradient

@LuFinch , is my understanding correct? If so, I would suggest refining the name of check_grad a little bit. Something could be like is_onednn_attention_backward_supported to illustrate your idea clearly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be used to check the grad requirements of inputs to determine whether they are suitable for supporting overrideable SDPA on XPU in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As guangye saying, it is use to determine whether use overrideable SDPA. If return True, then it can use overrideable SDPA.

In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't output grad for attn_mask.

Hence this function means:

  • If Grad mode is not enabled, we can use overrideable SDPA to run OneDNN SDPA inference forward graph.
  • Grad mode is enabled but none of q/k/v needs grad, we can use overrideable SDPA to run OneDNN SDPA inference forward graph.
  • If we need to compute grad, it is not GQA and attn_mask don't require gard, then we can use overrideable SDPA to run OneDNN SDPA training forward graph.
  • Otherwise, it should fallback to MATH backend.

if (!input_require_grad(
params.query, params.key, params.value, params.attn_mask))
return true;

auto q_num_heads = params.query.sym_size(-3);
auto k_num_heads = params.key.sym_size(-3);
auto v_num_heads = params.value.sym_size(-3);
bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads;
if (debug && is_gqa)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it has been gqa, why does this function return false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't compute grad for attn_mask.

TORCH_WARN(
"scale_dot_product_attention with gqa is not supported for gradient computation on xpu.");

bool attn_mask_needs_grad =
params.attn_mask.has_value() && params.attn_mask.value().requires_grad();
if (debug && attn_mask_needs_grad) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't compute grad for attn_mask.

TORCH_WARN(
"scale_dot_product_attention on xpu is not supported when attn_mask.requires_grad() == True.");
}
return !any_inputs_require_grad || !gradmode_enabled;

return !is_gqa && !attn_mask_needs_grad;
}

bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
Expand All @@ -64,7 +87,7 @@ bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
sdp::check_nonzero_sequence_lengths_dense,
sdp::check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>,
check_head_dim_size_xpu,
check_no_grad);
check_grad);
for (auto& constraint : constraints) {
if (!constraint(params, debug)) {
return false;
Expand Down Expand Up @@ -171,6 +194,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_bias,
bool compute_logsumexp,
double dropout_p,
bool is_causal,
bool return_debug_mask,
Expand All @@ -194,6 +218,9 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
TORCH_INTERNAL_ASSERT(
!(attn_bias.has_value() && is_causal),
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal");
TORCH_INTERNAL_ASSERT(
!(attn_bias.has_value() && attn_bias.value().requires_grad()),
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot have requires_grad=True");

const int64_t batch_size = query.size(0);
const int64_t num_head_q = query.size(1);
Expand All @@ -203,11 +230,14 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
const int64_t seq_len_q = query.size(2);
const int64_t seq_len_kv = key.size(2);

at::Tensor output;
std::vector<int64_t> output_shape = {
at::Tensor attention;
std::vector<int64_t> attention_shape = {
batch_size, num_head_q, seq_len_q, head_dim_v};
alloc_with_matching_layout(query, output, output_shape);
at::Tensor logsumexp, debug_attn_mask; // not supported
alloc_with_matching_layout(query, attention, attention_shape);

auto opts = query.options();
at::Tensor logsumexp =
at::empty({batch_size, num_head_q, seq_len_q}, opts.dtype(at::kFloat));

at::native::onednn::gpu_float_sdpa(
batch_size,
Expand All @@ -223,21 +253,122 @@ _scaled_dot_product_fused_attention_overrideable_xpu(
attn_bias,
is_causal,
scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim_qk)),
output);
attention,
compute_logsumexp,
logsumexp);

// rng not used
auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong));
return std::make_tuple(
output,
attention,
logsumexp,
/* cum_seq_q */ at::Tensor(),
/* cum_seq_k */ at::Tensor(),
seq_len_q,
seq_len_kv,
philox_seed,
philox_offset,
debug_attn_mask);
/*debug_attn_mask */ at::Tensor());
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward_xpu(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& attn_bias,
std::array<bool, 4> grad_input_mask,
const at::Tensor& out,
const at::Tensor& logsumexp,
const at::Tensor& cum_seq_q,
const at::Tensor& cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
std::optional<double> scale) {
TORCH_INTERNAL_ASSERT(
grad_out.dim() == 4 && out.dim() == 4 &&
grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) &&
grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {(B), H, T, K}");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of (B)?

TORCH_INTERNAL_ASSERT(
query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}");
TORCH_INTERNAL_ASSERT(
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) &&
(key.size(2) == value.size(2)),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: K/V should have the same batch / seq / num_head");
TORCH_INTERNAL_ASSERT(
query.size(0) == grad_out.size(0) && query.size(1) == grad_out.size(1) &&
query.size(2) == grad_out.size(2),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Q should have the same batch / num_head / seq_len as grad_out");
TORCH_INTERNAL_ASSERT(
query.size(3) == key.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Q/K should have the same head_dim");
TORCH_INTERNAL_ASSERT(
value.size(3) == grad_out.size(3),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: V should have the same head_dim as grad_out");
TORCH_INTERNAL_ASSERT(
query.size(1) == key.size(1),
"scaled_dot_product_fused_attention_overrideable_backward_xpu: number of heads in K/V must equal to number of heads in Q");
TORCH_INTERNAL_ASSERT(
dropout_p == 0.0,
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Currently do not support dropout > 0");
TORCH_INTERNAL_ASSERT(
logsumexp.dim() == 3 && logsumexp.size(0) == query.size(0) &&
logsumexp.size(1) == query.size(1) &&
logsumexp.size(2) == query.size(2) &&
"scaled_dot_product_fused_attention_overrideable_backward_xpu: logsumexp should have the shape of {(B), H, T}");

std::optional<Tensor> attn_bias_opt;
if (attn_bias.defined()) {
attn_bias_opt = attn_bias;
}

const int64_t batch_size = query.size(0);
const int64_t num_head_q = query.size(1);
const int64_t num_head_kv = key.size(1);
const int64_t seq_len_q = query.size(2);
const int64_t seq_len_kv = key.size(2);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);

auto grad_q = at::empty_like(query);
auto grad_k = at::empty_like(key);
auto grad_v = at::empty_like(value);
auto grad_attn_bias = attn_bias_opt.has_value()
? at::empty_like(attn_bias_opt.value())
: at::Tensor();
at::native::onednn::gpu_float_sdpa_backward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu_float_sdpa_backward has been defined? Does it mean the backward function only supports float?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It supports FP32/FP16/BF16. I directly copy this function name from SDPA inference. We could rename it.

batch_size,
num_head_q,
num_head_kv,
seq_len_q,
seq_len_kv,
head_dim_qk,
head_dim_v,
grad_out,
query,
key,
value,
out,
logsumexp,
attn_bias_opt,
is_causal,
scale.has_value() ? scale.value() : (1.0 / std::sqrt(query.size(3))),
grad_q,
grad_k,
grad_v);
return std::make_tuple(
std::move(grad_q),
std::move(grad_k),
std::move(grad_v),
std::move(grad_attn_bias));
}

REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu);
Expand Down
Loading
Loading