Skip to content

PERF speed up confusion matrix calculation #26820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 77 additions & 16 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
check_consistent_length,
column_or_1d,
)
from ..utils._array_api import get_namespace
from ..utils._param_validation import Interval, Options, StrOptions, validate_params
from ..utils.extmath import _nanaverage
from ..utils.multiclass import type_of_target, unique_labels
Expand All @@ -54,7 +55,7 @@ def _check_zero_division(zero_division):
return np.nan


def _check_targets(y_true, y_pred):
def _check_targets(y_true, y_pred, unique_y_true=None, unique_y_pred=None):
"""Check that y_true and y_pred belong to the same classification task.

This converts multiclass or binary types to a common shape, and raises a
Expand All @@ -71,6 +72,12 @@ def _check_targets(y_true, y_pred):

y_pred : array-like

unique_y_true : array-like, default=None
Infered from y_true if None.

unique_y_pred : array-like, default=None
Infered from y_pred if None.

Returns
-------
type_true : one of {'multilabel-indicator', 'multiclass', 'binary'}
Expand All @@ -82,8 +89,13 @@ def _check_targets(y_true, y_pred):
y_pred : array or indicator matrix
"""
check_consistent_length(y_true, y_pred)
type_true = type_of_target(y_true, input_name="y_true")
type_pred = type_of_target(y_pred, input_name="y_pred")
xp, _ = get_namespace(y_true, y_pred)
unique_y_true = xp.unique_values(y_true) if unique_y_true is None else unique_y_true
unique_y_pred = xp.unique_values(y_pred) if unique_y_pred is None else unique_y_pred
n_classes = len(unique_labels(unique_y_true, unique_y_pred))

type_true = type_of_target(y_true, input_name="y_true", n_classes=n_classes)
type_pred = type_of_target(y_pred, input_name="y_pred", n_classes=n_classes)

y_type = {type_true, type_pred}
if y_type == {"binary", "multiclass"}:
Expand All @@ -108,16 +120,16 @@ def _check_targets(y_true, y_pred):
y_pred = column_or_1d(y_pred)
if y_type == "binary":
try:
unique_values = np.union1d(y_true, y_pred)
unique_values = np.union1d(unique_y_true, unique_y_pred)
except TypeError as e:
# We expect y_true and y_pred to be of the same data type.
# If `y_true` was provided to the classifier as strings,
# `y_pred` given by the classifier will also be encoded with
# strings. So we raise a meaningful error
raise TypeError(
"Labels in y_true and y_pred should be of the same type. "
f"Got y_true={np.unique(y_true)} and "
f"y_pred={np.unique(y_pred)}. Make sure that the "
f"Got y_true={unique_y_true} and "
f"y_pred={xp.unique_values(y_pred)}. Make sure that the "
"predictions provided by the classifier coincides with "
"the true labels."
) from e
Expand Down Expand Up @@ -403,7 +415,14 @@ def confusion_matrix(
prefer_skip_nested_validation=True,
)
def multilabel_confusion_matrix(
y_true, y_pred, *, sample_weight=None, labels=None, samplewise=False
y_true,
y_pred,
*,
sample_weight=None,
labels=None,
samplewise=False,
unique_y_true=None,
unique_y_pred=None,
Comment on lines +424 to +425
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

labels, unique_y_true and unique_y_pred makes the public API looks really bloated. The only alternative I see is to have a private function.

Copy link
Member

@glemaitre glemaitre Nov 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @thomasjpfan. I see only 2 solutions:

  • Make some private function where the public one is a wrapper around
  • Implementing our own efficient unique for NumPy arrays but it could be rather complex

(third solution is to contribute upstream :))

):
"""Compute a confusion matrix for each class or sample.

Expand Down Expand Up @@ -444,6 +463,12 @@ def multilabel_confusion_matrix(
samplewise : bool, default=False
In the multilabel case, this calculates a confusion matrix per sample.

unique_y_true : array-like, default=None
The unique values in y_true. If None, it will be inferred from y_true.

unique_y_pred : array-like, default=None
The unique values in y_pred. If None, it will be inferred from y_pred.

Returns
-------
multi_confusion : ndarray of shape (n_outputs, 2, 2)
Expand Down Expand Up @@ -502,15 +527,21 @@ def multilabel_confusion_matrix(
[[2, 1],
[1, 2]]])
"""
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
xp, _ = get_namespace(y_true, y_pred)
unique_y_true = xp.unique_values(y_true) if unique_y_true is None else unique_y_true
unique_y_pred = xp.unique_values(y_pred) if unique_y_pred is None else unique_y_pred

y_type, y_true, y_pred = _check_targets(
y_true, y_pred, unique_y_true, unique_y_pred
)
if sample_weight is not None:
sample_weight = column_or_1d(sample_weight)
check_consistent_length(y_true, y_pred, sample_weight)

if y_type not in ("binary", "multiclass", "multilabel-indicator"):
raise ValueError("%s is not supported" % y_type)

present_labels = unique_labels(y_true, y_pred)
present_labels = unique_labels(unique_y_true, unique_y_pred)
if labels is None:
labels = present_labels
n_labels = None
Expand Down Expand Up @@ -1487,7 +1518,9 @@ def _warn_prf(average, modifier, msg_start, result_size):
warnings.warn(msg, UndefinedMetricWarning, stacklevel=2)


def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
def _check_set_wise_labels(
y_true, y_pred, average, labels, pos_label, unique_y_true, unique_y_pred
):
"""Validation associated with set-wise metrics.

Returns identified labels.
Expand All @@ -1496,10 +1529,12 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
if average not in average_options and average != "binary":
raise ValueError("average has to be one of " + str(average_options))

y_type, y_true, y_pred = _check_targets(y_true, y_pred)
y_type, y_true, y_pred = _check_targets(
y_true, y_pred, unique_y_true, unique_y_pred
)
# Convert to Python primitive type to avoid NumPy type / Python str
# comparison. See https://github.com/numpy/numpy/issues/6784
present_labels = unique_labels(y_true, y_pred).tolist()
present_labels = unique_labels(unique_y_true, unique_y_pred).tolist()
if average == "binary":
if y_type == "binary":
if pos_label not in present_labels:
Expand Down Expand Up @@ -1559,6 +1594,8 @@ def precision_recall_fscore_support(
warn_for=("precision", "recall", "f-score"),
sample_weight=None,
zero_division="warn",
unique_y_true=None,
unique_y_pred=None,
):
"""Compute precision, recall, F-measure and support for each class.

Expand Down Expand Up @@ -1656,6 +1693,14 @@ def precision_recall_fscore_support(
.. versionadded:: 1.3
`np.nan` option was added.

unique_y_true : array-like, default=None
The unique values in ``y_true``. If ``None``, the unique values are
determined from the input.

unique_y_pred : array-like, default=None
The unique values in ``y_pred``. If ``None``, the unique values are
determined from the input.

Returns
-------
precision : float (if average is not None) or array of float, shape =\
Expand Down Expand Up @@ -1718,7 +1763,13 @@ def precision_recall_fscore_support(
array([2, 2, 2]))
"""
zero_division_value = _check_zero_division(zero_division)
labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label)
xp, _ = get_namespace(y_true, y_pred)
unique_y_true = xp.unique_values(y_true) if unique_y_true is None else unique_y_true
unique_y_pred = xp.unique_values(y_pred) if unique_y_pred is None else unique_y_pred

labels = _check_set_wise_labels(
y_true, y_pred, average, labels, pos_label, unique_y_true, unique_y_pred
)

# Calculate tp_sum, pred_sum, true_sum ###
samplewise = average == "samples"
Expand Down Expand Up @@ -2535,19 +2586,25 @@ class 2 1.00 0.67 0.80 3
weighted avg 1.00 0.67 0.80 3
<BLANKLINE>
"""
xp, _ = get_namespace(y_true, y_pred)
unique_y_true = xp.unique_values(y_true)
unique_y_pred = xp.unique_values(y_pred)
unique_all = unique_labels(unique_y_true, unique_y_pred)

y_type, y_true, y_pred = _check_targets(y_true, y_pred)
y_type, y_true, y_pred = _check_targets(
y_true, y_pred, unique_y_true=unique_y_true, unique_y_pred=unique_y_pred
)

if labels is None:
labels = unique_labels(y_true, y_pred)
labels = unique_all
labels_given = False
else:
labels = np.asarray(labels)
labels_given = True

# labelled micro average
micro_is_accuracy = (y_type == "multiclass" or y_type == "binary") and (
not labels_given or (set(labels) == set(unique_labels(y_true, y_pred)))
not labels_given or (set(labels) == set(unique_all))
)

if target_names is not None and len(labels) != len(target_names):
Expand Down Expand Up @@ -2575,6 +2632,8 @@ class 2 1.00 0.67 0.80 3
average=None,
sample_weight=sample_weight,
zero_division=zero_division,
unique_y_true=unique_y_true,
unique_y_pred=unique_y_pred,
)
rows = zip(target_names, p, r, f1, s)

Expand Down Expand Up @@ -2614,6 +2673,8 @@ class 2 1.00 0.67 0.80 3
average=average,
sample_weight=sample_weight,
zero_division=zero_division,
unique_y_true=unique_y_true,
unique_y_pred=unique_y_pred,
)
avg = [avg_p, avg_r, avg_f1, np.sum(s)]

Expand Down
10 changes: 8 additions & 2 deletions sklearn/utils/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def check_classification_targets(y):
)


def type_of_target(y, input_name=""):
def type_of_target(y, input_name="", n_classes=None):
"""Determine the type of data indicated by the target.

Note that this type is the most specific type that can be inferred.
Expand All @@ -242,6 +242,11 @@ def type_of_target(y, input_name=""):

.. versionadded:: 1.1.0

n_classes : int, default=None
Number of classes. Will be inferred from the data if not provided.

.. versionadded:: 1.4

Returns
-------
target_type : str
Expand Down Expand Up @@ -383,8 +388,9 @@ def type_of_target(y, input_name=""):
return "continuous" + suffix

# Check multiclass
n_classes = n_classes or xp.unique_values(y).shape[0]
first_row = y[0] if not issparse(y) else y.getrow(0).data
if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1):
if n_classes > 2 or (y.ndim == 2 and len(first_row) > 1):
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
return "multiclass" + suffix
else:
Expand Down