-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[Draft][WIP] Enable XPU path for FlexAttention #143553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143553
Note: Links to docs will display an error until the docs builds have been completed. ❌ 9 New Failures, 1 PendingAs of commit 3de28ca with merge base 01bcf9a ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
@pytorchbot rebase |
You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra. |
if config.max_autotune: | ||
if config.max_autotune_flex_search_space == "EXHAUSTIVE": | ||
return self.exhaustive_flex_attn_fwd_configs | ||
flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @hoshibara, could you define flex_attn_fwd_autotune_configs
in XPUConfigHeuristic
instead of using it from base class.
It is very likely that for xpu these configs will be different, although now it can be left the same, but it will make patching easier for us. More details: intel/intel-xpu-backend-for-triton#4265 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I'll add it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
index_put fix in torch xpu ops
test/inductor/test_flex_attention.py
Outdated
@@ -461,8 +466,9 @@ def run_test( | |||
block_mask = create_block_mask( | |||
noop_mask, Q_B, Q_H, Q_S, KV_S, device=self.device | |||
) | |||
gold_dtype = torch.float64 if not HAS_XPU else torch.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does Intel GPU need to downgrade to FP32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have been removed
.ci/docker/common/install_xpu.sh
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chuanqi129 , have you submitted a pr to update the infra?
@@ -3803,7 +3875,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3 | |||
class mask_graph0(torch.nn.Module): | |||
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"): | |||
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |||
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) | |
full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type=GPU_TYPE, index=0), pin_memory = False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPU_TYPE is a placeholder to be replaced by the actual running device type, so it should not be modified.
third_party/xpu.txt
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does FlexAttention depend on torch-xpu-ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to test Yutao's fix PR. This is just a temp modification.
.ci/docker/ubuntu-xpu/Dockerfile
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chuanqi129 , ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is to test the CI pass rate on the rolling driver. It will be removed before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
update xpu ops pin add intel gpu specific autotune configs
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
Motivation
What does this PR do?
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela @yf225 @ColinPeppler @desertfire