Skip to content

Commit 96b53ad

Browse files
OmarManzoorogrisel
andauthored
ENH Array API support for f1_score and multilabel_confusion_matrix (#27369)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 0d56723 commit 96b53ad

File tree

9 files changed

+284
-120
lines changed

9 files changed

+284
-120
lines changed

doc/modules/array_api.rst

+3
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ Estimators
9494
- :class:`linear_model.Ridge` (with `solver="svd"`)
9595
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
9696
- :class:`preprocessing.KernelCenterer`
97+
- :class:`preprocessing.LabelEncoder`
9798
- :class:`preprocessing.MaxAbsScaler`
9899
- :class:`preprocessing.MinMaxScaler`
99100
- :class:`preprocessing.Normalizer`
@@ -115,6 +116,7 @@ Metrics
115116
- :func:`sklearn.metrics.cluster.entropy`
116117
- :func:`sklearn.metrics.accuracy_score`
117118
- :func:`sklearn.metrics.d2_tweedie_score`
119+
- :func:`sklearn.metrics.f1_score`
118120
- :func:`sklearn.metrics.max_error`
119121
- :func:`sklearn.metrics.mean_absolute_error`
120122
- :func:`sklearn.metrics.mean_absolute_percentage_error`
@@ -123,6 +125,7 @@ Metrics
123125
- :func:`sklearn.metrics.mean_squared_error`
124126
- :func:`sklearn.metrics.mean_squared_log_error`
125127
- :func:`sklearn.metrics.mean_tweedie_deviance`
128+
- :func:`sklearn.metrics.multilabel_confusion_matrix`
126129
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
127130
- :func:`sklearn.metrics.pairwise.chi2_kernel`
128131
- :func:`sklearn.metrics.pairwise.cosine_similarity`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- :func:`sklearn.metrics.f1_score` now supports Array API compatible
2+
inputs.
3+
By :user:`Omar Salman <OmarManzoor>`

sklearn/metrics/_classification.py

+102-64
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from numbers import Integral, Real
1616

1717
import numpy as np
18-
from scipy.sparse import coo_matrix, csr_matrix
18+
from scipy.sparse import coo_matrix, csr_matrix, issparse
1919
from scipy.special import xlogy
2020

2121
from ..exceptions import UndefinedMetricWarning
@@ -28,9 +28,15 @@
2828
)
2929
from ..utils._array_api import (
3030
_average,
31+
_bincount,
3132
_count_nonzero,
33+
_find_matching_floating_dtype,
3234
_is_numpy_namespace,
35+
_searchsorted,
36+
_setdiff1d,
37+
_tolist,
3338
_union1d,
39+
device,
3440
get_namespace,
3541
get_namespace_and_device,
3642
)
@@ -521,9 +527,11 @@ def multilabel_confusion_matrix(
521527
[1, 2]]])
522528
"""
523529
y_true, y_pred = attach_unique(y_true, y_pred)
530+
xp, _ = get_namespace(y_true, y_pred)
531+
device_ = device(y_true, y_pred)
524532
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
525533
if sample_weight is not None:
526-
sample_weight = column_or_1d(sample_weight)
534+
sample_weight = column_or_1d(sample_weight, device=device_)
527535
check_consistent_length(y_true, y_pred, sample_weight)
528536

529537
if y_type not in ("binary", "multiclass", "multilabel-indicator"):
@@ -534,9 +542,11 @@ def multilabel_confusion_matrix(
534542
labels = present_labels
535543
n_labels = None
536544
else:
537-
n_labels = len(labels)
538-
labels = np.hstack(
539-
[labels, np.setdiff1d(present_labels, labels, assume_unique=True)]
545+
labels = xp.asarray(labels, device=device_)
546+
n_labels = labels.shape[0]
547+
labels = xp.concat(
548+
[labels, _setdiff1d(present_labels, labels, assume_unique=True, xp=xp)],
549+
axis=-1,
540550
)
541551

542552
if y_true.ndim == 1:
@@ -556,77 +566,102 @@ def multilabel_confusion_matrix(
556566
tp = y_true == y_pred
557567
tp_bins = y_true[tp]
558568
if sample_weight is not None:
559-
tp_bins_weights = np.asarray(sample_weight)[tp]
569+
tp_bins_weights = sample_weight[tp]
560570
else:
561571
tp_bins_weights = None
562572

563-
if len(tp_bins):
564-
tp_sum = np.bincount(
565-
tp_bins, weights=tp_bins_weights, minlength=len(labels)
573+
if tp_bins.shape[0]:
574+
tp_sum = _bincount(
575+
tp_bins, weights=tp_bins_weights, minlength=labels.shape[0], xp=xp
566576
)
567577
else:
568578
# Pathological case
569-
true_sum = pred_sum = tp_sum = np.zeros(len(labels))
570-
if len(y_pred):
571-
pred_sum = np.bincount(y_pred, weights=sample_weight, minlength=len(labels))
572-
if len(y_true):
573-
true_sum = np.bincount(y_true, weights=sample_weight, minlength=len(labels))
579+
true_sum = pred_sum = tp_sum = xp.zeros(labels.shape[0])
580+
if y_pred.shape[0]:
581+
pred_sum = _bincount(
582+
y_pred, weights=sample_weight, minlength=labels.shape[0], xp=xp
583+
)
584+
if y_true.shape[0]:
585+
true_sum = _bincount(
586+
y_true, weights=sample_weight, minlength=labels.shape[0], xp=xp
587+
)
574588

575589
# Retain only selected labels
576-
indices = np.searchsorted(sorted_labels, labels[:n_labels])
577-
tp_sum = tp_sum[indices]
578-
true_sum = true_sum[indices]
579-
pred_sum = pred_sum[indices]
590+
indices = _searchsorted(sorted_labels, labels[:n_labels], xp=xp)
591+
tp_sum = xp.take(tp_sum, indices, axis=0)
592+
true_sum = xp.take(true_sum, indices, axis=0)
593+
pred_sum = xp.take(pred_sum, indices, axis=0)
580594

581595
else:
582596
sum_axis = 1 if samplewise else 0
583597

584598
# All labels are index integers for multilabel.
585599
# Select labels:
586-
if not np.array_equal(labels, present_labels):
587-
if np.max(labels) > np.max(present_labels):
600+
if labels.shape != present_labels.shape or xp.any(
601+
xp.not_equal(labels, present_labels)
602+
):
603+
if xp.max(labels) > xp.max(present_labels):
588604
raise ValueError(
589605
"All labels must be in [0, n labels) for "
590606
"multilabel targets. "
591-
"Got %d > %d" % (np.max(labels), np.max(present_labels))
607+
"Got %d > %d" % (xp.max(labels), xp.max(present_labels))
592608
)
593-
if np.min(labels) < 0:
609+
if xp.min(labels) < 0:
594610
raise ValueError(
595611
"All labels must be in [0, n labels) for "
596612
"multilabel targets. "
597-
"Got %d < 0" % np.min(labels)
613+
"Got %d < 0" % xp.min(labels)
598614
)
599615

600616
if n_labels is not None:
601617
y_true = y_true[:, labels[:n_labels]]
602618
y_pred = y_pred[:, labels[:n_labels]]
603619

620+
if issparse(y_true) or issparse(y_pred):
621+
true_and_pred = y_true.multiply(y_pred)
622+
else:
623+
true_and_pred = xp.multiply(y_true, y_pred)
624+
604625
# calculate weighted counts
605-
true_and_pred = y_true.multiply(y_pred)
606-
tp_sum = count_nonzero(
607-
true_and_pred, axis=sum_axis, sample_weight=sample_weight
626+
tp_sum = _count_nonzero(
627+
true_and_pred,
628+
axis=sum_axis,
629+
sample_weight=sample_weight,
630+
xp=xp,
631+
device=device_,
632+
)
633+
pred_sum = _count_nonzero(
634+
y_pred,
635+
axis=sum_axis,
636+
sample_weight=sample_weight,
637+
xp=xp,
638+
device=device_,
639+
)
640+
true_sum = _count_nonzero(
641+
y_true,
642+
axis=sum_axis,
643+
sample_weight=sample_weight,
644+
xp=xp,
645+
device=device_,
608646
)
609-
pred_sum = count_nonzero(y_pred, axis=sum_axis, sample_weight=sample_weight)
610-
true_sum = count_nonzero(y_true, axis=sum_axis, sample_weight=sample_weight)
611647

612648
fp = pred_sum - tp_sum
613649
fn = true_sum - tp_sum
614650
tp = tp_sum
615651

616652
if sample_weight is not None and samplewise:
617-
sample_weight = np.array(sample_weight)
618-
tp = np.array(tp)
619-
fp = np.array(fp)
620-
fn = np.array(fn)
653+
tp = xp.asarray(tp)
654+
fp = xp.asarray(fp)
655+
fn = xp.asarray(fn)
621656
tn = sample_weight * y_true.shape[1] - tp - fp - fn
622657
elif sample_weight is not None:
623-
tn = sum(sample_weight) - tp - fp - fn
658+
tn = xp.sum(sample_weight) - tp - fp - fn
624659
elif samplewise:
625660
tn = y_true.shape[1] - tp - fp - fn
626661
else:
627662
tn = y_true.shape[0] - tp - fp - fn
628663

629-
return np.array([tn, fp, fn, tp]).T.reshape(-1, 2, 2)
664+
return xp.reshape(xp.stack([tn, fp, fn, tp]).T, (-1, 2, 2))
630665

631666

632667
@validate_params(
@@ -1262,21 +1297,21 @@ def f1_score(
12621297
>>> y_true = [0, 1, 2, 0, 1, 2]
12631298
>>> y_pred = [0, 2, 1, 0, 0, 1]
12641299
>>> f1_score(y_true, y_pred, average='macro')
1265-
np.float64(0.26...)
1300+
0.26...
12661301
>>> f1_score(y_true, y_pred, average='micro')
1267-
np.float64(0.33...)
1302+
0.33...
12681303
>>> f1_score(y_true, y_pred, average='weighted')
1269-
np.float64(0.26...)
1304+
0.26...
12701305
>>> f1_score(y_true, y_pred, average=None)
12711306
array([0.8, 0. , 0. ])
12721307
12731308
>>> # binary classification
12741309
>>> y_true_empty = [0, 0, 0, 0, 0, 0]
12751310
>>> y_pred_empty = [0, 0, 0, 0, 0, 0]
12761311
>>> f1_score(y_true_empty, y_pred_empty)
1277-
np.float64(0.0...)
1312+
0.0...
12781313
>>> f1_score(y_true_empty, y_pred_empty, zero_division=1.0)
1279-
np.float64(1.0...)
1314+
1.0...
12801315
>>> f1_score(y_true_empty, y_pred_empty, zero_division=np.nan)
12811316
nan...
12821317
@@ -1466,17 +1501,17 @@ def fbeta_score(
14661501
>>> y_true = [0, 1, 2, 0, 1, 2]
14671502
>>> y_pred = [0, 2, 1, 0, 0, 1]
14681503
>>> fbeta_score(y_true, y_pred, average='macro', beta=0.5)
1469-
np.float64(0.23...)
1504+
0.23...
14701505
>>> fbeta_score(y_true, y_pred, average='micro', beta=0.5)
1471-
np.float64(0.33...)
1506+
0.33...
14721507
>>> fbeta_score(y_true, y_pred, average='weighted', beta=0.5)
1473-
np.float64(0.23...)
1508+
0.23...
14741509
>>> fbeta_score(y_true, y_pred, average=None, beta=0.5)
14751510
array([0.71..., 0. , 0. ])
14761511
>>> y_pred_empty = [0, 0, 0, 0, 0, 0]
14771512
>>> fbeta_score(y_true, y_pred_empty,
14781513
... average="macro", zero_division=np.nan, beta=0.5)
1479-
np.float64(0.12...)
1514+
0.12...
14801515
"""
14811516

