-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[FSDP] Reshard frozen params in backward #101982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101982
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ed8b8fc: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass. - Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass. - The approach is to register a multi-grad post-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard). This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called). [ghstack-poisoned]
My understanding might be wrong here, but frozen parameters being resharded late (I'm assuming they're resharded in _catch_all_reshard) is similar to how we reshard "unused parameters" in the catch all reshard as well. Could we use this technique to completely eliminate the _catch_all_reshard? |
@rohan-varma The catch-all reshard is still useful in the case that the very first input activation does not require gradient. Then, we have no activation on which we can reshard the root FSDP instance's parameters if they are frozen. |
@awgu @rohan-varma Is this MR ready for merge or not? What is the blocker? |
@pengyanghua I was waiting for a fix on autograd side and have not had a chance to rebase it yet. I will do that soon. |
This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass. - Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass. - The approach is to register a multi-grad post-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard). This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called). [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome work!
_p_assert( | ||
len(flat_param._post_backward_hook_state) == 2, | ||
post_backward_hook_state_len == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why can it be 1 or 2 now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the parameters are frozen (requires_grad=False
), we do not have an AccumulateGrad
object anymore, so the state looks like:
handle.flat_param._post_backward_hook_state = hook_handle
Normally, it is both the hook_handle
and the acc_grad
.
@@ -1360,6 +1373,39 @@ def _register_post_backward_hooks( | |||
flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined] | |||
|
|||
|
|||
def _register_post_backward_reshard_only_hooks( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like now we have 2 paths where params can be resharded in post backward:
- for flat parameters that don't require grad, do it via a hook on the input activations
- for flat parameters that do require grad, do it with the standard post backward hook
could we unify and do both the reshards with the input activations hook?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resharding on the input activation can be later than our existing post-backward hook, which may regress memory. We should keep both paths and view this newly added path as the fallback for the requires_grad=False
case.
@pytorchbot merge |
Merge failedReason: HTTP Error 403: rate limit exceeded Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@awgu , thanks for the great work. but I just come across a problem that might be related.
any idea? I am using the latest nightly version that I could find to support cuda117. That is |
@jcyk Would it be possible to share a repro? |
Stack from ghstack (oldest at bottom):
This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass.
post-hook on the input activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard).This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" fromThis was resolved in #102859.register_multi_grad_hook()
. I find that withregister_multi_grad_hook()
, sometimes the unit test counting the number of times_post_backward_reshard()
is called fails (due to it not being called).