Skip to content

Regression in SelectorMixin in 1.6.0rc1 #30324

Closed
@bmreiniger

Description

@bmreiniger

Describe the bug

Using the estimator tag allow_nan doesn't work with SelectorMixin in the release candidate.

A first skim suggests maybe ensure_all_finite is inconsistently expected to be False and other times "allow-nan"? In particular at https://github.com/scikit-learn/scikit-learn/blame/439ea045ad44e6a09115dc23e9bf23db00ff41de/sklearn/utils/validation.py#L1110 ?

Steps/Code to Reproduce

from sklearn.feature_selection import SelectorMixin
from sklearn.base import BaseEstimator
import numpy as np

class MyEstimator(SelectorMixin, BaseEstimator):
    def __init__(self):
        pass
    def fit(self, X, y=None):
        return self
    def _get_support_mask(self):
        mask = np.ones(self.n_features_in_, dtype=bool)
        return mask
    def _more_tags(self):
        return {'allow_nan': True}

my_est = MyEstimator()
my_est.fit_transform(np.array([5, 7, np.nan, 9]).reshape(2, 2))

Expected Results

No error is thrown, and the numpy array is returned unchanged.

Actual Results

ValueError                                Traceback (most recent call last)
[<ipython-input-2-d8e360602655>](https://localhost:8080/#) in <cell line: 20>()
     18 
     19 my_est = MyEstimator()
---> 20 my_est.fit_transform(np.array([5, 7, np.nan, 9]).reshape(2, 2))

7 frames
[/usr/local/lib/python3.10/dist-packages/sklearn/utils/_set_output.py](https://localhost:8080/#) in wrapped(self, X, *args, **kwargs)
    317     @wraps(f)
    318     def wrapped(self, X, *args, **kwargs):
--> 319         data_to_wrap = f(self, X, *args, **kwargs)
    320         if isinstance(data_to_wrap, tuple):
    321             # only wrap the first output for cross decomposition

[/usr/local/lib/python3.10/dist-packages/sklearn/base.py](https://localhost:8080/#) in fit_transform(self, X, y, **fit_params)
    857         if y is None:
    858             # fit method of arity 1 (unsupervised transformation)
--> 859             return self.fit(X, **fit_params).transform(X)
    860         else:
    861             # fit method of arity 2 (supervised transformation)

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/_set_output.py](https://localhost:8080/#) in wrapped(self, X, *args, **kwargs)
    317     @wraps(f)
    318     def wrapped(self, X, *args, **kwargs):
--> 319         data_to_wrap = f(self, X, *args, **kwargs)
    320         if isinstance(data_to_wrap, tuple):
    321             # only wrap the first output for cross decomposition

[/usr/local/lib/python3.10/dist-packages/sklearn/feature_selection/_base.py](https://localhost:8080/#) in transform(self, X)
    105         # note: we use get_tags instead of __sklearn_tags__ because this is a
    106         # public Mixin.
--> 107         X = validate_data(
    108             self,
    109             X,

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in validate_data(_estimator, X, y, reset, validate_separately, skip_check_array, **check_params)
   2931             out = X, y
   2932     elif not no_val_X and no_val_y:
-> 2933         out = check_array(X, input_name="X", **check_params)
   2934     elif no_val_X and not no_val_y:
   2935         out = _check_y(y, **check_params)

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_all_finite, ensure_non_negative, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)
   1104 
   1105         if ensure_all_finite:
-> 1106             _assert_all_finite(
   1107                 array,
   1108                 input_name=input_name,

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in _assert_all_finite(X, allow_nan, msg_dtype, estimator_name, input_name)
    118         return
    119 
--> 120     _assert_all_finite_element_wise(
    121         X,
    122         xp=xp,

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in _assert_all_finite_element_wise(X, xp, allow_nan, msg_dtype, estimator_name, input_name)
    167                 "#estimators-that-handle-nan-values"
    168             )
--> 169         raise ValueError(msg_err)
    170 
    171 

ValueError: Input X contains NaN.
MyEstimator does not accept missing values encoded as NaN natively. For supervised learning, you might want to consider sklearn.ensemble.HistGradientBoostingClassifier and Regressor which accept missing values encoded as NaNs natively. Alternatively, it is possible to preprocess the data, for instance by using an imputer transformer in a pipeline or drop samples with missing values. See https://scikit-learn.org/stable/modules/impute.html You can find a list of all estimators that handle NaN values at the following page: https://scikit-learn.org/stable/modules/impute.html#estimators-that-handle-nan-values

Versions

System:
    python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
executable: /usr/bin/python3
   machine: Linux-6.1.85+-x86_64-with-glibc2.35

Python dependencies:
      sklearn: 1.6.0rc1
          pip: 24.1.2
   setuptools: 75.1.0
        numpy: 1.26.4
        scipy: 1.13.1
       Cython: 3.0.11
       pandas: 2.2.2
   matplotlib: 3.8.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 2
         prefix: libopenblas
       filepath: /usr/local/lib/python3.10/dist-packages/numpy.libs/libopenblas64_p-r0-0cf96a72.3.23.dev.so
        version: 0.3.23.dev
threading_layer: pthreads
   architecture: Haswell

       user_api: blas
   internal_api: openblas
    num_threads: 2
         prefix: libopenblas
       filepath: /usr/local/lib/python3.10/dist-packages/scipy.libs/libopenblasp-r0-01191904.3.27.so
        version: 0.3.27
threading_layer: pthreads
   architecture: Haswell

       user_api: openmp
   internal_api: openmp
    num_threads: 2
         prefix: libgomp
       filepath: /usr/local/lib/python3.10/dist-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions