Skip to content

ValidationCurveDisplay can't handle categorical/string parameters #28536

Open
@sroener

Description

@sroener

Describe the bug

Hi,

I performed some optimization on a few models implemented via the sklearn API. For fine tuning, I want to visualize the effect of certain hyperparameters using the ValidationCurveDisplay implementation. For numerical parameters, everything works fine. Unfortunately, as soon as categorical parameters (passed as strings) are used, an error is raised.

Steps/Code to Reproduce

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import ValidationCurveDisplay, validation_curve
from sklearn.linear_model import LogisticRegression

X, y = make_classification(n_samples=1_000, random_state=0)
logistic_regression = LogisticRegression()

# put categorical values in param space
param_name, param_range = "penalty", ['elasticnet', 'l1', 'l2']
train_scores, test_scores = validation_curve(
     logistic_regression, X, y, param_name=param_name, param_range=param_range
 )
display = ValidationCurveDisplay(
     param_name=param_name, param_range=param_range,
     train_scores=train_scores, test_scores=test_scores, score_name="Score"
 )
display.plot()
plt.show()

Expected Results

The expected result is a validation curve display separating the values by their category and no errors.

Actual Results

---------------------------------------------------------------------------
UFuncTypeError                            Traceback (most recent call last)
Cell In [12], line 19
     12 train_scores, test_scores = validation_curve(
     13      logistic_regression, X, y, param_name=param_name, param_range=param_range
     14  )
     15 display = ValidationCurveDisplay(
     16      param_name=param_name, param_range=param_range,
     17      train_scores=train_scores, test_scores=test_scores, score_name="Score"
     18  )
---> 19 display.plot()
     20 plt.show()

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/model_selection/_plot.py:691, in ValidationCurveDisplay.plot(self, ax, negate_score, score_name, score_type, std_display_style, line_kw, fill_between_kw, errorbar_kw)
    630 def plot(
    631     self,
    632     ax=None,
   (...)
    640     errorbar_kw=None,
    641 ):
    642     """Plot visualization.
    643 
    644     Parameters
   (...)
    689         Object that stores computed values.
    690     """
--> 691     self._plot_curve(
    692         self.param_range,
    693         ax=ax,
    694         negate_score=negate_score,
    695         score_name=score_name,
    696         score_type=score_type,
    697         log_scale="deprecated",
    698         std_display_style=std_display_style,
    699         line_kw=line_kw,
    700         fill_between_kw=fill_between_kw,
    701         errorbar_kw=errorbar_kw,
    702     )
    703     self.ax_.set_xlabel(f"{self.param_name}")
    704     return self

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/model_selection/_plot.py:126, in _BaseCurveDisplay._plot_curve(self, x_data, ax, negate_score, score_name, score_type, log_scale, std_display_style, line_kw, fill_between_kw, errorbar_kw)
    121     xscale = "log" if log_scale else "linear"
    122 else:
    123     # We found that a ratio, smaller or bigger than 5, between the largest and
    124     # smallest gap of the x values is a good indicator to choose between linear
    125     # and log scale.
--> 126     if _interval_max_min_ratio(x_data) > 5:
    127         xscale = "symlog" if x_data.min() <= 0 else "log"
    128     else:

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/utils/_plotting.py:97, in _interval_max_min_ratio(data)
     90 def _interval_max_min_ratio(data):
     91     """Compute the ratio between the largest and smallest inter-point distances.
     92 
     93     A value larger than 5 typically indicates that the parameter range would
     94     better be displayed with a log scale while a linear scale would be more
     95     suitable otherwise.
     96     """
---> 97     diff = np.diff(np.sort(data))
     98     return diff.max() / diff.min()

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/numpy/lib/function_base.py:1452, in diff(a, n, axis, prepend, append)
   1450 op = not_equal if a.dtype == np.bool_ else subtract
   1451 for _ in range(n):
-> 1452     a = op(a[slice1], a[slice2])
   1454 return a

UFuncTypeError: ufunc 'subtract' did not contain a loop with signature matching types (dtype('<U10'), dtype('<U10')) -> None

Versions

System:
    python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0]
executable: ~/work/miniconda/envs/GC_overhaul_nb/bin/python
   machine: Linux-4.18.0-348.2.1.el8_5.x86_64-x86_64-with-glibc2.28

Python dependencies:
      sklearn: 1.4.1.post1
          pip: 22.2.2
   setuptools: 65.4.1
        numpy: 1.25.2
        scipy: 1.9.2
       Cython: 3.0.0
       pandas: 1.5.0
   matplotlib: 3.8.0
       joblib: 1.2.0
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: ~/work/miniconda/envs/GC_overhaul_nb/lib/libopenblasp-r0.3.21.so
        version: 0.3.21
threading_layer: pthreads
   architecture: Haswell
    num_threads: 6

       user_api: openmp
   internal_api: openmp
         prefix: libgomp
       filepath: ~/work/miniconda/envs/GC_overhaul_nb/lib/libgomp.so.1.0.0
        version: None
    num_threads: 6

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    Needs decision

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions