Skip to content

SimpleImputer's fill_value validation seems too strict #29381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
buhrmann opened this issue Jul 2, 2024 · 12 comments · Fixed by #30828 · May be fixed by #29564
Closed

SimpleImputer's fill_value validation seems too strict #29381

buhrmann opened this issue Jul 2, 2024 · 12 comments · Fixed by #30828 · May be fixed by #29564
Labels
Easy Well-defined and straightforward way to resolve Enhancement

Comments

@buhrmann
Copy link

buhrmann commented Jul 2, 2024

Describe the bug

The SimpleImputer checks whether the type of the fill_value can be cast with numpy to the dtype of the input data (X) using np.can_cast(fill_value_dtype, X.dtype, casting="same_kind"):

if not np.can_cast(fill_value_dtype, X.dtype, casting="same_kind"):
.

This seems too strict to me, and means one cannot impute a uint8 array with fill_value=0 (a python int). Replacing the validation with something like np.can_cast(fill_value, X.dtype, casting="safe") would be more permissible, and without knowing all the details looks safe enough, but perhaps I'm missing something.

Steps/Code to Reproduce

Though the example in isolation doesn't make a lot of sense (imputing data that doesn't contain missing values), this case might occur with generically defined transformation pipelines applied to unknown datasets.

import pandas as pd
from sklearn import impute

df = pd.Series([0, 1, 2], dtype="uint8").to_frame()
impute.SimpleImputer(strategy="constant", fill_value=0).fit_transform(df)

Expected Results

array([[0],
       [1],
       [2]], dtype=int8)

Actual Results

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[19], [line 5](vscode-notebook-cell:?execution_count=19&line=5)
      [2](vscode-notebook-cell:?execution_count=19&line=2) from sklearn import impute
      [4](vscode-notebook-cell:?execution_count=19&line=4) df = pd.Series([0, 1, 2], dtype="uint8").to_frame()
----> [5](vscode-notebook-cell:?execution_count=19&line=5) impute.SimpleImputer(strategy="constant", fill_value=0).fit_transform(df)

File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:295, in _wrap_method_output.<locals>.wrapped(self, X, *args, **kwargs)
    [293](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:293) @wraps(f)
    [294](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:294) def wrapped(self, X, *args, **kwargs):
--> [295](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:295)     data_to_wrap = f(self, X, *args, **kwargs)
    [296](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:296)     if isinstance(data_to_wrap, tuple):
    [297](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:297)         # only wrap the first output for cross decomposition
    [298](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:298)         return_tuple = (
    [299](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:299)             _wrap_data_with_container(method, data_to_wrap[0], X, self),
    [300](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:300)             *data_to_wrap[1:],
    [301](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/utils/_set_output.py:301)         )

