Skip to content

Commit 9d14be5

Browse files
sguggerstas00
andauthored
Add support for ZeRO-2/3 and ZeRO-offload in fairscale (huggingface#10354)
* Ass support for ZeRO-2/3 and ZeRO-offload in fairscale * Quality * Rework from review comments * Add doc * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
1 parent 88cc26d commit 9d14be5

File tree

5 files changed

+193
-46
lines changed

5 files changed

+193
-46
lines changed

docs/source/main_classes/trainer.rst

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ provides support for the following features from `the ZeRO paper <https://arxiv.
241241

242242
1. Optimizer State Sharding
243243
2. Gradient Sharding
244+
3. Model Parameters Sharding (new and very experimental)
245+
4. CPU offload (new and very experimental)
244246

245247
You will need at least two GPUs to use this feature.
246248

@@ -255,8 +257,9 @@ To deploy this feature:
255257
or find more details on `the FairScale's GitHub page
256258
<https://github.com/facebookresearch/fairscale/#installation>`__.
257259

258-
2. Add ``--sharded_ddp`` to the command line arguments, and make sure you have added the distributed launcher ``-m
259-
torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
260+
2. To use the first version of Sharded data-parallelism, add ``--sharded_ddp simple`` to the command line arguments,
261+
and make sure you have added the distributed launcher ``-m torch.distributed.launch
262+
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
260263

261264
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
262265

@@ -268,17 +271,55 @@ For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
268271
--do_train --max_train_samples 500 --num_train_epochs 1 \
269272
--dataset_name wmt16 --dataset_config "ro-en" \
270273
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
271-
--fp16 --sharded_ddp
274+
--fp16 --sharded_ddp simple
272275
273276
Notes:
274277

275278
- This feature requires distributed training (so multiple GPUs).
276279
- It is not implemented for TPUs.
277280
- It works with ``--fp16`` too, to make things even faster.
278-
- One of the main benefits of enabling ``--sharded_ddp`` is that it uses a lot less GPU memory, so you should be able
279-
to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
281+
- One of the main benefits of enabling ``--sharded_ddp simple`` is that it uses a lot less GPU memory, so you should be
282+
able to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
280283
significantly shorter training time.
281284

285+
3. To use the second version of Sharded data-parallelism, add ``--sharded_ddp zero_dp_2`` or ``--sharded_ddp zero_dp_3`
286+
to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch
287+
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
288+
289+
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
290+
291+
.. code-block:: bash
292+
293+
python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_seq2seq.py \
294+
--model_name_or_path t5-small --per_device_train_batch_size 1 \
295+
--output_dir output_dir --overwrite_output_dir \
296+
--do_train --max_train_samples 500 --num_train_epochs 1 \
297+
--dataset_name wmt16 --dataset_config "ro-en" \
298+
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
299+
--fp16 --sharded_ddp zero_dp_2
300+
301+
:obj:`zero_dp_2` is an optimized version of the simple wrapper, while :obj:`zero_dp_3` fully shards model weights,
302+
gradients and optimizer states.
303+
304+
Both are compatible with adding :obj:`cpu_offload` to enable ZeRO-offload (activate it like this: :obj:`--sharded_ddp
305+
"zero_dp_2 cpu_offload"`).
306+
307+
Notes:
308+
309+
- This feature requires distributed training (so multiple GPUs).
310+
- It is not implemented for TPUs.
311+
- It works with ``--fp16`` too, to make things even faster.
312+
- The ``cpu_offload`` additional option requires ``--fp16``.
313+
- This is an area of active development, so make sure you have a source install of fairscale to use this feature as
314+
some bugs you encounter may have been fixed there already.
315+
316+
Known caveats:
317+
318+
- This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script.
319+
- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
320+
:obj:`FullyShardedDataParallelism` of fairscale. This is not done automatically by any of the example scripts of the
321+
:class:`~transformers.Trainer`.
322+
282323

283324
DeepSpeed
284325
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

examples/tests/trainer/test_trainer_ext.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,13 @@ def require_apex(test_case):
6464

6565

6666
class TestTrainerExt(TestCasePlus):
67-
def run_seq2seq_quick(self, distributed=False, extra_args_str=None):
68-
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
67+
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
68+
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, predict_with_generate)
6969
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
7070
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
7171
first_step_stats = eval_metrics[0]
72-
assert "eval_bleu" in first_step_stats
72+
if predict_with_generate:
73+
assert "eval_bleu" in first_step_stats
7374

7475
@require_torch_non_multi_gpu
7576
def test_run_seq2seq_no_dist(self):
@@ -88,14 +89,28 @@ def test_run_seq2seq_ddp(self):
8889
# test --sharded_ddp w/o --fp16
8990
@require_torch_multi_gpu
9091
@require_fairscale
91-
def test_run_seq2seq_ddp_sharded_ddp(self):
92-
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp")
92+
def test_run_seq2seq_sharded_ddp(self):
93+
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
9394

9495
# test --sharded_ddp w/ --fp16
9596
@require_torch_multi_gpu
9697
@require_fairscale
97-
def test_run_seq2seq_ddp_sharded_ddp_fp16(self):
98-
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
98+
def test_run_seq2seq_sharded_ddp_fp16(self):
99+
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
100+
101+
# test --sharded_ddp zero2 w/o --fp16
102+
@require_torch_multi_gpu
103+
@require_fairscale
104+
def test_run_seq2seq_fully_sharded_ddp(self):
105+
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero2", predict_with_generate=False)
106+
107+
# test --sharded_ddp zero2 w/ --fp16
108+
@require_torch_multi_gpu
109+
@require_fairscale
110+
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
111+
self.run_seq2seq_quick(
112+
distributed=True, extra_args_str="--sharded_ddp zero2 --fp16", predict_with_generate=False
113+
)
99114

100115
@require_apex
101116
def test_run_seq2seq_apex(self):
@@ -131,6 +146,7 @@ def run_trainer(
131146
num_train_epochs: int,
132147
distributed: bool = False,
133148
extra_args_str: str = None,
149+
predict_with_generate: bool = True,
134150
):
135151
data_dir = self.examples_dir / "test_data/wmt_en_ro"
136152
output_dir = self.get_auto_remove_tmp_dir()
@@ -155,7 +171,6 @@ def run_trainer(
155171
--learning_rate 3e-3
156172
--warmup_steps 8
157173
--evaluation_strategy steps
158-
--predict_with_generate
159174
--logging_steps 0
160175
--save_steps {str(eval_steps)}
161176
--eval_steps {str(eval_steps)}
@@ -165,7 +180,11 @@ def run_trainer(
165180
--task translation
166181
--target_lang ro_RO
167182
--source_lang en_XX
168-
""".split()
183+
"""
184+
if predict_with_generate:
185+
args += "--predict_with_generate"
186+
187+
args = args.split()
169188

170189
if extra_args_str is not None:
171190
args.extend(extra_args_str.split())

src/transformers/trainer.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
EvalPrediction,
9494
HPSearchBackend,
9595
PredictionOutput,
96+
ShardedDDPOption,
9697
TrainerMemoryTracker,
9798
TrainOutput,
9899
default_compute_objective,
@@ -131,10 +132,16 @@
131132
import torch_xla.distributed.parallel_loader as pl
132133

133134
if is_fairscale_available():
135+
import fairscale
134136
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
135137
from fairscale.optim import OSS
136138
from fairscale.optim.grad_scaler import ShardedGradScaler
137139

140+
if version.parse(fairscale.__version__) >= version.parse("0.3"):
141+
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
142+
else:
143+
FullyShardedDDP = None
144+
138145
if is_sagemaker_distributed_available():
139146
import smdistributed.dataparallel.torch.distributed as dist
140147
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
@@ -277,9 +284,38 @@ def __init__(
277284
else:
278285
self.is_model_parallel = False
279286

287+
# Setup Sharded DDP training
288+
self.sharded_ddp = None
289+
if len(args.sharded_ddp) > 0:
290+
if args.deepspeed:
291+
raise ValueError(
292+
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
293+
)
294+
295+
if args.local_rank == -1:
296+
raise ValueError("Using sharded DDP only works in distributed training.")
297+
elif not is_fairscale_available():
298+
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
299+
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
300+
raise ImportError(
301+
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
302+
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
303+
)
304+
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
305+
self.sharded_ddp = ShardedDDPOption.SIMPLE
306+
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
307+
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
308+
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
309+
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
310+
280311
# one place to sort out whether to place the model on device or not
281312
self.place_model_on_device = args.place_model_on_device
282-
if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train):
313+
if (
314+
self.is_model_parallel
315+
or (args.deepspeed and args.do_train)
316+
or (args.fp16_full_eval and not args.do_train)
317+
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
318+
):
283319
self.place_model_on_device = False
284320

285321
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
@@ -346,21 +382,6 @@ def __init__(
346382
if isinstance(eval_dataset, datasets.Dataset):
347383
self._remove_unused_columns(self.eval_dataset, description="evaluation")
348384

349-
# Setup Sharded DDP training
350-
self.sharded_dpp = False
351-
if args.sharded_ddp:
352-
if args.deepspeed:
353-
raise ValueError(
354-
"Using --sharded_ddp together with --deepspeed is not possible, deactivate one of those flags."
355-
)
356-
357-
if args.local_rank == -1:
358-
raise ValueError("Using sharded DDP only works in distributed training.")
359-
elif not is_fairscale_available():
360-
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
361-
else:
362-
self.sharded_dpp = True
363-
364385
# Mixed precision setup
365386
self.use_apex = False
366387
self.use_amp = False
@@ -376,7 +397,7 @@ def __init__(
376397
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
377398
if self.fp16_backend == "amp":
378399
self.use_amp = True
379-
self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler()
400+
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
380401
else:
381402
if not is_apex_available():
382403
raise ImportError(
@@ -619,7 +640,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
619640
"eps": self.args.adam_epsilon,
620641
}
621642
optimizer_kwargs["lr"] = self.args.learning_rate
622-
if self.sharded_dpp:
643+
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
623644
self.optimizer = OSS(
624645
params=optimizer_grouped_parameters,
625646
optim=optimizer_cls,
@@ -737,8 +758,19 @@ def _wrap_model(self, model, training=True):
737758
return model
738759

739760
# Distributed training (should be after apex fp16 initialization)
740-
if self.sharded_dpp:
741-
model = ShardedDDP(model, self.optimizer)
761+
if self.sharded_ddp is not None:
762+
# Sharded DDP!
763+
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
764+
model = ShardedDDP(model, self.optimizer)
765+
else:
766+
mixed_precision = self.args.fp16
767+
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
768+
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
769+
# XXX: Breaking the self.model convention but I see no way around it for now.
770+
self.model = model = FullyShardedDDP(
771+
model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload
772+
).to(self.args.device)
773+
742774
elif is_sagemaker_distributed_available():
743775
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
744776
elif self.args.local_rank != -1:
@@ -855,14 +887,15 @@ def train(
855887
num_train_epochs = 1
856888
num_update_steps_per_epoch = max_steps
857889

890+
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
858891
if self.args.deepspeed:
859892
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
860893
self.model = model.module
861894
self.model_wrapped = model # will get further wrapped in DDP
862895
self.deepspeed = model # DeepSpeedEngine object
863896
self.optimizer = optimizer
864897
self.lr_scheduler = lr_scheduler
865-
else:
898+
elif not delay_optimizer_creation:
866899
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
867900

868901
self.state = TrainerState()
@@ -877,6 +910,9 @@ def train(
877910
if model is not self.model:
878911
self.model_wrapped = model
879912

913+
if delay_optimizer_creation:
914+
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
915+
880916
# important: at this point:
881917
# self.model is the Transformers Model
882918
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
@@ -1026,6 +1062,9 @@ def train(
10261062
if hasattr(self.optimizer, "clip_grad_norm"):
10271063
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
10281064
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
1065+
elif hasattr(model, "clip_grad_norm_"):
1066+
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1067+
model.clip_grad_norm_(self.args.max_grad_norm)
10291068
else:
10301069
# Revert to normal clipping otherwise, handling Apex or full precision
10311070
torch.nn.utils.clip_grad_norm_(
@@ -1148,8 +1187,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
11481187

11491188
def _save_checkpoint(self, model, trial, metrics=None):
11501189
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
1151-
# want to save.
1152-
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
1190+
# want to save except FullyShardedDDP.
1191+
# assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
11531192

11541193
# Save model checkpoint
11551194
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
@@ -1173,7 +1212,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
11731212
self.deepspeed.save_checkpoint(output_dir)
11741213

11751214
# Save optimizer and scheduler
1176-
if self.sharded_dpp:
1215+
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
11771216
self.optimizer.consolidate_state_dict()
11781217

11791218
if is_torch_tpu_available():
@@ -1479,7 +1518,11 @@ def _save_tpu(self, output_dir: Optional[str] = None):
14791518
# They can then be reloaded using `from_pretrained()`
14801519
xm.rendezvous("saving_checkpoint")
14811520
if not isinstance(self.model, PreTrainedModel):
1482-
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
1521+
if isinstance(_model_unwrap(self.model), PreTrainedModel):
1522+
if xm.is_master_ordinal():
1523+
_model_unwrap(self.model).config.save_pretrained(output_dir)
1524+
else:
1525+
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
14831526
state_dict = self.model.state_dict()
14841527
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
14851528
else:
@@ -1494,7 +1537,10 @@ def _save(self, output_dir: Optional[str] = None):
14941537
# Save a trained model and configuration using `save_pretrained()`.
14951538
# They can then be reloaded using `from_pretrained()`
14961539
if not isinstance(self.model, PreTrainedModel):
1497-
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
1540+
if isinstance(_model_unwrap(self.model), PreTrainedModel):
1541+
_model_unwrap(self.model).config.save_pretrained(output_dir)
1542+
else:
1543+
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
14981544
state_dict = self.model.state_dict()
14991545
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
15001546
else:

src/transformers/trainer_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,10 @@ def stop_and_update_metrics(self, metrics=None):
421421
# init doesn't have metrics to update so we just save that data for later stages to retrieve
422422
if metrics is not None:
423423
self.update_metrics(stage, metrics)
424+
425+
426+
class ShardedDDPOption(ExplicitEnum):
427+
SIMPLE = "simple"
428+
ZERO_DP_2 = "zero2"
429+
ZERO_DP_3 = "zero3"
430+
OFFLOAD = "offload"

0 commit comments

Comments
 (0)