Skip to content

ValidationCurveDisplay can't handle categorical/string parameters #28536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sroener opened this issue Feb 26, 2024 · 6 comments · May be fixed by #31043
Open

ValidationCurveDisplay can't handle categorical/string parameters #28536

sroener opened this issue Feb 26, 2024 · 6 comments · May be fixed by #31043

Comments

@sroener
Copy link

sroener commented Feb 26, 2024

Describe the bug

Hi,

I performed some optimization on a few models implemented via the sklearn API. For fine tuning, I want to visualize the effect of certain hyperparameters using the ValidationCurveDisplay implementation. For numerical parameters, everything works fine. Unfortunately, as soon as categorical parameters (passed as strings) are used, an error is raised.

Steps/Code to Reproduce

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import ValidationCurveDisplay, validation_curve
from sklearn.linear_model import LogisticRegression

X, y = make_classification(n_samples=1_000, random_state=0)
logistic_regression = LogisticRegression()

# put categorical values in param space
param_name, param_range = "penalty", ['elasticnet', 'l1', 'l2']
train_scores, test_scores = validation_curve(
     logistic_regression, X, y, param_name=param_name, param_range=param_range
 )
display = ValidationCurveDisplay(
     param_name=param_name, param_range=param_range,
     train_scores=train_scores, test_scores=test_scores, score_name="Score"
 )
display.plot()
plt.show()

Expected Results

The expected result is a validation curve display separating the values by their category and no errors.

Actual Results

---------------------------------------------------------------------------
UFuncTypeError                            Traceback (most recent call last)
Cell In [12], line 19
     12 train_scores, test_scores = validation_curve(
     13      logistic_regression, X, y, param_name=param_name, param_range=param_range
     14  )
     15 display = ValidationCurveDisplay(
     16      param_name=param_name, param_range=param_range,
     17      train_scores=train_scores, test_scores=test_scores, score_name="Score"
     18  )
---> 19 display.plot()
     20 plt.show()

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/model_selection/_plot.py:691, in ValidationCurveDisplay.plot(self, ax, negate_score, score_name, score_type, std_display_style, line_kw, fill_between_kw, errorbar_kw)
    630 def plot(
    631     self,
    632     ax=None,
   (...)
    640     errorbar_kw=None,
    641 ):
    642     """Plot visualization.
    643 
    644     Parameters
   (...)
    689         Object that stores computed values.
    690     """
--> 691     self._plot_curve(
    692         self.param_range,
    693         ax=ax,
    694         negate_score=negate_score,
    695         score_name=score_name,
    696         score_type=score_type,
    697         log_scale="deprecated",
    698         std_display_style=std_display_style,
    699         line_kw=line_kw,
    700         fill_between_kw=fill_between_kw,
    701         errorbar_kw=errorbar_kw,
    702     )
    703     self.ax_.set_xlabel(f"{self.param_name}")
    704     return self

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/model_selection/_plot.py:126, in _BaseCurveDisplay._plot_curve(self, x_data, ax, negate_score, score_name, score_type, log_scale, std_display_style, line_kw, fill_between_kw, errorbar_kw)
    121     xscale = "log" if log_scale else "linear"
    122 else:
    123     # We found that a ratio, smaller or bigger than 5, between the largest and
    124     # smallest gap of the x values is a good indicator to choose between linear
    125     # and log scale.
--> 126     if _interval_max_min_ratio(x_data) > 5:
    127         xscale = "symlog" if x_data.min() <= 0 else "log"
    128     else:

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/utils/_plotting.py:97, in _interval_max_min_ratio(data)
     90 def _interval_max_min_ratio(data):
     91     """Compute the ratio between the largest and smallest inter-point distances.
     92 
     93     A value larger than 5 typically indicates that the parameter range would
     94     better be displayed with a log scale while a linear scale would be more
     95     suitable otherwise.
     96     """
---> 97     diff = np.diff(np.sort(data))
     98     return diff.max() / diff.min()

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/numpy/lib/function_base.py:1452, in diff(a, n, axis, prepend, append)
   1450 op = not_equal if a.dtype == np.bool_ else subtract
   1451 for _ in range(n):
-> 1452     a = op(a[slice1], a[slice2])
   1454 return a

UFuncTypeError: ufunc 'subtract' did not contain a loop with signature matching types (dtype('<U10'), dtype('<U10')) -> None