14821517
_, _, f, _ = precision_recall_fscore_support(
@@ -1505,12 +1540,14 @@ def _prf_divide(
15051540
The metric, modifier and average arguments are used only for determining
15061541
an appropriate warning.
15071542
"""
1508-
mask = denominator == 0.0
1509-
denominator = denominator.copy()
1543+
xp, _ = get_namespace(numerator, denominator)
1544+
dtype_float = _find_matching_floating_dtype(numerator, denominator, xp=xp)
1545+
mask = denominator == 0
1546+
denominator = xp.asarray(denominator, copy=True, dtype=dtype_float)
15101547
denominator[mask] = 1 # avoid infs/nans
1511-
result = numerator / denominator
1548+
result = xp.asarray(numerator, dtype=dtype_float) / denominator
15121549

1513-
if not np.any(mask):
1550+
if not xp.any(mask):
15141551
return result
15151552

15161553
# set those with 0 denominator to `zero_division`, and 0 when "warn"
@@ -1559,7 +1596,7 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
15591596
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
15601597
# Convert to Python primitive type to avoid NumPy type / Python str
15611598
# comparison. See https://github.com/numpy/numpy/issues/6784
1562-
present_labels = unique_labels(y_true, y_pred).tolist()
1599+
present_labels = _tolist(unique_labels(y_true, y_pred))
15631600
if average == "binary":
15641601
if y_type == "binary":
15651602
if pos_label not in present_labels:
@@ -1774,11 +1811,11 @@ def precision_recall_fscore_support(
17741811
>>> y_true = np.array(['cat', 'dog', 'pig', 'cat', 'dog', 'pig'])
17751812
>>> y_pred = np.array(['cat', 'pig', 'dog', 'cat', 'cat', 'dog'])
17761813
>>> precision_recall_fscore_support(y_true, y_pred, average='macro')
1777-
(np.float64(0.22...), np.float64(0.33...), np.float64(0.26...), None)
1814+
(0.22..., 0.33..., 0.26..., None)
17781815
>>> precision_recall_fscore_support(y_true, y_pred, average='micro')
1779-
(np.float64(0.33...), np.float64(0.33...), np.float64(0.33...), None)
1816+
(0.33..., 0.33..., 0.33..., None)
17801817
>>> precision_recall_fscore_support(y_true, y_pred, average='weighted')
1781-
(np.float64(0.22...), np.float64(0.33...), np.float64(0.26...), None)
1818+
(0.22..., 0.33..., 0.26..., None)
17821819
17831820
It is possible to compute per-label precisions, recalls, F1-scores and
17841821
supports instead of averaging:
@@ -1805,10 +1842,11 @@ def precision_recall_fscore_support(
18051842
pred_sum = tp_sum + MCM[:, 0, 1]
18061843
true_sum = tp_sum + MCM[:, 1, 0]
18071844

1845+
xp, _ = get_namespace(y_true, y_pred)
18081846
if average == "micro":
1809-
tp_sum = np.array([tp_sum.sum()])
1810-
pred_sum = np.array([pred_sum.sum()])
1811-
true_sum = np.array([true_sum.sum()])
1847+
tp_sum = xp.reshape(xp.sum(tp_sum), (1,))
1848+
pred_sum = xp.reshape(xp.sum(pred_sum), (1,))
1849+
true_sum = xp.reshape(xp.sum(true_sum), (1,))
18121850

18131851
# Finally, we have all our sufficient statistics. Divide! #
18141852
beta2 = beta**2
@@ -1851,10 +1889,10 @@ def precision_recall_fscore_support(
18511889
weights = None
18521890

18531891
if average is not None:
1854-
assert average != "binary" or len(precision) == 1
1855-
precision = _nanaverage(precision, weights=weights)
1856-
recall = _nanaverage(recall, weights=weights)
1857-
f_score = _nanaverage(f_score, weights=weights)
1892+
assert average != "binary" or precision.shape[0] == 1
1893+
precision = float(_nanaverage(precision, weights=weights))
1894+
recall = float(_nanaverage(recall, weights=weights))
1895+
f_score = float(_nanaverage(f_score, weights=weights))
18581896
true_sum = None # return no support
18591897

18601898
return precision, recall, f_score, true_sum
@@ -2185,11 +2223,11 @@ def precision_score(
21852223
>>> y_true = [0, 1, 2, 0, 1, 2]
21862224
>>> y_pred = [0, 2, 1, 0, 0, 1]
21872225
>>> precision_score(y_true, y_pred, average='macro')
2188-
np.float64(0.22...)
2226+
0.22...
21892227
>>> precision_score(y_true, y_pred, average='micro')
2190-
np.float64(0.33...)
2228+
0.33...
21912229
>>> precision_score(y_true, y_pred, average='weighted')
2192-
np.float64(0.22...)
2230+
0.22...
21932231
>>> precision_score(y_true, y_pred, average=None)
21942232
array([0.66..., 0. , 0. ])
21952233
>>> y_pred = [0, 0, 0, 0, 0, 0]
@@ -2367,11 +2405,11 @@ def recall_score(
23672405
>>> y_true = [0, 1, 2, 0, 1, 2]
23682406
>>> y_pred = [0, 2, 1, 0, 0, 1]
23692407
>>> recall_score(y_true, y_pred, average='macro')
2370-
np.float64(0.33...)
2408+
0.33...
23712409
>>> recall_score(y_true, y_pred, average='micro')
2372-
np.float64(0.33...)
2410+
0.33...
23732411
>>> recall_score(y_true, y_pred, average='weighted')
2374-
np.float64(0.33...)
2412+
0.33...
23752413
>>> recall_score(y_true, y_pred, average=None)
23762414
array([1., 0., 0.])
23772415
>>> y_true = [0, 0, 0, 0, 0, 0]

0 commit comments

Comments
 (0)