|
5 | 5 |
|
6 | 6 | Example of Precision-Recall metric to evaluate classifier output quality.
|
7 | 7 |
|
8 |
| -In information retrieval, precision is a measure of result relevancy, while |
9 |
| -recall is a measure of how many truly relevant results are returned. A high |
10 |
| -area under the curve represents both high recall and high precision, where high |
11 |
| -precision relates to a low false positive rate, and high recall relates to a |
12 |
| -low false negative rate. High scores for both show that the classifier is |
13 |
| -returning accurate results (high precision), as well as returning a majority of |
14 |
| -all positive results (high recall). |
| 8 | +Precision-Recall is a useful measure of success of prediction when the |
| 9 | +classes are very imbalanced. In information retrieval, precision is a |
| 10 | +measure of result relevancy, while recall is a measure of how many truly |
| 11 | +relevant results are returned. |
| 12 | +
|
| 13 | +The precision-recall curve shows the tradeoff between precision and |
| 14 | +recall for different threshold. A high area under the curve represents |
| 15 | +both high recall and high precision, where high precision relates to a |
| 16 | +low false positive rate, and high recall relates to a low false negative |
| 17 | +rate. High scores for both show that the classifier is returning accurate |
| 18 | +results (high precision), as well as returning a majority of all positive |
| 19 | +results (high recall). |
15 | 20 |
|
16 | 21 | A system with high recall but low precision returns many results, but most of
|
17 | 22 | its predicted labels are incorrect when compared to the training labels. A
|
|
37 | 42 |
|
38 | 43 | :math:`F1 = 2\\frac{P \\times R}{P+R}`
|
39 | 44 |
|
40 |
| -It is important to note that the precision may not decrease with recall. The |
| 45 | +Note that the precision may not decrease with recall. The |
41 | 46 | definition of precision (:math:`\\frac{T_p}{T_p + F_p}`) shows that lowering
|
42 | 47 | the threshold of a classifier may increase the denominator, by increasing the
|
43 | 48 | number of results returned. If the threshold was previously set too high, the
|
|
54 | 59 | The relationship between recall and precision can be observed in the
|
55 | 60 | stairstep area of the plot - at the edges of these steps a small change
|
56 | 61 | in the threshold considerably reduces precision, with only a minor gain in
|
57 |
| -recall. See the corner at recall = .59, precision = .8 for an example of this |
58 |
| -phenomenon. |
| 62 | +recall. |
| 63 | +
|
| 64 | +**Average precision** summarizes such a plot as the weighted mean of precisions |
| 65 | +achieved at each threshold, with the increase in recall from the previous |
| 66 | +threshold used as the weight: |
| 67 | +
|
| 68 | +:math:`\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n` |
| 69 | +
|
| 70 | +where :math:`P_n` and :math:`R_n` are the precision and recall at the |
| 71 | +nth threshold. A pair :math:`(R_k, P_k)` is referred to as an |
| 72 | +*operating point*. |
59 | 73 |
|
60 | 74 | Precision-recall curves are typically used in binary classification to study
|
61 |
| -the output of a classifier. In order to extend Precision-recall curve and |
| 75 | +the output of a classifier. In order to extend the Precision-recall curve and |
62 | 76 | average precision to multi-class or multi-label classification, it is necessary
|
63 | 77 | to binarize the output. One curve can be drawn per label, but one can also draw
|
64 | 78 | a precision-recall curve by considering each element of the label indicator
|
|
71 | 85 | :func:`sklearn.metrics.precision_score`,
|
72 | 86 | :func:`sklearn.metrics.f1_score`
|
73 | 87 | """
|
74 |
| -print(__doc__) |
75 |
| - |
76 |
| -import matplotlib.pyplot as plt |
77 |
| -import numpy as np |
78 |
| -from itertools import cycle |
| 88 | +from __future__ import print_function |
79 | 89 |
|
| 90 | +############################################################################### |
| 91 | +# In binary classification settings |
| 92 | +# -------------------------------------------------------- |
| 93 | +# |
| 94 | +# Create simple data |
| 95 | +# .................. |
| 96 | +# |
| 97 | +# Try to differentiate the two first classes of the iris data |
80 | 98 | from sklearn import svm, datasets
|
81 |
| -from sklearn.metrics import precision_recall_curve |
82 |
| -from sklearn.metrics import average_precision_score |
83 | 99 | from sklearn.model_selection import train_test_split
|
84 |
| -from sklearn.preprocessing import label_binarize |
85 |
| -from sklearn.multiclass import OneVsRestClassifier |
| 100 | +import numpy as np |
86 | 101 |
|
87 |
| -# import some data to play with |
88 | 102 | iris = datasets.load_iris()
|
89 | 103 | X = iris.data
|
90 | 104 | y = iris.target
|
91 | 105 |
|
92 |
| -# setup plot details |
93 |
| -colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal']) |
94 |
| -lw = 2 |
95 |
| - |
96 |
| -# Binarize the output |
97 |
| -y = label_binarize(y, classes=[0, 1, 2]) |
98 |
| -n_classes = y.shape[1] |
99 |
| - |
100 | 106 | # Add noisy features
|
101 | 107 | random_state = np.random.RandomState(0)
|
102 | 108 | n_samples, n_features = X.shape
|
103 | 109 | X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
|
104 | 110 |
|
| 111 | +# Limit to the two first classes, and split into training and test |
| 112 | +X_train, X_test, y_train, y_test = train_test_split(X[y < 2], y[y < 2], |
| 113 | + test_size=.5, |
| 114 | + random_state=random_state) |
| 115 | + |
| 116 | +# Create a simple classifier |
| 117 | +classifier = svm.LinearSVC(random_state=random_state) |
| 118 | +classifier.fit(X_train, y_train) |
| 119 | +y_score = classifier.decision_function(X_test) |
| 120 | + |
| 121 | +############################################################################### |
| 122 | +# Compute the average precision score |
| 123 | +# ................................... |
| 124 | +from sklearn.metrics import average_precision_score |
| 125 | +average_precision = average_precision_score(y_test, y_score) |
| 126 | + |
| 127 | +print('Average precision-recall score: {0:0.2f}'.format( |
| 128 | + average_precision)) |
| 129 | + |
| 130 | +############################################################################### |
| 131 | +# Plot the Precision-Recall curve |
| 132 | +# ................................ |
| 133 | +from sklearn.metrics import precision_recall_curve |
| 134 | +import matplotlib.pyplot as plt |
| 135 | + |
| 136 | +precision, recall, _ = precision_recall_curve(y_test, y_score) |
| 137 | + |
| 138 | +plt.step(recall, precision, color='b', alpha=0.2, |
| 139 | + where='post') |
| 140 | +plt.fill_between(recall, precision, step='post', alpha=0.2, |
| 141 | + color='b') |
| 142 | + |
| 143 | +plt.xlabel('Recall') |
| 144 | +plt.ylabel('Precision') |
| 145 | +plt.ylim([0.0, 1.05]) |
| 146 | +plt.xlim([0.0, 1.0]) |
| 147 | +plt.title('2-class Precision-Recall curve: AUC={0:0.2f}'.format( |
| 148 | + average_precision)) |
| 149 | + |
| 150 | +############################################################################### |
| 151 | +# In multi-label settings |
| 152 | +# ------------------------ |
| 153 | +# |
| 154 | +# Create multli-label data, fit, and predict |
| 155 | +# ........................................... |
| 156 | +# |
| 157 | +# We create a multi-label dataset, to illustrate the precision-recall in |
| 158 | +# multi-label settings |
| 159 | + |
| 160 | +from sklearn.preprocessing import label_binarize |
| 161 | + |
| 162 | +# Use label_binarize to be multi-label like settings |
| 163 | +Y = label_binarize(y, classes=[0, 1, 2]) |
| 164 | +n_classes = Y.shape[1] |
| 165 | + |
105 | 166 | # Split into training and test
|
106 |
| -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, |
| 167 | +X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.5, |
107 | 168 | random_state=random_state)
|
108 | 169 |
|
| 170 | +# We use OneVsRestClassifier for multi-label prediction |
| 171 | +from sklearn.multiclass import OneVsRestClassifier |
109 | 172 | # Run classifier
|
110 |
| -classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True, |
111 |
| - random_state=random_state)) |
112 |
| -y_score = classifier.fit(X_train, y_train).decision_function(X_test) |
| 173 | +classifier = OneVsRestClassifier(svm.LinearSVC(random_state=random_state)) |
| 174 | +classifier.fit(X_train, Y_train) |
| 175 | +y_score = classifier.decision_function(X_test) |
113 | 176 |
|
114 |
| -# Compute Precision-Recall and plot curve |
| 177 | + |
| 178 | +############################################################################### |
| 179 | +# The precision-Recall score in multi-label settings |
| 180 | +# .................................................. |
| 181 | +from sklearn.metrics import precision_recall_curve |
| 182 | +from sklearn.metrics import average_precision_score |
| 183 | + |
| 184 | +# For each class |
115 | 185 | precision = dict()
|
116 | 186 | recall = dict()
|
117 | 187 | average_precision = dict()
|
118 | 188 | for i in range(n_classes):
|
119 |
| - precision[i], recall[i], _ = precision_recall_curve(y_test[:, i], |
| 189 | + precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], |
120 | 190 | y_score[:, i])
|
121 |
| - average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i]) |
| 191 | + average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i]) |
122 | 192 |
|
123 |
| -# Compute micro-average ROC curve and ROC area |
124 |
| -precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(), |
| 193 | +# A "micro-average": quantifying score on all classes jointly |
| 194 | +precision["micro"], recall["micro"], _ = precision_recall_curve(Y_test.ravel(), |
125 | 195 | y_score.ravel())
|
126 |
| -average_precision["micro"] = average_precision_score(y_test, y_score, |
| 196 | +average_precision["micro"] = average_precision_score(Y_test, y_score, |
127 | 197 | average="micro")
|
| 198 | +print('Precision-Recall micro-averaged over all classes: {0:0.2f}'.format( |
| 199 | + average_precision["micro"])) |
128 | 200 |
|
| 201 | +############################################################################### |
| 202 | +# Plot the micro-averaged Precision-Recall curve |
| 203 | +# ............................................... |
| 204 | +# |
| 205 | + |
| 206 | +plt.figure() |
| 207 | +plt.step(recall['micro'], precision['micro'], color='b', alpha=0.2, |
| 208 | + where='post') |
| 209 | +plt.fill_between(recall["micro"], precision["micro"], step='post', alpha=0.2, |
| 210 | + color='b') |
129 | 211 |
|
130 |
| -# Plot Precision-Recall curve |
131 |
| -plt.clf() |
132 |
| -plt.plot(recall[0], precision[0], lw=lw, color='navy', |
133 |
| - label='Precision-Recall curve') |
134 | 212 | plt.xlabel('Recall')
|
135 | 213 | plt.ylabel('Precision')
|
136 | 214 | plt.ylim([0.0, 1.05])
|
137 | 215 | plt.xlim([0.0, 1.0])
|
138 |
| -plt.title('Precision-Recall example: AUC={0:0.2f}'.format(average_precision[0])) |
139 |
| -plt.legend(loc="lower left") |
140 |
| -plt.show() |
| 216 | +plt.title('Precision-Recall micro-averaged over all classes: AUC={0:0.2f}' |
| 217 | + .format(average_precision["micro"])) |
141 | 218 |
|
| 219 | +############################################################################### |
142 | 220 | # Plot Precision-Recall curve for each class and iso-f1 curves
|
143 |
| -plt.clf() |
| 221 | +# ............................................................. |
| 222 | +# |
| 223 | +from itertools import cycle |
| 224 | +# setup plot details |
| 225 | +colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal']) |
| 226 | + |
| 227 | +plt.figure(figsize=(7, 8)) |
144 | 228 | f_scores = np.linspace(0.2, 0.8, num=4)
|
145 | 229 | lines = []
|
146 | 230 | labels = []
|
|
152 | 236 |
|
153 | 237 | lines.append(l)
|
154 | 238 | labels.append('iso-f1 curves')
|
155 |
| -l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw) |
| 239 | +l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=2) |
156 | 240 | lines.append(l)
|
157 |
| -labels.append('micro-average Precision-recall curve (area = {0:0.2f})' |
| 241 | +labels.append('micro-average Precision-recall (area = {0:0.2f})' |
158 | 242 | ''.format(average_precision["micro"]))
|
| 243 | + |
159 | 244 | for i, color in zip(range(n_classes), colors):
|
160 |
| - l, = plt.plot(recall[i], precision[i], color=color, lw=lw) |
| 245 | + l, = plt.plot(recall[i], precision[i], color=color, lw=2) |
161 | 246 | lines.append(l)
|
162 |
| - labels.append('Precision-recall curve of class {0} (area = {1:0.2f})' |
| 247 | + labels.append('Precision-recall for class {0} (area = {1:0.2f})' |
163 | 248 | ''.format(i, average_precision[i]))
|
164 | 249 |
|
165 | 250 | fig = plt.gcf()
|
166 |
| -fig.set_size_inches(7, 7) |
167 | 251 | fig.subplots_adjust(bottom=0.25)
|
168 | 252 | plt.xlim([0.0, 1.0])
|
169 | 253 | plt.ylim([0.0, 1.05])
|
170 | 254 | plt.xlabel('Recall')
|
171 | 255 | plt.ylabel('Precision')
|
172 | 256 | plt.title('Extension of Precision-Recall curve to multi-class')
|
173 |
| -plt.figlegend(lines, labels, loc='lower center') |
| 257 | +plt.legend(lines, labels, loc=(0, -.38), prop=dict(size=14)) |
| 258 | + |
| 259 | + |
| 260 | +############################################################################### |
| 261 | +# Eleven-point average precision |
| 262 | +# ------------------------------ |
| 263 | +# |
| 264 | +# In *interpolated* average precision, a set of desired recall values is |
| 265 | +# specified and for each desired value, we average the best precision |
| 266 | +# scores possible with a recall value at least equal to the target value. |
| 267 | +# The most common choice is 'eleven point' interpolated precision, where |
| 268 | +# the desired recall values are [0, 0.1, 0.2, ..., 1.0]. This is the |
| 269 | +# metric referenced in `The PASCAL Visual Object Classes (VOC) Challenge |
| 270 | +# <http://citeseerx.ist.psu.edu |
| 271 | +# /viewdoc/download?doi=10.1.1.157.5766&rep=rep1&type=pdf>`_ (top of page |
| 272 | +# 11, formula 1). In the example below, the eleven precision values are |
| 273 | +# indicated with an arrow to pointing to the best precision possible |
| 274 | +# while meeting or exceeding the desired recall. Note that it's possible |
| 275 | +# that the same operating point might correspond to multiple desired |
| 276 | +# recall values. |
| 277 | + |
| 278 | +from operator import itemgetter |
| 279 | + |
| 280 | + |
| 281 | +def pick_eleven_points(recall_, precision_): |
| 282 | + """Choose the eleven operating points that correspond |
| 283 | + to the best precision for any ``recall >= r`` for r in |
| 284 | + [0, 0.1, 0.2, ..., 1.0] |
| 285 | + """ |
| 286 | + operating_points = list() |
| 287 | + for target_recall in np.arange(0, 1.1, 0.1): |
| 288 | + operating_points_to_consider = [pair |
| 289 | + for pair in zip(recall_, precision_) |
| 290 | + if pair[0] >= target_recall] |
| 291 | + operating_points.append(max(operating_points_to_consider, |
| 292 | + key=itemgetter(1))) |
| 293 | + return operating_points |
| 294 | + |
| 295 | +# Work on the 2nd class of iris |
| 296 | +iris_cls = 2 |
| 297 | + |
| 298 | +eleven_points = pick_eleven_points(recall[iris_cls], precision[iris_cls]) |
| 299 | +interpolated_average_precision = np.mean([e[1] for e in eleven_points]) |
| 300 | + |
| 301 | + |
| 302 | +print("Target recall Selected recall Precision") |
| 303 | +for i in range(11): |
| 304 | + print(" >= {} {: >12.3f} {: >12.3f}".format(i / 10, |
| 305 | + *eleven_points[i])) |
| 306 | + |
| 307 | +print(" Average:{: >22.3f}".format(interpolated_average_precision)) |
| 308 | + |
| 309 | +############################################################################### |
| 310 | +# Plot illustrating eleven-point average precision |
| 311 | +# ................................................. |
| 312 | + |
| 313 | +plt.figure(figsize=(7, 7)) |
| 314 | +plt.step(recall[iris_cls], precision[iris_cls], color='g', where='post', |
| 315 | + alpha=0.5, linewidth=2, |
| 316 | + label='Precision-recall curve of class {0} (area = {1:0.2f})' |
| 317 | + ''.format(iris_cls, average_precision[iris_cls])) |
| 318 | + |
| 319 | +plt.fill_between(recall[iris_cls], precision[iris_cls], step='post', alpha=0.1, |
| 320 | + color='g') |
| 321 | +for i in range(11): |
| 322 | + plt.annotate('', |
| 323 | + xy=(eleven_points[i][0], eleven_points[i][1]), |
| 324 | + xycoords='data', xytext=(i / 10., 0), textcoords='data', |
| 325 | + arrowprops=dict(arrowstyle="->", alpha=0.7, |
| 326 | + connectionstyle="angle3,angleA=90,angleB=45")) |
| 327 | + |
| 328 | + |
| 329 | +plt.xlim([0.0, 1.0]) |
| 330 | +plt.ylim([0.0, 1.05]) |
| 331 | +plt.xticks(np.arange(0, 1.1, 0.1)) |
| 332 | +plt.xlabel('Recall') |
| 333 | +plt.ylabel('Precision') |
| 334 | +plt.title('Eleven point Precision Recall for class\\n {}'.format(iris_cls)) |
| 335 | +plt.legend(loc="upper right") |
| 336 | + |
174 | 337 | plt.show()
|
0 commit comments