Skip to content

ENH Improve ROC curves visualization and add option to plot chance level #25972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ Changelog
curves.
:pr:`24668` by :user:`dberenbaum`.

- |Enhancement| :class:`RocCurveDisplay` now plots the ROC curve with both axes
limited to [0, 1] and a loosely dotted frame. There is also an additional
parameter `plot_chance_level` to determine whether to plot the chance level.
:pr:`25972` by :user:`Yao Xiao <Charlie-XIAO>`.

- |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are
not normalized, instead of actually normalizing them in the metric. Starting from
1.5 this will raise an error. :pr:`25299` by :user:`Omar Salman <OmarManzoor`.
Expand Down
37 changes: 35 additions & 2 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class RocCurveDisplay:
line_ : matplotlib Artist
ROC Curve.

chance_level_ : matplotlib Artist
The chance level line or None if the chance level is not plotted.

ax_ : matplotlib Axes
Axes with ROC Curve.

Expand Down Expand Up @@ -81,7 +84,7 @@ def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=Non
self.roc_auc = roc_auc
self.pos_label = pos_label

def plot(self, ax=None, *, name=None, **kwargs):
def plot(self, ax=None, *, name=None, plot_chance_level=True, **kwargs):
"""Plot visualization.

Extra keyword arguments will be passed to matplotlib's ``plot``.
Expand All @@ -96,6 +99,9 @@ def plot(self, ax=None, *, name=None, **kwargs):
Name of ROC Curve for labeling. If `None`, use `estimator_name` if
not `None`, otherwise no labeling is shown.

plot_chance_level : bool, default=True
Whether to plot the chance level.

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -123,6 +129,24 @@ def plot(self, ax=None, *, name=None, **kwargs):
if ax is None:
fig, ax = plt.subplots()

# Set limits of axes to [0, 1] and fix aspect ratio to squared
ax.set_xlim((0, 1))
ax.set_ylim((0, 1))
ax.set_aspect(1)
Comment on lines +132 to +135
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this change in another PR. We will need an additional entry in the changelog since we are fixing/improving the rendering.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be shared with PR and ROC curve


# Plot the frame in dotted line, so that the curve can be
# seen better when values are close to 0 or 1
for s in ["right", "left", "top", "bottom"]:
ax.spines[s].set_linestyle((0, (1, 5)))
ax.spines[s].set_linewidth(0.5)
Comment on lines +137 to +141
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. We can postpone the despining. It could also be shared between different type of plots. And we should be able to control it via some keywords.


if plot_chance_level:
(self.chance_level_,) = ax.plot(
(0, 1), (0, 1), linestyle="dotted", label="Chance level"
)
else:
self.chance_level_ = None

(self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs)
info_pos_label = (
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
Expand Down Expand Up @@ -152,6 +176,7 @@ def from_estimator(
pos_label=None,
name=None,
ax=None,
plot_chance_level=True,
**kwargs,
):
"""Create a ROC Curve display from an estimator.
Expand Down Expand Up @@ -195,6 +220,9 @@ def from_estimator(
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

plot_chance_level : bool, default=True
Whether to plot the chance level.

**kwargs : dict
Comment on lines +225 to 226
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add the documnetation for chance_level_kwargs

Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -245,6 +273,7 @@ def from_estimator(
name=name,
ax=ax,
pos_label=pos_label,
plot_chance_level=plot_chance_level,
**kwargs,
)

Expand All @@ -259,6 +288,7 @@ def from_predictions(
pos_label=None,
name=None,
ax=None,
plot_chance_level=True,
**kwargs,
):
"""Plot ROC curve given the true and predicted values.
Expand Down Expand Up @@ -298,6 +328,9 @@ def from_predictions(
Axes object to plot on. If `None`, a new figure and axes is
created.

plot_chance_level : bool, default=True
Whether to plot the chance level.

**kwargs : dict
Additional keywords arguments passed to matplotlib `plot` function.

Expand Down Expand Up @@ -348,4 +381,4 @@ def from_predictions(
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label
)

return viz.plot(ax=ax, name=name, **kwargs)
return viz.plot(ax=ax, name=name, plot_chance_level=plot_chance_level, **kwargs)
20 changes: 20 additions & 0 deletions sklearn/metrics/_plot/tests/test_roc_curve_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def data_binary(data):
@pytest.mark.parametrize("with_sample_weight", [True, False])
@pytest.mark.parametrize("drop_intermediate", [True, False])
@pytest.mark.parametrize("with_strings", [True, False])
@pytest.mark.parametrize("plot_chance_level", [True, False])
@pytest.mark.parametrize(
"constructor_name, default_name",
[
Expand All @@ -50,6 +51,7 @@ def test_roc_curve_display_plotting(
with_sample_weight,
drop_intermediate,
with_strings,
plot_chance_level,
constructor_name,
default_name,
):
Expand Down Expand Up @@ -82,6 +84,7 @@ def test_roc_curve_display_plotting(
drop_intermediate=drop_intermediate,
pos_label=pos_label,
alpha=0.8,
plot_chance_level=plot_chance_level,
)
else:
display = RocCurveDisplay.from_predictions(
Expand All @@ -91,6 +94,7 @@ def test_roc_curve_display_plotting(
drop_intermediate=drop_intermediate,
pos_label=pos_label,
alpha=0.8,
plot_chance_level=plot_chance_level,
)

fpr, tpr, _ = roc_curve(
Expand All @@ -114,6 +118,13 @@ def test_roc_curve_display_plotting(
assert isinstance(display.ax_, mpl.axes.Axes)
assert isinstance(display.figure_, mpl.figure.Figure)

if plot_chance_level:
assert isinstance(display.chance_level_, mpl.lines.Line2D)
assert tuple(display.chance_level_.get_xdata()) == (0, 1)
assert tuple(display.chance_level_.get_ydata()) == (0, 1)
else:
assert display.chance_level_ is None

expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})"
assert display.line_.get_label() == expected_label

Expand All @@ -124,6 +135,15 @@ def test_roc_curve_display_plotting(
assert display.ax_.get_ylabel() == expected_ylabel
assert display.ax_.get_xlabel() == expected_xlabel

assert display.ax_.get_xlim() == (0, 1)
assert display.ax_.get_ylim() == (0, 1)
assert display.ax_.get_aspect() == 1

# Check frame styles
for s in ["right", "left", "top", "bottom"]:
assert display.ax_.spines[s].get_linestyle() == (0, (1, 5))
assert display.ax_.spines[s].get_linewidth() <= 0.5


@pytest.mark.parametrize(
"clf",
Expand Down