File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1098, in TransformerMixin.fit_transform(self, X, y, **fit_params)
   [1083](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1083)         warnings.warn(
   [1084](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1084)             (
   [1085](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1085)                 f"This object ({self.__class__.__name__}) has a `transform`"
   (...)
   [1093](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1093)             UserWarning,
   [1094](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1094)         )
   [1096](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1096) if y is None:
   [1097](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1097)     # fit method of arity 1 (unsupervised transformation)
-> [1098](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1098)     return self.fit(X, **fit_params).transform(X)
   [1099](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1099) else:
   [1100](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1100)     # fit method of arity 2 (supervised transformation)
   [1101](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1101)     return self.fit(X, y, **fit_params).transform(X)

File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1474, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
   [1467](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1467)     estimator._validate_params()
   [1469](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1469) with config_context(
   [1470](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1470)     skip_parameter_validation=(
   [1471](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1471)         prefer_skip_nested_validation or global_skip_validation
   [1472](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1472)     )
   [1473](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1473) ):
-> [1474](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/base.py:1474)     return fit_method(estimator, *args, **kwargs)

File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:410, in SimpleImputer.fit(self, X, y)
    [392](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:392) @_fit_context(prefer_skip_nested_validation=True)
    [393](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:393) def fit(self, X, y=None):
    [394](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:394)     """Fit the imputer on `X`.
    [395](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:395) 
    [396](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:396)     Parameters
   (...)
    [408](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:408)         Fitted estimator.
    [409](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:409)     """
--> [410](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:410)     X = self._validate_input(X, in_fit=True)
    [412](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:412)     # default fill_value is 0 for numerical input and "missing_value"
    [413](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:413)     # otherwise
    [414](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:414)     if self.fill_value is None:

File ~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:388, in SimpleImputer._validate_input(self, X, in_fit)
    [386](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:386)     # Make sure we can safely cast fill_value dtype to the input data dtype
    [387](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:387)     if not np.can_cast(fill_value_dtype, X.dtype, casting="same_kind"):
--> [388](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:388)         raise ValueError(err_msg)
    [390](https://file+.vscode-resource.vscode-cdn.net/Users/thomas/code/notebooks/~/micromamba/envs/grapy/lib/python3.9/site-packages/sklearn/impute/_base.py:390) return X

ValueError: fill_value=0 (of type <class 'int'>) cannot be cast to the input data that is dtype('uint8'). Make sure that both dtypes are of the same kind.

Versions

System:
    python: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:55:20)  [Clang 16.0.6 ]
executable: /Users/thomas/micromamba/envs/grapy/bin/python
   machine: macOS-14.4.1-arm64-arm-64bit

Python dependencies:
      sklearn: 1.4.2
          pip: 24.0
   setuptools: 65.5.1
        numpy: 1.26.4
        scipy: 1.11.3
       Cython: 0.29.37
       pandas: 1.5.3
   matplotlib: 3.8.4
       joblib: 1.4.0
threadpoolctl: 3.4.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libopenblas
       filepath: /Users/thomas/micromamba/envs/grapy/lib/libopenblas.0.dylib
        version: 0.3.25
threading_layer: openmp
   architecture: VORTEX

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: libomp
       filepath: /Users/thomas/micromamba/envs/grapy/lib/libomp.dylib
        version: None
@buhrmann buhrmann added Bug Needs Triage Issue requires triage labels Jul 2, 2024
@lesteve
Copy link
Member

lesteve commented Jul 2, 2024

From numpy.can_cast doc below, 'same_kind' is actually less strict than 'safe' or am I missing something?

casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
    Controls what kind of data casting may occur.
 
    * 'no' means the data types should not be cast at all.
    * 'equiv' means only byte-order changes are allowed.
    * 'safe' means only casts which can preserve values are allowed.
    * 'same_kind' means only safe casts or casts within a kind,
      like float64 to float32, are allowed.
    * 'unsafe' means any data conversions may be done.

In any case, both return False in your case (Python int to np.uint8):

In [12]: np.can_cast(type(1), np.uint8, casting="same_kind")
Out[12]: False

In [13]: np.can_cast(type(1), np.uint8, casting="safe")
Out[13]: False

@buhrmann
Copy link
Author

buhrmann commented Jul 2, 2024

Ah, yes, sorry, my suggestion was using the actual fill_value not the fill_value's type (since it's more precise, an int in general cannot be cast to uint8, but a 0 in particular can):

In [3]: np.can_cast(type(0), "uint8", casting="same_kind")
Out[3]: False

In [4]: np.can_cast(0, "uint8", casting="same_kind")
Out[4]: True

In [6]: np.can_cast(-1, "uint8", casting="same_kind")
Out[6]: False

The casting strategy "safe" was a distraction. The numpy docs say that the from_ parameter should be
dtype, dtype specifier, NumPy scalar, or array, but builtin Python scalars seem to be supported too.

@buhrmann
Copy link
Author

buhrmann commented Jul 2, 2024

Hm, but just saw this warning in the docs:

Changed in version 2.0: This function does not support Python scalars anymore and does not apply any value-based logic for 0-D arrays and NumPy scalars.

So probably a no go. One workaround is using a np.uint8(0) as the fill_value instead of a Python 0, which should work for all input data types I think.

@lesteve
Copy link
Member

lesteve commented Jul 2, 2024

It seems like currently this is up to the user to fix it by using fill_value=np.uint8(0) i.e. modifying your original snippet:

import pandas as pd
from sklearn import impute

df = pd.Series([0, 1, 2], dtype="uint8").to_frame()
impute.SimpleImputer(strategy="constant", fill_value=np.uint8(0)).fit_transform(df)

I am wondering whether things are good enough or whether there is room for improvement, maybe in the error message when fill_value is a Python scalar? Personally I lean slightly towards "good enough for now".

For completeness, I noticed too that trying np.can_cast with a value instead of a type is an error in Numpy 2:

In [1]: import numpy as np

In [2]: np.can_cast(1, np.uint8)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 1
----> 1 np.can_cast(1, np.uint8)

