26
26
27
27
print (__doc__ )
28
28
29
- import itertools
30
29
import numpy as np
31
30
import matplotlib .pyplot as plt
32
31
33
32
from sklearn import svm , datasets
34
33
from sklearn .model_selection import train_test_split
35
34
from sklearn .metrics import confusion_matrix
35
+ from sklearn .utils .multiclass import unique_labels
36
36
37
37
# import some data to play with
38
38
iris = datasets .load_iris ()
49
49
y_pred = classifier .fit (X_train , y_train ).predict (X_test )
50
50
51
51
52
- def plot_confusion_matrix (cm , classes ,
52
+ def plot_confusion_matrix (y_true , y_pred , classes ,
53
53
normalize = False ,
54
- title = 'Confusion matrix' ,
54
+ title = None ,
55
55
cmap = plt .cm .Blues ):
56
56
"""
57
57
This function prints and plots the confusion matrix.
58
58
Normalization can be applied by setting `normalize=True`.
59
59
"""
60
+ if not title :
61
+ if normalize :
62
+ title = 'Normalized confusion matrix'
63
+ else :
64
+ title = 'Confusion matrix, without normalization'
65
+
66
+ # Compute confusion matrix
67
+ cm = confusion_matrix (y_true , y_pred )
68
+ # Only use the labels that appear in the data
69
+ classes = classes [unique_labels (y_true , y_pred )]
60
70
if normalize :
61
71
cm = cm .astype ('float' ) / cm .sum (axis = 1 )[:, np .newaxis ]
62
72
print ("Normalized confusion matrix" )
@@ -65,37 +75,42 @@ def plot_confusion_matrix(cm, classes,
65
75
66
76
print (cm )
67
77
68
- plt .imshow (cm , interpolation = 'nearest' , cmap = cmap )
69
- plt .title (title )
70
- plt .colorbar ()
71
- tick_marks = np .arange (len (classes ))
72
- plt .xticks (tick_marks , classes , rotation = 45 )
73
- plt .yticks (tick_marks , classes )
74
-
78
+ fig , ax = plt .subplots ()
79
+ im = ax .imshow (cm , interpolation = 'nearest' , cmap = cmap )
80
+ ax .figure .colorbar (im , ax = ax )
81
+ # We want to show all ticks...
82
+ ax .set (xticks = np .arange (cm .shape [1 ]),
83
+ yticks = np .arange (cm .shape [0 ]),
84
+ # ... and label them with the respective list entries
85
+ xticklabels = classes , yticklabels = classes ,
86
+ title = title ,
87
+ ylabel = 'True label' ,
88
+ xlabel = 'Predicted label' )
89
+
90
+ # Rotate the tick labels and set their alignment.
91
+ plt .setp (ax .get_xticklabels (), rotation = 45 , ha = "right" ,
92
+ rotation_mode = "anchor" )
93
+
94
+ # Loop over data dimensions and create text annotations.
75
95
fmt = '.2f' if normalize else 'd'
76
96
thresh = cm .max () / 2.
77
- for i , j in itertools .product (range (cm .shape [0 ]), range (cm .shape [1 ])):
78
- plt .text (j , i , format (cm [i , j ], fmt ),
79
- horizontalalignment = "center" ,
80
- color = "white" if cm [i , j ] > thresh else "black" )
81
-
82
- plt .ylabel ('True label' )
83
- plt .xlabel ('Predicted label' )
84
- plt .tight_layout ()
97
+ for i in range (cm .shape [0 ]):
98
+ for j in range (cm .shape [1 ]):
99
+ ax .text (j , i , format (cm [i , j ], fmt ),
100
+ ha = "center" , va = "center" ,
101
+ color = "white" if cm [i , j ] > thresh else "black" )
102
+ fig .tight_layout ()
103
+ return ax
85
104
86
105
87
- # Compute confusion matrix
88
- cnf_matrix = confusion_matrix (y_test , y_pred )
89
106
np .set_printoptions (precision = 2 )
90
107
91
108
# Plot non-normalized confusion matrix
92
- plt .figure ()
93
- plot_confusion_matrix (cnf_matrix , classes = class_names ,
109
+ plot_confusion_matrix (y_test , y_pred , classes = class_names ,
94
110
title = 'Confusion matrix, without normalization' )
95
111
96
112
# Plot normalized confusion matrix
97
- plt .figure ()
98
- plot_confusion_matrix (cnf_matrix , classes = class_names , normalize = True ,
113
+ plot_confusion_matrix (y_test , y_pred , classes = class_names , normalize = True ,
99
114
title = 'Normalized confusion matrix' )
100
115
101
116
plt .show ()
0 commit comments