-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[FSDP] Add FrozenParamHandle to optimize memory for frozen parameters #159751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[FSDP] Add FrozenParamHandle to optimize memory for frozen parameters #159751
Conversation
- Introduce FrozenParamHandle class for frozen parameters - Skip optimizer state allocation for frozen params - Optimize CPU offload by keeping frozen params on CPU - Skip gradient operations for frozen parameters This reduces memory overhead from 8.85X to ~1.05X for PEFT/LoRA workloads where 99%+ of parameters are frozen. Testing shows: - 75% reduction in optimizer parameters tracked (64 -> 16) - Fixes the 1.65X memory overhead issue reported in pytorch#91165 - Enables efficient PEFT/LoRA training on consumer GPUs Fixes pytorch#91165
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159751
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 86ef3b3 with merge base a626dc8 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
"""Frozen parameters never need gradient synchronization.""" | ||
return False | ||
|
||
def prepare_gradient_for_backward(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these be typing.override or are they new hooks that should be overriden. If the former, type explicitly with typing_extension to prevent refactoring bugs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! You're correct - these methods (needs_gradient_sync
, prepare_gradient_for_backward
, and prepare_gradient_for_optim
) are indeed overrides from FlatParamHandle
. I've added the @override
decorator to all three methods to make this explicit. Changes pushed!
As suggested in review, adding typing_extensions.override decorators to make method overrides explicit and prevent refactoring bugs.
@weifengpy @awgu Thank you for your time on this PR. I've been thinking about the feedback and studying the codebase further. I realize that creating a new Would you prefer a simpler approach? I could:
This would be a much smaller change (~20 lines) while still providing the 35-50% memory savings needed for PEFT/LoRA workloads. It would complete the frozen parameter optimization that #101982 started. Happy to refactor this PR to that simpler approach if you think it's more appropriate for FSDP1. What are your thoughts? |
Summary
Fixes #91165
This PR introduces
FrozenParamHandle
to optimize FSDP memory usage when training models with mostly frozen parameters (common in PEFT/LoRA scenarios).Motivation
FSDP currently treats frozen and trainable parameters identically, leading to:
Changes
New
FrozenParamHandle
class (_flat_param.py
)FlatParamHandle
for frozen parametersDetection logic (
_init_utils.py
)FrozenParamHandle
instead ofFlatParamHandle
when appropriateRuntime optimizations (
_runtime_utils.py
)Optimizer state skipping (
_optim_utils.py
)Results
Testing with models having 99%+ frozen parameters shows:
Verified Test Results
Tested multiple configurations to confirm dynamic behavior:
Each frozen module is correctly identified and wrapped with
FrozenParamHandle
,excluding it from optimizer state allocation.
Testing
Impact
This is particularly important for the PEFT/LoRA community where 99%+ of parameters are typically frozen. The optimization reduces memory overhead significantly, making it practical to fine-tune large models on consumer GPUs.
Usage Note
For optimal results with PEFT/LoRA, structure models so frozen and trainable parameters
are in separate FSDP units. Example:
This ensures FSDP can properly identify and optimize frozen parameters.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @rohan-varma @pacman100 @sgugger @muellerzr