Skip to content

Commit 8c6a045

Browse files
glemaitrethomasjpfanadrinjalali
authored
ENH/DEP add class method and deprecate plot function for confusion matrix (scikit-learn#18543)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent 7404a82 commit 8c6a045

File tree

7 files changed

+719
-11
lines changed

7 files changed

+719
-11
lines changed

doc/modules/model_evaluation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ predicted to be in group :math:`j`. Here is an example::
613613
[0, 0, 1],
614614
[1, 0, 2]])
615615

616-
:func:`plot_confusion_matrix` can be used to visually represent a confusion
616+
:class:`ConfusionMatrixDisplay` can be used to visually represent a confusion
617617
matrix as shown in the
618618
:ref:`sphx_glr_auto_examples_model_selection_plot_confusion_matrix.py`
619619
example, which creates the following figure:

doc/whats_new/v1.0.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ Changelog
9191
:pr:`17743` by :user:`Maria Telenczuk <maikia>` and
9292
:user:`Alexandre Gramfort <agramfort>`.
9393

94+
:mod:`sklearn.metrics`
95+
......................
96+
97+
- |API| :class:`metrics.ConfusionMatrixDisplay` exposes two class methods
98+
:func:`~metrics.ConfusionMatrixDisplay.from_estimator` and
99+
:func:`~metrics.ConfusionMatrixDisplay.from_predictions` allowing to create
100+
a confusion matrix plot using an estimator or the predictions.
101+
:func:`metrics.plot_confusion_matrix` is deprecated in favor of these two
102+
class methods and will be removed in 1.2.
103+
:pr:`18543` by `Guillaume Lemaitre`_.
104+
94105
:mod:`sklearn.naive_bayes`
95106
..........................
96107

examples/classification/plot_digits_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
# We can also plot a :ref:`confusion matrix <confusion_matrix>` of the
9696
# true digit values and the predicted digit values.
9797

98-
disp = metrics.plot_confusion_matrix(clf, X_test, y_test)
98+
disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)
9999
disp.figure_.suptitle("Confusion Matrix")
100100
print(f"Confusion matrix:\n{disp.confusion_matrix}")
101101

examples/model_selection/plot_confusion_matrix.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from sklearn import svm, datasets
3333
from sklearn.model_selection import train_test_split
34-
from sklearn.metrics import plot_confusion_matrix
34+
from sklearn.metrics import ConfusionMatrixDisplay
3535

3636
# import some data to play with
3737
iris = datasets.load_iris()
@@ -52,10 +52,10 @@
5252
titles_options = [("Confusion matrix, without normalization", None),
5353
("Normalized confusion matrix", 'true')]
5454
for title, normalize in titles_options:
55-
disp = plot_confusion_matrix(classifier, X_test, y_test,
56-
display_labels=class_names,
57-
cmap=plt.cm.Blues,
58-
normalize=normalize)
55+
disp = ConfusionMatrixDisplay.from_estimator(
56+
classifier, X_test, y_test, display_labels=class_names,
57+
cmap=plt.cm.Blues, normalize=normalize
58+
)
5959
disp.ax_.set_title(title)
6060

6161
print(title)

sklearn/metrics/_plot/confusion_matrix.py

