Skip to content

Add check for input shape #549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,23 @@ def initialize(self):
if self.classification_head is not None:
init.initialize_head(self.classification_head)

def check_input_shape(self, x):

h, w = x.shape[-2:]
output_stride = self.encoder.output_stride
if h % output_stride != 0 or w % output_stride != 0:
new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h
new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w
raise RuntimeError(
f"Wrong input shape height={h}, width={w}. Expected image height and width "
f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
)

def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

self.check_input_shape(x)

features = self.encoder(x)
decoder_output = self.decoder(*features)

Expand Down
8 changes: 8 additions & 0 deletions segmentation_models_pytorch/encoders/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@ class EncoderMixin:
- patching first convolution for arbitrary input channels
"""

_output_stride = 32

@property
def out_channels(self):
"""Return channels dimensions for each tensor of forward output of encoder"""
return self._out_channels[: self._depth + 1]

@property
def output_stride(self):
return min(self._output_stride, 2 ** self._depth)

def set_in_channels(self, in_channels, pretrained=True):
"""Change first convolution channels"""
if in_channels == 3:
Expand Down Expand Up @@ -49,6 +55,8 @@ def make_dilated(self, output_stride):
else:
raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride))

self._output_stride = output_stride

stages = self.get_stages()
for stage_indx, dilation_rate in zip(stage_list, dilation_list):
utils.replace_strides_with_dilation(
Expand Down
5 changes: 5 additions & 0 deletions segmentation_models_pytorch/encoders/timm_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=
in_channels,
] + self.model.feature_info.channels()
self._depth = depth
self._output_stride = output_stride

def forward(self, x):
features = self.model(x)
Expand All @@ -35,3 +36,7 @@ def forward(self, x):
@property
def out_channels(self):
return self._out_channels

@property
def output_stride(self):
return min(self._output_stride, 2 ** self._depth)