1
1
""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
2
2
3
3
Model from official source: https://github.com/microsoft/unilm/tree/master/beit
4
- and
5
- https://github.com/microsoft/unilm/tree/master/beit2
6
4
7
5
@inproceedings{beit,
8
6
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
12
10
url={https://openreview.net/forum?id=p-BhZSz59o4}
13
11
}
14
12
13
+ BEiT-v2 from https://github.com/microsoft/unilm/tree/master/beit2
14
+
15
15
@article{beitv2,
16
16
title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
17
17
author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
21
21
primaryClass={cs.CV}
22
22
}
23
23
24
+ EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
25
+
26
+ @article{EVA,
27
+ title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
28
+ author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
29
+ Tiejun and Wang, Xinlong and Cao, Yue},
30
+ journal={arXiv preprint arXiv:2211.07636},
31
+ year={2022}
32
+ }
33
+
34
+
24
35
At this point only the 1k fine-tuned classification weights and model configs have been added,
25
36
see original source above for pre-training models and procedure.
26
37
37
48
# https://github.com/facebookresearch/deit/
38
49
# https://github.com/facebookresearch/dino
39
50
# --------------------------------------------------------'
51
+
52
+ # EVA models Copyright (c) 2022 BAAI-Vision
53
+
40
54
import math
41
55
from functools import partial
42
56
from typing import Optional , Tuple
46
60
import torch .nn .functional as F
47
61
from torch .utils .checkpoint import checkpoint
48
62
49
- from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
63
+ from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
50
64
from .helpers import build_model_with_cfg
51
65
from .layers import PatchEmbed , Mlp , DropPath , trunc_normal_
66
+ from .pretrained import generate_default_cfgs
52
67
from .registry import register_model
53
68
from .vision_transformer import checkpoint_filter_fn
54
69
@@ -64,52 +79,72 @@ def _cfg(url='', **kwargs):
64
79
}
65
80
66
81
67
- default_cfgs = {
68
- 'beit_base_patch16_224' : _cfg (
82
+ default_cfgs = generate_default_cfgs ( {
83
+ 'beit_base_patch16_224.in22k_ft_in22k_in1k ' : _cfg (
69
84
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth' ),
70
- 'beit_base_patch16_384' : _cfg (
85
+ 'beit_base_patch16_384.in22k_ft_in22k_in1k ' : _cfg (
71
86
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth' ,
72
87
input_size = (3 , 384 , 384 ), crop_pct = 1.0 ,
73
88
),
74
- 'beit_base_patch16_224_in22k ' : _cfg (
89
+ 'beit_base_patch16_224.in22k_ft_in22k ' : _cfg (
75
90
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth' ,
76
91
num_classes = 21841 ,
77
92
),
78
- 'beit_large_patch16_224' : _cfg (
93
+ 'beit_large_patch16_224.in22k_ft_in22k_in1k ' : _cfg (
79
94
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth' ),
80
- 'beit_large_patch16_384' : _cfg (
95
+ 'beit_large_patch16_384.in22k_ft_in22k_in1k ' : _cfg (
81
96
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth' ,
82
97
input_size = (3 , 384 , 384 ), crop_pct = 1.0 ,
83
98
),
84
- 'beit_large_patch16_512' : _cfg (
99
+ 'beit_large_patch16_512.in22k_ft_in22k_in1k ' : _cfg (
85
100
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth' ,
86
101
input_size = (3 , 512 , 512 ), crop_pct = 1.0 ,
87
102
),
88
- 'beit_large_patch16_224_in22k ' : _cfg (
103
+ 'beit_large_patch16_224.in22k_ft_in22k ' : _cfg (
89
104
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' ,
90
105
num_classes = 21841 ,
91
106
),
92
107
93
- 'beitv2_base_patch16_224' : _cfg (
108
+ 'beitv2_base_patch16_224.in1k_ft_in22k_in1k ' : _cfg (
94
109
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth' ,
95
110
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
96
111
),
97
- 'beitv2_base_patch16_224_in22k ' : _cfg (
112
+ 'beitv2_base_patch16_224.in1k_ft_in22k ' : _cfg (
98
113
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth' ,
99
114
num_classes = 21841 ,
100
115
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
101
116
),
102
- 'beitv2_large_patch16_224' : _cfg (
117
+ 'beitv2_large_patch16_224.in1k_ft_in22k_in1k ' : _cfg (
103
118
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth' ,
104
119
crop_pct = 0.95 ,
105
120
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
106
121
),
107
- 'beitv2_large_patch16_224_in22k ' : _cfg (
122
+ 'beitv2_large_patch16_224.in1k_ft_in22k ' : _cfg (
108
123
url = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth' ,
109
124
num_classes = 21841 ,
110
125
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD
111
126
),
112
- }
127
+
128
+ 'eva_giant_patch14_224.clip_ft_in1k' : _cfg (
129
+ hf_hub_id = 'BAAI/EVA' , hf_hub_filename = 'eva_clip_vis_enc_sz224_ftcls_89p1.pt' ,
130
+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
131
+ ),
132
+ 'eva_giant_patch14_336.clip_ft_in1k' : _cfg (
133
+ hf_hub_id = 'BAAI/EVA' ,
134
+ hf_hub_filename = 'eva_clip_vis_enc_sz336_ftcls_89p4.pt' ,
135
+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
136
+ input_size = (3 , 336 , 336 )),
137
+ 'eva_giant_patch14_336.m30m_ft_in22k_in1k' : _cfg (
138
+ hf_hub_id = 'BAAI/EVA' ,
139
+ hf_hub_filename = 'eva_21k_1k_336px_psz14_ema_89p6.pt' ,
140
+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
141
+ input_size = (3 , 336 , 336 )),
142
+ 'eva_giant_patch14_560.m30m_ft_in22k_in1k' : _cfg (
143
+ hf_hub_id = 'BAAI/EVA' ,
144
+ hf_hub_filename = 'eva_21k_1k_560px_psz14_ema_89p7.pt' ,
145
+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
146
+ input_size = (3 , 560 , 560 )),
147
+ })
113
148
114
149
115
150
def gen_relative_position_index (window_size : Tuple [int , int ]) -> torch .Tensor :
@@ -415,7 +450,7 @@ def beit_base_patch16_224(pretrained=False, **kwargs):
415
450
@register_model
416
451
def beit_base_patch16_384 (pretrained = False , ** kwargs ):
417
452
model_kwargs = dict (
418
- img_size = 384 , patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
453
+ img_size = 384 , patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 ,
419
454
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 0.1 , ** kwargs )
420
455
model = _create_beit ('beit_base_patch16_384' , pretrained = pretrained , ** model_kwargs )
421
456
return model
@@ -424,7 +459,7 @@ def beit_base_patch16_384(pretrained=False, **kwargs):
424
459
@register_model
425
460
def beit_base_patch16_224_in22k (pretrained = False , ** kwargs ):
426
461
model_kwargs = dict (
427
- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
462
+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 ,
428
463
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 0.1 , ** kwargs )
429
464
model = _create_beit ('beit_base_patch16_224_in22k' , pretrained = pretrained , ** model_kwargs )
430
465
return model
@@ -433,7 +468,7 @@ def beit_base_patch16_224_in22k(pretrained=False, **kwargs):
433
468
@register_model
434
469
def beit_large_patch16_224 (pretrained = False , ** kwargs ):
435
470
model_kwargs = dict (
436
- patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
471
+ patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
437
472
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , ** kwargs )
438
473
model = _create_beit ('beit_large_patch16_224' , pretrained = pretrained , ** model_kwargs )
439
474
return model
@@ -442,7 +477,7 @@ def beit_large_patch16_224(pretrained=False, **kwargs):
442
477
@register_model
443
478
def beit_large_patch16_384 (pretrained = False , ** kwargs ):
444
479
model_kwargs = dict (
445
- img_size = 384 , patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
480
+ img_size = 384 , patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
446
481
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , ** kwargs )
447
482
model = _create_beit ('beit_large_patch16_384' , pretrained = pretrained , ** model_kwargs )
448
483
return model
@@ -451,7 +486,7 @@ def beit_large_patch16_384(pretrained=False, **kwargs):
451
486
@register_model
452
487
def beit_large_patch16_512 (pretrained = False , ** kwargs ):
453
488
model_kwargs = dict (
454
- img_size = 512 , patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
489
+ img_size = 512 , patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
455
490
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , ** kwargs )
456
491
model = _create_beit ('beit_large_patch16_512' , pretrained = pretrained , ** model_kwargs )
457
492
return model
@@ -460,7 +495,7 @@ def beit_large_patch16_512(pretrained=False, **kwargs):
460
495
@register_model
461
496
def beit_large_patch16_224_in22k (pretrained = False , ** kwargs ):
462
497
model_kwargs = dict (
463
- patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
498
+ patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
464
499
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , ** kwargs )
465
500
model = _create_beit ('beit_large_patch16_224_in22k' , pretrained = pretrained , ** model_kwargs )
466
501
return model
@@ -487,7 +522,7 @@ def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs):
487
522
@register_model
488
523
def beitv2_large_patch16_224 (pretrained = False , ** kwargs ):
489
524
model_kwargs = dict (
490
- patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
525
+ patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
491
526
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , ** kwargs )
492
527
model = _create_beit ('beitv2_large_patch16_224' , pretrained = pretrained , ** model_kwargs )
493
528
return model
@@ -496,7 +531,33 @@ def beitv2_large_patch16_224(pretrained=False, **kwargs):
496
531
@register_model
497
532
def beitv2_large_patch16_224_in22k (pretrained = False , ** kwargs ):
498
533
model_kwargs = dict (
499
- patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
534
+ patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
500
535
use_abs_pos_emb = False , use_rel_pos_bias = True , init_values = 1e-5 , ** kwargs )
501
536
model = _create_beit ('beitv2_large_patch16_224_in22k' , pretrained = pretrained , ** model_kwargs )
502
537
return model
538
+
539
+
540
+ def eva_giant_patch14_224 (pretrained = False , ** kwargs ):
541
+ """ EVA-g model https://arxiv.org/abs/2211.07636 """
542
+ model_kwargs = dict (
543
+ patch_size = 14 , embed_dim = 1408 , depth = 40 , num_heads = 16 , mlp_ratio = 6144 / 1408 , ** kwargs )
544
+ model = _create_beit ('eva_giant_patch14_224' , pretrained = pretrained , ** model_kwargs )
545
+ return model
546
+
547
+
548
+ @register_model
549
+ def eva_giant_patch14_336 (pretrained = False , ** kwargs ):
550
+ """ EVA-g model https://arxiv.org/abs/2211.07636 """
551
+ model_kwargs = dict (
552
+ patch_size = 14 , embed_dim = 1408 , depth = 40 , num_heads = 16 , mlp_ratio = 6144 / 1408 , ** kwargs )
553
+ model = _create_beit ('eva_giant_patch14_336' , pretrained = pretrained , ** model_kwargs )
554
+ return model
555
+
556
+
557
+ @register_model
558
+ def eva_giant_patch14_560 (pretrained = False , ** kwargs ):
559
+ """ EVA-g model https://arxiv.org/abs/2211.07636 """
560
+ model_kwargs = dict (
561
+ patch_size = 14 , embed_dim = 1408 , depth = 40 , num_heads = 16 , mlp_ratio = 6144 / 1408 , ** kwargs )
562
+ model = _create_beit ('eva_giant_patch14_560' , pretrained = pretrained , ** model_kwargs )
563
+ return model
0 commit comments