Skip to content

Commit 13de540

Browse files
authored
TST fix estimator checks when set_output is called on the instance (scikit-learn#29869)
1 parent c3d4e57 commit 13de540

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

doc/whats_new/v1.6.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,11 @@ Changelog
367367
calling :func:`utils.validation.check_non_negative`.
368368
:pr:`29540` by :user:`Tamara Atanasoska <tamaraatanasoska>`.
369369

370+
- |FIX| :func:`utils.estimator_checks.parametrize_with_checks` and
371+
:func:`utils.estimator_checks.check_estimator` now support estimators that
372+
have `set_output` called on them.
373+
:pr:`29869` by `Adrin Jalali`_.
374+
370375
- |Enhancement| :func:`utils.validation.check_is_fitted` now passes on stateless
371376
estimators. An estimator can indicate it's stateless by setting the `requires_fit`
372377
tag. See :ref:`estimator_tags` for more information.

sklearn/utils/estimator_checks.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ..model_selection._validation import _safe_split
4141
from ..pipeline import make_pipeline
4242
from ..preprocessing import StandardScaler, scale
43+
from ..utils import _safe_indexing
4344
from ..utils._array_api import (
4445
_atol_for_type,
4546
_convert_to_numpy,
@@ -1458,8 +1459,8 @@ def check_methods_sample_order_invariance(name, estimator_orig):
14581459

14591460
if hasattr(estimator, method):
14601461
assert_allclose_dense_sparse(
1461-
getattr(estimator, method)(X)[idx],
1462-
getattr(estimator, method)(X[idx]),
1462+
_safe_indexing(getattr(estimator, method)(X), idx),
1463+
getattr(estimator, method)(_safe_indexing(X, idx)),
14631464
atol=1e-9,
14641465
err_msg=msg,
14651466
)
@@ -1804,6 +1805,9 @@ def check_estimators_dtypes(name, estimator_orig):
18041805
def check_transformer_preserve_dtypes(name, transformer_orig):
18051806
# check that dtype are preserved meaning if input X is of some dtype
18061807
# X_transformed should be from the same dtype.
1808+
transformer = clone(transformer_orig)
1809+
if hasattr(transformer, "set_output"):
1810+
transformer.set_output(transform="default")
18071811
X, y = make_blobs(
18081812
n_samples=30,
18091813
centers=[[0, 0, 0], [1, 1, 1]],
@@ -1815,7 +1819,6 @@ def check_transformer_preserve_dtypes(name, transformer_orig):
18151819

18161820
for dtype in get_tags(transformer_orig).transformer_tags.preserves_dtype:
18171821
X_cast = X.astype(dtype)
1818-
transformer = clone(transformer_orig)
18191822
set_random_state(transformer)
18201823
X_trans1 = transformer.fit_transform(X_cast, y)
18211824
X_trans2 = transformer.fit(X_cast, y).transform(X_cast)
@@ -3777,7 +3780,9 @@ def check_fit_idempotent(name, estimator_orig):
37773780
for method in check_methods:
37783781
if hasattr(estimator, method):
37793782
new_result = getattr(estimator, method)(X_test)
3780-
if np.issubdtype(new_result.dtype, np.floating):
3783+
if hasattr(new_result, "dtype") and np.issubdtype(
3784+
new_result.dtype, np.floating
3785+
):
37813786
tol = 2 * np.finfo(new_result.dtype).eps
37823787
else:
37833788
tol = 2 * np.finfo(np.float64).eps

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from sklearn.mixture import GaussianMixture
2929
from sklearn.neighbors import KNeighborsRegressor
30+
from sklearn.preprocessing import StandardScaler
3031
from sklearn.svm import SVC, NuSVC
3132
from sklearn.utils import _array_api, all_estimators, deprecated
3233
from sklearn.utils._param_validation import Interval, StrOptions
@@ -1417,3 +1418,16 @@ def _more_tags(self):
14171418
# to exist so that third party estimators can easily support multiple sklearn
14181419
# versions.
14191420
check_estimator_tags_renamed("OkayEstimator", OkayEstimator())
1421+
1422+
1423+
# Test that set_output doesn't make the tests to fail.
1424+
def test_estimator_with_set_output():
1425+
# Doing this since pytest is not available for this file.
1426+
for lib in ["pandas", "polars"]:
1427+
try:
1428+
importlib.__import__(lib)
1429+
except ImportError:
1430+
raise SkipTest(f"Library {lib} is not installed")
1431+
1432+
estimator = StandardScaler().set_output(transform=lib)
1433+
check_estimator(estimator)

0 commit comments

Comments
 (0)