Skip to content

Commit 2681a8d

Browse files
committed
Final blurpool2d cleanup and add resnetblur50 weights, match tresnet Downsample arg order to BlurPool2d for interop
1 parent 9590f30 commit 2681a8d

File tree

6 files changed

+101
-99
lines changed

6 files changed

+101
-99
lines changed

timm/models/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def load_state_dict(checkpoint_path, use_ema=False):
3131
raise FileNotFoundError()
3232

3333

34-
def load_checkpoint(model, checkpoint_path, use_ema=False):
34+
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
3535
state_dict = load_state_dict(checkpoint_path, use_ema)
36-
model.load_state_dict(state_dict)
36+
model.load_state_dict(state_dict, strict=strict)
3737

3838

3939
def resume_checkpoint(model, checkpoint_path):

timm/models/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
1919
from .anti_aliasing import AntiAliasDownsampleLayer
2020
from .space_to_depth import SpaceToDepthModule
21-
from .blurpool import BlurPool2d
21+
from .blur_pool import BlurPool2d

timm/models/layers/anti_aliasing.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66

77
class AntiAliasDownsampleLayer(nn.Module):
8-
def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0):
8+
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
99
super(AntiAliasDownsampleLayer, self).__init__()
1010
if no_jit:
11-
self.op = Downsample(filt_size, stride, channels)
11+
self.op = Downsample(channels, filt_size, stride)
1212
else:
13-
self.op = DownsampleJIT(filt_size, stride, channels)
13+
self.op = DownsampleJIT(channels, filt_size, stride)
1414

1515
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
1616

@@ -20,10 +20,10 @@ def forward(self, x):
2020

2121
@torch.jit.script
2222
class DownsampleJIT(object):
23-
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0):
23+
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
24+
self.channels = channels
2425
self.stride = stride
2526
self.filt_size = filt_size
26-
self.channels = channels
2727
assert self.filt_size == 3
2828
assert stride == 2
2929
self.filt = {} # lazy init by device for DataParallel compat
@@ -32,8 +32,7 @@ def _create_filter(self, like: torch.Tensor):
3232
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
3333
filt = filt[:, None] * filt[None, :]
3434
filt = filt / torch.sum(filt)
35-
filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
36-
return filt
35+
return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
3736

3837
def __call__(self, input: torch.Tensor):
3938
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
@@ -42,11 +41,11 @@ def __call__(self, input: torch.Tensor):
4241

4342

4443
class Downsample(nn.Module):
45-
def __init__(self, filt_size=3, stride=2, channels=None):
44+
def __init__(self, channels=None, filt_size=3, stride=2):
4645
super(Downsample, self).__init__()
46+
self.channels = channels
4747
self.filt_size = filt_size
4848
self.stride = stride
49-
self.channels = channels
5049

5150
assert self.filt_size == 3
5251
filt = torch.tensor([1., 2., 1.])

