Skip to content

SimpleImputer converts int32[pyarrow] extension array to float64, subsequently crashing with numpy int32 values #31412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
intrigus-lgtm opened this issue May 21, 2025 · 1 comment
Labels
Needs Triage Issue requires triage

Comments

@intrigus-lgtm
Copy link

Describe the bug

When using the SimpleImputer with a pyarrow-backed pandas DataFrame, any float/integer data is converted to None/float64 instead.
This causes the imputer to be fitted to float64, crashing on a dtype assertion when passing it a numpy-backed int32 DataFrame after fitting.

The flow is the following:

  1. The imputer calls _validate_input:
    def _validate_input(self, X, in_fit):
  2. This calls validate_data:
    X = validate_data(
    self,
    X,
    reset=in_fit,
    accept_sparse="csc",
    dtype=dtype,
    force_writeable=True if not in_fit else None,
    ensure_all_finite=ensure_all_finite,
    copy=self.copy,
    )
  3. This calls check_array:
    elif not no_val_X and no_val_y:
    out = check_array(X, input_name="X", **check_params)
  4. Our input is a pandas dataframe:
    if hasattr(array, "dtypes") and hasattr(array.dtypes, "__array__"):
  5. This now checks if the dtypes need to be converted:
    pandas_requires_conversion = any(
    _pandas_dtype_needs_early_conversion(i) for i in dtypes_orig
    )
  6. Our input is backed by an extension array and int32[pyarrow] is an integer datatype, so we return True here:
    if isinstance(pd_dtype, SparseDtype) or not is_extension_array_dtype(pd_dtype):
    # Sparse arrays will be converted later in `check_array`
    # Only handle extension arrays for integer and floats
    return False
    elif is_float_dtype(pd_dtype):
    # Float ndarrays can normally support nans. They need to be converted
    # first to map pd.NA to np.nan
    return True
    elif is_integer_dtype(pd_dtype):
    # XXX: Warn when converting from a high integer to a float
    return True
  7. Finally we pass the "needs conversion" check and convert the dataframe to dtype (which is None here, which apparently means float64):
    if pandas_requires_conversion:
    # pandas dataframe requires conversion earlier to handle extension dtypes with
    # nans
    # Use the original dtype for conversion if dtype is None
    new_dtype = dtype_orig if dtype is None else dtype
    array = array.astype(new_dtype)

Steps/Code to Reproduce

import polars as pl
from sklearn.impute import SimpleImputer
import numpy as np

print(
    SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=0)
      .fit(pl.DataFrame({\"a\": [10]}, schema={\"a\": pl.Int32}).to_pandas(use_pyarrow_extension_array=False))
      ._fit_dtype
)
# prints dtype('int32'), as expected

print(
    SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=0)
      .fit(pl.DataFrame({\"a\": [10]}, schema={\"a\": pl.Int32}).to_pandas(use_pyarrow_extension_array=True))
      ._fit_dtype
)
# prints dtype('float64') (!!)

Expected Results

Both imputers should be fitted with int32 values.

Actual Results

The imputer using the pyarrow extension array is fitted with float64.

This causes crashes when using the Imputer with normal int32 columns backed by numpy, as they won't be converted and therefore the dtypes differ.

Versions

System:
    python: 3.12.9 (main, Mar 11 2025, 17:26:57) [Clang 20.1.0 ]
executable: /tmp/scikit/.venv/bin/python3
   machine: Linux-6.14.4-arch1-2-x86_64-with-glibc2.41

Python dependencies:
      sklearn: 1.6.1
          pip: None
   setuptools: None
        numpy: 2.2.5
        scipy: 1.15.3
       Cython: None
       pandas: 2.2.3
   matplotlib: None
       joblib: 1.5.0
threadpoolctl: 3.6.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 16
         prefix: libscipy_openblas
       filepath: /tmp/scikit/.venv/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-6bb31eeb.so
        version: 0.3.28
threading_layer: pthreads
   architecture: Haswell

       user_api: blas
   internal_api: openblas
    num_threads: 16
         prefix: libscipy_openblas
       filepath: /tmp/scikit/.venv/lib/python3.12/site-packages/scipy.libs/libscipy_openblas-68440149.so
        version: 0.3.28
threading_layer: pthreads
   architecture: Haswell

       user_api: openmp
   internal_api: openmp
    num_threads: 16
         prefix: libgomp
       filepath: /tmp/scikit/.venv/lib/python3.12/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None
@github-actions github-actions bot added the Needs Triage Issue requires triage label May 21, 2025
@intrigus-lgtm
Copy link
Author

That was the wrong window.
Apologies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Needs Triage Issue requires triage
Projects
None yet
Development

No branches or pull requests

1 participant