Skip to content

Commit f1a03ac

Browse files
committed
DOC: More readable precision-recall example
And better tests, and reference better the Pascal VOC challenge (which actually claims that it uses the 11pt average precision, but doesn't)
1 parent d969100 commit f1a03ac

File tree

5 files changed

+392
-158
lines changed

5 files changed

+392
-158
lines changed

doc/modules/model_evaluation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ Here are some small examples in binary classification::
681681
>>> threshold
682682
array([ 0.35, 0.4 , 0.8 ])
683683
>>> average_precision_score(y_true, y_scores) # doctest: +ELLIPSIS
684-
0.79...
684+
0.83...
685685

686686

687687

doc/whats_new.rst

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Release history
66
===============
77

88
Version 0.19
9-
============
9+
==============
1010

1111
**In Development**
1212

@@ -65,6 +65,14 @@ New features
6565
Enhancements
6666
............
6767

68+
- Added a `'eleven-point'` interpolated average precision option to
69+
:func:`metrics.ranking.average_precision_score` as described in the
70+
`PASCAL
71+
Visual Object Classes (VOC) Challenge <http://citeseerx.ist.psu.edu/viewdoc/
72+
download?doi=10.1.1.157.5766&rep=rep1&type=pdf>`_.
73+
(`#7356 <https://github.com/scikit-learn/scikit-learn/pull/7356>`_). By
74+
`Nick Dingwall`_ and `Gael Varoquaux`_
75+
6876
- Update Sphinx-Gallery from 0.1.4 to 0.1.7 for resolving links in
6977
documentation build with Sphinx>1.5 :issue:`8010`, :issue:`7986`
7078
:user:`Oscar Najera <Titan-C>`
@@ -193,6 +201,13 @@ Enhancements
193201
Bug fixes
194202
.........
195203

204+
- :func:`metrics.ranking.average_precision_score` no longer linearly
205+
interpolates between operating points, and instead weights precisions
206+
by the change in recall since the last operating point, as per the
207+
`Wikipedia entry <http://en.wikipedia.org/wiki/Average_precision>`_.
208+
(`#7356 <https://github.com/scikit-learn/scikit-learn/pull/7356>`_). By
209+
`Nick Dingwall`_ and `Gael Varoquaux`_.
210+
196211
- Fixed a bug in :class:`sklearn.covariance.MinCovDet` where inputting data
197212
that produced a singular covariance matrix would cause the helper method
198213
`_c_step` to throw an exception.

examples/model_selection/plot_precision_recall.py

Lines changed: 216 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55
66
Example of Precision-Recall metric to evaluate classifier output quality.
77
8-
In information retrieval, precision is a measure of result relevancy, while
9-
recall is a measure of how many truly relevant results are returned. A high
10-
area under the curve represents both high recall and high precision, where high
11-
precision relates to a low false positive rate, and high recall relates to a
12-
low false negative rate. High scores for both show that the classifier is
13-
returning accurate results (high precision), as well as returning a majority of
14-
all positive results (high recall).
8+
Precision-Recall is a useful measure of success of prediction when the
9+
classes are very imbalanced. In information retrieval, precision is a
10+
measure of result relevancy, while recall is a measure of how many truly
11+
relevant results are returned.
12+
13+
The precision-recall curve shows the tradeoff between precision and
14+
recall for different threshold. A high area under the curve represents
15+
both high recall and high precision, where high precision relates to a
16+
low false positive rate, and high recall relates to a low false negative
17+
rate. High scores for both show that the classifier is returning accurate
18+
results (high precision), as well as returning a majority of all positive
19+
results (high recall).
1520
1621
A system with high recall but low precision returns many results, but most of
1722
its predicted labels are incorrect when compared to the training labels. A
@@ -37,7 +42,7 @@
3742
3843
:math:`F1 = 2\\frac{P \\times R}{P+R}`
3944
40-
It is important to note that the precision may not decrease with recall. The
45+
Note that the precision may not decrease with recall. The
4146
definition of precision (:math:`\\frac{T_p}{T_p + F_p}`) shows that lowering
4247
the threshold of a classifier may increase the denominator, by increasing the
4348
number of results returned. If the threshold was previously set too high, the
@@ -54,11 +59,20 @@
5459
The relationship between recall and precision can be observed in the
5560
stairstep area of the plot - at the edges of these steps a small change
5661
in the threshold considerably reduces precision, with only a minor gain in
57-
recall. See the corner at recall = .59, precision = .8 for an example of this
58-
phenomenon.
62+
recall.
63+
64+
**Average precision** summarizes such a plot as the weighted mean of precisions
65+
achieved at each threshold, with the increase in recall from the previous
66+
threshold used as the weight:
67+
68+
:math:`\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n`
69+
70+
where :math:`P_n` and :math:`R_n` are the precision and recall at the
71+
nth threshold. A pair :math:`(R_k, P_k)` is referred to as an
72+
*operating point*.
5973
6074
Precision-recall curves are typically used in binary classification to study
61-
the output of a classifier. In order to extend Precision-recall curve and
75+
the output of a classifier. In order to extend the Precision-recall curve and
6276
average precision to multi-class or multi-label classification, it is necessary
6377
to binarize the output. One curve can be drawn per label, but one can also draw
6478
a precision-recall curve by considering each element of the label indicator
@@ -71,76 +85,146 @@
7185
:func:`sklearn.metrics.precision_score`,
7286
:func:`sklearn.metrics.f1_score`
7387
"""
74-
print(__doc__)
75-
76-
import matplotlib.pyplot as plt
77-
import numpy as np
78-
from itertools import cycle
88+
from __future__ import print_function
7989

90+
###############################################################################
91+
# In binary classification settings
92+
# --------------------------------------------------------
93+
#
94+
# Create simple data
95+
# ..................
96+
#
97+
# Try to differentiate the two first classes of the iris data
8098
from sklearn import svm, datasets
81-
from sklearn.metrics import precision_recall_curve
82-
from sklearn.metrics import average_precision_score
8399
from sklearn.model_selection import train_test_split
84-
from sklearn.preprocessing import label_binarize
85-
from sklearn.multiclass import OneVsRestClassifier
100+
import numpy as np
86101

87-
# import some data to play with
88102
iris = datasets.load_iris()
89103
X = iris.data
90104
y = iris.target
91105

92-
# setup plot details
93-
colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])
94-
lw = 2
95-
96-
# Binarize the output
97-
y = label_binarize(y, classes=[0, 1, 2])
98-
n_classes = y.shape[1]
99-
100106
# Add noisy features
101107
random_state = np.random.RandomState(0)
102108
n_samples, n_features = X.shape
103109
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
104110

111+
# Limit to the two first classes, and split into training and test
112+
X_train, X_test, y_train, y_test = train_test_split(X[y < 2], y[y < 2],
113+
test_size=.5,
114+
random_state=random_state)
115+
116+
# Create a simple classifier
117+
classifier = svm.LinearSVC(random_state=random_state)
118+
classifier.fit(X_train, y_train)
119+
y_score = classifier.decision_function(X_test)
120+
121+
###############################################################################
122+
# Compute the average precision score
123+
# ...................................
124+
from sklearn.metrics import average_precision_score
125+
average_precision = average_precision_score(y_test, y_score)
126+
127+
print('Average precision-recall score: {0:0.2f}'.format(
128+
average_precision))
129+
130+
###############################################################################
131+
# Plot the Precision-Recall curve
132+
# ................................
133+
from sklearn.metrics import precision_recall_curve
134+
import matplotlib.pyplot as plt
135+
136+
precision, recall, _ = precision_recall_curve(y_test, y_score)
137+
138+
plt.step(recall, precision, color='b', alpha=0.2,
139+
where='post')
140+
plt.fill_between(recall, precision, step='post', alpha=0.2,
141+
color='b')
142+
143+
plt.xlabel('Recall')
144+
plt.ylabel('Precision')
145+
plt.ylim([0.0, 1.05])
146+
plt.xlim([0.0, 1.0])
147+
plt.title('2-class Precision-Recall curve: AUC={0:0.2f}'.format(
148+
average_precision))
149+
150+
###############################################################################
151+
# In multi-label settings
152+
# ------------------------
153+
#
154+
# Create multli-label data, fit, and predict
155+
# ...........................................
156+
#
157+
# We create a multi-label dataset, to illustrate the precision-recall in
158+
# multi-label settings
159+
160+
from sklearn.preprocessing import label_binarize
161+
162+
# Use label_binarize to be multi-label like settings
163+
Y = label_binarize(y, classes=[0, 1, 2])
164+
n_classes = Y.shape[1]
165+
105166
# Split into training and test
106-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,
167+
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.5,
107168
random_state=random_state)
108169

170+
# We use OneVsRestClassifier for multi-label prediction
171+
from sklearn.multiclass import OneVsRestClassifier
109172
# Run classifier
110-
classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,
111-
random_state=random_state))
112-
y_score = classifier.fit(X_train, y_train).decision_function(X_test)
173+
classifier = OneVsRestClassifier(svm.LinearSVC(random_state=random_state))
174+
classifier.fit(X_train, Y_train)
175+
y_score = classifier.decision_function(X_test)
113176

114-
# Compute Precision-Recall and plot curve
177+
178+
###############################################################################
179+
# The precision-Recall score in multi-label settings
180+
# ..................................................
181+
from sklearn.metrics import precision_recall_curve
182+
from sklearn.metrics import average_precision_score
183+
184+
# For each class
115185
precision = dict()
116186
recall = dict()
117187
average_precision = dict()
118188
for i in range(n_classes):
119-
precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
189+
precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i],
120190
y_score[:, i])
121-
average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])
191+
average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i])
122192

123-
# Compute micro-average ROC curve and ROC area
124-
precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(),
193+
# A "micro-average": quantifying score on all classes jointly
194+
precision["micro"], recall["micro"], _ = precision_recall_curve(Y_test.ravel(),
125195
y_score.ravel())
126-
average_precision["micro"] = average_precision_score(y_test, y_score,
196+
average_precision["micro"] = average_precision_score(Y_test, y_score,
127197
average="micro")
198+
print('Precision-Recall micro-averaged over all classes: {0:0.2f}'.format(
199+
average_precision["micro"]))
128200

201+
###############################################################################
202+
# Plot the micro-averaged Precision-Recall curve
203+
# ...............................................
204+
#
205+
206+
plt.figure()
207+
plt.step(recall['micro'], precision['micro'], color='b', alpha=0.2,
208+
where='post')
209+
plt.fill_between(recall["micro"], precision["micro"], step='post', alpha=0.2,
210+
color='b')
129211

130-
# Plot Precision-Recall curve
131-
plt.clf()
132-
plt.plot(recall[0], precision[0], lw=lw, color='navy',
133-
label='Precision-Recall curve')
134212
plt.xlabel('Recall')
135213
plt.ylabel('Precision')
136214
plt.ylim([0.0, 1.05])
137215
plt.xlim([0.0, 1.0])
138-
plt.title('Precision-Recall example: AUC={0:0.2f}'.format(average_precision[0]))
139-
plt.legend(loc="lower left")
140-
plt.show()
216+
plt.title('Precision-Recall micro-averaged over all classes: AUC={0:0.2f}'
217+
.format(average_precision["micro"]))
141218

219+
###############################################################################
142220
# Plot Precision-Recall curve for each class and iso-f1 curves
143-
plt.clf()
221+
# .............................................................
222+
#
223+
from itertools import cycle
224+
# setup plot details
225+
colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])
226+
227+
plt.figure(figsize=(7, 8))
144228
f_scores = np.linspace(0.2, 0.8, num=4)
145229
lines = []
146230
labels = []
@@ -152,23 +236,102 @@
152236

153237
lines.append(l)
154238
labels.append('iso-f1 curves')
155-
l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw)
239+
l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=2)
156240
lines.append(l)
157-
labels.append('micro-average Precision-recall curve (area = {0:0.2f})'
241+
labels.append('micro-average Precision-recall (area = {0:0.2f})'
158242
''.format(average_precision["micro"]))
243+
159244
for i, color in zip(range(n_classes), colors):
160-
l, = plt.plot(recall[i], precision[i], color=color, lw=lw)
245+
l, = plt.plot(recall[i], precision[i], color=color, lw=2)
161246
lines.append(l)
162-
labels.append('Precision-recall curve of class {0} (area = {1:0.2f})'
247+
labels.append('Precision-recall for class {0} (area = {1:0.2f})'
163248
''.format(i, average_precision[i]))
164249

165250
fig = plt.gcf()
166-
fig.set_size_inches(7, 7)
167251
fig.subplots_adjust(bottom=0.25)
168252
plt.xlim([0.0, 1.0])
169253
plt.ylim([0.0, 1.05])
170254
plt.xlabel('Recall')
171255
plt.ylabel('Precision')
172256
plt.title('Extension of Precision-Recall curve to multi-class')
173-
plt.figlegend(lines, labels, loc='lower center')
257+
plt.legend(lines, labels, loc=(0, -.38), prop=dict(size=14))
258+
259+
260+
###############################################################################
261+
# Eleven-point average precision
262+
# ------------------------------
263+
#
264+
# In *interpolated* average precision, a set of desired recall values is
265+
# specified and for each desired value, we average the best precision
266+
# scores possible with a recall value at least equal to the target value.
267+
# The most common choice is 'eleven point' interpolated precision, where
268+
# the desired recall values are [0, 0.1, 0.2, ..., 1.0]. This is the
269+
# metric referenced in `The PASCAL Visual Object Classes (VOC) Challenge
270+
# <http://citeseerx.ist.psu.edu
271+
# /viewdoc/download?doi=10.1.1.157.5766&rep=rep1&type=pdf>`_ (top of page
272+
# 11, formula 1). In the example below, the eleven precision values are
273+
# indicated with an arrow to pointing to the best precision possible
274+
# while meeting or exceeding the desired recall. Note that it's possible
275+
# that the same operating point might correspond to multiple desired
276+
# recall values.
277+
278+
from operator import itemgetter
279+
280+
281+
def pick_eleven_points(recall_, precision_):
282+
"""Choose the eleven operating points that correspond
283+
to the best precision for any ``recall >= r`` for r in
284+
[0, 0.1, 0.2, ..., 1.0]
285+
"""
286+
operating_points = list()
287+
for target_recall in np.arange(0, 1.1, 0.1):
288+
operating_points_to_consider = [pair
289+
for pair in zip(recall_, precision_)
290+
if pair[0] >= target_recall]
291+
operating_points.append(max(operating_points_to_consider,
292+
key=itemgetter(1)))
293+
return operating_points
294+
295+
# Work on the 2nd class of iris
296+
iris_cls = 2
297+
298+
eleven_points = pick_eleven_points(recall[iris_cls], precision[iris_cls])
299+
interpolated_average_precision = np.mean([e[1] for e in eleven_points])
300+
301+
302+
print("Target recall Selected recall Precision")
303+
for i in range(11):
304+
print(" >= {} {: >12.3f} {: >12.3f}".format(i / 10,
305+
*eleven_points[i]))
306+
307+
print(" Average:{: >22.3f}".format(interpolated_average_precision))
308+
309+
###############################################################################
310+
# Plot illustrating eleven-point average precision
311+
# .................................................
312+
313+
plt.figure(figsize=(7, 7))
314+
plt.step(recall[iris_cls], precision[iris_cls], color='g', where='post',
315+
alpha=0.5, linewidth=2,
316+
label='Precision-recall curve of class {0} (area = {1:0.2f})'
317+
''.format(iris_cls, average_precision[iris_cls]))
318+
319+
plt.fill_between(recall[iris_cls], precision[iris_cls], step='post', alpha=0.1,
320+
color='g')
321+
for i in range(11):
322+
plt.annotate('',
323+
xy=(eleven_points[i][0], eleven_points[i][1]),
324+
xycoords='data', xytext=(i / 10., 0), textcoords='data',
325+
arrowprops=dict(arrowstyle="->", alpha=0.7,
326+
connectionstyle="angle3,angleA=90,angleB=45"))
327+
328+
329+
plt.xlim([0.0, 1.0])
330+
plt.ylim([0.0, 1.05])
331+
plt.xticks(np.arange(0, 1.1, 0.1))
332+
plt.xlabel('Recall')
333+
plt.ylabel('Precision')
334+
plt.title('Eleven point Precision Recall for class\\n {}'.format(iris_cls))
335+
plt.legend(loc="upper right")
336+
174337
plt.show()

0 commit comments

Comments
 (0)