Skip to content

SimpleImputer converts int32[pyarrow] extension array to float64, subsequently crashing with numpy int32 values #31412

@intrigus-lgtm

Description

@intrigus-lgtm

Describe the bug

When using the SimpleImputer with a pyarrow-backed pandas DataFrame, any float/integer data is converted to None/float64 instead.
This causes the imputer to be fitted to float64, crashing on a dtype assertion when passing it a numpy-backed int32 DataFrame after fitting.

The flow is the following:

  1. The imputer calls _validate_input:
    def _validate_input(self, X, in_fit):
  2. This calls validate_data:
    X = validate_data(
    self,
    X,
    reset=in_fit,
    accept_sparse="csc",
    dtype=dtype,
    force_writeable=True if not in_fit else None,
    ensure_all_finite=ensure_all_finite,
    copy=self.copy,
    )
  3. This calls check_array:
    elif not no_val_X and no_val_y:
    out = check_array(X, input_name="X", **check_params)
  4. Our input is a pandas dataframe:
    if hasattr(array, "dtypes") and hasattr(array.dtypes, "__array__"):
  5. This now checks if the dtypes need to be converted:
    pandas_requires_conversion = any(
    _pandas_dtype_needs_early_conversion(i) for i in dtypes_orig
    )
  6. Our input is backed by an extension array and int32[pyarrow] is an integer datatype, so we return True here:
    if isinstance(pd_dtype, SparseDtype) or not is_extension_array_dtype(pd_dtype):
    # Sparse arrays will be converted later in `check_array`
    # Only handle extension arrays for integer and floats
    return False
    elif is_float_dtype(pd_dtype):
    # Float ndarrays can normally support nans. They need to be converted
    # first to map pd.NA to np.nan
    return True
    elif is_integer_dtype(pd_dtype):
    # XXX: Warn when converting from a high integer to a float
    return True
  7. Finally we pass the "needs conversion" check and convert the dataframe to dtype (which is None here, which apparently means float64):
    if pandas_requires_conversion:
    # pandas dataframe requires conversion earlier to handle extension dtypes with
    # nans
    # Use the original dtype for conversion if dtype is None
    new_dtype = dtype_orig if dtype is None else dtype
    array = array.astype(new_dtype)

Steps/Code to Reproduce

import polars as pl
from sklearn.impute import SimpleImputer
import numpy as np

print(
    SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=0)
      .fit(pl.DataFrame({\"a\": [10]}, schema={\"a\": pl.Int32}).to_pandas(use_pyarrow_extension_array=False))
      ._fit_dtype
)
# prints dtype('int32'), as expected

print(
    SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=0)
      .fit(pl.DataFrame({\"a\": [10]}, schema={\"a\": pl.Int32}).to_pandas(use_pyarrow_extension_array=True))
      ._fit_dtype
)
# prints dtype('float64') (!!)

Expected Results

Both imputers should be fitted with int32 values.

Actual Results

The imputer using the pyarrow extension array is fitted with float64.

This causes crashes when using the Imputer with normal int32 columns backed by numpy, as they won't be converted and therefore the dtypes differ.

Versions

System:
    python: 3.12.9 (main, Mar 11 2025, 17:26:57) [Clang 20.1.0 ]
executable: /tmp/scikit/.venv/bin/python3
   machine: Linux-6.14.4-arch1-2-x86_64-with-glibc2.41

Python dependencies:
      sklearn: 1.6.1
          pip: None
   setuptools: None
        numpy: 2.2.5
        scipy: 1.15.3
       Cython: None
       pandas: 2.2.3
   matplotlib: None
       joblib: 1.5.0
threadpoolctl: 3.6.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 16
         prefix: libscipy_openblas
       filepath: /tmp/scikit/.venv/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-6bb31eeb.so
        version: 0.3.28
threading_layer: pthreads
   architecture: Haswell

       user_api: blas
   internal_api: openblas
    num_threads: 16
         prefix: libscipy_openblas
       filepath: /tmp/scikit/.venv/lib/python3.12/site-packages/scipy.libs/libscipy_openblas-68440149.so
        version: 0.3.28
threading_layer: pthreads
   architecture: Haswell

       user_api: openmp
   internal_api: openmp
    num_threads: 16
         prefix: libgomp
       filepath: /tmp/scikit/.venv/lib/python3.12/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions