Skip to content

sklearn.utils._param_validation._InstancesOf is insufficient for numpy data types #23599

Closed
@Diadochokinetic

Description

@Diadochokinetic

Describe the bug

Numpy data types can be constructed in different ways. Although they result in the same data type, isinstance() yields different results based on the way the data type has been constructed.

>>> isinstance(np.float64, type)
True
>>> isinstance(np.float64, np.dtype)
False
>>> isinstance(np.dtype('float64'), np.dtype)
True
>>> isinstance(np.dtype('float64'), type)
False
>>> 

Therefore, if an object accepts numpy data types as a parameter, the parameter constraints need to be [type, np.dtype]. Currently this isn't supported, because sklearn.utils._param_validation._InstancesOf initalizes with the built-in type

 def __init__(self, type):
        self.type = type

This results in errors when trying to implement parameter_constraints in objects with numpy data types as parameters. E.g. #23579

I therefore propose to change the parameter self.type to self.param_type and initialize it with the actual type given within the list of parameter constraints.

Steps/Code to Reproduce

See #23579

Expected Results

No error is thrown.

Actual Results

ValueError: The 'dtype' parameter of OneHotEncoder must be an instance of 'type'. Got dtype('float64') instead.

Versions

System:
    python: 3.8.13 (default, Mar 28 2022, 11:38:47)  [GCC 7.5.0]
executable: /home/fabian/anaconda3/envs/sklearn-dev/bin/python3
   machine: Linux-5.13.0-48-generic-x86_64-with-glibc2.17

Python dependencies:
      sklearn: 1.2.dev0
          pip: 21.2.4
   setuptools: 61.2.0
        numpy: 1.17.3
        scipy: 1.3.2
       Cython: None
       pandas: None
   matplotlib: None
       joblib: 1.0.0
threadpoolctl: 2.0.0

Built with OpenMP: True

threadpoolctl info:
       filepath: /home/fabian/anaconda3/envs/sklearn-dev/lib/libgomp.so.1.0.0
         prefix: libgomp
       user_api: openmp
   internal_api: openmp
        version: None
    num_threads: 12

       filepath: /home/fabian/anaconda3/envs/sklearn-dev/lib/python3.8/site-packages/numpy/.libs/libopenblasp-r0-34a18dc3.3.7.so
         prefix: libopenblas
       user_api: blas
   internal_api: openblas
        version: 0.3.7
    num_threads: 12
threading_layer: pthreads

       filepath: /home/fabian/anaconda3/envs/sklearn-dev/lib/python3.8/site-packages/scipy/.libs/libopenblasp-r0-2ecf47d5.3.7.dev.so
         prefix: libopenblas
       user_api: blas
   internal_api: openblas
        version: 0.3.7.dev
    num_threads: 12
threading_layer: pthreads

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions