-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
Description
Describe the bug
When using the SimpleImputer
with a pyarrow-backed pandas DataFrame, any float/integer data is converted to None
/float64
instead.
This causes the imputer to be fitted to float64
, crashing on a dtype assertion when passing it a numpy-backed int32
DataFrame after fitting.
The flow is the following:
- The imputer calls
_validate_input
:
scikit-learn/sklearn/impute/_base.py
Line 319 in 675736a
def _validate_input(self, X, in_fit): - This calls
validate_data
:
scikit-learn/sklearn/impute/_base.py
Lines 344 to 353 in 675736a
X = validate_data( self, X, reset=in_fit, accept_sparse="csc", dtype=dtype, force_writeable=True if not in_fit else None, ensure_all_finite=ensure_all_finite, copy=self.copy, ) - This calls
check_array
:
scikit-learn/sklearn/utils/validation.py
Lines 2951 to 2952 in 675736a
elif not no_val_X and no_val_y: out = check_array(X, input_name="X", **check_params) - Our input is a pandas dataframe:
scikit-learn/sklearn/utils/validation.py
Line 909 in 675736a
if hasattr(array, "dtypes") and hasattr(array.dtypes, "__array__"): - This now checks if the dtypes need to be converted:
scikit-learn/sklearn/utils/validation.py
Lines 925 to 927 in 675736a
pandas_requires_conversion = any( _pandas_dtype_needs_early_conversion(i) for i in dtypes_orig ) - Our input is backed by an extension array and
int32[pyarrow]
is an integer datatype, so we returnTrue
here:
scikit-learn/sklearn/utils/validation.py
Lines 714 to 724 in 675736a
if isinstance(pd_dtype, SparseDtype) or not is_extension_array_dtype(pd_dtype): # Sparse arrays will be converted later in `check_array` # Only handle extension arrays for integer and floats return False elif is_float_dtype(pd_dtype): # Float ndarrays can normally support nans. They need to be converted # first to map pd.NA to np.nan return True elif is_integer_dtype(pd_dtype): # XXX: Warn when converting from a high integer to a float return True - Finally we pass the "needs conversion" check and convert the dataframe to
dtype
(which isNone
here, which apparently meansfloat64
):
scikit-learn/sklearn/utils/validation.py
Lines 966 to 971 in 675736a
if pandas_requires_conversion: # pandas dataframe requires conversion earlier to handle extension dtypes with # nans # Use the original dtype for conversion if dtype is None new_dtype = dtype_orig if dtype is None else dtype array = array.astype(new_dtype)
Steps/Code to Reproduce
import polars as pl
from sklearn.impute import SimpleImputer
import numpy as np
print(
SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=0)
.fit(pl.DataFrame({\"a\": [10]}, schema={\"a\": pl.Int32}).to_pandas(use_pyarrow_extension_array=False))
._fit_dtype
)
# prints dtype('int32'), as expected
print(
SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=0)
.fit(pl.DataFrame({\"a\": [10]}, schema={\"a\": pl.Int32}).to_pandas(use_pyarrow_extension_array=True))
._fit_dtype
)
# prints dtype('float64') (!!)
Expected Results
Both imputers should be fitted with int32
values.
Actual Results
The imputer using the pyarrow extension array is fitted with float64
.
This causes crashes when using the Imputer with normal int32
columns backed by numpy, as they won't be converted and therefore the dtypes differ.
Versions
System:
python: 3.12.9 (main, Mar 11 2025, 17:26:57) [Clang 20.1.0 ]
executable: /tmp/scikit/.venv/bin/python3
machine: Linux-6.14.4-arch1-2-x86_64-with-glibc2.41
Python dependencies:
sklearn: 1.6.1
pip: None
setuptools: None
numpy: 2.2.5
scipy: 1.15.3
Cython: None
pandas: 2.2.3
matplotlib: None
joblib: 1.5.0
threadpoolctl: 3.6.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 16
prefix: libscipy_openblas
filepath: /tmp/scikit/.venv/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-6bb31eeb.so
version: 0.3.28
threading_layer: pthreads
architecture: Haswell
user_api: blas
internal_api: openblas
num_threads: 16
prefix: libscipy_openblas
filepath: /tmp/scikit/.venv/lib/python3.12/site-packages/scipy.libs/libscipy_openblas-68440149.so
version: 0.3.28
threading_layer: pthreads
architecture: Haswell
user_api: openmp
internal_api: openmp
num_threads: 16
prefix: libgomp
filepath: /tmp/scikit/.venv/lib/python3.12/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None