From cf26a12d4275f52e8912ad0c9f03c84ea2eaf424 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sun, 7 Aug 2022 23:52:41 +0300 Subject: [PATCH 1/3] Segformer backbone --- README.md | 20 +- docs/encoders.rst | 15 + .../encoders/__init__.py | 2 + .../encoders/mix_transformer.py | 611 ++++++++++++++++++ tests/test_models.py | 4 +- 5 files changed, 649 insertions(+), 3 deletions(-) create mode 100644 segmentation_models_pytorch/encoders/mix_transformer.py diff --git a/README.md b/README.md index a51109bf..67fc56c2 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ The main features of this library are: - High level API (just two lines to create a neural network) - 9 models architectures for binary and multi class segmentation (including legendary Unet) - - 113 available encoders (and 400+ encoders from [timm](https://github.com/rwightman/pytorch-image-models)) + - 119 available encoders (and 400+ encoders from [timm](https://github.com/rwightman/pytorch-image-models)) - All encoders have pre-trained weights for faster and better convergence - Popular metrics and losses for training routines @@ -352,6 +352,24 @@ The following is a list of supported encoders in the SMP. Select the appropriate +
+Mix Vision Transformer +
+ +Backbone from SegFormer pretrained on Imagenet! Can be used with all other decoders from package, so you can combine Mix Visual Transformer with Unet, FPN and others! + +|Encoder |Weights |Params, M | +|--------------------------------|:------------------------------:|:------------------------------:| +|mit_b0 |imagenet |3M | +|mit_b1 |imagenet |13M | +|mit_b2 |imagenet |24M | +|mit_b3 |imagenet |44M | +|mit_b4 |imagenet |60M | +|mit_b5 |imagenet |81M | + +
+
+ \* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)). diff --git a/docs/encoders.rst b/docs/encoders.rst index e14f9546..3087d741 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -324,3 +324,18 @@ VGG +-------------+------------+-------------+ | vgg19\_bn | imagenet | 20M | +-------------+------------+-------------+ + + +Mix Visual Transformer +~~~~~~~~~~~~~~~~~~~~~ + ++-----------+----------+------------+ +| Encoder | Weights | Params, M | ++===========+==========+============+ +| mit\_b0 | imagenet | 3M | +| mit\_b1 | imagenet | 13M | +| mit\_b2 | imagenet | 24M | +| mit\_b3 | imagenet | 44M | +| mit\_b4 | imagenet | 60M | +| mit\_b5 | imagenet | 81M | ++-----------+----------+------------+ diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 93708aed..3a40be56 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -19,6 +19,7 @@ from .timm_sknet import timm_sknet_encoders from .timm_mobilenetv3 import timm_mobilenetv3_encoders from .timm_gernet import timm_gernet_encoders +from .mix_transformer import mix_transformer_encoders from .timm_universal import TimmUniversalEncoder @@ -42,6 +43,7 @@ encoders.update(timm_sknet_encoders) encoders.update(timm_mobilenetv3_encoders) encoders.update(timm_gernet_encoders) +encoders.update(mix_transformer_encoders) def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): diff --git a/segmentation_models_pytorch/encoders/mix_transformer.py b/segmentation_models_pytorch/encoders/mix_transformer.py new file mode 100644 index 00000000..d211e2ca --- /dev/null +++ b/segmentation_models_pytorch/encoders/mix_transformer.py @@ -0,0 +1,611 @@ +# --------------------------------------------------------------- +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# +# This work is licensed under the NVIDIA Source Code License +# --------------------------------------------------------------- +import math +import torch +import torch.nn as nn +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class OverlapPatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed( + img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0] + ) + self.patch_embed2 = OverlapPatchEmbed( + img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1] + ) + self.patch_embed3 = OverlapPatchEmbed( + img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2] + ) + self.patch_embed4 = OverlapPatchEmbed( + img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3] + ) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList( + [ + Block( + dim=embed_dims[0], + num_heads=num_heads[0], + mlp_ratio=mlp_ratios[0], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[0], + ) + for i in range(depths[0]) + ] + ) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList( + [ + Block( + dim=embed_dims[1], + num_heads=num_heads[1], + mlp_ratio=mlp_ratios[1], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[1], + ) + for i in range(depths[1]) + ] + ) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList( + [ + Block( + dim=embed_dims[2], + num_heads=num_heads[2], + mlp_ratio=mlp_ratios[2], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[2], + ) + for i in range(depths[2]) + ] + ) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList( + [ + Block( + dim=embed_dims[3], + num_heads=num_heads[3], + mlp_ratio=mlp_ratios[3], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[3], + ) + for i in range(depths[3]) + ] + ) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def init_weights(self, pretrained=None): + pass + + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed1", "pos_embed2", "pos_embed3", "pos_embed4", "cls_token"} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + outs = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + x = blk(x, H, W) + x = self.norm1(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + x = blk(x, H, W) + x = self.norm2(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + x = blk(x, H, W) + x = self.norm3(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + x = blk(x, H, W) + x = self.norm4(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + + return x + + +# --------------------------------------------------------------- +# End of NVIDIA code +# --------------------------------------------------------------- + +from ._base import EncoderMixin # noqa E402 + + +class MixVisionTransformerEncoder(MixVisionTransformer, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._out_channels = out_channels + self._depth = depth + self._in_channels = 3 + + def make_dilated(self, *args, **kwargs): + raise ValueError("MixVisionTransformer encoder does not support dilated mode") + + def set_in_channels(self, in_channels, *args, **kwargs): + if in_channels != 3: + raise ValueError("MixVisionTransformer encoder does not support in_channels setting other than 3") + + def forward(self, x): + + # create dummy output for the first block + B, C, H, W = x.shape + dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device) + + return [x, dummy] + self.forward_features(x)[: self._depth - 1] + + def load_state_dict(self, state_dict): + state_dict.pop("head.weight", None) + state_dict.pop("head.bias", None) + return super().load_state_dict(state_dict) + + +def get_pretrained_cfg(name): + return { + "url": "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/{}.pth".format(name), + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + } + + +mix_transformer_encoders = { + "mit_b0": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": { + "imagenet": get_pretrained_cfg("mit_b0"), + }, + "params": dict( + out_channels=(3, 0, 32, 64, 160, 256), + patch_size=4, + embed_dims=[32, 64, 160, 256], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, + "mit_b1": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": { + "imagenet": get_pretrained_cfg("mit_b1"), + }, + "params": dict( + out_channels=(3, 0, 64, 128, 320, 512), + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, + "mit_b2": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": { + "imagenet": get_pretrained_cfg("mit_b2"), + }, + "params": dict( + out_channels=(3, 0, 64, 128, 320, 512), + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, + "mit_b3": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": { + "imagenet": get_pretrained_cfg("mit_b3"), + }, + "params": dict( + out_channels=(3, 0, 64, 128, 320, 512), + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 18, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, + "mit_b4": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": { + "imagenet": get_pretrained_cfg("mit_b4"), + }, + "params": dict( + out_channels=(3, 0, 64, 128, 320, 512), + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 8, 27, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, + "mit_b5": { + "encoder": MixVisionTransformerEncoder, + "pretrained_settings": { + "imagenet": get_pretrained_cfg("mit_b5"), + }, + "params": dict( + out_channels=(3, 0, 64, 128, 320, 512), + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 6, 40, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ), + }, +} diff --git a/tests/test_models.py b/tests/test_models.py index ca1756a4..b94ed802 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -99,9 +99,8 @@ def test_upsample(model_class, upsampling): @pytest.mark.parametrize("model_class", [smp.FPN]) -@pytest.mark.parametrize("encoder_name", ENCODERS) @pytest.mark.parametrize("in_channels", [1, 2, 4]) -def test_in_channels(model_class, encoder_name, in_channels): +def test_in_channels(model_class, in_channels): sample = torch.ones([1, in_channels, 64, 64]) model = model_class(DEFAULT_ENCODER, encoder_weights=None, in_channels=in_channels) model.eval() @@ -118,6 +117,7 @@ def test_dilation(encoder_name): or encoder_name.startswith("vgg") or encoder_name.startswith("densenet") or encoder_name.startswith("timm-res") + or encoder_name.startswith("mit_b") ): return From 220c2b312c552cf8338b1a6e8fb87b67fc7ce572 Mon Sep 17 00:00:00 2001 From: qubvel Date: Mon, 8 Aug 2022 08:57:12 +0300 Subject: [PATCH 2/3] Add limitations for FPN, Unet++, Linknet --- README.md | 7 ++++++- segmentation_models_pytorch/decoders/fpn/model.py | 4 ++++ segmentation_models_pytorch/decoders/linknet/model.py | 3 +++ segmentation_models_pytorch/decoders/unetplusplus/model.py | 3 +++ tests/test_models.py | 7 ++++++- 5 files changed, 22 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 67fc56c2..b6718201 100644 --- a/README.md +++ b/README.md @@ -356,7 +356,12 @@ The following is a list of supported encoders in the SMP. Select the appropriate Mix Vision Transformer
-Backbone from SegFormer pretrained on Imagenet! Can be used with all other decoders from package, so you can combine Mix Visual Transformer with Unet, FPN and others! +Backbone from SegFormer pretrained on Imagenet! Can be used with other decoders from package, you can combine Mix Visual Transformer with Unet, FPN and others! + +Limitations: + + - encoder is not supported by Linknet, Unet++ + - encoder is not supported by FPN if encoder depth != 5 |Encoder |Weights |Params, M | |--------------------------------|:------------------------------:|:------------------------------:| diff --git a/segmentation_models_pytorch/decoders/fpn/model.py b/segmentation_models_pytorch/decoders/fpn/model.py index 555be529..7990b195 100644 --- a/segmentation_models_pytorch/decoders/fpn/model.py +++ b/segmentation_models_pytorch/decoders/fpn/model.py @@ -66,6 +66,10 @@ def __init__( ): super().__init__() + # validate input params + if encoder_name.startswith("mit_b") and encoder_depth != 5: + raise ValueError("Encoder {} support only encoder_depth=5".format(encoder_name)) + self.encoder = get_encoder( encoder_name, in_channels=in_channels, diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 0d47fada..509a8abf 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -64,6 +64,9 @@ def __init__( ): super().__init__() + if encoder_name.startswith("mit_b"): + raise ValueError("Encoder `{}` is not supported for Linknet".format(encoder_name)) + self.encoder = get_encoder( encoder_name, in_channels=in_channels, diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 4bd3c930..e5a38718 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -68,6 +68,9 @@ def __init__( ): super().__init__() + if encoder_name.startswith("mit_b"): + raise ValueError("UnetPlusPlus is not support encoder_name={}".format(encoder_name)) + self.encoder = get_encoder( encoder_name, in_channels=in_channels, diff --git a/tests/test_models.py b/tests/test_models.py index b94ed802..a8fb6368 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -18,7 +18,8 @@ def get_encoders(): encoders = smp.encoders.get_encoder_names() encoders = [e for e in encoders if e not in exclude_encoders] encoders.append("tu-resnet34") # for timm universal encoder - return encoders + # return encoders + return ["mit_b0"] ENCODERS = get_encoders() @@ -57,6 +58,10 @@ def _test_forward_backward(model, sample, test_shape=False): def test_forward(model_class, encoder_name, encoder_depth, **kwargs): if model_class is smp.Unet or model_class is smp.UnetPlusPlus or model_class is smp.MAnet: kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] + if model_class in [smp.UnetPlusPlus, smp.Linknet] and encoder_name.startswith("mit_b"): + return # skip mit_b* + if model_class is smp.FPN and encoder_name.startswith("mit_b") and encoder_depth != 5: + return # skip mit_b* model = model_class(encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs) sample = get_sample(model_class) model.eval() From e60d17c569769a383753359c8ab64803a610e07d Mon Sep 17 00:00:00 2001 From: qubvel Date: Mon, 8 Aug 2022 08:59:03 +0300 Subject: [PATCH 3/3] fix tests --- tests/test_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index a8fb6368..c2e6d941 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -18,8 +18,7 @@ def get_encoders(): encoders = smp.encoders.get_encoder_names() encoders = [e for e in encoders if e not in exclude_encoders] encoders.append("tu-resnet34") # for timm universal encoder - # return encoders - return ["mit_b0"] + return encoders ENCODERS = get_encoders()