diff --git a/sklearn/base.py b/sklearn/base.py index 5c7168adabc5e..fa54e49e47bb4 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -21,7 +21,7 @@ ) from .exceptions import InconsistentVersionWarning from .utils.validation import check_X_y -from .utils.validation import check_array +from .utils.validation import _check_array from .utils.validation import _check_y from .utils.validation import _num_features from .utils.validation import _check_feature_names_in @@ -574,7 +574,7 @@ def _validate_data( if no_val_X and no_val_y: raise ValueError("Validation should be done on X, y or both.") elif not no_val_X and no_val_y: - X = check_array(X, input_name="X", **check_params) + X = _check_array(X, input_name="X", **check_params) out = X elif no_val_X and not no_val_y: y = _check_y(y, **check_params) @@ -588,10 +588,10 @@ def _validate_data( check_X_params, check_y_params = validate_separately if "estimator" not in check_X_params: check_X_params = {**default_check_params, **check_X_params} - X = check_array(X, input_name="X", **check_X_params) + X = _check_array(X, input_name="X", **check_X_params) if "estimator" not in check_y_params: check_y_params = {**default_check_params, **check_y_params} - y = check_array(y, input_name="y", **check_y_params) + y = _check_array(y, input_name="y", **check_y_params) else: X, y = check_X_y(X, y, **check_params) out = X, y diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index e49cb0694b9ad..6be297d341913 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1759,3 +1759,28 @@ def test_boolean_series_remains_boolean(): assert res.dtype == expected.dtype assert_array_equal(res, expected) + + +def test_custom_asarray(): + """Check that a custom `asarray` function can be used during validation""" + est = BaseEstimator() + + # This "special" asarray method converts elements of the input array + # to ints before creating a numpy array. Mostly so we can be sure it + # and not the standard asarray, was used for the data validation. + def my_asarray(array, copy=False, **kwargs): + converted = [] + for row in array: + r = [] + for element in row: + r.append(int(element)) + converted.append(r) + + if copy: + return np.array(converted, **kwargs) + else: + return np.asarray(converted, **kwargs) + + x = est._validate_data([["1", "2", "3"], ["4", "5", "6"]], asarray=my_asarray) + + assert x.dtype == int diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index dd0c007602654..5becbce25bf13 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -9,7 +9,7 @@ # Sylvain Marie # License: BSD 3 clause -from functools import wraps +from functools import partial, wraps import warnings import numbers import operator @@ -642,6 +642,130 @@ def check_array( estimator=None, input_name="", ): + """Input validation on an array, list, sparse matrix or similar. + + By default, the input is checked to be a non-empty 2D array containing + only finite values. If the dtype of the array is object, attempt + converting to float, raising on failure. + + Parameters + ---------- + array : object + Input object to check / convert. + + accept_sparse : str, bool or list/tuple of str, default=False + String[s] representing allowed sparse matrix formats, such as 'csc', + 'csr', etc. If the input is sparse but not in the allowed format, + it will be converted to the first listed format. True allows the input + to be any format. False means that a sparse matrix input will + raise an error. + + accept_large_sparse : bool, default=True + If a CSR, CSC, COO or BSR sparse matrix is supplied and accepted by + accept_sparse, accept_large_sparse=False will cause it to be accepted + only if its indices are stored with a 32-bit dtype. + + .. versionadded:: 0.20 + + dtype : 'numeric', type, list of type or None, default='numeric' + Data type of result. If None, the dtype of the input is preserved. + If "numeric", dtype is preserved unless array.dtype is object. + If dtype is a list of types, conversion on the first type is only + performed if the dtype of the input is not in the list. + + order : {'F', 'C'} or None, default=None + Whether an array will be forced to be fortran or c-style. + When order is None (default), then if copy=False, nothing is ensured + about the memory layout of the output array; otherwise (copy=True) + the memory layout of the returned array is kept as close as possible + to the original array. + + copy : bool, default=False + Whether a forced copy will be triggered. If copy=False, a copy might + be triggered by a conversion. + + force_all_finite : bool or 'allow-nan', default=True + Whether to raise an error on np.inf, np.nan, pd.NA in array. The + possibilities are: + + - True: Force all values of array to be finite. + - False: accepts np.inf, np.nan, pd.NA in array. + - 'allow-nan': accepts only np.nan and pd.NA values in array. Values + cannot be infinite. + + .. versionadded:: 0.20 + ``force_all_finite`` accepts the string ``'allow-nan'``. + + .. versionchanged:: 0.23 + Accepts `pd.NA` and converts it into `np.nan` + + ensure_2d : bool, default=True + Whether to raise a value error if array is not 2D. + + allow_nd : bool, default=False + Whether to allow array.ndim > 2. + + ensure_min_samples : int, default=1 + Make sure that the array has a minimum number of samples in its first + axis (rows for a 2D array). Setting to 0 disables this check. + + ensure_min_features : int, default=1 + Make sure that the 2D array has some minimum number of features + (columns). The default value of 1 rejects empty datasets. + This check is only enforced when the input data has effectively 2 + dimensions or is originally 1D and ``ensure_2d`` is True. Setting to 0 + disables this check. + + estimator : str or estimator instance, default=None + If passed, include the name of the estimator in warning messages. + + input_name : str, default="" + The data name used to construct the error message. In particular + if `input_name` is "X" and the data has NaN values and + allow_nan is False, the error message will link to the imputer + documentation. + + .. versionadded:: 1.1.0 + + Returns + ------- + array_converted : object + The converted and validated array. + """ + return _check_array( + array, + accept_sparse=accept_sparse, + accept_large_sparse=accept_large_sparse, + dtype=dtype, + order=order, + copy=copy, + force_all_finite=force_all_finite, + ensure_2d=ensure_2d, + allow_nd=allow_nd, + ensure_min_samples=ensure_min_samples, + ensure_min_features=ensure_min_features, + estimator=estimator, + input_name=input_name, + ) + + +def _check_array( + array, + accept_sparse=False, + *, + accept_large_sparse=True, + dtype="numeric", + order=None, + copy=False, + force_all_finite=True, + ensure_2d=True, + allow_nd=False, + ensure_min_samples=1, + ensure_min_features=1, + estimator=None, + input_name="", + asarray=None, +): """Input validation on an array, list, sparse matrix or similar. @@ -728,6 +852,15 @@ def check_array( .. versionadded:: 1.1.0 + asarray : callable, default=None + A callable to use instead of `np.asarray` when converting the input + array. Useful when the input array is not a Numpy array or when the + converted array should be a ndarray from a differnt library. The callable + should have the same signature as `np.asarray` and in addition support + they `copy` keyword argument. + + .. versionadded:: 1.3.0 + Returns ------- array_converted : object @@ -742,6 +875,9 @@ def check_array( xp, is_array_api = get_namespace(array) + if asarray is None: + asarray = partial(_asarray_with_order, xp=xp) + # store reference to original array to check if copy is needed when # function returns array_orig = array @@ -865,7 +1001,7 @@ def check_array( # Conversion float -> int should not contain NaN or # inf (numpy#14412). We cannot use casting='safe' because # then conversion float -> int would be disallowed. - array = _asarray_with_order(array, order=order, xp=xp) + array = asarray(array, order=order) if array.dtype.kind == "f": _assert_all_finite( array, @@ -876,7 +1012,7 @@ def check_array( ) array = xp.astype(array, dtype, copy=False) else: - array = _asarray_with_order(array, order=order, dtype=dtype, xp=xp) + array = asarray(array, order=order, dtype=dtype) except ComplexWarning as complex_warning: raise ValueError( "Complex data not supported\n{}\n".format(array) @@ -947,13 +1083,19 @@ def check_array( if xp.__name__ in {"numpy", "numpy.array_api"}: # only make a copy if `array` and `array_orig` may share memory` if np.may_share_memory(array, array_orig): - array = _asarray_with_order( - array, dtype=dtype, order=order, copy=True, xp=xp + array = asarray( + array, + dtype=dtype, + order=order, + copy=True, ) else: # always make a copy for non-numpy arrays - array = _asarray_with_order( - array, dtype=dtype, order=order, copy=True, xp=xp + array = asarray( + array, + dtype=dtype, + order=order, + copy=True, ) return array