@@ -82,19 +82,30 @@ def fuse(self):
82
82
83
83
84
84
class RepVggDw (nn .Module ):
85
- def __init__ (self , ed , kernel_size ):
85
+ def __init__ (self , ed , kernel_size , legacy = False ):
86
86
super ().__init__ ()
87
87
self .conv = ConvNorm (ed , ed , kernel_size , 1 , (kernel_size - 1 ) // 2 , groups = ed )
88
- self .conv1 = ConvNorm (ed , ed , 1 , 1 , 0 , groups = ed )
88
+ if legacy :
89
+ self .conv1 = ConvNorm (ed , ed , 1 , 1 , 0 , groups = ed )
90
+ # Make torchscript happy.
91
+ self .bn = nn .Identity ()
92
+ else :
93
+ self .conv1 = nn .Conv2d (ed , ed , 1 , 1 , 0 , groups = ed )
94
+ self .bn = nn .BatchNorm2d (ed )
89
95
self .dim = ed
96
+ self .legacy = legacy
90
97
91
98
def forward (self , x ):
92
- return self .conv (x ) + self .conv1 (x ) + x
99
+ return self .bn ( self . conv (x ) + self .conv1 (x ) + x )
93
100
94
101
@torch .no_grad ()
95
102
def fuse (self ):
96
103
conv = self .conv .fuse ()
97
- conv1 = self .conv1 .fuse ()
104
+
105
+ if self .legacy :
106
+ conv1 = self .conv1 .fuse ()
107
+ else :
108
+ conv1 = self .conv1
98
109
99
110
conv_w = conv .weight
100
111
conv_b = conv .bias
@@ -112,6 +123,14 @@ def fuse(self):
112
123
113
124
conv .weight .data .copy_ (final_conv_w )
114
125
conv .bias .data .copy_ (final_conv_b )
126
+
127
+ if not self .legacy :
128
+ bn = self .bn
129
+ w = bn .weight / (bn .running_var + bn .eps ) ** 0.5
130
+ w = conv .weight * w [:, None , None , None ]
131
+ b = bn .bias + (conv .bias - bn .running_mean ) * bn .weight / (bn .running_var + bn .eps ) ** 0.5
132
+ conv .weight .data .copy_ (w )
133
+ conv .bias .data .copy_ (b )
115
134
return conv
116
135
117
136
@@ -127,10 +146,10 @@ def forward(self, x):
127
146
128
147
129
148
class RepViTBlock (nn .Module ):
130
- def __init__ (self , in_dim , mlp_ratio , kernel_size , use_se , act_layer ):
149
+ def __init__ (self , in_dim , mlp_ratio , kernel_size , use_se , act_layer , legacy = False ):
131
150
super (RepViTBlock , self ).__init__ ()
132
151
133
- self .token_mixer = RepVggDw (in_dim , kernel_size )
152
+ self .token_mixer = RepVggDw (in_dim , kernel_size , legacy )
134
153
self .se = SqueezeExcite (in_dim , 0.25 ) if use_se else nn .Identity ()
135
154
self .channel_mixer = RepVitMlp (in_dim , in_dim * mlp_ratio , act_layer )
136
155
@@ -155,9 +174,9 @@ def forward(self, x):
155
174
156
175
157
176
class RepVitDownsample (nn .Module ):
158
- def __init__ (self , in_dim , mlp_ratio , out_dim , kernel_size , act_layer ):
177
+ def __init__ (self , in_dim , mlp_ratio , out_dim , kernel_size , act_layer , legacy = False ):
159
178
super ().__init__ ()
160
- self .pre_block = RepViTBlock (in_dim , mlp_ratio , kernel_size , use_se = False , act_layer = act_layer )
179
+ self .pre_block = RepViTBlock (in_dim , mlp_ratio , kernel_size , use_se = False , act_layer = act_layer , legacy = legacy )
161
180
self .spatial_downsample = ConvNorm (in_dim , in_dim , kernel_size , 2 , (kernel_size - 1 ) // 2 , groups = in_dim )
162
181
self .channel_downsample = ConvNorm (in_dim , out_dim , 1 , 1 )
163
182
self .ffn = RepVitMlp (out_dim , out_dim * mlp_ratio , act_layer )
@@ -172,7 +191,7 @@ def forward(self, x):
172
191
173
192
174
193
class RepVitClassifier (nn .Module ):
175
- def __init__ (self , dim , num_classes , distillation = False , drop = 0. ):
194
+ def __init__ (self , dim , num_classes , distillation = False , drop = 0.0 ):
176
195
super ().__init__ ()
177
196
self .head_drop = nn .Dropout (drop )
178
197
self .head = NormLinear (dim , num_classes ) if num_classes > 0 else nn .Identity ()
@@ -211,18 +230,18 @@ def fuse(self):
211
230
212
231
213
232
class RepVitStage (nn .Module ):
214
- def __init__ (self , in_dim , out_dim , depth , mlp_ratio , act_layer , kernel_size = 3 , downsample = True ):
233
+ def __init__ (self , in_dim , out_dim , depth , mlp_ratio , act_layer , kernel_size = 3 , downsample = True , legacy = False ):
215
234
super ().__init__ ()
216
235
if downsample :
217
- self .downsample = RepVitDownsample (in_dim , mlp_ratio , out_dim , kernel_size , act_layer )
236
+ self .downsample = RepVitDownsample (in_dim , mlp_ratio , out_dim , kernel_size , act_layer , legacy )
218
237
else :
219
238
assert in_dim == out_dim
220
239
self .downsample = nn .Identity ()
221
240
222
241
blocks = []
223
242
use_se = True
224
243
for _ in range (depth ):
225
- blocks .append (RepViTBlock (out_dim , mlp_ratio , kernel_size , use_se , act_layer ))
244
+ blocks .append (RepViTBlock (out_dim , mlp_ratio , kernel_size , use_se , act_layer , legacy ))
226
245
use_se = not use_se
227
246
228
247
self .blocks = nn .Sequential (* blocks )
@@ -246,7 +265,8 @@ def __init__(
246
265
num_classes = 1000 ,
247
266
act_layer = nn .GELU ,
248
267
distillation = True ,
249
- drop_rate = 0. ,
268
+ drop_rate = 0.0 ,
269
+ legacy = False ,
250
270
):
251
271
super (RepVit , self ).__init__ ()
252
272
self .grad_checkpointing = False
@@ -275,6 +295,7 @@ def __init__(
275
295
act_layer = act_layer ,
276
296
kernel_size = kernel_size ,
277
297
downsample = downsample ,
298
+ legacy = legacy ,
278
299
)
279
300
)
280
301
stage_stride = 2 if downsample else 1
@@ -290,12 +311,9 @@ def __init__(
290
311
291
312
@torch .jit .ignore
292
313
def group_matcher (self , coarse = False ):
293
- matcher = dict (
294
- stem = r'^stem' , # stem and embed
295
- blocks = [(r'^blocks\.(\d+)' , None ), (r'^norm' , (99999 ,))]
296
- )
314
+ matcher = dict (stem = r'^stem' , blocks = [(r'^blocks\.(\d+)' , None ), (r'^norm' , (99999 ,))]) # stem and embed
297
315
return matcher
298
-
316
+
299
317
@torch .jit .ignore
300
318
def set_grad_checkpointing (self , enable = True ):
301
319
self .grad_checkpointing = enable
@@ -369,15 +387,42 @@ def _cfg(url='', **kwargs):
369
387
{
370
388
'repvit_m1.dist_in1k' : _cfg (
371
389
hf_hub_id = 'timm/' ,
372
- # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
373
390
),
374
391
'repvit_m2.dist_in1k' : _cfg (
375
392
hf_hub_id = 'timm/' ,
376
- # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
377
393
),
378
394
'repvit_m3.dist_in1k' : _cfg (
379
395
hf_hub_id = 'timm/' ,
380
- # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
396
+ ),
397
+ 'repvit_m0_9.dist_300e_in1k' : _cfg (
398
+ hf_hub_id = 'timm/' ,
399
+ ),
400
+ 'repvit_m0_9.dist_450e_in1k' : _cfg (
401
+ hf_hub_id = 'timm/' ,
402
+ ),
403
+ 'repvit_m1_0.dist_300e_in1k' : _cfg (
404
+ hf_hub_id = 'timm/' ,
405
+ ),
406
+ 'repvit_m1_0.dist_450e_in1k' : _cfg (
407
+ hf_hub_id = 'timm/' ,
408
+ ),
409
+ 'repvit_m1_1.dist_300e_in1k' : _cfg (
410
+ hf_hub_id = 'timm/' ,
411
+ ),
412
+ 'repvit_m1_1.dist_450e_in1k' : _cfg (
413
+ hf_hub_id = 'timm/' ,
414
+ ),
415
+ 'repvit_m1_5.dist_300e_in1k' : _cfg (
416
+ hf_hub_id = 'timm/' ,
417
+ ),
418
+ 'repvit_m1_5.dist_450e_in1k' : _cfg (
419
+ hf_hub_id = 'timm/' ,
420
+ ),
421
+ 'repvit_m2_3.dist_300e_in1k' : _cfg (
422
+ hf_hub_id = 'timm/' ,
423
+ ),
424
+ 'repvit_m2_3.dist_450e_in1k' : _cfg (
425
+ hf_hub_id = 'timm/' ,
381
426
),
382
427
}
383
428
)
@@ -386,7 +431,9 @@ def _cfg(url='', **kwargs):
386
431
def _create_repvit (variant , pretrained = False , ** kwargs ):
387
432
out_indices = kwargs .pop ('out_indices' , (0 , 1 , 2 , 3 ))
388
433
model = build_model_with_cfg (
389
- RepVit , variant , pretrained ,
434
+ RepVit ,
435
+ variant ,
436
+ pretrained ,
390
437
feature_cfg = dict (flatten_sequential = True , out_indices = out_indices ),
391
438
** kwargs ,
392
439
)
@@ -398,7 +445,7 @@ def repvit_m1(pretrained=False, **kwargs):
398
445
"""
399
446
Constructs a RepViT-M1 model
400
447
"""
401
- model_args = dict (embed_dim = (48 , 96 , 192 , 384 ), depth = (2 , 2 , 14 , 2 ))
448
+ model_args = dict (embed_dim = (48 , 96 , 192 , 384 ), depth = (2 , 2 , 14 , 2 ), legacy = True )
402
449
return _create_repvit ('repvit_m1' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
403
450
404
451
@@ -407,7 +454,7 @@ def repvit_m2(pretrained=False, **kwargs):
407
454
"""
408
455
Constructs a RepViT-M2 model
409
456
"""
410
- model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (2 , 2 , 12 , 2 ))
457
+ model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (2 , 2 , 12 , 2 ), legacy = True )
411
458
return _create_repvit ('repvit_m2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
412
459
413
460
@@ -416,5 +463,50 @@ def repvit_m3(pretrained=False, **kwargs):
416
463
"""
417
464
Constructs a RepViT-M3 model
418
465
"""
419
- model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (4 , 4 , 18 , 2 ))
466
+ model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (4 , 4 , 18 , 2 ), legacy = True )
420
467
return _create_repvit ('repvit_m3' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
468
+
469
+
470
+ @register_model
471
+ def repvit_m0_9 (pretrained = False , ** kwargs ):
472
+ """
473
+ Constructs a RepViT-M0.9 model
474
+ """
475
+ model_args = dict (embed_dim = (48 , 96 , 192 , 384 ), depth = (2 , 2 , 14 , 2 ))
476
+ return _create_repvit ('repvit_m0_9' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
477
+
478
+
479
+ @register_model
480
+ def repvit_m1_0 (pretrained = False , ** kwargs ):
481
+ """
482
+ Constructs a RepViT-M1.0 model
483
+ """
484
+ model_args = dict (embed_dim = (56 , 112 , 224 , 448 ), depth = (2 , 2 , 14 , 2 ))
485
+ return _create_repvit ('repvit_m1_0' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
486
+
487
+
488
+ @register_model
489
+ def repvit_m1_1 (pretrained = False , ** kwargs ):
490
+ """
491
+ Constructs a RepViT-M1.1 model
492
+ """
493
+ model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (2 , 2 , 12 , 2 ))
494
+ return _create_repvit ('repvit_m1_1' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
495
+
496
+
497
+ @register_model
498
+ def repvit_m1_5 (pretrained = False , ** kwargs ):
499
+ """
500
+ Constructs a RepViT-M1.5 model
501
+ """
502
+ model_args = dict (embed_dim = (64 , 128 , 256 , 512 ), depth = (4 , 4 , 24 , 4 ))
503
+ return _create_repvit ('repvit_m1_5' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
504
+
505
+
506
+ @register_model
507
+ def repvit_m2_3 (pretrained = False , ** kwargs ):
508
+ """
509
+ Constructs a RepViT-M2.3 model
510
+ """
511
+ model_args = dict (embed_dim = (80 , 160 , 320 , 640 ), depth = (6 , 6 , 34 , 2 ))
512
+ return _create_repvit ('repvit_m2_3' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments