From c1accddb2e2f15550e4138bd818755a4f36e0420 Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Tue, 18 Mar 2025 17:30:19 +0800 Subject: [PATCH] Fix `SequentialLR` deprecate warning about invoke `step(epoch)` --- test/optim/test_lrscheduler.py | 13 +++++++++++++ torch/optim/lr_scheduler.py | 7 +++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py index a6e448173f9e..c36e7b2e21d6 100644 --- a/test/optim/test_lrscheduler.py +++ b/test/optim/test_lrscheduler.py @@ -784,6 +784,19 @@ def test_sequentiallr5(self): scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) self._test(scheduler, targets, epochs) + def test_sequentiallr_no_warnings(self): + scheduler1 = LinearLR(self.opt, start_factor=0.5, end_factor=0.1, total_iters=5) + scheduler2 = ExponentialLR(self.opt, gamma=0.9) + scheduler = SequentialLR( + self.opt, schedulers=[scheduler1, scheduler2], milestones=[5] + ) + + for _ in range(10): + self.opt.step() + with warnings.catch_warnings(record=True) as ws: + scheduler.step() + self.assertTrue(len(ws) == 0, "No warning should be raised") + def test_get_last_lr_sequentiallr(self): epochs = 12 milestones = [3, 6] diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 6f9f6f1a3cf0..94dd5443aedf 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -200,13 +200,16 @@ def step(self, epoch: Optional[int] = None) -> None: ) self._step_count += 1 + if epoch is not None: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self._update_lr(epoch) + def _update_lr(self, epoch: Optional[int] = None): with _enable_get_lr_call(self): if epoch is None: self.last_epoch += 1 values = self.get_lr() else: - warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) self.last_epoch = epoch if hasattr(self, "_get_closed_form_lr"): values = cast(list[float], self._get_closed_form_lr()) @@ -913,7 +916,7 @@ def step(self) -> None: # type: ignore[override] idx = bisect_right(self._milestones, self.last_epoch) scheduler = self._schedulers[idx] if idx > 0 and self._milestones[idx - 1] == self.last_epoch: - scheduler.step(0) + scheduler._update_lr(0) else: scheduler.step()