-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ValidationCurveDisplay can't handle categorical/string parameters #28536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
The error message should be improved to state explicitly that For categorical parameters, I assume we could do a (horizontal?) bar plot instead of line plots (similarly to what we do for categorical features in the PDP display class). But then the word "curve" in |
I can see the point that a curve for a low number of categorical values might be suboptimal, but it still gives information. I am not sure, if (grouped) bar plots showing train and test scores give the same visual effect as the "curves". A workaround for categorical values seems to be encoding the variables as numerical values (e.g. with LabelEncoder), plotting the numerical values and remapping the original categories. Additionally, mixed typed params seem also not to work as expected. For example for the import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import ValidationCurveDisplay, validation_curve
from sklearn.linear_model import LogisticRegression
X, y = make_classification(n_samples=1_000, random_state=0)
logistic_regression = LogisticRegression()
# put categorical values in param space
param_name, param_range = "class_weight", [{0:0.9, 1:1.1}, 'balanced']
train_scores, test_scores = validation_curve(
logistic_regression, X, y, param_name=param_name, param_range=param_range
)
display = ValidationCurveDisplay(
param_name=param_name, param_range=param_range,
train_scores=train_scores, test_scores=test_scores, score_name="Score"
)
display.plot()
plt.show() which returns: ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In [1], line 19
12 train_scores, test_scores = validation_curve(
13 logistic_regression, X, y, param_name=param_name, param_range=param_range
14 )
15 display = ValidationCurveDisplay(
16 param_name=param_name, param_range=param_range,
17 train_scores=train_scores, test_scores=test_scores, score_name="Score"
18 )
---> 19 display.plot()
20 plt.show()
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/model_selection/_plot.py:691, in ValidationCurveDisplay.plot(self, ax, negate_score, score_name, score_type, std_display_style, line_kw, fill_between_kw, errorbar_kw)
630 def plot(
631 self,
632 ax=None,
(...)
640 errorbar_kw=None,
641 ):
642 """Plot visualization.
643
644 Parameters
(...)
689 Object that stores computed values.
690 """
--> 691 self._plot_curve(
692 self.param_range,
693 ax=ax,
694 negate_score=negate_score,
695 score_name=score_name,
696 score_type=score_type,
697 log_scale="deprecated",
698 std_display_style=std_display_style,
699 line_kw=line_kw,
700 fill_between_kw=fill_between_kw,
701 errorbar_kw=errorbar_kw,
702 )
703 self.ax_.set_xlabel(f"{self.param_name}")
704 return self
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/model_selection/_plot.py:64, in _BaseCurveDisplay._plot_curve(self, x_data, ax, negate_score, score_name, score_type, log_scale, std_display_style, line_kw, fill_between_kw, errorbar_kw)
61 self.lines_ = []
62 for line_label, score in scores.items():
63 self.lines_.append(
---> 64 *ax.plot(
65 x_data,
66 score.mean(axis=1),
67 label=line_label,
68 **line_kw,
69 )
70 )
71 self.errorbar_ = None
72 self.fill_between_ = None # overwritten below by fill_between
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/axes/_axes.py:1723, in Axes.plot(self, scalex, scaley, data, *args, **kwargs)
1721 lines = [*self._get_lines(self, *args, data=data, **kwargs)]
1722 for line in lines:
-> 1723 self.add_line(line)
1724 if scalex:
1725 self._request_autoscale_view("x")
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/axes/_base.py:2309, in _AxesBase.add_line(self, line)
2306 if line.get_clip_path() is None:
2307 line.set_clip_path(self.patch)
-> 2309 self._update_line_limits(line)
2310 if not line.get_label():
2311 line.set_label(f'_child{len(self._children)}')
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/axes/_base.py:2332, in _AxesBase._update_line_limits(self, line)
2328 def _update_line_limits(self, line):
2329 """
2330 Figures out the data limit of the given line, updating self.dataLim.
2331 """
-> 2332 path = line.get_path()
2333 if path.vertices.size == 0:
2334 return
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/lines.py:1032, in Line2D.get_path(self)
1030 """Return the `~matplotlib.path.Path` associated with this line."""
1031 if self._invalidy or self._invalidx:
-> 1032 self.recache()
1033 return self._path
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/lines.py:669, in Line2D.recache(self, always)
667 if always or self._invalidx:
668 xconv = self.convert_xunits(self._xorig)
--> 669 x = _to_unmasked_float_array(xconv).ravel()
670 else:
671 x = self._x
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/matplotlib/cbook/__init__.py:1369, in _to_unmasked_float_array(x)
1367 return np.ma.asarray(x, float).filled(np.nan)
1368 else:
-> 1369 return np.asarray(x, float)
TypeError: float() argument must be a string or a real number, not 'dict' |
Visual curve will be a bad way visualizing this information since the line create a relationship between a category and the two closest encoded categories while there is no obligation to have any dependence between the categories. On the top, the ordering will also matter then. It is a point that I mentioned here for the PDP. There are alternative plot for that. |
In particular, curve typically linearly interpolate between observed values. For categorical values, this interpolation is devoid of any meaning. @sroener Do you have particular categorical hyperparameters with an ordinal structure in mind? If so can you manually do the plot you are looking forward and post it in this discussion thread so that we can have a more grounded design making process w.r.t. the acceptance or not of such feature request? |
@glemaitre I see your point and agree that curves are probably not the best way to visualize categories. Nevertheless, connecting categories might be helpful in the context of comparing parameters, because the human eye is quite good in picking up differences in slopes. @ogrisel I agree that from a data perspective there is not much information from the interpolation, but think that the visual effect can be helpful in interpreting the differences (e.g. diverging trends between train and test score) more easily. I don't have a particular categorical hyperparameter in mind, especially because it feels like putting an ordinal structure to it could be very subjective (e.g. unbalanced < balanced class weights). Nevertheless, I would suggest a point plot, which would capture the categorical structure of the hyperparameter and give some visual cues via the connecting lines for test and train performance. As an example, see the Seaborn tutorial on categorical data with the respective documentation of Pointplots |
Hello @glemaitre, I would like to take ownership of this enhancement to support categorical parameters in ValidationCurveDisplay. I have started working on this issue and plan to modify the constructor and plotting methods to ensure compatibility with both numerical and categorical parameters. I will follow the contribution guidelines and submit a pull request once the implementation is complete and tested. Thank you, and I’m looking forward to contributing! |
Describe the bug
Hi,
I performed some optimization on a few models implemented via the sklearn API. For fine tuning, I want to visualize the effect of certain hyperparameters using the
ValidationCurveDisplay
implementation. For numerical parameters, everything works fine. Unfortunately, as soon as categorical parameters (passed as strings) are used, an error is raised.Steps/Code to Reproduce
Expected Results
The expected result is a validation curve display separating the values by their category and no errors.
Actual Results
Versions
The text was updated successfully, but these errors were encountered: