Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from .utils.validation import check_X_y
from .utils.validation import check_array
from .utils.validation import _check_y
from .utils.validation import _num_features
from .utils._estimator_html_repr import estimator_html_repr

Expand Down Expand Up @@ -376,25 +377,33 @@ def _check_n_features(self, X, reset):
f"X has {n_features} features, but {self.__class__.__name__} "
f"is expecting {self.n_features_in_} features as input.")

def _validate_data(self, X, y='no_validation', reset=True,
def _validate_data(self, X='no_validation', y='no_validation', reset=True,
validate_separately=False, **check_params):
"""Validate input data and set or check the `n_features_in_` attribute.

Parameters
----------
X : {array-like, sparse matrix, dataframe} of shape \
(n_samples, n_features)
(n_samples, n_features), default='no validation'
The input samples.
If `'no_validation'`, no validation is performed on `X`. This is
useful for meta-estimator which can delegate input validation to
their underlying estimator(s). In that case `y` must be passed and
the only accepted `check_params` are `multi_output` and
`y_numeric`.

y : array-like of shape (n_samples,), default='no_validation'
The targets.

- If `None`, `check_array` is called on `X`. If the estimator's
requires_y tag is True, then an error will be raised.
- If `'no_validation'`, `check_array` is called on `X` and the
estimator's requires_y tag is ignored. This is a default
placeholder and is never meant to be explicitly set.
- Otherwise, both `X` and `y` are checked with either `check_array`
or `check_X_y` depending on `validate_separately`.
placeholder and is never meant to be explicitly set. In that case
`X` must be passed.
- Otherwise, only `y` with `_check_y` or both `X` and `y` are
checked with either `check_array` or `check_X_y` depending on
`validate_separately`.

reset : bool, default=True
Whether to reset the `n_features_in_` attribute.
Expand All @@ -416,20 +425,26 @@ def _validate_data(self, X, y='no_validation', reset=True,
Returns
-------
out : {ndarray, sparse matrix} or tuple of these
The validated input. A tuple is returned if `y` is not None.
The validated input. A tuple is returned if both `X` and `y` are
validated.
"""
if y is None and self._get_tags()['requires_y']:
raise ValueError(
f"This {self.__class__.__name__} estimator "
f"requires y to be passed, but the target y is None."
)

if y is None:
if self._get_tags()['requires_y']:
raise ValueError(
f"This {self.__class__.__name__} estimator "
f"requires y to be passed, but the target y is None."
)
X = check_array(X, **check_params)
out = X
elif isinstance(y, str) and y == 'no_validation':
no_val_X = isinstance(X, str) and X == 'no_validation'
no_val_y = y is None or isinstance(y, str) and y == 'no_validation'

if no_val_X and no_val_y:
raise ValueError("Validation should be done on X, y or both.")
elif not no_val_X and no_val_y:
X = check_array(X, **check_params)
out = X
elif no_val_X and not no_val_y:
y = _check_y(y, **check_params)
out = y
else:
if validate_separately:
# We need this because some estimators validate X and y
Expand All @@ -443,7 +458,7 @@ def _validate_data(self, X, y='no_validation', reset=True,
X, y = check_X_y(X, y, **check_params)
out = X, y

if check_params.get('ensure_2d', True):
if not no_val_X and check_params.get('ensure_2d', True):
self._check_n_features(X, reset=reset)

return out
Expand Down
5 changes: 1 addition & 4 deletions sklearn/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
from .utils._tags import _safe_tags
from .utils.validation import _num_samples
from .utils.validation import check_is_fitted
from .utils.validation import column_or_1d
from .utils.validation import _assert_all_finite
from .utils.multiclass import (_check_partial_fit_first_call,
check_classification_targets,
_ovr_decision_function)
Expand Down Expand Up @@ -909,8 +907,7 @@ def fit(self, X, y):
-------
self
"""
y = column_or_1d(y, warn=True)
_assert_all_finite(y)
y = self._validate_data(X='no_validation', y=y)

if self.code_size <= 0:
raise ValueError("code_size should be greater than 0, got {0}"
Expand Down
7 changes: 7 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_num_samples,
check_scalar,
_check_psd_eigenvalues,
_check_y,
_deprecate_positional_args,
_check_sample_weight,
_allclose_dense_sparse,
Expand Down Expand Up @@ -679,6 +680,12 @@ def test_check_array_complex_data_error():
with pytest.raises(ValueError, match="Complex data not supported"):
check_array(X)

# target variable does not always go through check_array but should
# never accept complex data either.
y = np.array([1 + 2j, 3 + 4j, 5 + 7j, 2 + 3j, 4 + 5j, 6 + 7j])
with pytest.raises(ValueError, match="Complex data not supported"):
_check_y(y)


def test_has_fit_parameter():
assert not has_fit_parameter(KNeighborsClassifier, "sample_weight")
Expand Down
15 changes: 12 additions & 3 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,18 +880,27 @@ def check_X_y(X, y, accept_sparse=False, *, accept_large_sparse=True,
ensure_min_samples=ensure_min_samples,
ensure_min_features=ensure_min_features,
estimator=estimator)

y = _check_y(y, multi_output=multi_output, y_numeric=y_numeric)

check_consistent_length(X, y)

return X, y


def _check_y(y, multi_output=False, y_numeric=False):
"""Isolated part of check_X_y dedicated to y validation"""
if multi_output:
y = check_array(y, accept_sparse='csr', force_all_finite=True,
ensure_2d=False, dtype=None)
else:
y = column_or_1d(y, warn=True)
_assert_all_finite(y)
_ensure_no_complex_data(y)
if y_numeric and y.dtype.kind == 'O':
y = y.astype(np.float64)

check_consistent_length(X, y)

return X, y
return y


def column_or_1d(y, *, warn=False):
Expand Down