Skip to content

[FSDP2] cast unsharded_param_grad to correct reduce dtype #160279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

tonyf
Copy link

@tonyf tonyf commented Aug 10, 2025

This is an edge case when using gradient accumulation steps with FSDP2.

If a parameter within a parameter group doesnt have gradients for all but the final backwards pass, the grad is not casted to the reduce dtype resulting in inconsistent dtype between gradients for reduce_scatter.

Specifically, if using a MixedPrecisionPolicy with param_dtype=torch.bfloat16 and reduce_dtype=torch.float32, parameters that had grads for both steps will have gradients of dtype torch.float32, but those that have a gradient only on the final pass will have gradients of dtype torch.bfloat16. This raises:

AssertionError: FSDP reduce-scatter expects uniform gradient dtype but got {torch.bfloat16, torch.float32}

This pr explicitly casts the grad to the correct dtype for the given case.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

Copy link

pytorch-bot bot commented Aug 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160279

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 12 Pending, 2 Unrelated Failures

As of commit 42b057a with merge base 05c19d1 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Aug 10, 2025
Copy link

linux-foundation-easycla bot commented Aug 10, 2025

CLA Signed


The committers listed above are authorized under a signed CLA.

@tonyf tonyf changed the title [FSDP2] cast unsharded param to reduce dtype [FSDP2] cast unsharded_param_grad to reduce dtype Aug 10, 2025
@tonyf tonyf changed the title [FSDP2] cast unsharded_param_grad to reduce dtype [FSDP2] cast unsharded_param_grad to correct reduce dtype Aug 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants