Skip to content

Commit 4ecf826

Browse files
authored
Fix pickling preprocess_input function
1 parent 0dadc48 commit 4ecf826

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import torch.utils.model_zoo as model_zoo
23

34
from .resnet import resnet_encoders
@@ -44,8 +45,5 @@ def get_preprocessing_fn(encoder_name, pretrained='imagenet'):
4445
input_range = settings[pretrained].get('input_range')
4546
mean = settings[pretrained].get('mean')
4647
std = settings[pretrained].get('std')
47-
48-
def _preprocess_input(x, **kwargs):
49-
return preprocess_input(x, mean=mean, std=std, input_space=input_space, input_range=input_range, **kwargs)
50-
51-
return _preprocess_input
48+
49+
return functools.partial(preprocess_input, mean=mean, std=std, input_space=input_space, input_range=input_range)

0 commit comments

Comments
 (0)