Versions

System:
    python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0]
executable: ~/work/miniconda/envs/GC_overhaul_nb/bin/python
   machine: Linux-4.18.0-348.2.1.el8_5.x86_64-x86_64-with-glibc2.28

Python dependencies:
      sklearn: 1.4.1.post1
          pip: 22.2.2
   setuptools: 65.4.1
        numpy: 1.25.2
        scipy: 1.9.2
       Cython: 3.0.0
       pandas: 1.5.0
   matplotlib: 3.8.0
       joblib: 1.2.0
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: ~/work/miniconda/envs/GC_overhaul_nb/lib/libopenblasp-r0.3.21.so
        version: 0.3.21
threading_layer: pthreads
   architecture: Haswell
    num_threads: 6

       user_api: openmp
   internal_api: openmp
         prefix: libgomp
       filepath: ~/work/miniconda/envs/GC_overhaul_nb/lib/libgomp.so.1.0.0
        version: None
    num_threads: 6
@sroener sroener added Bug Needs Triage Issue requires triage labels Feb 26, 2024
@ogrisel
Copy link
Member

ogrisel commented Feb 26, 2024

The error message should be improved to state explicitly that ValidationCurveDisplay is not meant for non-numerical hyper-parameters.

For categorical parameters, I assume we could do a (horizontal?) bar plot instead of line plots (similarly to what we do for categorical features in the PDP display class).

But then the word "curve" in ValidationCurveDsplay would be confusing anyhow. I am not sure we want to support this in scikit-learn. Maybe we should instead provide better tooling to inspect the cv_results_ attribute of a fitted grid search object.

@ogrisel ogrisel removed the Needs Triage Issue requires triage label Feb 26, 2024
@sroener
Copy link
Author

sroener commented Feb 26, 2024

I can see the point that a curve for a low number of categorical values might be suboptimal, but it still gives information.

I am not sure, if (grouped) bar plots showing train and test scores give the same visual effect as the "curves".

A workaround for categorical values seems to be encoding the variables as numerical values (e.g. with LabelEncoder), plotting the numerical values and remapping the original categories.

Additionally, mixed typed params seem also not to work as expected. For example for the class_weight parameter:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import ValidationCurveDisplay, validation_curve
from sklearn.linear_model import LogisticRegression

X, y = make_classification(n_samples=1_000, random_state=0)
logistic_regression = LogisticRegression()

# put categorical values in param space
param_name, param_range = "class_weight", [{0:0.9, 1:1.1}, 'balanced']
train_scores, test_scores = validation_curve(
     logistic_regression, X, y, param_name=param_name, param_range=param_range
 )
display = ValidationCurveDisplay(
     param_name=param_name, param_range=param_range,
     train_scores=train_scores, test_scores=test_scores, score_name="Score"
 )
display.plot()
plt.show()

which returns:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [1], line 19
     12 train_scores, test_scores = validation_curve(
     13      logistic_regression, X, y, param_name=param_name, param_range=param_range
     14  )
     15 display = ValidationCurveDisplay(
     16      param_name=param_name, param_range=param_range,
     17      train_scores=train_scores, test_scores=test_scores, score_name="Score"
     18  )
---> 19 display.plot()
     20 plt.show()

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/model_selection/_plot.py:691, in ValidationCurveDisplay.plot(self, ax, negate_score, score_name, score_type, std_display_style, line_kw, fill_between_kw, errorbar_kw)
    630 def plot(
    631     self,
    632     ax=None,
   (...)
    640     errorbar_kw=None,
    641 ):
    642     """Plot visualization.
    643 
    644     Parameters
   (...)
    689         Object that stores computed values.
    690     """
--> 691     self._plot_curve(
    692         self.param_range,
    693         ax=ax,
    694         negate_score=negate_score,
    695         score_name=score_name,
    696         score_type=score_type,
    697         log_scale="deprecated",
    698         std_display_style=std_display_style,
    699         line_kw=line_kw,
    700         fill_between_kw=fill_between_kw,
    701         errorbar_kw=errorbar_kw,
    702     )
    703     self.ax_.set_xlabel(f"{self.param_name}")
    704     return self

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/model_selection/_plot.py:64, in _BaseCurveDisplay._plot_curve(self, x_data, ax, negate_score, score_name, score_type, log_scale, std_display_style, line_kw, fill_between_kw, errorbar_kw)
     61 self.lines_ = []
     62 for line_label, score in scores.items():
     63     self.lines_.append(
---> 64         *ax.plot(
     65             x_data,
     66             score.mean(axis=1),
     67             label=line_label,
     68             **line_kw,
     69         )
     70     )
     71 self.errorbar_ = None
     72 self.fill_between_ = None  # overwritten below by fill_between

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/axes/_axes.py:1723, in Axes.plot(self, scalex, scaley, data, *args, **kwargs)
   1721 lines = [*self._get_lines(self, *args, data=data, **kwargs)]
   1722 for line in lines:
-> 1723     self.add_line(line)
   1724 if scalex:
   1725     self._request_autoscale_view("x")

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/axes/_base.py:2309, in _AxesBase.add_line(self, line)
   2306 if line.get_clip_path() is None:
   2307     line.set_clip_path(self.patch)
-> 2309 self._update_line_limits(line)
   2310 if not line.get_label():
   2311     line.set_label(f'_child{len(self._children)}')

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/axes/_base.py:2332, in _AxesBase._update_line_limits(self, line)
   2328 def _update_line_limits(self, line):
   2329     """
   2330     Figures out the data limit of the given line, updating self.dataLim.
   2331     """
-> 2332     path = line.get_path()
   2333     if path.vertices.size == 0:
   2334         return

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/lines.py:1032, in Line2D.get_path(self)
   1030 """Return the `~matplotlib.path.Path` associated with this line."""
   1031 if self._invalidy or self._invalidx:
-> 1032     self.recache()
   1033 return self._path

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/lines.py:669, in Line2D.recache(self, always)
    667 if always or self._invalidx:
    668     xconv = self.convert_xunits(self._xorig)
--> 669     x = _to_unmasked_float_array(xconv).ravel()
    670 else:
    671     x = self._x

File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/cbook/__init__.py:1369, in _to_unmasked_float_array(x)
   1367     return np.ma.asarray(x, float).filled(np.nan)
   1368 else:
-> 1369     return np.asarray(x, float)

TypeError: float() argument must be a string or a real number, not 'dict'

@glemaitre
Copy link
Member

I am not sure, if (grouped) bar plots showing train and test scores give the same visual effect as the "curves"

Visual curve will be a bad way visualizing this information since the line create a relationship between a category and the two closest encoded categories while there is no obligation to have any dependence between the categories. On the top, the ordering will also matter then.

It is a point that I mentioned here for the PDP. There are alternative plot for that.

@ogrisel
Copy link
Member

ogrisel commented Feb 28, 2024

In particular, curve typically linearly interpolate between observed values. For categorical values, this interpolation is devoid of any meaning.

@sroener Do you have particular categorical hyperparameters with an ordinal structure in mind? If so can you manually do the plot you are looking forward and post it in this discussion thread so that we can have a more grounded design making process w.r.t. the acceptance or not of such feature request?

@sroener
Copy link
Author

sroener commented Feb 29, 2024

@glemaitre I see your point and agree that curves are probably not the best way to visualize categories. Nevertheless, connecting categories might be helpful in the context of comparing parameters, because the human eye is quite good in picking up differences in slopes.

@ogrisel I agree that from a data perspective there is not much information from the interpolation, but think that the visual effect can be helpful in interpreting the differences (e.g. diverging trends between train and test score) more easily. I don't have a particular categorical hyperparameter in mind, especially because it feels like putting an ordinal structure to it could be very subjective (e.g. unbalanced < balanced class weights). Nevertheless, I would suggest a point plot, which would capture the categorical structure of the hyperparameter and give some visual cues via the connecting lines for test and train performance.

As an example, see the Seaborn tutorial on categorical data with the respective documentation of Pointplots

@kaekkr
Copy link

kaekkr commented Feb 25, 2025

Hello @glemaitre,

I would like to take ownership of this enhancement to support categorical parameters in ValidationCurveDisplay. I have started working on this issue and plan to modify the constructor and plotting methods to ensure compatibility with both numerical and categorical parameters.

I will follow the contribution guidelines and submit a pull request once the implementation is complete and tested.

Thank you, and I’m looking forward to contributing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Needs decision
Development

Successfully merging a pull request may close this issue.

4 participants