Skip to content

Commit 74a3375

Browse files
TST move check_n_features_in_after_fitting to common tests (scikit-learn#29844)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 13de540 commit 74a3375

File tree

6 files changed

+105
-28
lines changed

6 files changed

+105
-28
lines changed

sklearn/neighbors/_classification.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,14 @@ def predict_proba(self, X):
411411
def __sklearn_tags__(self):
412412
tags = super().__sklearn_tags__()
413413
tags.classifier_tags.multi_label = True
414+
tags.input_tags.pairwise = self.metric == "precomputed"
415+
if tags.input_tags.pairwise:
416+
tags._xfail_checks.update(
417+
{
418+
"check_n_features_in_after_fitting": "FIXME",
419+
"check_dataframe_column_names_consistency": "FIXME",
420+
}
421+
)
414422
return tags
415423

416424

sklearn/neighbors/_regression.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,13 @@ def __sklearn_tags__(self):
195195
tags = super().__sklearn_tags__()
196196
# For cross-validation routines to split data correctly
197197
tags.input_tags.pairwise = self.metric == "precomputed"
198+
if tags.input_tags.pairwise:
199+
tags._xfail_checks.update(
200+
{
201+
"check_n_features_in_after_fitting": "FIXME",
202+
"check_dataframe_column_names_consistency": "FIXME",
203+
}
204+
)
198205
return tags
199206

200207
@_fit_context(

sklearn/tests/test_common.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
check_global_output_transform_pandas,
5252
check_global_set_output_transform_polars,
5353
check_inplace_ensure_writeable,
54-
check_n_features_in_after_fitting,
5554
check_param_validation,
5655
check_set_output_transform,
5756
check_set_output_transform_pandas,
@@ -243,13 +242,6 @@ def check_field_types(tags, defaults):
243242
check_field_types(tags.transformer_tags, defaults.transformer_tags)
244243

245244

246-
@pytest.mark.parametrize(
247-
"estimator", _tested_estimators(), ids=_get_check_estimator_ids
248-
)
249-
def test_check_n_features_in_after_fitting(estimator):
250-
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)
251-
252-
253245
def _estimators_that_predict_in_fit():
254246
for estimator in _tested_estimators():
255247
est_params = set(estimator.get_params())
@@ -286,6 +278,11 @@ def _estimators_that_predict_in_fit():
286278
def test_pandas_column_name_consistency(estimator):
287279
if isinstance(estimator, ColumnTransformer):
288280
pytest.skip("ColumnTransformer is not tested here")
281+
tags = get_tags(estimator)
282+
if "check_dataframe_column_names_consistency" in tags._xfail_checks:
283+
pytest.skip(
284+
"Estimator does not support check_dataframe_column_names_consistency"
285+
)
289286
with ignore_warnings(category=(FutureWarning)):
290287
with warnings.catch_warnings(record=True) as record:
291288
check_dataframe_column_names_consistency(

sklearn/utils/_test_common/instance_generator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@
140140
MultiOutputRegressor,
141141
RegressorChain,
142142
)
143-
from sklearn.neighbors import NeighborhoodComponentsAnalysis
143+
from sklearn.neighbors import (
144+
KNeighborsClassifier,
145+
KNeighborsRegressor,
146+
NeighborhoodComponentsAnalysis,
147+
)
144148
from sklearn.neural_network import BernoulliRBM, MLPClassifier, MLPRegressor
145149
from sklearn.pipeline import FeatureUnion, Pipeline
146150
from sklearn.preprocessing import OneHotEncoder, StandardScaler, TargetEncoder
@@ -308,6 +312,8 @@
308312
IncrementalPCA: dict(batch_size=10),
309313
IsolationForest: dict(n_estimators=5),
310314
KMeans: dict(n_init=2, n_clusters=2, max_iter=5),
315+
KNeighborsClassifier: [dict(n_neighbors=2), dict(metric="precomputed")],
316+
KNeighborsRegressor: [dict(n_neighbors=2), dict(metric="precomputed")],
311317
LabelPropagation: dict(max_iter=5),
312318
LabelSpreading: dict(max_iter=5),
313319
LarsCV: dict(max_iter=5, cv=3),
@@ -448,8 +454,8 @@
448454
],
449455
cv=3,
450456
),
451-
SVC: dict(max_iter=-1),
452-
SVR: dict(max_iter=-1),
457+
SVC: [dict(max_iter=-1), dict(kernel="precomputed")],
458+
SVR: [dict(max_iter=-1), dict(kernel="precomputed")],
453459
TargetEncoder: dict(cv=3),
454460
TheilSenRegressor: dict(max_iter=5, max_subpopulation=100),
455461
# TruncatedSVD doesn't run with n_components = n_features

sklearn/utils/estimator_checks.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pickle
77
import re
8+
import textwrap
89
import warnings
910
from contextlib import nullcontext
1011
from copy import deepcopy
@@ -95,6 +96,7 @@ def _yield_api_checks(estimator):
9596
if tags.requires_fit:
9697
yield check_estimators_unfitted
9798
yield check_do_not_raise_errors_in_init_or_set_params
99+
yield check_n_features_in_after_fitting
98100

99101

100102
def _yield_checks(estimator):
@@ -441,7 +443,7 @@ def _should_be_skipped_or_marked(estimator, check):
441443
return False, "placeholder reason that will never be used"
442444

443445

444-
def parametrize_with_checks(estimators, *, legacy=True):
446+
def parametrize_with_checks(estimators, *, legacy: bool = True):
445447
"""Pytest specific decorator for parametrizing estimator checks.
446448
447449
Checks are categorised into the following groups:
@@ -468,6 +470,7 @@ def parametrize_with_checks(estimators, *, legacy=True):
468470
469471
.. versionadded:: 0.24
470472
473+
471474
legacy : bool, default=True
472475
Whether to include legacy checks. Over time we remove checks from this category
473476
and move them into their specific category.
@@ -520,7 +523,7 @@ def checks_generator():
520523
)
521524

522525

523-
def check_estimator(estimator=None, generate_only=False, *, legacy=True):
526+
def check_estimator(estimator=None, generate_only=False, *, legacy: bool = True):
524527
"""Check if estimator adheres to scikit-learn conventions.
525528
526529
This function will run an extensive test-suite for input validation,
@@ -2009,13 +2012,14 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
20092012

20102013

20112014
@ignore_warnings(category=FutureWarning)
2012-
def check_classifier_multioutput(name, estimator):
2015+
def check_classifier_multioutput(name, estimator_orig):
20132016
n_samples, n_labels, n_classes = 42, 5, 3
2014-
tags = get_tags(estimator)
2015-
estimator = clone(estimator)
2017+
tags = get_tags(estimator_orig)
2018+
estimator = clone(estimator_orig)
20162019
X, y = make_multilabel_classification(
20172020
random_state=42, n_samples=n_samples, n_labels=n_labels, n_classes=n_classes
20182021
)
2022+
X = _enforce_estimator_tags_X(estimator, X)
20192023
estimator.fit(X, y)
20202024
y_pred = estimator.predict(X)
20212025

@@ -2174,13 +2178,14 @@ def check_clusterer_compute_labels_predict(name, clusterer_orig):
21742178
def check_classifiers_one_label(name, classifier_orig):
21752179
error_string_fit = "Classifier can't train when only one class is present."
21762180
error_string_predict = "Classifier can't predict when only one class is present."
2181+
classifier = clone(classifier_orig)
21772182
rnd = np.random.RandomState(0)
21782183
X_train = rnd.uniform(size=(10, 3))
21792184
X_test = rnd.uniform(size=(10, 3))
2185+
X_train, X_test = _enforce_estimator_tags_X(classifier, X_train, X_test=X_test)
21802186
y = np.ones(10)
21812187
# catch deprecation warnings
21822188
with ignore_warnings(category=FutureWarning):
2183-
classifier = clone(classifier_orig)
21842189
with raises(
21852190
ValueError, match="class", may_pass=True, err_msg=error_string_fit
21862191
) as cm:
@@ -2505,6 +2510,7 @@ def check_classifiers_multilabel_representation_invariance(name, classifier_orig
25052510

25062511
X_train, y_train = X[:80], y[:80]
25072512
X_test = X[80:]
2513+
X_train, X_test = _enforce_estimator_tags_X(classifier_orig, X_train, X_test=X_test)
25082514

25092515
y_train_list_of_lists = y_train.tolist()
25102516
y_train_list_of_arrays = list(y_train)
@@ -2552,6 +2558,7 @@ def check_classifiers_multilabel_output_format_predict(name, classifier_orig):
25522558

25532559
X_train, X_test = X[:-test_size], X[-test_size:]
25542560
y_train, y_test = y[:-test_size], y[-test_size:]
2561+
X_train, X_test = _enforce_estimator_tags_X(classifier_orig, X_train, X_test=X_test)
25552562
classifier.fit(X_train, y_train)
25562563

25572564
response_method_name = "predict"
@@ -2597,6 +2604,7 @@ def check_classifiers_multilabel_output_format_predict_proba(name, classifier_or
25972604

25982605
X_train, X_test = X[:-test_size], X[-test_size:]
25992606
y_train = y[:-test_size]
2607+
X_train, X_test = _enforce_estimator_tags_X(classifier_orig, X_train, X_test=X_test)
26002608
classifier.fit(X_train, y_train)
26012609

26022610
response_method_name = "predict_proba"
@@ -2681,6 +2689,7 @@ def check_classifiers_multilabel_output_format_decision_function(name, classifie
26812689

26822690
X_train, X_test = X[:-test_size], X[-test_size:]
26832691
y_train = y[:-test_size]
2692+
X_train, X_test = _enforce_estimator_tags_X(classifier_orig, X_train, X_test=X_test)
26842693
classifier.fit(X_train, y_train)
26852694

26862695
response_method_name = "decision_function"
@@ -3474,30 +3483,48 @@ def _enforce_estimator_tags_y(estimator, y):
34743483
return y
34753484

34763485

3477-
def _enforce_estimator_tags_X(estimator, X, kernel=linear_kernel):
3486+
def _enforce_estimator_tags_X(estimator, X, X_test=None, kernel=linear_kernel):
34783487
# Estimators with `1darray` in `X_types` tag only accept
34793488
# X of shape (`n_samples`,)
34803489
if get_tags(estimator).input_tags.one_d_array:
34813490
X = X[:, 0]
3491+
if X_test is not None:
3492+
X_test = X_test[:, 0] # pragma: no cover
34823493
# Estimators with a `requires_positive_X` tag only accept
34833494
# strictly positive data
34843495
if get_tags(estimator).input_tags.positive_only:
34853496
X = X - X.min()
3497+
if X_test is not None:
3498+
X_test = X_test - X_test.min() # pragma: no cover
34863499
if get_tags(estimator).input_tags.categorical:
34873500
dtype = np.float64 if get_tags(estimator).input_tags.allow_nan else np.int32
34883501
X = np.round((X - X.min())).astype(dtype)
3502+
if X_test is not None:
3503+
X_test = np.round((X_test - X_test.min())).astype(dtype) # pragma: no cover
34893504

34903505
if estimator.__class__.__name__ == "SkewedChi2Sampler":
34913506
# SkewedChi2Sampler requires X > -skewdness in transform
34923507
X = X - X.min()
3508+
if X_test is not None:
3509+
X_test = X_test - X_test.min() # pragma: no cover
3510+
3511+
X_res = X
34933512

34943513
# Pairwise estimators only accept
34953514
# X of shape (`n_samples`, `n_samples`)
34963515
if _is_pairwise_metric(estimator):
3497-
X = pairwise_distances(X, metric="euclidean")
3516+
X_res = pairwise_distances(X, metric="euclidean")
3517+
if X_test is not None:
3518+
X_test = pairwise_distances(
3519+
X_test, X, metric="euclidean"
3520+
) # pragma: no cover
34983521
elif get_tags(estimator).input_tags.pairwise:
3499-
X = kernel(X, X)
3500-
return X
3522+
X_res = kernel(X, X)
3523+
if X_test is not None:
3524+
X_test = kernel(X_test, X) # pragma: no cover
3525+
if X_test is not None:
3526+
return X_res, X_test
3527+
return X_res
35013528

35023529

35033530
@ignore_warnings(category=FutureWarning)
@@ -3913,8 +3940,16 @@ def check_n_features_in_after_fitting(name, estimator_orig):
39133940
y = rng.randint(low=0, high=2, size=n_samples)
39143941
y = _enforce_estimator_tags_y(estimator, y)
39153942

3943+
err_msg = (
3944+
"`{name}.fit()` does not set the `n_features_in_` attribute. "
3945+
"You might want to use `sklearn.utils.validation.validate_data` instead "
3946+
"of `check_array` in `{name}.fit()` which takes care of setting the "
3947+
"attribute.".format(name=name)
3948+
)
3949+
39163950
estimator.fit(X, y)
3917-
assert estimator.n_features_in_ == X.shape[1]
3951+
assert hasattr(estimator, "n_features_in_"), err_msg
3952+
assert estimator.n_features_in_ == X.shape[1], err_msg
39183953

39193954
# check methods will check n_features_in_
39203955
check_methods = [
@@ -3926,6 +3961,28 @@ def check_n_features_in_after_fitting(name, estimator_orig):
39263961
]
39273962
X_bad = X[:, [1]]
39283963

3964+
err_msg = """\
3965+
`{name}.{method}()` does not check for consistency between input number
3966+
of features with {name}.fit(), via the `n_features_in_` attribute.
3967+
You might want to use `sklearn.utils.validation.validate_data` instead
3968+
of `check_array` in `{name}.fit()` and {name}.{method}()`. This can be done
3969+
like the following:
3970+
from sklearn.utils.validation import validate_data
3971+
...
3972+
class MyEstimator(BaseEstimator):
3973+
...
3974+
def fit(self, X, y):
3975+
X, y = validate_data(self, X, y, ...)
3976+
...
3977+
return self
3978+
...
3979+
def {method}(self, X):
3980+
X = validate_data(self, X, ..., reset=False)
3981+
...
3982+
return X
3983+
"""
3984+
err_msg = textwrap.dedent(err_msg)
3985+
39293986
msg = f"X has 1 features, but \\w+ is expecting {X.shape[1]} features as input"
39303987
for method in check_methods:
39313988
if not hasattr(estimator, method):
@@ -3935,7 +3992,9 @@ def check_n_features_in_after_fitting(name, estimator_orig):
39353992
if method == "score":
39363993
callable_method = partial(callable_method, y=y)
39373994

3938-
with raises(ValueError, match=msg):
3995+
with raises(
3996+
ValueError, match=msg, err_msg=err_msg.format(name=name, method=method)
3997+
):
39393998
callable_method(X_bad)
39403999

39414000
# partial_fit will check in the second call

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ def fit(self, X, y=None):
304304
return self
305305

306306
def transform(self, X):
307-
X = check_array(X)
307+
check_is_fitted(self)
308+
X = validate_data(self, X, reset=False)
308309
return X
309310

310311

@@ -422,16 +423,15 @@ def __init__(self, sparse_container=None):
422423
self.sparse_container = sparse_container
423424

424425
def fit(self, X, y=None):
425-
self.X_shape_ = validate_data(self, X).shape
426+
validate_data(self, X)
426427
return self
427428

428429
def fit_transform(self, X, y=None):
429430
return self.fit(X, y).transform(X)
430431

431432
def transform(self, X):
432-
X = check_array(X)
433-
if X.shape[1] != self.X_shape_[1]:
434-
raise ValueError("Bad number of features")
433+
check_is_fitted(self)
434+
X = validate_data(self, X, accept_sparse=True, reset=False)
435435
return self.sparse_container(X)
436436

437437

0 commit comments

Comments
 (0)