Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ def _check_n_features(self, X, reset):
If True, the `n_features_in_` attribute is set to `X.shape[1]`.
Else, the attribute must already exist and the function checks
that it is equal to `X.shape[1]`.
.. note::
It is recommended to call reset=True in `fit` and in the first
call to `partial_fit`. All other methods that validate `X`
should set `reset=False`.
"""
n_features = X.shape[1]

Expand All @@ -378,7 +382,7 @@ def _check_n_features(self, X, reset):
self.n_features_in_)
)

def _validate_data(self, X, y=None, reset=True,
def _validate_data(self, X, y='no_validation', reset=True,
validate_separately=False, **check_params):
"""Validate input data and set or check the `n_features_in_` attribute.

Expand All @@ -387,13 +391,25 @@ def _validate_data(self, X, y=None, reset=True,
X : {array-like, sparse matrix, dataframe} of shape \
(n_samples, n_features)
The input samples.
y : array-like of shape (n_samples,), default=None
The targets. If None, `check_array` is called on `X` and
`check_X_y` is called otherwise.
y : array-like of shape (n_samples,), default='no_validation'
The targets.

- If `None`, `check_array` is called on `X`. If the estimator's
requires_y tag is True, then an error will be raised.
- If `'no_validation'`, `check_array` is called on `X` and the
estimator's requires_y tag is ignored. This is a default
placeholder and is never meant to be explicitly set.
- Otherwise, both `X` and `y` are checked with either `check_array`
or `check_X_y` depending on `validate_separately`.

reset : bool, default=True
Whether to reset the `n_features_in_` attribute.
If False, the input will be checked for consistency with data
provided when reset was last True.
.. note::
It is recommended to call reset=True in `fit` and in the first
call to `partial_fit`. All other methods that validate `X`
should set `reset=False`.
validate_separately : False or tuple of dicts, default=False
Only used if y is not None.
If False, call validate_X_y(). Else, it must be a tuple of kwargs
Expand All @@ -417,6 +433,9 @@ def _validate_data(self, X, y=None, reset=True,
)
X = check_array(X, **check_params)
out = X
elif isinstance(y, str) and y == 'no_validation':
X = check_array(X, **check_params)
out = X
else:
if validate_separately:
# We need this because some estimators validate X and y
Expand Down
21 changes: 12 additions & 9 deletions sklearn/neural_network/_multilayer_perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..utils import gen_batches, check_random_state
from ..utils import shuffle
from ..utils import _safe_indexing
from ..utils import check_array, column_or_1d
from ..utils import column_or_1d
from ..exceptions import ConvergenceWarning
from ..utils.extmath import safe_sparse_dot
from ..utils.validation import check_is_fitted, _deprecate_positional_args
Expand Down Expand Up @@ -131,7 +131,7 @@ def _forward_pass_fast(self, X):
y_pred : ndarray of shape (n_samples,) or (n_samples, n_outputs)
The decision function of the samples for each class in the model.
"""
X = check_array(X, accept_sparse=['csr', 'csc'])
X = self._validate_data(X, accept_sparse=['csr', 'csc'], reset=False)

# Initialize first layer
activation = X
Expand Down Expand Up @@ -358,8 +358,10 @@ def _fit(self, X, y, incremental=False):
if np.any(np.array(hidden_layer_sizes) <= 0):
raise ValueError("hidden_layer_sizes must be > 0, got %s." %
hidden_layer_sizes)
first_pass = (not hasattr(self, 'coefs_') or
(not self.warm_start and not incremental))

X, y = self._validate_input(X, y, incremental)
X, y = self._validate_input(X, y, incremental, reset=first_pass)

n_samples, n_features = X.shape

Expand All @@ -375,8 +377,7 @@ def _fit(self, X, y, incremental=False):
# check random state
self._random_state = check_random_state(self.random_state)

if not hasattr(self, 'coefs_') or (not self.warm_start and not
incremental):
if first_pass:
# First time training the model
self._initialize(y, layer_units, X.dtype)

Expand Down Expand Up @@ -963,10 +964,11 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu", *,
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon,
n_iter_no_change=n_iter_no_change, max_fun=max_fun)

def _validate_input(self, X, y, incremental):
def _validate_input(self, X, y, incremental, reset):
X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'],
multi_output=True,
dtype=(np.float64, np.float32))
dtype=(np.float64, np.float32),
reset=reset)
if y.ndim == 2 and y.shape[1] == 1:
y = column_or_1d(y, warn=True)

Expand Down Expand Up @@ -1409,10 +1411,11 @@ def predict(self, X):
return y_pred.ravel()
return y_pred

def _validate_input(self, X, y, incremental):
def _validate_input(self, X, y, incremental, reset):
X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'],
multi_output=True, y_numeric=True,
dtype=(np.float64, np.float32))
dtype=(np.float64, np.float32),
reset=reset)
if y.ndim == 2 and y.shape[1] == 1:
y = column_or_1d(y, warn=True)
return X, y
7 changes: 5 additions & 2 deletions sklearn/neural_network/_rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def transform(self, X):
"""
check_is_fitted(self)

X = check_array(X, accept_sparse='csr', dtype=(np.float64, np.float32))
X = self._validate_data(X, accept_sparse='csr', reset=False,
dtype=(np.float64, np.float32))
return self._mean_hiddens(X)

def _mean_hiddens(self, v):
Expand Down Expand Up @@ -243,7 +244,9 @@ def partial_fit(self, X, y=None):
self : BernoulliRBM
The fitted model.
"""
X = check_array(X, accept_sparse='csr', dtype=np.float64)
first_pass = not hasattr(self, 'components_')
X = self._validate_data(X, accept_sparse='csr', dtype=np.float64,
reset=first_pass)
if not hasattr(self, 'random_state_'):
self.random_state_ = check_random_state(self.random_state)
if not hasattr(self, 'components_'):
Expand Down
1 change: 1 addition & 0 deletions sklearn/neural_network/tests/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_fit():
mlp.intercepts_[1] = np.array([1.0])
mlp._coef_grads = [] * 2
mlp._intercept_grads = [] * 2
mlp.n_features_in_ = 3

# Initialize parameters
mlp.n_iter_ = 0
Expand Down
57 changes: 56 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
_set_checking_parameters,
_get_check_estimator_ids,
check_class_weight_balanced_linear_classifier,
parametrize_with_checks)
parametrize_with_checks,
check_n_features_in_after_fitting)


def test_all_estimator_no_base_class():
Expand Down Expand Up @@ -270,3 +271,57 @@ def test_strict_mode_check_estimator():
def test_strict_mode_parametrize_with_checks(estimator, check):
# Ideally we should assert that the strict checks are Xfailed...
check(estimator)


# TODO: When more modules get added, we can remove it from this list to make
# sure it gets tested. After we finish each module we can move the checks
# into sklearn.utils.estimator_checks.check_n_features_in.
#
# check_estimators_partial_fit_n_features can either be removed or updated
# with the two more assertions:
# 1. `n_features_in_` is set during the first call to `partial_fit`.
# 2. More strict when it comes to the error message.
#
# check_classifiers_train would need to be updated with the error message
N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE = {
'calibration',
'cluster',
'compose',
'covariance',
'cross_decomposition',
'decomposition',
'discriminant_analysis',
'ensemble',
'feature_extraction',
'feature_selection',
'gaussian_process',
'impute',
'isotonic',
'kernel_approximation',
'kernel_ridge',
'linear_model',
'manifold',
'mixture',
'model_selection',
'multiclass',
'multioutput',
'naive_bayes',
'neighbors',
'pipeline',
'preprocessing',
'random_projection',
'semi_supervised',
'svm',
'tree',
}

N_FEATURES_IN_AFTER_FIT_ESTIMATORS = [
est for est in _tested_estimators() if est.__module__.split('.')[1] not in
N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE
]


@pytest.mark.parametrize("estimator", N_FEATURES_IN_AFTER_FIT_ESTIMATORS,
ids=_get_check_estimator_ids)
def test_check_n_features_in_after_fitting(estimator):
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)
54 changes: 54 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3111,6 +3111,60 @@ def check_requires_y_none(name, estimator_orig, strict_mode=True):
warnings.warn(warning_msg, FutureWarning)


def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once all modules are supported, should this be merge with the already-existing check_n_features_in check?

We also have another check that specifically checks for error when the number of features are inconsistent (I don't remember the name). Should this one be removed then? (If yes let's document it here and next to N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE please)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea we should merge it into check_n_features_in.

We also have another check that specifically checks for error when the number of features are inconsistent (I don't remember the name). Should this one be removed then?

I think you are referring to check_estimators_partial_fit_n_features. This new check adds two new requirements on top of check_estimators_partial_fit_n_features:

  1. n_features_in_ is set during the first call to partial_fit.
  2. More strict when it comes to the error message.

I updated the comment with the above message.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was more referring to e.g. check_classifiers_train :

        msg = ("The classifier {} does not raise an error when the number of "
               "features in {} is different from the number of features in "
               "fit.")

but we can keep it as-is and remove later (or not, as long as it passes)

# Make sure that n_features_in are checked after fitting
tags = estimator_orig._get_tags()

if "2darray" not in tags["X_types"] or tags["no_validation"]:
return

rng = np.random.RandomState(0)

estimator = clone(estimator_orig)
set_random_state(estimator)
if 'warm_start' in estimator.get_params():
estimator.set_params(warm_start=False)

n_samples = 100
X = rng.normal(loc=100, size=(n_samples, 2))
X = _pairwise_estimator_convert_X(X, estimator)
if is_regressor(estimator):
y = rng.normal(size=n_samples)
else:
y = rng.randint(low=0, high=2, size=n_samples)
y = _enforce_estimator_tags_y(estimator, y)

estimator.fit(X, y)
assert estimator.n_features_in_ == X.shape[1]

# check methods will check n_features_in_
check_methods = ["predict", "transform", "decision_function",
"predict_proba"]
X_bad = X[:, [1]]

msg = (f"X has 1 features, but {name} is expecting {X.shape[1]} "
"features as input")
for method in check_methods:
if not hasattr(estimator, method):
continue
with raises(ValueError, match=msg):
getattr(estimator, method)(X_bad)

# partial_fit will check in the second call
if not hasattr(estimator, "partial_fit"):
return

estimator = clone(estimator_orig)
if is_classifier(estimator):
estimator.partial_fit(X, y, classes=np.unique(y))
else:
estimator.partial_fit(X, y)
assert estimator.n_features_in_ == X.shape[1]

with raises(ValueError, match=msg):
estimator.partial_fit(X_bad, y)


# set of checks that are completely strict, i.e. they have no non-strict part
_FULLY_STRICT_CHECKS = set([
'check_n_features_in',
Expand Down