Skip to content

SimpleImputer's fill_value validation seems too strict #29381

@buhrmann

Description

@buhrmann

Describe the bug

The SimpleImputer checks whether the type of the fill_value can be cast with numpy to the dtype of the input data (X) using np.can_cast(fill_value_dtype, X.dtype, casting="same_kind"):

if not np.can_cast(fill_value_dtype, X.dtype, casting="same_kind"):
.

This seems too strict to me, and means one cannot impute a uint8 array with fill_value=0 (a python int). Replacing the validation with something like np.can_cast(fill_value, X.dtype, casting="safe") would be more permissible, and without knowing all the details looks safe enough, but perhaps I'm missing something.

Steps/Code to Reproduce

Though the example in isolation doesn't make a lot of sense (imputing data that doesn't contain missing values), this case might occur with generically defined transformation pipelines applied to unknown datasets.

import pandas as pd
from sklearn import impute

df = pd.Series([0, 1, 2], dtype="uint8").to_frame()
impute.SimpleImputer(strategy="constant", fill_value=0).fit_transform(df)

Expected Results

array([[0],
       [1],
       [2]], dtype=int8)

Actual Results

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[19], [line 5](vscode-notebook-cell:?execution_count=19&line=5)
      [2](vscode-notebook-cell:?execution_count=19&line=2) from sklearn import impute
      [4](vscode-notebook-cell:?execution_count=19&line=4) df = pd.Series([0, 1, 2], dtype="uint8").to_frame()
----> [5](vscode-notebook-cell:?execution_count=19&line=5) impute.SimpleImputer(strategy="constant", fill_value=0).fit_transform(df)

File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:295, in _wrap_method_output.<locals>.wrapped(self, X, *args, **kwargs)
    [293](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:293) @wraps(f)
    [294](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:294) def wrapped(self, X, *args, **kwargs):
--> [295](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:295)     data_to_wrap = f(self, X, *args, **kwargs)
    [296](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:296)     if isinstance(data_to_wrap, tuple):
    [297](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:297)         # only wrap the first output for cross decomposition
    [298](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:298)         return_tuple = (
    [299](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:299)             _wrap_data_with_container(method, data_to_wrap[0], X, self),
    [300](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:300)             *data_to_wrap[1:],
    [301](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:301)         )

File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1098, in TransformerMixin.fit_transform(self, X, y, **fit_params)
   [1083](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1083)         warnings.warn(
   [1084](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1084)             (
   [1085](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1085)                 f"This object ({self.__class__.__name__}) has a `transform`"
   (...)
   [1093](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1093)             UserWarning,
   [1094](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1094)         )
   [1096](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1096) if y is None:
   [1097](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1097)     # fit method of arity 1 (unsupervised transformation)
-> [1098](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1098)     return self.fit(X, **fit_params).transform(X)
   [1099](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1099) else:
   [1100](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1100)     # fit method of arity 2 (supervised transformation)
   [1101](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1101)     return self.fit(X, y, **fit_params).transform(X)

File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1474, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
   [1467](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1467)     estimator._validate_params()
   [1469](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1469) with config_context(
   [1470](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1470)     skip_parameter_validation=(
   [1471](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1471)         prefer_skip_nested_validation or global_skip_validation
   [1472](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1472)     )
   [1473](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1473) ):
-> [1474](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1474)     return fit_method(estimator, *args, **kwargs)

File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:410, in SimpleImputer.fit(self, X, y)
    [392](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:392) @_fit_context(prefer_skip_nested_validation=True)
    [393](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:393) def fit(self, X, y=None):
    [394](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:394)     """Fit the imputer on `X`.
    [395](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:395) 
    [396](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:396)     Parameters
   (...)
    [408](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:408)         Fitted estimator.
    [409](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:409)     """
--> [410](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:410)     X = self._validate_input(X, in_fit=True)
    [412](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:412)     # default fill_value is 0 for numerical input and "missing_value"
    [413](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:413)     # otherwise
    [414](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:414)     if self.fill_value is None:

File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:388, in SimpleImputer._validate_input(self, X, in_fit)
    [386](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:386)     # Make sure we can safely cast fill_value dtype to the input data dtype
    [387](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:387)     if not np.can_cast(fill_value_dtype, X.dtype, casting="same_kind"):
--> [388](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:388)         raise ValueError(err_msg)
    [390](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:390) return X

ValueError: fill_value=0 (of type <class 'int'>) cannot be cast to the input data that is dtype('uint8'). Make sure that both dtypes are of the same kind.

Versions

System:
    python: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:55:20)  [Clang 16.0.6 ]
executable: /Users/thomas/micromamba/envs/grapy/bin/python
   machine: macOS-14.4.1-arm64-arm-64bit

Python dependencies:
      sklearn: 1.4.2
          pip: 24.0
   setuptools: 65.5.1
        numpy: 1.26.4
        scipy: 1.11.3
       Cython: 0.29.37
       pandas: 1.5.3
   matplotlib: 3.8.4
       joblib: 1.4.0
threadpoolctl: 3.4.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libopenblas
       filepath: /Users/thomas/micromamba/envs/grapy/lib/libopenblas.0.dylib
        version: 0.3.25
threading_layer: openmp
   architecture: VORTEX

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: libomp
       filepath: /Users/thomas/micromamba/envs/grapy/lib/libomp.dylib
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    EasyWell-defined and straightforward way to resolveEnhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions