Skip to content

Commit a517b82

Browse files
committed
Merge branch 'jameslahm-main'
2 parents d3ebdcf + 462fb3e commit a517b82

File tree

1 file changed

+117
-25
lines changed

1 file changed

+117
-25
lines changed

timm/models/repvit.py

Lines changed: 117 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,30 @@ def fuse(self):
8282

8383

8484
class RepVggDw(nn.Module):
85-
def __init__(self, ed, kernel_size):
85+
def __init__(self, ed, kernel_size, legacy=False):
8686
super().__init__()
8787
self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed)
88-
self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed)
88+
if legacy:
89+
self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed)
90+
# Make torchscript happy.
91+
self.bn = nn.Identity()
92+
else:
93+
self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
94+
self.bn = nn.BatchNorm2d(ed)
8995
self.dim = ed
96+
self.legacy = legacy
9097

9198
def forward(self, x):
92-
return self.conv(x) + self.conv1(x) + x
99+
return self.bn(self.conv(x) + self.conv1(x) + x)
93100

94101
@torch.no_grad()
95102
def fuse(self):
96103
conv = self.conv.fuse()
97-
conv1 = self.conv1.fuse()
104+
105+
if self.legacy:
106+
conv1 = self.conv1.fuse()
107+
else:
108+
conv1 = self.conv1
98109

99110
conv_w = conv.weight
100111
conv_b = conv.bias
@@ -112,6 +123,14 @@ def fuse(self):
112123

113124
conv.weight.data.copy_(final_conv_w)
114125
conv.bias.data.copy_(final_conv_b)
126+
127+
if not self.legacy:
128+
bn = self.bn
129+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
130+
w = conv.weight * w[:, None, None, None]
131+
b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / (bn.running_var + bn.eps) ** 0.5
132+
conv.weight.data.copy_(w)
133+
conv.bias.data.copy_(b)
115134
return conv
116135

117136

@@ -127,10 +146,10 @@ def forward(self, x):
127146

128147

129148
class RepViTBlock(nn.Module):
130-
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer):
149+
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy=False):
131150
super(RepViTBlock, self).__init__()
132151

133-
self.token_mixer = RepVggDw(in_dim, kernel_size)
152+
self.token_mixer = RepVggDw(in_dim, kernel_size, legacy)
134153
self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity()
135154
self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer)
136155

@@ -155,9 +174,9 @@ def forward(self, x):
155174

156175

157176
class RepVitDownsample(nn.Module):
158-
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer):
177+
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy=False):
159178
super().__init__()
160-
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer)
179+
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer, legacy=legacy)
161180
self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim)
162181
self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1)
163182
self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer)
@@ -172,7 +191,7 @@ def forward(self, x):
172191

173192

174193
class RepVitClassifier(nn.Module):
175-
def __init__(self, dim, num_classes, distillation=False, drop=0.):
194+
def __init__(self, dim, num_classes, distillation=False, drop=0.0):
176195
super().__init__()
177196
self.head_drop = nn.Dropout(drop)
178197
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
@@ -211,18 +230,18 @@ def fuse(self):
211230

212231

213232
class RepVitStage(nn.Module):
214-
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True):
233+
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True, legacy=False):
215234
super().__init__()
216235
if downsample:
217-
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
236+
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy)
218237
else:
219238
assert in_dim == out_dim
220239
self.downsample = nn.Identity()
221240

222241
blocks = []
223242
use_se = True
224243
for _ in range(depth):
225-
blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer))
244+
blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy))
226245
use_se = not use_se
227246

