16
16
import numpy as np
17
17
import scipy .sparse as sp
18
18
from distutils .version import LooseVersion
19
- from inspect import signature , isclass , Parameter
19
+ from inspect import signature , Parameter
20
+ from inspect import isclass , isfunction , ismethod , ismodule
20
21
21
22
from numpy .core .numeric import ComplexWarning
22
23
import joblib
@@ -212,6 +213,26 @@ def check_consistent_length(*arrays):
212
213
" samples: %r" % [int (l ) for l in lengths ])
213
214
214
215
216
+ def _convert_iterable (iterable ):
217
+ """Helper convert iterable to arrays of sparse matrices.
218
+
219
+ Convert sparse matrices to csr and non-interable objects to arrays.
220
+ Let passes `None`.
221
+
222
+ Parameters
223
+ ----------
224
+ iterable : {list, dataframe, array, sparse} or None
225
+ Object to be converted to a sliceable iterable.
226
+ """
227
+ if sp .issparse (iterable ):
228
+ return iterable .tocsr ()
229
+ elif hasattr (iterable , "__getitem__" ) or hasattr (iterable , "iloc" ):
230
+ return iterable
231
+ elif iterable is None :
232
+ return iterable
233
+ return np .array (iterable )
234
+
235
+
215
236
def indexable (* iterables ):
216
237
"""Make arrays indexable for cross-validation.
217
238
@@ -224,16 +245,7 @@ def indexable(*iterables):
224
245
*iterables : lists, dataframes, arrays, sparse matrices
225
246
List of objects to ensure sliceability.
226
247
"""
227
- result = []
228
- for X in iterables :
229
- if sp .issparse (X ):
230
- result .append (X .tocsr ())
231
- elif hasattr (X , "__getitem__" ) or hasattr (X , "iloc" ):
232
- result .append (X )
233
- elif X is None :
234
- result .append (X )
235
- else :
236
- result .append (np .array (X ))
248
+ result = [_convert_iterable (X ) for X in iterables ]
237
249
check_consistent_length (* result )
238
250
return result
239
251
@@ -1257,3 +1269,32 @@ def inner_f(*args, **kwargs):
1257
1269
kwargs .update ({k : arg for k , arg in zip (all_args , args )})
1258
1270
return f (** kwargs )
1259
1271
return inner_f
1272
+
1273
+
1274
+ def _check_fit_params (fit_params ):
1275
+ """Check and validate the parameters passed during `fit`.
1276
+
1277
+ Parameters
1278
+ ----------
1279
+ fit_params : dict
1280
+ Dictionary containing the parameters passed at fit.
1281
+
1282
+ Returns
1283
+ -------
1284
+ fit_params_validated : dict
1285
+ Validated parameters. We ensure that the values are iterable.
1286
+ """
1287
+ fit_params_validated = {}
1288
+ for param_key , param_value in fit_params .items ():
1289
+ is_scalar = [
1290
+ check (param_value )
1291
+ for check in [np .isscalar , ismodule , isclass , ismethod , isfunction ]
1292
+ ]
1293
+ if any (is_scalar ):
1294
+ # keep scalar as is for backward-compatibility
1295
+ # https://github.com/scikit-learn/scikit-learn/issues/15805
1296
+ fit_params_validated [param_key ] = param_value
1297
+ else :
1298
+ # ensure iterable will be sliceable
1299
+ fit_params_validated [param_key ] = _convert_iterable (param_value )
1300
+ return fit_params_validated
0 commit comments