timm/models/layers/blur_pool.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""
2+
BlurPool layer inspired by
3+
- Kornia's Max_BlurPool2d
4+
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
5+
6+
FIXME merge this impl with those in `anti_aliasing.py`
7+
8+
Hacked together by Chris Ha and Ross Wightman
9+
"""
10+
11+
import torch
12+
import torch.nn as nn
13+
import torch.nn.functional as F
14+
import numpy as np
15+
from typing import Dict
16+
from .padding import get_padding
17+
18+
19+
class BlurPool2d(nn.Module):
20+
r"""Creates a module that computes blurs and downsample a given feature map.
21+
See :cite:`zhang2019shiftinvar` for more details.
22+
Corresponds to the Downsample class, which does blurring and subsampling
23+
24+
Args:
25+
channels = Number of input channels
26+
filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
27+
stride (int): downsampling filter stride
28+
29+
Returns:
30+
torch.Tensor: the transformed tensor.
31+
"""
32+
filt: Dict[str, torch.Tensor]
33+
34+
def __init__(self, channels, filt_size=3, stride=2) -> None:
35+
super(BlurPool2d, self).__init__()
36+
assert filt_size > 1
37+
self.channels = channels
38+
self.filt_size = filt_size
39+
self.stride = stride
40+
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
41+
self.padding = nn.ReflectionPad2d(pad_size)
42+
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
43+
self.filt = {} # lazy init by device for DataParallel compat
44+
45+
def _create_filter(self, like: torch.Tensor):
46+
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
47+
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
48+
49+
def _apply(self, fn):
50+
# override nn.Module _apply, reset filter cache if used
51+
self.filt = {}
52+
super(BlurPool2d, self)._apply(fn)
53+
54+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
55+
C = input_tensor.shape[1]
56+
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
57+
return F.conv2d(
58+
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)

timm/models/layers/blurpool.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

timm/models/resnet.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,11 @@ def _cfg(url='', **kwargs):
118118
'ecaresnet101d_pruned': _cfg(
119119
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
120120
interpolation='bicubic'),
121-
'resnetblur18': _cfg(),
122-
'resnetblur50': _cfg()
121+
'resnetblur18': _cfg(
122+
interpolation='bicubic'),
123+
'resnetblur50': _cfg(
124+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
125+
interpolation='bicubic')
123126
}
124127

125128

@@ -133,21 +136,22 @@ class BasicBlock(nn.Module):
133136

134137
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
135138
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
136-
attn_layer=None, drop_block=None, drop_path=None, blur=False):
139+
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
137140
super(BasicBlock, self).__init__()
138141

139142
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
140143
assert base_width == 64, 'BasicBlock does not support changing base width'
141144
first_planes = planes // reduce_first
142145
outplanes = planes * self.expansion
143146
first_dilation = first_dilation or dilation
147+
use_aa = aa_layer is not None
144148

145149
self.conv1 = nn.Conv2d(
146-
inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation,
150+
inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
147151
dilation=first_dilation, bias=False)
148152
self.bn1 = norm_layer(first_planes)
149153
self.act1 = act_layer(inplace=True)
150-
self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None
154+
self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None
151155

152156
self.conv2 = nn.Conv2d(
153157
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
@@ -173,8 +177,8 @@ def forward(self, x):
173177
if self.drop_block is not None:
174178
x = self.drop_block(x)
175179
x = self.act1(x)
176-
if self.blurpool is not None:
177-
x = self.blurpool(x)
180+
if self.aa is not None:
181+
x = self.aa(x)
178182

179183
x = self.conv2(x)
180184
x = self.bn2(x)
@@ -201,25 +205,25 @@ class Bottleneck(nn.Module):
201205

202206
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
203207
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
204-
attn_layer=None, drop_block=None, drop_path=None, blur=False):
208+
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
205209
super(Bottleneck, self).__init__()
206210

207211
width = int(math.floor(planes * (base_width / 64)) * cardinality)
208212
first_planes = width // reduce_first
209213
outplanes = planes * self.expansion
210214
first_dilation = first_dilation or dilation
211-
self.blur = blur
215+
use_aa = aa_layer is not None
212216

213217
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
214218
self.bn1 = norm_layer(first_planes)
215219
self.act1 = act_layer(inplace=True)
216220

217221
self.conv2 = nn.Conv2d(
218-
first_planes, width, kernel_size=3, stride=1 if blur else stride,
222+
first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
219223
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
220224
self.bn2 = norm_layer(width)
221225
self.act2 = act_layer(inplace=True)
222-
self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None
226+
self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None
223227

224228
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
225229
self.bn3 = norm_layer(outplanes)
@@ -250,8 +254,8 @@ def forward(self, x):
250254
if self.drop_block is not None:
251255
x = self.drop_block(x)
252256
x = self.act2(x)
253-
if self.blurpool is not None:
254-
x = self.blurpool(x)
257+
if self.aa is not None:
258+
x = self.aa(x)
255259

256260
x = self.conv3(x)
257261
x = self.bn3(x)
@@ -365,25 +369,19 @@ class ResNet(nn.Module):
365369
Whether to use average pooling for projection skip connection between stages/downsample.
366370
output_stride : int, default 32
367371
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
368-
act_layer : class, activation layer
369-
norm_layer : class, normalization layer
372+
act_layer : nn.Module, activation layer
373+
norm_layer : nn.Module, normalization layer
374+
aa_layer : nn.Module, anti-aliasing layer
370375
drop_rate : float, default 0.
371376
Dropout probability before classifier, for training
372377
global_pool : str, default 'avg'
373378
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
374-
blur : str, default ''
375-
Location of Blurring:
376-
* '', default - Not applied
377-
* 'max' - only stem layer MaxPool will be blurred
378-
* 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style)
379-
* 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets)
380-
381379
"""
382380
def __init__(self, block, layers, num_classes=1000, in_chans=3,
383381
cardinality=1, base_width=64, stem_width=64, stem_type='',
384382
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
385-
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
386-
drop_block_rate=0., global_pool='avg', blur='', zero_init_last_bn=True, block_args=None):
383+
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
384+
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
387385
block_args = block_args or dict()
388386
self.num_classes = num_classes
389387
deep_stem = 'deep' in stem_type
@@ -392,7 +390,6 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
392390
self.base_width = base_width
393391
self.drop_rate = drop_rate
394392
self.expansion = block.expansion
395-
self.blur = 'strided' in blur
396393
super(ResNet, self).__init__()
397394

398395
# Stem
@@ -414,12 +411,12 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
414411
self.bn1 = norm_layer(self.inplanes)
415412
self.act1 = act_layer(inplace=True)
416413
# Stem Pooling
417-
if 'max' in blur :
414+
if aa_layer is not None:
418415
self.maxpool = nn.Sequential(*[
419416
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
420-
BlurPool2d(channels=self.inplanes, stride=2)
417+
aa_layer(channels=self.inplanes, stride=2)
421418
])
422-
else :
419+
else:
423420
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
424421

425422
# Feature Blocks
@@ -437,7 +434,7 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3,
437434
assert output_stride == 32
438435
layer_args = list(zip(channels, layers, strides, dilations))
439436
layer_kwargs = dict(
440-
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
437+
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
441438
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
442439
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
443440
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
@@ -472,7 +469,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=
472469

473470
block_kwargs = dict(
474471
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
475-
dilation=dilation, blur=self.blur, **kwargs)
472+
dilation=dilation, **kwargs)
476473
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
477474
self.inplanes = planes * block.expansion
478475
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
@@ -1148,18 +1145,21 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
11481145
"""Constructs a ResNet-18 model with blur anti-aliasing
11491146
"""
11501147
default_cfg = default_cfgs['resnetblur18']
1151-
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, blur='max_strided',**kwargs)
1148+
model = ResNet(
1149+
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
11521150
model.default_cfg = default_cfg
11531151
if pretrained:
11541152
load_pretrained(model, default_cfg, num_classes, in_chans)
11551153
return model
11561154

1155+
11571156
@register_model
11581157
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
11591158
"""Constructs a ResNet-50 model with blur anti-aliasing
11601159
"""
11611160
default_cfg = default_cfgs['resnetblur50']
1162-
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs)
1161+
model = ResNet(
1162+
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
11631163
model.default_cfg = default_cfg
11641164
if pretrained:
11651165
load_pretrained(model, default_cfg, num_classes, in_chans)

0 commit comments

Comments
 (0)