Skip to content

Commit 42f359d

Browse files
tjruwasestas00
andauthored
Use DS callable API to allow hf_scheduler + ds_optimizer (huggingface#13216)
* Use DS callable API to allow hf_scheduler + ds_optimizer * Preserve backward-compatibility * Restore backward compatibility * Tweak arg positioning * Tweak arg positioning * bump the required version * Undo indent * Update src/transformers/trainer.py * style Co-authored-by: Stas Bekman <stas@stason.org> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
1 parent 35236b8 commit 42f359d

File tree

5 files changed

+24
-31
lines changed

5 files changed

+24
-31
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
"cookiecutter==1.7.2",
9292
"dataclasses",
9393
"datasets",
94-
"deepspeed>=0.4.3",
94+
"deepspeed>=0.5.1",
9595
"docutils==0.16.0",
9696
"fairscale>0.3",
9797
"faiss-cpu",

src/transformers/deepspeed.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -311,13 +311,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
311311
# 1. DS scheduler + DS optimizer: Yes
312312
# 2. HF scheduler + HF optimizer: Yes
313313
# 3. DS scheduler + HF optimizer: Yes
314-
# 4. HF scheduler + DS optimizer: No
314+
# 4. HF scheduler + DS optimizer: Yes
315315
#
316316
# Unless Offload is enabled in which case it's:
317317
# 1. DS scheduler + DS optimizer: Yes
318318
# 2. HF scheduler + HF optimizer: Mostly*
319319
# 3. DS scheduler + HF optimizer: Mostly*
320-
# 4. HF scheduler + DS optimizer: No
320+
# 4. HF scheduler + DS optimizer: Yes
321321
#
322322
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)
323323

@@ -336,28 +336,20 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
336336

337337
# ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
338338
# But trainer uses AdamW by default.
339-
trainer.create_optimizer()
340-
optimizer = trainer.optimizer
339+
optimizer = trainer.create_optimizer()
341340
# To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer`
342341
config["zero_allow_untested_optimizer"] = True
343342

344-
# DS schedulers (deepspeed/runtime/lr_schedules.py):
345-
#
346-
# DS name | --lr_scheduler_type | HF func | Notes
347-
# -------------| ---------------------|-----------------------------------|--------------------
348-
# LRRangeTest | na | na | LRRT
349-
# OneCycle | na | na | 1CLR
350-
# WarmupLR | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0
351-
# WarmupDecayLR| linear | get_linear_schedule_with_warmup |
343+
def _lr_scheduler_callable(optimizer):
344+
return trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
345+
352346
lr_scheduler = None
353347
if "scheduler" not in config:
354-
if "optimizer" in config:
355-
# to make this option work, we need to init DS optimizer first, then init HS scheduler,
356-
# then pass the HS scheduler to DS init, which is not possible at the moment
357-
raise ValueError("At the moment HF scheduler + DeepSpeed optimizer combination is not possible")
348+
if optimizer is None:
349+
# Optimizer is not available, so use callable to defer lr_scheduler creation to DS init
350+
lr_scheduler = _lr_scheduler_callable
358351
else:
359-
trainer.create_scheduler(num_training_steps=num_training_steps)
360-
lr_scheduler = trainer.lr_scheduler
352+
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
361353

362354
# keep for quick debug:
363355
# from pprint import pprint; pprint(config)

src/transformers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"cookiecutter": "cookiecutter==1.7.2",
99
"dataclasses": "dataclasses",
1010
"datasets": "datasets",
11-
"deepspeed": "deepspeed>=0.4.3",
11+
"deepspeed": "deepspeed>=0.5.1",
1212
"docutils": "docutils==0.16.0",
1313
"fairscale": "fairscale>0.3",
1414
"faiss-cpu": "faiss-cpu",

src/transformers/trainer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
768768
and/or :obj:`create_scheduler`) in a subclass.
769769
"""
770770
self.create_optimizer()
771-
self.create_scheduler(num_training_steps)
771+
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
772772

773773
def create_optimizer(self):
774774
"""
@@ -813,20 +813,24 @@ def create_optimizer(self):
813813
if is_sagemaker_mp_enabled():
814814
self.optimizer = smp.DistributedOptimizer(self.optimizer)
815815

816-
def create_scheduler(self, num_training_steps: int):
816+
return self.optimizer
817+
818+
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
817819
"""
818-
Setup the scheduler. The optimizer of the trainer must have been set up before this method is called.
820+
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
821+
passed as an argument.
819822
820823
Args:
821824
num_training_steps (int): The number of training steps to do.
822825
"""
823826
if self.lr_scheduler is None:
824827
self.lr_scheduler = get_scheduler(
825828
self.args.lr_scheduler_type,
826-
self.optimizer,
829+
optimizer=self.optimizer if optimizer is None else optimizer,
827830
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
828831
num_training_steps=num_training_steps,
829832
)
833+
return self.lr_scheduler
830834

831835
def num_examples(self, dataloader: DataLoader) -> int:
832836
"""

tests/deepspeed/test_deepspeed.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,19 +292,16 @@ def test_ds_scheduler_hf_optimizer(self):
292292
self.assertNotEqual(new_a, a)
293293

294294
def test_hf_scheduler_ds_optimizer(self):
295-
# this combo is not possible at the moment
295+
a = 0
296296
with mockenv_context(**self.dist_env_1_gpu):
297297
ds_config_zero2_dict = self.get_config_dict(ZERO2)
298298
del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler
299299
ds_config_zero2_dict["zero_optimization"]["offload_optimizer"]["device"] = "none"
300300
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
301301
trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
302-
with self.assertRaises(Exception) as context:
303-
trainer.train()
304-
self.assertTrue(
305-
"HF scheduler + DeepSpeed optimizer combination is not possible" in str(context.exception),
306-
f"got exception: {context.exception}",
307-
)
302+
trainer.train()
303+
new_a = trainer.model.a.item()
304+
self.assertNotEqual(new_a, a)
308305

309306
@require_deepspeed_aio
310307
def test_stage3_nvme_offload(self):

0 commit comments

Comments
 (0)