Skip to content

[FSDP] Reshard frozen params in backward #101982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed

Conversation

awgu
Copy link
Collaborator

@awgu awgu commented May 22, 2023

Stack from ghstack (oldest at bottom):

This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass.

  • Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass.
  • The approach is to register a multi-grad post-hook on the input activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard).

This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from register_multi_grad_hook(). I find that with register_multi_grad_hook(), sometimes the unit test counting the number of times _post_backward_reshard() is called fails (due to it not being called). This was resolved in #102859.

@pytorch-bot
Copy link

pytorch-bot bot commented May 22, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101982

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ed8b8fc:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@awgu awgu added the topic: improvements topic category label May 22, 2023
This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass.
- Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass.
- The approach is to register a multi-grad post-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard).

This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called). 

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request May 22, 2023
ghstack-source-id: 71b0854
Pull Request resolved: #101982
@rohan-varma
Copy link
Member

My understanding might be wrong here, but frozen parameters being resharded late (I'm assuming they're resharded in _catch_all_reshard) is similar to how we reshard "unused parameters" in the catch all reshard as well. Could we use this technique to completely eliminate the _catch_all_reshard?

@awgu
Copy link
Collaborator Author

awgu commented May 22, 2023

@rohan-varma The catch-all reshard is still useful in the case that the very first input activation does not require gradient. Then, we have no activation on which we can reshard the root FSDP instance's parameters if they are frozen.

@pengyanghua
Copy link

@awgu @rohan-varma Is this MR ready for merge or not? What is the blocker?

@awgu
Copy link
Collaborator Author

awgu commented Jun 8, 2023

@pengyanghua I was waiting for a fix on autograd side and have not had a chance to rebase it yet. I will do that soon.

This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass.
- Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass.
- The approach is to register a multi-grad post-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard).

This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called). 

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Jun 8, 2023
ghstack-source-id: 378e72f
Pull Request resolved: #101982
@awgu awgu marked this pull request as ready for review June 8, 2023 17:47
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 8, 2023
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome work!

_p_assert(
len(flat_param._post_backward_hook_state) == 2,
post_backward_hook_state_len == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can it be 1 or 2 now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the parameters are frozen (requires_grad=False), we do not have an AccumulateGrad object anymore, so the state looks like:

handle.flat_param._post_backward_hook_state = hook_handle

Normally, it is both the hook_handle and the acc_grad.

@@ -1360,6 +1373,39 @@ def _register_post_backward_hooks(
flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined]


def _register_post_backward_reshard_only_hooks(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like now we have 2 paths where params can be resharded in post backward:

  • for flat parameters that don't require grad, do it via a hook on the input activations
  • for flat parameters that do require grad, do it with the standard post backward hook

could we unify and do both the reshards with the input activations hook?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resharding on the input activation can be later than our existing post-backward hook, which may regress memory. We should keep both paths and view this newly added path as the fallback for the requires_grad=False case.

@awgu
Copy link
Collaborator Author

awgu commented Jun 8, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: HTTP Error 403: rate limit exceeded

Details for Dev Infra team Raised by workflow job

@awgu
Copy link
Collaborator Author

awgu commented Jun 8, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jcyk
Copy link

jcyk commented Aug 5, 2023

@awgu , thanks for the great work. but I just come across a problem that might be related.
when I freeze some parameters (by setting requires_grad = False)
I encounter the following error:

File "/usr/local/python/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 2400, in _che
ck_storage_allocated
    _reshard(state, [handle], [free_unsharded_flat_param])
  File "/usr/local/python/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 389, in _
reshard
    self._free_unsharded_flat_param()
  File "/usr/local/python/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 1629, in _fre
e_unsharded_flat_param
    handle.reshard(free_unsharded_flat_param)
  File "/usr/local/python/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 1600, in resh
ard
    _reshard(state, [handle], [free_unsharded_flat_param])
  File "/usr/local/python/lib/python3.8/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 389, in _
reshard
    _p_assert(storage_size > 0, "Expects storage to be allocated")
  File "/usr/local/python/lib/python3.8/site-packages/torch/distributed/utils.py", line 147, in _p_assert
    self._free_unsharded_flat_param()
  File "/usr/local/python/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 1629, in _fre
e_unsharded_flat_param
    handle.reshard(free_unsharded_flat_param)
  File "/usr/local/python/lib/python3.8/site-packages/torch/distributed/fsdp/flat_param.py", line 1600, in resh
ard
    raise AssertionError(s)
AssertionError: Expects storage to be allocated

any idea?

I am using the latest nightly version that I could find to support cuda117. That is torch-2.1.0.dev20230621+cu117-cp38-cp38-linux_x86_64.whl

@awgu
Copy link
Collaborator Author

awgu commented Aug 13, 2023

@jcyk Would it be possible to share a repro?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants