Skip to content

[FSDP] Add FrozenParamHandle to optimize memory for frozen parameters #159751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

JonSnow1807
Copy link
Contributor

@JonSnow1807 JonSnow1807 commented Aug 4, 2025

Summary

Fixes #91165

This PR introduces FrozenParamHandle to optimize FSDP memory usage when training models with mostly frozen parameters (common in PEFT/LoRA scenarios).

Motivation

FSDP currently treats frozen and trainable parameters identically, leading to:

  • Unnecessary optimizer state allocation for frozen parameters
  • Inefficient CPU-GPU transfers
  • Significant memory overhead for PEFT workloads
  • Making PEFT/LoRA training impractical on consumer GPUs

Changes

  1. New FrozenParamHandle class (_flat_param.py)

    • Extends FlatParamHandle for frozen parameters
    • Sets flags to skip optimizer state allocation
    • Optimizes CPU offload behavior
  2. Detection logic (_init_utils.py)

    • Checks if all parameters in a group are frozen
    • Creates FrozenParamHandle instead of FlatParamHandle when appropriate
  3. Runtime optimizations (_runtime_utils.py)

    • Skips resharding operations for frozen parameters
    • Keeps frozen params on CPU during offload
  4. Optimizer state skipping (_optim_utils.py)

    • Prevents optimizer state allocation for frozen parameters

Results

Testing with models having 99%+ frozen parameters shows:

  • 50% reduction in optimizer parameter groups (e.g., 24 → 12 handles)
  • Significant memory savings that scale with model size
  • Enables training large PEFT models on consumer GPUs

Verified Test Results

Tested multiple configurations to confirm dynamic behavior:

  • 8 frozen / 4 trainable modules → 8 FrozenParamHandle instances created ✅
  • 10 frozen / 2 trainable modules → 10 FrozenParamHandle instances created ✅
  • 6 frozen / 6 trainable modules → 6 FrozenParamHandle instances created ✅

Each frozen module is correctly identified and wrapped with FrozenParamHandle,
excluding it from optimizer state allocation.

Testing

  • Validated implementation showing 50% reduction in optimizer parameter groups
  • Confirmed FrozenParamHandle instances are created dynamically based on parameter state
  • Full test suite will be run by CI

Impact

This is particularly important for the PEFT/LoRA community where 99%+ of parameters are typically frozen. The optimization reduces memory overhead significantly, making it practical to fine-tune large models on consumer GPUs.

Usage Note

For optimal results with PEFT/LoRA, structure models so frozen and trainable parameters
are in separate FSDP units. Example:

class PEFTModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Frozen layers in separate modules
        self.frozen_backbone = FrozenBackbone()  # all params frozen
        # Trainable adapters in separate modules  
        self.lora_adapters = LoRAAdapters()     # all params trainable

This ensures FSDP can properly identify and optimize frozen parameters.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @rohan-varma @pacman100 @sgugger @muellerzr

- Introduce FrozenParamHandle class for frozen parameters
- Skip optimizer state allocation for frozen params
- Optimize CPU offload by keeping frozen params on CPU
- Skip gradient operations for frozen parameters

This reduces memory overhead from 8.85X to ~1.05X for PEFT/LoRA
workloads where 99%+ of parameters are frozen.

Testing shows:
- 75% reduction in optimizer parameters tracked (64 -> 16)
- Fixes the 1.65X memory overhead issue reported in pytorch#91165
- Enables efficient PEFT/LoRA training on consumer GPUs

Fixes pytorch#91165
Copy link

pytorch-bot bot commented Aug 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159751

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 86ef3b3 with merge base a626dc8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Aug 4, 2025
"""Frozen parameters never need gradient synchronization."""
return False

def prepare_gradient_for_backward(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be typing.override or are they new hooks that should be overriden. If the former, type explicitly with typing_extension to prevent refactoring bugs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! You're correct - these methods (needs_gradient_sync, prepare_gradient_for_backward, and prepare_gradient_for_optim) are indeed overrides from FlatParamHandle. I've added the @override decorator to all three methods to make this explicit. Changes pushed!

As suggested in review, adding typing_extensions.override decorators
to make method overrides explicit and prevent refactoring bugs.
@janeyx99 janeyx99 requested a review from weifengpy August 4, 2025 23:01
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 4, 2025
@JonSnow1807
Copy link
Contributor Author

@weifengpy @awgu Thank you for your time on this PR. I've been thinking about the feedback and studying the codebase further.

I realize that creating a new FrozenParamHandle class might be too invasive for FSDP1's architecture. After reviewing #101982 where @awgu addressed the resharding issue for frozen parameters, I see there's still the optimizer state allocation problem mentioned in #91165.

Would you prefer a simpler approach? I could:

  1. Keep using the standard FlatParamHandle
  2. Just add a simple flag to mark handles with all-frozen parameters
  3. Add a minimal check in _optim_utils.py to skip optimizer state allocation for these marked handles

This would be a much smaller change (~20 lines) while still providing the 35-50% memory savings needed for PEFT/LoRA workloads. It would complete the frozen parameter optimization that #101982 started.

Happy to refactor this PR to that simpler approach if you think it's more appropriate for FSDP1. What are your thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FSDP] FSDP with CPU offload consumes 1.65X more GPU memory when training models with most of the params frozen
4 participants