Open
Description
Describe the bug
Hi,
I performed some optimization on a few models implemented via the sklearn API. For fine tuning, I want to visualize the effect of certain hyperparameters using the ValidationCurveDisplay
implementation. For numerical parameters, everything works fine. Unfortunately, as soon as categorical parameters (passed as strings) are used, an error is raised.
Steps/Code to Reproduce
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import ValidationCurveDisplay, validation_curve
from sklearn.linear_model import LogisticRegression
X, y = make_classification(n_samples=1_000, random_state=0)
logistic_regression = LogisticRegression()
# put categorical values in param space
param_name, param_range = "penalty", ['elasticnet', 'l1', 'l2']
train_scores, test_scores = validation_curve(
logistic_regression, X, y, param_name=param_name, param_range=param_range
)
display = ValidationCurveDisplay(
param_name=param_name, param_range=param_range,
train_scores=train_scores, test_scores=test_scores, score_name="Score"
)
display.plot()
plt.show()
Expected Results
The expected result is a validation curve display separating the values by their category and no errors.
Actual Results
---------------------------------------------------------------------------
UFuncTypeError Traceback (most recent call last)
Cell In [12], line 19
12 train_scores, test_scores = validation_curve(
13 logistic_regression, X, y, param_name=param_name, param_range=param_range
14 )
15 display = ValidationCurveDisplay(
16 param_name=param_name, param_range=param_range,
17 train_scores=train_scores, test_scores=test_scores, score_name="Score"
18 )
---> 19 display.plot()
20 plt.show()
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/model_selection/_plot.py:691, in ValidationCurveDisplay.plot(self, ax, negate_score, score_name, score_type, std_display_style, line_kw, fill_between_kw, errorbar_kw)
630 def plot(
631 self,
632 ax=None,
(...)
640 errorbar_kw=None,
641 ):
642 """Plot visualization.
643
644 Parameters
(...)
689 Object that stores computed values.
690 """
--> 691 self._plot_curve(
692 self.param_range,
693 ax=ax,
694 negate_score=negate_score,
695 score_name=score_name,
696 score_type=score_type,
697 log_scale="deprecated",
698 std_display_style=std_display_style,
699 line_kw=line_kw,
700 fill_between_kw=fill_between_kw,
701 errorbar_kw=errorbar_kw,
702 )
703 self.ax_.set_xlabel(f"{self.param_name}")
704 return self
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/model_selection/_plot.py:126, in _BaseCurveDisplay._plot_curve(self, x_data, ax, negate_score, score_name, score_type, log_scale, std_display_style, line_kw, fill_between_kw, errorbar_kw)
121 xscale = "log" if log_scale else "linear"
122 else:
123 # We found that a ratio, smaller or bigger than 5, between the largest and
124 # smallest gap of the x values is a good indicator to choose between linear
125 # and log scale.
--> 126 if _interval_max_min_ratio(x_data) > 5:
127 xscale = "symlog" if x_data.min() <= 0 else "log"
128 else:
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/sklearn/utils/_plotting.py:97, in _interval_max_min_ratio(data)
90 def _interval_max_min_ratio(data):
91 """Compute the ratio between the largest and smallest inter-point distances.
92
93 A value larger than 5 typically indicates that the parameter range would
94 better be displayed with a log scale while a linear scale would be more
95 suitable otherwise.
96 """
---> 97 diff = np.diff(np.sort(data))
98 return diff.max() / diff.min()
File ~/work/miniconda/envs/GC_overhaul_nb/lib/python3.10/site-packages/numpy/lib/function_base.py:1452, in diff(a, n, axis, prepend, append)
1450 op = not_equal if a.dtype == np.bool_ else subtract
1451 for _ in range(n):
-> 1452 a = op(a[slice1], a[slice2])
1454 return a
UFuncTypeError: ufunc 'subtract' did not contain a loop with signature matching types (dtype('<U10'), dtype('<U10')) -> None
Versions
System:
python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0]
executable: ~/work/miniconda/envs/GC_overhaul_nb/bin/python
machine: Linux-4.18.0-348.2.1.el8_5.x86_64-x86_64-with-glibc2.28
Python dependencies:
sklearn: 1.4.1.post1
pip: 22.2.2
setuptools: 65.4.1
numpy: 1.25.2
scipy: 1.9.2
Cython: 3.0.0
pandas: 1.5.0
matplotlib: 3.8.0
joblib: 1.2.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: ~/work/miniconda/envs/GC_overhaul_nb/lib/libopenblasp-r0.3.21.so
version: 0.3.21
threading_layer: pthreads
architecture: Haswell
num_threads: 6
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: ~/work/miniconda/envs/GC_overhaul_nb/lib/libgomp.so.1.0.0
version: None
num_threads: 6