@@ -100,10 +100,10 @@ def _cfg(url='', **kwargs):
100
100
# hybrid models (weights ported from official Google JAX impl)
101
101
'vit_base_resnet50_224_in21k' : _cfg (
102
102
url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth' ,
103
- num_classes = 21843 , mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ), crop_pct = 0.9 ),
103
+ num_classes = 21843 , mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ), crop_pct = 0.9 , first_conv = 'patch_embed.backbone.stem.conv' ),
104
104
'vit_base_resnet50_384' : _cfg (
105
105
url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth' ,
106
- input_size = (3 , 384 , 384 ), mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ), crop_pct = 1.0 ),
106
+ input_size = (3 , 384 , 384 ), mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ), crop_pct = 1.0 , first_conv = 'patch_embed.backbone.stem.conv' ),
107
107
108
108
# hybrid models (my experiments)
109
109
'vit_small_resnet26d_224' : _cfg (),
@@ -256,11 +256,33 @@ def forward(self, x):
256
256
257
257
258
258
class VisionTransformer (nn .Module ):
259
- """ Vision Transformer with support for patch or hybrid CNN input stage
259
+ """ Vision Transformer
260
+
261
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
262
+ https://arxiv.org/abs/2010.11929
260
263
"""
261
264
def __init__ (self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , embed_dim = 768 , depth = 12 ,
262
265
num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , qk_scale = None , representation_size = None ,
263
266
drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. , hybrid_backbone = None , norm_layer = None ):
267
+ """
268
+ Args:
269
+ img_size (int, tuple): input image size
270
+ patch_size (int, tuple): patch size
271
+ in_chans (int): number of input channels
272
+ num_classes (int): number of classes for classification head
273
+ embed_dim (int): embedding dimension
274
+ depth (int): depth of transformer
275
+ num_heads (int): number of attention heads
276
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
277
+ qkv_bias (bool): enable bias for qkv if True
278
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
279
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
280
+ drop_rate (float): dropout rate
281
+ attn_drop_rate (float): attention dropout rate
282
+ drop_path_rate (float): stochastic depth rate
283
+ hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
284
+ norm_layer: (nn.Module): normalization layer
285
+ """
264
286
super ().__init__ ()
265
287
self .num_classes = num_classes
266
288
self .num_features = self .embed_dim = embed_dim # num_features for consistency with other models
@@ -346,8 +368,7 @@ def forward(self, x):
346
368
347
369
348
370
def resize_pos_embed (posemb , posemb_new ):
349
- # Rescale the grid of position embeddings when loading from state_dict
350
- # Adapted from
371
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
351
372
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
352
373
_logger .info ('Resized position embedding: %s to %s' , posemb .shape , posemb_new .shape )
353
374
ntok_new = posemb_new .shape [1 ]
@@ -363,22 +384,21 @@ def resize_pos_embed(posemb, posemb_new):
363
384
posemb_grid = F .interpolate (posemb_grid , size = (gs_new , gs_new ), mode = 'bilinear' )
364
385
posemb_grid = posemb_grid .permute (0 , 2 , 3 , 1 ).reshape (1 , gs_new * gs_new , - 1 )
365
386
posemb = torch .cat ([posemb_tok , posemb_grid ], dim = 1 )
366
- state_dict ['pos_embed' ] = posemb
367
- return state_dict
387
+ return posemb
368
388
369
389
370
390
def checkpoint_filter_fn (state_dict , model ):
371
391
""" convert patch embedding weight from manual patchify + linear proj to conv"""
372
392
out_dict = {}
373
393
if 'model' in state_dict :
374
- # for deit models
394
+ # For deit models
375
395
state_dict = state_dict ['model' ]
376
396
for k , v in state_dict .items ():
377
397
if 'patch_embed.proj.weight' in k and len (v .shape ) < 4 :
378
- # for old models that I trained prior to conv based patchification
398
+ # For old models that I trained prior to conv based patchification
379
399
v = v .reshape (model .patch_embed .proj .weight .shape )
380
400
elif k == 'pos_embed' and v .shape != model .pos_embed .shape :
381
- # to resize pos embedding when using model at different size from pretrained weights
401
+ # To resize pos embedding when using model at different size from pretrained weights
382
402
v = resize_pos_embed (v , model .pos_embed )
383
403
out_dict [k ] = v
384
404
return out_dict
@@ -393,8 +413,9 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
393
413
img_size = kwargs .pop ('img_size' , default_img_size )
394
414
repr_size = kwargs .pop ('representation_size' , None )
395
415
if repr_size is not None and num_classes != default_num_classes :
396
- # remove representation layer if fine-tuning
397
- _logger .info ("Removing representation layer for fine-tuning." )
416
+ # Remove representation layer if fine-tuning. This may not always be the desired action,
417
+ # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
418
+ _logger .warning ("Removing representation layer for fine-tuning." )
398
419
repr_size = None
399
420
400
421
model = VisionTransformer (img_size = img_size , num_classes = num_classes , representation_size = repr_size , ** kwargs )
@@ -409,6 +430,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
409
430
410
431
@register_model
411
432
def vit_small_patch16_224 (pretrained = False , ** kwargs ):
433
+ """ My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3."""
412
434
model_kwargs = dict (
413
435
patch_size = 16 , embed_dim = 768 , depth = 8 , num_heads = 8 , mlp_ratio = 3. ,
414
436
qkv_bias = False , norm_layer = nn .LayerNorm , ** kwargs )
@@ -421,27 +443,38 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
421
443
422
444
@register_model
423
445
def vit_base_patch16_224 (pretrained = False , ** kwargs ):
446
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
447
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
448
+ """
424
449
model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , ** kwargs )
425
450
model = _create_vision_transformer ('vit_base_patch16_224' , pretrained = pretrained , ** model_kwargs )
426
451
return model
427
452
428
453
429
454
@register_model
430
455
def vit_base_patch32_224 (pretrained = False , ** kwargs ):
456
+ """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
457
+ """
431
458
model_kwargs = dict (patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , ** kwargs )
432
459
model = _create_vision_transformer ('vit_base_patch32_224' , pretrained = pretrained , ** model_kwargs )
433
460
return model
434
461
435
462
436
463
@register_model
437
464
def vit_base_patch16_384 (pretrained = False , ** kwargs ):
465
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
466
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
467
+ """
438
468
model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , ** kwargs )
439
469
model = _create_vision_transformer ('vit_base_patch16_384' , pretrained = pretrained , ** model_kwargs )
440
470
return model
441
471
442
472
443
473
@register_model
444
474
def vit_base_patch32_384 (pretrained = False , ** kwargs ):
475
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
476
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
477
+ """
445
478
model_kwargs = dict (
446
479
patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
447
480
norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
@@ -451,35 +484,48 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
451
484
452
485
@register_model
453
486
def vit_large_patch16_224 (pretrained = False , ** kwargs ):
487
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
488
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
489
+ """
454
490
model_kwargs = dict (patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , ** kwargs )
455
491
model = _create_vision_transformer ('vit_large_patch16_224' , pretrained = pretrained , ** model_kwargs )
456
492
return model
457
493
458
494
459
495
@register_model
460
496
def vit_large_patch32_224 (pretrained = False , ** kwargs ):
497
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
498
+ """
461
499
model_kwargs = dict (patch_size = 32 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , ** kwargs )
462
500
model = _create_vision_transformer ('vit_large_patch32_224' , pretrained = pretrained , ** model_kwargs )
463
501
return model
464
502
465
503
466
504
@register_model
467
505
def vit_large_patch16_384 (pretrained = False , ** kwargs ):
506
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
507
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
508
+ """
468
509
model_kwargs = dict (patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , ** kwargs )
469
510
model = _create_vision_transformer ('vit_large_patch16_384' , pretrained = pretrained , ** model_kwargs )
470
511
return model
471
512
472
513
473
514
@register_model
474
- def vit_base_patch16_224_in21k (pretrained = False , ** kwargs ):
475
- model_kwargs = dict (
476
- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , representation_size = 768 , ** kwargs )
477
- model = _create_vision_transformer ('vit_base_patch16_224_in21k' , pretrained = pretrained , ** model_kwargs )
515
+ def vit_large_patch32_384 (pretrained = False , ** kwargs ):
516
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
517
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
518
+ """
519
+ model_kwargs = dict (patch_size = 32 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , ** kwargs )
520
+ model = _create_vision_transformer ('vit_large_patch32_384' , pretrained = pretrained , ** model_kwargs )
478
521
return model
479
522
480
523
481
524
@register_model
482
- def vit_base_patch16_384_in21k (pretrained = False , ** kwargs ):
525
+ def vit_base_patch16_224_in21k (pretrained = False , ** kwargs ):
526
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
527
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
528
+ """
483
529
model_kwargs = dict (
484
530
patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , representation_size = 768 , ** kwargs )
485
531
model = _create_vision_transformer ('vit_base_patch16_224_in21k' , pretrained = pretrained , ** model_kwargs )
@@ -488,6 +534,9 @@ def vit_base_patch16_384_in21k(pretrained=False, **kwargs):
488
534
489
535
@register_model
490
536
def vit_base_patch32_224_in21k (pretrained = False , ** kwargs ):
537
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
538
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
539
+ """
491
540
model_kwargs = dict (
492
541
patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , representation_size = 768 , ** kwargs )
493
542
model = _create_vision_transformer ('vit_base_patch32_224_in21k' , pretrained = pretrained , ** model_kwargs )
@@ -496,22 +545,20 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
496
545
497
546
@register_model
498
547
def vit_large_patch16_224_in21k (pretrained = False , ** kwargs ):
548
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
549
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
550
+ """
499
551
model_kwargs = dict (
500
552
patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , representation_size = 1024 , ** kwargs )
501
553
model = _create_vision_transformer ('vit_large_patch16_224_in21k' , pretrained = pretrained , ** model_kwargs )
502
554
return model
503
555
504
556
505
- # @register_model
506
- # def vit_large_patch16_384_in21k(pretrained=False, **kwargs):
507
- # model_kwargs = dict(
508
- # patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
509
- # model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
510
- # return model
511
-
512
-
513
557
@register_model
514
558
def vit_large_patch32_224_in21k (pretrained = False , ** kwargs ):
559
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
560
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
561
+ """
515
562
model_kwargs = dict (
516
563
patch_size = 32 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , representation_size = 1024 , ** kwargs )
517
564
model = _create_vision_transformer ('vit_large_patch32_224_in21k' , pretrained = pretrained , ** model_kwargs )
@@ -520,6 +567,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
520
567
521
568
@register_model
522
569
def vit_huge_patch14_224_in21k (pretrained = False , ** kwargs ):
570
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
571
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
572
+ NOTE: converted weights not currently available, too large for github release hosting.
573
+ """
523
574
model_kwargs = dict (
524
575
patch_size = 14 , embed_dim = 1280 , depth = 32 , num_heads = 16 , mlp_ratio = 4 , representation_size = 1280 , ** kwargs )
525
576
model = _create_vision_transformer ('vit_huge_patch14_224_in21k' , pretrained = pretrained , ** model_kwargs )
@@ -528,9 +579,13 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
528
579
529
580
@register_model
530
581
def vit_base_resnet50_224_in21k (pretrained = False , ** kwargs ):
582
+ """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
583
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
584
+ """
531
585
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
532
586
backbone = ResNetV2 (
533
- layers = (3 , 4 , 9 ), preact = False , stem_type = 'same' , conv_layer = StdConv2dSame , num_classes = 0 , global_pool = '' )
587
+ layers = (3 , 4 , 9 ), num_classes = 0 , global_pool = '' , in_chans = kwargs .get ('in_chans' , 3 ),
588
+ preact = False , stem_type = 'same' , conv_layer = StdConv2dSame )
534
589
model_kwargs = dict (
535
590
embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , hybrid_backbone = backbone ,
536
591
representation_size = 768 , ** kwargs )
@@ -540,73 +595,93 @@ def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
540
595
541
596
@register_model
542
597
def vit_base_resnet50_384 (pretrained = False , ** kwargs ):
598
+ """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
599
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
600
+ """
543
601
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
544
602
backbone = ResNetV2 (
545
- layers = (3 , 4 , 9 ), preact = False , stem_type = 'same' , conv_layer = StdConv2dSame , num_classes = 0 , global_pool = '' )
603
+ layers = (3 , 4 , 9 ), num_classes = 0 , global_pool = '' , in_chans = kwargs .get ('in_chans' , 3 ),
604
+ preact = False , stem_type = 'same' , conv_layer = StdConv2dSame )
546
605
model_kwargs = dict (embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , hybrid_backbone = backbone , ** kwargs )
547
606
model = _create_vision_transformer ('vit_base_resnet50_384' , pretrained = pretrained , ** model_kwargs )
548
607
return model
549
608
550
609
551
610
@register_model
552
611
def vit_small_resnet26d_224 (pretrained = False , ** kwargs ):
553
- pretrained_backbone = kwargs .get ('pretrained_backbone' , True ) # default to True for now, for testing
554
- backbone = resnet26d (pretrained = pretrained_backbone , features_only = True , out_indices = [4 ])
612
+ """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
613
+ """
614
+ backbone = resnet26d (pretrained = pretrained , features_only = True , out_indices = [4 ])
555
615
model_kwargs = dict (embed_dim = 768 , depth = 8 , num_heads = 8 , mlp_ratio = 3 , hybrid_backbone = backbone , ** kwargs )
556
616
model = _create_vision_transformer ('vit_small_resnet26d_224' , pretrained = pretrained , ** model_kwargs )
557
617
return model
558
618
559
619
560
620
@register_model
561
621
def vit_small_resnet50d_s3_224 (pretrained = False , ** kwargs ):
562
- pretrained_backbone = kwargs .get ('pretrained_backbone' , True ) # default to True for now, for testing
563
- backbone = resnet50d (pretrained = pretrained_backbone , features_only = True , out_indices = [3 ])
622
+ """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
623
+ """
624
+ backbone = resnet50d (pretrained = pretrained , features_only = True , out_indices = [3 ])
564
625
model_kwargs = dict (embed_dim = 768 , depth = 8 , num_heads = 8 , mlp_ratio = 3 , hybrid_backbone = backbone , ** kwargs )
565
626
model = _create_vision_transformer ('vit_small_resnet50d_s3_224' , pretrained = pretrained , ** model_kwargs )
566
627
return model
567
628
568
629
569
630
@register_model
570
631
def vit_base_resnet26d_224 (pretrained = False , ** kwargs ):
571
- pretrained_backbone = kwargs .get ('pretrained_backbone' , True ) # default to True for now, for testing
572
- backbone = resnet26d (pretrained = pretrained_backbone , features_only = True , out_indices = [4 ])
632
+ """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
633
+ """
634
+ backbone = resnet26d (pretrained = pretrained , features_only = True , out_indices = [4 ])
573
635
model_kwargs = dict (embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , hybrid_backbone = backbone , ** kwargs )
574
636
model = _create_vision_transformer ('vit_base_resnet26d_224' , pretrained = pretrained , ** model_kwargs )
575
637
return model
576
638
577
639
578
640
@register_model
579
641
def vit_base_resnet50d_224 (pretrained = False , ** kwargs ):
580
- pretrained_backbone = kwargs .get ('pretrained_backbone' , True ) # default to True for now, for testing
581
- backbone = resnet50d (pretrained = pretrained_backbone , features_only = True , out_indices = [4 ])
642
+ """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
643
+ """
644
+ backbone = resnet50d (pretrained = pretrained , features_only = True , out_indices = [4 ])
582
645
model_kwargs = dict (embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , hybrid_backbone = backbone , ** kwargs )
583
646
model = _create_vision_transformer ('vit_base_resnet50d_224' , pretrained = pretrained , ** model_kwargs )
584
647
return model
585
648
586
649
587
650
@register_model
588
651
def vit_deit_tiny_patch16_224 (pretrained = False , ** kwargs ):
652
+ """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
653
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
654
+ """
589
655
model_kwargs = dict (patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , mlp_ratio = 4 , ** kwargs )
590
656
model = _create_vision_transformer ('vit_deit_tiny_patch16_224' , pretrained = pretrained , ** model_kwargs )
591
657
return model
592
658
593
659
594
660
@register_model
595
661
def vit_deit_small_patch16_224 (pretrained = False , ** kwargs ):
662
+ """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
663
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
664
+ """
596
665
model_kwargs = dict (patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , mlp_ratio = 4 , ** kwargs )
597
666
model = _create_vision_transformer ('vit_deit_small_patch16_224' , pretrained = pretrained , ** model_kwargs )
598
667
return model
599
668
600
669
601
670
@register_model
602
671
def vit_deit_base_patch16_224 (pretrained = False , ** kwargs ):
672
+ """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
673
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
674
+ """
603
675
model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , ** kwargs )
604
676
model = _create_vision_transformer ('vit_deit_base_patch16_224' , pretrained = pretrained , ** model_kwargs )
605
677
return model
606
678
607
679
608
680
@register_model
609
681
def vit_deit_base_patch16_384 (pretrained = False , ** kwargs ):
682
+ """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
683
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
684
+ """
610
685
model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , ** kwargs )
611
686
model = _create_vision_transformer ('vit_deit_base_patch16_384' , pretrained = pretrained , ** model_kwargs )
612
687
return model
0 commit comments