Skip to content

Commit a1a29c2

Browse files
MLopez-Ibanezjnothman
authored andcommitted
EXA plot_confusion_matrix example breaks down if not all classes present (scikit-learn#13126)
* fix scikit-learn#12700 plot_confusion_matrix example breaks down if not all classes are present in the test data * plot_confusion_matrix: update function call, fix style issues * remove redundant confusion_matrix call
1 parent c62a0e9 commit a1a29c2

File tree

1 file changed

+39
-24
lines changed

1 file changed

+39
-24
lines changed

examples/model_selection/plot_confusion_matrix.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626

2727
print(__doc__)
2828

29-
import itertools
3029
import numpy as np
3130
import matplotlib.pyplot as plt
3231

3332
from sklearn import svm, datasets
3433
from sklearn.model_selection import train_test_split
3534
from sklearn.metrics import confusion_matrix
35+
from sklearn.utils.multiclass import unique_labels
3636

3737
# import some data to play with
3838
iris = datasets.load_iris()
@@ -49,14 +49,24 @@
4949
y_pred = classifier.fit(X_train, y_train).predict(X_test)
5050

5151

52-
def plot_confusion_matrix(cm, classes,
52+
def plot_confusion_matrix(y_true, y_pred, classes,
5353
normalize=False,
54-
title='Confusion matrix',
54+
title=None,
5555
cmap=plt.cm.Blues):
5656
"""
5757
This function prints and plots the confusion matrix.
5858
Normalization can be applied by setting `normalize=True`.
5959
"""
60+
if not title:
61+
if normalize:
62+
title = 'Normalized confusion matrix'
63+
else:
64+
title = 'Confusion matrix, without normalization'
65+
66+
# Compute confusion matrix
67+
cm = confusion_matrix(y_true, y_pred)
68+
# Only use the labels that appear in the data
69+
classes = classes[unique_labels(y_true, y_pred)]
6070
if normalize:
6171
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
6272
print("Normalized confusion matrix")
@@ -65,37 +75,42 @@ def plot_confusion_matrix(cm, classes,
6575

6676
print(cm)
6777

68-
plt.imshow(cm, interpolation='nearest', cmap=cmap)
69-
plt.title(title)
70-
plt.colorbar()
71-
tick_marks = np.arange(len(classes))
72-
plt.xticks(tick_marks, classes, rotation=45)
73-
plt.yticks(tick_marks, classes)
74-
78+
fig, ax = plt.subplots()
79+
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
80+
ax.figure.colorbar(im, ax=ax)
81+
# We want to show all ticks...
82+
ax.set(xticks=np.arange(cm.shape[1]),
83+
yticks=np.arange(cm.shape[0]),
84+
# ... and label them with the respective list entries
85+
xticklabels=classes, yticklabels=classes,
86+
title=title,
87+
ylabel='True label',
88+
xlabel='Predicted label')
89+
90+
# Rotate the tick labels and set their alignment.
91+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
92+
rotation_mode="anchor")
93+
94+
# Loop over data dimensions and create text annotations.
7595
fmt = '.2f' if normalize else 'd'
7696
thresh = cm.max() / 2.
77-
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
78-
plt.text(j, i, format(cm[i, j], fmt),
79-
horizontalalignment="center",
80-
color="white" if cm[i, j] > thresh else "black")
81-
82-
plt.ylabel('True label')
83-
plt.xlabel('Predicted label')
84-
plt.tight_layout()
97+
for i in range(cm.shape[0]):
98+
for j in range(cm.shape[1]):
99+
ax.text(j, i, format(cm[i, j], fmt),
100+
ha="center", va="center",
101+
color="white" if cm[i, j] > thresh else "black")
102+
fig.tight_layout()
103+
return ax
85104

86105

87-
# Compute confusion matrix
88-
cnf_matrix = confusion_matrix(y_test, y_pred)
89106
np.set_printoptions(precision=2)
90107

91108
# Plot non-normalized confusion matrix
92-
plt.figure()
93-
plot_confusion_matrix(cnf_matrix, classes=class_names,
109+
plot_confusion_matrix(y_test, y_pred, classes=class_names,
94110
title='Confusion matrix, without normalization')
95111

96112
# Plot normalized confusion matrix
97-
plt.figure()
98-
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
113+
plot_confusion_matrix(y_test, y_pred, classes=class_names, normalize=True,
99114
title='Normalized confusion matrix')
100115

101116
plt.show()

0 commit comments

Comments
 (0)