228247
self.blocks = nn.Sequential(*blocks)
@@ -246,7 +265,8 @@ def __init__(
246265
num_classes=1000,
247266
act_layer=nn.GELU,
248267
distillation=True,
249-
drop_rate=0.,
268+
drop_rate=0.0,
269+
legacy=False,
250270
):
251271
super(RepVit, self).__init__()
252272
self.grad_checkpointing = False
@@ -275,6 +295,7 @@ def __init__(
275295
act_layer=act_layer,
276296
kernel_size=kernel_size,
277297
downsample=downsample,
298+
legacy=legacy,
278299
)
279300
)
280301
stage_stride = 2 if downsample else 1
@@ -290,12 +311,9 @@ def __init__(
290311

291312
@torch.jit.ignore
292313
def group_matcher(self, coarse=False):
293-
matcher = dict(
294-
stem=r'^stem', # stem and embed
295-
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
296-
)
314+
matcher = dict(stem=r'^stem', blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]) # stem and embed
297315
return matcher
298-
316+
299317
@torch.jit.ignore
300318
def set_grad_checkpointing(self, enable=True):
301319
self.grad_checkpointing = enable
@@ -369,15 +387,42 @@ def _cfg(url='', **kwargs):
369387
{
370388
'repvit_m1.dist_in1k': _cfg(
371389
hf_hub_id='timm/',
372-
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
373390
),
374391
'repvit_m2.dist_in1k': _cfg(
375392
hf_hub_id='timm/',
376-
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
377393
),
378394
'repvit_m3.dist_in1k': _cfg(
379395
hf_hub_id='timm/',
380-
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
396+
),
397+
'repvit_m0_9.dist_300e_in1k': _cfg(
398+
hf_hub_id='timm/',
399+
),
400+
'repvit_m0_9.dist_450e_in1k': _cfg(
401+
hf_hub_id='timm/',
402+
),
403+
'repvit_m1_0.dist_300e_in1k': _cfg(
404+
hf_hub_id='timm/',
405+
),
406+
'repvit_m1_0.dist_450e_in1k': _cfg(
407+
hf_hub_id='timm/',
408+
),
409+
'repvit_m1_1.dist_300e_in1k': _cfg(
410+
hf_hub_id='timm/',
411+
),
412+
'repvit_m1_1.dist_450e_in1k': _cfg(
413+
hf_hub_id='timm/',
414+
),
415+
'repvit_m1_5.dist_300e_in1k': _cfg(
416+
hf_hub_id='timm/',
417+
),
418+
'repvit_m1_5.dist_450e_in1k': _cfg(
419+
hf_hub_id='timm/',
420+
),
421+
'repvit_m2_3.dist_300e_in1k': _cfg(
422+
hf_hub_id='timm/',
423+
),
424+
'repvit_m2_3.dist_450e_in1k': _cfg(
425+
hf_hub_id='timm/',
381426
),
382427
}
383428
)
@@ -386,7 +431,9 @@ def _cfg(url='', **kwargs):
386431
def _create_repvit(variant, pretrained=False, **kwargs):
387432
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
388433
model = build_model_with_cfg(
389-
RepVit, variant, pretrained,
434+
RepVit,
435+
variant,
436+
pretrained,
390437
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
391438
**kwargs,
392439
)
@@ -398,7 +445,7 @@ def repvit_m1(pretrained=False, **kwargs):
398445
"""
399446
Constructs a RepViT-M1 model
400447
"""
401-
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
448+
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2), legacy=True)
402449
return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
403450

404451

@@ -407,7 +454,7 @@ def repvit_m2(pretrained=False, **kwargs):
407454
"""
408455
Constructs a RepViT-M2 model
409456
"""
410-
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
457+
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2), legacy=True)
411458
return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
412459

413460

@@ -416,5 +463,50 @@ def repvit_m3(pretrained=False, **kwargs):
416463
"""
417464
Constructs a RepViT-M3 model
418465
"""
419-
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2))
466+
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2), legacy=True)
420467
return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
468+
469+
470+
@register_model
471+
def repvit_m0_9(pretrained=False, **kwargs):
472+
"""
473+
Constructs a RepViT-M0.9 model
474+
"""
475+
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
476+
return _create_repvit('repvit_m0_9', pretrained=pretrained, **dict(model_args, **kwargs))
477+
478+
479+
@register_model
480+
def repvit_m1_0(pretrained=False, **kwargs):
481+
"""
482+
Constructs a RepViT-M1.0 model
483+
"""
484+
model_args = dict(embed_dim=(56, 112, 224, 448), depth=(2, 2, 14, 2))
485+
return _create_repvit('repvit_m1_0', pretrained=pretrained, **dict(model_args, **kwargs))
486+
487+
488+
@register_model
489+
def repvit_m1_1(pretrained=False, **kwargs):
490+
"""
491+
Constructs a RepViT-M1.1 model
492+
"""
493+
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
494+
return _create_repvit('repvit_m1_1', pretrained=pretrained, **dict(model_args, **kwargs))
495+
496+
497+
@register_model
498+
def repvit_m1_5(pretrained=False, **kwargs):
499+
"""
500+
Constructs a RepViT-M1.5 model
501+
"""
502+
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 24, 4))
503+
return _create_repvit('repvit_m1_5', pretrained=pretrained, **dict(model_args, **kwargs))
504+
505+
506+
@register_model
507+
def repvit_m2_3(pretrained=False, **kwargs):
508+
"""
509+
Constructs a RepViT-M2.3 model
510+
"""
511+
model_args = dict(embed_dim=(80, 160, 320, 640), depth=(6, 6, 34, 2))
512+
return _create_repvit('repvit_m2_3', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)