-
-
Notifications
You must be signed in to change notification settings - Fork 26k
PERF speed up confusion matrix calculation #26820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ | |
check_consistent_length, | ||
column_or_1d, | ||
) | ||
from ..utils._array_api import get_namespace | ||
from ..utils._param_validation import Interval, Options, StrOptions, validate_params | ||
from ..utils.extmath import _nanaverage | ||
from ..utils.multiclass import type_of_target, unique_labels | ||
|
@@ -54,7 +55,7 @@ def _check_zero_division(zero_division): | |
return np.nan | ||
|
||
|
||
def _check_targets(y_true, y_pred): | ||
def _check_targets(y_true, y_pred, unique_y_true=None, unique_y_pred=None): | ||
"""Check that y_true and y_pred belong to the same classification task. | ||
|
||
This converts multiclass or binary types to a common shape, and raises a | ||
|
@@ -71,6 +72,12 @@ def _check_targets(y_true, y_pred): | |
|
||
y_pred : array-like | ||
|
||
unique_y_true : array-like, default=None | ||
Infered from y_true if None. | ||
|
||
unique_y_pred : array-like, default=None | ||
Infered from y_pred if None. | ||
|
||
Returns | ||
------- | ||
type_true : one of {'multilabel-indicator', 'multiclass', 'binary'} | ||
|
@@ -82,8 +89,13 @@ def _check_targets(y_true, y_pred): | |
y_pred : array or indicator matrix | ||
""" | ||
check_consistent_length(y_true, y_pred) | ||
type_true = type_of_target(y_true, input_name="y_true") | ||
type_pred = type_of_target(y_pred, input_name="y_pred") | ||
xp, _ = get_namespace(y_true, y_pred) | ||
unique_y_true = xp.unique_values(y_true) if unique_y_true is None else unique_y_true | ||
unique_y_pred = xp.unique_values(y_pred) if unique_y_pred is None else unique_y_pred | ||
n_classes = len(unique_labels(unique_y_true, unique_y_pred)) | ||
|
||
type_true = type_of_target(y_true, input_name="y_true", n_classes=n_classes) | ||
type_pred = type_of_target(y_pred, input_name="y_pred", n_classes=n_classes) | ||
|
||
y_type = {type_true, type_pred} | ||
if y_type == {"binary", "multiclass"}: | ||
|
@@ -108,16 +120,16 @@ def _check_targets(y_true, y_pred): | |
y_pred = column_or_1d(y_pred) | ||
if y_type == "binary": | ||
try: | ||
unique_values = np.union1d(y_true, y_pred) | ||
unique_values = np.union1d(unique_y_true, unique_y_pred) | ||
except TypeError as e: | ||
# We expect y_true and y_pred to be of the same data type. | ||
# If `y_true` was provided to the classifier as strings, | ||
# `y_pred` given by the classifier will also be encoded with | ||
# strings. So we raise a meaningful error | ||
raise TypeError( | ||
"Labels in y_true and y_pred should be of the same type. " | ||
f"Got y_true={np.unique(y_true)} and " | ||
f"y_pred={np.unique(y_pred)}. Make sure that the " | ||
f"Got y_true={unique_y_true} and " | ||
f"y_pred={xp.unique_values(y_pred)}. Make sure that the " | ||
"predictions provided by the classifier coincides with " | ||
"the true labels." | ||
) from e | ||
|
@@ -403,7 +415,14 @@ def confusion_matrix( | |
prefer_skip_nested_validation=True, | ||
) | ||
def multilabel_confusion_matrix( | ||
y_true, y_pred, *, sample_weight=None, labels=None, samplewise=False | ||
y_true, | ||
y_pred, | ||
*, | ||
sample_weight=None, | ||
labels=None, | ||
samplewise=False, | ||
unique_y_true=None, | ||
unique_y_pred=None, | ||
Comment on lines
+424
to
+425
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @thomasjpfan. I see only 2 solutions:
(third solution is to contribute upstream :)) |
||
): | ||
"""Compute a confusion matrix for each class or sample. | ||
|
||
|
@@ -444,6 +463,12 @@ def multilabel_confusion_matrix( | |
samplewise : bool, default=False | ||
In the multilabel case, this calculates a confusion matrix per sample. | ||
|
||
unique_y_true : array-like, default=None | ||
The unique values in y_true. If None, it will be inferred from y_true. | ||
adrinjalali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
unique_y_pred : array-like, default=None | ||
The unique values in y_pred. If None, it will be inferred from y_pred. | ||
|
||
Returns | ||
------- | ||
multi_confusion : ndarray of shape (n_outputs, 2, 2) | ||
|
@@ -502,15 +527,21 @@ def multilabel_confusion_matrix( | |
[[2, 1], | ||
[1, 2]]]) | ||
""" | ||
y_type, y_true, y_pred = _check_targets(y_true, y_pred) | ||
xp, _ = get_namespace(y_true, y_pred) | ||
unique_y_true = xp.unique_values(y_true) if unique_y_true is None else unique_y_true | ||
unique_y_pred = xp.unique_values(y_pred) if unique_y_pred is None else unique_y_pred | ||
|
||
y_type, y_true, y_pred = _check_targets( | ||
y_true, y_pred, unique_y_true, unique_y_pred | ||
) | ||
if sample_weight is not None: | ||
sample_weight = column_or_1d(sample_weight) | ||
check_consistent_length(y_true, y_pred, sample_weight) | ||
|
||
if y_type not in ("binary", "multiclass", "multilabel-indicator"): | ||
raise ValueError("%s is not supported" % y_type) | ||
|
||
present_labels = unique_labels(y_true, y_pred) | ||
present_labels = unique_labels(unique_y_true, unique_y_pred) | ||
if labels is None: | ||
labels = present_labels | ||
n_labels = None | ||
|
@@ -1487,7 +1518,9 @@ def _warn_prf(average, modifier, msg_start, result_size): | |
warnings.warn(msg, UndefinedMetricWarning, stacklevel=2) | ||
|
||
|
||
def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): | ||
def _check_set_wise_labels( | ||
y_true, y_pred, average, labels, pos_label, unique_y_true, unique_y_pred | ||
): | ||
"""Validation associated with set-wise metrics. | ||
|
||
Returns identified labels. | ||
|
@@ -1496,10 +1529,12 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): | |
if average not in average_options and average != "binary": | ||
raise ValueError("average has to be one of " + str(average_options)) | ||
|
||
y_type, y_true, y_pred = _check_targets(y_true, y_pred) | ||
y_type, y_true, y_pred = _check_targets( | ||
y_true, y_pred, unique_y_true, unique_y_pred | ||
) | ||
# Convert to Python primitive type to avoid NumPy type / Python str | ||
# comparison. See https://github.com/numpy/numpy/issues/6784 | ||
present_labels = unique_labels(y_true, y_pred).tolist() | ||
present_labels = unique_labels(unique_y_true, unique_y_pred).tolist() | ||
if average == "binary": | ||
if y_type == "binary": | ||
if pos_label not in present_labels: | ||
|
@@ -1559,6 +1594,8 @@ def precision_recall_fscore_support( | |
warn_for=("precision", "recall", "f-score"), | ||
sample_weight=None, | ||
zero_division="warn", | ||
unique_y_true=None, | ||
unique_y_pred=None, | ||
): | ||
"""Compute precision, recall, F-measure and support for each class. | ||
|
||
|
@@ -1656,6 +1693,14 @@ def precision_recall_fscore_support( | |
.. versionadded:: 1.3 | ||
`np.nan` option was added. | ||
|
||
unique_y_true : array-like, default=None | ||
The unique values in ``y_true``. If ``None``, the unique values are | ||
determined from the input. | ||
|
||
unique_y_pred : array-like, default=None | ||
The unique values in ``y_pred``. If ``None``, the unique values are | ||
determined from the input. | ||
|
||
Returns | ||
------- | ||
precision : float (if average is not None) or array of float, shape =\ | ||
|
@@ -1718,7 +1763,13 @@ def precision_recall_fscore_support( | |
array([2, 2, 2])) | ||
""" | ||
zero_division_value = _check_zero_division(zero_division) | ||
labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label) | ||
xp, _ = get_namespace(y_true, y_pred) | ||
unique_y_true = xp.unique_values(y_true) if unique_y_true is None else unique_y_true | ||
unique_y_pred = xp.unique_values(y_pred) if unique_y_pred is None else unique_y_pred | ||
|
||
labels = _check_set_wise_labels( | ||
y_true, y_pred, average, labels, pos_label, unique_y_true, unique_y_pred | ||
) | ||
|
||
# Calculate tp_sum, pred_sum, true_sum ### | ||
samplewise = average == "samples" | ||
|
@@ -2535,19 +2586,25 @@ class 2 1.00 0.67 0.80 3 | |
weighted avg 1.00 0.67 0.80 3 | ||
<BLANKLINE> | ||
""" | ||
xp, _ = get_namespace(y_true, y_pred) | ||
unique_y_true = xp.unique_values(y_true) | ||
unique_y_pred = xp.unique_values(y_pred) | ||
unique_all = unique_labels(unique_y_true, unique_y_pred) | ||
|
||
y_type, y_true, y_pred = _check_targets(y_true, y_pred) | ||
y_type, y_true, y_pred = _check_targets( | ||
y_true, y_pred, unique_y_true=unique_y_true, unique_y_pred=unique_y_pred | ||
) | ||
|
||
if labels is None: | ||
labels = unique_labels(y_true, y_pred) | ||
labels = unique_all | ||
labels_given = False | ||
else: | ||
labels = np.asarray(labels) | ||
labels_given = True | ||
|
||
# labelled micro average | ||
micro_is_accuracy = (y_type == "multiclass" or y_type == "binary") and ( | ||
not labels_given or (set(labels) == set(unique_labels(y_true, y_pred))) | ||
not labels_given or (set(labels) == set(unique_all)) | ||
) | ||
|
||
if target_names is not None and len(labels) != len(target_names): | ||
|
@@ -2575,6 +2632,8 @@ class 2 1.00 0.67 0.80 3 | |
average=None, | ||
sample_weight=sample_weight, | ||
zero_division=zero_division, | ||
unique_y_true=unique_y_true, | ||
unique_y_pred=unique_y_pred, | ||
) | ||
rows = zip(target_names, p, r, f1, s) | ||
|
||
|
@@ -2614,6 +2673,8 @@ class 2 1.00 0.67 0.80 3 | |
average=average, | ||
sample_weight=sample_weight, | ||
zero_division=zero_division, | ||
unique_y_true=unique_y_true, | ||
unique_y_pred=unique_y_pred, | ||
) | ||
avg = [avg_p, avg_r, avg_f1, np.sum(s)] | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.