6
6
try :
7
7
from apex import amp
8
8
from apex .parallel import DistributedDataParallel as DDP
9
+ from apex .parallel import convert_syncbn_model
9
10
has_apex = True
10
11
except ImportError :
11
12
has_apex = False
12
13
13
14
from data import Dataset , create_loader , resolve_data_config , FastCollateMixup , mixup_target
14
- from models import create_model , resume_checkpoint
15
+ from models import create_model , resume_checkpoint , load_checkpoint
15
16
from utils import *
16
17
from loss import LabelSmoothingCrossEntropy , SoftTargetCrossEntropy
17
18
from optim import create_optimizer
91
92
help = 'BatchNorm momentum override (if not None)' )
92
93
parser .add_argument ('--bn-eps' , type = float , default = None ,
93
94
help = 'BatchNorm epsilon override (if not None)' )
95
+ parser .add_argument ('--model-ema' , action = 'store_true' , default = False ,
96
+ help = 'Enable tracking moving average of model weights' )
97
+ parser .add_argument ('--model-ema-force-cpu' , action = 'store_true' , default = False ,
98
+ help = 'Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.' )
99
+ parser .add_argument ('--model-ema-decay' , type = float , default = 0.9998 ,
100
+ help = 'decay factor for model weights moving average (default: 0.9998)' )
94
101
parser .add_argument ('--seed' , type = int , default = 42 , metavar = 'S' ,
95
102
help = 'random seed (default: 42)' )
96
103
parser .add_argument ('--log-interval' , type = int , default = 50 , metavar = 'N' ,
97
104
help = 'how many batches to wait before logging training status' )
98
- parser .add_argument ('--recovery-interval' , type = int , default = 1000 , metavar = 'N' ,
105
+ parser .add_argument ('--recovery-interval' , type = int , default = 0 , metavar = 'N' ,
99
106
help = 'how many batches to wait before writing recovery checkpoint' )
100
107
parser .add_argument ('-j' , '--workers' , type = int , default = 4 , metavar = 'N' ,
101
108
help = 'how many training processes to use (default: 1)' )
109
116
help = 'save images of input bathes every log interval for debugging' )
110
117
parser .add_argument ('--amp' , action = 'store_true' , default = False ,
111
118
help = 'use NVIDIA amp for mixed precision training' )
119
+ parser .add_argument ('--sync-bn' , action = 'store_true' ,
120
+ help = 'enabling apex sync BN.' )
112
121
parser .add_argument ('--no-prefetcher' , action = 'store_true' , default = False ,
113
122
help = 'disable fast prefetcher' )
114
123
parser .add_argument ('--output' , default = '' , type = str , metavar = 'PATH' ,
@@ -131,31 +140,28 @@ def main():
131
140
132
141
args .device = 'cuda:0'
133
142
args .world_size = 1
134
- r = - 1
143
+ args . rank = 0 # global rank
135
144
if args .distributed :
136
145
args .num_gpu = 1
137
146
args .device = 'cuda:%d' % args .local_rank
138
147
torch .cuda .set_device (args .local_rank )
139
- torch .distributed .init_process_group (backend = 'nccl' ,
140
- init_method = 'env://' )
148
+ torch .distributed .init_process_group (
149
+ backend = 'nccl' , init_method = 'env://' )
141
150
args .world_size = torch .distributed .get_world_size ()
142
- r = torch .distributed .get_rank ()
151
+ args .rank = torch .distributed .get_rank ()
152
+ assert args .rank >= 0
143
153
144
154
if args .distributed :
145
155
print ('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
146
- % (r , args .world_size ))
156
+ % (args . rank , args .world_size ))
147
157
else :
148
158
print ('Training with a single process on %d GPUs.' % args .num_gpu )
149
159
150
- # FIXME seed handling for multi-process distributed?
151
- torch .manual_seed (args .seed )
160
+ torch .manual_seed (args .seed + args .rank )
152
161
153
162
output_dir = ''
154
163
if args .local_rank == 0 :
155
- if args .output :
156
- output_base = args .output
157
- else :
158
- output_base = './output'
164
+ output_base = args .output if args .output else './output'
159
165
exp_name = '-' .join ([
160
166
datetime .now ().strftime ("%Y%m%d-%H%M%S" ),
161
167
args .model ,
@@ -191,6 +197,8 @@ def main():
191
197
args .amp = False
192
198
model = nn .DataParallel (model , device_ids = list (range (args .num_gpu ))).cuda ()
193
199
else :
200
+ if args .distributed and args .sync_bn and has_apex :
201
+ model = convert_syncbn_model (model )
194
202
model .cuda ()
195
203
196
204
optimizer = create_optimizer (args , model )
@@ -205,8 +213,20 @@ def main():
205
213
use_amp = False
206
214
print ('AMP disabled' )
207
215
216
+ model_ema = None
217
+ if args .model_ema :
218
+ model_ema = ModelEma (
219
+ model ,
220
+ decay = args .model_ema_decay ,
221
+ device = 'cpu' if args .model_ema_force_cpu else '' ,
222
+ resume = args .resume )
223
+
208
224
if args .distributed :
209
225
model = DDP (model , delay_allreduce = True )
226
+ if model_ema is not None and not args .model_ema_force_cpu :
227
+ # must also distribute EMA model to allow validation
228
+ model_ema .ema = DDP (model_ema .ema , delay_allreduce = True )
229
+ model_ema .ema_has_module = True
210
230
211
231
lr_scheduler , num_epochs = create_scheduler (args , optimizer )
212
232
if start_epoch > 0 :
@@ -273,6 +293,7 @@ def main():
273
293
eval_metric = args .eval_metric
274
294
saver = None
275
295
if output_dir :
296
+ # only set if process is rank 0
276
297
decreasing = True if eval_metric == 'loss' else False
277
298
saver = CheckpointSaver (checkpoint_dir = output_dir , decreasing = decreasing )
278
299
best_metric = None
@@ -284,10 +305,15 @@ def main():
284
305
285
306
train_metrics = train_epoch (
286
307
epoch , model , loader_train , optimizer , train_loss_fn , args ,
287
- lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir , use_amp = use_amp )
308
+ lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir ,
309
+ use_amp = use_amp , model_ema = model_ema )
310
+
311
+ eval_metrics = validate (model , loader_eval , validate_loss_fn , args )
288
312
289
- eval_metrics = validate (
290
- model , loader_eval , validate_loss_fn , args )
313
+ if model_ema is not None and not args .model_ema_force_cpu :
314
+ ema_eval_metrics = validate (
315
+ model_ema .ema , loader_eval , validate_loss_fn , args , log_suffix = ' (EMA)' )
316
+ eval_metrics = ema_eval_metrics
291
317
292
318
if lr_scheduler is not None :
293
319
lr_scheduler .step (epoch , eval_metrics [eval_metric ])
@@ -298,15 +324,12 @@ def main():
298
324
299
325
if saver is not None :
300
326
# save proper checkpoint with eval metric
301
- best_metric , best_epoch = saver .save_checkpoint ({
302
- 'epoch' : epoch + 1 ,
303
- 'arch' : args .model ,
304
- 'state_dict' : model .state_dict (),
305
- 'optimizer' : optimizer .state_dict (),
306
- 'args' : args ,
307
- },
327
+ save_metric = eval_metrics [eval_metric ]
328
+ best_metric , best_epoch = saver .save_checkpoint (
329
+ model , optimizer , args ,
308
330
epoch = epoch + 1 ,
309
- metric = eval_metrics [eval_metric ])
331
+ model_ema = model_ema ,
332
+ metric = save_metric )
310
333
311
334
except KeyboardInterrupt :
312
335
pass
@@ -316,7 +339,7 @@ def main():
316
339
317
340
def train_epoch (
318
341
epoch , model , loader , optimizer , loss_fn , args ,
319
- lr_scheduler = None , saver = None , output_dir = '' , use_amp = False ):
342
+ lr_scheduler = None , saver = None , output_dir = '' , use_amp = False , model_ema = None ):
320
343
321
344
if args .prefetcher and args .mixup > 0 and loader .mixup_enabled :
322
345
if args .mixup_off_epoch and epoch >= args .mixup_off_epoch :
@@ -359,6 +382,8 @@ def train_epoch(
359
382
optimizer .step ()
360
383
361
384
torch .cuda .synchronize ()
385
+ if model_ema is not None :
386
+ model_ema .update (model )
362
387
num_updates += 1
363
388
364
389
batch_time_m .update (time .time () - end )
@@ -394,18 +419,11 @@ def train_epoch(
394
419
padding = 0 ,
395
420
normalize = True )
396
421
397
- if args . local_rank == 0 and (
398
- saver is not None and last_batch or (batch_idx + 1 ) % args .recovery_interval == 0 ):
422
+ if saver is not None and args . recovery_interval and (
423
+ last_batch or (batch_idx + 1 ) % args .recovery_interval == 0 ):
399
424
save_epoch = epoch + 1 if last_batch else epoch
400
- saver .save_recovery ({
401
- 'epoch' : save_epoch ,
402
- 'arch' : args .model ,
403
- 'state_dict' : model .state_dict (),
404
- 'optimizer' : optimizer .state_dict (),
405
- 'args' : args ,
406
- },
407
- epoch = save_epoch ,
408
- batch_idx = batch_idx )
425
+ saver .save_recovery (
426
+ model , optimizer , args , save_epoch , model_ema = model_ema , batch_idx = batch_idx )
409
427
410
428
if lr_scheduler is not None :
411
429
lr_scheduler .step_update (num_updates = num_updates , metric = losses_m .avg )
@@ -415,7 +433,7 @@ def train_epoch(
415
433
return OrderedDict ([('loss' , losses_m .avg )])
416
434
417
435
418
- def validate (model , loader , loss_fn , args ):
436
+ def validate (model , loader , loss_fn , args , log_suffix = '' ):
419
437
batch_time_m = AverageMeter ()
420
438
losses_m = AverageMeter ()
421
439
prec1_m = AverageMeter ()
@@ -461,12 +479,13 @@ def validate(model, loader, loss_fn, args):
461
479
batch_time_m .update (time .time () - end )
462
480
end = time .time ()
463
481
if args .local_rank == 0 and (last_batch or batch_idx % args .log_interval == 0 ):
464
- print ('Test: [{0}/{1}]\t '
482
+ log_name = 'Test' + log_suffix
483
+ print ('{0}: [{1}/{2}]\t '
465
484
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
466
485
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
467
486
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
468
487
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})' .format (
469
- batch_idx , last_idx ,
488
+ log_name , batch_idx , last_idx ,
470
489
batch_time = batch_time_m , loss = losses_m ,
471
490
top1 = prec1_m , top5 = prec5_m ))
472
491
@@ -475,12 +494,5 @@ def validate(model, loader, loss_fn, args):
475
494
return metrics
476
495
477
496
478
- def reduce_tensor (tensor , n ):
479
- rt = tensor .clone ()
480
- dist .all_reduce (rt , op = dist .ReduceOp .SUM )
481
- rt /= n
482
- return rt
483
-
484
-
485
497
if __name__ == '__main__' :
486
498
main ()
0 commit comments