-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
SimpleImputer converts int32[pyarrow]
extension array to float64
, subsequently crashing with numpy int32
values
#31373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
The core issue stems from the fact that
This leads to unexpected conversion of pyarrow-backed integer columns to float64 inside check_array in scikit-learn, causing inconsistent dtype handling and eventual crashes when mixing pandas and polars dataframes or numpy arrays. Proposal I would like to work on making the Happy to discuss the best approach and how to write tests for this. /take |
I am not passing in a polars type, I am passing in a pandas dataframe backed by pyarrow. The >>> typ = pl.DataFrame({"a": [10]}, schema={"a": pl.Int32}).to_pandas(use_pyarrow_extension_array=True).dtypes["a"]
>>> is_integer_dtype(typ)
True |
The issue could be the first check in It is checking for data types and assigning a default from @gdacciaro will adding support for pyarrow dtypes in the if-else block work?
|
@Astroficboy |
@gdacciaro Will try for a pull request. |
Describe the bug
When using the
SimpleImputer
with a pyarrow-backed pandas DataFrame, any float/integer data is converted toNone
/float64
instead.This causes the imputer to be fitted to
float64
, crashing on a dtype assertion when passing it a numpy-backedint32
DataFrame after fitting.The flow is the following:
_validate_input
:scikit-learn/sklearn/impute/_base.py
Line 319 in 675736a
validate_data
:scikit-learn/sklearn/impute/_base.py
Lines 344 to 353 in 675736a
check_array
:scikit-learn/sklearn/utils/validation.py
Lines 2951 to 2952 in 675736a
scikit-learn/sklearn/utils/validation.py
Line 909 in 675736a
scikit-learn/sklearn/utils/validation.py
Lines 925 to 927 in 675736a
int32[pyarrow]
is an integer datatype, so we returnTrue
here:scikit-learn/sklearn/utils/validation.py
Lines 714 to 724 in 675736a
dtype
(which isNone
here, which apparently meansfloat64
):scikit-learn/sklearn/utils/validation.py
Lines 966 to 971 in 675736a
Steps/Code to Reproduce
Expected Results
Both imputers should be fitted with
int32
values.Actual Results
The imputer using the pyarrow extension array is fitted with
float64
.This causes crashes when using the Imputer with normal
int32
columns backed by numpy, as they won't be converted and therefore the dtypes differ.Versions
The text was updated successfully, but these errors were encountered: