diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 4b1725e1..93708aed 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -1,3 +1,4 @@ +import timm import functools import torch.utils.model_zoo as model_zoo @@ -91,16 +92,24 @@ def get_encoder_names(): def get_preprocessing_params(encoder_name, pretrained="imagenet"): - settings = encoders[encoder_name]["pretrained_settings"] - if pretrained not in settings.keys(): - raise ValueError("Available pretrained options {}".format(settings.keys())) + if encoder_name.startswith("tu-"): + encoder_name = encoder_name[3:] + if encoder_name not in timm.models.registry._model_has_pretrained: + raise ValueError(f"{encoder_name} does not have pretrained weights and preprocessing parameters") + settings = timm.models.registry._model_default_cfgs[encoder_name] + else: + all_settings = encoders[encoder_name]["pretrained_settings"] + if pretrained not in all_settings.keys(): + raise ValueError("Available pretrained options {}".format(all_settings.keys())) + settings = all_settings[pretrained] formatted_settings = {} - formatted_settings["input_space"] = settings[pretrained].get("input_space") - formatted_settings["input_range"] = settings[pretrained].get("input_range") - formatted_settings["mean"] = settings[pretrained].get("mean") - formatted_settings["std"] = settings[pretrained].get("std") + formatted_settings["input_space"] = settings.get("input_space", "RGB") + formatted_settings["input_range"] = list(settings.get("input_range", [0, 1])) + formatted_settings["mean"] = list(settings.get("mean")) + formatted_settings["std"] = list(settings.get("std")) + return formatted_settings diff --git a/tests/test_models.py b/tests/test_models.py index d3f3ad6b..ca1756a4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,3 @@ -import os import sys import mock import pytest @@ -6,7 +5,7 @@ # mock detection module sys.modules["torchvision._C"] = mock.Mock() -import segmentation_models_pytorch as smp +import segmentation_models_pytorch as smp # noqa def get_encoders(): diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index a97ed7bc..282232d3 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,13 +1,10 @@ -import os import sys import mock -import pytest import numpy as np # mock detection module sys.modules["torchvision._C"] = mock.Mock() - -import segmentation_models_pytorch as smp +import segmentation_models_pytorch as smp # noqa def _test_preprocessing(inp, out, **params): @@ -41,3 +38,19 @@ def test_input_space(): inp = np.stack([np.ones((32, 32)), np.zeros((32, 32))], axis=-1) out = np.stack([np.zeros((32, 32)), np.ones((32, 32))], axis=-1) _test_preprocessing(inp, out, input_space="BGR") + + +def test_preprocessing_params(): + # check default encoder params + params = smp.encoders.get_preprocessing_params("resnet18") + assert params["mean"] == [0.485, 0.456, 0.406] + assert params["std"] == [0.229, 0.224, 0.225] + assert params["input_range"] == [0, 1] + assert params["input_space"] == "RGB" + + # check timm params + params = smp.encoders.get_preprocessing_params("tu-resnet18") + assert params["mean"] == [0.485, 0.456, 0.406] + assert params["std"] == [0.229, 0.224, 0.225] + assert params["input_range"] == [0, 1] + assert params["input_space"] == "RGB"