Skip to content

Commit 9bcd651

Browse files
committed
Add exponential moving average for model weights + few other additions and cleanup
* ModelEma class added to track an EMA set of weights for the model being trained * EMA handling added to train, validation and clean_checkpoint scripts * Add multi checkpoint or multi-model validation support to validate.py * Add syncbn option (APEX) to train script for experimentation * Cleanup interface of CheckpointSaver while adding ema functionality
1 parent ff99625 commit 9bcd651

File tree

6 files changed

+258
-84
lines changed

6 files changed

+258
-84
lines changed

clean_checkpoint.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
help='path to latest checkpoint (default: none)')
1010
parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH',
1111
help='output path')
12+
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
13+
help='use ema version of weights if present')
1214

1315

1416
def main():
@@ -24,8 +26,13 @@ def main():
2426
checkpoint = torch.load(args.checkpoint, map_location='cpu')
2527

2628
new_state_dict = OrderedDict()
27-
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
28-
state_dict = checkpoint['state_dict']
29+
if isinstance(checkpoint, dict):
30+
state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict'
31+
if state_dict_key in checkpoint:
32+
state_dict = checkpoint[state_dict_key]
33+
else:
34+
print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint))
35+
exit(1)
2936
else:
3037
state_dict = checkpoint
3138
for k, v in state_dict.items():

models/helpers.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,24 @@
44
from collections import OrderedDict
55

66

7-
def load_checkpoint(model, checkpoint_path):
7+
def load_checkpoint(model, checkpoint_path, use_ema=False):
88
if checkpoint_path and os.path.isfile(checkpoint_path):
9-
print("=> Loading checkpoint '{}'".format(checkpoint_path))
109
checkpoint = torch.load(checkpoint_path)
11-
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
10+
state_dict_key = ''
11+
if isinstance(checkpoint, dict):
12+
state_dict_key = 'state_dict'
13+
if use_ema and 'state_dict_ema' in checkpoint:
14+
state_dict_key = 'state_dict_ema'
15+
if state_dict_key and state_dict_key in checkpoint:
1216
new_state_dict = OrderedDict()
13-
for k, v in checkpoint['state_dict'].items():
14-
if k.startswith('module'):
15-
name = k[7:] # remove `module.`
16-
else:
17-
name = k
17+
for k, v in checkpoint[state_dict_key].items():
18+
# strip `module.` prefix
19+
name = k[7:] if k.startswith('module') else k
1820
new_state_dict[name] = v
1921
model.load_state_dict(new_state_dict)
2022
else:
2123
model.load_state_dict(checkpoint)
22-
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
24+
print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
2325
else:
2426
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
2527
raise FileNotFoundError()
@@ -28,27 +30,24 @@ def load_checkpoint(model, checkpoint_path):
2830
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
2931
optimizer_state = None
3032
if os.path.isfile(checkpoint_path):
31-
print("=> loading checkpoint '{}'".format(checkpoint_path))
3233
checkpoint = torch.load(checkpoint_path)
3334
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
3435
new_state_dict = OrderedDict()
3536
for k, v in checkpoint['state_dict'].items():
36-
if k.startswith('module'):
37-
name = k[7:] # remove `module.`
38-
else:
39-
name = k
37+
name = k[7:] if k.startswith('module') else k
4038
new_state_dict[name] = v
4139
model.load_state_dict(new_state_dict)
4240
if 'optimizer' in checkpoint:
4341
optimizer_state = checkpoint['optimizer']
44-
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
4542
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
43+
print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
4644
else:
4745
model.load_state_dict(checkpoint)
4846
start_epoch = 0 if start_epoch is None else start_epoch
47+
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
4948
return optimizer_state, start_epoch
5049
else:
51-
print("=> No checkpoint found at '{}'".format(checkpoint_path))
50+
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
5251
raise FileNotFoundError()
5352

5453

optim/rmsprop_tf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def step(self, closure=None):
8989
state['step'] += 1
9090

9191
if group['weight_decay'] != 0:
92-
if group['decoupled_decay']:
92+
if 'decoupled_decay' in group and group['decoupled_decay']:
9393
p.data.add_(-group['weight_decay'], p.data)
9494
else:
9595
grad = grad.add(group['weight_decay'], p.data)
@@ -109,7 +109,7 @@ def step(self, closure=None):
109109
if group['momentum'] > 0:
110110
buf = state['momentum_buffer']
111111
# Tensorflow accumulates the LR scaling in the momentum buffer
112-
if group['lr_in_momentum']:
112+
if 'lr_in_momentum' in group and group['lr_in_momentum']:
113113
buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
114114
p.data.add_(-buf)
115115
else:

train.py

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
try:
77
from apex import amp
88
from apex.parallel import DistributedDataParallel as DDP
9+
from apex.parallel import convert_syncbn_model
910
has_apex = True
1011
except ImportError:
1112
has_apex = False
1213

1314
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
14-
from models import create_model, resume_checkpoint
15+
from models import create_model, resume_checkpoint, load_checkpoint
1516
from utils import *
1617
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
1718
from optim import create_optimizer
@@ -91,11 +92,17 @@
9192
help='BatchNorm momentum override (if not None)')
9293
parser.add_argument('--bn-eps', type=float, default=None,
9394
help='BatchNorm epsilon override (if not None)')
95+
parser.add_argument('--model-ema', action='store_true', default=False,
96+
help='Enable tracking moving average of model weights')
97+
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
98+
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
99+
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
100+
help='decay factor for model weights moving average (default: 0.9998)')
94101
parser.add_argument('--seed', type=int, default=42, metavar='S',
95102
help='random seed (default: 42)')
96103
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
97104
help='how many batches to wait before logging training status')
98-
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
105+
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
99106
help='how many batches to wait before writing recovery checkpoint')
100107
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
101108
help='how many training processes to use (default: 1)')
@@ -109,6 +116,8 @@
109116
help='save images of input bathes every log interval for debugging')
110117
parser.add_argument('--amp', action='store_true', default=False,
111118
help='use NVIDIA amp for mixed precision training')
119+
parser.add_argument('--sync-bn', action='store_true',
120+
help='enabling apex sync BN.')
112121
parser.add_argument('--no-prefetcher', action='store_true', default=False,
113122
help='disable fast prefetcher')
114123
parser.add_argument('--output', default='', type=str, metavar='PATH',
@@ -131,31 +140,28 @@ def main():
131140

132141
args.device = 'cuda:0'
133142
args.world_size = 1
134-
r = -1
143+
args.rank = 0 # global rank
135144
if args.distributed:
136145
args.num_gpu = 1
137146
args.device = 'cuda:%d' % args.local_rank
138147
torch.cuda.set_device(args.local_rank)
139-
torch.distributed.init_process_group(backend='nccl',
140-
init_method='env://')
148+
torch.distributed.init_process_group(
149+
backend='nccl', init_method='env://')
141150
args.world_size = torch.distributed.get_world_size()
142-
r = torch.distributed.get_rank()
151+
args.rank = torch.distributed.get_rank()
152+
assert args.rank >= 0
143153

144154
if args.distributed:
145155
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
146-
% (r, args.world_size))
156+
% (args.rank, args.world_size))
147157
else:
148158
print('Training with a single process on %d GPUs.' % args.num_gpu)
149159

150-
# FIXME seed handling for multi-process distributed?
151-
torch.manual_seed(args.seed)
160+
torch.manual_seed(args.seed + args.rank)
152161

153162
output_dir = ''
154163
if args.local_rank == 0:
155-
if args.output:
156-
output_base = args.output
157-
else:
158-
output_base = './output'
164+
output_base = args.output if args.output else './output'
159165
exp_name = '-'.join([
160166
datetime.now().strftime("%Y%m%d-%H%M%S"),
161167
args.model,
@@ -191,6 +197,8 @@ def main():
191197
args.amp = False
192198
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
193199
else:
200+
if args.distributed and args.sync_bn and has_apex:
201+
model = convert_syncbn_model(model)
194202
model.cuda()
195203

196204
optimizer = create_optimizer(args, model)
@@ -205,8 +213,20 @@ def main():
205213
use_amp = False
206214
print('AMP disabled')
207215

216+
model_ema = None
217+
if args.model_ema:
218+
model_ema = ModelEma(
219+
model,
220+
decay=args.model_ema_decay,
221+
device='cpu' if args.model_ema_force_cpu else '',
222+
resume=args.resume)
223+
208224
if args.distributed:
209225
model = DDP(model, delay_allreduce=True)
226+
if model_ema is not None and not args.model_ema_force_cpu:
227+
# must also distribute EMA model to allow validation
228+
model_ema.ema = DDP(model_ema.ema, delay_allreduce=True)
229+
model_ema.ema_has_module = True
210230

211231
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
212232
if start_epoch > 0:
@@ -273,6 +293,7 @@ def main():
273293
eval_metric = args.eval_metric
274294
saver = None
275295
if output_dir:
296+
# only set if process is rank 0
276297
decreasing = True if eval_metric == 'loss' else False
277298
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
278299
best_metric = None
@@ -284,10 +305,15 @@ def main():
284305

285306
train_metrics = train_epoch(
286307
epoch, model, loader_train, optimizer, train_loss_fn, args,
287-
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
308+
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
309+
use_amp=use_amp, model_ema=model_ema)
310+
311+
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
288312

289-
eval_metrics = validate(
290-
model, loader_eval, validate_loss_fn, args)
313+
if model_ema is not None and not args.model_ema_force_cpu:
314+
ema_eval_metrics = validate(
315+
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
316+
eval_metrics = ema_eval_metrics
291317

292318
if lr_scheduler is not None:
293319
lr_scheduler.step(epoch, eval_metrics[eval_metric])
@@ -298,15 +324,12 @@ def main():
298324

299325
if saver is not None:
300326
# save proper checkpoint with eval metric
301-
best_metric, best_epoch = saver.save_checkpoint({
302-
'epoch': epoch + 1,
303-
'arch': args.model,
304-
'state_dict': model.state_dict(),
305-
'optimizer': optimizer.state_dict(),
306-
'args': args,
307-
},
327+
save_metric = eval_metrics[eval_metric]
328+
best_metric, best_epoch = saver.save_checkpoint(
329+
model, optimizer, args,
308330
epoch=epoch + 1,
309-
metric=eval_metrics[eval_metric])
331+
model_ema=model_ema,
332+
metric=save_metric)
310333

311334
except KeyboardInterrupt:
312335
pass
@@ -316,7 +339,7 @@ def main():
316339

317340
def train_epoch(
318341
epoch, model, loader, optimizer, loss_fn, args,
319-
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
342+
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
320343

321344
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
322345
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
@@ -359,6 +382,8 @@ def train_epoch(
359382
optimizer.step()
360383

361384
torch.cuda.synchronize()
385+
if model_ema is not None:
386+
model_ema.update(model)
362387
num_updates += 1
363388

364389
batch_time_m.update(time.time() - end)
@@ -394,18 +419,11 @@ def train_epoch(
394419
padding=0,
395420
normalize=True)
396421

397-
if args.local_rank == 0 and (
398-
saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
422+
if saver is not None and args.recovery_interval and (
423+
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
399424
save_epoch = epoch + 1 if last_batch else epoch
400-
saver.save_recovery({
401-
'epoch': save_epoch,
402-
'arch': args.model,
403-
'state_dict': model.state_dict(),
404-
'optimizer': optimizer.state_dict(),
405-
'args': args,
406-
},
407-
epoch=save_epoch,
408-
batch_idx=batch_idx)
425+
saver.save_recovery(
426+
model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
409427

410428
if lr_scheduler is not None:
411429
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
@@ -415,7 +433,7 @@ def train_epoch(
415433
return OrderedDict([('loss', losses_m.avg)])
416434

417435

418-
def validate(model, loader, loss_fn, args):
436+
def validate(model, loader, loss_fn, args, log_suffix=''):
419437
batch_time_m = AverageMeter()
420438
losses_m = AverageMeter()
421439
prec1_m = AverageMeter()
@@ -461,12 +479,13 @@ def validate(model, loader, loss_fn, args):
461479
batch_time_m.update(time.time() - end)
462480
end = time.time()
463481
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
464-
print('Test: [{0}/{1}]\t'
482+
log_name = 'Test' + log_suffix
483+
print('{0}: [{1}/{2}]\t'
465484
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
466485
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
467486
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
468487
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
469-
batch_idx, last_idx,
488+
log_name, batch_idx, last_idx,
470489
batch_time=batch_time_m, loss=losses_m,
471490
top1=prec1_m, top5=prec5_m))
472491

@@ -475,12 +494,5 @@ def validate(model, loader, loss_fn, args):
475494
return metrics
476495

477496

478-
def reduce_tensor(tensor, n):
479-
rt = tensor.clone()
480-
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
481-
rt /= n
482-
return rt
483-
484-
485497
if __name__ == '__main__':
486498
main()

0 commit comments

Comments
 (0)