TypeError: can_cast() does not support Python ints, floats, and complex because the result used to depend on the value.
This change was part of adopting NEP 50, we may explicitly allow them again in the future.

@lesteve lesteve removed Bug Needs Triage Issue requires triage labels Jul 2, 2024
@lesteve
Copy link
Member

lesteve commented Jul 3, 2024

It feels like long-term Numpy may potentially support casting from Python scalars as mentioned in the error message above and that doing nothing in the mean time (even if this takes a bit of time) is a reasonable option.

Two potential alternatives as a stop-gap improvement:

  • add a code suggestion in the error message if fill_value is a Python scalar to create the numpy scalar of the right dtype. There may be edge cases e.g. if fill_value=1.3 and X.dtype is np.int64, we may recommend something that loses information
  • add some logic to deal with Python scalars and create a Numpy scalar of the right dtype. One idea seems to be something like this:
    fill_value = np.array(fill_value).astype(X.dtype, casting='same_kind')
    Some care would be needed for edge cases, for example what happens if you provide fill_value=1.3 and dtype='int64', should it be an error (this is the case with the previous snippet), a warning?

@lesteve lesteve added Easy Well-defined and straightforward way to resolve Enhancement help wanted labels Jul 3, 2024
@daustria
Copy link

daustria commented Jul 4, 2024

was brainstorming potential solutions for this issue, was trying to use @lesteve 's second option but i think there is no clean way to convert python scalars to a dtype which is not known in advance, here is a relevant stack overflow thread. perhaps i am missing something?

in any case i think it would make for a nicer user experience if we took either of the two options presented above, and would be interested in helping to write something.

edit: will probably not look at this. for anyone wants to take this issue, feel free

@lesteve
Copy link
Member

lesteve commented Jul 5, 2024

in any case i think it would make for a nicer user experience if we took either of the two options presented above, and would be interested in helping to write something.

@daustria great if you are interested in contributing, I would start with my suggestion and see where that goes.

There may be some tricky edge cases with my current idea, I am guessing this is why Numpy 2 np.can_cast decided to not enable them for now. On tricky edge case I can think of (that probably does not happen too much in practice but still ...):

np.array(1000).astype(np.int8, casting='same_kind')
# output is array(-24, dtype=int8)

Some pieces of advice:

  • have a look at the contributing guide
  • you would need to make sure that this does not breaking existing tests and write additional tests to make sure the behaviour is the one we want, in particular check the error message in failure cases

@ojoaomorais
Copy link
Contributor

I will try to work on this.

@ojoaomorais
Copy link
Contributor

* add a code suggestion in the error message if `fill_value` is a Python scalar to create the numpy scalar of the right dtype. There may be edge cases e.g. if `fill_value=1.3` and `X.dtype` is `np.int64`, we may recommend something that loses information

I believe this is the best way to solve this problem. Since NumPy does not handle Python scalars within its own library, I think trying to solve the problem on our end would be an unfeasible option. What do you think about adding to the error message the information that the fill_value needs to be of a NumPy type?

@lesteve
Copy link
Member

lesteve commented Jul 23, 2024

@ojoaomorais feel free to have a go at this and open a PR! I think something that would be nice if we chose the "better error message" route (in contrary to the "automatic conversion" route):

  • for the Python scalar input, mention that the input is a Python scalar of type int, maybe I read too quickly but I found the <class 'int'> a bit confusing and thought it was a weird dtype, same things for Python float
  • add a recommendation about a possible solution, something like "a possible solution could be to use fill_value=np.uint8(0)"

As a reminder, the current error message is like this:

ValueError: fill_value=0 (of type <class 'int'>) cannot be cast to the input data that is dtype('uint8'). Make sure that both dtypes are of the same kind.

I am proposing something like (improvements more than welcome of course!):

ValueError: fill_value=0 is a Python int and cannot be cast to the input data, whose dtype is dtype('uint8'). Make sure that both dtypes are of the same kind. One possible solution could be to use a numpy scalar, i.e. fill_value=np.uint8(0)

@hoipranav
Copy link

hoipranav commented Aug 8, 2024

Hey @lesteve any update on this? If there's any issue in the PR made can I work on this?

@ojoaomorais
Copy link
Contributor

ojoaomorais commented Aug 8, 2024

Hey @lesteve any update on this? If there's any issue in the PR made can I work on this?

I'm working on this in #29564. I will submit a PR asap

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Easy Well-defined and straightforward way to resolve Enhancement
Projects
None yet
5 participants