From ab8f9c17cbf60fe959046949ec7efa6181646ce9 Mon Sep 17 00:00:00 2001 From: qubvel Date: Sun, 30 Jan 2022 20:25:56 +0300 Subject: [PATCH] Add ckeck for input shape --- segmentation_models_pytorch/base/model.py | 15 +++++++++++++++ segmentation_models_pytorch/encoders/_base.py | 8 ++++++++ .../encoders/timm_universal.py | 5 +++++ 3 files changed, 28 insertions(+) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 5410b4d3..806f2740 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -9,8 +9,23 @@ def initialize(self): if self.classification_head is not None: init.initialize_head(self.classification_head) + def check_input_shape(self, x): + + h, w = x.shape[-2:] + output_stride = self.encoder.output_stride + if h % output_stride != 0 or w % output_stride != 0: + new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h + new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w + raise RuntimeError( + f"Wrong input shape height={h}, width={w}. Expected image height and width " + f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})." + ) + def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" + + self.check_input_shape(x) + features = self.encoder(x) decoder_output = self.decoder(*features) diff --git a/segmentation_models_pytorch/encoders/_base.py b/segmentation_models_pytorch/encoders/_base.py index 0aee1b83..94f43429 100644 --- a/segmentation_models_pytorch/encoders/_base.py +++ b/segmentation_models_pytorch/encoders/_base.py @@ -12,11 +12,17 @@ class EncoderMixin: - patching first convolution for arbitrary input channels """ + _output_stride = 32 + @property def out_channels(self): """Return channels dimensions for each tensor of forward output of encoder""" return self._out_channels[: self._depth + 1] + @property + def output_stride(self): + return min(self._output_stride, 2 ** self._depth) + def set_in_channels(self, in_channels, pretrained=True): """Change first convolution channels""" if in_channels == 3: @@ -49,6 +55,8 @@ def make_dilated(self, output_stride): else: raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride)) + self._output_stride = output_stride + stages = self.get_stages() for stage_indx, dilation_rate in zip(stage_list, dilation_list): utils.replace_strides_with_dilation( diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index fbfb8c73..8a27da9a 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -24,6 +24,7 @@ def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride= in_channels, ] + self.model.feature_info.channels() self._depth = depth + self._output_stride = output_stride def forward(self, x): features = self.model(x) @@ -35,3 +36,7 @@ def forward(self, x): @property def out_channels(self): return self._out_channels + + @property + def output_stride(self): + return min(self._output_stride, 2 ** self._depth)