Description
Describe the bug
The SimpleImputer
checks whether the type of the fill_value
can be cast with numpy to the dtype of the input data (X
) using np.can_cast(fill_value_dtype, X.dtype, casting="same_kind")
:
scikit-learn/sklearn/impute/_base.py
Line 397 in 0ad90d5
This seems too strict to me, and means one cannot impute a uint8 array with fill_value=0 (a python int). Replacing the validation with something like np.can_cast(fill_value, X.dtype, casting="safe")
would be more permissible, and without knowing all the details looks safe enough, but perhaps I'm missing something.
Steps/Code to Reproduce
Though the example in isolation doesn't make a lot of sense (imputing data that doesn't contain missing values), this case might occur with generically defined transformation pipelines applied to unknown datasets.
import pandas as pd
from sklearn import impute
df = pd.Series([0, 1, 2], dtype="uint8").to_frame()
impute.SimpleImputer(strategy="constant", fill_value=0).fit_transform(df)
Expected Results
array([[0],
[1],
[2]], dtype=int8)
Actual Results
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[19], [line 5](vscode-notebook-cell:?execution_count=19&line=5)
[2](vscode-notebook-cell:?execution_count=19&line=2) from sklearn import impute
[4](vscode-notebook-cell:?execution_count=19&line=4) df = pd.Series([0, 1, 2], dtype="uint8").to_frame()
----> [5](vscode-notebook-cell:?execution_count=19&line=5) impute.SimpleImputer(strategy="constant", fill_value=0).fit_transform(df)
File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:295, in _wrap_method_output.<locals>.wrapped(self, X, *args, **kwargs)
[293](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:293) @wraps(f)
[294](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:294) def wrapped(self, X, *args, **kwargs):
--> [295](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:295) data_to_wrap = f(self, X, *args, **kwargs)
[296](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:296) if isinstance(data_to_wrap, tuple):
[297](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:297) # only wrap the first output for cross decomposition
[298](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:298) return_tuple = (
[299](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:299) _wrap_data_with_container(method, data_to_wrap[0], X, self),
[300](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:300) *data_to_wrap[1:],
[301](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:301) )
File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1098, in TransformerMixin.fit_transform(self, X, y, **fit_params)
[1083](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1083) warnings.warn(
[1084](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1084) (
[1085](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1085) f"This object ({self.__class__.__name__}) has a `transform`"
(...)
[1093](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1093) UserWarning,
[1094](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1094) )
[1096](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1096) if y is None:
[1097](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1097) # fit method of arity 1 (unsupervised transformation)
-> [1098](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1098) return self.fit(X, **fit_params).transform(X)
[1099](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1099) else:
[1100](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1100) # fit method of arity 2 (supervised transformation)
[1101](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1101) return self.fit(X, y, **fit_params).transform(X)
File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1474, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
[1467](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1467) estimator._validate_params()
[1469](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1469) with config_context(
[1470](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1470) skip_parameter_validation=(
[1471](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1471) prefer_skip_nested_validation or global_skip_validation
[1472](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1472) )
[1473](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1473) ):
-> [1474](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1474) return fit_method(estimator, *args, **kwargs)
File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:410, in SimpleImputer.fit(self, X, y)
[392](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:392) @_fit_context(prefer_skip_nested_validation=True)
[393](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:393) def fit(self, X, y=None):
[394](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:394) """Fit the imputer on `X`.
[395](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:395)
[396](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:396) Parameters
(...)
[408](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:408) Fitted estimator.
[409](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:409) """
--> [410](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:410) X = self._validate_input(X, in_fit=True)
[412](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:412) # default fill_value is 0 for numerical input and "missing_value"
[413](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:413) # otherwise
[414](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:414) if self.fill_value is None:
File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:388, in SimpleImputer._validate_input(self, X, in_fit)
[386](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:386) # Make sure we can safely cast fill_value dtype to the input data dtype
[387](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:387) if not np.can_cast(fill_value_dtype, X.dtype, casting="same_kind"):
--> [388](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:388) raise ValueError(err_msg)
[390](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:390) return X
ValueError: fill_value=0 (of type <class 'int'>) cannot be cast to the input data that is dtype('uint8'). Make sure that both dtypes are of the same kind.
Versions
System:
python: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:55:20) [Clang 16.0.6 ]
executable: /Users/thomas/micromamba/envs/grapy/bin/python
machine: macOS-14.4.1-arm64-arm-64bit
Python dependencies:
sklearn: 1.4.2
pip: 24.0
setuptools: 65.5.1
numpy: 1.26.4
scipy: 1.11.3
Cython: 0.29.37
pandas: 1.5.3
matplotlib: 3.8.4
joblib: 1.4.0
threadpoolctl: 3.4.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
filepath: /Users/thomas/micromamba/envs/grapy/lib/libopenblas.0.dylib
version: 0.3.25
threading_layer: openmp
architecture: VORTEX
user_api: openmp
internal_api: openmp
num_threads: 8
prefix: libomp
filepath: /Users/thomas/micromamba/envs/grapy/lib/libomp.dylib
version: None