-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[Intel GPU] Enable backward for SDPA XPU [WIP] #156272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0c913b3
9c74464
1676802
2fac2fd
2cfb19a
6714c5d
c8d7c3b
01c0085
741ac4d
b028d72
7b46e0f
d6ddf7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,14 +39,37 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) { | |
return true; | ||
} | ||
|
||
bool check_no_grad(sdp::sdp_params const& params, bool debug) { | ||
const bool any_inputs_require_grad = params.query.requires_grad() || | ||
params.key.requires_grad() || params.value.requires_grad(); | ||
const bool gradmode_enabled = at::GradMode::is_enabled(); | ||
if (debug && any_inputs_require_grad && gradmode_enabled) { | ||
TORCH_WARN("Backward or grad to be supported."); | ||
bool input_require_grad( | ||
const at::Tensor& query, | ||
const at::Tensor& key, | ||
const at::Tensor& value, | ||
const std::optional<at::Tensor>& attn_mask) { | ||
return at::GradMode::is_enabled() && | ||
(query.requires_grad() || key.requires_grad() || value.requires_grad() || | ||
(attn_mask.has_value() && attn_mask.value().requires_grad())); | ||
} | ||
|
||
bool check_grad(sdp::sdp_params const& params, bool debug) { | ||
if (!input_require_grad( | ||
params.query, params.key, params.value, params.attn_mask)) | ||
return true; | ||
|
||
auto q_num_heads = params.query.sym_size(-3); | ||
auto k_num_heads = params.key.sym_size(-3); | ||
auto v_num_heads = params.value.sym_size(-3); | ||
bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads; | ||
if (debug && is_gqa) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it has been There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't compute grad for attn_mask. |
||
TORCH_WARN( | ||
"scale_dot_product_attention with gqa is not supported for gradient computation on xpu."); | ||
|
||
bool attn_mask_needs_grad = | ||
params.attn_mask.has_value() && params.attn_mask.value().requires_grad(); | ||
if (debug && attn_mask_needs_grad) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't compute grad for attn_mask. |
||
TORCH_WARN( | ||
"scale_dot_product_attention on xpu is not supported when attn_mask.requires_grad() == True."); | ||
} | ||
return !any_inputs_require_grad || !gradmode_enabled; | ||
|
||
return !is_gqa && !attn_mask_needs_grad; | ||
} | ||
|
||
bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) { | ||
|
@@ -64,7 +87,7 @@ bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) { | |
sdp::check_nonzero_sequence_lengths_dense, | ||
sdp::check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>, | ||
check_head_dim_size_xpu, | ||
check_no_grad); | ||
check_grad); | ||
for (auto& constraint : constraints) { | ||
if (!constraint(params, debug)) { | ||
return false; | ||
|
@@ -171,6 +194,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu( | |
const at::Tensor& key, | ||
const at::Tensor& value, | ||
const std::optional<at::Tensor>& attn_bias, | ||
bool compute_logsumexp, | ||
double dropout_p, | ||
bool is_causal, | ||
bool return_debug_mask, | ||
|
@@ -194,6 +218,9 @@ _scaled_dot_product_fused_attention_overrideable_xpu( | |
TORCH_INTERNAL_ASSERT( | ||
!(attn_bias.has_value() && is_causal), | ||
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal"); | ||
TORCH_INTERNAL_ASSERT( | ||
!(attn_bias.has_value() && attn_bias.value().requires_grad()), | ||
"scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot have requires_grad=True"); | ||
|
||
const int64_t batch_size = query.size(0); | ||
const int64_t num_head_q = query.size(1); | ||
|
@@ -203,11 +230,14 @@ _scaled_dot_product_fused_attention_overrideable_xpu( | |
const int64_t seq_len_q = query.size(2); | ||
const int64_t seq_len_kv = key.size(2); | ||
|
||
at::Tensor output; | ||
std::vector<int64_t> output_shape = { | ||
at::Tensor attention; | ||
std::vector<int64_t> attention_shape = { | ||
batch_size, num_head_q, seq_len_q, head_dim_v}; | ||
alloc_with_matching_layout(query, output, output_shape); | ||
at::Tensor logsumexp, debug_attn_mask; // not supported | ||
alloc_with_matching_layout(query, attention, attention_shape); | ||
|
||
auto opts = query.options(); | ||
at::Tensor logsumexp = | ||
at::empty({batch_size, num_head_q, seq_len_q}, opts.dtype(at::kFloat)); | ||
|
||
at::native::onednn::gpu_float_sdpa( | ||
batch_size, | ||
|
@@ -223,21 +253,122 @@ _scaled_dot_product_fused_attention_overrideable_xpu( | |
attn_bias, | ||
is_causal, | ||
scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim_qk)), | ||
output); | ||
attention, | ||
compute_logsumexp, | ||
logsumexp); | ||
|
||
// rng not used | ||
auto philox_seed = at::empty({}, at::dtype(at::kLong)); | ||
auto philox_offset = at::empty({}, at::dtype(at::kLong)); | ||
return std::make_tuple( | ||
output, | ||
attention, | ||
logsumexp, | ||
/* cum_seq_q */ at::Tensor(), | ||
/* cum_seq_k */ at::Tensor(), | ||
seq_len_q, | ||
seq_len_kv, | ||
philox_seed, | ||
philox_offset, | ||
debug_attn_mask); | ||
/*debug_attn_mask */ at::Tensor()); | ||
} | ||
|
||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> | ||
_scaled_dot_product_fused_attention_overrideable_backward_xpu( | ||
const at::Tensor& grad_out, | ||
const at::Tensor& query, | ||
const at::Tensor& key, | ||
const at::Tensor& value, | ||
const at::Tensor& attn_bias, | ||
std::array<bool, 4> grad_input_mask, | ||
const at::Tensor& out, | ||
const at::Tensor& logsumexp, | ||
const at::Tensor& cum_seq_q, | ||
const at::Tensor& cum_seq_k, | ||
int64_t max_q, | ||
int64_t max_k, | ||
double dropout_p, | ||
bool is_causal, | ||
const at::Tensor& philox_seed, | ||
const at::Tensor& philox_offset, | ||
std::optional<double> scale) { | ||
TORCH_INTERNAL_ASSERT( | ||
grad_out.dim() == 4 && out.dim() == 4 && | ||
grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) && | ||
grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3), | ||
"scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {(B), H, T, K}"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the meaning of |
||
TORCH_INTERNAL_ASSERT( | ||
query.dim() == 4 && key.dim() == 4 && value.dim() == 4, | ||
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}"); | ||
TORCH_INTERNAL_ASSERT( | ||
(key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) && | ||
(key.size(2) == value.size(2)), | ||
"scaled_dot_product_fused_attention_overrideable_backward_xpu: K/V should have the same batch / seq / num_head"); | ||
TORCH_INTERNAL_ASSERT( | ||
query.size(0) == grad_out.size(0) && query.size(1) == grad_out.size(1) && | ||
query.size(2) == grad_out.size(2), | ||
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Q should have the same batch / num_head / seq_len as grad_out"); | ||
TORCH_INTERNAL_ASSERT( | ||
query.size(3) == key.size(3), | ||
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Q/K should have the same head_dim"); | ||
TORCH_INTERNAL_ASSERT( | ||
value.size(3) == grad_out.size(3), | ||
"scaled_dot_product_fused_attention_overrideable_backward_xpu: V should have the same head_dim as grad_out"); | ||
TORCH_INTERNAL_ASSERT( | ||
query.size(1) == key.size(1), | ||
"scaled_dot_product_fused_attention_overrideable_backward_xpu: number of heads in K/V must equal to number of heads in Q"); | ||
TORCH_INTERNAL_ASSERT( | ||
dropout_p == 0.0, | ||
"scaled_dot_product_fused_attention_overrideable_backward_xpu: Currently do not support dropout > 0"); | ||
TORCH_INTERNAL_ASSERT( | ||
logsumexp.dim() == 3 && logsumexp.size(0) == query.size(0) && | ||
logsumexp.size(1) == query.size(1) && | ||
logsumexp.size(2) == query.size(2) && | ||
"scaled_dot_product_fused_attention_overrideable_backward_xpu: logsumexp should have the shape of {(B), H, T}"); | ||
|
||
std::optional<Tensor> attn_bias_opt; | ||
if (attn_bias.defined()) { | ||
attn_bias_opt = attn_bias; | ||
} | ||
|
||
const int64_t batch_size = query.size(0); | ||
const int64_t num_head_q = query.size(1); | ||
const int64_t num_head_kv = key.size(1); | ||
const int64_t seq_len_q = query.size(2); | ||
const int64_t seq_len_kv = key.size(2); | ||
const int64_t head_dim_qk = query.size(3); | ||
const int64_t head_dim_v = value.size(3); | ||
|
||
auto grad_q = at::empty_like(query); | ||
auto grad_k = at::empty_like(key); | ||
auto grad_v = at::empty_like(value); | ||
auto grad_attn_bias = attn_bias_opt.has_value() | ||
? at::empty_like(attn_bias_opt.value()) | ||
: at::Tensor(); | ||
at::native::onednn::gpu_float_sdpa_backward( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have the same question. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It supports FP32/FP16/BF16. I directly copy this function name from SDPA inference. We could rename it. |
||
batch_size, | ||
num_head_q, | ||
num_head_kv, | ||
seq_len_q, | ||
seq_len_kv, | ||
head_dim_qk, | ||
head_dim_v, | ||
grad_out, | ||
query, | ||
key, | ||
value, | ||
out, | ||
logsumexp, | ||
attn_bias_opt, | ||
is_causal, | ||
scale.has_value() ? scale.value() : (1.0 / std::sqrt(query.size(3))), | ||
grad_q, | ||
grad_k, | ||
grad_v); | ||
return std::make_tuple( | ||
std::move(grad_q), | ||
std::move(grad_k), | ||
std::move(grad_v), | ||
std::move(grad_attn_bias)); | ||
} | ||
|
||
REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the implementation details, it returns True when
@LuFinch , is my understanding correct? If so, I would suggest refining the name of
check_grad
a little bit. Something could be likeis_onednn_attention_backward_supported
to illustrate your idea clearly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should be used to check the grad requirements of inputs to determine whether they are suitable for supporting overrideable SDPA on XPU in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As guangye saying, it is use to determine whether use overrideable SDPA. If return True, then it can use overrideable SDPA.
In OneDNN v3.9, SDPA training forward and backward don't support GQA and won't output grad for attn_mask.
Hence this function means: