Skip to content

API add from_estimator and from_preditions to PrecisionRecallDisplay #20552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
21832af
ENH add from_estimator and from_preditions to PredictionRecallDisplay
glemaitre Jul 17, 2021
4468c7a
iter
glemaitre Jul 19, 2021
7ea1768
TST add some common test and deprecation check
glemaitre Jul 19, 2021
b57508f
iter
glemaitre Jul 19, 2021
bc24204
more tests
glemaitre Jul 19, 2021
a74c7b1
more tests
glemaitre Jul 19, 2021
abfdb1a
TST check strings
glemaitre Jul 19, 2021
1c154ff
TST add test to check average precision computed
glemaitre Jul 19, 2021
73a6ab0
iter
glemaitre Jul 19, 2021
1529e66
DOC some update
glemaitre Jul 19, 2021
6679579
DOC some update
glemaitre Jul 19, 2021
9aa4af8
iter
glemaitre Jul 19, 2021
2c9f6b1
iter
glemaitre Jul 19, 2021
d7527cd
iter
glemaitre Jul 19, 2021
8ab59f0
DOC update user guide
glemaitre Jul 19, 2021
7240dd8
FIX order parameters
glemaitre Jul 19, 2021
1a592f9
TST add common test for future curve
glemaitre Jul 19, 2021
36602d9
revert setup.cfg
glemaitre Jul 19, 2021
0225198
simplify shape
glemaitre Jul 19, 2021
8f44996
consistency
glemaitre Jul 19, 2021
6a67f39
iter
glemaitre Jul 19, 2021
cd9a82c
add comment tweek
glemaitre Jul 19, 2021
97398cb
Merge remote-tracking branch 'origin/main' into class_methods_Precisi…
glemaitre Aug 6, 2021
e764bf2
Apply suggestions from code review
glemaitre Aug 6, 2021
57199ed
Merge remote-tracking branch 'origin/main' into class_methods_Precisi…
glemaitre Aug 6, 2021
1f38f5c
fix doc
glemaitre Aug 6, 2021
34481d7
Merge remote-tracking branch 'glemaitre/class_methods_PrecisionRecall…
glemaitre Aug 6, 2021
6460069
Merge remote-tracking branch 'origin/main' into class_methods_Precisi…
glemaitre Aug 7, 2021
76f0c79
Merge remote-tracking branch 'origin/main' into class_methods_Precisi…
glemaitre Aug 9, 2021
6efe596
Merge remote-tracking branch 'origin/main' into class_methods_Precisi…
glemaitre Aug 9, 2021
7d013ff
update error message
glemaitre Aug 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -796,9 +796,10 @@ score:

Note that the :func:`precision_recall_curve` function is restricted to the
binary case. The :func:`average_precision_score` function works only in
binary classification and multilabel indicator format. The
:func:`plot_precision_recall_curve` function plots the precision recall as
follows.
binary classification and multilabel indicator format.
The :func:`PredictionRecallDisplay.from_estimator` and
:func:`PredictionRecallDisplay.from_predictions` functions will plot the
precision-recall curve as follows.

.. image:: ../auto_examples/model_selection/images/sphx_glr_plot_precision_recall_001.png
:target: ../auto_examples/model_selection/plot_precision_recall.html#plot-the-precision-recall-curve
Expand Down
8 changes: 8 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,14 @@ Changelog
class methods and will be removed in 1.2.
:pr:`18543` by `Guillaume Lemaitre`_.

- |API| :class:`metrics.PrecisionRecallDisplay` exposes two class methods
:func:`~metrics.PrecisionRecallDisplay.from_estimator` and
:func:`~metrics.PrecisionRecallDisplay.from_predictions` allowing to create
a precision-recall curve using an estimator or the predictions.
:func:`metrics.plot_precision_recall_curve` is deprecated in favor of these
two class methods and will be removed in 1.2.
:pr:`20552` by `Guillaume Lemaitre`_.

- |API| :class:`metrics.DetCurveDisplay` exposes two class methods
:func:`~metrics.DetCurveDisplay.from_estimator` and
:func:`~metrics.DetCurveDisplay.from_predictions` allowing to create
Expand Down
192 changes: 104 additions & 88 deletions examples/model_selection/plot_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,64 +92,80 @@
"""
# %%
# In binary classification settings
# --------------------------------------------------------
# ---------------------------------
#
# Create simple data
# ..................
# Dataset and model
# .................
#
# Try to differentiate the two first classes of the iris data
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
# We will use a Linear SVC classifier to differentiate two types of irises.
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

iris = datasets.load_iris()
X = iris.data
y = iris.target
X, y = load_iris(return_X_y=True)

# Add noisy features
random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)

# Limit to the two first classes, and split into training and test
X_train, X_test, y_train, y_test = train_test_split(X[y < 2], y[y < 2],
test_size=.5,
random_state=random_state)
X_train, X_test, y_train, y_test = train_test_split(
X[y < 2], y[y < 2], test_size=0.5, random_state=random_state
)

# Create a simple classifier
classifier = svm.LinearSVC(random_state=random_state)
# %%
# Linear SVC will expect each feature to have a similar range of values. Thus,
# we will first scale the data using a
# :class:`~sklearn.preprocessing.StandardScaler`.
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC

classifier = make_pipeline(StandardScaler(), LinearSVC(random_state=random_state))
classifier.fit(X_train, y_train)
y_score = classifier.decision_function(X_test)

# %%
# Compute the average precision score
# ...................................
from sklearn.metrics import average_precision_score
average_precision = average_precision_score(y_test, y_score)
# Plot the Precision-Recall curve
# ...............................
#
# To plot the precision-recall curve, you should use
# :class:`~sklearn.metrics.PrecisionRecallDisplay`. Indeed, there is two
# methods available depending if you already computed the predictions of the
# classifier or not.
#
# Let's first plot the precision-recall curve without the classifier
# predictions. We use
# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` that
# computes the predictions for us before plotting the curve.
from sklearn.metrics import PrecisionRecallDisplay

print('Average precision-recall score: {0:0.2f}'.format(
average_precision))
display = PrecisionRecallDisplay.from_estimator(
classifier, X_test, y_test, name="LinearSVC"
)
_ = display.ax_.set_title("2-class Precision-Recall curve")

# %%
# Plot the Precision-Recall curve
# ................................
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import plot_precision_recall_curve
import matplotlib.pyplot as plt
# If we already got the estimated probabilities or scores for
# our model, then we can use
# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions`.
y_score = classifier.decision_function(X_test)

disp = plot_precision_recall_curve(classifier, X_test, y_test)
disp.ax_.set_title('2-class Precision-Recall curve: '
'AP={0:0.2f}'.format(average_precision))
display = PrecisionRecallDisplay.from_predictions(y_test, y_score, name="LinearSVC")
_ = display.ax_.set_title("2-class Precision-Recall curve")

# %%
# In multi-label settings
# ------------------------
# -----------------------
#
# The precision-recall curve does not support the multilabel setting. However,
# one can decide how to handle this case. We show such an example below.
#
# Create multi-label data, fit, and predict
# ...........................................
# .........................................
#
# We create a multi-label dataset, to illustrate the precision-recall in
# multi-label settings
# multi-label settings.

from sklearn.preprocessing import label_binarize

Expand All @@ -158,95 +174,95 @@
n_classes = Y.shape[1]

# Split into training and test
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.5,
random_state=random_state)
X_train, X_test, Y_train, Y_test = train_test_split(
X, Y, test_size=0.5, random_state=random_state
)

# We use OneVsRestClassifier for multi-label prediction
# %%
# We use :class:`~sklearn.multiclass.OneVsRestClassifier` for multi-label
# prediction.
from sklearn.multiclass import OneVsRestClassifier

# Run classifier
classifier = OneVsRestClassifier(svm.LinearSVC(random_state=random_state))
classifier = OneVsRestClassifier(
make_pipeline(StandardScaler(), LinearSVC(random_state=random_state))
)
classifier.fit(X_train, Y_train)
y_score = classifier.decision_function(X_test)


# %%
# The average precision score in multi-label settings
# ....................................................
# ...................................................
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score

# For each class
precision = dict()
recall = dict()
average_precision = dict()
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i],
y_score[:, i])
precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], y_score[:, i])
average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i])

# A "micro-average": quantifying score on all classes jointly
precision["micro"], recall["micro"], _ = precision_recall_curve(Y_test.ravel(),
y_score.ravel())
average_precision["micro"] = average_precision_score(Y_test, y_score,
average="micro")
print('Average precision score, micro-averaged over all classes: {0:0.2f}'
.format(average_precision["micro"]))
precision["micro"], recall["micro"], _ = precision_recall_curve(
Y_test.ravel(), y_score.ravel()
)
average_precision["micro"] = average_precision_score(Y_test, y_score, average="micro")

# %%
# Plot the micro-averaged Precision-Recall curve
# ...............................................
#

plt.figure()
plt.step(recall['micro'], precision['micro'], where='post')

plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title(
'Average precision score, micro-averaged over all classes: AP={0:0.2f}'
.format(average_precision["micro"]))
# ..............................................
display = PrecisionRecallDisplay(
recall=recall["micro"],
precision=precision["micro"],
average_precision=average_precision["micro"],
)
display.plot()
_ = display.ax_.set_title("Micro-averaged over all classes")

# %%
# Plot Precision-Recall curve for each class and iso-f1 curves
# .............................................................
#
# ............................................................
import matplotlib.pyplot as plt
from itertools import cycle

# setup plot details
colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])
colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])

_, ax = plt.subplots(figsize=(7, 8))

plt.figure(figsize=(7, 8))
f_scores = np.linspace(0.2, 0.8, num=4)
lines = []
labels = []
lines, labels = [], []
for f_score in f_scores:
x = np.linspace(0.01, 1)
y = f_score * x / (2 * x - f_score)
l, = plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2)
plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02))
(l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))

lines.append(l)
labels.append('iso-f1 curves')
l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=2)
lines.append(l)
labels.append('micro-average Precision-recall (area = {0:0.2f})'
''.format(average_precision["micro"]))
display = PrecisionRecallDisplay(
recall=recall["micro"],
precision=precision["micro"],
average_precision=average_precision["micro"],
)
display.plot(ax=ax, name="Micro-average precision-recall", color="gold")

for i, color in zip(range(n_classes), colors):
l, = plt.plot(recall[i], precision[i], color=color, lw=2)
lines.append(l)
labels.append('Precision-recall for class {0} (area = {1:0.2f})'
''.format(i, average_precision[i]))

fig = plt.gcf()
fig.subplots_adjust(bottom=0.25)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Extension of Precision-Recall curve to multi-class')
plt.legend(lines, labels, loc=(0, -.38), prop=dict(size=14))

display = PrecisionRecallDisplay(
recall=recall[i],
precision=precision[i],
average_precision=average_precision[i],
)
display.plot(ax=ax, name=f"Precision-recall for class {i}", color=color)

# add the legend for the iso-f1 curves
handles, labels = display.ax_.get_legend_handles_labels()
handles.extend([l])
labels.extend(["iso-f1 curves"])
# set the legend and the axes
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.legend(handles=handles, labels=labels, loc="best")
ax.set_title("Extension of Precision-Recall curve to multi-class")

plt.show()
Loading