93
93
EvalPrediction ,
94
94
HPSearchBackend ,
95
95
PredictionOutput ,
96
+ ShardedDDPOption ,
96
97
TrainerMemoryTracker ,
97
98
TrainOutput ,
98
99
default_compute_objective ,
131
132
import torch_xla .distributed .parallel_loader as pl
132
133
133
134
if is_fairscale_available ():
135
+ import fairscale
134
136
from fairscale .nn .data_parallel import ShardedDataParallel as ShardedDDP
135
137
from fairscale .optim import OSS
136
138
from fairscale .optim .grad_scaler import ShardedGradScaler
137
139
140
+ if version .parse (fairscale .__version__ ) >= version .parse ("0.3" ):
141
+ from fairscale .nn .data_parallel import FullyShardedDataParallel as FullyShardedDDP
142
+ else :
143
+ FullyShardedDDP = None
144
+
138
145
if is_sagemaker_distributed_available ():
139
146
import smdistributed .dataparallel .torch .distributed as dist
140
147
from smdistributed .dataparallel .torch .parallel .distributed import DistributedDataParallel as DDP
@@ -277,9 +284,38 @@ def __init__(
277
284
else :
278
285
self .is_model_parallel = False
279
286
287
+ # Setup Sharded DDP training
288
+ self .sharded_ddp = None
289
+ if len (args .sharded_ddp ) > 0 :
290
+ if args .deepspeed :
291
+ raise ValueError (
292
+ "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
293
+ )
294
+
295
+ if args .local_rank == - 1 :
296
+ raise ValueError ("Using sharded DDP only works in distributed training." )
297
+ elif not is_fairscale_available ():
298
+ raise ImportError ("Sharded DDP training requires fairscale: `pip install fairscale`." )
299
+ elif ShardedDDPOption .SIMPLE not in args .sharded_ddp and FullyShardedDDP is None :
300
+ raise ImportError (
301
+ "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
302
+ f"{ fairscale .__version__ } . Upgrade your fairscale library: `pip install --upgrade fairscale`."
303
+ )
304
+ elif ShardedDDPOption .SIMPLE in args .sharded_ddp :
305
+ self .sharded_ddp = ShardedDDPOption .SIMPLE
306
+ elif ShardedDDPOption .ZERO_DP_2 in args .sharded_ddp :
307
+ self .sharded_ddp = ShardedDDPOption .ZERO_DP_2
308
+ elif ShardedDDPOption .ZERO_DP_3 in args .sharded_ddp :
309
+ self .sharded_ddp = ShardedDDPOption .ZERO_DP_3
310
+
280
311
# one place to sort out whether to place the model on device or not
281
312
self .place_model_on_device = args .place_model_on_device
282
- if self .is_model_parallel or (args .deepspeed and args .do_train ) or (args .fp16_full_eval and not args .do_train ):
313
+ if (
314
+ self .is_model_parallel
315
+ or (args .deepspeed and args .do_train )
316
+ or (args .fp16_full_eval and not args .do_train )
317
+ or (self .sharded_ddp in [ShardedDDPOption .ZERO_DP_2 , ShardedDDPOption .ZERO_DP_3 ])
318
+ ):
283
319
self .place_model_on_device = False
284
320
285
321
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding (tokenizer )
@@ -346,21 +382,6 @@ def __init__(
346
382
if isinstance (eval_dataset , datasets .Dataset ):
347
383
self ._remove_unused_columns (self .eval_dataset , description = "evaluation" )
348
384
349
- # Setup Sharded DDP training
350
- self .sharded_dpp = False
351
- if args .sharded_ddp :
352
- if args .deepspeed :
353
- raise ValueError (
354
- "Using --sharded_ddp together with --deepspeed is not possible, deactivate one of those flags."
355
- )
356
-
357
- if args .local_rank == - 1 :
358
- raise ValueError ("Using sharded DDP only works in distributed training." )
359
- elif not is_fairscale_available ():
360
- raise ImportError ("Sharded DDP training requires fairscale: `pip install fairscale`." )
361
- else :
362
- self .sharded_dpp = True
363
-
364
385
# Mixed precision setup
365
386
self .use_apex = False
366
387
self .use_amp = False
@@ -376,7 +397,7 @@ def __init__(
376
397
if args .fp16 and not args .deepspeed : # deepspeed manages its own fp16
377
398
if self .fp16_backend == "amp" :
378
399
self .use_amp = True
379
- self .scaler = ShardedGradScaler () if self .sharded_dpp else torch .cuda .amp .GradScaler ()
400
+ self .scaler = ShardedGradScaler () if self .sharded_ddp is not None else torch .cuda .amp .GradScaler ()
380
401
else :
381
402
if not is_apex_available ():
382
403
raise ImportError (
@@ -619,7 +640,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
619
640
"eps" : self .args .adam_epsilon ,
620
641
}
621
642
optimizer_kwargs ["lr" ] = self .args .learning_rate
622
- if self .sharded_dpp :
643
+ if self .sharded_ddp == ShardedDDPOption . SIMPLE :
623
644
self .optimizer = OSS (
624
645
params = optimizer_grouped_parameters ,
625
646
optim = optimizer_cls ,
@@ -737,8 +758,19 @@ def _wrap_model(self, model, training=True):
737
758
return model
738
759
739
760
# Distributed training (should be after apex fp16 initialization)
740
- if self .sharded_dpp :
741
- model = ShardedDDP (model , self .optimizer )
761
+ if self .sharded_ddp is not None :
762
+ # Sharded DDP!
763
+ if self .sharded_ddp == ShardedDDPOption .SIMPLE :
764
+ model = ShardedDDP (model , self .optimizer )
765
+ else :
766
+ mixed_precision = self .args .fp16
767
+ cpu_offload = ShardedDDPOption .OFFLOAD in self .args .sharded_ddp
768
+ zero_3 = self .sharded_ddp == ShardedDDPOption .ZERO_DP_3
769
+ # XXX: Breaking the self.model convention but I see no way around it for now.
770
+ self .model = model = FullyShardedDDP (
771
+ model , mixed_precision = mixed_precision , reshard_after_forward = zero_3 , cpu_offload = cpu_offload
772
+ ).to (self .args .device )
773
+
742
774
elif is_sagemaker_distributed_available ():
743
775
model = DDP (model , device_ids = [dist .get_local_rank ()], broadcast_buffers = False )
744
776
elif self .args .local_rank != - 1 :
@@ -855,14 +887,15 @@ def train(
855
887
num_train_epochs = 1
856
888
num_update_steps_per_epoch = max_steps
857
889
890
+ delay_optimizer_creation = self .sharded_ddp is not None and self .sharded_ddp != ShardedDDPOption .SIMPLE
858
891
if self .args .deepspeed :
859
892
model , optimizer , lr_scheduler = init_deepspeed (self , num_training_steps = max_steps )
860
893
self .model = model .module
861
894
self .model_wrapped = model # will get further wrapped in DDP
862
895
self .deepspeed = model # DeepSpeedEngine object
863
896
self .optimizer = optimizer
864
897
self .lr_scheduler = lr_scheduler
865
- else :
898
+ elif not delay_optimizer_creation :
866
899
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
867
900
868
901
self .state = TrainerState ()
@@ -877,6 +910,9 @@ def train(
877
910
if model is not self .model :
878
911
self .model_wrapped = model
879
912
913
+ if delay_optimizer_creation :
914
+ self .create_optimizer_and_scheduler (num_training_steps = max_steps )
915
+
880
916
# important: at this point:
881
917
# self.model is the Transformers Model
882
918
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
@@ -1026,6 +1062,9 @@ def train(
1026
1062
if hasattr (self .optimizer , "clip_grad_norm" ):
1027
1063
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1028
1064
self .optimizer .clip_grad_norm (self .args .max_grad_norm )
1065
+ elif hasattr (model , "clip_grad_norm_" ):
1066
+ # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1067
+ model .clip_grad_norm_ (self .args .max_grad_norm )
1029
1068
else :
1030
1069
# Revert to normal clipping otherwise, handling Apex or full precision
1031
1070
torch .nn .utils .clip_grad_norm_ (
@@ -1148,8 +1187,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
1148
1187
1149
1188
def _save_checkpoint (self , model , trial , metrics = None ):
1150
1189
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
1151
- # want to save.
1152
- assert _model_unwrap (model ) is self .model , "internal model should be a reference to self.model"
1190
+ # want to save except FullyShardedDDP .
1191
+ # assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
1153
1192
1154
1193
# Save model checkpoint
1155
1194
checkpoint_folder = f"{ PREFIX_CHECKPOINT_DIR } -{ self .state .global_step } "
@@ -1173,7 +1212,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
1173
1212
self .deepspeed .save_checkpoint (output_dir )
1174
1213
1175
1214
# Save optimizer and scheduler
1176
- if self .sharded_dpp :
1215
+ if self .sharded_ddp == ShardedDDPOption . SIMPLE :
1177
1216
self .optimizer .consolidate_state_dict ()
1178
1217
1179
1218
if is_torch_tpu_available ():
@@ -1479,7 +1518,11 @@ def _save_tpu(self, output_dir: Optional[str] = None):
1479
1518
# They can then be reloaded using `from_pretrained()`
1480
1519
xm .rendezvous ("saving_checkpoint" )
1481
1520
if not isinstance (self .model , PreTrainedModel ):
1482
- logger .info ("Trainer.model is not a `PreTrainedModel`, only saving its state dict." )
1521
+ if isinstance (_model_unwrap (self .model ), PreTrainedModel ):
1522
+ if xm .is_master_ordinal ():
1523
+ _model_unwrap (self .model ).config .save_pretrained (output_dir )
1524
+ else :
1525
+ logger .info ("Trainer.model is not a `PreTrainedModel`, only saving its state dict." )
1483
1526
state_dict = self .model .state_dict ()
1484
1527
xm .save (state_dict , os .path .join (output_dir , WEIGHTS_NAME ))
1485
1528
else :
@@ -1494,7 +1537,10 @@ def _save(self, output_dir: Optional[str] = None):
1494
1537
# Save a trained model and configuration using `save_pretrained()`.
1495
1538
# They can then be reloaded using `from_pretrained()`
1496
1539
if not isinstance (self .model , PreTrainedModel ):
1497
- logger .info ("Trainer.model is not a `PreTrainedModel`, only saving its state dict." )
1540
+ if isinstance (_model_unwrap (self .model ), PreTrainedModel ):
1541
+ _model_unwrap (self .model ).config .save_pretrained (output_dir )
1542
+ else :
1543
+ logger .info ("Trainer.model is not a `PreTrainedModel`, only saving its state dict." )
1498
1544
state_dict = self .model .state_dict ()
1499
1545
torch .save (state_dict , os .path .join (output_dir , WEIGHTS_NAME ))
1500
1546
else :
0 commit comments