Skip to content

ENH add pos_label to confusion_matrix #26839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,15 @@ Changelog
:func:`sklearn.metrics.zero_one_loss` now support Array API compatible inputs.
:pr:`27137` by :user:`Edoardo Abati <EdAbati>`.

- |Enhancement| |Fix| Added a `pos_label` to :func:`metrics.confusion_matrix`
avoiding ambiguity regarding the position of the positive class label in the
matrix. An error is raised if the positive label cannot be set to `1` and
also if the `pos_label` is set on other classification
problems than binary. `pos_label` is also added to
:meth:`metrics.ConfusionMatrixDisplay.from_estimator` and
:meth:`metrics.ConfusionMatrixDisplay.from_predictions`.
:pr:`26839` by :user:`Guillaume Lemaitre <glemaitre>`.

- |API| Deprecated `needs_threshold` and `needs_proba` from :func:`metrics.make_scorer`.
These parameters will be removed in version 1.6. Instead, use `response_method` that
accepts `"predict"`, `"predict_proba"` or `"decision_function"` or a list of such
Expand Down
72 changes: 68 additions & 4 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,14 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
"y_true": ["array-like"],
"y_pred": ["array-like"],
"labels": ["array-like", None],
"pos_label": [Real, str, "boolean", None],
"sample_weight": ["array-like", None],
"normalize": [StrOptions({"true", "pred", "all"}), None],
},
prefer_skip_nested_validation=True,
)
def confusion_matrix(
y_true, y_pred, *, labels=None, sample_weight=None, normalize=None
y_true, y_pred, *, labels=None, pos_label=None, sample_weight=None, normalize=None
):
"""Compute confusion matrix to evaluate the accuracy of a classification.

Expand Down Expand Up @@ -260,6 +261,15 @@ def confusion_matrix(
If ``None`` is given, those that appear at least once
in ``y_true`` or ``y_pred`` are used in sorted order.

pos_label : int, float, bool or str, default=None
The label of the positive class for binary classification.
When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`,
`pos_label` is set to 1, otherwise an error will be raised.
An error is also raised if `pos_label` is set and `y_true` is not a binary
classification problem.

.. versionadded:: 1.4

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

Expand Down Expand Up @@ -320,6 +330,19 @@ def confusion_matrix(
if y_type not in ("binary", "multiclass"):
raise ValueError("%s is not supported" % y_type)

if y_true.size == 0 and y_pred.size == 0:
# early return for empty arrays avoiding all checks
n_classes = 0 if labels is None else len(labels)
return np.zeros((n_classes, n_classes), dtype=int)

if y_type == "binary":
pos_label = _check_pos_label_consistency(pos_label, y_true)
elif pos_label is not None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we are not shooting in our own foot with this ValueError.

If we change the default pos_label=1 then we will by default raise this error whenever y_type != "binary". If we choose, pos_label=1 as a default, then we need to have the same strategy than the precision recall meaning that we need to ignore pos_label.

Am I right @lucyleeow or I missed something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another advantages to implicitely ignoring pos_label is that we don't need anymore to call _check_target and thus the unique function because we can pass the default value most of the time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hen we will by default raise this error whenever y_type != "binary"

I think we could set default to 1 and change this to:

Suggested change
elif pos_label is not None:
elif pos_label != 1:

_check_set_wise_labels does similar checking:

elif pos_label not in (None, 1):
warnings.warn(
"Note that pos_label (set to %r) is ignored when "
"average != 'binary' (got %r). You may use "
"labels=[pos_label] to specify a single positive class."
% (pos_label, average),
UserWarning,
)

the precision_recall_fscore_support functions use it and default is pos_label=1 (None option is allowed byt not documented)

(I may have gotten something wrong, this stuff gets confusing)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise you could technically just flip the labels by using the 'reorder' functionality of labels though it is less clear and not consistent with other places where we use pos_label for binary case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'reorder' functionality of labels

Indeed but once I found this out, I was surprise because the API is kind of not consistent. I find it a bit better to have consistency across metrics if possible.

I think we could set default to 1 and change this to:

The problem is that you let the case where someone pass pos_label=1 explicitly in case other than "binary". So if we skip this test, I would rather prefer to be lenient as in other metrics.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean you would not raise an error if pos_label is not the default value and it is not the binary case?

Yes, no raising any error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Options:

  1. default is None and mention that it is equivalent to 1 in binary case
  2. default is 1 and completely ignore it in cases other than binary classification (do not raise error/warning)
  3. default is 1, in case off binary classification, raise warning if pos_label is not the default

Now I do not know what my preference would be. I would think that it is okay to raise a warning when pos_label is not default but target is not binary. We do miss the case when the user wants pos_label to be default. (1) has the problem of pos_label=None being ill-defined.

But maybe more opinions are needed here. I will try to summarise in #10010

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I will try to ping more dev in #10010 such that we can settle on a solution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be in favor of defaulting to None everywhere to make it easy to spot cases where users provide a non-default value in case we need to raise an informative error message.

For the specific case of confusion_matrix I think it we should not pass a pos_label as the matrix is always defined irrespective of what we consider positive or not (it does not care). What matters would be to pass the labels=class_labels when calling confusion_matrix indirectly from a public function such as f1_score which itself asks the user to specify a pos_label, so that there is no ambiguity about the meaning of the rows and columns of the CM.

To make this more efficient, I think type_of_target could be extended to also have an alternative type_of_target_and_class_labels to avoid computing xp.unique many times.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, when y is a dataframe or series we should use y.unique to make this more efficient (pandas does not need to sort the data to compute the unique values).

raise ValueError(
"`pos_label` should only be set when the target is binary. Got "
f"{y_type} type of target instead."
)

if labels is None:
labels = unique_labels(y_true, y_pred)
else:
Expand Down Expand Up @@ -382,6 +405,11 @@ def confusion_matrix(
cm = cm / cm.sum()
cm = np.nan_to_num(cm)

if pos_label is not None and pos_label != labels[-1]:
# Reorder the confusion matrix such that TP is at index
# [1, 1].
cm = cm[::-1, ::-1]

if cm.shape == (1, 1):
warnings.warn(
(
Expand Down Expand Up @@ -680,7 +708,17 @@ class labels [2]_.
.. [3] `Wikipedia entry for the Cohen's kappa
<https://en.wikipedia.org/wiki/Cohen%27s_kappa>`_.
"""
confusion = confusion_matrix(y1, y2, labels=labels, sample_weight=sample_weight)
y_type, y1, y2 = _check_targets(y1, y2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since _check_targets call unique, this adds more overhead to the computation.

REF: #26820

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to go around that. Somehow having private function to call could help and passing the unique label could help but I would not block this PR for it.

if y_type == "binary":
# we can set `pos_label` to any class labels because the computation of MCC
# is symmetric and invariant to `pos_label` switch.
pos_label = y1[0]
else:
pos_label = None

confusion = confusion_matrix(
y1, y2, labels=labels, pos_label=pos_label, sample_weight=sample_weight
)
n_classes = confusion.shape[0]
sum0 = np.sum(confusion, axis=0)
sum1 = np.sum(confusion, axis=1)
Expand Down Expand Up @@ -966,12 +1004,21 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None):
if y_type not in {"binary", "multiclass"}:
raise ValueError("%s is not supported" % y_type)

if y_type == "binary":
# we can set `pos_label` to any class labels because the computation of MCC
# is symmetric and invariant to `pos_label` switch.
pos_label = y_true[0]
else:
pos_label = None

lb = LabelEncoder()
lb.fit(np.hstack([y_true, y_pred]))
y_true = lb.transform(y_true)
y_pred = lb.transform(y_pred)

C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
C = confusion_matrix(
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
)
t_sum = C.sum(axis=1, dtype=np.float64)
p_sum = C.sum(axis=0, dtype=np.float64)
n_correct = np.trace(C, dtype=np.float64)
Expand Down Expand Up @@ -1921,11 +1968,18 @@ class after being classified as negative. This is the case when the
f"problems, got targets of type: {y_type}"
)

if labels is None:
classes = np.unique(y_true)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we should probably call _unique instead just to be sure.
I assume that we could introduce a pos_label as well.

pos_label = 1 if len(classes) < 2 else classes[1]
else:
pos_label = labels[-1]

cm = confusion_matrix(
y_true,
y_pred,
sample_weight=sample_weight,
labels=labels,
pos_label=pos_label,
)

# Case when `y_test` contains a single class and `y_test == y_pred`.
Expand Down Expand Up @@ -2396,7 +2450,17 @@ def balanced_accuracy_score(y_true, y_pred, *, sample_weight=None, adjusted=Fals
>>> balanced_accuracy_score(y_true, y_pred)
0.625
"""
C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
if y_type == "binary":
# We can set `pos_label` to any value since we are computing per-class
# statistics and averaging them.
pos_label = y_true[0]
else:
pos_label = None

C = confusion_matrix(
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
)
with np.errstate(divide="ignore", invalid="ignore"):
per_class = np.diag(C) / C.sum(axis=1)
if np.any(np.isnan(per_class)):
Expand Down
22 changes: 22 additions & 0 deletions sklearn/metrics/_plot/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def from_estimator(
y,
*,
labels=None,
pos_label=None,
sample_weight=None,
normalize=None,
display_labels=None,
Expand Down Expand Up @@ -238,6 +239,15 @@ def from_estimator(
that appear at least once in `y_true` or `y_pred` are used in
sorted order.

pos_label : int, float, bool or str, default=None
The label of the positive class for binary classification.
When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`,
`pos_label` is set to 1, otherwise an error will be raised.
An error is also raised if `pos_label` is set and `y_true` is not a binary
classification problem.

.. versionadded:: 1.4

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

Expand Down Expand Up @@ -323,6 +333,7 @@ def from_estimator(
y,
y_pred,
sample_weight=sample_weight,
pos_label=pos_label,
labels=labels,
normalize=normalize,
display_labels=display_labels,
Expand All @@ -343,6 +354,7 @@ def from_predictions(
y_pred,
*,
labels=None,
pos_label=None,
sample_weight=None,
normalize=None,
display_labels=None,
Expand Down Expand Up @@ -376,6 +388,15 @@ def from_predictions(
that appear at least once in `y_true` or `y_pred` are used in
sorted order.

pos_label : int, float, bool or str, default=None
The label of the positive class for binary classification.
When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`,
`pos_label` is set to 1, otherwise an error will be raised.
An error is also raised if `pos_label` is set and `y_true` is not a binary
classification problem.

.. versionadded:: 1.4

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

Expand Down Expand Up @@ -465,6 +486,7 @@ def from_predictions(
y_pred,
sample_weight=sample_weight,
labels=labels,
pos_label=pos_label,
normalize=normalize,
)

Expand Down
27 changes: 20 additions & 7 deletions sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy.stats import bernoulli

from sklearn import datasets, svm
from sklearn.datasets import make_multilabel_classification
from sklearn.datasets import make_classification, make_multilabel_classification
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import (
accuracy_score,
Expand Down Expand Up @@ -457,13 +457,26 @@ def test_precision_recall_f_unused_pos_label():
)


def test_confusion_matrix_binary():
def test_confusion_matrix_pos_label_error():
_, y = make_classification(n_classes=3, n_clusters_per_class=1, random_state=0)
err_msg = "`pos_label` should only be set when the target is binary."
with pytest.raises(ValueError, match=err_msg):
confusion_matrix(y, y, pos_label=1)


@pytest.mark.parametrize("pos_label", [0, 1])
def test_confusion_matrix_binary(pos_label):
# Test confusion matrix - binary classification case
y_true, y_pred, _ = make_prediction(binary=True)

def test(y_true, y_pred):
cm = confusion_matrix(y_true, y_pred)
assert_array_equal(cm, [[22, 3], [8, 17]])
def test(y_true, y_pred, pos_label):
cm = confusion_matrix(y_true, y_pred, pos_label=pos_label)
expected_cm = np.array([[22, 3], [8, 17]])
if pos_label in {"0", 0}:
# we should flip the confusion matrix to respect the documentation
# of tp, fp, fn, tn
expected_cm = expected_cm[::-1, ::-1]
assert_array_equal(cm, expected_cm)

tp, fp, fn, tn = cm.flatten()
num = tp * tn - fp * fn
Expand All @@ -474,8 +487,8 @@ def test(y_true, y_pred):
assert_array_almost_equal(mcc, true_mcc, decimal=2)
assert_array_almost_equal(mcc, 0.57, decimal=2)

test(y_true, y_pred)
test([str(y) for y in y_true], [str(y) for y in y_pred])
test(y_true, y_pred, pos_label)
test([str(y) for y in y_true], [str(y) for y in y_pred], str(pos_label))


def test_multilabel_confusion_matrix_binary():
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):

# Metrics with a "pos_label" argument
METRICS_WITH_POS_LABEL = {
"unnormalized_confusion_matrix",
"normalized_confusion_matrix",
"roc_curve",
"precision_recall_curve",
"det_curve",
Expand Down