Skip to content

Commit 96a96f1

Browse files
authored
ENH Adds n_features_in_ checking in cross_decomposition (#18741)
1 parent 71f7085 commit 96a96f1

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

sklearn/cross_decomposition/_pls.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def transform(self, X, Y=None, copy=True):
317317
`x_scores` if `Y` is not given, `(x_scores, y_scores)` otherwise.
318318
"""
319319
check_is_fitted(self)
320-
X = check_array(X, copy=copy, dtype=FLOAT_DTYPES)
320+
X = self._validate_data(X, copy=copy, dtype=FLOAT_DTYPES, reset=False)
321321
# Normalize
322322
X -= self._x_mean
323323
X /= self._x_std
@@ -379,7 +379,7 @@ def predict(self, X, copy=True):
379379
space.
380380
"""
381381
check_is_fitted(self)
382-
X = check_array(X, copy=copy, dtype=FLOAT_DTYPES)
382+
X = self._validate_data(X, copy=copy, dtype=FLOAT_DTYPES, reset=False)
383383
# Normalize
384384
X -= self._x_mean
385385
X /= self._x_std
@@ -984,7 +984,7 @@ def transform(self, X, Y=None):
984984
`(X_transformed, Y_transformed)` otherwise.
985985
"""
986986
check_is_fitted(self)
987-
X = check_array(X, dtype=np.float64)
987+
X = self._validate_data(X, dtype=np.float64, reset=False)
988988
Xr = (X - self._x_mean) / self._x_std
989989
x_scores = np.dot(Xr, self.x_weights_)
990990
if Y is not None:

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,6 @@ def test_search_cv(estimator, check, request):
267267
'calibration',
268268
'compose',
269269
'covariance',
270-
'cross_decomposition',
271270
'discriminant_analysis',
272271
'ensemble',
273272
'feature_extraction',

sklearn/utils/estimator_checks.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
load_iris,
6464
make_blobs,
6565
make_multilabel_classification,
66-
make_regression,
66+
make_regression
6767
)
6868

6969
REGRESSION_DATASET = None
@@ -646,6 +646,9 @@ def _set_checking_parameters(estimator):
646646
if name == 'OneHotEncoder':
647647
estimator.set_params(handle_unknown='ignore')
648648

649+
if name in CROSS_DECOMPOSITION:
650+
estimator.set_params(n_components=1)
651+
649652

650653
class _NotAnArray:
651654
"""An object that is convertible to an array.
@@ -3122,9 +3125,11 @@ def check_n_features_in_after_fitting(name, estimator_orig):
31223125
if 'warm_start' in estimator.get_params():
31233126
estimator.set_params(warm_start=False)
31243127

3125-
n_samples = 100
3126-
X = rng.normal(loc=100, size=(n_samples, 2))
3128+
n_samples = 150
3129+
X = rng.normal(size=(n_samples, 8))
3130+
X = _enforce_estimator_tags_x(estimator, X)
31273131
X = _pairwise_estimator_convert_X(X, estimator)
3132+
31283133
if is_regressor(estimator):
31293134
y = rng.normal(size=n_samples)
31303135
else:

0 commit comments

Comments
 (0)