Skip to content

Fix lr_scheduler unexpectedly calls step() when init argument last_epoch is larger than -1 #149312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

zeshengzong
Copy link
Contributor

@zeshengzong zeshengzong commented Mar 17, 2025

Fixes #102261

Changes

  • Use flag _is_initial to replace self.last_epoch == 0 condition to judge whether lr should be initial value
  • Add test for ExponentialLR checkpoint usecase

Test Result

pytest -s test/optim/test_lrscheduler.py  -vv

image

Copy link

pytorch-bot bot commented Mar 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149312

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0208b8f with merge base a264af8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zeshengzong zeshengzong marked this pull request as ready for review March 18, 2025 08:06
@zeshengzong
Copy link
Contributor Author

Hello @albanD @janeyx99 , please check whether the fixing is feasible, if it works, I would like to continue fix more schedulers which have same problem, like MultiplicativeLR, LinearLR, thanks!

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 20, 2025
@albanD albanD removed their request for review April 9, 2025 19:37
@zeshengzong
Copy link
Contributor Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix/optim/step onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout fix/optim/step && git pull --rebase)

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not look like the right approach. If the discrepancy is for ExponentialLR between get_lr and _get_closed_form_lr, I'd expect the fix to be local there. Could you explain your approach a little bit?

optim2 = torch.optim.AdamW(model.parameters())
optim2.load_state_dict(optim.state_dict())
sch2 = LRClass(optim2, last_epoch=1)
self.assertEqual(optim.param_groups[0]["lr"], optim2.param_groups[0]["lr"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the same comparison as the repro--we should be comparing that the closed form lr is the same as the params group lr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thanks!

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually, I see what you're doing now. Sorry I was confused yesterday. I'm willing to accept this fix if you update the test case.

It would also be good to include a comment about why we prefer the _is_initial.

@janeyx99 janeyx99 added the topic: bug fixes topic category label May 6, 2025
@janeyx99 janeyx99 dismissed their stale review May 6, 2025 17:46

left newer review

@@ -134,7 +135,8 @@ def wrapper(*args, **kwargs):
def _initial_step(self):
"""Initialize step counts and perform a step."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As someone who has looked into LRScheduler more than I've been able to, have you seen a good reason why we need to call .step() from the constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one of the key effect is to initialize optimizer initial lr as the same as the scheduler lr when create it, and reuse this part of code:

with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = cast(list[float], self._get_closed_form_lr())
else:
values = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, values):
if isinstance(param_group["lr"], Tensor):
param_group["lr"].fill_(_to_scalar(lr))
else:
param_group["lr"] = lr
self._last_lr: list[float] = [
group["lr"] for group in self.optimizer.param_groups
]

One improvement can be made is extracting internal update lr logic from step public method, please check this PR: #149392 and the issue it fixed. Thanks!

@joecummings
Copy link
Member

I'd love to see this expanded to ensure this works for all LRSchedulers! I have confirmed that I see the same issue when testing with StepLR (when I try to resume training and setup a new LRScheduler, it is always one step off b/c of this initial step that is taken in the init of LRSchedulers).

@janeyx99
Copy link
Contributor

@zeshengzong lmk if you can bring this PR over the finish line with expanding it to all LRSchedulers!

@zeshengzong
Copy link
Contributor Author

zeshengzong commented May 14, 2025

@zeshengzong lmk if you can bring this PR over the finish line with expanding it to all LRSchedulers!

Hi @janeyx99 , sorry for late reply, busy with something else previously. I would like fix all of them and hope I could clean up all issues related with lr_scheduler, thanks for help!

@zeshengzong
Copy link
Contributor Author

Oh actually, I see what you're doing now. Sorry I was confused yesterday. I'm willing to accept this fix if you update the test case.

It would also be good to include a comment about why we prefer the _is_initial.

Yes, adding a context to better distinguish initial lr or calculate lr, self.last_epoch == 0 is not enough at this case.

[
partial(ExponentialLR, gamma=0.999),
],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be great to expand this to more than ExponentialLR!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Participating a pytorch meetup, will do it next week, thanks! :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @janeyx99 , I've added more schedulers in here, ReduceLROnPlateau has different pattern, so I separate it to another test.

optim2 = torch.optim.AdamW(model.parameters())
optim2.load_state_dict(optim.state_dict())
sch2 = LRClass(optim2, last_epoch=0)
self.assertEqual(sch2.get_last_lr()[0], optim.param_groups[0]["lr"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with get_last_lr since some schedulers not implemented _get_closed_form_lr method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use _get_closed_form_lr whenever it is available (using hasattr)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thanks!

@zeshengzong
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 20, 2025
@zeshengzong
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: Apply lint suggestions

Details for Dev Infra team Raised by workflow job

@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: Apply lint suggestions

Details for Dev Infra team Raised by workflow job

@zeshengzong
Copy link
Contributor Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

zeshengzong and others added 6 commits May 22, 2025 01:22
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix/optim/step onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout fix/optim/step && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label May 22, 2025
@zeshengzong
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 22, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ExponentialLR unexpectedly calls step() when init argument last_epoch is larger than -1
5 participants