Skip to content

[Draft][WIP] Enable XPU path for FlexAttention #143553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 107 commits into
base: main
Choose a base branch
from

Conversation

liangan1
Copy link

@liangan1 liangan1 commented Dec 19, 2024

Motivation

  1. The Attention has been the critical performance bottleneck in the current LLM models, and FlexAttention is a good choice to cover the broad variants in the transformers series models. With FlexAttention, it is easy for us to enable the paged attention and fused SDPA in the transformers repo on XPU device. Besides, it also provide a candidate to process attention in LLM ecosystem libraries ., e.g., vLLM, SGLang on XPU device.
  2. FlexAttention is good start point to push the intel triton based GEMM kernel to be matured. FlexAttention provide both flexattention kernel and flexdecoding kernel to cover both compute bound and memory bound GEMM computation, and different shapes should also been supported to serve LLM inference., e.g. head_dim=64, 96, 128, 256.

What does this PR do?

  1. Enable the device type for Flexattention kernel and UTs to ensure all important UTs pass on XPU device.
  2. For E2E model inference, ensure the functionality of LLM models inference with FlexAttention to be ready.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela @yf225 @ColinPeppler @desertfire

Copy link

pytorch-bot bot commented Dec 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143553

Note: Links to docs will display an error until the docs builds have been completed.

❌ 9 New Failures, 1 Pending

As of commit 3de28ca with merge base 01bcf9a (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Dec 19, 2024

@liangan1 liangan1 marked this pull request as draft December 19, 2024 04:39
@EikanWang EikanWang added topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module ciflow/xpu Run XPU CI tasks labels Dec 24, 2024
Copy link

pytorch-bot bot commented Dec 24, 2024

To add the ciflow label ciflow/xpu please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Dec 24, 2024
@EikanWang EikanWang self-requested a review December 24, 2024 02:14
@EikanWang EikanWang added the ciflow/xpu Run XPU CI tasks label Dec 24, 2024
Copy link

pytorch-bot bot commented Dec 24, 2024

To add the ciflow label ciflow/xpu please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Dec 24, 2024
@liangan1
Copy link
Author

@pytorchbot rebase

Copy link

pytorch-bot bot commented Feb 10, 2025

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

if config.max_autotune:
if config.max_autotune_flex_search_space == "EXHAUSTIVE":
return self.exhaustive_flex_attn_fwd_configs
flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @hoshibara, could you define flex_attn_fwd_autotune_configs in XPUConfigHeuristic instead of using it from base class.

It is very likely that for xpu these configs will be different, although now it can be left the same, but it will make patching easier for us. More details: intel/intel-xpu-backend-for-triton#4265 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll add it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

@@ -461,8 +466,9 @@ def run_test(
block_mask = create_block_mask(
noop_mask, Q_B, Q_H, Q_S, KV_S, device=self.device
)
gold_dtype = torch.float64 if not HAS_XPU else torch.float32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does Intel GPU need to downgrade to FP32?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have been removed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chuanqi129 , have you submitted a pr to update the infra?

@@ -3803,7 +3875,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3

class mask_graph0(torch.nn.Module):
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type=GPU_TYPE, index=0), pin_memory = False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU_TYPE is a placeholder to be replaced by the actual running device type, so it should not be modified.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does FlexAttention depend on torch-xpu-ops?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to test Yutao's fix PR. This is just a temp modification.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chuanqi129 , ditto

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to test the CI pass rate on the rolling driver. It will be removed before merging.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Copy link

pytorch-bot bot commented Aug 7, 2025

To add the ciflow label ciflow/inductor please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/xpu Run XPU CI tasks label Aug 8, 2025
@etaf etaf added the ciflow/xpu Run XPU CI tasks label Aug 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/xpu Run XPU CI tasks keep-going Don't stop on first failure, keep running tests until the end module: dynamo module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.