Skip to content

Commit 3437d12

Browse files
authored
[Trainer/Deepspeed] handle get_last_lr() before first step() (huggingface#10362)
* handle get_last_lr() before first step() * abstract away the lr getting logic * cleanup * add test * move to utils
1 parent 4a1ab7c commit 3437d12

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

examples/tests/deepspeed/test_deepspeed.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,31 @@ def test_fake_notebook_no_launcher(self):
7878
trainer.train()
7979
assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none"
8080

81+
def test_early_get_last_lr(self):
82+
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
83+
# not run for the first few dozen steps while loss scale is too large, and thus during
84+
# that time `get_last_lr` will fail if called during that warm up stage,
85+
#
86+
# setting `logging_steps=1` forces an early `trainer._maybe_log_save_evaluate()` which calls
87+
# `self.lr_scheduler.get_last_lr()` and originally it'd fail on the very first step.
88+
with mockenv_context(**self.dist_env_1_gpu):
89+
a = b = 0.0
90+
trainer = get_regression_trainer(
91+
a=a,
92+
b=b,
93+
local_rank=0,
94+
train_len=8,
95+
deepspeed=self.ds_config_file,
96+
per_device_train_batch_size=8,
97+
logging_steps=1,
98+
)
99+
trainer.train()
100+
no_grad_accum_a = trainer.model.a.item()
101+
102+
# it's enough that train didn't fail for this test, but we must check that
103+
# optimizer/scheduler didn't run (since if it did this test isn't testing the right thing)
104+
self.assertEqual(no_grad_accum_a, a)
105+
81106
def test_gradient_accumulation(self):
82107

83108
# this test measures that we get identical weights and similar loss with:

src/transformers/trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
SequentialDistributedSampler,
8383
distributed_broadcast_scalars,
8484
distributed_concat,
85+
get_learning_rate,
8586
nested_concat,
8687
nested_detach,
8788
nested_numpify,
@@ -1129,12 +1130,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
11291130
tr_loss -= tr_loss
11301131

11311132
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
1132-
# backward compatibility for pytorch schedulers
1133-
logs["learning_rate"] = (
1134-
self.lr_scheduler.get_last_lr()[0]
1135-
if version.parse(torch.__version__) >= version.parse("1.4")
1136-
else self.lr_scheduler.get_lr()[0]
1137-
)
1133+
logs["learning_rate"] = get_learning_rate(self)
1134+
11381135
self._total_loss_scalar += tr_loss_scalar
11391136
self._globalstep_last_logged = self.state.global_step
11401137

src/transformers/trainer_pt_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import numpy as np
2626
import torch
27+
from packaging import version
2728
from torch.utils.data.dataset import Dataset
2829
from torch.utils.data.distributed import DistributedSampler
2930
from torch.utils.data.sampler import RandomSampler, Sampler
@@ -262,6 +263,29 @@ def _get_first_shape(arrays):
262263
return arrays.shape
263264

264265

266+
def get_learning_rate(trainer):
267+
if trainer.deepspeed:
268+
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
269+
# not run for the first few dozen steps while loss scale is too large, and thus during
270+
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
271+
try:
272+
last_lr = trainer.lr_scheduler.get_last_lr()[0]
273+
except AssertionError as e:
274+
if "need to call step" in str(e):
275+
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
276+
last_lr = 0
277+
else:
278+
raise
279+
else:
280+
last_lr = (
281+
# backward compatibility for pytorch schedulers
282+
trainer.lr_scheduler.get_last_lr()[0]
283+
if version.parse(torch.__version__) >= version.parse("1.4")
284+
else trainer.lr_scheduler.get_lr()[0]
285+
)
286+
return last_lr
287+
288+
265289
class DistributedTensorGatherer:
266290
"""
267291
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.

0 commit comments

Comments
 (0)