-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[WIP] cast to bf16 before mul op in flex bwd #154922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154922
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit be3407e with merge base 48807d5 ( NEW FAILURES - The following jobs have failed:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
a15c930
to
be3407e
Compare
# If mul was upcasted from fp8 to bf16, we need to downcast it back to fp8. | ||
if upcast_from_fp8: | ||
mul_delta = lowerings[prims.convert_element_type](mul_delta, orig_fp8_dtype) | ||
|
||
delta = lowerings[aten.sum](mul_delta, axis=-1) | ||
delta = lowerings[aten.sub](delta, grad_lse_exp2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm separate from the error you're seeing, are you going to run into similar issues trying to run aten.sum/sub
in triton with fp8 inputs? You may need to delay downcasting back to fp8 til after these ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this is something we could do elsewhere. potentially similar to how we upcast fp16 for unsupported operators.
see
pytorch/torch/_inductor/codegen/triton.py
Line 775 in a1a268a
def maybe_upcast_float32(convert_output: bool = True) -> Callable[[_T], _T]: |
although that's just pointwise, not for sum..
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Issue: #154750
fp8 flex attention works for forward but errors in backward due to mul op being supported for fp8 dtypes.
This PR is WIP figuring out the best way to address this.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben