Skip to content

[Flex Attn][CPU] support flash decoding for cpu #159835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Valentine233
Copy link
Collaborator

@Valentine233 Valentine233 commented Aug 5, 2025

Description:

  1. Support flash decoding in CppFlexAttentionTemplate. We prefer to choose flash decoding instead of flash attention when query length is 1.
  2. For flash decoding, we add a kernel option PARTITION_SIZE to define the partition size of doing the parallelism on KV length dimension. The default value is 128, which should be multiple of KV cache block size to use flash decoding.
  3. As mentioned in Fix large_tensor_test skipping cpu #158617, flex_attn UTs for the cpu backend are disabled because of the long duration. Here we re-enable the essential ones.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

Copy link

pytorch-bot bot commented Aug 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159835

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 35eb3f4 with merge base 8a2f53c (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Valentine233
Copy link
Collaborator Author

Valentine233 commented Aug 6, 2025

@jianan-gu @CaoE Please help review, thanks~

@CaoE CaoE added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants