Skip to content

Commit fac58f6

Browse files
committed
Add RAdam, NovoGrad, Lookahead, and AdamW optimizers, a few ResNet tweaks and scheduler factory tweak.
* Add some of the trendy new optimizers. Decent results but not clearly better than the standards. * Can create a None scheduler for constant LR * ResNet defaults to zero_init of last BN in residual * add resnet50d config
1 parent f37e633 commit fac58f6

File tree

10 files changed

+507
-18
lines changed

10 files changed

+507
-18
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ The work of many others is present here. I've tried to make sure all source mate
1313
* [Myself](https://github.com/rwightman/pytorch-dpn-pretrained)
1414
* LR scheduler ideas from [AllenNLP](https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers), [FAIRseq](https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler), and SGDR: Stochastic Gradient Descent with Warm Restarts (https://arxiv.org/abs/1608.03983)
1515
* Random Erasing from [Zhun Zhong](https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py) (https://arxiv.org/abs/1708.04896)
16-
16+
* Optimizers:
17+
* RAdam by [Liyuan Liu](https://github.com/LiyuanLucasLiu/RAdam) (https://arxiv.org/abs/1908.03265)
18+
* NovoGrad by [Masashi Kimura](https://github.com/convergence-lab/novograd) (https://arxiv.org/abs/1905.11286)
19+
* Lookahead adapted from impl by [Liam](https://github.com/alphadl/lookahead.pytorch) (https://arxiv.org/abs/1907.08610)
1720
## Models
1821

1922
I've included a few of my favourite models, but this is not an exhaustive collection. You can't do better than Cadene's collection in that regard. Most models do have pretrained weights from their respective sources or original authors.

timm/models/resnet.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def _cfg(url='', **kwargs):
4444
'resnet50': _cfg(
4545
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth',
4646
interpolation='bicubic'),
47+
'resnet50d': _cfg(
48+
url='',
49+
interpolation='bicubic'),
4750
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
4851
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
4952
'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
@@ -259,7 +262,7 @@ class ResNet(nn.Module):
259262
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
260263
cardinality=1, base_width=64, stem_width=64, deep_stem=False,
261264
block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False,
262-
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg'):
265+
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', zero_init_last_bn=True):
263266
self.num_classes = num_classes
264267
self.inplanes = stem_width * 2 if deep_stem else 64
265268
self.cardinality = cardinality
@@ -296,11 +299,16 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
296299
self.num_features = 512 * block.expansion
297300
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
298301

299-
for m in self.modules():
302+
last_bn_name = 'bn3' if 'Bottleneck' in block.__name__ else 'bn2'
303+
for n, m in self.named_modules():
300304
if isinstance(m, nn.Conv2d):
301305
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
302306
elif isinstance(m, nn.BatchNorm2d):
303-
nn.init.constant_(m.weight, 1.)
307+
if zero_init_last_bn and 'layer' in n and last_bn_name in n:
308+
# Initialize weight/gamma of last BN in each residual block to zero
309+
nn.init.constant_(m.weight, 0.)
310+
else:
311+
nn.init.constant_(m.weight, 1.)
304312
nn.init.constant_(m.bias, 0.)
305313

306314
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
@@ -434,6 +442,20 @@ def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
434442
return model
435443

436444

445+
@register_model
446+
def resnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
447+
"""Constructs a ResNet-50-D model.
448+
"""
449+
default_cfg = default_cfgs['resnet50d']
450+
model = ResNet(
451+
Bottleneck, [3, 4, 6, 3], stem_width=32, deep_stem=True, avg_down=True,
452+
num_classes=num_classes, in_chans=in_chans, **kwargs)
453+
model.default_cfg = default_cfg
454+
if pretrained:
455+
load_pretrained(model, default_cfg, num_classes, in_chans)
456+
return model
457+
458+
437459
@register_model
438460
def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
439461
"""Constructs a ResNet-101 model.

timm/optim/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from .nadam import Nadam
22
from .rmsprop_tf import RMSpropTF
3+
from .adamw import AdamW
4+
from .radam import RAdam
5+
from .novograd import NovoGrad
6+
from .lookahead import Lookahead
37
from .optim_factory import create_optimizer

timm/optim/adamw.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
""" AdamW Optimizer
2+
Impl copied from PyTorch master
3+
"""
4+
import math
5+
import torch
6+
from torch.optim.optimizer import Optimizer
7+
8+
9+
class AdamW(Optimizer):
10+
r"""Implements AdamW algorithm.
11+
12+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
13+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
14+
15+
Arguments:
16+
params (iterable): iterable of parameters to optimize or dicts defining
17+
parameter groups
18+
lr (float, optional): learning rate (default: 1e-3)
19+
betas (Tuple[float, float], optional): coefficients used for computing
20+
running averages of gradient and its square (default: (0.9, 0.999))
21+
eps (float, optional): term added to the denominator to improve
22+
numerical stability (default: 1e-8)
23+
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
24+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25+
algorithm from the paper `On the Convergence of Adam and Beyond`_
26+
(default: False)
27+
28+
.. _Adam\: A Method for Stochastic Optimization:
29+
https://arxiv.org/abs/1412.6980
30+
.. _Decoupled Weight Decay Regularization:
31+
https://arxiv.org/abs/1711.05101
32+
.. _On the Convergence of Adam and Beyond:
33+
https://openreview.net/forum?id=ryQu7f-RZ
34+
"""
35+
36+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
37+
weight_decay=1e-2, amsgrad=False):
38+
if not 0.0 <= lr:
39+
raise ValueError("Invalid learning rate: {}".format(lr))
40+
if not 0.0 <= eps:
41+
raise ValueError("Invalid epsilon value: {}".format(eps))
42+
if not 0.0 <= betas[0] < 1.0:
43+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
44+
if not 0.0 <= betas[1] < 1.0:
45+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
46+
defaults = dict(lr=lr, betas=betas, eps=eps,
47+
weight_decay=weight_decay, amsgrad=amsgrad)
48+
super(AdamW, self).__init__(params, defaults)
49+
50+
def __setstate__(self, state):
51+
super(AdamW, self).__setstate__(state)
52+
for group in self.param_groups:
53+
group.setdefault('amsgrad', False)
54+
55+
def step(self, closure=None):
56+
"""Performs a single optimization step.
57+
58+
Arguments:
59+
closure (callable, optional): A closure that reevaluates the model
60+
and returns the loss.
61+
"""
62+
loss = None
63+
if closure is not None:
64+
loss = closure()
65+
66+
for group in self.param_groups:
67+
for p in group['params']:
68+
if p.grad is None:
69+
continue
70+
71+
# Perform stepweight decay
72+
p.data.mul_(1 - group['lr'] * group['weight_decay'])
73+
74+
# Perform optimization step
75+
grad = p.grad.data
76+
if grad.is_sparse:
77+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
78+
amsgrad = group['amsgrad']
79+
80+
state = self.state[p]
81+
82+
# State initialization
83+
if len(state) == 0:
84+
state['step'] = 0
85+
# Exponential moving average of gradient values
86+
state['exp_avg'] = torch.zeros_like(p.data)
87+
# Exponential moving average of squared gradient values
88+
state['exp_avg_sq'] = torch.zeros_like(p.data)
89+
if amsgrad:
90+
# Maintains max of all exp. moving avg. of sq. grad. values
91+
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
92+
93+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
94+
if amsgrad:
95+
max_exp_avg_sq = state['max_exp_avg_sq']
96+
beta1, beta2 = group['betas']
97+
98+
state['step'] += 1
99+
bias_correction1 = 1 - beta1 ** state['step']
100+
bias_correction2 = 1 - beta2 ** state['step']
101+
102+
# Decay the first and second moment running average coefficient
103+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
104+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
105+
if amsgrad:
106+
# Maintains the maximum of all 2nd moment running avg. till now
107+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
108+
# Use the max. for normalizing running avg. of gradient
109+
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
110+
else:
111+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
112+
113+
step_size = group['lr'] / bias_correction1
114+
115+
p.data.addcdiv_(-step_size, exp_avg, denom)
116+
117+
return loss

timm/optim/lookahead.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
""" Lookahead Optimizer Wrapper.
2+
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
3+
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
4+
"""
5+
import torch
6+
from torch.optim.optimizer import Optimizer
7+
from collections import defaultdict
8+
9+
10+
class Lookahead(Optimizer):
11+
def __init__(self, base_optimizer, alpha=0.5, k=6):
12+
if not 0.0 <= alpha <= 1.0:
13+
raise ValueError(f'Invalid slow update rate: {alpha}')
14+
if not 1 <= k:
15+
raise ValueError(f'Invalid lookahead steps: {k}')
16+
self.alpha = alpha
17+
self.k = k
18+
self.base_optimizer = base_optimizer
19+
self.param_groups = self.base_optimizer.param_groups
20+
self.defaults = base_optimizer.defaults
21+
self.state = defaultdict(dict)
22+
for group in self.param_groups:
23+
group["step_counter"] = 0
24+
25+
def update_slow_weights(self, group):
26+
for fast_p in group["params"]:
27+
if fast_p.grad is None:
28+
continue
29+
param_state = self.state[fast_p]
30+
if "slow_buffer" not in param_state:
31+
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
32+
param_state["slow_buffer"].copy_(fast_p.data)
33+
slow = param_state["slow_buffer"]
34+
slow.add_(self.alpha, fast_p.data - slow)
35+
fast_p.data.copy_(slow)
36+
37+
def sync_lookahead(self):
38+
for group in self.param_groups:
39+
self.update_slow_weights(group)
40+
41+
def step(self, closure=None):
42+
loss = self.base_optimizer.step(closure)
43+
for group in self.param_groups:
44+
group['step_counter'] += 1
45+
if group['step_counter'] % self.k == 0:
46+
self.update_slow_weights(group)
47+
return loss
48+
49+
def state_dict(self):
50+
fast_state_dict = self.base_optimizer.state_dict()
51+
slow_state = {
52+
(id(k) if isinstance(k, torch.Tensor) else k): v
53+
for k, v in self.state.items()
54+
}
55+
fast_state = fast_state_dict["state"]
56+
param_groups = fast_state_dict["param_groups"]
57+
return {
58+
"state": fast_state,
59+
"slow_state": slow_state,
60+
"param_groups": param_groups,
61+
}
62+
63+
def load_state_dict(self, state_dict):
64+
if 'slow_state' not in state_dict:
65+
print('Loading state_dict from optimizer without Lookahead applied')
66+
state_dict['slow_state'] = defaultdict(dict)
67+
slow_state_dict = {
68+
"state": state_dict["slow_state"],
69+
"param_groups": state_dict["param_groups"],
70+
}
71+
fast_state_dict = {
72+
"state": state_dict["state"],
73+
"param_groups": state_dict["param_groups"],
74+
}
75+
super(Lookahead, self).load_state_dict(slow_state_dict)
76+
self.base_optimizer.load_state_dict(fast_state_dict)
77+
78+
def add_param_group(self, param_group):
79+
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
80+
This can be useful when fine tuning a pre-trained network as frozen
81+
layers can be made trainable and added to the :class:`Optimizer` as
82+
training progresses.
83+
Args:
84+
param_group (dict): Specifies what Tensors should be optimized along
85+
with group specific optimization options.
86+
"""
87+
param_group['step_counter'] = 0
88+
self.base_optimizer.add_param_group(param_group)

timm/optim/novograd.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""NovoGrad Optimizer.
2+
Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
3+
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
4+
- https://arxiv.org/abs/1905.11286
5+
"""
6+
7+
import torch
8+
from torch.optim.optimizer import Optimizer
9+
import math
10+
11+
12+
class NovoGrad(Optimizer):
13+
def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
14+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
15+
super(NovoGrad, self).__init__(params, defaults)
16+
self._lr = lr
17+
self._beta1 = betas[0]
18+
self._beta2 = betas[1]
19+
self._eps = eps
20+
self._wd = weight_decay
21+
self._grad_averaging = grad_averaging
22+
23+
self._momentum_initialized = False
24+
25+
def step(self, closure=None):
26+
loss = None
27+
if closure is not None:
28+
loss = closure()
29+
30+
if not self._momentum_initialized:
31+
for group in self.param_groups:
32+
for p in group['params']:
33+
if p.grad is None:
34+
continue
35+
state = self.state[p]
36+
grad = p.grad.data
37+
if grad.is_sparse:
38+
raise RuntimeError('NovoGrad does not support sparse gradients')
39+
40+
v = torch.norm(grad)**2
41+
m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
42+
state['step'] = 0
43+
state['v'] = v
44+
state['m'] = m
45+
state['grad_ema'] = None
46+
self._momentum_initialized = True
47+
48+
for group in self.param_groups:
49+
for p in group['params']:
50+
if p.grad is None:
51+
continue
52+
state = self.state[p]
53+
state['step'] += 1
54+
55+
step, v, m = state['step'], state['v'], state['m']
56+
grad_ema = state['grad_ema']
57+
58+
grad = p.grad.data
59+
g2 = torch.norm(grad)**2
60+
grad_ema = g2 if grad_ema is None else grad_ema * \
61+
self._beta2 + g2 * (1. - self._beta2)
62+
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
63+
64+
if self._grad_averaging:
65+
grad *= (1. - self._beta1)
66+
67+
g2 = torch.norm(grad)**2
68+
v = self._beta2*v + (1. - self._beta2)*g2
69+
m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data)
70+
bias_correction1 = 1 - self._beta1 ** step
71+
bias_correction2 = 1 - self._beta2 ** step
72+
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
73+
74+
state['v'], state['m'] = v, m
75+
state['grad_ema'] = grad_ema
76+
p.data.add_(-step_size, m)
77+
return loss

0 commit comments

Comments
 (0)