From 0c913b38d64e2f8e769ee398fdc73eb49cee6352 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Wed, 18 Jun 2025 20:35:29 -0700 Subject: [PATCH 01/11] integrate sdpa training forward/backward --- aten/src/ATen/native/mkldnn/xpu/Attention.cpp | 171 ++++- .../native/mkldnn/xpu/detail/Attention.cpp | 677 +++++++++++++++++- .../ATen/native/mkldnn/xpu/detail/oneDNN.h | 25 +- aten/src/ATen/native/native_functions.yaml | 1 + cmake/Modules/FindMKLDNN.cmake | 5 +- test/test_transformers.py | 37 +- torch/_meta_registrations.py | 26 + 7 files changed, 874 insertions(+), 68 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 813db7a97ef9..02ce3cf739cd 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -39,14 +39,32 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) { return true; } -bool check_no_grad(sdp::sdp_params const& params, bool debug) { - const bool any_inputs_require_grad = params.query.requires_grad() || - params.key.requires_grad() || params.value.requires_grad(); - const bool gradmode_enabled = at::GradMode::is_enabled(); - if (debug && any_inputs_require_grad && gradmode_enabled) { - TORCH_WARN("Backward or grad to be supported."); +bool input_require_grad(const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask) { + return at::GradMode::is_enabled() && (query.requires_grad() || key.requires_grad() || value.requires_grad() + || (attn_mask.has_value() && attn_mask.value().requires_grad())); +} + +bool check_grad(sdp::sdp_params const& params, bool debug) { + if (!input_require_grad(params.query, params.key, params.value, params.attn_mask)) + return true; + + auto q_num_heads = params.query.sym_size(-3); + auto k_num_heads = params.key.sym_size(-3); + auto v_num_heads = params.value.sym_size(-3); + bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads; + if (debug && is_gqa) + TORCH_WARN("scale_dot_product_attention with gqa is not supported for gradient computation on xpu."); + + bool attn_mask_needs_grad = params.attn_mask.has_value() && params.attn_mask.value().requires_grad(); + if (debug && attn_mask_needs_grad) { + TORCH_WARN( + "scale_dot_product_attention on xpu is not supported with attn_mask.requires_grad() == True."); } - return !any_inputs_require_grad || !gradmode_enabled; + + return !is_gqa && !attn_mask_needs_grad; } bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) { @@ -64,7 +82,7 @@ bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) { sdp::check_nonzero_sequence_lengths_dense, sdp::check_last_dim_stride_equals_1_dense, check_head_dim_size_xpu, - check_no_grad); + check_grad); for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; @@ -194,6 +212,9 @@ _scaled_dot_product_fused_attention_overrideable_xpu( TORCH_INTERNAL_ASSERT( !(attn_bias.has_value() && is_causal), "scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal"); + TORCH_INTERNAL_ASSERT( + !(attn_bias.has_value() && attn_bias.value().requires_grad()), + "scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot have requires_grad=True"); const int64_t batch_size = query.size(0); const int64_t num_head_q = query.size(1); @@ -202,12 +223,18 @@ _scaled_dot_product_fused_attention_overrideable_xpu( const int64_t head_dim_v = value.size(3); const int64_t seq_len_q = query.size(2); const int64_t seq_len_kv = key.size(2); + const bool compute_logsumexp = input_require_grad(query, key, value, attn_bias); - at::Tensor output; - std::vector output_shape = { + at::Tensor attention; + std::vector attention_shape = { batch_size, num_head_q, seq_len_q, head_dim_v}; - alloc_with_matching_layout(query, output, output_shape); - at::Tensor logsumexp, debug_attn_mask; // not supported + alloc_with_matching_layout(query, attention, attention_shape); + + at::Tensor logsumexp; + if (compute_logsumexp) { + logsumexp = at::empty( + {batch_size, num_head_q, seq_len_q, 1}, opts.dtype(at::kFloat)); + } at::native::onednn::gpu_float_sdpa( batch_size, @@ -223,13 +250,15 @@ _scaled_dot_product_fused_attention_overrideable_xpu( attn_bias, is_causal, scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim_qk)), - output); + attention, + compute_logsumexp, + logsumexp); // rng not used auto philox_seed = at::empty({}, at::dtype(at::kLong)); auto philox_offset = at::empty({}, at::dtype(at::kLong)); return std::make_tuple( - output, + attention, logsumexp, /* cum_seq_q */ at::Tensor(), /* cum_seq_k */ at::Tensor(), @@ -237,7 +266,119 @@ _scaled_dot_product_fused_attention_overrideable_xpu( seq_len_kv, philox_seed, philox_offset, - debug_attn_mask); + /*debug_attn_mask */ at::Tensor()); +} + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +_scaled_dot_product_fused_attention_overrideable_backward_xpu( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale) { + if (!grad_out.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); + } + + TORCH_INTERNAL_ASSERT( + grad_out.dim() == 4 && out.dim() == 4 && + grad_out.size(0) == out.size(0) && + grad_out.size(1) == out.size(1) && + grad_out.size(2) == out.size(2) && + grad_out.size(3) == out.size(3), + "scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {(B), H, T, K}"); + TORCH_INTERNAL_ASSERT( + query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "scaled_dot_product_fused_attention_overrideable_backward_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}"); + TORCH_INTERNAL_ASSERT( + (key.size(0) == value.size(0)) && (key.size(1) == value.size(1)) && + (key.size(2) == value.size(2)), + "scaled_dot_product_fused_attention_overrideable_backward_xpu: K/V should have the same batch / seq / num_head"); + TORCH_INTERNAL_ASSERT( + query.size(0) == grad_out.size(0) && + query.size(1) == grad_out.size(1) && + query.size(2) == grad_out.size(2), + "scaled_dot_product_fused_attention_overrideable_backward_xpu: Q should have the same batch / num_head / seq_len as grad_out"); + TORCH_INTERNAL_ASSERT( + query.size(3) == key.size(3), + "scaled_dot_product_fused_attention_overrideable_backward_xpu: Q/K should have the same head_dim"); + TORCH_INTERNAL_ASSERT( + value.size(3) == grad_out.size(3), + "scaled_dot_product_fused_attention_overrideable_backward_xpu: V should have the same head_dim as grad_out"); + TORCH_INTERNAL_ASSERT( + query.size(1) == key.size(1), + "scaled_dot_product_fused_attention_overrideable_backward_xpu: number of heads in K/V must equal to number of heads in Q"); + TORCH_INTERNAL_ASSERT( + dropout_p == 0.0, + "scaled_dot_product_fused_attention_overrideable_backward_xpu: Currently do not support dropout > 0"); + TORCH_INTERNAL_ASSERT(logsumexp.dim() == 4 && + logsumexp.size(0) == query.size(0) && + logsumexp.size(1) == query.size(1) && + logsumexp.size(2) == query.size(2) && + logsumexp.size(3) == 1, + "scaled_dot_product_fused_attention_overrideable_backward_xpu: logsumexp should have the shape of {(B), H, T, 1}"); + + std::optional attn_bias_opt; + if (attn_bias.defined()) { + attn_bias_opt = attn_bias; + } + TORCH_INTERNAL_ASSERT( + !(attn_bias_opt.has_value() && is_causal), + "scaled_dot_product_fused_attention_overrideable_backward_xpu: attn_bias cannot present with is_causal"); + + const int64_t batch_size = query.size(0); + const int64_t num_head_q = query.size(1); + const int64_t num_head_kv = key.size(1); + const int64_t seq_len_q = query.size(2); + const int64_t seq_len_kv = key.size(2); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); + + auto grad_q = at::empty_like(query); + auto grad_k = at::empty_like(key); + auto grad_v = at::empty_like(value); + + at::native::onednn::gpu_float_sdpa_backward( + batch_size, + num_head_q, + num_head_kv, + seq_len_q, + seq_len_kv, + head_dim_qk, + head_dim_v, + grad_out, + query, + key, + value, + out, + logsumexp, + attn_bias_opt, + is_causal, + scale.has_value() ? scale.value() : (1.0 / std::sqrt(query.size(3))), + grad_q, + grad_k, + grad_v); + return std::make_tuple( + std::move(grad_q), + std::move(grad_k), + std::move(grad_v), + at::Tensor()); } REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu); diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp index 1d90711f6e38..3f33d6b85c80 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp @@ -13,6 +13,9 @@ using dims = logical_tensor::dims; using op = dnnl::graph::op; using partition = dnnl::graph::partition; +constexpr logical_tensor::data_type sdpa_intermedia_dtype = + logical_tensor::data_type::f32; + inline data_type to_logical_tensor_data_type(c10::ScalarType scalar_type) { return scalar_type == c10::ScalarType::Float ? data_type::f32 : scalar_type == c10::ScalarType::Half ? data_type::f16 @@ -20,6 +23,8 @@ inline data_type to_logical_tensor_data_type(c10::ScalarType scalar_type) { : data_type::undef; } +namespace sdpa_forward { + struct SDPALogicalParams { enum class TensorID { query, @@ -28,7 +33,8 @@ struct SDPALogicalParams { neg_inf, attn_mask, value, - output, + attention, + logsumexp, end, }; @@ -38,14 +44,16 @@ struct SDPALogicalParams { std::optional neg_inf; std::optional attn_mask; logical_tensor value{}; - logical_tensor output{}; + logical_tensor attention{}; + std::optional logsumexp; SDPALogicalParams( const at::Tensor& query_, const at::Tensor& key_, const at::Tensor& value_, const std::optional& attn_mask_, - const at::Tensor& output_, + const at::Tensor& attention_, + const at::Tensor& logsumexp_, int batch_size, int seq_len_q, int seq_len_kv, @@ -53,31 +61,27 @@ struct SDPALogicalParams { int num_head_kv, int head_dim_qk, int head_dim_v, - bool is_causal) { + bool is_causal, + bool compute_logsumexp) { const data_type dtype = to_logical_tensor_data_type(query_.scalar_type()); TORCH_INTERNAL_ASSERT( (dtype != data_type::undef), "Only FP16/BF16/FP32 datatypes are currently supported"); + TORCH_INTERNAL_ASSERT( + query_.scalar_type() == attention_.scalar_type(), + "scaled_dot_product_attention_xpu: query and attention tensors should have the same data type."); + TORCH_INTERNAL_ASSERT( + !at::native::onednn::is_broadcast(query_) && + !at::native::onednn::is_broadcast(key_) && + !at::native::onednn::is_broadcast(value_), + "scaled_dot_product_attention_xpu: tensors q/k/v should not be broadcasted tensor."); const dims scalar_shape = {1}; - std::vector inputLogicalTensors; at::Tensor reshaped_query = query_; at::Tensor reshaped_key = key_; at::Tensor reshaped_value = value_; - at::Tensor reshaped_output = output_; + at::Tensor reshaped_attention = attention_; at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor()); - if (at::native::onednn::is_broadcast(reshaped_query)) { - at::native::onednn::undo_broadcast(reshaped_query); - } - if (at::native::onednn::is_broadcast(reshaped_key)) { - at::native::onednn::undo_broadcast(reshaped_key); - } - if (at::native::onednn::is_broadcast(reshaped_value)) { - at::native::onednn::undo_broadcast(reshaped_value); - } - if (at::native::onednn::is_broadcast(reshaped_output)) { - at::native::onednn::undo_broadcast(reshaped_output); - } if (attn_mask_.has_value() && at::native::onednn::is_broadcast(reshaped_attn_mask)) { at::native::onednn::undo_broadcast(reshaped_attn_mask); @@ -95,7 +99,7 @@ struct SDPALogicalParams { {batch_size, group_num, group_size, seq_len_q, head_dim_qk}); reshaped_key = key_.unsqueeze(2); reshaped_value = value_.unsqueeze(2); - reshaped_output = output_.view( + reshaped_attention = attention_.view( {batch_size, group_num, group_size, seq_len_q, head_dim_v}); if (attn_mask_.has_value() && attn_mask_.value().dim() == 4) { reshaped_attn_mask = attn_mask_.value().unsqueeze(2); @@ -143,11 +147,23 @@ struct SDPALogicalParams { dtype, reshaped_value.sizes().vec(), reshaped_value.strides().vec()}; - output = { - static_cast(TensorID::output), + attention = { + static_cast(TensorID::attention), dtype, - reshaped_output.sizes().vec(), - reshaped_output.strides().vec()}; + reshaped_attention.sizes().vec(), + reshaped_attention.strides().vec()}; + if (compute_logsumexp) { + TORCH_INTERNAL_ASSERT( + logsumexp_.scalar_type() == at::kFloat, + "scaled_dot_product_attention: Expected logsumexp data type in FP32, but got ", + logsumexp_.scalar_type(), + " instead."); + logsumexp = { + static_cast(TensorID::logsumexp), + sdpa_intermedia_dtype, + logsumexp_.sizes().vec(), + logsumexp_.strides().vec()}; + } } std::vector get_input() const { std::vector input = {query, key, scale}; @@ -161,16 +177,21 @@ struct SDPALogicalParams { return input; } std::vector get_output() const { - return {output}; + std::vector output; + output.push_back(attention); + if (logsumexp.has_value()) { + output.push_back(logsumexp.value()); + } + return output; } }; partition create_sdpa_graph_partition( bool is_causal, + bool compute_logsumexp, data_type dtype, const SDPALogicalParams& params) { // graph building and partitioning - // currently, we assume that Q and K have same sequence length size_t lt_id = static_cast(SDPALogicalParams::TensorID::end); size_t op_id = 0; @@ -180,7 +201,7 @@ partition create_sdpa_graph_partition( // Matrix Extensions (Intel(R) XMX) support, which means the // Q/K/V tensors have bf16 or f16 data type while the output of the first // MatMul, Scale, Mask, and the input of SoftMax are in f32 data type. - logical_tensor matmul_qk_out{lt_id++, data_type::f32}; + logical_tensor matmul_qk_out{lt_id++, sdpa_intermedia_dtype}; op matmul_qk{ op_id++, op::kind::MatMul, @@ -189,7 +210,7 @@ partition create_sdpa_graph_partition( "matmul_qk"}; matmul_qk.set_attr(op::attr::transpose_b, true); - logical_tensor scaled_qk_out{lt_id++, data_type::f32}; + logical_tensor scaled_qk_out{lt_id++, sdpa_intermedia_dtype}; op scale_mul{ op_id++, op::kind::Multiply, @@ -214,7 +235,7 @@ partition create_sdpa_graph_partition( if (params.attn_mask.has_value()) { TORCH_INTERNAL_ASSERT( !is_causal, "Additive mask cannot use with is_causal."); - masked_qk_out = {lt_id++, data_type::f32}; + masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; mask_add = { op_id++, op::kind::Add, @@ -249,7 +270,7 @@ partition create_sdpa_graph_partition( {mask_gt_out.value()}, "mask_gt"}; - masked_qk_out = {lt_id++, data_type::f32}; + masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; mask_select = { op_id++, op::kind::Select, @@ -270,12 +291,15 @@ partition create_sdpa_graph_partition( logical_tensor softmax_out{lt_id++, dtype}; softmax.add_input(masked_qk_out.value_or(scaled_qk_out)); softmax.add_output(softmax_out); + if (compute_logsumexp) { + softmax.add_output(params.logsumexp.value()); + } op matmul_v{ op_id++, op::kind::MatMul, {softmax_out, params.value}, - {params.output}, + {params.attention}, "matmul_v"}; constexpr auto ekind = dnnl::engine::kind::gpu; @@ -304,6 +328,7 @@ partition create_sdpa_graph_partition( partition& find_or_create_graph_partition( bool is_causal, + bool compute_logsumexp, const SDPALogicalParams& params) { thread_local static PartitionCache cache; const data_type dtype = params.query.get_data_type(); @@ -327,17 +352,470 @@ partition& find_or_create_graph_partition( // attn_mask patternID.set(pos++, params.attn_mask.has_value()); patternID.set(pos++, is_causal); + // compute_logsumexp + patternID.set(pos++, compute_logsumexp); auto partition_ = cache.find_partition(patternID); if (!partition_.has_value()) { // partition cache no hit // graph building and partitioning - partition sdp_partition = - create_sdpa_graph_partition(is_causal, dtype, params); + partition sdp_partition = create_sdpa_graph_partition( + is_causal, compute_logsumexp, dtype, params); partition_ = cache.insert_partition_cache(patternID, sdp_partition); } return *partition_; } +} // namespace sdpa_forward + +namespace sdpa_backward { + +struct SDPABackwardLogicalParams { + enum class TensorID { + grad_out, + query, + key, + value, + out, + logsumexp, + scale, + neg_inf, + attn_mask, + grad_query, + grad_key, + grad_value, + end, + }; + + logical_tensor grad_out{}; + logical_tensor query{}; + logical_tensor key{}; + logical_tensor value{}; + logical_tensor out{}; + logical_tensor logsumexp{}; + logical_tensor scale{}; + std::optional neg_inf; + std::optional attn_mask; + logical_tensor grad_query{}; + logical_tensor grad_key{}; + logical_tensor grad_value{}; + + SDPABackwardLogicalParams( + const at::Tensor& grad_out_, + const at::Tensor& query_, + const at::Tensor& key_, + const at::Tensor& value_, + const at::Tensor& out_, + const at::Tensor& logsumexp_, + const std::optional& attn_mask_, + const at::Tensor& grad_query_, + const at::Tensor& grad_key_, + const at::Tensor& grad_value_, + int batch_size, + int num_head_q, + int num_head_kv, + int seq_len_q, + int seq_len_kv, + int head_dim_qk, + int head_dim_v, + bool is_causal) { + const data_type dtype = to_logical_tensor_data_type(query_.scalar_type()); + TORCH_INTERNAL_ASSERT( + (dtype != data_type::undef), + "Only FP16/BF16/FP32 datatypes are currently supported"); + TORCH_INTERNAL_ASSERT( + grad_out_.scalar_type() == query_.scalar_type() && + grad_out_.scalar_type() == key_.scalar_type() && + grad_out_.scalar_type() == value_.scalar_type() && + grad_out_.scalar_type() == out_.scalar_type(), + "scaled_dot_product_attention_backward_xpu: Expected grad_out, q, k, v and out to have the same data type, but got ", + " grad_out: ", + grad_out_.scalar_type(), + ", q: ", + query_.scalar_type(), + ", k: ", + key_.scalar_type(), + ", v: ", + value_.scalar_type(), + ", out: ", + out_.scalar_type()); + TORCH_INTERNAL_ASSERT( + logsumexp_.defined() && logsumexp_.scalar_type() == at::kFloat, + "scaled_dot_product_attention_backward_xpu: Expected logsumexp to be defined and have FP32 data type"); + TORCH_INTERNAL_ASSERT( + !at::native::onednn::is_broadcast(query_) && + !at::native::onednn::is_broadcast(key_) && + !at::native::onednn::is_broadcast(value_) && + !at::native::onednn::is_broadcast(out_) && + !at::native::onednn::is_broadcast(logsumexp_), + "scaled_dot_product_attention_backward_xpu: tensors grad_out, q, k, v, out and logsumexp should not be broadcasted tensor."); + + const dims scalar_shape = {1}; + + at::Tensor reshaped_grad_out = grad_out_; + at::Tensor reshaped_query = query_; + at::Tensor reshaped_key = key_; + at::Tensor reshaped_value = value_; + at::Tensor reshaped_out = out_; + at::Tensor reshaped_logsumexp = logsumexp_; + at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor()); + if (at::native::onednn::is_broadcast(reshaped_grad_out)) { + at::native::onednn::undo_broadcast(reshaped_grad_out); + } + if (attn_mask_.has_value() && + at::native::onednn::is_broadcast(reshaped_attn_mask)) { + at::native::onednn::undo_broadcast(reshaped_attn_mask); + } + + // TODO: Support GQA in backward pass once OneDNN supports it. + + grad_out = { + static_cast(TensorID::grad_out), + dtype, + reshaped_grad_out.sizes().vec(), + reshaped_grad_out.strides().vec()}; + query = { + static_cast(TensorID::query), + dtype, + reshaped_query.sizes().vec(), + reshaped_query.strides().vec()}; + key = { + static_cast(TensorID::key), + dtype, + reshaped_key.sizes().vec(), + reshaped_key.strides().vec()}; + value = { + static_cast(TensorID::value), + dtype, + reshaped_value.sizes().vec(), + reshaped_value.strides().vec()}; + out = { + static_cast(TensorID::out), + dtype, + reshaped_out.sizes().vec(), + reshaped_out.strides().vec()}; + logsumexp = { + static_cast(TensorID::logsumexp), + sdpa_intermedia_dtype, + reshaped_logsumexp.sizes().vec(), + reshaped_logsumexp.strides().vec()}; + scale = { + static_cast(TensorID::scale), + to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), + scalar_shape, + logical_tensor::layout_type::strided, + logical_tensor::property_type::constant}; + if (is_causal) { + neg_inf = { + static_cast(TensorID::neg_inf), + to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), + scalar_shape, + logical_tensor::layout_type::strided, + logical_tensor::property_type::constant}; + } + if (attn_mask_.has_value()) { + const data_type mask_dtype = + to_logical_tensor_data_type(attn_mask_->scalar_type()); + TORCH_INTERNAL_ASSERT( + (mask_dtype != data_type::undef), + "Only FP16/BF16/FP32 datatypes are currently supported for attn_mask"); + attn_mask = { + static_cast(TensorID::attn_mask), + mask_dtype, + reshaped_attn_mask.sizes().vec(), + reshaped_attn_mask.strides().vec()}; + } + grad_query = { + static_cast(TensorID::grad_query), + dtype, + reshaped_query.sizes().vec(), + reshaped_query.strides().vec()}; + grad_key = { + static_cast(TensorID::grad_key), + dtype, + reshaped_key.sizes().vec(), + reshaped_key.strides().vec()}; + grad_value = { + static_cast(TensorID::grad_value), + dtype, + reshaped_value.sizes().vec(), + reshaped_value.strides().vec()}; + } + std::vector get_input() const { + std::vector input = { + grad_out, query, key, value, out, logsumexp, scale}; + if (neg_inf.has_value()) { + input.push_back(neg_inf.value()); + } + if (attn_mask.has_value()) { + input.push_back(attn_mask.value()); + } + return input; + } + std::vector get_output() const { + std::vector output = {grad_query, grad_key, grad_value}; + return output; + } +}; + +partition create_sdpa_backward_graph_partition( + bool is_causal, + data_type dtype, + const SDPABackwardLogicalParams& params) { + // graph building and partitioning + size_t lt_id = static_cast(SDPABackwardLogicalParams::TensorID::end); + size_t op_id = 0; + + // OneDNN graph has optimized implementation for `f16` or `bf16` SDPA with + // `f32` intermediate data type on Intel Graphics Products with Intel(R) Xe + // Matrix Extensions (Intel(R) XMX) support, which means the + // Q/K/V tensors have bf16 or f16 data type while the output of the first + // MatMul, Scale, Mask, and the input of SoftMax are in f32 data type. + logical_tensor matmul_qk_out{lt_id++, sdpa_intermedia_dtype}; + op matmul_qk{ + op_id++, + op::kind::MatMul, + {params.query, params.key}, + {matmul_qk_out}, + "matmul_qk"}; + matmul_qk.set_attr(op::attr::transpose_b, true); + + logical_tensor scaled_qk_out{lt_id++, sdpa_intermedia_dtype}; + op scale_mul{ + op_id++, + op::kind::Multiply, + {matmul_qk_out, params.scale}, + {scaled_qk_out}, + "scale_mul"}; + + std::optional masked_qk_out; + + // For optional additive mask + std::optional mask_add; + + // For optional implicite causal mask + std::optional mask_gen_idx_row; + std::optional mask_row_idx; + std::optional mask_gen_idx_col; + std::optional mask_col_idx; + std::optional mask_gt; + std::optional mask_gt_out; + std::optional mask_select; + + if (params.attn_mask.has_value()) { + TORCH_INTERNAL_ASSERT( + !is_causal, "Additive mask cannot use with is_causal."); + masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; + mask_add = { + op_id++, + op::kind::Add, + {scaled_qk_out, params.attn_mask.value()}, + {masked_qk_out.value()}, + "mask_add"}; + } else if (is_causal) { + mask_row_idx = {lt_id++, data_type::s32}; + mask_gen_idx_row = { + op_id++, + op::kind::GenIndex, + {scaled_qk_out}, + {mask_row_idx.value()}, + "mask_gen_idx_row"}; + mask_gen_idx_row->set_attr(op::attr::axis, -2); + + mask_col_idx = {lt_id++, data_type::s32}; + mask_gen_idx_col = { + op_id++, + op::kind::GenIndex, + {scaled_qk_out}, + {mask_col_idx.value()}, + "mask_gen_idx_col"}; + mask_gen_idx_col->set_attr(op::attr::axis, -1); + + mask_gt_out = {lt_id++, data_type::boolean}; + mask_gt = { + op_id++, + op::kind::GreaterEqual, + {mask_row_idx.value(), mask_col_idx.value()}, + {mask_gt_out.value()}, + "mask_gt"}; + + masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; + mask_select = { + op_id++, + op::kind::Select, + {mask_gt_out.value(), scaled_qk_out, params.neg_inf.value()}, + {masked_qk_out.value()}, + "mask_select"}; + } + + // attention_probs = softmax(masked_score) = exp(masked_score - logsumexp) + logical_tensor sub_out{lt_id++, sdpa_intermedia_dtype}; + op subtract{ + op_id++, + op::kind::Subtract, + {masked_qk_out.value_or(scaled_qk_out), params.logsumexp}, + {sub_out}, + "subtract"}; + logical_tensor prob{lt_id++, sdpa_intermedia_dtype}; + op exp{op_id++, op::kind::Exp, {sub_out}, {prob}, "exp"}; + + // The following matmul doesn't support different input dtypes, insert a + // typecast + logical_tensor prob_casted = prob; + op typecast = op(op_id++, op::kind::TypeCast, "typecast"); + if (dtype != sdpa_intermedia_dtype) { + prob_casted = logical_tensor(lt_id++, dtype); + typecast.add_inputs({prob}); + typecast.add_outputs({prob_casted}); + } + + // grad_value = prob^T * grad_out + // TODO: handle GQA headnum because (batch_size, num_head_kv, seq_len_kv, + // head_dim_v) != (batch_size, num_head_q, seqlen_kv, seq_len_q) * + // (batch_size, num_head_q, seqlen_q, head_dim_v) + op matmul_grad_value{ + op_id++, + op::kind::MatMul, + {prob_casted, params.grad_out}, + {params.grad_value}, + "matmul_grad_value"}; + matmul_grad_value.set_attr(op::attr::transpose_a, true); + + // grad_prop = grad_out * value^T + // TODO: handle GQA headnum because (batch_size, num_head_q, seq_len_q, + // seq_len_kv) != (batch_size, num_head_q, seq_len_q, head_dim_v) * + // (batch_size, num_head_kv, head_dim_v, seq_len_kv) + logical_tensor grad_prop{lt_id++, sdpa_intermedia_dtype}; + op matmul_grad_prop{ + op_id++, + op::kind::MatMul, + {params.grad_out, params.value}, + {grad_prop}, + "matmul_grad_prop"}; + matmul_grad_prop.set_attr(op::attr::transpose_b, true); + + // grad_masked_score = softmaxbackward(grad_prop) + logical_tensor grad_masked_score{lt_id++, sdpa_intermedia_dtype}; + op softmax_backward{ + op_id++, + op::kind::SoftMaxBackward, + {grad_prop, prob}, + {grad_masked_score}, + "softmax_backward"}; + softmax_backward.set_attr(op::attr::axis, -1); + + // TODO: add output tensor grad_attn_mask = grad_masked_score once OneDNN + // supports output grad_attn_mask. + + // grad_scaled_score = grad_masked_score * scale + logical_tensor grad_scaled_score{lt_id++, sdpa_intermedia_dtype}; + op grad_scale_mul{ + op_id++, + op::kind::Multiply, + {grad_masked_score, params.scale}, + {grad_scaled_score}, + "grad_scale_mul"}; + + // The following matmul doesn't support different input dtypes, insert a + // typecast + logical_tensor grad_scaled_score_cast = grad_scaled_score; + op typecast2 = op(op_id++, op::kind::TypeCast, "typecast2"); + if (dtype != sdpa_intermedia_dtype) { + grad_scaled_score_cast = logical_tensor(lt_id++, dtype); + typecast2.add_inputs({grad_scaled_score}); + typecast2.add_outputs({grad_scaled_score_cast}); + } + + // grad_query = grad_scaled_score_cast * key + // TODO: handle GQA headnum because (batch_size, num_head_q, seq_len_q, + // head_dim_qk) != (batch_size, num_head_q, seq_len_q, seq_len_kv) * + // (batch_size, num_head_kv, seq_len_kv, head_dim_qk) + op matmul_grad_query{ + op_id++, + op::kind::MatMul, + {grad_scaled_score_cast, params.key}, + {params.grad_query}, + "matmul_grad_query"}; + + // grad_key = grad_scaled_score_cast^T * query + op matmul_grad_key{ + op_id++, + op::kind::MatMul, + {grad_scaled_score_cast, params.query}, + {params.grad_key}, + "matmul_grad_key"}; + matmul_grad_key.set_attr(op::attr::transpose_a, true); + + constexpr auto ekind = dnnl::engine::kind::gpu; + dnnl::graph::graph g(ekind); + g.add_op(matmul_qk); + g.add_op(scale_mul); + if (mask_add.has_value()) { + g.add_op(mask_add.value()); + } + // if (is_causal) { + // g.add_op(mask_gen_idx_row.value()); + // g.add_op(mask_gen_idx_col.value()); + // g.add_op(mask_gt.value()); + // g.add_op(mask_select.value()); + // } + g.add_op(subtract); + g.add_op(exp); + g.add_op(matmul_grad_value); + g.add_op(matmul_grad_prop); + g.add_op(softmax_backward); + g.add_op(grad_scale_mul); + g.add_op(matmul_grad_query); + g.add_op(matmul_grad_key); + if (dtype != sdpa_intermedia_dtype) { + g.add_op(typecast); + g.add_op(typecast2); + } + g.finalize(); + auto partitions = g.get_partitions(); + TORCH_INTERNAL_ASSERT( + (partitions.size() == 1) && partitions[0].is_supported(), + "oneDNN doesn't support this fusion pattern. If you'd like its support, please submit a issue."); + return partitions[0]; +} + +partition& find_or_create_backward_graph_partition( + bool is_causal, + const SDPABackwardLogicalParams& params) { + thread_local static PartitionCache cache; + const data_type dtype = params.query.get_data_type(); + + // cache key creation + // patternID is determined on the basis of the arguments provided + std::bitset<32> patternID; + if (dtype == data_type::f32) { + // bit 3 corresponds to float32 dtype + patternID.set(3, 1); + } + if (dtype == data_type::bf16) { + // bit 2 corresponds to fp16/bf16 dtype + patternID.set(2, 1); + } + // sdpa backward pattern + patternID.set(5, 1); + + // Refer to comments in Utils.h. The first 8 bits are reserved + int pos = 8; + // attn_mask + patternID.set(pos++, params.attn_mask.has_value()); + patternID.set(pos++, is_causal); + + auto partition_ = cache.find_partition(patternID); + if (!partition_.has_value()) { + // partition cache no hit + // graph building and partitioning + partition sdpa_backward_partition = + create_sdpa_backward_graph_partition(is_causal, dtype, params); + partition_ = + cache.insert_partition_cache(patternID, sdpa_backward_partition); + } + return *partition_; +} +} // namespace sdpa_backward } // namespace namespace at::native::onednn { @@ -355,7 +833,9 @@ void gpu_float_sdpa( std::optional attn_mask, bool is_causal, float softmax_scale, - const Tensor& output) { + const Tensor& attention, + bool compute_logsumexp, + const Tensor& logsumexp) { auto& eng = GpuEngineManager::Instance().get_engine(); auto& strm = GpuStreamManager::Instance().get_stream(); @@ -382,12 +862,13 @@ void gpu_float_sdpa( std::optional compiled_partition; auto get_compiled_partition = [&]() { - const SDPALogicalParams logical_params( + const sdpa_forward::SDPALogicalParams logical_params( query, key, value, attn_mask, - output, + attention, + logsumexp, batch_size, seq_len_q, seq_len_kv, @@ -395,9 +876,10 @@ void gpu_float_sdpa( num_head_kv, head_dim_qk, head_dim_v, - is_causal); - auto& partition_ = - find_or_create_graph_partition(is_causal, logical_params); + is_causal, + compute_logsumexp); + auto& partition_ = sdpa_forward::find_or_create_graph_partition( + is_causal, compute_logsumexp, logical_params); auto i = logical_params.get_input(); auto o = logical_params.get_output(); auto compiled_partition = partition_.compile(i, o, eng); @@ -421,8 +903,12 @@ void gpu_float_sdpa( } std::vector outputs = { - {l_outputs[0], eng, output.data_ptr()}, + {l_outputs[0], eng, attention.data_ptr()}, }; + if (compute_logsumexp) { + outputs.emplace_back(l_outputs[1], eng, logsumexp.data_ptr()); + } + size_t i = 0; std::vector inputs; inputs.reserve(l_inputs.size()); @@ -438,4 +924,117 @@ void gpu_float_sdpa( inputs.emplace_back(l_inputs[i++], eng, value.data_ptr()); compiled_partition->execute(strm, inputs, outputs); } + +void gpu_float_sdpa_backward( + int batch_size, + int num_head_q, + int num_head_kv, + int seq_len_q, + int seq_len_kv, + int head_dim_qk, + int head_dim_v, + const Tensor& grad_out, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& out, + const Tensor& logsumexp, + std::optional attn_mask, + bool is_causal, + double scale, + Tensor& grad_query, + Tensor& grad_key, + Tensor& grad_value) { + auto& eng = GpuEngineManager::Instance().get_engine(); + auto& strm = GpuStreamManager::Instance().get_stream(); + + const auto get_tril_mask = [&]() { + auto opts = query.options(); + auto bool_tril = + at::ones_symint({seq_len_q, seq_len_kv}, opts.dtype(at::kBool)).tril(); + return at::where( + bool_tril, + 0.f, + at::scalar_tensor(-std::numeric_limits::infinity(), opts)); + }; + + // OneDNN doesn't support fp32 ukernel for implicit causal mask, + // and the reference implementation is worse than aten math + explict causal + // mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32 + // ukernel for implicit causal mask. + // TODO: support causal once OneDNN support causal in backward pass. + if (true) { // || (is_causal && query.dtype() == at::kFloat)) { + attn_mask = get_tril_mask(); + is_causal = false; + } + + std::vector l_inputs, l_outputs; + std::optional compiled_partition; + + auto get_compiled_partition = [&]() { + const sdpa_backward::SDPABackwardLogicalParams logical_params( + grad_out, + query, + key, + value, + out, + logsumexp, + attn_mask, + grad_query, + grad_key, + grad_value, + batch_size, + num_head_q, + num_head_kv, + seq_len_q, + seq_len_kv, + head_dim_qk, + head_dim_v, + is_causal); + auto& partition_ = sdpa_backward::find_or_create_backward_graph_partition( + is_causal, logical_params); + auto i = logical_params.get_input(); + auto o = logical_params.get_output(); + auto compiled_partition = partition_.compile(i, o, eng); + l_inputs = std::move(i); + l_outputs = std::move(o); + return compiled_partition; + }; + + compiled_partition = get_compiled_partition(); + + Tensor softmax_scale = at::full( + {}, scale, query.options().dtype(at::toOpMathType(query.scalar_type()))); + std::optional neg_inf; + if (is_causal) { + neg_inf = at::full( + {}, + -INFINITY, + query.options().dtype(at::toOpMathType(query.scalar_type()))); + } + + std::vector outputs = { + {l_outputs[0], eng, grad_query.data_ptr()}, + {l_outputs[1], eng, grad_key.data_ptr()}, + {l_outputs[2], eng, grad_value.data_ptr()}, + }; + + size_t i = 0; + std::vector inputs; + inputs.reserve(l_inputs.size()); + inputs.emplace_back(l_inputs[i++], eng, grad_out.data_ptr()); + inputs.emplace_back(l_inputs[i++], eng, query.data_ptr()); + inputs.emplace_back(l_inputs[i++], eng, key.data_ptr()); + inputs.emplace_back(l_inputs[i++], eng, value.data_ptr()); + inputs.emplace_back(l_inputs[i++], eng, out.data_ptr()); + inputs.emplace_back(l_inputs[i++], eng, logsumexp.data_ptr()); + inputs.emplace_back(l_inputs[i++], eng, softmax_scale.data_ptr()); + if (neg_inf.has_value()) { + inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr()); + } + if (attn_mask.has_value()) { + inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr()); + } + compiled_partition->execute(strm, inputs, outputs); +} } // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h index e73cb73e8b1e..a07d2fd3c254 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h @@ -178,5 +178,28 @@ void gpu_float_sdpa( std::optional attn_mask, bool is_causal, float softmax_scale, - const Tensor& output); + const Tensor& attention, + bool compute_logsumexp, + const Tensor& logsumexp); + +void gpu_float_sdpa_backward( + int batch_size, + int num_head_q, + int num_head_kv, + int seq_len_q, + int seq_len_kv, + int head_dim_qk, + int head_dim_v, + const Tensor& grad_out, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& out, + const Tensor& logsumexp, + std::optional attn_mask, + bool is_causal, + double scale, + Tensor& grad_query, + Tensor& grad_key, + Tensor& grad_value); } // namespace at::native::onednn diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b1bb48647743..0f1f3f925646 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14957,6 +14957,7 @@ variants: function dispatch: CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable_backward + XPU: _scaled_dot_product_fused_attention_overrideable_backward_xpu - func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) dispatch: diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 00fd0130d834..09b8e1ab887e 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -46,8 +46,8 @@ IF(NOT MKLDNN_FOUND) endif() endif() ExternalProject_Add(xpu_mkldnn_proj - GIT_REPOSITORY https://github.com/oneapi-src/oneDNN - GIT_TAG v3.8.1 + GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN.git + GIT_TAG yixin/sdpa-training-impl PREFIX ${XPU_MKLDNN_DIR_PREFIX} BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx @@ -56,6 +56,7 @@ IF(NOT MKLDNN_FOUND) -DDNNL_CPU_RUNTIME=THREADPOOL -DDNNL_BUILD_TESTS=OFF -DDNNL_BUILD_EXAMPLES=OFF + -DONEDNN_ENABLE_GRAPH_DUMP=ON -DONEDNN_BUILD_GRAPH=ON -DDNNL_LIBRARY_TYPE=STATIC -DDNNL_DPCPP_HOST_COMPILER=${DNNL_HOST_COMPILER} # Use global cxx compiler as host compiler diff --git a/test/test_transformers.py b/test/test_transformers.py index 89db8d798c26..ea6bbc60a039 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -4323,7 +4323,7 @@ def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: s self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) - @parametrize("fused_kernel", [SDPBackend.MATH, SDPBackend.OVERRIDEABLE]) + @parametrize("fused_kernel", [SDPBackend.OVERRIDEABLE]) @parametrize("dtype", [torch.half, torch.bfloat16, torch.float32]) @parametrize("batch_size,n_head,q_size,kv_size,head_dim", [ (2, 5, 9216, 9216, 64), @@ -4341,8 +4341,8 @@ def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: s (1, 32, 2016, 2016, 128), (4, 32, 2016, 2016, 128), ]) - @parametrize("mask_type", ["float", "causal"]) - @parametrize("train", [False]) + @parametrize("mask_type", ["float",]) #"causal"]) + @parametrize("train", [False, True]) def test_scaled_dot_product_fused_attention_mask_vs_math( self, device, @@ -4362,7 +4362,7 @@ def test_scaled_dot_product_fused_attention_mask_vs_math( tol = Tolerances(5e-2, 5e-2) if dtype is torch.float16: tol = Tolerances(1e-2, 1e-2) - mask_shape = [batch_size, 1, 1, kv_size] + mask_shape = [batch_size, 1, q_size, kv_size] make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False) q_shape = SdpaShape(batch_size, n_head, q_size, head_dim) kv_shape = SdpaShape(batch_size, n_head, kv_size, head_dim) @@ -4398,18 +4398,33 @@ def test_scaled_dot_product_fused_attention_mask_vs_math( v2 = v2.view(batch_size, kv_size, n_head, head_dim).transpose(1, 2) attn_mask2 = attn_mask.float() if attn_mask is not None else None - if fused_kernel == SDPBackend.MATH: - actual = torch.ops.aten._scaled_dot_product_attention_math( - q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)[0] - elif fused_kernel == SDPBackend.OVERRIDEABLE: - actual = torch.ops.aten._scaled_dot_product_fused_attention_overrideable( - q, k, v, attn_bias=attn_mask, dropout_p=0.0, is_causal=is_causal)[0] + if fused_kernel == SDPBackend.OVERRIDEABLE: + with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]): + actual = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal) + else: + raise NotImplementedError("Only OVERRIDEABLE backend is supported in this test") math_ref = torch.ops.aten._scaled_dot_product_attention_math( q2, k2, v2, attn_mask=attn_mask2, dropout_p=0.0, is_causal=is_causal)[0] - self.assertEqual(actual.float(), math_ref, atol=tol.atol, rtol=tol.rtol) + if dtype in [torch.float16, torch.bfloat16]: + math_ref = math_ref.to(dtype) + + self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol) + + if train: + loss = torch.mean(actual) + loss_ref = torch.mean(math_ref) + loss.backward() + loss_ref.backward() + + grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad + grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad + self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol) + self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol) + self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol) class TestAttnBias(NNTestCase): diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 07f22eab3f01..f8727ded4217 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5758,6 +5758,32 @@ def meta__scaled_dot_product_fused_attention_overrideable( ) +@register_meta([aten._scaled_dot_product_fused_attention_overrideable_backward]) +def meta__scaled_dot_product_fused_attention_overrideable_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor, + grad_input_mask: list[bool], + out: Tensor, + logsumexp: Tensor, + cum_seq_q: Tensor, + cum_seq_k: Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + philox_seed: Tensor, + philox_offset: Tensor, + scale: Optional[float] = None, +): + grad_q = torch.empty_like(query) + grad_k = torch.empty_like(key) + grad_v = torch.empty_like(value) + return grad_q, grad_k, grad_v, None + + @register_meta( [ aten._scaled_dot_product_flash_attention_backward, From 9c74464180fcfdc97d974b75ca9246bb8557c352 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Fri, 20 Jun 2025 03:05:28 -0700 Subject: [PATCH 02/11] fix metashape and causal --- aten/src/ATen/native/mkldnn/xpu/Attention.cpp | 28 ++-- .../native/mkldnn/xpu/detail/Attention.cpp | 140 +++++++++--------- test/test_transformers.py | 4 +- torch/_meta_registrations.py | 6 +- 4 files changed, 93 insertions(+), 85 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 02ce3cf739cd..4de49495ee37 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -64,7 +64,12 @@ bool check_grad(sdp::sdp_params const& params, bool debug) { "scale_dot_product_attention on xpu is not supported with attn_mask.requires_grad() == True."); } - return !is_gqa && !attn_mask_needs_grad; + bool is_causal = params.is_causal; + if (debug && is_causal) { + TORCH_WARN( + "scale_dot_product_attention on xpu is not supported with is_causal == True for training."); + } + return !is_gqa && !attn_mask_needs_grad && !is_causal; } bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) { @@ -233,7 +238,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu( at::Tensor logsumexp; if (compute_logsumexp) { logsumexp = at::empty( - {batch_size, num_head_q, seq_len_q, 1}, opts.dtype(at::kFloat)); + {batch_size, num_head_q, seq_len_q}, opts.dtype(at::kFloat)); } at::native::onednn::gpu_float_sdpa( @@ -292,10 +297,6 @@ _scaled_dot_product_fused_attention_overrideable_backward_xpu( const at::Tensor& philox_seed, const at::Tensor& philox_offset, std::optional scale) { - if (!grad_out.defined()) { - return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); - } - TORCH_INTERNAL_ASSERT( grad_out.dim() == 4 && out.dim() == 4 && grad_out.size(0) == out.size(0) && @@ -327,20 +328,19 @@ _scaled_dot_product_fused_attention_overrideable_backward_xpu( TORCH_INTERNAL_ASSERT( dropout_p == 0.0, "scaled_dot_product_fused_attention_overrideable_backward_xpu: Currently do not support dropout > 0"); - TORCH_INTERNAL_ASSERT(logsumexp.dim() == 4 && + TORCH_INTERNAL_ASSERT(logsumexp.dim() == 3 && logsumexp.size(0) == query.size(0) && logsumexp.size(1) == query.size(1) && logsumexp.size(2) == query.size(2) && - logsumexp.size(3) == 1, - "scaled_dot_product_fused_attention_overrideable_backward_xpu: logsumexp should have the shape of {(B), H, T, 1}"); + "scaled_dot_product_fused_attention_overrideable_backward_xpu: logsumexp should have the shape of {(B), H, T}"); std::optional attn_bias_opt; if (attn_bias.defined()) { attn_bias_opt = attn_bias; } TORCH_INTERNAL_ASSERT( - !(attn_bias_opt.has_value() && is_causal), - "scaled_dot_product_fused_attention_overrideable_backward_xpu: attn_bias cannot present with is_causal"); + !is_causal, + "scaled_dot_product_fused_attention_overrideable_backward_xpu: Curently do not support is_causal = True"); const int64_t batch_size = query.size(0); const int64_t num_head_q = query.size(1); @@ -353,7 +353,9 @@ _scaled_dot_product_fused_attention_overrideable_backward_xpu( auto grad_q = at::empty_like(query); auto grad_k = at::empty_like(key); auto grad_v = at::empty_like(value); - + auto grad_attn_bias = attn_bias_opt.has_value() + ? at::empty_like(attn_bias_opt.value()) + : at::Tensor(); at::native::onednn::gpu_float_sdpa_backward( batch_size, num_head_q, @@ -378,7 +380,7 @@ _scaled_dot_product_fused_attention_overrideable_backward_xpu( std::move(grad_q), std::move(grad_k), std::move(grad_v), - at::Tensor()); + std::move(grad_attn_bias)); } REGISTER_XPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_xpu); diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp index 3f33d6b85c80..16e5e9c1944d 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp @@ -81,6 +81,7 @@ struct SDPALogicalParams { at::Tensor reshaped_key = key_; at::Tensor reshaped_value = value_; at::Tensor reshaped_attention = attention_; + at::Tensor reshaped_logsumexp = logsumexp_.unsqueeze(-1); at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor()); if (attn_mask_.has_value() && at::native::onednn::is_broadcast(reshaped_attn_mask)) { @@ -161,8 +162,8 @@ struct SDPALogicalParams { logsumexp = { static_cast(TensorID::logsumexp), sdpa_intermedia_dtype, - logsumexp_.sizes().vec(), - logsumexp_.strides().vec()}; + reshaped_logsumexp.sizes().vec(), + reshaped_logsumexp.strides().vec()}; } } std::vector get_input() const { @@ -456,7 +457,7 @@ struct SDPABackwardLogicalParams { at::Tensor reshaped_key = key_; at::Tensor reshaped_value = value_; at::Tensor reshaped_out = out_; - at::Tensor reshaped_logsumexp = logsumexp_; + at::Tensor reshaped_logsumexp = logsumexp_.unsqueeze(-1); at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor()); if (at::native::onednn::is_broadcast(reshaped_grad_out)) { at::native::onednn::undo_broadcast(reshaped_grad_out); @@ -504,14 +505,14 @@ struct SDPABackwardLogicalParams { scalar_shape, logical_tensor::layout_type::strided, logical_tensor::property_type::constant}; - if (is_causal) { - neg_inf = { - static_cast(TensorID::neg_inf), - to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), - scalar_shape, - logical_tensor::layout_type::strided, - logical_tensor::property_type::constant}; - } + // if (is_causal) { + // neg_inf = { + // static_cast(TensorID::neg_inf), + // to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), + // scalar_shape, + // logical_tensor::layout_type::strided, + // logical_tensor::property_type::constant}; + // } if (attn_mask_.has_value()) { const data_type mask_dtype = to_logical_tensor_data_type(attn_mask_->scalar_type()); @@ -543,9 +544,9 @@ struct SDPABackwardLogicalParams { std::vector get_input() const { std::vector input = { grad_out, query, key, value, out, logsumexp, scale}; - if (neg_inf.has_value()) { - input.push_back(neg_inf.value()); - } + // if (neg_inf.has_value()) { + // input.push_back(neg_inf.value()); + // } if (attn_mask.has_value()) { input.push_back(attn_mask.value()); } @@ -593,13 +594,13 @@ partition create_sdpa_backward_graph_partition( std::optional mask_add; // For optional implicite causal mask - std::optional mask_gen_idx_row; - std::optional mask_row_idx; - std::optional mask_gen_idx_col; - std::optional mask_col_idx; - std::optional mask_gt; - std::optional mask_gt_out; - std::optional mask_select; + // std::optional mask_gen_idx_row; + // std::optional mask_row_idx; + // std::optional mask_gen_idx_col; + // std::optional mask_col_idx; + // std::optional mask_gt; + // std::optional mask_gt_out; + // std::optional mask_select; if (params.attn_mask.has_value()) { TORCH_INTERNAL_ASSERT( @@ -611,41 +612,42 @@ partition create_sdpa_backward_graph_partition( {scaled_qk_out, params.attn_mask.value()}, {masked_qk_out.value()}, "mask_add"}; - } else if (is_causal) { - mask_row_idx = {lt_id++, data_type::s32}; - mask_gen_idx_row = { - op_id++, - op::kind::GenIndex, - {scaled_qk_out}, - {mask_row_idx.value()}, - "mask_gen_idx_row"}; - mask_gen_idx_row->set_attr(op::attr::axis, -2); - - mask_col_idx = {lt_id++, data_type::s32}; - mask_gen_idx_col = { - op_id++, - op::kind::GenIndex, - {scaled_qk_out}, - {mask_col_idx.value()}, - "mask_gen_idx_col"}; - mask_gen_idx_col->set_attr(op::attr::axis, -1); - - mask_gt_out = {lt_id++, data_type::boolean}; - mask_gt = { - op_id++, - op::kind::GreaterEqual, - {mask_row_idx.value(), mask_col_idx.value()}, - {mask_gt_out.value()}, - "mask_gt"}; - - masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; - mask_select = { - op_id++, - op::kind::Select, - {mask_gt_out.value(), scaled_qk_out, params.neg_inf.value()}, - {masked_qk_out.value()}, - "mask_select"}; - } + } + // else if (is_causal) { + // mask_row_idx = {lt_id++, data_type::s32}; + // mask_gen_idx_row = { + // op_id++, + // op::kind::GenIndex, + // {scaled_qk_out}, + // {mask_row_idx.value()}, + // "mask_gen_idx_row"}; + // mask_gen_idx_row->set_attr(op::attr::axis, -2); + + // mask_col_idx = {lt_id++, data_type::s32}; + // mask_gen_idx_col = { + // op_id++, + // op::kind::GenIndex, + // {scaled_qk_out}, + // {mask_col_idx.value()}, + // "mask_gen_idx_col"}; + // mask_gen_idx_col->set_attr(op::attr::axis, -1); + + // mask_gt_out = {lt_id++, data_type::boolean}; + // mask_gt = { + // op_id++, + // op::kind::GreaterEqual, + // {mask_row_idx.value(), mask_col_idx.value()}, + // {mask_gt_out.value()}, + // "mask_gt"}; + + // masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; + // mask_select = { + // op_id++, + // op::kind::Select, + // {mask_gt_out.value(), scaled_qk_out, params.neg_inf.value()}, + // {masked_qk_out.value()}, + // "mask_select"}; + // } // attention_probs = softmax(masked_score) = exp(masked_score - logsumexp) logical_tensor sub_out{lt_id++, sdpa_intermedia_dtype}; @@ -963,10 +965,10 @@ void gpu_float_sdpa_backward( // mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32 // ukernel for implicit causal mask. // TODO: support causal once OneDNN support causal in backward pass. - if (true) { // || (is_causal && query.dtype() == at::kFloat)) { - attn_mask = get_tril_mask(); - is_causal = false; - } + // if (is_causal && query.dtype() == at::kFloat) { + // attn_mask = get_tril_mask(); + // is_causal = false; + // } std::vector l_inputs, l_outputs; std::optional compiled_partition; @@ -1006,12 +1008,12 @@ void gpu_float_sdpa_backward( Tensor softmax_scale = at::full( {}, scale, query.options().dtype(at::toOpMathType(query.scalar_type()))); std::optional neg_inf; - if (is_causal) { - neg_inf = at::full( - {}, - -INFINITY, - query.options().dtype(at::toOpMathType(query.scalar_type()))); - } + // if (is_causal) { + // neg_inf = at::full( + // {}, + // -INFINITY, + // query.options().dtype(at::toOpMathType(query.scalar_type()))); + // } std::vector outputs = { {l_outputs[0], eng, grad_query.data_ptr()}, @@ -1029,9 +1031,9 @@ void gpu_float_sdpa_backward( inputs.emplace_back(l_inputs[i++], eng, out.data_ptr()); inputs.emplace_back(l_inputs[i++], eng, logsumexp.data_ptr()); inputs.emplace_back(l_inputs[i++], eng, softmax_scale.data_ptr()); - if (neg_inf.has_value()) { - inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr()); - } + // if (neg_inf.has_value()) { + // inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr()); + // } if (attn_mask.has_value()) { inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr()); } diff --git a/test/test_transformers.py b/test/test_transformers.py index ea6bbc60a039..4c217c920b8f 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -4341,7 +4341,7 @@ def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: s (1, 32, 2016, 2016, 128), (4, 32, 2016, 2016, 128), ]) - @parametrize("mask_type", ["float",]) #"causal"]) + @parametrize("mask_type", ["float", "causal"]) @parametrize("train", [False, True]) def test_scaled_dot_product_fused_attention_mask_vs_math( self, @@ -4399,7 +4399,7 @@ def test_scaled_dot_product_fused_attention_mask_vs_math( attn_mask2 = attn_mask.float() if attn_mask is not None else None if fused_kernel == SDPBackend.OVERRIDEABLE: - with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE]): + with sdpa_kernel(backends=[SDPBackend.OVERRIDEABLE, SDPBackend.MATH]): actual = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal) else: diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index f8727ded4217..dec68892cd8c 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5781,7 +5781,11 @@ def meta__scaled_dot_product_fused_attention_overrideable_backward( grad_q = torch.empty_like(query) grad_k = torch.empty_like(key) grad_v = torch.empty_like(value) - return grad_q, grad_k, grad_v, None + + grad_attn_bias = None + if attn_bias is not None: + grad_attn_bias = torch.empty_like(attn_bias) + return grad_q, grad_k, grad_v, grad_attn_bias @register_meta( From 16768029181fc89786b84cbcbc0cf09a35789798 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Mon, 23 Jun 2025 23:27:00 -0700 Subject: [PATCH 03/11] rebase --- aten/src/ATen/native/mkldnn/xpu/Attention.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 4de49495ee37..8699f2dbd6ba 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -235,11 +235,9 @@ _scaled_dot_product_fused_attention_overrideable_xpu( batch_size, num_head_q, seq_len_q, head_dim_v}; alloc_with_matching_layout(query, attention, attention_shape); - at::Tensor logsumexp; - if (compute_logsumexp) { - logsumexp = at::empty( + auto opts = query.options(); + at::Tensor logsumexp = at::empty( {batch_size, num_head_q, seq_len_q}, opts.dtype(at::kFloat)); - } at::native::onednn::gpu_float_sdpa( batch_size, From 2fac2fd280c74bfa8b2ae6599fe1bc99c614cf3d Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Wed, 25 Jun 2025 03:17:51 -0700 Subject: [PATCH 04/11] add back causal mask --- aten/src/ATen/native/mkldnn/xpu/Attention.cpp | 10 +- .../native/mkldnn/xpu/detail/Attention.cpp | 146 +++++++++--------- 2 files changed, 73 insertions(+), 83 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 8699f2dbd6ba..9370e6b9584e 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -64,12 +64,7 @@ bool check_grad(sdp::sdp_params const& params, bool debug) { "scale_dot_product_attention on xpu is not supported with attn_mask.requires_grad() == True."); } - bool is_causal = params.is_causal; - if (debug && is_causal) { - TORCH_WARN( - "scale_dot_product_attention on xpu is not supported with is_causal == True for training."); - } - return !is_gqa && !attn_mask_needs_grad && !is_causal; + return !is_gqa && !attn_mask_needs_grad; } bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) { @@ -336,9 +331,6 @@ _scaled_dot_product_fused_attention_overrideable_backward_xpu( if (attn_bias.defined()) { attn_bias_opt = attn_bias; } - TORCH_INTERNAL_ASSERT( - !is_causal, - "scaled_dot_product_fused_attention_overrideable_backward_xpu: Curently do not support is_causal = True"); const int64_t batch_size = query.size(0); const int64_t num_head_q = query.size(1); diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp index 16e5e9c1944d..69d6946d1dbe 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp @@ -505,14 +505,14 @@ struct SDPABackwardLogicalParams { scalar_shape, logical_tensor::layout_type::strided, logical_tensor::property_type::constant}; - // if (is_causal) { - // neg_inf = { - // static_cast(TensorID::neg_inf), - // to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), - // scalar_shape, - // logical_tensor::layout_type::strided, - // logical_tensor::property_type::constant}; - // } + if (is_causal) { + neg_inf = { + static_cast(TensorID::neg_inf), + to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), + scalar_shape, + logical_tensor::layout_type::strided, + logical_tensor::property_type::constant}; + } if (attn_mask_.has_value()) { const data_type mask_dtype = to_logical_tensor_data_type(attn_mask_->scalar_type()); @@ -544,9 +544,9 @@ struct SDPABackwardLogicalParams { std::vector get_input() const { std::vector input = { grad_out, query, key, value, out, logsumexp, scale}; - // if (neg_inf.has_value()) { - // input.push_back(neg_inf.value()); - // } + if (neg_inf.has_value()) { + input.push_back(neg_inf.value()); + } if (attn_mask.has_value()) { input.push_back(attn_mask.value()); } @@ -594,13 +594,13 @@ partition create_sdpa_backward_graph_partition( std::optional mask_add; // For optional implicite causal mask - // std::optional mask_gen_idx_row; - // std::optional mask_row_idx; - // std::optional mask_gen_idx_col; - // std::optional mask_col_idx; - // std::optional mask_gt; - // std::optional mask_gt_out; - // std::optional mask_select; + std::optional mask_gen_idx_row; + std::optional mask_row_idx; + std::optional mask_gen_idx_col; + std::optional mask_col_idx; + std::optional mask_gt; + std::optional mask_gt_out; + std::optional mask_select; if (params.attn_mask.has_value()) { TORCH_INTERNAL_ASSERT( @@ -612,42 +612,41 @@ partition create_sdpa_backward_graph_partition( {scaled_qk_out, params.attn_mask.value()}, {masked_qk_out.value()}, "mask_add"}; - } - // else if (is_causal) { - // mask_row_idx = {lt_id++, data_type::s32}; - // mask_gen_idx_row = { - // op_id++, - // op::kind::GenIndex, - // {scaled_qk_out}, - // {mask_row_idx.value()}, - // "mask_gen_idx_row"}; - // mask_gen_idx_row->set_attr(op::attr::axis, -2); - - // mask_col_idx = {lt_id++, data_type::s32}; - // mask_gen_idx_col = { - // op_id++, - // op::kind::GenIndex, - // {scaled_qk_out}, - // {mask_col_idx.value()}, - // "mask_gen_idx_col"}; - // mask_gen_idx_col->set_attr(op::attr::axis, -1); - - // mask_gt_out = {lt_id++, data_type::boolean}; - // mask_gt = { - // op_id++, - // op::kind::GreaterEqual, - // {mask_row_idx.value(), mask_col_idx.value()}, - // {mask_gt_out.value()}, - // "mask_gt"}; - - // masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; - // mask_select = { - // op_id++, - // op::kind::Select, - // {mask_gt_out.value(), scaled_qk_out, params.neg_inf.value()}, - // {masked_qk_out.value()}, - // "mask_select"}; - // } + } else if (is_causal) { + mask_row_idx = {lt_id++, data_type::s32}; + mask_gen_idx_row = { + op_id++, + op::kind::GenIndex, + {scaled_qk_out}, + {mask_row_idx.value()}, + "mask_gen_idx_row"}; + mask_gen_idx_row->set_attr(op::attr::axis, -2); + + mask_col_idx = {lt_id++, data_type::s32}; + mask_gen_idx_col = { + op_id++, + op::kind::GenIndex, + {scaled_qk_out}, + {mask_col_idx.value()}, + "mask_gen_idx_col"}; + mask_gen_idx_col->set_attr(op::attr::axis, -1); + + mask_gt_out = {lt_id++, data_type::boolean}; + mask_gt = { + op_id++, + op::kind::GreaterEqual, + {mask_row_idx.value(), mask_col_idx.value()}, + {mask_gt_out.value()}, + "mask_gt"}; + + masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; + mask_select = { + op_id++, + op::kind::Select, + {mask_gt_out.value(), scaled_qk_out, params.neg_inf.value()}, + {masked_qk_out.value()}, + "mask_select"}; + } // attention_probs = softmax(masked_score) = exp(masked_score - logsumexp) logical_tensor sub_out{lt_id++, sdpa_intermedia_dtype}; @@ -754,12 +753,12 @@ partition create_sdpa_backward_graph_partition( if (mask_add.has_value()) { g.add_op(mask_add.value()); } - // if (is_causal) { - // g.add_op(mask_gen_idx_row.value()); - // g.add_op(mask_gen_idx_col.value()); - // g.add_op(mask_gt.value()); - // g.add_op(mask_select.value()); - // } + if (is_causal) { + g.add_op(mask_gen_idx_row.value()); + g.add_op(mask_gen_idx_col.value()); + g.add_op(mask_gt.value()); + g.add_op(mask_select.value()); + } g.add_op(subtract); g.add_op(exp); g.add_op(matmul_grad_value); @@ -964,11 +963,10 @@ void gpu_float_sdpa_backward( // and the reference implementation is worse than aten math + explict causal // mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32 // ukernel for implicit causal mask. - // TODO: support causal once OneDNN support causal in backward pass. - // if (is_causal && query.dtype() == at::kFloat) { - // attn_mask = get_tril_mask(); - // is_causal = false; - // } + if (is_causal && query.dtype() == at::kFloat) { + attn_mask = get_tril_mask(); + is_causal = false; + } std::vector l_inputs, l_outputs; std::optional compiled_partition; @@ -1008,12 +1006,12 @@ void gpu_float_sdpa_backward( Tensor softmax_scale = at::full( {}, scale, query.options().dtype(at::toOpMathType(query.scalar_type()))); std::optional neg_inf; - // if (is_causal) { - // neg_inf = at::full( - // {}, - // -INFINITY, - // query.options().dtype(at::toOpMathType(query.scalar_type()))); - // } + if (is_causal) { + neg_inf = at::full( + {}, + -INFINITY, + query.options().dtype(at::toOpMathType(query.scalar_type()))); + } std::vector outputs = { {l_outputs[0], eng, grad_query.data_ptr()}, @@ -1031,9 +1029,9 @@ void gpu_float_sdpa_backward( inputs.emplace_back(l_inputs[i++], eng, out.data_ptr()); inputs.emplace_back(l_inputs[i++], eng, logsumexp.data_ptr()); inputs.emplace_back(l_inputs[i++], eng, softmax_scale.data_ptr()); - // if (neg_inf.has_value()) { - // inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr()); - // } + if (neg_inf.has_value()) { + inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr()); + } if (attn_mask.has_value()) { inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr()); } From 2cfb19a7859ef2509aa8c4c7c818e44dffb7435d Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Wed, 25 Jun 2025 07:43:12 -0700 Subject: [PATCH 05/11] add compute_logsumexp to fix aot autograd --- aten/src/ATen/native/mkldnn/xpu/Attention.cpp | 2 +- aten/src/ATen/native/native_functions.yaml | 2 +- aten/src/ATen/native/transformers/attention.cpp | 6 +++++- tools/autograd/derivatives.yaml | 2 +- torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h | 2 +- torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h | 2 +- torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h | 2 +- 7 files changed, 11 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 9370e6b9584e..b79c27762c23 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -189,6 +189,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu( const at::Tensor& key, const at::Tensor& value, const std::optional& attn_bias, + bool compute_logsumexp, double dropout_p, bool is_causal, bool return_debug_mask, @@ -223,7 +224,6 @@ _scaled_dot_product_fused_attention_overrideable_xpu( const int64_t head_dim_v = value.size(3); const int64_t seq_len_q = query.size(2); const int64_t seq_len_kv = key.size(2); - const bool compute_logsumexp = input_require_grad(query, key, value, attn_bias); at::Tensor attention; std::vector attention_shape = { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0f1f3f925646..ae700b1dcd00 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14933,7 +14933,7 @@ CPU: _scaled_dot_product_flash_attention_cpu tags: nondeterministic_seeded -- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, bool compute_log_sumexp=False, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) dispatch: CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable XPU: _scaled_dot_product_fused_attention_overrideable_xpu diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 7aad4309924d..dc51385161a7 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -768,8 +768,11 @@ Tensor scaled_dot_product_attention( return std::get<0>(out_and_lse); } case SDPBackend::overrideable: { + bool compute_logsumexp = should_compute_logsumexp(query_, key, value); + compute_logsumexp = compute_logsumexp || + (at::GradMode::is_enabled() && attn_mask.has_value() && attn_mask.value().requires_grad()); auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable( - query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale); + query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale); return std::get<0>(out_lse_softmax); } case SDPBackend::math: { @@ -1012,6 +1015,7 @@ _scaled_dot_product_fused_attention_overrideable( const at::Tensor & key, const at::Tensor & value, const std::optional & attn_bias, + bool compute_logsumexp, double dropout_p, bool is_causal, bool return_debug_mask, diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index a778c1a85da0..5aaf6a278a5d 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2908,7 +2908,7 @@ output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) -- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, bool compute_log_sumexp=False, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index aced2b2f539d..eff0769c8516 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -36,7 +36,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, int32_t compute_log_sumexp, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 92d30ded855f..06db40911f65 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -42,7 +42,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_a AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double dropout_p, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, int32_t compute_log_sumexp, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 6fc51bd0c8f8..759b18f0fc31 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -14,7 +14,7 @@ extern "C" { AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, int32_t compute_log_sumexp, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0); From 6714c5d87dc76cf780c7c5f404c7381eab477f90 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Thu, 26 Jun 2025 04:38:34 -0700 Subject: [PATCH 06/11] add meta --- torch/_meta_registrations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index dec68892cd8c..19bdc14fc75b 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5721,6 +5721,7 @@ def meta__scaled_dot_product_fused_attention_overrideable( key: Tensor, value: Tensor, attn_bias: Optional[Tensor] = None, + compute_log_sumexp: bool = False, dropout_p: float = 0.0, is_causal: bool = False, return_debug_mask: bool = False, From c8d7c3b351e91cb50b7db1dffcd9c3b34cdaca69 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Thu, 24 Jul 2025 18:56:43 -0700 Subject: [PATCH 07/11] update oneDNN commit to v3.9 --- cmake/Modules/FindMKLDNN.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 09b8e1ab887e..805524ca011e 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -47,7 +47,7 @@ IF(NOT MKLDNN_FOUND) endif() ExternalProject_Add(xpu_mkldnn_proj GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN.git - GIT_TAG yixin/sdpa-training-impl + GIT_TAG rls-v3.9 PREFIX ${XPU_MKLDNN_DIR_PREFIX} BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx From 01c00852b65373b4fd9d463a15d33a9cc66903fa Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Mon, 4 Aug 2025 23:29:38 -0700 Subject: [PATCH 08/11] fix code lint --- aten/src/ATen/native/mkldnn/xpu/Attention.cpp | 54 +++++++++---------- .../ATen/native/transformers/attention.cpp | 2 +- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index b79c27762c23..c171bed8f364 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -39,16 +39,19 @@ bool check_head_dim_size_xpu(sdp::sdp_params const& params, bool debug) { return true; } -bool input_require_grad(const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const std::optional& attn_mask) { - return at::GradMode::is_enabled() && (query.requires_grad() || key.requires_grad() || value.requires_grad() - || (attn_mask.has_value() && attn_mask.value().requires_grad())); +bool input_require_grad( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask) { + return at::GradMode::is_enabled() && + (query.requires_grad() || key.requires_grad() || value.requires_grad() || + (attn_mask.has_value() && attn_mask.value().requires_grad())); } bool check_grad(sdp::sdp_params const& params, bool debug) { - if (!input_require_grad(params.query, params.key, params.value, params.attn_mask)) + if (!input_require_grad( + params.query, params.key, params.value, params.attn_mask)) return true; auto q_num_heads = params.query.sym_size(-3); @@ -56,9 +59,11 @@ bool check_grad(sdp::sdp_params const& params, bool debug) { auto v_num_heads = params.value.sym_size(-3); bool is_gqa = q_num_heads != k_num_heads || q_num_heads != v_num_heads; if (debug && is_gqa) - TORCH_WARN("scale_dot_product_attention with gqa is not supported for gradient computation on xpu."); + TORCH_WARN( + "scale_dot_product_attention with gqa is not supported for gradient computation on xpu."); - bool attn_mask_needs_grad = params.attn_mask.has_value() && params.attn_mask.value().requires_grad(); + bool attn_mask_needs_grad = + params.attn_mask.has_value() && params.attn_mask.value().requires_grad(); if (debug && attn_mask_needs_grad) { TORCH_WARN( "scale_dot_product_attention on xpu is not supported with attn_mask.requires_grad() == True."); @@ -214,7 +219,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu( !(attn_bias.has_value() && is_causal), "scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot present with is_causal"); TORCH_INTERNAL_ASSERT( - !(attn_bias.has_value() && attn_bias.value().requires_grad()), + !(attn_bias.has_value() && attn_bias.value().requires_grad()), "scaled_dot_product_fused_attention_overrideable_xpu: attn_bias cannot have requires_grad=True"); const int64_t batch_size = query.size(0); @@ -231,8 +236,8 @@ _scaled_dot_product_fused_attention_overrideable_xpu( alloc_with_matching_layout(query, attention, attention_shape); auto opts = query.options(); - at::Tensor logsumexp = at::empty( - {batch_size, num_head_q, seq_len_q}, opts.dtype(at::kFloat)); + at::Tensor logsumexp = + at::empty({batch_size, num_head_q, seq_len_q}, opts.dtype(at::kFloat)); at::native::onednn::gpu_float_sdpa( batch_size, @@ -267,11 +272,7 @@ _scaled_dot_product_fused_attention_overrideable_xpu( /*debug_attn_mask */ at::Tensor()); } -std::tuple< - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor> +std::tuple _scaled_dot_product_fused_attention_overrideable_backward_xpu( const at::Tensor& grad_out, const at::Tensor& query, @@ -291,12 +292,10 @@ _scaled_dot_product_fused_attention_overrideable_backward_xpu( const at::Tensor& philox_offset, std::optional scale) { TORCH_INTERNAL_ASSERT( - grad_out.dim() == 4 && out.dim() == 4 && - grad_out.size(0) == out.size(0) && - grad_out.size(1) == out.size(1) && - grad_out.size(2) == out.size(2) && - grad_out.size(3) == out.size(3), - "scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {(B), H, T, K}"); + grad_out.dim() == 4 && out.dim() == 4 && + grad_out.size(0) == out.size(0) && grad_out.size(1) == out.size(1) && + grad_out.size(2) == out.size(2) && grad_out.size(3) == out.size(3), + "scaled_dot_product_fused_attention_overrideable_backward_xpu: grad_out and out should have the same shape of {(B), H, T, K}"); TORCH_INTERNAL_ASSERT( query.dim() == 4 && key.dim() == 4 && value.dim() == 4, "scaled_dot_product_fused_attention_overrideable_backward_xpu: Accept only 4 dims inputs shape of {(B), H, T, K}"); @@ -305,9 +304,8 @@ _scaled_dot_product_fused_attention_overrideable_backward_xpu( (key.size(2) == value.size(2)), "scaled_dot_product_fused_attention_overrideable_backward_xpu: K/V should have the same batch / seq / num_head"); TORCH_INTERNAL_ASSERT( - query.size(0) == grad_out.size(0) && - query.size(1) == grad_out.size(1) && - query.size(2) == grad_out.size(2), + query.size(0) == grad_out.size(0) && query.size(1) == grad_out.size(1) && + query.size(2) == grad_out.size(2), "scaled_dot_product_fused_attention_overrideable_backward_xpu: Q should have the same batch / num_head / seq_len as grad_out"); TORCH_INTERNAL_ASSERT( query.size(3) == key.size(3), @@ -321,8 +319,8 @@ _scaled_dot_product_fused_attention_overrideable_backward_xpu( TORCH_INTERNAL_ASSERT( dropout_p == 0.0, "scaled_dot_product_fused_attention_overrideable_backward_xpu: Currently do not support dropout > 0"); - TORCH_INTERNAL_ASSERT(logsumexp.dim() == 3 && - logsumexp.size(0) == query.size(0) && + TORCH_INTERNAL_ASSERT( + logsumexp.dim() == 3 && logsumexp.size(0) == query.size(0) && logsumexp.size(1) == query.size(1) && logsumexp.size(2) == query.size(2) && "scaled_dot_product_fused_attention_overrideable_backward_xpu: logsumexp should have the shape of {(B), H, T}"); diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index dc51385161a7..9ae0bbb96442 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -769,7 +769,7 @@ Tensor scaled_dot_product_attention( } case SDPBackend::overrideable: { bool compute_logsumexp = should_compute_logsumexp(query_, key, value); - compute_logsumexp = compute_logsumexp || + compute_logsumexp = compute_logsumexp || (at::GradMode::is_enabled() && attn_mask.has_value() && attn_mask.value().requires_grad()); auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable( query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale); From b028d72428bcbf2b592c285da948c9ceb56451ad Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Mon, 4 Aug 2025 23:46:59 -0700 Subject: [PATCH 09/11] remove debug flag --- cmake/Modules/FindMKLDNN.cmake | 1 - 1 file changed, 1 deletion(-) diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 805524ca011e..62ce07f915fe 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -56,7 +56,6 @@ IF(NOT MKLDNN_FOUND) -DDNNL_CPU_RUNTIME=THREADPOOL -DDNNL_BUILD_TESTS=OFF -DDNNL_BUILD_EXAMPLES=OFF - -DONEDNN_ENABLE_GRAPH_DUMP=ON -DONEDNN_BUILD_GRAPH=ON -DDNNL_LIBRARY_TYPE=STATIC -DDNNL_DPCPP_HOST_COMPILER=${DNNL_HOST_COMPILER} # Use global cxx compiler as host compiler From 7b46e0f0aa3add66d2ba904c5b9570fb81bc68a3 Mon Sep 17 00:00:00 2001 From: mayuyuace Date: Sun, 10 Aug 2025 22:16:28 -0700 Subject: [PATCH 10/11] code refine --- .../native/mkldnn/xpu/detail/Attention.cpp | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp index 69d6946d1dbe..a2d3392df55b 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp @@ -13,7 +13,7 @@ using dims = logical_tensor::dims; using op = dnnl::graph::op; using partition = dnnl::graph::partition; -constexpr logical_tensor::data_type sdpa_intermedia_dtype = +constexpr logical_tensor::data_type sdpa_intermediate_dtype = logical_tensor::data_type::f32; inline data_type to_logical_tensor_data_type(c10::ScalarType scalar_type) { @@ -161,7 +161,7 @@ struct SDPALogicalParams { " instead."); logsumexp = { static_cast(TensorID::logsumexp), - sdpa_intermedia_dtype, + sdpa_intermediate_dtype, reshaped_logsumexp.sizes().vec(), reshaped_logsumexp.strides().vec()}; } @@ -202,7 +202,7 @@ partition create_sdpa_graph_partition( // Matrix Extensions (Intel(R) XMX) support, which means the // Q/K/V tensors have bf16 or f16 data type while the output of the first // MatMul, Scale, Mask, and the input of SoftMax are in f32 data type. - logical_tensor matmul_qk_out{lt_id++, sdpa_intermedia_dtype}; + logical_tensor matmul_qk_out{lt_id++, sdpa_intermediate_dtype}; op matmul_qk{ op_id++, op::kind::MatMul, @@ -211,7 +211,7 @@ partition create_sdpa_graph_partition( "matmul_qk"}; matmul_qk.set_attr(op::attr::transpose_b, true); - logical_tensor scaled_qk_out{lt_id++, sdpa_intermedia_dtype}; + logical_tensor scaled_qk_out{lt_id++, sdpa_intermediate_dtype}; op scale_mul{ op_id++, op::kind::Multiply, @@ -236,7 +236,7 @@ partition create_sdpa_graph_partition( if (params.attn_mask.has_value()) { TORCH_INTERNAL_ASSERT( !is_causal, "Additive mask cannot use with is_causal."); - masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; + masked_qk_out = {lt_id++, sdpa_intermediate_dtype}; mask_add = { op_id++, op::kind::Add, @@ -271,7 +271,7 @@ partition create_sdpa_graph_partition( {mask_gt_out.value()}, "mask_gt"}; - masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; + masked_qk_out = {lt_id++, sdpa_intermediate_dtype}; mask_select = { op_id++, op::kind::Select, @@ -496,7 +496,7 @@ struct SDPABackwardLogicalParams { reshaped_out.strides().vec()}; logsumexp = { static_cast(TensorID::logsumexp), - sdpa_intermedia_dtype, + sdpa_intermediate_dtype, reshaped_logsumexp.sizes().vec(), reshaped_logsumexp.strides().vec()}; scale = { @@ -571,7 +571,7 @@ partition create_sdpa_backward_graph_partition( // Matrix Extensions (Intel(R) XMX) support, which means the // Q/K/V tensors have bf16 or f16 data type while the output of the first // MatMul, Scale, Mask, and the input of SoftMax are in f32 data type. - logical_tensor matmul_qk_out{lt_id++, sdpa_intermedia_dtype}; + logical_tensor matmul_qk_out{lt_id++, sdpa_intermediate_dtype}; op matmul_qk{ op_id++, op::kind::MatMul, @@ -580,7 +580,7 @@ partition create_sdpa_backward_graph_partition( "matmul_qk"}; matmul_qk.set_attr(op::attr::transpose_b, true); - logical_tensor scaled_qk_out{lt_id++, sdpa_intermedia_dtype}; + logical_tensor scaled_qk_out{lt_id++, sdpa_intermediate_dtype}; op scale_mul{ op_id++, op::kind::Multiply, @@ -605,7 +605,7 @@ partition create_sdpa_backward_graph_partition( if (params.attn_mask.has_value()) { TORCH_INTERNAL_ASSERT( !is_causal, "Additive mask cannot use with is_causal."); - masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; + masked_qk_out = {lt_id++, sdpa_intermediate_dtype}; mask_add = { op_id++, op::kind::Add, @@ -639,7 +639,7 @@ partition create_sdpa_backward_graph_partition( {mask_gt_out.value()}, "mask_gt"}; - masked_qk_out = {lt_id++, sdpa_intermedia_dtype}; + masked_qk_out = {lt_id++, sdpa_intermediate_dtype}; mask_select = { op_id++, op::kind::Select, @@ -649,21 +649,21 @@ partition create_sdpa_backward_graph_partition( } // attention_probs = softmax(masked_score) = exp(masked_score - logsumexp) - logical_tensor sub_out{lt_id++, sdpa_intermedia_dtype}; + logical_tensor sub_out{lt_id++, sdpa_intermediate_dtype}; op subtract{ op_id++, op::kind::Subtract, {masked_qk_out.value_or(scaled_qk_out), params.logsumexp}, {sub_out}, "subtract"}; - logical_tensor prob{lt_id++, sdpa_intermedia_dtype}; + logical_tensor prob{lt_id++, sdpa_intermediate_dtype}; op exp{op_id++, op::kind::Exp, {sub_out}, {prob}, "exp"}; // The following matmul doesn't support different input dtypes, insert a // typecast logical_tensor prob_casted = prob; op typecast = op(op_id++, op::kind::TypeCast, "typecast"); - if (dtype != sdpa_intermedia_dtype) { + if (dtype != sdpa_intermediate_dtype) { prob_casted = logical_tensor(lt_id++, dtype); typecast.add_inputs({prob}); typecast.add_outputs({prob_casted}); @@ -685,7 +685,7 @@ partition create_sdpa_backward_graph_partition( // TODO: handle GQA headnum because (batch_size, num_head_q, seq_len_q, // seq_len_kv) != (batch_size, num_head_q, seq_len_q, head_dim_v) * // (batch_size, num_head_kv, head_dim_v, seq_len_kv) - logical_tensor grad_prop{lt_id++, sdpa_intermedia_dtype}; + logical_tensor grad_prop{lt_id++, sdpa_intermediate_dtype}; op matmul_grad_prop{ op_id++, op::kind::MatMul, @@ -695,7 +695,7 @@ partition create_sdpa_backward_graph_partition( matmul_grad_prop.set_attr(op::attr::transpose_b, true); // grad_masked_score = softmaxbackward(grad_prop) - logical_tensor grad_masked_score{lt_id++, sdpa_intermedia_dtype}; + logical_tensor grad_masked_score{lt_id++, sdpa_intermediate_dtype}; op softmax_backward{ op_id++, op::kind::SoftMaxBackward, @@ -708,7 +708,7 @@ partition create_sdpa_backward_graph_partition( // supports output grad_attn_mask. // grad_scaled_score = grad_masked_score * scale - logical_tensor grad_scaled_score{lt_id++, sdpa_intermedia_dtype}; + logical_tensor grad_scaled_score{lt_id++, sdpa_intermediate_dtype}; op grad_scale_mul{ op_id++, op::kind::Multiply, @@ -720,7 +720,7 @@ partition create_sdpa_backward_graph_partition( // typecast logical_tensor grad_scaled_score_cast = grad_scaled_score; op typecast2 = op(op_id++, op::kind::TypeCast, "typecast2"); - if (dtype != sdpa_intermedia_dtype) { + if (dtype != sdpa_intermediate_dtype) { grad_scaled_score_cast = logical_tensor(lt_id++, dtype); typecast2.add_inputs({grad_scaled_score}); typecast2.add_outputs({grad_scaled_score_cast}); @@ -767,7 +767,7 @@ partition create_sdpa_backward_graph_partition( g.add_op(grad_scale_mul); g.add_op(matmul_grad_query); g.add_op(matmul_grad_key); - if (dtype != sdpa_intermedia_dtype) { + if (dtype != sdpa_intermediate_dtype) { g.add_op(typecast); g.add_op(typecast2); } From d6ddf7aa325eb7650bb98ca8c30b17529dba3526 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Mon, 11 Aug 2025 18:44:43 -0700 Subject: [PATCH 11/11] fix typo --- aten/src/ATen/native/mkldnn/xpu/Attention.cpp | 2 +- aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index c171bed8f364..81618400b7ae 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -66,7 +66,7 @@ bool check_grad(sdp::sdp_params const& params, bool debug) { params.attn_mask.has_value() && params.attn_mask.value().requires_grad(); if (debug && attn_mask_needs_grad) { TORCH_WARN( - "scale_dot_product_attention on xpu is not supported with attn_mask.requires_grad() == True."); + "scale_dot_product_attention on xpu is not supported when attn_mask.requires_grad() == True."); } return !is_gqa && !attn_mask_needs_grad; diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp index a2d3392df55b..8d39b80f7d0f 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp @@ -851,8 +851,8 @@ void gpu_float_sdpa( }; // OneDNN doesn't support fp32 ukernel for implicit causal mask, - // and the reference implementation is worse than aten math + explict causal - // mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32 + // and the reference implementation is worse than aten math + explicit causal + // mask. Fall back to explicit causal mask until OneDNN v3.9 which has fp32 // ukernel for implicit causal mask. if (is_causal && query.dtype() == at::kFloat) { attn_mask = get_tril_mask(); @@ -960,8 +960,8 @@ void gpu_float_sdpa_backward( }; // OneDNN doesn't support fp32 ukernel for implicit causal mask, - // and the reference implementation is worse than aten math + explict causal - // mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32 + // and the reference implementation is worse than aten math + explicit causal + // mask. Fall back to explicit causal mask until OneDNN v3.9 which has fp32 // ukernel for implicit causal mask. if (is_causal && query.dtype() == at::kFloat) { attn_mask = get_tril_mask();