Skip to content

Commit 53d7a91

Browse files
committed
FIX introduce _check_fit_params to validate parameters
1 parent 341f8e0 commit 53d7a91

File tree

2 files changed

+54
-21
lines changed

2 files changed

+54
-21
lines changed

sklearn/model_selection/_search.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ..utils import check_random_state
3434
from ..utils.fixes import MaskedArray
3535
from ..utils.random import sample_without_replacement
36-
from ..utils.validation import indexable, check_is_fitted
36+
from ..utils.validation import indexable, check_is_fitted, _check_fit_params
3737
from ..utils.metaestimators import if_delegate_has_method
3838
from ..metrics._scorer import _check_multimetric_scoring
3939
from ..metrics import check_scoring
@@ -648,15 +648,7 @@ def fit(self, X, y=None, groups=None, **fit_params):
648648
refit_metric = 'score'
649649

650650
X, y, groups = indexable(X, y, groups)
651-
# make sure fit_params are sliceable
652-
# TODO: remove in 0.24
653-
try:
654-
fit_params_values = indexable(*fit_params.values())
655-
fit_params = dict(zip(fit_params.keys(), fit_params_values))
656-
except TypeError:
657-
warnings.warn("Support for scaler fit params is deprecated. "
658-
"Passing scalar values as a fit parameter will "
659-
"raise an error in v0.24", FutureWarning)
651+
fit_params = _check_fit_params(fit_params)
660652

661653
n_splits = cv.get_n_splits(X, y, groups)
662654

sklearn/utils/validation.py

+52-11
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import numpy as np
1717
import scipy.sparse as sp
1818
from distutils.version import LooseVersion
19-
from inspect import signature, isclass, Parameter
19+
from inspect import signature, Parameter
20+
from inspect import isclass, isfunction, ismethod, ismodule
2021

2122
from numpy.core.numeric import ComplexWarning
2223
import joblib
@@ -212,6 +213,26 @@ def check_consistent_length(*arrays):
212213
" samples: %r" % [int(l) for l in lengths])
213214

214215

216+
def _convert_iterable(iterable):
217+
"""Helper convert iterable to arrays of sparse matrices.
218+
219+
Convert sparse matrices to csr and non-interable objects to arrays.
220+
Let passes `None`.
221+
222+
Parameters
223+
----------
224+
iterable : {list, dataframe, array, sparse} or None
225+
Object to be converted to a sliceable iterable.
226+
"""
227+
if sp.issparse(iterable):
228+
return iterable.tocsr()
229+
elif hasattr(iterable, "__getitem__") or hasattr(iterable, "iloc"):
230+
return iterable
231+
elif iterable is None:
232+
return iterable
233+
return np.array(iterable)
234+
235+
215236
def indexable(*iterables):
216237
"""Make arrays indexable for cross-validation.
217238
@@ -224,16 +245,7 @@ def indexable(*iterables):
224245
*iterables : lists, dataframes, arrays, sparse matrices
225246
List of objects to ensure sliceability.
226247
"""
227-
result = []
228-
for X in iterables:
229-
if sp.issparse(X):
230-
result.append(X.tocsr())
231-
elif hasattr(X, "__getitem__") or hasattr(X, "iloc"):
232-
result.append(X)
233-
elif X is None:
234-
result.append(X)
235-
else:
236-
result.append(np.array(X))
248+
result = [_convert_iterable(X) for X in iterables]
237249
check_consistent_length(*result)
238250
return result
239251

@@ -1257,3 +1269,32 @@ def inner_f(*args, **kwargs):
12571269
kwargs.update({k: arg for k, arg in zip(all_args, args)})
12581270
return f(**kwargs)
12591271
return inner_f
1272+
1273+
1274+
def _check_fit_params(fit_params):
1275+
"""Check and validate the parameters passed during `fit`.
1276+
1277+
Parameters
1278+
----------
1279+
fit_params : dict
1280+
Dictionary containing the parameters passed at fit.
1281+
1282+
Returns
1283+
-------
1284+
fit_params_validated : dict
1285+
Validated parameters. We ensure that the values are iterable.
1286+
"""
1287+
fit_params_validated = {}
1288+
for param_key, param_value in fit_params.items():
1289+
is_scalar = [
1290+
check(param_value)
1291+
for check in [np.isscalar, ismodule, isclass, ismethod, isfunction]
1292+
]
1293+
if any(is_scalar):
1294+
# keep scalar as is for backward-compatibility
1295+
# https://github.com/scikit-learn/scikit-learn/issues/15805
1296+
fit_params_validated[param_key] = param_value
1297+
else:
1298+
# ensure iterable will be sliceable
1299+
fit_params_validated[param_key] = _convert_iterable(param_value)
1300+
return fit_params_validated

0 commit comments

Comments
 (0)