Lines changed: 286 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .. import confusion_matrix
66
from ...utils import check_matplotlib_support
7+
from ...utils import deprecated
78
from ...utils.multiclass import unique_labels
89
from ...utils.validation import _deprecate_positional_args
910
from ...base import is_classifier
@@ -12,7 +13,9 @@
1213
class ConfusionMatrixDisplay:
1314
"""Confusion Matrix visualization.
1415
15-
It is recommend to use :func:`~sklearn.metrics.plot_confusion_matrix` to
16+
It is recommend to use
17+
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator` or
18+
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` to
1619
create a :class:`ConfusionMatrixDisplay`. All parameters are stored as
1720
attributes.
1821
@@ -161,7 +164,274 @@ def plot(self, *, include_values=True, cmap='viridis',
161164
self.ax_ = ax
162165
return self
163166

167+
@classmethod
168+
def from_estimator(
169+
cls,
170+
estimator,
171+
X,
172+
y,
173+
*,
174+
labels=None,
175+
sample_weight=None,
176+
normalize=None,
177+
display_labels=None,
178+
include_values=True,
179+
xticks_rotation="horizontal",
180+
values_format=None,
181+
cmap="viridis",
182+
ax=None,
183+
colorbar=True,
184+
):
185+
"""Plot Confusion Matrix given an estimator and some data.
186+
187+
Read more in the :ref:`User Guide <confusion_matrix>`.
188+
189+
.. versionadded:: 1.0
164190
191+
Parameters
192+
----------
193+
estimator : estimator instance
194+
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
195+
in which the last estimator is a classifier.
196+
197+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
198+
Input values.
199+
200+
y : array-like of shape (n_samples,)
201+
Target values.
202+
203+
labels : array-like of shape (n_classes,), default=None
204+
List of labels to index the confusion matrix. This may be used to
205+
reorder or select a subset of labels. If `None` is given, those
206+
that appear at least once in `y_true` or `y_pred` are used in
207+
sorted order.
208+
209+
sample_weight : array-like of shape (n_samples,), default=None
210+
Sample weights.
211+
212+
normalize : {'true', 'pred', 'all'}, default=None
213+
Either to normalize the counts display in the matrix:
214+
215+
- if `'true'`, the confusion matrix is normalized over the true
216+
conditions (e.g. rows);
217+
- if `'pred'`, the confusion matrix is normalized over the
218+
predicted conditions (e.g. columns);
219+
- if `'all'`, the confusion matrix is normalized by the total
220+
number of samples;
221+
- if `None` (default), the confusion matrix will not be normalized.
222+
223+
display_labels : array-like of shape (n_classes,), default=None
224+
Target names used for plotting. By default, `labels` will be used
225+
if it is defined, otherwise the unique labels of `y_true` and
226+
`y_pred` will be used.
227+
228+
include_values : bool, default=True
229+
Includes values in confusion matrix.
230+
231+
xticks_rotation : {'vertical', 'horizontal'} or float, \
232+
default='horizontal'
233+
Rotation of xtick labels.
234+
235+
values_format : str, default=None
236+
Format specification for values in confusion matrix. If `None`, the
237+
format specification is 'd' or '.2g' whichever is shorter.
238+
239+
cmap : str or matplotlib Colormap, default='viridis'
240+
Colormap recognized by matplotlib.
241+
242+
ax : matplotlib Axes, default=None
243+
Axes object to plot on. If `None`, a new figure and axes is
244+
created.
245+
246+
colorbar : bool, default=True
247+
Whether or not to add a colorbar to the plot.
248+
249+
Returns
250+
-------
251+
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
252+
253+
See Also
254+
--------
255+
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
256+
given the true and predicted labels.
257+
258+
Examples
259+
--------
260+
>>> import matplotlib.pyplot as plt # doctest: +SKIP
261+
>>> from sklearn.datasets import make_classification
262+
>>> from sklearn.metrics import ConfusionMatrixDisplay
263+
>>> from sklearn.model_selection import train_test_split
264+
>>> from sklearn.svm import SVC
265+
>>> X, y = make_classification(random_state=0)
266+
>>> X_train, X_test, y_train, y_test = train_test_split(
267+
... X, y, random_state=0)
268+
>>> clf = SVC(random_state=0)
269+
>>> clf.fit(X_train, y_train)
270+
SVC(random_state=0)
271+
>>> ConfusionMatrixDisplay.from_estimator(
272+
... clf, X_test, y_test) # doctest: +SKIP
273+
>>> plt.show() # doctest: +SKIP
274+
"""
275+
method_name = f"{cls.__name__}.from_estimator"
276+
check_matplotlib_support(method_name)
277+
if not is_classifier(estimator):
278+
raise ValueError(f"{method_name} only supports classifiers")
279+
y_pred = estimator.predict(X)
280+
281+
return cls.from_predictions(
282+
y,
283+
y_pred,
284+
sample_weight=sample_weight,
285+
labels=labels,
286+
normalize=normalize,
287+
display_labels=display_labels,
288+
include_values=include_values,
289+
cmap=cmap,
290+
ax=ax,
291+
xticks_rotation=xticks_rotation,
292+
values_format=values_format,
293+
colorbar=colorbar,
294+
)
295+
296+
@classmethod
297+
def from_predictions(
298+
cls,
299+
y_true,
300+
y_pred,
301+
*,
302+
labels=None,
303+
sample_weight=None,
304+
normalize=None,
305+
display_labels=None,
306+
include_values=True,
307+
xticks_rotation="horizontal",
308+
values_format=None,
309+
cmap="viridis",
310+
ax=None,
311+
colorbar=True,
312+
):
313+
"""Plot Confusion Matrix given true and predicted labels.
314+
315+
Read more in the :ref:`User Guide <confusion_matrix>`.
316+
317+
.. versionadded:: 0.24
318+
319+
Parameters
320+
----------
321+
y_true : array-like of shape (n_samples,)
322+
True labels.
323+
324+
y_pred : array-like of shape (n_samples,)
325+
The predicted labels given by the method `predict` of an
326+
classifier.
327+
328+
labels : array-like of shape (n_classes,), default=None
329+
List of labels to index the confusion matrix. This may be used to
330+
reorder or select a subset of labels. If `None` is given, those
331+
that appear at least once in `y_true` or `y_pred` are used in
332+
sorted order.
333+
334+
sample_weight : array-like of shape (n_samples,), default=None
335+
Sample weights.
336+
337+
normalize : {'true', 'pred', 'all'}, default=None
338+
Either to normalize the counts display in the matrix:
339+
340+
- if `'true'`, the confusion matrix is normalized over the true
341+
conditions (e.g. rows);
342+
- if `'pred'`, the confusion matrix is normalized over the
343+
predicted conditions (e.g. columns);
344+
- if `'all'`, the confusion matrix is normalized by the total
345+
number of samples;
346+
- if `None` (default), the confusion matrix will not be normalized.
347+
348+
display_labels : array-like of shape (n_classes,), default=None
349+
Target names used for plotting. By default, `labels` will be used
350+
if it is defined, otherwise the unique labels of `y_true` and
351+
`y_pred` will be used.
352+
353+
include_values : bool, default=True
354+
Includes values in confusion matrix.
355+
356+
xticks_rotation : {'vertical', 'horizontal'} or float, \
357+
default='horizontal'
358+
Rotation of xtick labels.
359+
360+
values_format : str, default=None
361+
Format specification for values in confusion matrix. If `None`, the
362+
format specification is 'd' or '.2g' whichever is shorter.
363+
364+
cmap : str or matplotlib Colormap, default='viridis'
365+
Colormap recognized by matplotlib.
366+
367+
ax : matplotlib Axes, default=None
368+
Axes object to plot on. If `None`, a new figure and axes is
369+
created.
370+
371+
colorbar : bool, default=True
372+
Whether or not to add a colorbar to the plot.
373+
374+
Returns
375+
-------
376+
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
377+
378+
See Also
379+
--------
380+
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
381+
given an estimator, the data, and the label.
382+
383+
Examples
384+
--------
385+
>>> import matplotlib.pyplot as plt # doctest: +SKIP
386+
>>> from sklearn.datasets import make_classification
387+
>>> from sklearn.metrics import ConfusionMatrixDisplay
388+
>>> from sklearn.model_selection import train_test_split
389+
>>> from sklearn.svm import SVC
390+
>>> X, y = make_classification(random_state=0)
391+
>>> X_train, X_test, y_train, y_test = train_test_split(
392+
... X, y, random_state=0)
393+
>>> clf = SVC(random_state=0)
394+
>>> clf.fit(X_train, y_train)
395+
SVC(random_state=0)
396+
>>> y_pred = clf.predict(X_test)
397+
>>> ConfusionMatrixDisplay.from_predictions(
398+
... y_test, y_pred) # doctest: +SKIP
399+
>>> plt.show() # doctest: +SKIP
400+
"""
401+
check_matplotlib_support(f"{cls.__name__}.from_predictions")
402+
403+
if display_labels is None:
404+
if labels is None:
405+
display_labels = unique_labels(y_true, y_pred)
406+
else:
407+
display_labels = labels
408+
409+
cm = confusion_matrix(
410+
y_true,
411+
y_pred,
412+
sample_weight=sample_weight,
413+
labels=labels,
414+
normalize=normalize,
415+
)
416+
417+
disp = cls(confusion_matrix=cm, display_labels=display_labels)
418+
419+
return disp.plot(
420+
include_values=include_values,
421+
cmap=cmap,
422+
ax=ax,
423+
xticks_rotation=xticks_rotation,
424+
values_format=values_format,
425+
colorbar=colorbar,
426+
)
427+
428+
429+
@deprecated(
430+
"Function plot_confusion_matrix is deprecated in 1.0 and will be "
431+
"removed in 1.2. Use one of the class methods: "
432+
"ConfusionMatrixDisplay.from_predictions or "
433+
"ConfusionMatrixDisplay.from_estimator."
434+
)
165435
@_deprecate_positional_args
166436
def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
167437
sample_weight=None, normalize=None,
@@ -173,6 +443,12 @@ def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
173443
174444
Read more in the :ref:`User Guide <confusion_matrix>`.
175445
446+
.. deprecated:: 1.0
447+
`plot_confusion_matrix` is deprecated in 1.0 and will be removed in
448+
1.2. Use one of the following class methods:
449+
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` or
450+
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator`.
451+
176452
Parameters
177453
----------
178454
estimator : estimator instance
@@ -194,9 +470,15 @@ def plot_confusion_matrix(estimator, X, y_true, *, labels=None,
194470
Sample weights.
195471
196472
normalize : {'true', 'pred', 'all'}, default=None
197-
Normalizes confusion matrix over the true (rows), predicted (columns)
198-
conditions or all the population. If None, confusion matrix will not be
199-
normalized.
473+
Either to normalize the counts display in the matrix:
474+
475+
- if `'true'`, the confusion matrix is normalized over the true
476+
conditions (e.g. rows);
477+
- if `'pred'`, the confusion matrix is normalized over the
478+
predicted conditions (e.g. columns);
479+
- if `'all'`, the confusion matrix is normalized by the total
480+
number of samples;
481+
- if `None` (default), the confusion matrix will not be normalized.
200482
201483
display_labels : array-like of shape (n_classes,), default=None
202484
Target names used for plotting. By default, `labels` will be used if

0 commit comments

Comments
 (0)