Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
939df89
ENH Adds plot_confusion matrix
thomasjpfan Aug 22, 2019
f106d48
DOC Adds attributes
thomasjpfan Sep 24, 2019
eb36a09
CLN Removes unneeded tests
thomasjpfan Sep 25, 2019
10d70e3
Merge remote-tracking branch 'upstream/master' into plot_confusion_ma…
thomasjpfan Sep 25, 2019
7f5f029
ENH Colormap dependent text labels
thomasjpfan Sep 25, 2019
4b29bb4
DOC Use better name
thomasjpfan Sep 25, 2019
835b889
ENH Adds format_values
thomasjpfan Oct 10, 2019
e1ed771
ENH Adds text formating
thomasjpfan Oct 10, 2019
ed6d2fa
Merge remote-tracking branch 'upstream/master' into plot_confusion_ma…
thomasjpfan Oct 10, 2019
3671311
API Changes normalize default
thomasjpfan Oct 10, 2019
6a21e80
DOC Adds confusion matrix to another example
thomasjpfan Oct 10, 2019
48c1281
TST Adds constrast test
thomasjpfan Oct 10, 2019
04df25d
CLN Uses display_labels
thomasjpfan Oct 23, 2019
511d60c
DOC Fix function call
thomasjpfan Oct 25, 2019
49e5afc
Merge remote-tracking branch 'upstream/master' into plot_confusion_ma…
thomasjpfan Oct 28, 2019
7db73bb
WIP
thomasjpfan Oct 28, 2019
7d3a802
WIP
thomasjpfan Oct 28, 2019
4eea3b1
BUG Address some comments
thomasjpfan Oct 29, 2019
24da8db
DOC Fixes confusion matrix style
thomasjpfan Oct 29, 2019
1407e75
Merge remote-tracking branch 'upstream/master' into plot_confusion_ma…
thomasjpfan Oct 30, 2019
ba39d07
CLN Address glemaitres comments
thomasjpfan Oct 30, 2019
fe8a572
DOC Adds plot_confusion_matrix to user guide
thomasjpfan Oct 30, 2019
9a9c24b
STY Flake8
thomasjpfan Oct 30, 2019
8272cbb
ENH Adds normalization options
thomasjpfan Nov 6, 2019
af7366b
Merge remote-tracking branch 'upstream/master' into plot_confusion_ma…
thomasjpfan Nov 6, 2019
449c0a1
DOC Fix
thomasjpfan Nov 6, 2019
8265856
CLN Updates options for normalization
thomasjpfan Nov 6, 2019
3fcecf6
CLN Reduce number of tests
thomasjpfan Nov 7, 2019
a89f662
STY Flake8
thomasjpfan Nov 7, 2019
c06843d
CLN Address comments
thomasjpfan Nov 8, 2019
9fcdecc
Merge remote-tracking branch 'origin/master' into pr/thomasjpfan/15083
glemaitre Nov 14, 2019
c13b84b
TST check fitted error with pipeline
glemaitre Nov 14, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1082,13 +1082,15 @@ See the :ref:`visualizations` section of the user guide for further details.
:toctree: generated/
:template: function.rst

metrics.plot_confusion_matrix
metrics.plot_precision_recall_curve
metrics.plot_roc_curve

.. autosummary::
:toctree: generated/
:template: class.rst

metrics.ConfusionMatrixDisplay
metrics.PrecisionRecallDisplay
metrics.RocCurveDisplay

Expand Down
6 changes: 4 additions & 2 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,10 @@ predicted to be in group :math:`j`. Here is an example::
[0, 0, 1],
[1, 0, 2]])

Here is a visual representation of such a confusion matrix (this figure comes
from the :ref:`sphx_glr_auto_examples_model_selection_plot_confusion_matrix.py` example):
:func:`plot_confusion_matrix` can be used to visually represent a confusion
matrix as shown in the
:ref:`sphx_glr_auto_examples_model_selection_plot_confusion_matrix.py`
example, which creates the following figure:

.. image:: ../auto_examples/model_selection/images/sphx_glr_plot_confusion_matrix_001.png
:target: ../auto_examples/model_selection/plot_confusion_matrix.html
Expand Down
2 changes: 2 additions & 0 deletions doc/visualizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Functions
.. autosummary::

inspection.plot_partial_dependence
metrics.plot_confusion_matrix
metrics.plot_precision_recall_curve
metrics.plot_roc_curve

Expand All @@ -84,5 +85,6 @@ Display Objects
.. autosummary::

inspection.PartialDependenceDisplay
metrics.ConfusionMatrixDisplay
metrics.PrecisionRecallDisplay
metrics.RocCurveDisplay
27 changes: 14 additions & 13 deletions examples/classification/plot_digits_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
# matplotlib.pyplot.imread. Note that each image must have the same size. For these
# images, we know which digit they represent: it is given in the 'target' of
# the dataset.
_, axes = plt.subplots(2, 4)
images_and_labels = list(zip(digits.images, digits.target))
for index, (image, label) in enumerate(images_and_labels[:4]):
plt.subplot(2, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Training: %i' % label)
for ax, (image, label) in zip(axes[0, :], images_and_labels[:4]):
ax.set_axis_off()
ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
ax.set_title('Training: %i' % label)

# To apply a classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
Expand All @@ -56,15 +56,16 @@
# Now predict the value of the digit on the second half:
predicted = classifier.predict(X_test)

images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted))
for ax, (image, prediction) in zip(axes[1, :], images_and_predictions[:4]):
ax.set_axis_off()
ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
ax.set_title('Prediction: %i' % prediction)

print("Classification report for classifier %s:\n%s\n"
% (classifier, metrics.classification_report(y_test, predicted)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(y_test, predicted))

images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted))
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
plt.subplot(2, 4, index + 5)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Prediction: %i' % prediction)
disp = metrics.plot_confusion_matrix(classifier, X_test, y_test)
disp.figure_.suptitle("Confusion Matrix")
print("Confusion matrix:\n%s" % disp.confusion_matrix)

plt.show()
78 changes: 13 additions & 65 deletions examples/model_selection/plot_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import plot_confusion_matrix

# import some data to play with
iris = datasets.load_iris()
Expand All @@ -45,72 +44,21 @@

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)


def plot_confusion_matrix(y_true, y_pred, classes,
normalize=False,
title=None,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Only use the labels that appear in the data
classes = classes[unique_labels(y_true, y_pred)]
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')

print(cm)

fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='True label',
xlabel='Predicted label')

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")

# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax

classifier = svm.SVC(kernel='linear', C=0.01).fit(X_train, y_train)

np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plot_confusion_matrix(y_test, y_pred, classes=class_names,
title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plot_confusion_matrix(y_test, y_pred, classes=class_names, normalize=True,
title='Normalized confusion matrix')
titles_options = [("Confusion matrix, without normalization", None),
("Normalized confusion matrix", 'true')]
for title, normalize in titles_options:
disp = plot_confusion_matrix(classifier, X_test, y_test,
display_labels=class_names,
cmap=plt.cm.Blues,
normalize=normalize)
disp.ax_.set_title(title)

print(title)
print(disp.confusion_matrix)

plt.show()
5 changes: 5 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
from ._plot.precision_recall_curve import plot_precision_recall_curve
from ._plot.precision_recall_curve import PrecisionRecallDisplay

from ._plot.confusion_matrix import plot_confusion_matrix
from ._plot.confusion_matrix import ConfusionMatrixDisplay


__all__ = [
'accuracy_score',
Expand All @@ -97,6 +100,7 @@
'cluster',
'cohen_kappa_score',
'completeness_score',
'ConfusionMatrixDisplay',
'confusion_matrix',
'consensus_score',
'coverage_error',
Expand Down Expand Up @@ -137,6 +141,7 @@
'pairwise_distances_argmin_min',
'pairwise_distances_chunked',
'pairwise_kernels',
'plot_confusion_matrix',
'plot_precision_recall_curve',
'plot_roc_curve',
'PrecisionRecallDisplay',
Expand Down
Loading