diff --git a/README.md b/README.md index d254f32100..8f4460754b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,4 @@ # PyTorch Image Models -- [Sponsors](#sponsors) - [What's New](#whats-new) - [Introduction](#introduction) - [Models](#models) @@ -11,14 +10,6 @@ - [Licenses](#licenses) - [Citing](#citing) -## Sponsors - -Thanks to the following for hardware support: -* TPU Research Cloud (TRC) (https://sites.research.google/trc/about/) -* Nvidia (https://www.nvidia.com/en-us/) - -And a big thanks to all GitHub sponsors who helped with some of my costs before I joined Hugging Face. - ## What's New ❗Updates after Oct 10, 2022 are available in version >= 0.9❗ @@ -35,6 +26,37 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before * The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license. * Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version. +### Sep 1, 2023 +* TinyViT added by [SeeFun](https://github.com/seefun) +* Fix EfficientViT (MIT) to use torch.autocast so it works back to PT 1.10 +* 0.9.7 release + +### Aug 28, 2023 +* Add dynamic img size support to models in `vision_transformer.py`, `vision_transformer_hybrid.py`, `deit.py`, and `eva.py` w/o breaking backward compat. + * Add `dynamic_img_size=True` to args at model creation time to allow changing the grid size (interpolate abs and/or ROPE pos embed each forward pass). + * Add `dynamic_img_pad=True` to allow image sizes that aren't divisible by patch size (pad bottom right to patch size each forward pass). + * Enabling either dynamic mode will break FX tracing unless PatchEmbed module added as leaf. + * Existing method of resizing position embedding by passing different `img_size` (interpolate pretrained embed weights once) on creation still works. + * Existing method of changing `patch_size` (resize pretrained patch_embed weights once) on creation still works. + * Example validation cmd `python validate.py /imagenet --model vit_base_patch16_224 --amp --amp-dtype bfloat16 --img-size 255 --crop-pct 1.0 --model-kwargs dynamic_img_size=True dyamic_img_pad=True` + +### Aug 25, 2023 +* Many new models since last release + * FastViT - https://arxiv.org/abs/2303.14189 + * MobileOne - https://arxiv.org/abs/2206.04040 + * InceptionNeXt - https://arxiv.org/abs/2303.16900 + * RepGhostNet - https://arxiv.org/abs/2211.06088 (thanks https://github.com/ChengpengChen) + * GhostNetV2 - https://arxiv.org/abs/2211.12905 (thanks https://github.com/yehuitang) + * EfficientViT (MSRA) - https://arxiv.org/abs/2305.07027 (thanks https://github.com/seefun) + * EfficientViT (MIT) - https://arxiv.org/abs/2205.14756 (thanks https://github.com/seefun) +* Add `--reparam` arg to `benchmark.py`, `onnx_export.py`, and `validate.py` to trigger layer reparameterization / fusion for models with any one of `reparameterize()`, `switch_to_deploy()` or `fuse()` + * Including FastViT, MobileOne, RepGhostNet, EfficientViT (MSRA), RepViT, RepVGG, and LeViT +* Preparing 0.9.6 'back to school' release + +### Aug 11, 2023 +* Swin, MaxViT, CoAtNet, and BEiT models support resizing of image/window size on creation with adaptation of pretrained weights +* Example validation cmd to test w/ non-square resize `python validate.py /imagenet --model swin_base_patch4_window7_224.ms_in22k_ft_in1k --amp --amp-dtype bfloat16 --input-size 3 256 320 --model-kwargs window_size=8,10 img_size=256,320` + ### Aug 3, 2023 * Add GluonCV weights for HRNet w18_small and w18_small_v2. Converted by [SeeFun](https://github.com/seefun) * Fix `selecsls*` model naming regression @@ -380,179 +402,6 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before * `maxvit_tiny_rw_224` - 83.5 @ 224 (G) * `maxvit_rmlp_tiny_rw_256` - 84.2 @ 256, 84.8 @ 320 (T) -### Aug 29, 2022 -* MaxVit window size scales with img_size by default. Add new RelPosMlp MaxViT weight that leverages this: - * `maxvit_rmlp_nano_rw_256` - 83.0 @ 256, 83.6 @ 320 (T) - -### Aug 26, 2022 -* CoAtNet (https://arxiv.org/abs/2106.04803) and MaxVit (https://arxiv.org/abs/2204.01697) `timm` original models - * both found in [`maxxvit.py`](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/maxxvit.py) model def, contains numerous experiments outside scope of original papers - * an unfinished Tensorflow version from MaxVit authors can be found https://github.com/google-research/maxvit -* Initial CoAtNet and MaxVit timm pretrained weights (working on more): - * `coatnet_nano_rw_224` - 81.7 @ 224 (T) - * `coatnet_rmlp_nano_rw_224` - 82.0 @ 224, 82.8 @ 320 (T) - * `coatnet_0_rw_224` - 82.4 (T) -- NOTE timm '0' coatnets have 2 more 3rd stage blocks - * `coatnet_bn_0_rw_224` - 82.4 (T) - * `maxvit_nano_rw_256` - 82.9 @ 256 (T) - * `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T) - * `coatnet_1_rw_224` - 83.6 @ 224 (G) - * (T) = TPU trained with `bits_and_tpu` branch training code, (G) = GPU trained -* GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes) -* MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit) -* EfficientFormer (adapted from https://github.com/snap-research/EfficientFormer) -* PyramidVisionTransformer-V2 (adapted from https://github.com/whai362/PVT) -* 'Fast Norm' support for LayerNorm and GroupNorm that avoids float32 upcast w/ AMP (uses APEX LN if available for further boost) - - -### Aug 15, 2022 -* ConvNeXt atto weights added - * `convnext_atto` - 75.7 @ 224, 77.0 @ 288 - * `convnext_atto_ols` - 75.9 @ 224, 77.2 @ 288 - -### Aug 5, 2022 -* More custom ConvNeXt smaller model defs with weights - * `convnext_femto` - 77.5 @ 224, 78.7 @ 288 - * `convnext_femto_ols` - 77.9 @ 224, 78.9 @ 288 - * `convnext_pico` - 79.5 @ 224, 80.4 @ 288 - * `convnext_pico_ols` - 79.5 @ 224, 80.5 @ 288 - * `convnext_nano_ols` - 80.9 @ 224, 81.6 @ 288 -* Updated EdgeNeXt to improve ONNX export, add new base variant and weights from original (https://github.com/mmaaz60/EdgeNeXt) - -### July 28, 2022 -* Add freshly minted DeiT-III Medium (width=512, depth=12, num_heads=8) model weights. Thanks [Hugo Touvron](https://github.com/TouvronHugo)! - -### July 27, 2022 -* All runtime benchmark and validation result csv files are finally up-to-date! -* A few more weights & model defs added: - * `darknetaa53` - 79.8 @ 256, 80.5 @ 288 - * `convnext_nano` - 80.8 @ 224, 81.5 @ 288 - * `cs3sedarknet_l` - 81.2 @ 256, 81.8 @ 288 - * `cs3darknet_x` - 81.8 @ 256, 82.2 @ 288 - * `cs3sedarknet_x` - 82.2 @ 256, 82.7 @ 288 - * `cs3edgenet_x` - 82.2 @ 256, 82.7 @ 288 - * `cs3se_edgenet_x` - 82.8 @ 256, 83.5 @ 320 -* `cs3*` weights above all trained on TPU w/ `bits_and_tpu` branch. Thanks to TRC program! -* Add output_stride=8 and 16 support to ConvNeXt (dilation) -* deit3 models not being able to resize pos_emb fixed -* Version 0.6.7 PyPi release (/w above bug fixes and new weighs since 0.6.5) - -### July 8, 2022 -More models, more fixes -* Official research models (w/ weights) added: - * EdgeNeXt from (https://github.com/mmaaz60/EdgeNeXt) - * MobileViT-V2 from (https://github.com/apple/ml-cvnets) - * DeiT III (Revenge of the ViT) from (https://github.com/facebookresearch/deit) -* My own models: - * Small `ResNet` defs added by request with 1 block repeats for both basic and bottleneck (resnet10 and resnet14) - * `CspNet` refactored with dataclass config, simplified CrossStage3 (`cs3`) option. These are closer to YOLO-v5+ backbone defs. - * More relative position vit fiddling. Two `srelpos` (shared relative position) models trained, and a medium w/ class token. - * Add an alternate downsample mode to EdgeNeXt and train a `small` model. Better than original small, but not their new USI trained weights. -* My own model weight results (all ImageNet-1k training) - * `resnet10t` - 66.5 @ 176, 68.3 @ 224 - * `resnet14t` - 71.3 @ 176, 72.3 @ 224 - * `resnetaa50` - 80.6 @ 224 , 81.6 @ 288 - * `darknet53` - 80.0 @ 256, 80.5 @ 288 - * `cs3darknet_m` - 77.0 @ 256, 77.6 @ 288 - * `cs3darknet_focus_m` - 76.7 @ 256, 77.3 @ 288 - * `cs3darknet_l` - 80.4 @ 256, 80.9 @ 288 - * `cs3darknet_focus_l` - 80.3 @ 256, 80.9 @ 288 - * `vit_srelpos_small_patch16_224` - 81.1 @ 224, 82.1 @ 320 - * `vit_srelpos_medium_patch16_224` - 82.3 @ 224, 83.1 @ 320 - * `vit_relpos_small_patch16_cls_224` - 82.6 @ 224, 83.6 @ 320 - * `edgnext_small_rw` - 79.6 @ 224, 80.4 @ 320 -* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs. -* Hugging Face Hub support fixes verified, demo notebook TBA -* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation. -* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103) -* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases. - * a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges. - * previous impl exists as `LayerNormExp2d` in `models/layers/norm.py` -* Numerous bug fixes -* Currently testing for imminent PyPi 0.6.x release -* LeViT pretraining of larger models still a WIP, they don't train well / easily without distillation. Time to add distill support (finally)? -* ImageNet-22k weight training + finetune ongoing, work on multi-weight support (slowly) chugging along (there are a LOT of weights, sigh) ... - -### May 13, 2022 -* Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript. -* Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects. -* More Vision Transformer relative position / residual post-norm experiments (all trained on TPU thanks to TRC program) - * `vit_relpos_small_patch16_224` - 81.5 @ 224, 82.5 @ 320 -- rel pos, layer scale, no class token, avg pool - * `vit_relpos_medium_patch16_rpn_224` - 82.3 @ 224, 83.1 @ 320 -- rel pos + res-post-norm, no class token, avg pool - * `vit_relpos_medium_patch16_224` - 82.5 @ 224, 83.3 @ 320 -- rel pos, layer scale, no class token, avg pool - * `vit_relpos_base_patch16_gapcls_224` - 82.8 @ 224, 83.9 @ 320 -- rel pos, layer scale, class token, avg pool (by mistake) -* Bring 512 dim, 8-head 'medium' ViT model variant back to life (after using in a pre DeiT 'small' model for first ViT impl back in 2020) -* Add ViT relative position support for switching btw existing impl and some additions in official Swin-V2 impl for future trials -* Sequencer2D impl (https://arxiv.org/abs/2205.01972), added via PR from author (https://github.com/okojoalg) - -### May 2, 2022 -* Vision Transformer experiments adding Relative Position (Swin-V2 log-coord) (`vision_transformer_relpos.py`) and Residual Post-Norm branches (from Swin-V2) (`vision_transformer*.py`) - * `vit_relpos_base_patch32_plus_rpn_256` - 79.5 @ 256, 80.6 @ 320 -- rel pos + extended width + res-post-norm, no class token, avg pool - * `vit_relpos_base_patch16_224` - 82.5 @ 224, 83.6 @ 320 -- rel pos, layer scale, no class token, avg pool - * `vit_base_patch16_rpn_224` - 82.3 @ 224 -- rel pos + res-post-norm, no class token, avg pool -* Vision Transformer refactor to remove representation layer that was only used in initial vit and rarely used since with newer pretrain (ie `How to Train Your ViT`) -* `vit_*` models support removal of class token, use of global average pool, use of fc_norm (ala beit, mae). - -### April 22, 2022 -* `timm` models are now officially supported in [fast.ai](https://www.fast.ai/)! Just in time for the new Practical Deep Learning course. `timmdocs` documentation link updated to [timm.fast.ai](http://timm.fast.ai/). -* Two more model weights added in the TPU trained [series](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights). Some In22k pretrain still in progress. - * `seresnext101d_32x8d` - 83.69 @ 224, 84.35 @ 288 - * `seresnextaa101d_32x8d` (anti-aliased w/ AvgPool2d) - 83.85 @ 224, 84.57 @ 288 - -### March 23, 2022 -* Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795) -* `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs. - -### March 21, 2022 -* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch [`0.5.x`](https://github.com/rwightman/pytorch-image-models/tree/0.5.x) or a previous 0.5.x release can be used if stability is required. -* Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights) - * `regnety_040` - 82.3 @ 224, 82.96 @ 288 - * `regnety_064` - 83.0 @ 224, 83.65 @ 288 - * `regnety_080` - 83.17 @ 224, 83.86 @ 288 - * `regnetv_040` - 82.44 @ 224, 83.18 @ 288 (timm pre-act) - * `regnetv_064` - 83.1 @ 224, 83.71 @ 288 (timm pre-act) - * `regnetz_040` - 83.67 @ 256, 84.25 @ 320 - * `regnetz_040h` - 83.77 @ 256, 84.5 @ 320 (w/ extra fc in head) - * `resnetv2_50d_gn` - 80.8 @ 224, 81.96 @ 288 (pre-act GroupNorm) - * `resnetv2_50d_evos` 80.77 @ 224, 82.04 @ 288 (pre-act EvoNormS) - * `regnetz_c16_evos` - 81.9 @ 256, 82.64 @ 320 (EvoNormS) - * `regnetz_d8_evos` - 83.42 @ 256, 84.04 @ 320 (EvoNormS) - * `xception41p` - 82 @ 299 (timm pre-act) - * `xception65` - 83.17 @ 299 - * `xception65p` - 83.14 @ 299 (timm pre-act) - * `resnext101_64x4d` - 82.46 @ 224, 83.16 @ 288 - * `seresnext101_32x8d` - 83.57 @ 224, 84.270 @ 288 - * `resnetrs200` - 83.85 @ 256, 84.44 @ 320 -* HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon) -* SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks. -* Swin-S3 (AutoFormerV2) models / weights added from https://github.com/microsoft/Cream/tree/main/AutoFormerV2 -* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets -* PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer -* VOLO models w/ weights adapted from https://github.com/sail-sg/volo -* Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc -* Enhance support for alternate norm + act ('NormAct') layers added to a number of models, esp EfficientNet/MobileNetV3, RegNet, and aligned Xception -* Grouped conv support added to EfficientNet family -* Add 'group matching' API to all models to allow grouping model parameters for application of 'layer-wise' LR decay, lr scale added to LR scheduler -* Gradient checkpointing support added to many models -* `forward_head(x, pre_logits=False)` fn added to all models to allow separate calls of `forward_features` + `forward_head` -* All vision transformer and vision MLP models update to return non-pooled / non-token selected features from `foward_features`, for consistency with CNN models, token selection or pooling now applied in `forward_head` - -### Feb 2, 2022 -* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) -* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so. - * The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware `pip install git+https://github.com/rwightman/pytorch-image-models` installs! - * `0.5.x` releases and a `0.5.x` branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable. - -### Jan 14, 2022 -* Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon.... -* Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features -* Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way... - * `mnasnet_small` - 65.6 top-1 - * `mobilenetv2_050` - 65.9 - * `lcnet_100/075/050` - 72.1 / 68.8 / 63.1 - * `semnasnet_075` - 73 - * `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0 -* TinyNet models added by [rsomani95](https://github.com/rsomani95) -* LCNet added via MobileNetV3 architecture ## Introduction @@ -594,26 +443,33 @@ All model architecture families include variants with pretrained weights. There * MobileNet-V2 - https://arxiv.org/abs/1801.04381 * Single-Path NAS - https://arxiv.org/abs/1904.02877 * TinyNet - https://arxiv.org/abs/2010.14819 +* EfficientViT (MIT) - https://arxiv.org/abs/2205.14756 +* EfficientViT (MSRA) - https://arxiv.org/abs/2305.07027 * EVA - https://arxiv.org/abs/2211.07636 * EVA-02 - https://arxiv.org/abs/2303.11331 +* FastViT - https://arxiv.org/abs/2303.14189 * FlexiViT - https://arxiv.org/abs/2212.08013 * FocalNet (Focal Modulation Networks) - https://arxiv.org/abs/2203.11926 * GCViT (Global Context Vision Transformer) - https://arxiv.org/abs/2206.09959 * GhostNet - https://arxiv.org/abs/1911.11907 +* GhostNet-V2 - https://arxiv.org/abs/2211.12905 * gMLP - https://arxiv.org/abs/2105.08050 * GPU-Efficient Networks - https://arxiv.org/abs/2006.14090 * Halo Nets - https://arxiv.org/abs/2103.12731 * HRNet - https://arxiv.org/abs/1908.07919 +* InceptionNeXt - https://arxiv.org/abs/2303.16900 * Inception-V3 - https://arxiv.org/abs/1512.00567 * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * Lambda Networks - https://arxiv.org/abs/2102.08602 * LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136 * MaxViT (Multi-Axis Vision Transformer) - https://arxiv.org/abs/2204.01697 +* MetaFormer (PoolFormer-v2, ConvFormer, CAFormer) - https://arxiv.org/abs/2210.13452 * MLP-Mixer - https://arxiv.org/abs/2105.01601 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * FBNet-V3 - https://arxiv.org/abs/2006.02049 * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 * LCNet - https://arxiv.org/abs/2109.15099 +* MobileOne - https://arxiv.org/abs/2206.04040 * MobileViT - https://arxiv.org/abs/2110.02178 * MobileViT-V2 - https://arxiv.org/abs/2206.02680 * MViT-V2 (Improved Multiscale Vision Transformer) - https://arxiv.org/abs/2112.01526 @@ -628,6 +484,8 @@ All model architecture families include variants with pretrained weights. There * RegNet - https://arxiv.org/abs/2003.13678 * RegNetZ - https://arxiv.org/abs/2103.06877 * RepVGG - https://arxiv.org/abs/2101.03697 +* RepGhostNet - https://arxiv.org/abs/2211.06088 +* RepViT - https://arxiv.org/abs/2307.09283 * ResMLP - https://arxiv.org/abs/2105.03404 * ResNet/ResNeXt * ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385 diff --git a/benchmark.py b/benchmark.py index 2cce3e2c11..c31708f513 100755 --- a/benchmark.py +++ b/benchmark.py @@ -22,7 +22,8 @@ from timm.layers import set_fast_norm from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 -from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs +from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\ + reparameterize_model has_apex = False try: @@ -116,6 +117,8 @@ help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--fast-norm', default=False, action='store_true', help='enable experimental fast-norm') +parser.add_argument('--reparam', default=False, action='store_true', + help='Reparameterize model') parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) # codegen (model compilation) options @@ -222,6 +225,7 @@ def __init__( torchscript=False, torchcompile=None, aot_autograd=False, + reparam=False, precision='float32', fuser='', num_warm_iter=10, @@ -252,10 +256,13 @@ def __init__( drop_block_rate=kwargs.pop('drop_block', None), **kwargs.pop('model_kwargs', {}), ) + if reparam: + self.model = reparameterize_model(self.model) self.model.to( device=self.device, dtype=self.model_dtype, - memory_format=torch.channels_last if self.channels_last else None) + memory_format=torch.channels_last if self.channels_last else None, + ) self.num_classes = self.model.num_classes self.param_count = count_params(self.model) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) diff --git a/docs/changes.md b/docs/changes.md index edf88c6217..e28a4ff3c4 100644 --- a/docs/changes.md +++ b/docs/changes.md @@ -1,4 +1,178 @@ # Recent Changes + +### Aug 29, 2022 +* MaxVit window size scales with img_size by default. Add new RelPosMlp MaxViT weight that leverages this: + * `maxvit_rmlp_nano_rw_256` - 83.0 @ 256, 83.6 @ 320 (T) + +### Aug 26, 2022 +* CoAtNet (https://arxiv.org/abs/2106.04803) and MaxVit (https://arxiv.org/abs/2204.01697) `timm` original models + * both found in [`maxxvit.py`](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/maxxvit.py) model def, contains numerous experiments outside scope of original papers + * an unfinished Tensorflow version from MaxVit authors can be found https://github.com/google-research/maxvit +* Initial CoAtNet and MaxVit timm pretrained weights (working on more): + * `coatnet_nano_rw_224` - 81.7 @ 224 (T) + * `coatnet_rmlp_nano_rw_224` - 82.0 @ 224, 82.8 @ 320 (T) + * `coatnet_0_rw_224` - 82.4 (T) -- NOTE timm '0' coatnets have 2 more 3rd stage blocks + * `coatnet_bn_0_rw_224` - 82.4 (T) + * `maxvit_nano_rw_256` - 82.9 @ 256 (T) + * `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T) + * `coatnet_1_rw_224` - 83.6 @ 224 (G) + * (T) = TPU trained with `bits_and_tpu` branch training code, (G) = GPU trained +* GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes) +* MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit) +* EfficientFormer (adapted from https://github.com/snap-research/EfficientFormer) +* PyramidVisionTransformer-V2 (adapted from https://github.com/whai362/PVT) +* 'Fast Norm' support for LayerNorm and GroupNorm that avoids float32 upcast w/ AMP (uses APEX LN if available for further boost) + +### Aug 15, 2022 +* ConvNeXt atto weights added + * `convnext_atto` - 75.7 @ 224, 77.0 @ 288 + * `convnext_atto_ols` - 75.9 @ 224, 77.2 @ 288 + +### Aug 5, 2022 +* More custom ConvNeXt smaller model defs with weights + * `convnext_femto` - 77.5 @ 224, 78.7 @ 288 + * `convnext_femto_ols` - 77.9 @ 224, 78.9 @ 288 + * `convnext_pico` - 79.5 @ 224, 80.4 @ 288 + * `convnext_pico_ols` - 79.5 @ 224, 80.5 @ 288 + * `convnext_nano_ols` - 80.9 @ 224, 81.6 @ 288 +* Updated EdgeNeXt to improve ONNX export, add new base variant and weights from original (https://github.com/mmaaz60/EdgeNeXt) + +### July 28, 2022 +* Add freshly minted DeiT-III Medium (width=512, depth=12, num_heads=8) model weights. Thanks [Hugo Touvron](https://github.com/TouvronHugo)! + +### July 27, 2022 +* All runtime benchmark and validation result csv files are finally up-to-date! +* A few more weights & model defs added: + * `darknetaa53` - 79.8 @ 256, 80.5 @ 288 + * `convnext_nano` - 80.8 @ 224, 81.5 @ 288 + * `cs3sedarknet_l` - 81.2 @ 256, 81.8 @ 288 + * `cs3darknet_x` - 81.8 @ 256, 82.2 @ 288 + * `cs3sedarknet_x` - 82.2 @ 256, 82.7 @ 288 + * `cs3edgenet_x` - 82.2 @ 256, 82.7 @ 288 + * `cs3se_edgenet_x` - 82.8 @ 256, 83.5 @ 320 +* `cs3*` weights above all trained on TPU w/ `bits_and_tpu` branch. Thanks to TRC program! +* Add output_stride=8 and 16 support to ConvNeXt (dilation) +* deit3 models not being able to resize pos_emb fixed +* Version 0.6.7 PyPi release (/w above bug fixes and new weighs since 0.6.5) + +### July 8, 2022 +More models, more fixes +* Official research models (w/ weights) added: + * EdgeNeXt from (https://github.com/mmaaz60/EdgeNeXt) + * MobileViT-V2 from (https://github.com/apple/ml-cvnets) + * DeiT III (Revenge of the ViT) from (https://github.com/facebookresearch/deit) +* My own models: + * Small `ResNet` defs added by request with 1 block repeats for both basic and bottleneck (resnet10 and resnet14) + * `CspNet` refactored with dataclass config, simplified CrossStage3 (`cs3`) option. These are closer to YOLO-v5+ backbone defs. + * More relative position vit fiddling. Two `srelpos` (shared relative position) models trained, and a medium w/ class token. + * Add an alternate downsample mode to EdgeNeXt and train a `small` model. Better than original small, but not their new USI trained weights. +* My own model weight results (all ImageNet-1k training) + * `resnet10t` - 66.5 @ 176, 68.3 @ 224 + * `resnet14t` - 71.3 @ 176, 72.3 @ 224 + * `resnetaa50` - 80.6 @ 224 , 81.6 @ 288 + * `darknet53` - 80.0 @ 256, 80.5 @ 288 + * `cs3darknet_m` - 77.0 @ 256, 77.6 @ 288 + * `cs3darknet_focus_m` - 76.7 @ 256, 77.3 @ 288 + * `cs3darknet_l` - 80.4 @ 256, 80.9 @ 288 + * `cs3darknet_focus_l` - 80.3 @ 256, 80.9 @ 288 + * `vit_srelpos_small_patch16_224` - 81.1 @ 224, 82.1 @ 320 + * `vit_srelpos_medium_patch16_224` - 82.3 @ 224, 83.1 @ 320 + * `vit_relpos_small_patch16_cls_224` - 82.6 @ 224, 83.6 @ 320 + * `edgnext_small_rw` - 79.6 @ 224, 80.4 @ 320 +* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs. +* Hugging Face Hub support fixes verified, demo notebook TBA +* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation. +* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103) +* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases. + * a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges. + * previous impl exists as `LayerNormExp2d` in `models/layers/norm.py` +* Numerous bug fixes +* Currently testing for imminent PyPi 0.6.x release +* LeViT pretraining of larger models still a WIP, they don't train well / easily without distillation. Time to add distill support (finally)? +* ImageNet-22k weight training + finetune ongoing, work on multi-weight support (slowly) chugging along (there are a LOT of weights, sigh) ... + +### May 13, 2022 +* Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript. +* Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects. +* More Vision Transformer relative position / residual post-norm experiments (all trained on TPU thanks to TRC program) + * `vit_relpos_small_patch16_224` - 81.5 @ 224, 82.5 @ 320 -- rel pos, layer scale, no class token, avg pool + * `vit_relpos_medium_patch16_rpn_224` - 82.3 @ 224, 83.1 @ 320 -- rel pos + res-post-norm, no class token, avg pool + * `vit_relpos_medium_patch16_224` - 82.5 @ 224, 83.3 @ 320 -- rel pos, layer scale, no class token, avg pool + * `vit_relpos_base_patch16_gapcls_224` - 82.8 @ 224, 83.9 @ 320 -- rel pos, layer scale, class token, avg pool (by mistake) +* Bring 512 dim, 8-head 'medium' ViT model variant back to life (after using in a pre DeiT 'small' model for first ViT impl back in 2020) +* Add ViT relative position support for switching btw existing impl and some additions in official Swin-V2 impl for future trials +* Sequencer2D impl (https://arxiv.org/abs/2205.01972), added via PR from author (https://github.com/okojoalg) + +### May 2, 2022 +* Vision Transformer experiments adding Relative Position (Swin-V2 log-coord) (`vision_transformer_relpos.py`) and Residual Post-Norm branches (from Swin-V2) (`vision_transformer*.py`) + * `vit_relpos_base_patch32_plus_rpn_256` - 79.5 @ 256, 80.6 @ 320 -- rel pos + extended width + res-post-norm, no class token, avg pool + * `vit_relpos_base_patch16_224` - 82.5 @ 224, 83.6 @ 320 -- rel pos, layer scale, no class token, avg pool + * `vit_base_patch16_rpn_224` - 82.3 @ 224 -- rel pos + res-post-norm, no class token, avg pool +* Vision Transformer refactor to remove representation layer that was only used in initial vit and rarely used since with newer pretrain (ie `How to Train Your ViT`) +* `vit_*` models support removal of class token, use of global average pool, use of fc_norm (ala beit, mae). + +### April 22, 2022 +* `timm` models are now officially supported in [fast.ai](https://www.fast.ai/)! Just in time for the new Practical Deep Learning course. `timmdocs` documentation link updated to [timm.fast.ai](http://timm.fast.ai/). +* Two more model weights added in the TPU trained [series](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights). Some In22k pretrain still in progress. + * `seresnext101d_32x8d` - 83.69 @ 224, 84.35 @ 288 + * `seresnextaa101d_32x8d` (anti-aliased w/ AvgPool2d) - 83.85 @ 224, 84.57 @ 288 + +### March 23, 2022 +* Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795) +* `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs. + +### March 21, 2022 +* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch [`0.5.x`](https://github.com/rwightman/pytorch-image-models/tree/0.5.x) or a previous 0.5.x release can be used if stability is required. +* Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights) + * `regnety_040` - 82.3 @ 224, 82.96 @ 288 + * `regnety_064` - 83.0 @ 224, 83.65 @ 288 + * `regnety_080` - 83.17 @ 224, 83.86 @ 288 + * `regnetv_040` - 82.44 @ 224, 83.18 @ 288 (timm pre-act) + * `regnetv_064` - 83.1 @ 224, 83.71 @ 288 (timm pre-act) + * `regnetz_040` - 83.67 @ 256, 84.25 @ 320 + * `regnetz_040h` - 83.77 @ 256, 84.5 @ 320 (w/ extra fc in head) + * `resnetv2_50d_gn` - 80.8 @ 224, 81.96 @ 288 (pre-act GroupNorm) + * `resnetv2_50d_evos` 80.77 @ 224, 82.04 @ 288 (pre-act EvoNormS) + * `regnetz_c16_evos` - 81.9 @ 256, 82.64 @ 320 (EvoNormS) + * `regnetz_d8_evos` - 83.42 @ 256, 84.04 @ 320 (EvoNormS) + * `xception41p` - 82 @ 299 (timm pre-act) + * `xception65` - 83.17 @ 299 + * `xception65p` - 83.14 @ 299 (timm pre-act) + * `resnext101_64x4d` - 82.46 @ 224, 83.16 @ 288 + * `seresnext101_32x8d` - 83.57 @ 224, 84.270 @ 288 + * `resnetrs200` - 83.85 @ 256, 84.44 @ 320 +* HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon) +* SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks. +* Swin-S3 (AutoFormerV2) models / weights added from https://github.com/microsoft/Cream/tree/main/AutoFormerV2 +* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets +* PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer +* VOLO models w/ weights adapted from https://github.com/sail-sg/volo +* Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc +* Enhance support for alternate norm + act ('NormAct') layers added to a number of models, esp EfficientNet/MobileNetV3, RegNet, and aligned Xception +* Grouped conv support added to EfficientNet family +* Add 'group matching' API to all models to allow grouping model parameters for application of 'layer-wise' LR decay, lr scale added to LR scheduler +* Gradient checkpointing support added to many models +* `forward_head(x, pre_logits=False)` fn added to all models to allow separate calls of `forward_features` + `forward_head` +* All vision transformer and vision MLP models update to return non-pooled / non-token selected features from `foward_features`, for consistency with CNN models, token selection or pooling now applied in `forward_head` + +### Feb 2, 2022 +* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) +* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so. + * The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware `pip install git+https://github.com/rwightman/pytorch-image-models` installs! + * `0.5.x` releases and a `0.5.x` branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable. + +### Jan 14, 2022 +* Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon.... +* Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features +* Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way... + * `mnasnet_small` - 65.6 top-1 + * `mobilenetv2_050` - 65.9 + * `lcnet_100/075/050` - 72.1 / 68.8 / 63.1 + * `semnasnet_075` - 73 + * `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0 +* TinyNet models added by [rsomani95](https://github.com/rsomani95) +* LCNet added via MobileNetV3 architecture + ### Jan 5, 2023 * ConvNeXt-V2 models and weights added to existing `convnext.py` * Paper: [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](http://arxiv.org/abs/2301.00808) diff --git a/onnx_export.py b/onnx_export.py index 54f8f352e5..3baab369ca 100644 --- a/onnx_export.py +++ b/onnx_export.py @@ -21,6 +21,7 @@ import argparse import timm +from timm.utils.model import reparameterize_model from timm.utils.onnx import onnx_export parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') @@ -50,7 +51,12 @@ help='Number classes in dataset') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to checkpoint (default: none)') - +parser.add_argument('--reparam', default=False, action='store_true', + help='Reparameterize model') +parser.add_argument('--training', default=False, action='store_true', + help='Export in training mode (default is eval)') +parser.add_argument('--verbose', default=False, action='store_true', + help='Extra stdout output') def main(): args = parser.parse_args() @@ -71,6 +77,9 @@ def main(): exportable=True, ) + if args.reparam: + model = reparameterize_model(model) + onnx_export( model, args.output, @@ -79,6 +88,8 @@ def main(): aten_fallback=args.aten_fallback, keep_initializers=args.keep_init, check_forward=args.check_forward, + training=args.training, + verbose=args.verbose, ) diff --git a/tests/test_models.py b/tests/test_models.py index d8ac8d6438..b1b2bf195a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -41,7 +41,7 @@ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'eva_*', 'flexivit*', 'eva02*', 'samvit_*' + 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*' ] NUM_NON_STD = len(NON_STD_FILTERS) @@ -175,7 +175,7 @@ def test_model_default_cfgs(model_name, batch_size): outputs = model.forward_features(input_tensor) assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config' assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config' - if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)): + if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)): assert outputs.shape[feat_axis] == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features @@ -188,8 +188,8 @@ def test_model_default_cfgs(model_name, batch_size): model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 - if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)): - # mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP + if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)): + # mobilenetv3/ghostnet/repghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1] if 'pruned' not in model_name: # FIXME better pruned model handling @@ -197,7 +197,7 @@ def test_model_default_cfgs(model_name, batch_size): model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval() outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 - if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)): + if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)): assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1] # check classifier name matches default_cfg diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index caec5e696e..5a610da601 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -37,7 +37,8 @@ from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc -from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords +from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \ + resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat diff --git a/timm/layers/interpolate.py b/timm/layers/interpolate.py new file mode 100644 index 0000000000..adba9342ec --- /dev/null +++ b/timm/layers/interpolate.py @@ -0,0 +1,68 @@ +""" Interpolation helpers for timm layers + +RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations +Copyright Shane Barratt, Apache 2.0 license +""" +import torch +from itertools import product + + +class RegularGridInterpolator: + """ Interpolate data defined on a rectilinear grid with even or uneven spacing. + Produces similar results to scipy RegularGridInterpolator or interp2d + in 'linear' mode. + + Taken from https://github.com/sbarratt/torch_interpolations + """ + + def __init__(self, points, values): + self.points = points + self.values = values + + assert isinstance(self.points, tuple) or isinstance(self.points, list) + assert isinstance(self.values, torch.Tensor) + + self.ms = list(self.values.shape) + self.n = len(self.points) + + assert len(self.ms) == self.n + + for i, p in enumerate(self.points): + assert isinstance(p, torch.Tensor) + assert p.shape[0] == self.values.shape[i] + + def __call__(self, points_to_interp): + assert self.points is not None + assert self.values is not None + + assert len(points_to_interp) == len(self.points) + K = points_to_interp[0].shape[0] + for x in points_to_interp: + assert x.shape[0] == K + + idxs = [] + dists = [] + overalls = [] + for p, x in zip(self.points, points_to_interp): + idx_right = torch.bucketize(x, p) + idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1 + idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1) + dist_left = x - p[idx_left] + dist_right = p[idx_right] - x + dist_left[dist_left < 0] = 0. + dist_right[dist_right < 0] = 0. + both_zero = (dist_left == 0) & (dist_right == 0) + dist_left[both_zero] = dist_right[both_zero] = 1. + + idxs.append((idx_left, idx_right)) + dists.append((dist_left, dist_right)) + overalls.append(dist_left + dist_right) + + numerator = 0. + for indexer in product([0, 1], repeat=self.n): + as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)] + bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)] + numerator += self.values[as_s] * \ + torch.prod(torch.stack(bs_s), dim=0) + denominator = torch.prod(torch.stack(overalls), dim=0) + return numerator / denominator diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index 473b095a46..ec8986d33e 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -26,6 +26,7 @@ class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ output_fmt: Format + dynamic_img_pad: torch.jit.Final[bool] def __init__( self, @@ -38,6 +39,7 @@ def __init__( output_fmt: Optional[str] = None, bias: bool = True, strict_img_size: bool = True, + dynamic_img_pad: bool = False, ): super().__init__() self.patch_size = to_2tuple(patch_size) @@ -58,6 +60,7 @@ def __init__( self.flatten = flatten self.output_fmt = Format.NCHW self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -68,7 +71,7 @@ def forward(self, x): if self.strict_img_size: _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") - else: + elif not self.dynamic_img_pad: _assert( H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." @@ -77,7 +80,10 @@ def forward(self, x): W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." ) - + if self.dynamic_img_pad: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC diff --git a/timm/layers/pos_embed.py b/timm/layers/pos_embed.py index 6be0017f05..3e67be0080 100644 --- a/timm/layers/pos_embed.py +++ b/timm/layers/pos_embed.py @@ -29,7 +29,7 @@ def resample_abs_pos_embed( if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: return posemb - if not old_size: + if old_size is None: hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) old_size = hw, hw diff --git a/timm/layers/pos_embed_rel.py b/timm/layers/pos_embed_rel.py index 5cb3d0f4dd..4620e81deb 100644 --- a/timm/layers/pos_embed_rel.py +++ b/timm/layers/pos_embed_rel.py @@ -3,15 +3,19 @@ Hacked together by / Copyright 2022 Ross Wightman """ import math +import os from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from .interpolate import RegularGridInterpolator from .mlp import Mlp from .weight_init import trunc_normal_ +_USE_SCIPY = int(os.environ.get('TIMM_USE_SCIPY_INTERP', 0)) > 0 + def gen_relative_position_index( q_size: Tuple[int, int], @@ -20,51 +24,251 @@ def gen_relative_position_index( ) -> torch.Tensor: # Adapted with significant modifications from Swin / BeiT codebases # get pair-wise relative position index for each token inside the window - if k_size is None: - coords = torch.stack( - torch.meshgrid([ - torch.arange(q_size[0]), - torch.arange(q_size[1]) - ]) - ).flatten(1) # 2, Wh, Ww - relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 - num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3 - else: - # FIXME different q vs k sizes is a WIP, need to better offset the two grids? - q_coords = torch.stack( - torch.meshgrid([ - torch.arange(q_size[0]), - torch.arange(q_size[1]) - ]) - ).flatten(1) # 2, Wh, Ww - k_coords = torch.stack( - torch.meshgrid([ - torch.arange(k_size[0]), - torch.arange(k_size[1]) - ]) - ).flatten(1) - relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 - # relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0 - # relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1 - # relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1 - # relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw - num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3 - - _, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) + assert k_size is None, 'Different q & k sizes not currently supported' # FIXME + + coords = torch.stack( + torch.meshgrid([ + torch.arange(q_size[0]), + torch.arange(q_size[1]) + ]) + ).flatten(1) # 2, Wh, Ww + relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 + relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += q_size[1] - 1 + relative_coords[:, :, 0] *= 2 * q_size[1] - 1 + num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + + # else: + # # FIXME different q vs k sizes is a WIP, need to better offset the two grids? + # q_coords = torch.stack( + # torch.meshgrid([ + # torch.arange(q_size[0]), + # torch.arange(q_size[1]) + # ]) + # ).flatten(1) # 2, Wh, Ww + # k_coords = torch.stack( + # torch.meshgrid([ + # torch.arange(k_size[0]), + # torch.arange(k_size[1]) + # ]) + # ).flatten(1) + # relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww + # relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 + # relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0 + # relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1 + # relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1 + # relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw + # num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + k_size[1] - 1) + 3 + + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww if class_token: # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias # NOTE not intended or tested with MLP log-coords relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) - relative_position_index[0, 0:] = num_relative_distance - 3 - relative_position_index[0:, 0] = num_relative_distance - 2 - relative_position_index[0, 0] = num_relative_distance - 1 + relative_position_index[0, 0:] = num_relative_distance + relative_position_index[0:, 0] = num_relative_distance + 1 + relative_position_index[0, 0] = num_relative_distance + 2 return relative_position_index.contiguous() +def resize_rel_pos_bias_table_simple( + rel_pos_bias, + new_window_size: Tuple[int, int], + new_bias_shape: Tuple[int, ...], +): + dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1) + if rel_pos_bias.ndim == 3: + # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported + _, dst_h, dst_w = new_bias_shape + num_attn_heads, src_h, src_w = rel_pos_bias.shape + assert dst_h == dst_size[0] and dst_w == dst_size[1] + if src_h != dst_h or src_w != dst_w: + rel_pos_bias = torch.nn.functional.interpolate( + rel_pos_bias.unsqueeze(0), + size=dst_size, + mode="bicubic", + align_corners=False, + ).squeeze(0) + else: + assert rel_pos_bias.ndim == 2 + # (num_pos, num_heads) (aka flat) bias shape + dst_num_pos, _ = new_bias_shape + src_num_pos, num_attn_heads = rel_pos_bias.shape + num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1]) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + src_size = (src_size, src_size) # FIXME could support non-equal src if argument passed + + if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]: + if num_extra_tokens: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + else: + extra_tokens = None + + rel_pos_bias = torch.nn.functional.interpolate( + rel_pos_bias.transpose(1, 0).reshape((1, -1, src_size[0], src_size[1])), + size=dst_size, + mode="bicubic", + align_corners=False, + ).view(-1, dst_num_pos - num_extra_tokens).transpose(0, 1) + + if extra_tokens is not None: + rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + + return rel_pos_bias + + +def resize_rel_pos_bias_table_levit( + position_bias_table, + new_size, + interpolation: str = 'bicubic', + antialias: bool = True, +): + """ + Resample relative position bias table suggested in LeVit + Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py + """ + L1, nH1 = position_bias_table.size() + L2, nH2 = new_size + assert nH1 == nH2 + if L1 != L2: + orig_dtype = position_bias_table.dtype + position_bias_table = position_bias_table.float() + # bicubic interpolate relative_position_bias_table if not match + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + relative_position_bias_table_resized = F.interpolate( + position_bias_table.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), + mode=interpolation, + antialias=antialias) + relative_position_bias_table_resized = \ + relative_position_bias_table_resized.view(nH2, L2).permute(1, 0) + relative_position_bias_table_resized.to(orig_dtype) + return relative_position_bias_table_resized + else: + return position_bias_table + + +def resize_rel_pos_bias_table( + rel_pos_bias, + new_window_size: Tuple[int, int], + new_bias_shape: Tuple[int, ...], +): + """ Resize relative position bias table using more advanced interpolation. + + Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc). + + https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351 + + Args: + rel_pos_bias: + new_window_size: + new_bias_shape: + + Returns: + + """ + if _USE_SCIPY: + from scipy import interpolate + + dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1) + if rel_pos_bias.ndim == 3: + # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported + num_extra_tokens = 0 + _, dst_h, dst_w = new_bias_shape + assert dst_h == dst_size[0] and dst_w == dst_size[1] + num_attn_heads, src_h, src_w = rel_pos_bias.shape + src_size = (src_h, src_w) + has_flat_shape = False + else: + assert rel_pos_bias.ndim == 2 + # (num_pos, num_heads) (aka flat) bias shape + dst_num_pos, _ = new_bias_shape + src_num_pos, num_attn_heads = rel_pos_bias.shape + num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1]) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + src_size = (src_size, src_size) + has_flat_shape = True + + if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]: + # print("Interpolating position from %dx%d to %dx%d" % (src_size[0], src_size[1], dst_size[0], dst_size[1])) + if num_extra_tokens: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + else: + extra_tokens = None + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + def _calc(src, dst): + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src // 2) + if gp > dst // 2: + right = q + else: + left = q + + dis = [] + cur = 1 + for i in range(src // 2): + dis.append(cur) + cur += q ** (i + 1) + r_ids = [-_ for _ in reversed(dis)] + return r_ids + [0] + dis + + y = _calc(src_size[0], dst_size[0]) + x = _calc(src_size[1], dst_size[1]) + yx = [torch.tensor(y), torch.tensor(x)] + # print("Original positions = %s" % str(x)) + + ty = dst_size[0] // 2.0 + tx = dst_size[1] // 2.0 + dy = torch.arange(-ty, ty + 0.1, 1.0) + dx = torch.arange(-tx, tx + 0.1, 1.0) + dyx = torch.meshgrid([dy, dx]) + # print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + for i in range(num_attn_heads): + if has_flat_shape: + z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float() + else: + z = rel_pos_bias[i, :, :].float() + + if _USE_SCIPY: + # Original beit code uses scipy w/ cubic interpolation + f = interpolate.interp2d(x, y, z.numpy(), kind='cubic') + r = torch.Tensor(f(dx, dy)).contiguous().to(rel_pos_bias.device) + else: + # Without scipy dependency, I've found a reasonably simple impl + # that supports uneven spaced interpolation pts with 'linear' interp. + # Results are comparable to scipy for model accuracy in most cases. + f = RegularGridInterpolator(yx, z) + r = f(dyx).contiguous().to(rel_pos_bias.device) + + if has_flat_shape: + r = r.view(-1, 1) + all_rel_pos_bias.append(r) + + if has_flat_shape: + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + else: + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=0) + + if extra_tokens is not None: + assert has_flat_shape + rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + + return rel_pos_bias + + class RelPosBias(nn.Module): """ Relative Position Bias Adapted from Swin-V1 relative position bias impl, modularized. diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index c7beb3d6f3..e850c03409 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -400,13 +400,12 @@ def __init__( temperature=temperature, step=1, ) - print(bands) self.register_buffer( 'bands', bands, persistent=False, ) - self.embed = None + self.pos_embed = None else: # cache full sin/cos embeddings if shape provided up front embeds = build_rotary_pos_embed( @@ -425,17 +424,19 @@ def __init__( ) def get_embed(self, shape: Optional[List[int]] = None): - if self.bands is not None: + if self.bands is not None and shape is not None: # rebuild embeddings every call, use if target shape changes - _assert(shape is not None, 'valid shape needed') embeds = build_rotary_pos_embed( shape, self.bands, in_pixels=self.in_pixels, + ref_feat_shape=self.ref_feat_shape, ) return torch.cat(embeds, -1) - else: + elif self.pos_embed is not None: return self.pos_embed + else: + assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands" def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 diff --git a/timm/models/__init__.py b/timm/models/__init__.py index f308a580b8..0eb9561d54 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -17,12 +17,16 @@ from .efficientformer import * from .efficientformer_v2 import * from .efficientnet import * +from .efficientvit_mit import * +from .efficientvit_msra import * from .eva import * +from .fastvit import * from .focalnet import * from .gcvit import * from .ghostnet import * from .hardcorenas import * from .hrnet import * +from .inception_next import * from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * @@ -40,6 +44,7 @@ from .pnasnet import * from .pvt_v2 import * from .regnet import * +from .repghost import * from .repvit import * from .res2net import * from .resnest import * @@ -53,6 +58,7 @@ from .swin_transformer import * from .swin_transformer_v2 import * from .swin_transformer_v2_cr import * +from .tiny_vit import * from .tnt import * from .tresnet import * from .twins import * diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index b84312a739..c48c13b7fc 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -18,6 +18,15 @@ from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame from timm.layers.non_local_attn import BilinearAttnTransform from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame +from timm.layers.norm_act import ( + BatchNormAct2d, + SyncBatchNormAct, + FrozenBatchNormAct2d, + GroupNormAct, + GroupNorm1Act, + LayerNormAct, + LayerNormAct2d +) __all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules', 'register_notrace_function', 'is_notrace_function', 'get_notrace_functions', @@ -30,7 +39,14 @@ BilinearAttnTransform, # reason: flow control t <= 1 # Reason: get_same_padding has a max which raises a control flow error Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, - CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) + CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]), + BatchNormAct2d, + SyncBatchNormAct, + FrozenBatchNormAct2d, + GroupNormAct, + GroupNorm1Act, + LayerNormAct, + LayerNormAct2d, } try: diff --git a/timm/models/beit.py b/timm/models/beit.py index 3472c0dca2..3863198f12 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -48,6 +48,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn +from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table + from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -115,7 +117,7 @@ def __init__( self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH - self.register_buffer("relative_position_index", gen_relative_position_index(window_size)) + self.register_buffer("relative_position_index", gen_relative_position_index(window_size), persistent=False) else: self.window_size = None self.relative_position_bias_table = None @@ -504,11 +506,46 @@ def _cfg(url='', **kwargs): }) -def _beit_checkpoint_filter_fn(state_dict, model): - if 'module' in state_dict: - # beit v2 didn't strip module - state_dict = state_dict['module'] - return checkpoint_filter_fn(state_dict, model) +def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True): + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('module', state_dict) + # beit v2 didn't strip module + + out_dict = {} + for k, v in state_dict.items(): + if 'relative_position_index' in k: + continue + if 'patch_embed.proj.weight' in k: + O, I, H, W = model.patch_embed.proj.weight.shape + if v.shape[-1] != W or v.shape[-2] != H: + v = resample_patch_embed( + v, + (H, W), + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + num_prefix_tokens = 1 + v = resample_abs_pos_embed( + v, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + out_dict[k] = v + return out_dict def _create_beit(variant, pretrained=False, **kwargs): diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index afb5df4282..a504b7262b 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -12,6 +12,10 @@ Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT +MobileOne - mobileone_* +Paper: `MobileOne: An Improved One millisecond Mobile Backbone` - https://arxiv.org/abs/2206.04040 +Code and weights: https://github.com/apple/ml-mobileone, licensed MIT + In all cases the models have been modified to fit within the design of ByobNet. I've remapped the original weights and verified accuracies. @@ -91,6 +95,27 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): return bcfg +def _mobileone_bcfg(d=(2, 8, 10, 1), wf=(1., 1., 1., 1.), se_blocks=(), num_conv_branches=1): + c = (64, 128, 256, 512) + prev_c = min(64, c[0] * wf[0]) + se_blocks = se_blocks or (0,) * len(d) + bcfg = [] + for d, c, w, se in zip(d, c, wf, se_blocks): + scfg = [] + for i in range(d): + out_c = c * w + bk = dict(num_conv_branches=num_conv_branches) + ak = {} + if i >= d - se: + ak['attn_layer'] = 'se' + scfg += [ByoBlockCfg(type='one', d=1, c=prev_c, gs=1, block_kwargs=bk, **ak)] # depthwise block + scfg += [ByoBlockCfg( + type='one', d=1, c=out_c, gs=0, block_kwargs=dict(kernel_size=1, **bk), **ak)] # pointwise block + prev_c = out_c + bcfg += [scfg] + return bcfg + + def interleave_blocks( types: Tuple[str, str], d, every: Union[int, List[int]] = 1, @@ -447,8 +472,6 @@ class RepVggBlock(nn.Module): """ RepVGG Block. Adapted from impl at https://github.com/DingXiaoH/RepVGG - - This version does not currently support the deploy optimization. It is currently fixed in 'train' mode. """ def __init__( @@ -464,20 +487,34 @@ def __init__( layers: LayerFn = None, drop_block: Callable = None, drop_path_rate: float = 0., + inference_mode: bool = False ): super(RepVggBlock, self).__init__() + self.groups = groups = num_groups(group_size, in_chs) layers = layers or LayerFn() - groups = num_groups(group_size, in_chs) - use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] - self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None - self.conv_kxk = layers.conv_norm_act( - in_chs, out_chs, kernel_size, - stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False, - ) - self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False) + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_channels=in_chs, + out_channels=out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=True, + ) + else: + self.reparam_conv = None + use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] + self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None + self.conv_kxk = layers.conv_norm_act( + in_chs, out_chs, kernel_size, + stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False, + ) + self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) - self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() self.act = layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): @@ -490,16 +527,298 @@ def init_weights(self, zero_init_last: bool = False): self.attn.reset_parameters() def forward(self, x): + if self.reparam_conv is not None: + return self.act(self.attn(self.reparam_conv(x))) + if self.identity is None: x = self.conv_1x1(x) + self.conv_kxk(x) else: identity = self.identity(x) x = self.conv_1x1(x) + self.conv_kxk(x) x = self.drop_path(x) # not in the paper / official impl, experimental - x = x + identity + x += identity x = self.attn(x) # no attn in the paper / official impl, experimental return self.act(x) + def reparameterize(self): + """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + if self.reparam_conv is not None: + return + + kernel, bias = self._get_kernel_bias() + self.reparam_conv = nn.Conv2d( + in_channels=self.conv_kxk.conv.in_channels, + out_channels=self.conv_kxk.conv.out_channels, + kernel_size=self.conv_kxk.conv.kernel_size, + stride=self.conv_kxk.conv.stride, + padding=self.conv_kxk.conv.padding, + dilation=self.conv_kxk.conv.dilation, + groups=self.conv_kxk.conv.groups, + bias=True, + ) + self.reparam_conv.weight.data = kernel + self.reparam_conv.bias.data = bias + + # Delete un-used branches + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue + para.detach_() + self.__delattr__('conv_kxk') + self.__delattr__('conv_1x1') + self.__delattr__('identity') + self.__delattr__('drop_path') + + def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + """ + # get weights and bias of scale branch + kernel_1x1 = 0 + bias_1x1 = 0 + if self.conv_1x1 is not None: + kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.conv_1x1) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.conv_kxk.conv.kernel_size[0] // 2 + kernel_1x1 = torch.nn.functional.pad(kernel_1x1, [pad, pad, pad, pad]) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.identity is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity) + + # get weights and bias of conv branches + kernel_conv, bias_conv = self._fuse_bn_tensor(self.conv_kxk) + + kernel_final = kernel_conv + kernel_1x1 + kernel_identity + bias_final = bias_conv + bias_1x1 + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + """ + if isinstance(branch, ConvNormAct): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + in_chs = self.conv_kxk.conv.in_channels + input_dim = in_chs // self.groups + kernel_size = self.conv_kxk.conv.kernel_size + kernel_value = torch.zeros_like(self.conv_kxk.conv.weight) + for i in range(in_chs): + kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class MobileOneBlock(nn.Module): + """ MobileOne building block. + + This block has a multi-branched architecture at train-time + and plain-CNN style architecture at inference time + For more details, please refer to our paper: + `An Improved One millisecond Mobile Backbone` - + https://arxiv.org/pdf/2206.04040.pdf + """ + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + bottle_ratio: float = 1.0, # unused + group_size: Optional[int] = None, + downsample: str = '', # unused + inference_mode: bool = False, + num_conv_branches: int = 1, + layers: LayerFn = None, + drop_block: Callable = None, + drop_path_rate: float = 0., + ) -> None: + """ Construct a MobileOneBlock module. + """ + super(MobileOneBlock, self).__init__() + self.num_conv_branches = num_conv_branches + self.groups = groups = num_groups(group_size, in_chs) + layers = layers or LayerFn() + + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_channels=in_chs, + out_channels=out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=True) + else: + self.reparam_conv = None + + # Re-parameterizable skip connection + use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] + self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None + + # Re-parameterizable conv branches + convs = [] + for _ in range(self.num_conv_branches): + convs.append(layers.conv_norm_act( + in_chs, out_chs, kernel_size=kernel_size, + stride=stride, groups=groups, apply_act=False)) + self.conv_kxk = nn.ModuleList(convs) + + # Re-parameterizable scale branch + self.conv_scale = None + if kernel_size > 1: + self.conv_scale = layers.conv_norm_act( + in_chs, out_chs, kernel_size=1, + stride=stride, groups=groups, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() + + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.act = layers.act(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Apply forward pass. """ + # Inference mode forward pass. + if self.reparam_conv is not None: + return self.act(self.attn(self.reparam_conv(x))) + + # Multi-branched train-time forward pass. + # Skip branch output + identity_out = 0 + if self.identity is not None: + identity_out = self.identity(x) + + # Scale branch output + scale_out = 0 + if self.conv_scale is not None: + scale_out = self.conv_scale(x) + + # Other branches + out = scale_out + for ck in self.conv_kxk: + out += ck(x) + out = self.drop_path(out) + out += identity_out + + return self.act(self.attn(out)) + + def reparameterize(self): + """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + if self.reparam_conv is not None: + return + + kernel, bias = self._get_kernel_bias() + self.reparam_conv = nn.Conv2d( + in_channels=self.conv_kxk[0].conv.in_channels, + out_channels=self.conv_kxk[0].conv.out_channels, + kernel_size=self.conv_kxk[0].conv.kernel_size, + stride=self.conv_kxk[0].conv.stride, + padding=self.conv_kxk[0].conv.padding, + dilation=self.conv_kxk[0].conv.dilation, + groups=self.conv_kxk[0].conv.groups, + bias=True) + self.reparam_conv.weight.data = kernel + self.reparam_conv.bias.data = bias + + # Delete un-used branches + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue + para.detach_() + self.__delattr__('conv_kxk') + self.__delattr__('conv_scale') + self.__delattr__('identity') + self.__delattr__('drop_path') + + def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.conv_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.conv_kxk[0].conv.kernel_size[0] // 2 + kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.identity is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + """ + if isinstance(branch, ConvNormAct): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + in_chs = self.conv_kxk[0].conv.in_channels + input_dim = in_chs // self.groups + kernel_size = self.conv_kxk[0].conv.kernel_size + kernel_value = torch.zeros_like(self.conv_kxk[0].conv.weight) + for i in range(in_chs): + kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + class SelfAttnBlock(nn.Module): """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1 @@ -576,6 +895,7 @@ def forward(self, x): dark=DarkBlock, edge=EdgeBlock, rep=RepVggBlock, + one=MobileOneBlock, self_attn=SelfAttnBlock, ) @@ -657,7 +977,7 @@ def create_byob_stem( layers: LayerFn = None, ): layers = layers or LayerFn() - assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3') + assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', 'one', '7x7', '3x3') if 'quad' in stem_type: # based on NFNet stem, stack of 4 3x3 convs num_act = 2 if 'quad2' in stem_type else None @@ -670,6 +990,8 @@ def create_byob_stem( stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers) elif 'rep' in stem_type: stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers) + elif 'one' in stem_type: + stem = MobileOneBlock(in_chs, out_chs, kernel_size=3, stride=2, layers=layers) elif '7x7' in stem_type: # 7x7 stem conv as in ResNet if pool_type: @@ -993,6 +1315,16 @@ def _init_weights(module, name='', zero_init_last=False): num_features=1920, ), + repvgg_a0=ByoModelCfg( + blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(0.75, 0.75, 0.75, 2.5)), + stem_type='rep', + stem_chs=48, + ), + repvgg_a1=ByoModelCfg( + blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1, 1, 1, 2.5)), + stem_type='rep', + stem_chs=64, + ), repvgg_a2=ByoModelCfg( blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)), stem_type='rep', @@ -1033,6 +1365,13 @@ def _init_weights(module, name='', zero_init_last=False): stem_type='rep', stem_chs=64, ), + repvgg_d2se=ByoModelCfg( + blocks=_rep_vgg_bcfg(d=(8, 14, 24, 1), wf=(2.5, 2.5, 2.5, 5.)), + stem_type='rep', + stem_chs=64, + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.0625, rd_divisor=1), + ), # 4 x conv stem w/ 2 act, no maxpool, 2,4,6,4 repeats, group size 32 in first 3 blocks # DW convs in last block, 2048 pre-FC, silu act @@ -1375,6 +1714,32 @@ def _init_weights(module, name='', zero_init_last=False): attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), ), + + mobileone_s0=ByoModelCfg( + blocks=_mobileone_bcfg(wf=(0.75, 1.0, 1.0, 2.), num_conv_branches=4), + stem_type='one', + stem_chs=48, + ), + mobileone_s1=ByoModelCfg( + blocks=_mobileone_bcfg(wf=(1.5, 1.5, 2.0, 2.5)), + stem_type='one', + stem_chs=64, + ), + mobileone_s2=ByoModelCfg( + blocks=_mobileone_bcfg(wf=(1.5, 2.0, 2.5, 4.0)), + stem_type='one', + stem_chs=64, + ), + mobileone_s3=ByoModelCfg( + blocks=_mobileone_bcfg(wf=(2.0, 2.5, 3.0, 4.0)), + stem_type='one', + stem_chs=64, + ), + mobileone_s4=ByoModelCfg( + blocks=_mobileone_bcfg(wf=(3.0, 3.5, 3.5, 4.0), se_blocks=(0, 0, 5, 1)), + stem_type='one', + stem_chs=64, + ), ) @@ -1413,6 +1778,12 @@ def _cfgr(url='', **kwargs): 'gernet_l.idstcv_in1k': _cfg(hf_hub_id='timm/', input_size=(3, 256, 256), pool_size=(8, 8)), # RepVGG weights + 'repvgg_a0.rvgg_in1k': _cfg( + hf_hub_id='timm/', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'), + 'repvgg_a1.rvgg_in1k': _cfg( + hf_hub_id='timm/', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'), 'repvgg_a2.rvgg_in1k': _cfg( hf_hub_id='timm/', first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'), @@ -1437,6 +1808,11 @@ def _cfgr(url='', **kwargs): 'repvgg_b3g4.rvgg_in1k': _cfg( hf_hub_id='timm/', first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'), + 'repvgg_d2se.rvgg_in1k': _cfg( + hf_hub_id='timm/', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, + ), # experimental ResNet configs 'resnet51q.ra2_in1k': _cfg( @@ -1539,6 +1915,32 @@ def _cfgr(url='', **kwargs): hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0), + + 'mobileone_s0.apple_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.875, + first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'), + ), + 'mobileone_s1.apple_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.9, + first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'), + ), + 'mobileone_s2.apple_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.9, + first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'), + ), + 'mobileone_s3.apple_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.9, + first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'), + ), + 'mobileone_s4.apple_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.9, + first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'), + ), }) @@ -1566,6 +1968,22 @@ def gernet_s(pretrained=False, **kwargs) -> ByobNet: return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs) +@register_model +def repvgg_a0(pretrained=False, **kwargs) -> ByobNet: + """ RepVGG-A0 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_a0', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_a1(pretrained=False, **kwargs) -> ByobNet: + """ RepVGG-A1 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_a1', pretrained=pretrained, **kwargs) + + @register_model def repvgg_a2(pretrained=False, **kwargs) -> ByobNet: """ RepVGG-A2 @@ -1630,6 +2048,14 @@ def repvgg_b3g4(pretrained=False, **kwargs) -> ByobNet: return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs) +@register_model +def repvgg_d2se(pretrained=False, **kwargs) -> ByobNet: + """ RepVGG-D2se + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_d2se', pretrained=pretrained, **kwargs) + + @register_model def resnet51q(pretrained=False, **kwargs) -> ByobNet: """ @@ -1782,3 +2208,38 @@ def regnetz_d8_evos(pretrained=False, **kwargs) -> ByobNet: """ """ return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs) + + +@register_model +def mobileone_s0(pretrained=False, **kwargs) -> ByobNet: + """ + """ + return _create_byobnet('mobileone_s0', pretrained=pretrained, **kwargs) + + +@register_model +def mobileone_s1(pretrained=False, **kwargs) -> ByobNet: + """ + """ + return _create_byobnet('mobileone_s1', pretrained=pretrained, **kwargs) + + +@register_model +def mobileone_s2(pretrained=False, **kwargs) -> ByobNet: + """ + """ + return _create_byobnet('mobileone_s2', pretrained=pretrained, **kwargs) + + +@register_model +def mobileone_s3(pretrained=False, **kwargs) -> ByobNet: + """ + """ + return _create_byobnet('mobileone_s3', pretrained=pretrained, **kwargs) + + +@register_model +def mobileone_s4(pretrained=False, **kwargs) -> ByobNet: + """ + """ + return _create_byobnet('mobileone_s4', pretrained=pretrained, **kwargs) diff --git a/timm/models/coat.py b/timm/models/coat.py index f6c89af24d..68358b3d6b 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -690,8 +690,11 @@ def checkpoint_filter_fn(state_dict, model): for k, v in state_dict.items(): # original model had unused norm layers, removing them requires filtering pretrained checkpoints if k.startswith('norm1') or \ - (model.norm2 is None and k.startswith('norm2')) or \ - (model.norm3 is None and k.startswith('norm3')): + (k.startswith('norm2') and getattr(model, 'norm2', None) is None) or \ + (k.startswith('norm3') and getattr(model, 'norm3', None) is None) or \ + (k.startswith('norm4') and getattr(model, 'norm4', None) is None) or \ + (k.startswith('aggregate') and getattr(model, 'aggregate', None) is None) or \ + (k.startswith('head') and getattr(model, 'head', None) is None): continue out_dict[k] = v return out_dict diff --git a/timm/models/deit.py b/timm/models/deit.py index 650ab6796e..f80087e80d 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -73,45 +73,36 @@ def reset_classifier(self, num_classes, global_pool=None): def set_distilled_training(self, enable=True): self.distilled_training = enable - def _intermediate_layers( - self, - x: torch.Tensor, - n: Union[int, Sequence] = 1, - ): - outputs, num_blocks = [], len(self.blocks) - take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) - - # forward pass - x = self.patch_embed(x) - x = torch.cat(( - self.cls_token.expand(x.shape[0], -1, -1), - self.dist_token.expand(x.shape[0], -1, -1), - x), - dim=1) - x = self.pos_drop(x + self.pos_embed) - x = self.patch_drop(x) - x = self.norm_pre(x) - for i, blk in enumerate(self.blocks): - x = blk(x) - if i in take_indices: - outputs.append(x) - - return outputs - - def forward_features(self, x) -> torch.Tensor: - x = self.patch_embed(x) - x = torch.cat(( - self.cls_token.expand(x.shape[0], -1, -1), - self.dist_token.expand(x.shape[0], -1, -1), - x), - dim=1) - x = self.pos_drop(x + self.pos_embed) - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x) + def _pos_embed(self, x): + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) else: - x = self.blocks(x) - x = self.norm(x) - return x + pos_embed = self.pos_embed + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + x = torch.cat(( + self.cls_token.expand(x.shape[0], -1, -1), + self.dist_token.expand(x.shape[0], -1, -1), + x), + dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + x = torch.cat(( + self.cls_token.expand(x.shape[0], -1, -1), + self.dist_token.expand(x.shape[0], -1, -1), + x), + dim=1) + x = x + pos_embed + return self.pos_drop(x) def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor: x, x_dist = x[:, 0], x[:, 1] diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index 9388131ed6..357b258dec 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -232,7 +232,7 @@ def __init__( self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N)) k_pos = torch.stack(torch.meshgrid(torch.arange( - self.resolution[1]), + self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1) q_pos = torch.stack(torch.meshgrid( torch.arange(0, self.resolution[0], step=2), diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py new file mode 100644 index 0000000000..6fe444ada1 --- /dev/null +++ b/timm/models/efficientvit_mit.py @@ -0,0 +1,691 @@ +""" EfficientViT (by MIT Song Han's Lab) + +Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition` + - https://arxiv.org/abs/2205.14756 + +Adapted from official impl at https://github.com/mit-han-lab/efficientvit +""" + +__all__ = ['EfficientVit'] +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import SelectAdaptivePool2d, create_conv2d +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +def val2list(x: list or tuple or any, repeat_time=1): + if isinstance(x, (list, tuple)): + return list(x) + return [x for _ in range(repeat_time)] + + +def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1): + # repeat elements if necessary + x = val2list(x) + if len(x) > 0: + x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] + + return tuple(x) + + +def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]: + if isinstance(kernel_size, tuple): + return tuple([get_same_padding(ks) for ks in kernel_size]) + else: + assert kernel_size % 2 > 0, "kernel size should be odd number" + return kernel_size // 2 + + +class ConvNormAct(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + bias=False, + dropout=0., + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(ConvNormAct, self).__init__() + self.dropout = nn.Dropout(dropout, inplace=False) + self.conv = create_conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity() + self.act = act_layer(inplace=True) if act_layer else nn.Identity() + + def forward(self, x): + x = self.dropout(x) + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + return x + + +class DSConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, None), + ): + super(DSConv, self).__init__() + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + + self.depth_conv = ConvNormAct( + in_channels, + in_channels, + kernel_size, + stride, + groups=in_channels, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.point_conv = ConvNormAct( + in_channels, + out_channels, + 1, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + + def forward(self, x): + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class MBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, nn.ReLU6, None), + ): + super(MBConv, self).__init__() + use_bias = val2tuple(use_bias, 3) + norm_layer = val2tuple(norm_layer, 3) + act_layer = val2tuple(act_layer, 3) + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.inverted_conv = ConvNormAct( + in_channels, + mid_channels, + 1, + stride=1, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.depth_conv = ConvNormAct( + mid_channels, + mid_channels, + kernel_size, + stride=stride, + groups=mid_channels, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + self.point_conv = ConvNormAct( + mid_channels, + out_channels, + 1, + norm_layer=norm_layer[2], + act_layer=act_layer[2], + bias=use_bias[2], + ) + + def forward(self, x): + x = self.inverted_conv(x) + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class LiteMSA(nn.Module): + """Lightweight multi-scale attention""" + + def __init__( + self, + in_channels: int, + out_channels: int, + heads: int or None = None, + heads_ratio: float = 1.0, + dim=8, + use_bias=False, + norm_layer=(None, nn.BatchNorm2d), + act_layer=(None, None), + kernel_func=nn.ReLU, + scales=(5,), + eps=1e-5, + ): + super(LiteMSA, self).__init__() + self.eps = eps + heads = heads or int(in_channels // dim * heads_ratio) + total_dim = heads * dim + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + + self.dim = dim + self.qkv = ConvNormAct( + in_channels, + 3 * total_dim, + 1, + bias=use_bias[0], + norm_layer=norm_layer[0], + act_layer=act_layer[0], + ) + self.aggreg = nn.ModuleList([ + nn.Sequential( + nn.Conv2d( + 3 * total_dim, + 3 * total_dim, + scale, + padding=get_same_padding(scale), + groups=3 * total_dim, + bias=use_bias[0], + ), + nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]), + ) + for scale in scales + ]) + self.kernel_func = kernel_func(inplace=False) + + self.proj = ConvNormAct( + total_dim * (1 + len(scales)), + out_channels, + 1, + bias=use_bias[1], + norm_layer=norm_layer[1], + act_layer=act_layer[1], + ) + + def _attn(self, q, k, v): + dtype = v.dtype + q, k, v = q.float(), k.float(), v.float() + kv = k.transpose(-1, -2) @ v + out = q @ kv + out = out[..., :-1] / (out[..., -1:] + self.eps) + return out.to(dtype) + + def forward(self, x): + B, _, H, W = x.shape + + # generate multi-scale q, k, v + qkv = self.qkv(x) + multi_scale_qkv = [qkv] + for op in self.aggreg: + multi_scale_qkv.append(op(qkv)) + multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) + multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2) + q, k, v = multi_scale_qkv.chunk(3, dim=-1) + + # lightweight global attention + q = self.kernel_func(q) + k = self.kernel_func(k) + v = F.pad(v, (0, 1), mode="constant", value=1.) + + if not torch.jit.is_scripting(): + with torch.autocast(device_type=v.device.type, enabled=False): + out = self._attn(q, k, v) + else: + out = self._attn(q, k, v) + + # final projection + out = out.transpose(-1, -2).reshape(B, -1, H, W) + out = self.proj(out) + return out + + +register_notrace_module(LiteMSA) + + +class EfficientVitBlock(nn.Module): + def __init__( + self, + in_channels, + heads_ratio=1.0, + head_dim=32, + expand_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + ): + super(EfficientVitBlock, self).__init__() + self.context_module = ResidualBlock( + LiteMSA( + in_channels=in_channels, + out_channels=in_channels, + heads_ratio=heads_ratio, + dim=head_dim, + norm_layer=(None, norm_layer), + ), + nn.Identity(), + ) + self.local_module = ResidualBlock( + MBConv( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=expand_ratio, + use_bias=(True, True, False), + norm_layer=(None, None, norm_layer), + act_layer=(act_layer, act_layer, None), + ), + nn.Identity(), + ) + + def forward(self, x): + x = self.context_module(x) + x = self.local_module(x) + return x + + +class ResidualBlock(nn.Module): + def __init__( + self, + main: Optional[nn.Module], + shortcut: Optional[nn.Module] = None, + pre_norm: Optional[nn.Module] = None, + ): + super(ResidualBlock, self).__init__() + self.pre_norm = pre_norm if pre_norm is not None else nn.Identity() + self.main = main + self.shortcut = shortcut + + def forward(self, x): + res = self.main(self.pre_norm(x)) + if self.shortcut is not None: + res = res + self.shortcut(x) + return res + + +def build_local_block( + in_channels: int, + out_channels: int, + stride: int, + expand_ratio: float, + norm_layer: str, + act_layer: str, + fewer_norm: bool = False, +): + if expand_ratio == 1: + block = DSConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + use_bias=(True, False) if fewer_norm else False, + norm_layer=(None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, None), + ) + else: + block = MBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + use_bias=(True, True, False) if fewer_norm else False, + norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, act_layer, None), + ) + return block + + +class Stem(nn.Sequential): + def __init__(self, in_chs, out_chs, depth, norm_layer, act_layer): + super().__init__() + self.stride = 2 + + self.add_module( + 'in_conv', + ConvNormAct( + in_chs, out_chs, + kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer, + ) + ) + stem_block = 0 + for _ in range(depth): + self.add_module(f'res{stem_block}', ResidualBlock( + build_local_block( + in_channels=out_chs, + out_channels=out_chs, + stride=1, + expand_ratio=1, + norm_layer=norm_layer, + act_layer=act_layer, + ), + nn.Identity(), + )) + stem_block += 1 + + +class EfficientVitStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + depth, + norm_layer, + act_layer, + expand_ratio, + head_dim, + vit_stage=False, + ): + super(EfficientVitStage, self).__init__() + blocks = [ResidualBlock( + build_local_block( + in_channels=in_chs, + out_channels=out_chs, + stride=2, + expand_ratio=expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + fewer_norm=vit_stage, + ), + None, + )] + in_chs = out_chs + + if vit_stage: + # for stage 3, 4 + for _ in range(depth): + blocks.append( + EfficientVitBlock( + in_channels=in_chs, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + ) + ) + else: + # for stage 1, 2 + for i in range(1, depth): + blocks.append(ResidualBlock( + build_local_block( + in_channels=in_chs, + out_channels=out_chs, + stride=1, + expand_ratio=expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer + ), + nn.Identity(), + )) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + return self.blocks(x) + + +class ClassifierHead(nn.Module): + def __init__( + self, + in_channels, + widths, + n_classes=1000, + dropout=0., + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + global_pool='avg', + ): + super(ClassifierHead, self).__init__() + self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') + self.classifier = nn.Sequential( + nn.Linear(widths[0], widths[1], bias=False), + nn.LayerNorm(widths[1]), + act_layer(inplace=True), + nn.Dropout(dropout, inplace=False), + nn.Linear(widths[1], n_classes, bias=True), + ) + + def forward(self, x, pre_logits: bool = False): + x = self.in_conv(x) + x = self.global_pool(x) + if pre_logits: + return x + x = self.classifier(x) + return x + + +class EfficientVit(nn.Module): + def __init__( + self, + in_chans=3, + widths=(), + depths=(), + head_dim=32, + expand_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + global_pool='avg', + head_widths=(), + drop_rate=0.0, + num_classes=1000, + ): + super(EfficientVit, self).__init__() + self.grad_checkpointing = False + self.global_pool = global_pool + self.num_classes = num_classes + + # input stem + self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer) + stride = self.stem.stride + + # stages + self.feature_info = [] + self.stages = nn.Sequential() + in_channels = widths[0] + for i, (w, d) in enumerate(zip(widths[1:], depths[1:])): + self.stages.append(EfficientVitStage( + in_channels, + w, + depth=d, + norm_layer=norm_layer, + act_layer=act_layer, + expand_ratio=expand_ratio, + head_dim=head_dim, + vit_stage=i >= 2, + )) + stride *= 2 + in_channels = w + self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')] + + self.num_features = in_channels + self.head_widths = head_widths + self.head_dropout = drop_rate + if num_classes > 0: + self.head = ClassifierHead( + self.num_features, + self.head_widths, + n_classes=num_classes, + dropout=self.head_dropout, + global_pool=self.global_pool, + ) + else: + if self.global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + self.head = nn.Identity() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.classifier[-1] + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + if num_classes > 0: + self.head = ClassifierHead( + self.num_features, + self.head_widths, + n_classes=num_classes, + dropout=self.head_dropout, + global_pool=self.global_pool, + ) + else: + if self.global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + else: + self.head = nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.in_conv.conv', + 'classifier': 'head.classifier.4', + 'crop_pct': 0.95, + 'input_size': (3, 224, 224), + 'pool_size': (7, 7), + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'efficientvit_b0.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b1.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b1.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_b1.r288_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, + ), + 'efficientvit_b2.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b2.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_b2.r288_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, + ), + 'efficientvit_b3.r224_in1k': _cfg( + hf_hub_id='timm/', + ), + 'efficientvit_b3.r256_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, + ), + 'efficientvit_b3.r288_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, + ), +}) + + +def _create_efficientvit(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + model = build_model_with_cfg( + EfficientVit, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs + ) + return model + + +@register_model +def efficientvit_b0(pretrained=False, **kwargs): + model_args = dict( + widths=(8, 16, 32, 64, 128), depths=(1, 2, 2, 2, 2), head_dim=16, head_widths=(1024, 1280)) + return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_b1(pretrained=False, **kwargs): + model_args = dict( + widths=(16, 32, 64, 128, 256), depths=(1, 2, 3, 3, 4), head_dim=16, head_widths=(1536, 1600)) + return _create_efficientvit('efficientvit_b1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_b2(pretrained=False, **kwargs): + model_args = dict( + widths=(24, 48, 96, 192, 384), depths=(1, 3, 4, 4, 6), head_dim=32, head_widths=(2304, 2560)) + return _create_efficientvit('efficientvit_b2', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_b3(pretrained=False, **kwargs): + model_args = dict( + widths=(32, 64, 128, 256, 512), depths=(1, 4, 6, 6, 9), head_dim=32, head_widths=(2304, 2560)) + return _create_efficientvit('efficientvit_b3', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py new file mode 100644 index 0000000000..1b7f52a02f --- /dev/null +++ b/timm/models/efficientvit_msra.py @@ -0,0 +1,659 @@ +""" EfficientViT (by MSRA) + +Paper: `EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention` + - https://arxiv.org/abs/2305.07027 + +Adapted from official impl at https://github.com/microsoft/Cream/tree/main/EfficientViT +""" + +__all__ = ['EfficientVitMsra'] +import itertools +from collections import OrderedDict +from typing import Dict + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +class ConvNorm(torch.nn.Sequential): + def __init__(self, in_chs, out_chs, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + super().__init__() + self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False) + self.bn = nn.BatchNorm2d(out_chs) + torch.nn.init.constant_(self.bn.weight, bn_weight_init) + torch.nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self.conv, self.bn + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d( + w.size(1) * self.conv.groups, w.size(0), w.shape[2:], + stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class NormLinear(torch.nn.Sequential): + def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.): + super().__init__() + self.bn = nn.BatchNorm1d(in_features) + self.drop = nn.Dropout(drop) + self.linear = nn.Linear(in_features, out_features, bias=bias) + + trunc_normal_(self.linear.weight, std=std) + if self.linear.bias is not None: + nn.init.constant_(self.linear.bias, 0) + + @torch.no_grad() + def fuse(self): + bn, linear = self.bn, self.linear + w = bn.weight / (bn.running_var + bn.eps)**0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps)**0.5 + w = linear.weight * w[None, :] + if linear.bias is None: + b = b @ self.linear.weight.T + else: + b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class PatchMerging(torch.nn.Module): + def __init__(self, dim, out_dim): + super().__init__() + hid_dim = int(dim * 4) + self.conv1 = ConvNorm(dim, hid_dim, 1, 1, 0) + self.act = torch.nn.ReLU() + self.conv2 = ConvNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim) + self.se = SqueezeExcite(hid_dim, .25) + self.conv3 = ConvNorm(hid_dim, out_dim, 1, 1, 0) + + def forward(self, x): + x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x)))))) + return x + + +class ResidualDrop(torch.nn.Module): + def __init__(self, m, drop=0.): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class ConvMlp(torch.nn.Module): + def __init__(self, ed, h): + super().__init__() + self.pw1 = ConvNorm(ed, h) + self.act = torch.nn.ReLU() + self.pw2 = ConvNorm(h, ed, bn_weight_init=0) + + def forward(self, x): + x = self.pw2(self.act(self.pw1(x))) + return x + + +class CascadedGroupAttention(torch.nn.Module): + attention_bias_cache: Dict[str, torch.Tensor] + + r""" Cascaded Group Attention. + + Args: + dim (int): Number of input channels. + key_dim (int): The dimension for query and key. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution, correspond to the window size. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + kernels=(5, 5, 5, 5), + ): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.val_dim = int(attn_ratio * key_dim) + self.attn_ratio = attn_ratio + + qkvs = [] + dws = [] + for i in range(num_heads): + qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim)) + dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim)) + self.qkvs = torch.nn.ModuleList(qkvs) + self.dws = torch.nn.ModuleList(dws) + self.proj = torch.nn.Sequential( + torch.nn.ReLU(), + ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0) + ) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + self.attention_bias_cache = {} + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if torch.jit.is_tracing() or self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, x): + B, C, H, W = x.shape + feats_in = x.chunk(len(self.qkvs), dim=1) + feats_out = [] + feat = feats_in[0] + attn_bias = self.get_attention_biases(x.device) + for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)): + if head_idx > 0: + feat = feat + feats_in[head_idx] + feat = qkv(feat) + q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1) + q = dws(q) + q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) + q = q * self.scale + attn = q.transpose(-2, -1) @ k + attn = attn + attn_bias[head_idx] + attn = attn.softmax(dim=-1) + feat = v @ attn.transpose(-2, -1) + feat = feat.view(B, self.val_dim, H, W) + feats_out.append(feat) + x = self.proj(torch.cat(feats_out, 1)) + return x + + +class LocalWindowAttention(torch.nn.Module): + r""" Local Window Attention. + + Args: + dim (int): Number of input channels. + key_dim (int): The dimension for query and key. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution. + window_resolution (int): Local window resolution. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=(5, 5, 5, 5), + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.resolution = resolution + assert window_resolution > 0, 'window_size must be greater than 0' + self.window_resolution = window_resolution + window_resolution = min(window_resolution, resolution) + self.attn = CascadedGroupAttention( + dim, key_dim, num_heads, + attn_ratio=attn_ratio, + resolution=window_resolution, + kernels=kernels, + ) + + def forward(self, x): + H = W = self.resolution + B, C, H_, W_ = x.shape + # Only check this for classifcation models + _assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}') + _assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}') + if H <= self.window_resolution and W <= self.window_resolution: + x = self.attn(x) + else: + x = x.permute(0, 2, 3, 1) + pad_b = (self.window_resolution - H % self.window_resolution) % self.window_resolution + pad_r = (self.window_resolution - W % self.window_resolution) % self.window_resolution + x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_resolution + nW = pW // self.window_resolution + # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw + x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3) + x = x.reshape(B * nH * nW, self.window_resolution, self.window_resolution, C).permute(0, 3, 1, 2) + x = self.attn(x) + # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC + x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C) + x = x.transpose(2, 3).reshape(B, pH, pW, C) + x = x[:, :H, :W].contiguous() + x = x.permute(0, 3, 1, 2) + return x + + +class EfficientVitBlock(torch.nn.Module): + """ A basic EfficientVit building block. + + Args: + dim (int): Number of input channels. + key_dim (int): Dimension for query and key in the token mixer. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution. + window_resolution (int): Local window resolution. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5], + ): + super().__init__() + + self.dw0 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.)) + self.ffn0 = ResidualDrop(ConvMlp(dim, int(dim * 2))) + + self.mixer = ResidualDrop( + LocalWindowAttention( + dim, key_dim, num_heads, + attn_ratio=attn_ratio, + resolution=resolution, + window_resolution=window_resolution, + kernels=kernels, + ) + ) + + self.dw1 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.)) + self.ffn1 = ResidualDrop(ConvMlp(dim, int(dim * 2))) + + def forward(self, x): + return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x))))) + + +class EfficientVitStage(torch.nn.Module): + def __init__( + self, + in_dim, + out_dim, + key_dim, + downsample=('', 1), + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5], + depth=1, + ): + super().__init__() + if downsample[0] == 'subsample': + self.resolution = (resolution - 1) // downsample[1] + 1 + down_blocks = [] + down_blocks.append(( + 'res1', + torch.nn.Sequential( + ResidualDrop(ConvNorm(in_dim, in_dim, 3, 1, 1, groups=in_dim)), + ResidualDrop(ConvMlp(in_dim, int(in_dim * 2))), + ) + )) + down_blocks.append(('patchmerge', PatchMerging(in_dim, out_dim))) + down_blocks.append(( + 'res2', + torch.nn.Sequential( + ResidualDrop(ConvNorm(out_dim, out_dim, 3, 1, 1, groups=out_dim)), + ResidualDrop(ConvMlp(out_dim, int(out_dim * 2))), + ) + )) + self.downsample = nn.Sequential(OrderedDict(down_blocks)) + else: + assert in_dim == out_dim + self.downsample = nn.Identity() + self.resolution = resolution + + blocks = [] + for d in range(depth): + blocks.append(EfficientVitBlock(out_dim, key_dim, num_heads, attn_ratio, self.resolution, window_resolution, kernels)) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +class PatchEmbedding(torch.nn.Sequential): + def __init__(self, in_chans, dim): + super().__init__() + self.add_module('conv1', ConvNorm(in_chans, dim // 8, 3, 2, 1)) + self.add_module('relu1', torch.nn.ReLU()) + self.add_module('conv2', ConvNorm(dim // 8, dim // 4, 3, 2, 1)) + self.add_module('relu2', torch.nn.ReLU()) + self.add_module('conv3', ConvNorm(dim // 4, dim // 2, 3, 2, 1)) + self.add_module('relu3', torch.nn.ReLU()) + self.add_module('conv4', ConvNorm(dim // 2, dim, 3, 2, 1)) + self.patch_size = 16 + + +class EfficientVitMsra(nn.Module): + def __init__( + self, + img_size=224, + in_chans=3, + num_classes=1000, + embed_dim=(64, 128, 192), + key_dim=(16, 16, 16), + depth=(1, 2, 3), + num_heads=(4, 4, 4), + window_size=(7, 7, 7), + kernels=(5, 5, 5, 5), + down_ops=(('', 1), ('subsample', 2), ('subsample', 2)), + global_pool='avg', + drop_rate=0., + ): + super(EfficientVitMsra, self).__init__() + self.grad_checkpointing = False + self.num_classes = num_classes + self.drop_rate = drop_rate + + # Patch embedding + self.patch_embed = PatchEmbedding(in_chans, embed_dim[0]) + stride = self.patch_embed.patch_size + resolution = img_size // self.patch_embed.patch_size + attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))] + + # Build EfficientVit blocks + self.feature_info = [] + stages = [] + pre_ed = embed_dim[0] + for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)): + stage = EfficientVitStage( + in_dim=pre_ed, + out_dim=ed, + key_dim=kd, + downsample=do, + num_heads=nh, + attn_ratio=ar, + resolution=resolution, + window_resolution=wd, + kernels=kernels, + depth=dpth, + ) + pre_ed = ed + if do[0] == 'subsample' and i != 0: + stride *= do[1] + resolution = stage.resolution + stages.append(stage) + self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + + if global_pool == 'avg': + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + assert num_classes == 0 + self.global_pool = nn.Identity() + self.num_features = embed_dim[-1] + self.head = NormLinear( + self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embed', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.linear + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + if global_pool == 'avg': + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + assert num_classes == 0 + self.global_pool = nn.Identity() + self.head = NormLinear( + self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +# def checkpoint_filter_fn(state_dict, model): +# if 'model' in state_dict.keys(): +# state_dict = state_dict['model'] +# tmp_dict = {} +# out_dict = {} +# target_keys = model.state_dict().keys() +# target_keys = [k for k in target_keys if k.startswith('stages.')] +# +# for k, v in state_dict.items(): +# if 'attention_bias_idxs' in k: +# continue +# k = k.split('.') +# if k[-2] == 'c': +# k[-2] = 'conv' +# if k[-2] == 'l': +# k[-2] = 'linear' +# k = '.'.join(k) +# tmp_dict[k] = v +# +# for k, v in tmp_dict.items(): +# if k.startswith('patch_embed'): +# k = k.split('.') +# k[1] = 'conv' + str(int(k[1]) // 2 + 1) +# k = '.'.join(k) +# elif k.startswith('blocks'): +# kw = '.'.join(k.split('.')[2:]) +# find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a] +# idx = find_kw.index(k) +# k = [a for a in target_keys if kw in a][idx] +# out_dict[k] = v +# +# return out_dict + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv1.conv', + 'classifier': 'head.linear', + 'fixed_input_size': True, + 'pool_size': (4, 4), + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'efficientvit_m0.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth' + ), + 'efficientvit_m1.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth' + ), + 'efficientvit_m2.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth' + ), + 'efficientvit_m3.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth' + ), + 'efficientvit_m4.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth' + ), + 'efficientvit_m5.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth' + ), +}) + + +def _create_efficientvit_msra(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2)) + model = build_model_with_cfg( + EfficientVitMsra, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs + ) + return model + + +@register_model +def efficientvit_m0(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[64, 128, 192], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5] + ) + return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m1(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 144, 192], + depth=[1, 2, 3], + num_heads=[2, 3, 3], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m2(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 192, 224], + depth=[1, 2, 3], + num_heads=[4, 3, 2], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m2', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m3(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 240, 320], + depth=[1, 2, 3], + num_heads=[4, 3, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5] + ) + return _create_efficientvit_msra('efficientvit_m3', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m4(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[128, 256, 384], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m4', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def efficientvit_m5(pretrained=False, **kwargs): + model_args = dict( + img_size=224, + embed_dim=[192, 288, 384], + depth=[1, 3, 4], + num_heads=[3, 3, 4], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) + return _create_efficientvit_msra('efficientvit_m5', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/eva.py b/timm/models/eva.py index f0ab9c7224..81bcce525d 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -367,6 +367,8 @@ def __init__( use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None, head_init_scale: float = 0.001, ): @@ -406,13 +408,20 @@ def __init__( self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 + self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + dynamic_img_pad=dynamic_img_pad, + **embed_args, ) num_patches = self.patch_embed.num_patches @@ -435,7 +444,7 @@ def __init__( self.rope = RotaryEmbeddingCat( embed_dim // num_heads, in_pixels=False, - feat_shape=self.patch_embed.grid_size, + feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, ref_feat_shape=ref_feat_shape, ) else: @@ -519,30 +528,44 @@ def reset_classifier(self, num_classes, global_pool=None): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - def forward_features(self, x): - x = self.patch_embed(x) + def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.dynamic_img_size: + B, H, W, C = x.shape + if self.pos_embed is not None: + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + pos_embed = None + x = x.view(B, -1, C) + rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None + else: + pos_embed = self.pos_embed + rot_pos_embed = self.rope.get_embed() if self.rope is not None else None if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - - # apply abs position embedding - if self.pos_embed is not None: - x = x + self.pos_embed + if pos_embed is not None: + x = x + pos_embed x = self.pos_drop(x) # obtain shared rotary position embedding and apply patch dropout - rot_pos_embed = self.rope.get_embed() if self.rope is not None else None if self.patch_drop is not None: x, keep_indices = self.patch_drop(x) if rot_pos_embed is not None and keep_indices is not None: rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices) + return x, rot_pos_embed + def forward_features(self, x): + x = self.patch_embed(x) + x, rot_pos_embed = self._pos_embed(x) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(blk, x, rope=rot_pos_embed) else: x = blk(x, rope=rot_pos_embed) - x = self.norm(x) return x diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py new file mode 100644 index 0000000000..f61b54e5ff --- /dev/null +++ b/timm/models/fastvit.py @@ -0,0 +1,1413 @@ +# FastViT for PyTorch +# +# Original implementation and weights from https://github.com/apple/ml-fastvit +# +# For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main +# Original work is copyright (C) 2023 Apple Inc. All Rights Reserved. +# +import os +from functools import partial +from typing import Tuple, Optional, Union + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \ + ClassifierHead +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +def num_groups(group_size, channels): + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + assert channels % group_size == 0 + return channels // group_size + + +class MobileOneBlock(nn.Module): + """MobileOne building block. + + This block has a multi-branched architecture at train-time + and plain-CNN style architecture at inference time + For more details, please refer to our paper: + `An Improved One millisecond Mobile Backbone` - + https://arxiv.org/pdf/2206.04040.pdf + """ + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + group_size: int = 0, + inference_mode: bool = False, + use_se: bool = False, + use_act: bool = True, + use_scale_branch: bool = True, + num_conv_branches: int = 1, + act_layer: nn.Module = nn.GELU, + ) -> None: + """Construct a MobileOneBlock module. + + Args: + in_chs: Number of channels in the input. + out_chs: Number of channels produced by the block. + kernel_size: Size of the convolution kernel. + stride: Stride size. + dilation: Kernel dilation factor. + group_size: Convolution group size. + inference_mode: If True, instantiates model in inference mode. + use_se: Whether to use SE-ReLU activations. + use_act: Whether to use activation. Default: ``True`` + use_scale_branch: Whether to use scale branch. Default: ``True`` + num_conv_branches: Number of linear conv branches. + """ + super(MobileOneBlock, self).__init__() + self.inference_mode = inference_mode + self.groups = num_groups(group_size, in_chs) + self.stride = stride + self.dilation = dilation + self.kernel_size = kernel_size + self.in_chs = in_chs + self.out_chs = out_chs + self.num_conv_branches = num_conv_branches + + # Check if SE-ReLU is requested + self.se = SqueezeExcite(out_chs, rd_divisor=1) if use_se else nn.Identity() + + if inference_mode: + self.reparam_conv = create_conv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=self.groups, + bias=True, + ) + else: + # Re-parameterizable skip connection + self.reparam_conv = None + + self.identity = ( + nn.BatchNorm2d(num_features=in_chs) + if out_chs == in_chs and stride == 1 + else None + ) + + # Re-parameterizable conv branches + if num_conv_branches > 0: + self.conv_kxk = nn.ModuleList([ + ConvNormAct( + self.in_chs, + self.out_chs, + kernel_size=kernel_size, + stride=self.stride, + groups=self.groups, + apply_act=False, + ) for _ in range(self.num_conv_branches) + ]) + else: + self.conv_kxk = None + + # Re-parameterizable scale branch + self.conv_scale = None + if kernel_size > 1 and use_scale_branch: + self.conv_scale = ConvNormAct( + self.in_chs, + self.out_chs, + kernel_size=1, + stride=self.stride, + groups=self.groups, + apply_act=False + ) + + self.act = act_layer() if use_act else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply forward pass.""" + # Inference mode forward pass. + if self.reparam_conv is not None: + return self.act(self.se(self.reparam_conv(x))) + + # Multi-branched train-time forward pass. + # Identity branch output + identity_out = 0 + if self.identity is not None: + identity_out = self.identity(x) + + # Scale branch output + scale_out = 0 + if self.conv_scale is not None: + scale_out = self.conv_scale(x) + + # Other kxk conv branches + out = scale_out + identity_out + if self.conv_kxk is not None: + for rc in self.conv_kxk: + out += rc(x) + + return self.act(self.se(out)) + + def reparameterize(self): + """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + if self.reparam_conv is not None: + return + + kernel, bias = self._get_kernel_bias() + self.reparam_conv = create_conv2d( + in_channels=self.in_chs, + out_channels=self.out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + bias=True, + ) + self.reparam_conv.weight.data = kernel + self.reparam_conv.bias.data = bias + + # Delete un-used branches + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue + para.detach_() + + self.__delattr__("conv_kxk") + self.__delattr__("conv_scale") + if hasattr(self, "identity"): + self.__delattr__("identity") + + self.inference_mode = True + + def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + + Returns: + Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.conv_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.identity is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + if self.conv_kxk is not None: + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor( + self, branch: Union[nn.Sequential, nn.BatchNorm2d] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + + Args: + branch: Sequence of ops to be fused. + + Returns: + Tuple of (kernel, bias) after fusing batchnorm. + """ + if isinstance(branch, ConvNormAct): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, "id_tensor"): + input_dim = self.in_chs // self.groups + kernel_value = torch.zeros( + (self.in_chs, input_dim, self.kernel_size, self.kernel_size), + dtype=branch.weight.dtype, + device=branch.weight.device, + ) + for i in range(self.in_chs): + kernel_value[ + i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2 + ] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class ReparamLargeKernelConv(nn.Module): + """Building Block of RepLKNet + + This class defines overparameterized large kernel conv block + introduced in `RepLKNet `_ + + Reference: https://github.com/DingXiaoH/RepLKNet-pytorch + """ + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int, + group_size: int, + small_kernel: Optional[int] = None, + inference_mode: bool = False, + act_layer: Optional[nn.Module] = None, + ) -> None: + """Construct a ReparamLargeKernelConv module. + + Args: + in_chs: Number of input channels. + out_chs: Number of output channels. + kernel_size: Kernel size of the large kernel conv branch. + stride: Stride size. Default: 1 + group_size: Group size. Default: 1 + small_kernel: Kernel size of small kernel conv branch. + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + act_layer: Activation module. Default: ``nn.GELU`` + """ + super(ReparamLargeKernelConv, self).__init__() + self.stride = stride + self.groups = num_groups(group_size, in_chs) + self.in_chs = in_chs + self.out_chs = out_chs + + self.kernel_size = kernel_size + self.small_kernel = small_kernel + if inference_mode: + self.reparam_conv = create_conv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=1, + groups=self.groups, + bias=True, + ) + else: + self.reparam_conv = None + self.large_conv = ConvNormAct( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=self.stride, + groups=self.groups, + apply_act=False, + ) + if small_kernel is not None: + assert ( + small_kernel <= kernel_size + ), "The kernel size for re-param cannot be larger than the large kernel!" + self.small_conv = ConvNormAct( + in_chs, + out_chs, + kernel_size=small_kernel, + stride=self.stride, + groups=self.groups, + apply_act=False, + ) + # FIXME output of this act was not used in original impl, likely due to bug + self.act = act_layer() if act_layer is not None else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.reparam_conv is not None: + out = self.reparam_conv(x) + else: + out = self.large_conv(x) + if self.small_conv is not None: + out = out + self.small_conv(x) + out = self.act(out) + return out + + def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepLKNet-pytorch + + Returns: + Tuple of (kernel, bias) after fusing branches. + """ + eq_k, eq_b = self._fuse_bn(self.large_conv.conv, self.large_conv.bn) + if hasattr(self, "small_conv"): + small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) + eq_b += small_b + eq_k += nn.functional.pad( + small_k, [(self.kernel_size - self.small_kernel) // 2] * 4 + ) + return eq_k, eq_b + + def reparameterize(self) -> None: + """ + Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + eq_k, eq_b = self.get_kernel_bias() + self.reparam_conv = create_conv2d( + self.in_chs, + self.out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + groups=self.groups, + bias=True, + ) + + self.reparam_conv.weight.data = eq_k + self.reparam_conv.bias.data = eq_b + self.__delattr__("large_conv") + if hasattr(self, "small_conv"): + self.__delattr__("small_conv") + + @staticmethod + def _fuse_bn( + conv: torch.Tensor, bn: nn.BatchNorm2d + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to fuse batchnorm layer with conv layer. + + Args: + conv: Convolutional kernel weights. + bn: Batchnorm 2d layer. + + Returns: + Tuple of (kernel, bias) after fusing batchnorm. + """ + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +def convolutional_stem( + in_chs: int, + out_chs: int, + inference_mode: bool = False +) -> nn.Sequential: + """Build convolutional stem with MobileOne blocks. + + Args: + in_chs: Number of input channels. + out_chs: Number of output channels. + inference_mode: Flag to instantiate model in inference mode. Default: ``False`` + + Returns: + nn.Sequential object with stem elements. + """ + return nn.Sequential( + MobileOneBlock( + in_chs=in_chs, + out_chs=out_chs, + kernel_size=3, + stride=2, + inference_mode=inference_mode, + ), + MobileOneBlock( + in_chs=out_chs, + out_chs=out_chs, + kernel_size=3, + stride=2, + group_size=1, + inference_mode=inference_mode, + ), + MobileOneBlock( + in_chs=out_chs, + out_chs=out_chs, + kernel_size=1, + stride=1, + inference_mode=inference_mode, + ), + ) + + +class Attention(nn.Module): + """Multi-headed Self Attention module. + + Source modified from: + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """Build MHSA module that can handle 3D or 4D input tensors. + + Args: + dim: Number of embedding dimensions. + head_dim: Number of hidden dimensions per head. Default: ``32`` + qkv_bias: Use bias or not. Default: ``False`` + attn_drop: Dropout rate for attention tensor. + proj_drop: Dropout rate for projection tensor. + """ + super().__init__() + assert dim % head_dim == 0, "dim should be divisible by head_dim" + self.head_dim = head_dim + self.num_heads = dim // head_dim + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + N = H * W + x = x.flatten(2).transpose(-2, -1) # (B, N, C) + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + if self.fused_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + x = x.transpose(-2, -1).reshape(B, C, H, W) + + return x + + +class PatchEmbed(nn.Module): + """Convolutional patch embedding layer.""" + + def __init__( + self, + patch_size: int, + stride: int, + in_chs: int, + embed_dim: int, + act_layer: nn.Module = nn.GELU, + lkc_use_act: bool = False, + inference_mode: bool = False, + ) -> None: + """Build patch embedding layer. + + Args: + patch_size: Patch size for embedding computation. + stride: Stride for convolutional embedding layer. + in_chs: Number of channels of input tensor. + embed_dim: Number of embedding dimensions. + inference_mode: Flag to instantiate model in inference mode. Default: ``False`` + """ + super().__init__() + self.proj = nn.Sequential( + ReparamLargeKernelConv( + in_chs=in_chs, + out_chs=embed_dim, + kernel_size=patch_size, + stride=stride, + group_size=1, + small_kernel=3, + inference_mode=inference_mode, + act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act + ), + MobileOneBlock( + in_chs=embed_dim, + out_chs=embed_dim, + kernel_size=1, + stride=1, + act_layer=act_layer, + inference_mode=inference_mode, + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x + + +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class RepMixer(nn.Module): + """Reparameterizable token mixer. + + For more details, please refer to our paper: + `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization `_ + """ + + def __init__( + self, + dim, + kernel_size=3, + layer_scale_init_value=1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Module. + + Args: + dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. + kernel_size: Kernel size for spatial mixing. Default: 3 + layer_scale_init_value: Initial value for layer scale. Default: 1e-5 + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + """ + super().__init__() + self.dim = dim + self.kernel_size = kernel_size + self.inference_mode = inference_mode + + if inference_mode: + self.reparam_conv = nn.Conv2d( + self.dim, + self.dim, + kernel_size=self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + groups=self.dim, + bias=True, + ) + else: + self.reparam_conv = None + self.norm = MobileOneBlock( + dim, + dim, + kernel_size, + group_size=1, + use_act=False, + use_scale_branch=False, + num_conv_branches=0, + ) + self.mixer = MobileOneBlock( + dim, + dim, + kernel_size, + group_size=1, + use_act=False, + ) + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale = nn.Identity + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.reparam_conv is not None: + x = self.reparam_conv(x) + else: + x = x + self.layer_scale(self.mixer(x) - self.norm(x)) + return x + + def reparameterize(self) -> None: + """Reparameterize mixer and norm into a single + convolutional layer for efficient inference. + """ + if self.inference_mode: + return + + self.mixer.reparameterize() + self.norm.reparameterize() + + if isinstance(self.layer_scale, LayerScale2d): + w = self.mixer.id_tensor + self.layer_scale.gamma.unsqueeze(-1) * ( + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight + ) + b = torch.squeeze(self.layer_scale.gamma) * ( + self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias + ) + else: + w = ( + self.mixer.id_tensor + + self.mixer.reparam_conv.weight + - self.norm.reparam_conv.weight + ) + b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias + + self.reparam_conv = create_conv2d( + self.dim, + self.dim, + kernel_size=self.kernel_size, + stride=1, + groups=self.dim, + bias=True, + ) + self.reparam_conv.weight.data = w + self.reparam_conv.bias.data = b + + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue + para.detach_() + self.__delattr__("mixer") + self.__delattr__("norm") + self.__delattr__("layer_scale") + + +class ConvMlp(nn.Module): + """Convolutional FFN Module.""" + + def __init__( + self, + in_chs: int, + hidden_channels: Optional[int] = None, + out_chs: Optional[int] = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + ) -> None: + """Build convolutional FFN module. + + Args: + in_chs: Number of input channels. + hidden_channels: Number of channels after expansion. Default: None + out_chs: Number of output channels. Default: None + act_layer: Activation layer. Default: ``GELU`` + drop: Dropout rate. Default: ``0.0``. + """ + super().__init__() + out_chs = out_chs or in_chs + hidden_channels = hidden_channels or in_chs + self.conv = ConvNormAct( + in_chs, + out_chs, + kernel_size=7, + groups=in_chs, + apply_act=False, + ) + self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1) + self.drop = nn.Dropout(drop) + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class RepConditionalPosEnc(nn.Module): + """Implementation of conditional positional encoding. + + For more details refer to paper: + `Conditional Positional Encodings for Vision Transformers `_ + + In our implementation, we can reparameterize this module to eliminate a skip connection. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + spatial_shape: Union[int, Tuple[int, int]] = (7, 7), + inference_mode=False, + ) -> None: + """Build reparameterizable conditional positional encoding + + Args: + dim: Number of input channels. + dim_out: Number of embedding dimensions. Default: 768 + spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + super(RepConditionalPosEnc, self).__init__() + if isinstance(spatial_shape, int): + spatial_shape = tuple([spatial_shape] * 2) + assert isinstance(spatial_shape, Tuple), ( + f'"spatial_shape" must by a sequence or int, ' + f"get {type(spatial_shape)} instead." + ) + assert len(spatial_shape) == 2, ( + f'Length of "spatial_shape" should be 2, ' + f"got {len(spatial_shape)} instead." + ) + + self.spatial_shape = spatial_shape + self.dim = dim + self.dim_out = dim_out or dim + self.groups = dim + + if inference_mode: + self.reparam_conv = nn.Conv2d( + self.dim, + self.dim_out, + kernel_size=self.spatial_shape, + stride=1, + padding=spatial_shape[0] // 2, + groups=self.groups, + bias=True, + ) + else: + self.reparam_conv = None + self.pos_enc = nn.Conv2d( + self.dim, + self.dim_out, + spatial_shape, + 1, + int(spatial_shape[0] // 2), + groups=self.groups, + bias=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.reparam_conv is not None: + x = self.reparam_conv(x) + else: + x = self.pos_enc(x) + x + return x + + def reparameterize(self) -> None: + # Build equivalent Id tensor + input_dim = self.dim // self.groups + kernel_value = torch.zeros( + ( + self.dim, + input_dim, + self.spatial_shape[0], + self.spatial_shape[1], + ), + dtype=self.pos_enc.weight.dtype, + device=self.pos_enc.weight.device, + ) + for i in range(self.dim): + kernel_value[ + i, + i % input_dim, + self.spatial_shape[0] // 2, + self.spatial_shape[1] // 2, + ] = 1 + id_tensor = kernel_value + + # Reparameterize Id tensor and conv + w_final = id_tensor + self.pos_enc.weight + b_final = self.pos_enc.bias + + # Introduce reparam conv + self.reparam_conv = nn.Conv2d( + self.dim, + self.dim_out, + kernel_size=self.spatial_shape, + stride=1, + padding=int(self.spatial_shape[0] // 2), + groups=self.groups, + bias=True, + ) + self.reparam_conv.weight.data = w_final + self.reparam_conv.bias.data = b_final + + for name, para in self.named_parameters(): + if 'reparam_conv' in name: + continue + para.detach_() + self.__delattr__("pos_enc") + + +class RepMixerBlock(nn.Module): + """Implementation of Metaformer block with RepMixer as token mixer. + + For more details on Metaformer structure, please refer to: + `MetaFormer Is Actually What You Need for Vision `_ + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + proj_drop: float = 0.0, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Block. + + Args: + dim: Number of embedding dimensions. + kernel_size: Kernel size for repmixer. Default: 3 + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + proj_drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + + super().__init__() + + self.token_mixer = RepMixer( + dim, + kernel_size=kernel_size, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + + self.mlp = ConvMlp( + in_chs=dim, + hidden_channels=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale = nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + x = self.token_mixer(x) + x = x + self.drop_path(self.layer_scale(self.mlp(x))) + return x + + +class AttentionBlock(nn.Module): + """Implementation of metaformer block with MHSA as token mixer. + + For more details on Metaformer structure, please refer to: + `MetaFormer Is Actually What You Need for Vision `_ + """ + + def __init__( + self, + dim: int, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + proj_drop: float = 0.0, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-5, + ): + """Build Attention Block. + + Args: + dim: Number of embedding dimensions. + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` + proj_drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + """ + + super().__init__() + + self.norm = norm_layer(dim) + self.token_mixer = Attention(dim=dim) + if layer_scale_init_value is not None: + self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale_1 = nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.mlp = ConvMlp( + in_chs=dim, + hidden_channels=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + if layer_scale_init_value is not None: + self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value) + else: + self.layer_scale_2 = nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + x = x + self.drop_path1(self.layer_scale_1(self.token_mixer(self.norm(x)))) + x = x + self.drop_path2(self.layer_scale_2(self.mlp(x))) + return x + + +class FastVitStage(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + depth: int, + token_mixer_type: str, + downsample: bool = True, + down_patch_size: int = 7, + down_stride: int = 2, + pos_emb_layer: Optional[nn.Module] = None, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + proj_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + layer_scale_init_value: Optional[float] = 1e-5, + lkc_use_act=False, + inference_mode=False, + ): + """FastViT stage. + + Args: + dim: Number of embedding dimensions. + depth: Number of blocks in stage + token_mixer_type: Token mixer type. + kernel_size: Kernel size for repmixer. + mlp_ratio: MLP expansion ratio. + act_layer: Activation layer. + norm_layer: Normalization layer. + proj_drop_rate: Dropout rate. + drop_path_rate: Drop path rate. + layer_scale_init_value: Layer scale value at initialization. + inference_mode: Flag to instantiate block in inference mode. + """ + super().__init__() + self.grad_checkpointing = False + + if downsample: + self.downsample = PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + in_chs=dim, + embed_dim=dim_out, + act_layer=act_layer, + lkc_use_act=lkc_use_act, + inference_mode=inference_mode, + ) + else: + assert dim == dim_out + self.downsample = nn.Identity() + + if pos_emb_layer is not None: + self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode) + else: + self.pos_emb = nn.Identity() + + blocks = [] + for block_idx in range(depth): + if token_mixer_type == "repmixer": + blocks.append(RepMixerBlock( + dim_out, + kernel_size=kernel_size, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + proj_drop=proj_drop_rate, + drop_path=drop_path_rate[block_idx], + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + )) + elif token_mixer_type == "attention": + blocks.append(AttentionBlock( + dim_out, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + proj_drop=proj_drop_rate, + drop_path=drop_path_rate[block_idx], + layer_scale_init_value=layer_scale_init_value, + )) + else: + raise ValueError( + "Token mixer type: {} not supported".format(token_mixer_type) + ) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.pos_emb(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class FastVit(nn.Module): + fork_feat: torch.jit.Final[bool] + + """ + This class implements `FastViT architecture `_ + """ + + def __init__( + self, + in_chans: int = 3, + layers: Tuple[int, ...] = (2, 2, 6, 2), + token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), + embed_dims: Tuple[int, ...] = (64, 128, 256, 512), + mlp_ratios: Tuple[float, ...] = (4,) * 4, + downsamples: Tuple[bool, ...] = (False, True, True, True), + repmixer_kernel_size: int = 3, + num_classes: int = 1000, + pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, + down_patch_size: int = 7, + down_stride: int = 2, + drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-5, + fork_feat: bool = False, + cls_ratio: float = 2.0, + global_pool: str = 'avg', + norm_layer: nn.Module = nn.BatchNorm2d, + act_layer: nn.Module = nn.GELU, + lkc_use_act: bool = False, + inference_mode: bool = False, + ) -> None: + super().__init__() + self.num_classes = 0 if fork_feat else num_classes + self.fork_feat = fork_feat + self.global_pool = global_pool + self.feature_info = [] + + # Convolutional stem + self.stem = convolutional_stem( + in_chans, + embed_dims[0], + inference_mode, + ) + + # Build the main stages of the network architecture + prev_dim = embed_dims[0] + scale = 1 + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] + stages = [] + for i in range(len(layers)): + downsample = downsamples[i] or prev_dim != embed_dims[i] + stage = FastVitStage( + dim=prev_dim, + dim_out=embed_dims[i], + depth=layers[i], + downsample=downsample, + down_patch_size=down_patch_size, + down_stride=down_stride, + pos_emb_layer=pos_embs[i], + token_mixer_type=token_mixers[i], + kernel_size=repmixer_kernel_size, + mlp_ratio=mlp_ratios[i], + act_layer=act_layer, + norm_layer=norm_layer, + proj_drop_rate=proj_drop_rate, + drop_path_rate=dpr[i], + layer_scale_init_value=layer_scale_init_value, + lkc_use_act=lkc_use_act, + inference_mode=inference_mode, + ) + stages.append(stage) + prev_dim = embed_dims[i] + if downsample: + scale *= 2 + self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + self.num_features = prev_dim + + # For segmentation and detection, extract intermediate output + if self.fork_feat: + # add a norm layer for each output + self.out_indices = [0, 2, 4, 6] + for i_emb, i_layer in enumerate(self.out_indices): + if i_emb == 0 and os.environ.get("FORK_LAST3", None): + """For RetinaNet, `start_level=1`. The first norm layer will not used. + cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...` + """ + layer = nn.Identity() + else: + layer = norm_layer(embed_dims[i_emb]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + else: + # Classifier head + self.num_features = final_features = int(embed_dims[-1] * cls_ratio) + self.final_conv = MobileOneBlock( + in_chs=embed_dims[-1], + out_chs=final_features, + kernel_size=3, + stride=1, + group_size=1, + inference_mode=inference_mode, + use_se=True, + num_conv_branches=1, + ) + self.head = ClassifierHead( + final_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + """Init. for classification""" + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return set() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+).pos_emb', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + # input embedding + x = self.stem(x) + outs = [] + for idx, block in enumerate(self.stages): + x = block(x) + if self.fork_feat: + if idx in self.out_indices: + norm_layer = getattr(self, f"norm{idx}") + x_out = norm_layer(x) + outs.append(x_out) + if self.fork_feat: + # output the features of four stages for dense prediction + return outs + x = self.final_conv(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False): + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + if self.fork_feat: + return x + x = self.forward_head(x) + return x + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 256, 256), + "pool_size": (8, 8), + "crop_pct": 0.9, + "interpolation": "bicubic", + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + 'first_conv': ('stem.0.conv_kxk.0.conv', 'stem.0.conv_scale.conv'), + "classifier": "head.fc", + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + "fastvit_t8.apple_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_t12.apple_in1k": _cfg( + hf_hub_id='timm/'), + + "fastvit_s12.apple_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_sa12.apple_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_sa24.apple_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_sa36.apple_in1k": _cfg( + hf_hub_id='timm/'), + + "fastvit_ma36.apple_in1k": _cfg( + hf_hub_id='timm/', + crop_pct=0.95 + ), + + "fastvit_t8.apple_dist_in1k": _cfg( + hf_hub_id='timm/'), + "fastvit_t12.apple_dist_in1k": _cfg( + hf_hub_id='timm/'), + + "fastvit_s12.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + "fastvit_sa12.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + "fastvit_sa24.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + "fastvit_sa36.apple_dist_in1k": _cfg( + hf_hub_id='timm/',), + + "fastvit_ma36.apple_dist_in1k": _cfg( + hf_hub_id='timm/', + crop_pct=0.95 + ), +}) + + +def _create_fastvit(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + model = build_model_with_cfg( + FastVit, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs + ) + return model + + +@register_model +def fastvit_t8(pretrained=False, **kwargs): + """Instantiate FastViT-T8 model variant.""" + model_args = dict( + layers=(2, 2, 4, 2), + embed_dims=(48, 96, 192, 384), + mlp_ratios=(3, 3, 3, 3), + token_mixers=("repmixer", "repmixer", "repmixer", "repmixer") + ) + return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_t12(pretrained=False, **kwargs): + """Instantiate FastViT-T12 model variant.""" + model_args = dict( + layers=(2, 2, 6, 2), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(3, 3, 3, 3), + token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), + ) + return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_s12(pretrained=False, **kwargs): + """Instantiate FastViT-S12 model variant.""" + model_args = dict( + layers=(2, 2, 6, 2), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(4, 4, 4, 4), + token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), + ) + return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_sa12(pretrained=False, **kwargs): + """Instantiate FastViT-SA12 model variant.""" + model_args = dict( + layers=(2, 2, 6, 2), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(4, 4, 4, 4), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_sa24(pretrained=False, **kwargs): + """Instantiate FastViT-SA24 model variant.""" + model_args = dict( + layers=(4, 4, 12, 4), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(4, 4, 4, 4), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_sa36(pretrained=False, **kwargs): + """Instantiate FastViT-SA36 model variant.""" + model_args = dict( + layers=(6, 6, 18, 6), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(4, 4, 4, 4), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_ma36(pretrained=False, **kwargs): + """Instantiate FastViT-MA36 model variant.""" + model_args = dict( + layers=(6, 6, 18, 6), + embed_dims=(76, 152, 304, 608), + mlp_ratios=(4, 4, 4, 4), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention") + ) + return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs)) \ No newline at end of file diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index f5b5123f54..b7c0f5ddcc 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -1,8 +1,11 @@ """ -An implementation of GhostNet Model as defined in: +An implementation of GhostNet & GhostNetV2 Models as defined in: GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907 -The train script of the model is similar to that of MobileNetV3 +GhostNetV2: Enhance Cheap Operation with Long-Range Attention. https://proceedings.neurips.cc/paper_files/paper/2022/file/40b60852a4abdaa696b5a1a78da34635-Paper-Conference.pdf + +The train script & code of models at: Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch +Original model: https://github.com/huawei-noah/Efficient-AI-Backbones/blob/master/ghostnetv2_pytorch/model/ghostnetv2_torch.py """ import math from functools import partial @@ -33,7 +36,8 @@ def __init__( ratio=2, dw_size=3, stride=1, - relu=True, + use_act=True, + act_layer=nn.ReLU, ): super(GhostModule, self).__init__() self.out_chs = out_chs @@ -43,13 +47,13 @@ def __init__( self.primary_conv = nn.Sequential( nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False), nn.BatchNorm2d(init_chs), - nn.ReLU(inplace=True) if relu else nn.Identity(), + act_layer(inplace=True) if use_act else nn.Identity(), ) self.cheap_operation = nn.Sequential( nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False), nn.BatchNorm2d(new_chs), - nn.ReLU(inplace=True) if relu else nn.Identity(), + act_layer(inplace=True) if use_act else nn.Identity(), ) def forward(self, x): @@ -59,6 +63,51 @@ def forward(self, x): return out[:, :self.out_chs, :, :] +class GhostModuleV2(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + use_act=True, + act_layer=nn.ReLU, + ): + super().__init__() + self.gate_fn = nn.Sigmoid() + self.out_chs = out_chs + init_chs = math.ceil(out_chs / ratio) + new_chs = init_chs * (ratio - 1) + self.primary_conv = nn.Sequential( + nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False), + nn.BatchNorm2d(init_chs), + act_layer(inplace=True) if use_act else nn.Identity(), + ) + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size // 2, groups=init_chs, bias=False), + nn.BatchNorm2d(new_chs), + act_layer(inplace=True) if use_act else nn.Identity(), + ) + self.short_conv = nn.Sequential( + nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False), + nn.BatchNorm2d(out_chs), + nn.Conv2d(out_chs, out_chs, kernel_size=(1, 5), stride=1, padding=(0, 2), groups=out_chs, bias=False), + nn.BatchNorm2d(out_chs), + nn.Conv2d(out_chs, out_chs, kernel_size=(5, 1), stride=1, padding=(2, 0), groups=out_chs, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2)) + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out[:, :self.out_chs, :, :] * F.interpolate( + self.gate_fn(res), size=(out.shape[-2], out.shape[-1]), mode='nearest') + + class GhostBottleneck(nn.Module): """ Ghost bottleneck w/ optional SE""" @@ -71,13 +120,17 @@ def __init__( stride=1, act_layer=nn.ReLU, se_ratio=0., + mode='original', ): super(GhostBottleneck, self).__init__() has_se = se_ratio is not None and se_ratio > 0. self.stride = stride # Point-wise expansion - self.ghost1 = GhostModule(in_chs, mid_chs, relu=True) + if mode == 'original': + self.ghost1 = GhostModule(in_chs, mid_chs, use_act=True, act_layer=act_layer) + else: + self.ghost1 = GhostModuleV2(in_chs, mid_chs, use_act=True, act_layer=act_layer) # Depth-wise convolution if self.stride > 1: @@ -93,7 +146,7 @@ def __init__( self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None # Point-wise linear projection - self.ghost2 = GhostModule(mid_chs, out_chs, relu=False) + self.ghost2 = GhostModule(mid_chs, out_chs, use_act=False) # shortcut if in_chs == out_chs and self.stride == 1: @@ -140,6 +193,7 @@ def __init__( output_stride=32, global_pool='avg', drop_rate=0.2, + version='v1', ): super(GhostNet, self).__init__() # setting of inverted residual blocks @@ -160,8 +214,8 @@ def __init__( # building inverted residual blocks stages = nn.ModuleList([]) - block = GhostBottleneck stage_idx = 0 + layer_idx = 0 net_stride = 2 for cfg in self.cfgs: layers = [] @@ -169,8 +223,12 @@ def __init__( for k, exp_size, c, se_ratio, s in cfg: out_chs = make_divisible(c * width, 4) mid_chs = make_divisible(exp_size * width, 4) - layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio)) + layer_kwargs = {} + if version == 'v2' and layer_idx > 1: + layer_kwargs['mode'] = 'attn' + layers.append(GhostBottleneck(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, **layer_kwargs)) prev_chs = out_chs + layer_idx += 1 if s > 1: net_stride *= 2 self.feature_info.append(dict( @@ -246,6 +304,15 @@ def forward(self, x): return x +def checkpoint_filter_fn(state_dict, model: nn.Module): + out_dict = {} + for k, v in state_dict.items(): + if 'total' in k: + continue + out_dict[k] = v + return out_dict + + def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs): """ Constructs a GhostNet model @@ -285,6 +352,7 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs): GhostNet, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True), **model_kwargs, ) @@ -293,7 +361,7 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs): def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'conv_stem', 'classifier': 'classifier', **kwargs @@ -303,8 +371,22 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ 'ghostnet_050.untrained': _cfg(), 'ghostnet_100.in1k': _cfg( - url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'), + hf_hub_id='timm/', + # url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth' + ), 'ghostnet_130.untrained': _cfg(), + 'ghostnetv2_100.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_10.pth.tar' + ), + 'ghostnetv2_130.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_13.pth.tar' + ), + 'ghostnetv2_160.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_16.pth.tar' + ), }) @@ -327,3 +409,24 @@ def ghostnet_130(pretrained=False, **kwargs) -> GhostNet: """ GhostNet-1.3x """ model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs) return model + + +@register_model +def ghostnetv2_100(pretrained=False, **kwargs) -> GhostNet: + """ GhostNetV2-1.0x """ + model = _create_ghostnet('ghostnetv2_100', width=1.0, pretrained=pretrained, version='v2', **kwargs) + return model + + +@register_model +def ghostnetv2_130(pretrained=False, **kwargs) -> GhostNet: + """ GhostNetV2-1.3x """ + model = _create_ghostnet('ghostnetv2_130', width=1.3, pretrained=pretrained, version='v2', **kwargs) + return model + + +@register_model +def ghostnetv2_160(pretrained=False, **kwargs) -> GhostNet: + """ GhostNetV2-1.6x """ + model = _create_ghostnet('ghostnetv2_160', width=1.6, pretrained=pretrained, version='v2', **kwargs) + return model diff --git a/timm/models/inception_next.py b/timm/models/inception_next.py new file mode 100644 index 0000000000..f5d37db981 --- /dev/null +++ b/timm/models/inception_next.py @@ -0,0 +1,441 @@ +""" +InceptionNeXt paper: https://arxiv.org/abs/2303.16900 +Original implementation & weights from: https://github.com/sail-sg/inceptionnext +""" + +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import trunc_normal_, DropPath, to_2tuple, get_padding, SelectAdaptivePool2d +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +class InceptionDWConv2d(nn.Module): + """ Inception depthwise convolution + """ + + def __init__( + self, + in_chs, + square_kernel_size=3, + band_kernel_size=11, + branch_ratio=0.125, + dilation=1, + ): + super().__init__() + + gc = int(in_chs * branch_ratio) # channel numbers of a convolution branch + square_padding = get_padding(square_kernel_size, dilation=dilation) + band_padding = get_padding(band_kernel_size, dilation=dilation) + self.dwconv_hw = nn.Conv2d( + gc, gc, square_kernel_size, + padding=square_padding, dilation=dilation, groups=gc) + self.dwconv_w = nn.Conv2d( + gc, gc, (1, band_kernel_size), + padding=(0, band_padding), dilation=(1, dilation), groups=gc) + self.dwconv_h = nn.Conv2d( + gc, gc, (band_kernel_size, 1), + padding=(band_padding, 0), dilation=(dilation, 1), groups=gc) + self.split_indexes = (in_chs - 3 * gc, gc, gc, gc) + + def forward(self, x): + x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1) + return torch.cat(( + x_id, + self.dwconv_hw(x_hw), + self.dwconv_w(x_w), + self.dwconv_h(x_h) + ), dim=1, + ) + + +class ConvMlp(nn.Module): + """ MLP using 1x1 convs that keeps spatial dims + copied from timm: https://github.com/huggingface/pytorch-image-models/blob/v0.6.11/timm/models/layers/mlp.py + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.ReLU, + norm_layer=None, + bias=True, + drop=0., + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.drop = nn.Dropout(drop) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x + + +class MlpClassifierHead(nn.Module): + """ MLP classification head + """ + + def __init__( + self, + dim, + num_classes=1000, + pool_type='avg', + mlp_ratio=3, + act_layer=nn.GELU, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0., + bias=True + ): + super().__init__() + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True) + in_features = dim * self.global_pool.feat_mult() + hidden_features = int(mlp_ratio * in_features) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) + self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.global_pool(x) + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.drop(x) + x = self.fc2(x) + return x + + +class MetaNeXtBlock(nn.Module): + """ MetaNeXtBlock Block + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + dilation=1, + token_mixer=InceptionDWConv2d, + norm_layer=nn.BatchNorm2d, + mlp_layer=ConvMlp, + mlp_ratio=4, + act_layer=nn.GELU, + ls_init_value=1e-6, + drop_path=0., + + ): + super().__init__() + self.token_mixer = token_mixer(dim, dilation=dilation) + self.norm = norm_layer(dim) + self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.token_mixer(x) + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + x = self.drop_path(x) + shortcut + return x + + +class MetaNeXtStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + stride=2, + depth=2, + dilation=(1, 1), + drop_path_rates=None, + ls_init_value=1.0, + token_mixer=InceptionDWConv2d, + act_layer=nn.GELU, + norm_layer=None, + mlp_ratio=4, + ): + super().__init__() + self.grad_checkpointing = False + if stride > 1 or dilation[0] != dilation[1]: + self.downsample = nn.Sequential( + norm_layer(in_chs), + nn.Conv2d( + in_chs, + out_chs, + kernel_size=2, + stride=stride, + dilation=dilation[0], + ), + ) + else: + self.downsample = nn.Identity() + + drop_path_rates = drop_path_rates or [0.] * depth + stage_blocks = [] + for i in range(depth): + stage_blocks.append(MetaNeXtBlock( + dim=out_chs, + dilation=dilation[1], + drop_path=drop_path_rates[i], + ls_init_value=ls_init_value, + token_mixer=token_mixer, + act_layer=act_layer, + norm_layer=norm_layer, + mlp_ratio=mlp_ratio, + )) + self.blocks = nn.Sequential(*stage_blocks) + + def forward(self, x): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class MetaNeXt(nn.Module): + r""" MetaNeXt + A PyTorch impl of : `InceptionNeXt: When Inception Meets ConvNeXt` - https://arxiv.org/abs/2303.16900 + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 9, 3) + dims (tuple(int)): Feature dimension at each stage. Default: (96, 192, 384, 768) + token_mixers: Token mixer function. Default: nn.Identity + norm_layer: Normalization layer. Default: nn.BatchNorm2d + act_layer: Activation function for MLP. Default: nn.GELU + mlp_ratios (int or tuple(int)): MLP ratios. Default: (4, 4, 4, 3) + head_fn: classifier head + drop_rate (float): Head dropout rate + drop_path_rate (float): Stochastic depth rate. Default: 0. + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + output_stride=32, + depths=(3, 3, 9, 3), + dims=(96, 192, 384, 768), + token_mixers=InceptionDWConv2d, + norm_layer=nn.BatchNorm2d, + act_layer=nn.GELU, + mlp_ratios=(4, 4, 4, 3), + head_fn=MlpClassifierHead, + drop_rate=0., + drop_path_rate=0., + ls_init_value=1e-6, + ): + super().__init__() + + num_stage = len(depths) + if not isinstance(token_mixers, (list, tuple)): + token_mixers = [token_mixers] * num_stage + if not isinstance(mlp_ratios, (list, tuple)): + mlp_ratios = [mlp_ratios] * num_stage + self.num_classes = num_classes + self.global_pool = global_pool + self.drop_rate = drop_rate + self.feature_info = [] + + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + norm_layer(dims[0]) + ) + + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + prev_chs = dims[0] + curr_stride = 4 + dilation = 1 + # feature resolution stages, each consisting of multiple residual blocks + self.stages = nn.Sequential() + for i in range(num_stage): + stride = 2 if curr_stride == 2 or i > 0 else 1 + if curr_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + curr_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + out_chs = dims[i] + self.stages.append(MetaNeXtStage( + prev_chs, + out_chs, + stride=stride if i > 0 else 1, + dilation=(first_dilation, dilation), + depth=depths[i], + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + act_layer=act_layer, + token_mixer=token_mixers[i], + norm_layer=norm_layer, + mlp_ratio=mlp_ratios[i], + )) + prev_chs = out_chs + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + self.num_features = prev_chs + if self.num_classes > 0: + if issubclass(head_fn, MlpClassifierHead): + assert self.global_pool, 'Cannot disable global pooling with MLP head present.' + self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=drop_rate) + else: + if self.global_pool: + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + else: + self.head = nn.Identity() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + ] + ) + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc2 + + def reset_classifier(self, num_classes=0, global_pool=None, head_fn=MlpClassifierHead): + if global_pool is not None: + self.global_pool = global_pool + if num_classes > 0: + if issubclass(head_fn, MlpClassifierHead): + assert self.global_pool, 'Cannot disable global pooling with MLP head present.' + self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=self.drop_rate) + else: + if self.global_pool: + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + else: + self.head = nn.Identity() + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return set() + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if pre_logits: + if hasattr(self.head, 'global_pool'): + x = self.head.global_pool(x) + return x + return self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc2', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'inception_next_tiny.sail_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth', + ), + 'inception_next_small.sail_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth', + ), + 'inception_next_base.sail_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth', + crop_pct=0.95, + ), + 'inception_next_base.sail_in1k_384': _cfg( + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), +}) + + +def _create_inception_next(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + MetaNeXt, variant, pretrained, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs, + ) + return model + + +@register_model +def inception_next_tiny(pretrained=False, **kwargs): + model_args = dict( + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), + token_mixers=InceptionDWConv2d, + ) + return _create_inception_next('inception_next_tiny', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def inception_next_small(pretrained=False, **kwargs): + model_args = dict( + depths=(3, 3, 27, 3), dims=(96, 192, 384, 768), + token_mixers=InceptionDWConv2d, + ) + return _create_inception_next('inception_next_small', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def inception_next_base(pretrained=False, **kwargs): + model_args = dict( + depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024), + token_mixers=InceptionDWConv2d, + ) + return _create_inception_next('inception_next_base', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index e3cf7adde4..12709f5818 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -48,7 +48,7 @@ from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert -from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn +from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import named_apply, checkpoint_seq @@ -186,9 +186,9 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): attn_bias = shared_rel_pos x = torch.nn.functional.scaled_dot_product_attention( - q.transpose(-1, -2), - k.transpose(-1, -2), - v.transpose(-1, -2), + q.transpose(-1, -2).contiguous(), + k.transpose(-1, -2).contiguous(), + v.transpose(-1, -2).contiguous(), attn_mask=attn_bias, dropout_p=self.attn_drop.p, ).transpose(-1, -2).reshape(B, -1, H, W) @@ -1790,6 +1790,15 @@ def checkpoint_filter_fn(state_dict, model: nn.Module): model_state_dict = model.state_dict() out_dict = {} for k, v in state_dict.items(): + if k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel(): # adapt between conv2d / linear layers assert v.ndim in (2, 4) diff --git a/timm/models/repghost.py b/timm/models/repghost.py new file mode 100644 index 0000000000..da697b705c --- /dev/null +++ b/timm/models/repghost.py @@ -0,0 +1,479 @@ +""" +An implementation of RepGhostNet Model as defined in: +RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization. https://arxiv.org/abs/2211.06088 + +Original implementation: https://github.com/ChengpengChen/RepGhost +""" +import copy +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import SelectAdaptivePool2d, Linear, make_divisible +from ._builder import build_model_with_cfg +from ._efficientnet_blocks import SqueezeExcite, ConvBnAct +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + +__all__ = ['RepGhostNet'] + + +_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4)) + + +class RepGhostModule(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size=1, + dw_size=3, + stride=1, + relu=True, + reparam=True, + ): + super(RepGhostModule, self).__init__() + self.out_chs = out_chs + init_chs = out_chs + new_chs = out_chs + + self.primary_conv = nn.Sequential( + nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False), + nn.BatchNorm2d(init_chs), + nn.ReLU(inplace=True) if relu else nn.Identity(), + ) + + fusion_conv = [] + fusion_bn = [] + if reparam: + fusion_conv.append(nn.Identity()) + fusion_bn.append(nn.BatchNorm2d(init_chs)) + + self.fusion_conv = nn.Sequential(*fusion_conv) + self.fusion_bn = nn.Sequential(*fusion_bn) + + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False), + nn.BatchNorm2d(new_chs), + # nn.ReLU(inplace=True) if relu else nn.Identity(), + ) + self.relu = nn.ReLU(inplace=False) if relu else nn.Identity() + + def forward(self, x): + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + for conv, bn in zip(self.fusion_conv, self.fusion_bn): + x2 = x2 + bn(conv(x1)) + return self.relu(x2) + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.cheap_operation[0], self.cheap_operation[1]) + for conv, bn in zip(self.fusion_conv, self.fusion_bn): + kernel, bias = self._fuse_bn_tensor(conv, bn, kernel3x3.shape[0], kernel3x3.device) + kernel3x3 += self._pad_1x1_to_3x3_tensor(kernel) + bias3x3 += bias + return kernel3x3, bias3x3 + + @staticmethod + def _pad_1x1_to_3x3_tensor(kernel1x1): + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) + + @staticmethod + def _fuse_bn_tensor(conv, bn, in_channels=None, device=None): + in_channels = in_channels if in_channels else bn.running_mean.shape[0] + device = device if device else bn.weight.device + if isinstance(conv, nn.Conv2d): + kernel = conv.weight + assert conv.bias is None + else: + assert isinstance(conv, nn.Identity) + kernel = torch.ones(in_channels, 1, 1, 1, device=device) + + if isinstance(bn, nn.BatchNorm2d): + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + assert isinstance(bn, nn.Identity) + return kernel, torch.zeros(in_channels).to(kernel.device) + + def switch_to_deploy(self): + if len(self.fusion_conv) == 0 and len(self.fusion_bn) == 0: + return + kernel, bias = self.get_equivalent_kernel_bias() + self.cheap_operation = nn.Conv2d( + in_channels=self.cheap_operation[0].in_channels, + out_channels=self.cheap_operation[0].out_channels, + kernel_size=self.cheap_operation[0].kernel_size, + padding=self.cheap_operation[0].padding, + dilation=self.cheap_operation[0].dilation, + groups=self.cheap_operation[0].groups, + bias=True) + self.cheap_operation.weight.data = kernel + self.cheap_operation.bias.data = bias + self.__delattr__('fusion_conv') + self.__delattr__('fusion_bn') + self.fusion_conv = [] + self.fusion_bn = [] + + def reparameterize(self): + self.switch_to_deploy() + + +class RepGhostBottleneck(nn.Module): + """ RepGhost bottleneck w/ optional SE""" + + def __init__( + self, + in_chs, + mid_chs, + out_chs, + dw_kernel_size=3, + stride=1, + act_layer=nn.ReLU, + se_ratio=0., + reparam=True, + ): + super(RepGhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0. + self.stride = stride + + # Point-wise expansion + self.ghost1 = RepGhostModule(in_chs, mid_chs, relu=True, reparam=reparam) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False) + self.bn_dw = nn.BatchNorm2d(mid_chs) + else: + self.conv_dw = None + self.bn_dw = None + + # Squeeze-and-excitation + self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None + + # Point-wise linear projection + self.ghost2 = RepGhostModule(mid_chs, out_chs, relu=False, reparam=reparam) + + # shortcut + if in_chs == out_chs and self.stride == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + shortcut = x + + # 1st ghost bottleneck + x = self.ghost1(x) + + # Depth-wise convolution + if self.conv_dw is not None: + x = self.conv_dw(x) + x = self.bn_dw(x) + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # 2nd ghost bottleneck + x = self.ghost2(x) + + x += self.shortcut(shortcut) + return x + + +class RepGhostNet(nn.Module): + def __init__( + self, + cfgs, + num_classes=1000, + width=1.0, + in_chans=3, + output_stride=32, + global_pool='avg', + drop_rate=0.2, + reparam=True, + ): + super(RepGhostNet, self).__init__() + # setting of inverted residual blocks + assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' + self.cfgs = cfgs + self.num_classes = num_classes + self.drop_rate = drop_rate + self.grad_checkpointing = False + self.feature_info = [] + + # building first layer + stem_chs = make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False) + self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem')) + self.bn1 = nn.BatchNorm2d(stem_chs) + self.act1 = nn.ReLU(inplace=True) + prev_chs = stem_chs + + # building inverted residual blocks + stages = nn.ModuleList([]) + block = RepGhostBottleneck + stage_idx = 0 + net_stride = 2 + for cfg in self.cfgs: + layers = [] + s = 1 + for k, exp_size, c, se_ratio, s in cfg: + out_chs = make_divisible(c * width, 4) + mid_chs = make_divisible(exp_size * width, 4) + layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, reparam=reparam)) + prev_chs = out_chs + if s > 1: + net_stride *= 2 + self.feature_info.append(dict( + num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}')) + stages.append(nn.Sequential(*layers)) + stage_idx += 1 + + out_chs = make_divisible(exp_size * width * 2, 4) + stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1))) + self.pool_dim = prev_chs = out_chs + + self.blocks = nn.Sequential(*stages) + + # building last several layers + self.num_features = out_chs = 1280 + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) + self.act2 = nn.ReLU(inplace=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^conv_stem|bn1', + blocks=[ + (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None), + (r'conv_head', (99999,)) + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + # cannot meaningfully change pooling of efficient head after creation + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x, flatten=True) + else: + x = self.blocks(x) + return x + + def forward_head(self, x): + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + x = self.flatten(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + def convert_to_deploy(self): + repghost_model_convert(self, do_copy=False) + + +def repghost_model_convert(model: torch.nn.Module, save_path=None, do_copy=True): + """ + taken from from https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py + """ + if do_copy: + model = copy.deepcopy(model) + for module in model.modules(): + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + if save_path is not None: + torch.save(model.state_dict(), save_path) + return model + + +def _create_repghostnet(variant, width=1.0, pretrained=False, **kwargs): + """ + Constructs a RepGhostNet model + """ + cfgs = [ + # k, t, c, SE, s + # stage1 + [[3, 8, 16, 0, 1]], + # stage2 + [[3, 24, 24, 0, 2]], + [[3, 36, 24, 0, 1]], + # stage3 + [[5, 36, 40, 0.25, 2]], + [[5, 60, 40, 0.25, 1]], + # stage4 + [[3, 120, 80, 0, 2]], + [[3, 100, 80, 0, 1], + [3, 120, 80, 0, 1], + [3, 120, 80, 0, 1], + [3, 240, 112, 0.25, 1], + [3, 336, 112, 0.25, 1] + ], + # stage5 + [[5, 336, 160, 0.25, 2]], + [[5, 480, 160, 0, 1], + [5, 480, 160, 0.25, 1], + [5, 480, 160, 0, 1], + [5, 480, 160, 0.25, 1] + ] + ] + model_kwargs = dict( + cfgs=cfgs, + width=width, + **kwargs, + ) + return build_model_with_cfg( + RepGhostNet, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True), + **model_kwargs, + ) + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'repghostnet_050.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_5x_43M_66.95.pth.tar' + ), + 'repghostnet_058.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_58x_60M_68.94.pth.tar' + ), + 'repghostnet_080.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_8x_96M_72.24.pth.tar' + ), + 'repghostnet_100.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_0x_142M_74.22.pth.tar' + ), + 'repghostnet_111.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_11x_170M_75.07.pth.tar' + ), + 'repghostnet_130.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_3x_231M_76.37.pth.tar' + ), + 'repghostnet_150.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_5x_301M_77.45.pth.tar' + ), + 'repghostnet_200.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_2_0x_516M_78.81.pth.tar' + ), +}) + + +@register_model +def repghostnet_050(pretrained=False, **kwargs) -> RepGhostNet: + """ RepGhostNet-0.5x """ + model = _create_repghostnet('repghostnet_050', width=0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def repghostnet_058(pretrained=False, **kwargs) -> RepGhostNet: + """ RepGhostNet-0.58x """ + model = _create_repghostnet('repghostnet_058', width=0.58, pretrained=pretrained, **kwargs) + return model + + +@register_model +def repghostnet_080(pretrained=False, **kwargs) -> RepGhostNet: + """ RepGhostNet-0.8x """ + model = _create_repghostnet('repghostnet_080', width=0.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def repghostnet_100(pretrained=False, **kwargs) -> RepGhostNet: + """ RepGhostNet-1.0x """ + model = _create_repghostnet('repghostnet_100', width=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def repghostnet_111(pretrained=False, **kwargs) -> RepGhostNet: + """ RepGhostNet-1.11x """ + model = _create_repghostnet('repghostnet_111', width=1.11, pretrained=pretrained, **kwargs) + return model + +@register_model +def repghostnet_130(pretrained=False, **kwargs) -> RepGhostNet: + """ RepGhostNet-1.3x """ + model = _create_repghostnet('repghostnet_130', width=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def repghostnet_150(pretrained=False, **kwargs) -> RepGhostNet: + """ RepGhostNet-1.5x """ + model = _create_repghostnet('repghostnet_150', width=1.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def repghostnet_200(pretrained=False, **kwargs) -> RepGhostNet: + """ RepGhostNet-2.0x """ + model = _create_repghostnet('repghostnet_200', width=2.0, pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/repvit.py b/timm/models/repvit.py index b0199b8986..a0def2f41c 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -15,7 +15,7 @@ Adapted from official impl at https://github.com/jameslahm/RepViT """ -__all__ = ['RepViT'] +__all__ = ['RepVit'] import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -81,7 +81,7 @@ def fuse(self): return m -class RepVGGDW(nn.Module): +class RepVggDw(nn.Module): def __init__(self, ed, kernel_size): super().__init__() self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed) @@ -115,7 +115,7 @@ def fuse(self): return conv -class RepViTMlp(nn.Module): +class RepVitMlp(nn.Module): def __init__(self, in_dim, hidden_dim, act_layer): super().__init__() self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0) @@ -130,9 +130,9 @@ class RepViTBlock(nn.Module): def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer): super(RepViTBlock, self).__init__() - self.token_mixer = RepVGGDW(in_dim, kernel_size) + self.token_mixer = RepVggDw(in_dim, kernel_size) self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity() - self.channel_mixer = RepViTMlp(in_dim, in_dim * mlp_ratio, act_layer) + self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer) def forward(self, x): x = self.token_mixer(x) @@ -142,7 +142,7 @@ def forward(self, x): return identity + x -class RepViTStem(nn.Module): +class RepVitStem(nn.Module): def __init__(self, in_chs, out_chs, act_layer): super().__init__() self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1) @@ -154,13 +154,13 @@ def forward(self, x): return self.conv2(self.act1(self.conv1(x))) -class RepViTDownsample(nn.Module): +class RepVitDownsample(nn.Module): def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer): super().__init__() self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer) self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim) self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1) - self.ffn = RepViTMlp(out_dim, out_dim * mlp_ratio, act_layer) + self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer) def forward(self, x): x = self.pre_block(x) @@ -171,21 +171,25 @@ def forward(self, x): return x + identity -class RepViTClassifier(nn.Module): - def __init__(self, dim, num_classes, distillation=False): +class RepVitClassifier(nn.Module): + def __init__(self, dim, num_classes, distillation=False, drop=0.): super().__init__() + self.head_drop = nn.Dropout(drop) self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() self.distillation = distillation + self.distilled_training = False + self.num_classes = num_classes if distillation: self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() def forward(self, x): + x = self.head_drop(x) if self.distillation: x1, x2 = self.head(x), self.head_dist(x) - if (not self.training) or torch.jit.is_scripting(): - return (x1 + x2) / 2 - else: + if self.training and self.distilled_training and not torch.jit.is_scripting(): return x1, x2 + else: + return (x1 + x2) / 2 else: x = self.head(x) return x @@ -206,11 +210,11 @@ def fuse(self): return head -class RepViTStage(nn.Module): +class RepVitStage(nn.Module): def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True): super().__init__() if downsample: - self.downsample = RepViTDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer) + self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer) else: assert in_dim == out_dim self.downsample = nn.Identity() @@ -229,7 +233,7 @@ def forward(self, x): return x -class RepViT(nn.Module): +class RepVit(nn.Module): def __init__( self, in_chans=3, @@ -242,15 +246,16 @@ def __init__( num_classes=1000, act_layer=nn.GELU, distillation=True, + drop_rate=0., ): - super(RepViT, self).__init__() + super(RepVit, self).__init__() self.grad_checkpointing = False self.global_pool = global_pool self.embed_dim = embed_dim self.num_classes = num_classes in_dim = embed_dim[0] - self.stem = RepViTStem(in_chans, in_dim, act_layer) + self.stem = RepVitStem(in_chans, in_dim, act_layer) stride = self.stem.stride resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))]) @@ -262,7 +267,7 @@ def __init__( for i in range(num_stages): downsample = True if i != 0 else False stages.append( - RepViTStage( + RepVitStage( in_dim, embed_dim[i], depth[i], @@ -280,7 +285,8 @@ def __init__( self.stages = nn.Sequential(*stages) self.num_features = embed_dim[-1] - self.head = RepViTClassifier(embed_dim[-1], num_classes, distillation) + self.head_drop = nn.Dropout(drop_rate) + self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation) @torch.jit.ignore def group_matcher(self, coarse=False): @@ -303,9 +309,13 @@ def reset_classifier(self, num_classes, global_pool=None, distillation=False): if global_pool is not None: self.global_pool = global_pool self.head = ( - RepViTClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity() + RepVitClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity() ) + @torch.jit.ignore + def set_distilled_training(self, enable=True): + self.head.distilled_training = enable + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -316,8 +326,9 @@ def forward_features(self, x): def forward_head(self, x, pre_logits: bool = False): if self.global_pool == 'avg': - x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) - return x if pre_logits else self.head(x) + x = x.mean((2, 3), keepdim=False) + x = self.head_drop(x) + return self.head(x) def forward(self, x): x = self.forward_features(x) @@ -357,13 +368,16 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs( { 'repvit_m1.dist_in1k': _cfg( - url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth' + hf_hub_id='timm/', + # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth' ), 'repvit_m2.dist_in1k': _cfg( - url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth' + hf_hub_id='timm/', + # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth' ), 'repvit_m3.dist_in1k': _cfg( - url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth' + hf_hub_id='timm/', + # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth' ), } ) @@ -372,7 +386,9 @@ def _cfg(url='', **kwargs): def _create_repvit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) model = build_model_with_cfg( - RepViT, variant, pretrained, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs + RepVit, variant, pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs, ) return model diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 2cca973673..a96c69548c 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -24,7 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ - _assert, use_fused_attn + _assert, use_fused_attn, resize_rel_pos_bias_table from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply @@ -38,23 +38,28 @@ _int_or_tuple_2_t = Union[int, Tuple[int, int]] -def window_partition(x, window_size: int): +def window_partition( + x: torch.Tensor, + window_size: Tuple[int, int], +) -> torch.Tensor: """ + Partition into non-overlapping windows with padding if needed. Args: - x: (B, H, W, C) - window_size (int): window size + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. Returns: - windows: (num_windows*B, window_size, window_size, C) + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) return windows @register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows, window_size: int, H: int, W: int): +def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int): """ Args: windows: (num_windows*B, window_size, window_size, C) @@ -66,7 +71,7 @@ def window_reverse(windows, window_size: int, H: int, W: int): x: (B, H, W, C) """ C = windows.shape[-1] - x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C) + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x @@ -124,7 +129,7 @@ def __init__( self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) # get pair-wise relative position index for each token inside the window - self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w)) + self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False) self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -218,14 +223,11 @@ def __init__( super().__init__() self.dim = dim self.input_resolution = input_resolution - self.window_size = window_size - self.shift_size = shift_size + ws, ss = self._calc_window_shift(window_size, shift_size) + self.window_size: Tuple[int, int] = ws + self.shift_size: Tuple[int, int] = ss + self.window_area = self.window_size[0] * self.window_size[1] self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( @@ -237,8 +239,8 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = Mlp( in_features=dim, @@ -246,66 +248,81 @@ def __init__( act_layer=act_layer, drop=proj_drop, ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - if self.shift_size > 0: + if any(self.shift_size): # calculate attention mask for SW-MSA H, W = self.input_resolution + H = math.ceil(H / self.window_size[0]) * self.window_size[0] + W = math.ceil(W / self.window_size[1]) * self.window_size[1] img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 cnt = 0 for h in ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)): + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)): for w in ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)): + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)): img_mask[:, h, w, :] = cnt cnt += 1 - mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_area) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None - self.register_buffer("attn_mask", attn_mask) - def forward(self, x): - B, H, W, C = x.shape - _assert(H == self.input_resolution[0], "input feature has wrong size") - _assert(W == self.input_resolution[1], "input feature has wrong size") + self.register_buffer("attn_mask", attn_mask, persistent=False) + + def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_2tuple(target_window_size) + target_shift_size = to_2tuple(target_shift_size) + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return tuple(window_size), tuple(shift_size) - shortcut = x - x = self.norm1(x) + def _attn(self, x): + B, H, W, C = x.shape # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + has_shift = any(self.shift_size) + if has_shift: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) else: shifted_x = x + # pad for resolution not divisible by window size + pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + # partition windows - x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + shifted_x = shifted_x[:, :H, :W, :].contiguous() # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + if has_shift: + x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) else: x = shifted_x + return x - # FFN - x = shortcut + self.drop_path(x) - + def forward(self, x): + B, H, W, C = x.shape + x = x + self.drop_path1(self._attn(self.norm1(x))) x = x.reshape(B, -1, C) - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x + self.drop_path2(self.mlp(self.norm2(x))) x = x.reshape(B, H, W, C) return x @@ -385,6 +402,8 @@ def __init__( self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution self.depth = depth self.grad_checkpointing = False + window_size = to_2tuple(window_size) + shift_size = tuple([w // 2 for w in window_size]) # patch merging layer if downsample: @@ -405,7 +424,7 @@ def __init__( num_heads=num_heads, head_dim=head_dim, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, + shift_size=0 if (i % 2 == 0) else shift_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_drop=proj_drop, @@ -499,7 +518,11 @@ def __init__( # build layers head_dim = to_ntuple(self.num_layers)(head_dim) - window_size = to_ntuple(self.num_layers)(window_size) + if not isinstance(window_size, (list, tuple)): + window_size = to_ntuple(self.num_layers)(window_size) + elif len(window_size) == 2: + window_size = (window_size,) * self.num_layers + assert len(window_size) == self.num_layers mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio) dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] layers = [] @@ -598,15 +621,30 @@ def forward(self, x): def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" + old_weights = True if 'head.fc.weight' in state_dict: - return state_dict + old_weights = False import re out_dict = {} state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('state_dict', state_dict) for k, v in state_dict.items(): - k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) - k = k.replace('head.', 'head.fc.') + if any([n in k for n in ('relative_position_index', 'attn_mask')]): + continue # skip buffers that should not be persistent + + if k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + + if old_weights: + k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) + k = k.replace('head.', 'head.fc.') + out_dict[k] = v return out_dict diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index dba74a9a38..eca2ae7939 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -398,6 +398,8 @@ def __init__( self.depth = depth self.output_nchw = output_nchw self.grad_checkpointing = False + window_size = to_2tuple(window_size) + shift_size = tuple([w // 2 for w in window_size]) # patch merging / downsample layer if downsample: @@ -413,7 +415,7 @@ def __init__( input_resolution=self.output_resolution, num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, + shift_size=0 if (i % 2 == 0) else shift_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_drop=proj_drop, @@ -568,7 +570,7 @@ def _init_weights(self, m): def no_weight_decay(self): nod = set() for n, m in self.named_modules(): - if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]): + if any([kw in n for kw in ("cpb_mlp", "logit_scale")]): nod.add(n) return nod diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py new file mode 100644 index 0000000000..4b5836584c --- /dev/null +++ b/timm/models/tiny_vit.py @@ -0,0 +1,715 @@ +""" TinyViT + +Paper: `TinyViT: Fast Pretraining Distillation for Small Vision Transformers` + - https://arxiv.org/abs/2207.10666 + +Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyViT +""" + +__all__ = ['TinyVit'] + +import math +import itertools +from functools import partial +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ + trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +class ConvNorm(torch.nn.Sequential): + def __init__(self, in_chs, out_chs, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + super().__init__() + self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False) + self.bn = nn.BatchNorm2d(out_chs) + torch.nn.init.constant_(self.bn.weight, bn_weight_init) + torch.nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self.conv, self.bn + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d( + w.size(1) * self.conv.groups, w.size(0), w.shape[2:], + stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class PatchEmbed(nn.Module): + def __init__(self, in_chs, out_chs, act_layer): + super().__init__() + self.stride = 4 + self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1) + self.act = act_layer() + self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.act(x) + x = self.conv2(x) + return x + + +class MBConv(nn.Module): + def __init__(self, in_chs, out_chs, expand_ratio, act_layer, drop_path): + super().__init__() + mid_chs = int(in_chs * expand_ratio) + self.conv1 = ConvNorm(in_chs, mid_chs, ks=1) + self.act1 = act_layer() + self.conv2 = ConvNorm(mid_chs, mid_chs, ks=3, stride=1, pad=1, groups=mid_chs) + self.act2 = act_layer() + self.conv3 = ConvNorm(mid_chs, out_chs, ks=1, bn_weight_init=0.0) + self.act3 = act_layer() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop_path(x) + x += shortcut + x = self.act3(x) + return x + + +class PatchMerging(nn.Module): + def __init__(self, dim, out_dim, act_layer): + super().__init__() + self.conv1 = ConvNorm(dim, out_dim, 1, 1, 0) + self.act1 = act_layer() + self.conv2 = ConvNorm(out_dim, out_dim, 3, 2, 1, groups=out_dim) + self.act2 = act_layer() + self.conv3 = ConvNorm(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + return x + + +class ConvLayer(nn.Module): + def __init__( + self, + dim, + depth, + act_layer, + drop_path=0., + conv_expand_ratio=4., + ): + super().__init__() + self.dim = dim + self.depth = depth + self.blocks = nn.Sequential(*[ + MBConv( + dim, dim, conv_expand_ratio, act_layer, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth) + ]) + + def forward(self, x): + x = self.blocks(x) + return x + + +class NormMlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + drop=0., + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = norm_layer(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.drop1 = nn.Dropout(drop) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class Attention(torch.nn.Module): + fused_attn: torch.jit.Final[bool] + attention_bias_cache: Dict[str, torch.Tensor] + + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.val_dim = int(attn_ratio * key_dim) + self.out_dim = self.val_dim * num_heads + self.attn_ratio = attn_ratio + self.resolution = resolution + self.fused_attn = use_fused_attn() + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, num_heads * (self.val_dim + 2 * key_dim)) + self.proj = nn.Linear(self.out_dim, dim) + + points = list(itertools.product(range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + self.attention_bias_cache = {} + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if torch.jit.is_tracing() or self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] + + def forward(self, x): + attn_bias = self.get_attention_biases(x.device) + B, N, _ = x.shape + # Normalization + x = self.norm(x) + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn + attn_bias + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N, self.out_dim) + x = self.proj(x) + return x + + +class TinyVitBlock(nn.Module): + """ TinyViT Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + act_layer: the activation function. Default: nn.GELU + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + mlp_ratio=4., + drop=0., + drop_path=0., + local_conv_size=3, + act_layer=nn.GELU + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + + self.mlp = NormMlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + pad = local_conv_size // 2 + self.local_conv = ConvNorm(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + B, H, W, C = x.shape + L = H * W + + shortcut = x + if H == self.window_size and W == self.window_size: + x = x.reshape(B, L, C) + x = self.attn(x) + x = x.view(B, H, W, C) + else: + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_r = (self.window_size - W % self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + # window partition + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C + ) + + x = self.attn(x) + + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + x = shortcut + self.drop_path1(x) + + x = x.permute(0, 3, 1, 2) + x = self.local_conv(x) + x = x.reshape(B, C, L).transpose(1, 2) + + x = x + self.drop_path2(self.mlp(x)) + return x.view(B, H, W, C) + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + + +register_notrace_module(TinyVitBlock) + + +class TinyVitStage(nn.Module): + """ A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + out_dim: the output dimension of the layer + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + act_layer: the activation function. Default: nn.GELU + """ + + def __init__( + self, + dim, + out_dim, + depth, + num_heads, + window_size, + mlp_ratio=4., + drop=0., + drop_path=0., + downsample=None, + local_conv_size=3, + act_layer=nn.GELU, + ): + + super().__init__() + self.depth = depth + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + dim=dim, + out_dim=out_dim, + act_layer=act_layer, + ) + else: + self.downsample = nn.Identity() + assert dim == out_dim + + # build blocks + self.blocks = nn.Sequential(*[ + TinyVitBlock( + dim=out_dim, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + local_conv_size=local_conv_size, + act_layer=act_layer, + ) + for i in range(depth)]) + + def forward(self, x): + x = self.downsample(x) + x = x.permute(0, 2, 3, 1) # BCHW -> BHWC + x = self.blocks(x) + x = x.permute(0, 3, 1, 2) # BHWC -> BCHW + return x + + def extra_repr(self) -> str: + return f"dim={self.out_dim}, depth={self.depth}" + + +class TinyVit(nn.Module): + def __init__( + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + embed_dims=(96, 192, 384, 768), + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_sizes=(7, 7, 14, 7), + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + act_layer=nn.GELU, + ): + super().__init__() + + self.num_classes = num_classes + self.depths = depths + self.num_stages = len(depths) + self.mlp_ratio = mlp_ratio + self.grad_checkpointing = use_checkpoint + + self.patch_embed = PatchEmbed( + in_chs=in_chans, + out_chs=embed_dims[0], + act_layer=act_layer, + ) + + # stochastic depth rate rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + # build stages + self.stages = nn.Sequential() + stride = self.patch_embed.stride + prev_dim = embed_dims[0] + self.feature_info = [] + for stage_idx in range(self.num_stages): + if stage_idx == 0: + stage = ConvLayer( + dim=prev_dim, + depth=depths[stage_idx], + act_layer=act_layer, + drop_path=dpr[:depths[stage_idx]], + conv_expand_ratio=mbconv_expand_ratio, + ) + else: + out_dim = embed_dims[stage_idx] + drop_path_rate = dpr[sum(depths[:stage_idx]):sum(depths[:stage_idx + 1])] + stage = TinyVitStage( + dim=embed_dims[stage_idx - 1], + out_dim=out_dim, + depth=depths[stage_idx], + num_heads=num_heads[stage_idx], + window_size=window_sizes[stage_idx], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + drop_path=drop_path_rate, + downsample=PatchMerging, + act_layer=act_layer, + ) + prev_dim = out_dim + stride *= 2 + self.stages.append(stage) + self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{stage_idx}')] + + # Classifier head + self.num_features = embed_dims[-1] + + norm_layer_cf = partial(LayerNorm2d, eps=1e-5) + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + norm_layer=norm_layer_cf, + ) + + # init weights + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'attention_biases'} + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embed', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.\w+\.(\d+)', None), + ] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool=global_pool) + + def forward_features(self, x): + x = self.patch_embed(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x): + x = self.head(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict.keys(): + state_dict = state_dict['model'] + target_sd = model.state_dict() + out_dict = {} + for k, v in state_dict.items(): + if k.endswith('attention_bias_idxs'): + continue + if 'attention_biases' in k: + # TODO: whether move this func into model for dynamic input resolution? (high risk) + v = resize_rel_pos_bias_table_levit(v.T, target_sd[k].shape[::-1]).T + out_dict[k] = v + return out_dict + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv1.conv', + 'classifier': 'head.fc', + 'pool_size': (7, 7), + 'input_size': (3, 224, 224), + 'crop_pct': 0.95, + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'tiny_vit_5m_224.dist_in22k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth', + num_classes=21841 + ), + 'tiny_vit_5m_224.dist_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth' + ), + 'tiny_vit_5m_224.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth' + ), + 'tiny_vit_11m_224.dist_in22k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth', + num_classes=21841 + ), + 'tiny_vit_11m_224.dist_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth' + ), + 'tiny_vit_11m_224.in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth' + ), + 'tiny_vit_21m_224.dist_in22k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth', + num_classes=21841 + ), + 'tiny_vit_21m_224.dist_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth' + ), + 'tiny_vit_21m_224.in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth' + ), + 'tiny_vit_21m_384.dist_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'tiny_vit_21m_512.dist_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash', + ), +}) + + +def _create_tiny_vit(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + model = build_model_with_cfg( + TinyVit, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs + ) + return model + + +@register_model +def tiny_vit_5m_224(pretrained=False, **kwargs): + model_kwargs = dict( + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + drop_path_rate=0.0, + ) + model_kwargs.update(kwargs) + return _create_tiny_vit('tiny_vit_5m_224', pretrained, **model_kwargs) + + +@register_model +def tiny_vit_11m_224(pretrained=False, **kwargs): + model_kwargs = dict( + embed_dims=[64, 128, 256, 448], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 8, 14], + window_sizes=[7, 7, 14, 7], + drop_path_rate=0.1, + ) + model_kwargs.update(kwargs) + return _create_tiny_vit('tiny_vit_11m_224', pretrained, **model_kwargs) + + +@register_model +def tiny_vit_21m_224(pretrained=False, **kwargs): + model_kwargs = dict( + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[7, 7, 14, 7], + drop_path_rate=0.2, + ) + model_kwargs.update(kwargs) + return _create_tiny_vit('tiny_vit_21m_224', pretrained, **model_kwargs) + + +@register_model +def tiny_vit_21m_384(pretrained=False, **kwargs): + model_kwargs = dict( + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[12, 12, 24, 12], + drop_path_rate=0.1, + ) + model_kwargs.update(kwargs) + return _create_tiny_vit('tiny_vit_21m_384', pretrained, **model_kwargs) + + +@register_model +def tiny_vit_21m_512(pretrained=False, **kwargs): + model_kwargs = dict( + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[16, 16, 32, 16], + drop_path_rate=0.1, + ) + model_kwargs.update(kwargs) + return _create_tiny_vit('tiny_vit_21m_512', pretrained, **model_kwargs) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 025a01a84a..10b9296b49 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -383,6 +383,7 @@ class VisionTransformer(nn.Module): A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ + dynamic_img_size: Final[bool] def __init__( self, @@ -402,6 +403,8 @@ def __init__( no_embed_class: bool = False, pre_norm: bool = False, fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, drop_rate: float = 0., pos_drop_rate: float = 0., patch_drop_rate: float = 0., @@ -452,14 +455,21 @@ def __init__( self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 self.no_embed_class = no_embed_class + self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, ) num_patches = self.patch_embed.num_patches @@ -546,10 +556,20 @@ def reset_classifier(self, num_classes: int, global_pool=None): self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _pos_embed(self, x): + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat - x = x + self.pos_embed + x = x + pos_embed if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) else: @@ -557,7 +577,7 @@ def _pos_embed(self, x): # pos_embed has entry for class token, concat then add if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.pos_embed + x = x + pos_embed return self.pos_drop(x) def _intermediate_layers( diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 8cf7bec1e6..e29bf73feb 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -14,13 +14,14 @@ Hacked together by / Copyright 2020, Ross Wightman """ from functools import partial -from typing import List, Tuple +from typing import List, Optional, Tuple import torch import torch.nn as nn +import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import StdConv2dSame, StdConv2d, to_2tuple +from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to from ._registry import generate_default_cfgs, register_model, register_model_deprecations from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem @@ -31,6 +32,9 @@ class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ + output_fmt: Format + dynamic_img_pad: torch.jit.Final[bool] + def __init__( self, backbone, @@ -40,6 +44,10 @@ def __init__( in_chans=3, embed_dim=768, bias=True, + flatten: bool = True, + output_fmt: Optional[str] = None, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, ): super().__init__() assert isinstance(backbone, nn.Module) @@ -66,17 +74,36 @@ def __init__( feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features - assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 + if not dynamic_img_pad: + assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] + if output_fmt is not None: + self.flatten = False + self.output_fmt = Format(output_fmt) + else: + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.output_fmt = Format.NCHW + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features + _, _, H, W = x.shape + if self.dynamic_img_pad: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) x = self.proj(x) - x = x.flatten(2).transpose(1, 2) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + elif self.output_fmt != Format.NCHW: + x = nchw_to(x, self.output_fmt) return x diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index c561ea1b22..53c49b071e 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -17,13 +17,15 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint +from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\ - Format, resample_abs_pos_embed_nhwc + Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model +from ._features_fx import register_notrace_function # model_registry will add each entrypoint fn to this __all__ = ['VisionTransformerSAM'] @@ -32,7 +34,77 @@ _logger = logging.getLogger(__name__) +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + +register_notrace_function(get_rel_pos) + + +def get_decomposed_rel_pos_bias( + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + Args: + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + bias (Tensor): attention bias to add to attention map + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn_bias = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + return attn_bias.reshape(-1, q_h * q_w, k_h * k_w) + + class Attention(nn.Module): + fused_attn: Final[bool] def __init__( self, @@ -44,14 +116,15 @@ def __init__( proj_drop=0., norm_layer=nn.LayerNorm, use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int]] = None, + rope: Optional[nn.Module] = None, ): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() @@ -61,6 +134,7 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) self.use_rel_pos = use_rel_pos if self.use_rel_pos: + assert rope is None assert ( input_size is not None ), "Input size must be provided if using relative positional encoding." @@ -69,26 +143,45 @@ def __init__( 2 * input_size[0] - 1, self.head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros( 2 * input_size[1] - 1, self.head_dim)) + self.rope = rope def forward(self, x): B, H, W, _ = x.shape - qkv = self.qkv(x).reshape( - B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + N = H * W + x = x.reshape(B, N, -1) + qkv = self.qkv(x).view(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # qkv with shape (3, B, nHead, H * W, C) - q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + q, k, v = qkv.reshape(3, B * self.num_heads, N, -1).unbind(0) # q, k, v with shape (B * nHead, H * W, C) q, k = self.q_norm(q), self.k_norm(k) - q = q * self.scale - attn = q @ k.transpose(-2, -1) if self.use_rel_pos: - attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + attn_bias = get_decomposed_rel_pos_bias(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + else: + attn_bias = None + if self.rope is not None: + rope = self.rope.get_embed() + q = apply_rot_embed_cat(q, rope).type_as(v) + k = apply_rot_embed_cat(k, rope).type_as(v) + + if self.fused_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_bias, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.view(B, self.num_heads, N, -1).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) - + x = x.view(B, H, W, -1) return x @@ -121,6 +214,7 @@ def __init__( use_rel_pos=False, window_size=0, input_size=None, + rope=None, ): super().__init__() self.window_size = window_size @@ -135,6 +229,7 @@ def __init__( norm_layer=norm_layer, use_rel_pos=use_rel_pos, input_size=input_size if window_size == 0 else (window_size, window_size), + rope=rope, ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -150,20 +245,26 @@ def __init__( self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): + B, H, W, _ = x.shape + shortcut = x x = self.norm1(x) # Window partition + pad_hw: Optional[Tuple[int, int]] = None if self.window_size > 0: - H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.drop_path1(self.ls1(self.attn(x))) + # Reverse window partition if self.window_size > 0: - x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + x = window_unpartition(x, self.window_size, (H, W), pad_hw) x = shortcut + x + + x = x.reshape(B, H * W, -1) # MLP is faster for N, L, C tensor x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + x = x.reshape(B, H, W, -1) return x @@ -183,8 +284,7 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size - if pad_h > 0 or pad_w > 0: - x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) @@ -193,7 +293,7 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T def window_unpartition( - windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] + windows: torch.Tensor, window_size: int, hw: Tuple[int, int], pad_hw: Optional[Tuple[int, int]] = None, ) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. @@ -206,90 +306,15 @@ def window_unpartition( Returns: x: unpartitioned sequences with [B, H, W, C]. """ - Hp, Wp = pad_hw + Hp, Wp = pad_hw if pad_hw is not None else hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) - - if Hp > H or Wp > W: - x = x[:, :H, :W, :].contiguous() + x = x[:, :H, :W, :].contiguous() return x -def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: - """ - Get relative positional embeddings according to the relative positions of - query and key sizes. - Args: - q_size (int): size of query q. - k_size (int): size of key k. - rel_pos (Tensor): relative position embeddings (L, C). - - Returns: - Extracted positional embeddings according to relative positions. - """ - max_rel_dist = int(2 * max(q_size, k_size) - 1) - # Interpolate rel pos if needed. - if rel_pos.shape[0] != max_rel_dist: - # Interpolate rel pos. - rel_pos_resized = F.interpolate( - rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), - size=max_rel_dist, - mode="linear", - ) - rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) - else: - rel_pos_resized = rel_pos - - # Scale the coords with short length if shapes for q and k are different. - q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) - k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) - relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) - - return rel_pos_resized[relative_coords.long()] - - -def add_decomposed_rel_pos( - attn: torch.Tensor, - q: torch.Tensor, - rel_pos_h: torch.Tensor, - rel_pos_w: torch.Tensor, - q_size: Tuple[int, int], - k_size: Tuple[int, int], -) -> torch.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py - Args: - attn (Tensor): attention map. - q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). - rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. - rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. - q_size (Tuple): spatial sequence size of query q with (q_h, q_w). - k_size (Tuple): spatial sequence size of key k with (k_h, k_w). - - Returns: - attn (Tensor): attention map with added relative positional embeddings. - """ - q_h, q_w = q_size - k_h, k_w = k_size - Rh = get_rel_pos(q_h, k_h, rel_pos_h) - Rw = get_rel_pos(q_w, k_w, rel_pos_w) - - B, _, dim = q.shape - r_q = q.reshape(B, q_h, q_w, dim) - rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) - rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) - - attn = ( - attn.view(B, q_h, q_w, k_h, k_w) + - rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] - ).view(B, q_h * q_w, k_h * k_w) - - return attn - - class VisionTransformerSAM(nn.Module): """ Vision Transformer for Segment-Anything Model(SAM) @@ -326,11 +351,13 @@ def __init__( mlp_layer: Callable = Mlp, use_abs_pos: bool = True, use_rel_pos: bool = False, + use_rope: bool = False, window_size: int = 14, global_attn_indexes: Tuple[int, ...] = (), neck_chans: int = 256, global_pool: str = 'avg', - head_hidden_size: Optional[int] = None + head_hidden_size: Optional[int] = None, + ref_feat_shape: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None ): """ Args: @@ -356,10 +383,12 @@ def __init__( block_fn: Transformer block layer. use_abs_pos: If True, use absolute positional embeddings. use_rel_pos: If True, add relative positional embeddings to the attention map. + use_rope: If True, add rotary position embeddings to q/k in attention block. window_size: Window size for window attention blocks. If 0, not use window attention. global_attn_indexes: Indexes for blocks using global attention. Used when window_size > 0. global_pool: Global pooling type. head_hidden_size: If set, use NormMlpHead + ref_feat_shape: Tuple of reference feature shapes for ROPE, (global, local) """ super().__init__() norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) @@ -394,6 +423,30 @@ def __init__( self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + if use_rope: + assert not use_rel_pos, "ROPE and relative pos embeddings should not be enabled at same time" + if ref_feat_shape is not None: + assert len(ref_feat_shape) == 2 + ref_feat_shape_global = to_2tuple(ref_feat_shape[0]) + ref_feat_shape_window = to_2tuple(ref_feat_shape[1]) + else: + ref_feat_shape_global = ref_feat_shape_window = None + self.rope_global = RotaryEmbeddingCat( + embed_dim // num_heads, + in_pixels=False, + feat_shape=grid_size, + ref_feat_shape=ref_feat_shape_global, + ) + self.rope_window = RotaryEmbeddingCat( + embed_dim // num_heads, + in_pixels=False, + feat_shape=to_2tuple(window_size), + ref_feat_shape=ref_feat_shape_window, + ) + else: + self.rope_global = None + self.rope_window = None + # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] self.blocks = nn.Sequential(*[ @@ -413,6 +466,7 @@ def __init__( use_rel_pos=use_rel_pos, window_size=window_size if i not in global_attn_indexes else 0, input_size=grid_size, + rope=self.rope_window if i not in global_attn_indexes else self.rope_global, ) for i in range(depth)]) @@ -434,8 +488,13 @@ def __init__( ), LayerNorm2d(neck_chans), ) + self.num_features = neck_chans else: - self.neck = nn.Identity() + if head_hidden_size: + self.neck = nn.Identity() + else: + # should have a final norm with standard ClassifierHead + self.neck = LayerNorm2d(embed_dim) neck_chans = embed_dim # Classifier Head @@ -525,7 +584,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 1024, 1024), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', **kwargs } @@ -551,6 +610,10 @@ def _cfg(url='', **kwargs): license='apache-2.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 1024, 1024), crop_pct=1.0), + + 'samvit_base_patch16_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=1000, + input_size=(3, 224, 224), crop_pct=0.9), }) @@ -605,3 +668,17 @@ def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: model = _create_vision_transformer( 'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) return model + + +@register_model +def samvit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformerSAM: + """ ViT-B/16 based on samvit arch + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, global_attn_indexes=[2, 5, 8, 11], + window_size=14, use_rel_pos=True, use_abs_pos=False, img_size=224, neck_chans=None, + ) + model = _create_vision_transformer( + 'samvit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 7727adff40..63fcf4c5b4 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -9,7 +9,7 @@ from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg, ParseKwargs -from .model import unwrap_model, get_state_dict, freeze, unfreeze +from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model from .model_ema import ModelEma, ModelEmaV2 from .random import random_seed from .summary import update_summary, get_outdir diff --git a/timm/utils/model.py b/timm/utils/model.py index d74ee5b76b..894453a856 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -3,6 +3,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ import fnmatch +from copy import deepcopy import torch from torchvision.ops.misc import FrozenBatchNorm2d @@ -219,3 +220,21 @@ def unfreeze(root_module, submodules=[], include_bn_running_stats=True): See example in docstring for `freeze`. """ _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze") + + +def reparameterize_model(model: torch.nn.Module, inplace=False) -> torch.nn.Module: + if not inplace: + model = deepcopy(model) + + def _fuse(m): + for child_name, child in m.named_children(): + if hasattr(child, 'fuse'): + setattr(m, child_name, child.fuse()) + elif hasattr(child, "reparameterize"): + child.reparameterize() + elif hasattr(child, "switch_to_deploy"): + child.switch_to_deploy() + _fuse(child) + + _fuse(model) + return model diff --git a/timm/version.py b/timm/version.py index 9272695b35..d0564e59af 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.9.5' +__version__ = '0.9.7' diff --git a/validate.py b/validate.py index 794d1ae8f9..8798f80e89 100755 --- a/validate.py +++ b/validate.py @@ -26,7 +26,7 @@ from timm.layers import apply_test_time_pool, set_fast_norm from timm.models import create_model, load_checkpoint, is_model, list_models from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ - decay_batch_step, check_batch_size_retry, ParseKwargs + decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model try: from apex import amp @@ -125,6 +125,8 @@ help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--fast-norm', default=False, action='store_true', help='enable experimental fast-norm') +parser.add_argument('--reparam', default=False, action='store_true', + help='Reparameterize model') parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) @@ -207,6 +209,9 @@ def validate(args): if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) + if args.reparam: + model = reparameterize_model(model) + param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count))