15
15
from numbers import Integral , Real
16
16
17
17
import numpy as np
18
- from scipy .sparse import coo_matrix , csr_matrix
18
+ from scipy .sparse import coo_matrix , csr_matrix , issparse
19
19
from scipy .special import xlogy
20
20
21
21
from ..exceptions import UndefinedMetricWarning
28
28
)
29
29
from ..utils ._array_api import (
30
30
_average ,
31
+ _bincount ,
31
32
_count_nonzero ,
33
+ _find_matching_floating_dtype ,
32
34
_is_numpy_namespace ,
35
+ _searchsorted ,
36
+ _setdiff1d ,
37
+ _tolist ,
33
38
_union1d ,
39
+ device ,
34
40
get_namespace ,
35
41
get_namespace_and_device ,
36
42
)
@@ -521,9 +527,11 @@ def multilabel_confusion_matrix(
521
527
[1, 2]]])
522
528
"""
523
529
y_true , y_pred = attach_unique (y_true , y_pred )
530
+ xp , _ = get_namespace (y_true , y_pred )
531
+ device_ = device (y_true , y_pred )
524
532
y_type , y_true , y_pred = _check_targets (y_true , y_pred )
525
533
if sample_weight is not None :
526
- sample_weight = column_or_1d (sample_weight )
534
+ sample_weight = column_or_1d (sample_weight , device = device_ )
527
535
check_consistent_length (y_true , y_pred , sample_weight )
528
536
529
537
if y_type not in ("binary" , "multiclass" , "multilabel-indicator" ):
@@ -534,9 +542,11 @@ def multilabel_confusion_matrix(
534
542
labels = present_labels
535
543
n_labels = None
536
544
else :
537
- n_labels = len (labels )
538
- labels = np .hstack (
539
- [labels , np .setdiff1d (present_labels , labels , assume_unique = True )]
545
+ labels = xp .asarray (labels , device = device_ )
546
+ n_labels = labels .shape [0 ]
547
+ labels = xp .concat (
548
+ [labels , _setdiff1d (present_labels , labels , assume_unique = True , xp = xp )],
549
+ axis = - 1 ,
540
550
)
541
551
542
552
if y_true .ndim == 1 :
@@ -556,77 +566,102 @@ def multilabel_confusion_matrix(
556
566
tp = y_true == y_pred
557
567
tp_bins = y_true [tp ]
558
568
if sample_weight is not None :
559
- tp_bins_weights = np . asarray ( sample_weight ) [tp ]
569
+ tp_bins_weights = sample_weight [tp ]
560
570
else :
561
571
tp_bins_weights = None
562
572
563
- if len ( tp_bins ) :
564
- tp_sum = np . bincount (
565
- tp_bins , weights = tp_bins_weights , minlength = len ( labels )
573
+ if tp_bins . shape [ 0 ] :
574
+ tp_sum = _bincount (
575
+ tp_bins , weights = tp_bins_weights , minlength = labels . shape [ 0 ], xp = xp
566
576
)
567
577
else :
568
578
# Pathological case
569
- true_sum = pred_sum = tp_sum = np .zeros (len (labels ))
570
- if len (y_pred ):
571
- pred_sum = np .bincount (y_pred , weights = sample_weight , minlength = len (labels ))
572
- if len (y_true ):
573
- true_sum = np .bincount (y_true , weights = sample_weight , minlength = len (labels ))
579
+ true_sum = pred_sum = tp_sum = xp .zeros (labels .shape [0 ])
580
+ if y_pred .shape [0 ]:
581
+ pred_sum = _bincount (
582
+ y_pred , weights = sample_weight , minlength = labels .shape [0 ], xp = xp
583
+ )
584
+ if y_true .shape [0 ]:
585
+ true_sum = _bincount (
586
+ y_true , weights = sample_weight , minlength = labels .shape [0 ], xp = xp
587
+ )
574
588
575
589
# Retain only selected labels
576
- indices = np . searchsorted (sorted_labels , labels [:n_labels ])
577
- tp_sum = tp_sum [ indices ]
578
- true_sum = true_sum [ indices ]
579
- pred_sum = pred_sum [ indices ]
590
+ indices = _searchsorted (sorted_labels , labels [:n_labels ], xp = xp )
591
+ tp_sum = xp . take ( tp_sum , indices , axis = 0 )
592
+ true_sum = xp . take ( true_sum , indices , axis = 0 )
593
+ pred_sum = xp . take ( pred_sum , indices , axis = 0 )
580
594
581
595
else :
582
596
sum_axis = 1 if samplewise else 0
583
597
584
598
# All labels are index integers for multilabel.
585
599
# Select labels:
586
- if not np .array_equal (labels , present_labels ):
587
- if np .max (labels ) > np .max (present_labels ):
600
+ if labels .shape != present_labels .shape or xp .any (
601
+ xp .not_equal (labels , present_labels )
602
+ ):
603
+ if xp .max (labels ) > xp .max (present_labels ):
588
604
raise ValueError (
589
605
"All labels must be in [0, n labels) for "
590
606
"multilabel targets. "
591
- "Got %d > %d" % (np .max (labels ), np .max (present_labels ))
607
+ "Got %d > %d" % (xp .max (labels ), xp .max (present_labels ))
592
608
)
593
- if np .min (labels ) < 0 :
609
+ if xp .min (labels ) < 0 :
594
610
raise ValueError (
595
611
"All labels must be in [0, n labels) for "
596
612
"multilabel targets. "
597
- "Got %d < 0" % np .min (labels )
613
+ "Got %d < 0" % xp .min (labels )
598
614
)
599
615
600
616
if n_labels is not None :
601
617
y_true = y_true [:, labels [:n_labels ]]
602
618
y_pred = y_pred [:, labels [:n_labels ]]
603
619
620
+ if issparse (y_true ) or issparse (y_pred ):
621
+ true_and_pred = y_true .multiply (y_pred )
622
+ else :
623
+ true_and_pred = xp .multiply (y_true , y_pred )
624
+
604
625
# calculate weighted counts
605
- true_and_pred = y_true .multiply (y_pred )
606
- tp_sum = count_nonzero (
607
- true_and_pred , axis = sum_axis , sample_weight = sample_weight
626
+ tp_sum = _count_nonzero (
627
+ true_and_pred ,
628
+ axis = sum_axis ,
629
+ sample_weight = sample_weight ,
630
+ xp = xp ,
631
+ device = device_ ,
632
+ )
633
+ pred_sum = _count_nonzero (
634
+ y_pred ,
635
+ axis = sum_axis ,
636
+ sample_weight = sample_weight ,
637
+ xp = xp ,
638
+ device = device_ ,
639
+ )
640
+ true_sum = _count_nonzero (
641
+ y_true ,
642
+ axis = sum_axis ,
643
+ sample_weight = sample_weight ,
644
+ xp = xp ,
645
+ device = device_ ,
608
646
)
609
- pred_sum = count_nonzero (y_pred , axis = sum_axis , sample_weight = sample_weight )
610
- true_sum = count_nonzero (y_true , axis = sum_axis , sample_weight = sample_weight )
611
647
612
648
fp = pred_sum - tp_sum
613
649
fn = true_sum - tp_sum
614
650
tp = tp_sum
615
651
616
652
if sample_weight is not None and samplewise :
617
- sample_weight = np .array (sample_weight )
618
- tp = np .array (tp )
619
- fp = np .array (fp )
620
- fn = np .array (fn )
653
+ tp = xp .asarray (tp )
654
+ fp = xp .asarray (fp )
655
+ fn = xp .asarray (fn )
621
656
tn = sample_weight * y_true .shape [1 ] - tp - fp - fn
622
657
elif sample_weight is not None :
623
- tn = sum (sample_weight ) - tp - fp - fn
658
+ tn = xp . sum (sample_weight ) - tp - fp - fn
624
659
elif samplewise :
625
660
tn = y_true .shape [1 ] - tp - fp - fn
626
661
else :
627
662
tn = y_true .shape [0 ] - tp - fp - fn
628
663
629
- return np . array ( [tn , fp , fn , tp ]).T . reshape (- 1 , 2 , 2 )
664
+ return xp . reshape ( xp . stack ( [tn , fp , fn , tp ]).T , (- 1 , 2 , 2 ) )
630
665
631
666
632
667
@validate_params (
@@ -1262,21 +1297,21 @@ def f1_score(
1262
1297
>>> y_true = [0, 1, 2, 0, 1, 2]
1263
1298
>>> y_pred = [0, 2, 1, 0, 0, 1]
1264
1299
>>> f1_score(y_true, y_pred, average='macro')
1265
- np.float64( 0.26...)
1300
+ 0.26...
1266
1301
>>> f1_score(y_true, y_pred, average='micro')
1267
- np.float64( 0.33...)
1302
+ 0.33...
1268
1303
>>> f1_score(y_true, y_pred, average='weighted')
1269
- np.float64( 0.26...)
1304
+ 0.26...
1270
1305
>>> f1_score(y_true, y_pred, average=None)
1271
1306
array([0.8, 0. , 0. ])
1272
1307
1273
1308
>>> # binary classification
1274
1309
>>> y_true_empty = [0, 0, 0, 0, 0, 0]
1275
1310
>>> y_pred_empty = [0, 0, 0, 0, 0, 0]
1276
1311
>>> f1_score(y_true_empty, y_pred_empty)
1277
- np.float64( 0.0...)
1312
+ 0.0...
1278
1313
>>> f1_score(y_true_empty, y_pred_empty, zero_division=1.0)
1279
- np.float64( 1.0...)
1314
+ 1.0...
1280
1315
>>> f1_score(y_true_empty, y_pred_empty, zero_division=np.nan)
1281
1316
nan...
1282
1317
@@ -1466,17 +1501,17 @@ def fbeta_score(
1466
1501
>>> y_true = [0, 1, 2, 0, 1, 2]
1467
1502
>>> y_pred = [0, 2, 1, 0, 0, 1]
1468
1503
>>> fbeta_score(y_true, y_pred, average='macro', beta=0.5)
1469
- np.float64( 0.23...)
1504
+ 0.23...
1470
1505
>>> fbeta_score(y_true, y_pred, average='micro', beta=0.5)
1471
- np.float64( 0.33...)
1506
+ 0.33...
1472
1507
>>> fbeta_score(y_true, y_pred, average='weighted', beta=0.5)
1473
- np.float64( 0.23...)
1508
+ 0.23...
1474
1509
>>> fbeta_score(y_true, y_pred, average=None, beta=0.5)
1475
1510
array([0.71..., 0. , 0. ])
1476
1511
>>> y_pred_empty = [0, 0, 0, 0, 0, 0]
1477
1512
>>> fbeta_score(y_true, y_pred_empty,
1478
1513
... average="macro", zero_division=np.nan, beta=0.5)
1479
- np.float64( 0.12...)
1514
+ 0.12...
1480
1515
"""
1481
1516
1482
1517
_ , _ , f , _ = precision_recall_fscore_support (
@@ -1505,12 +1540,14 @@ def _prf_divide(
1505
1540
The metric, modifier and average arguments are used only for determining
1506
1541
an appropriate warning.
1507
1542
"""
1508
- mask = denominator == 0.0
1509
- denominator = denominator .copy ()
1543
+ xp , _ = get_namespace (numerator , denominator )
1544
+ dtype_float = _find_matching_floating_dtype (numerator , denominator , xp = xp )
1545
+ mask = denominator == 0
1546
+ denominator = xp .asarray (denominator , copy = True , dtype = dtype_float )
1510
1547
denominator [mask ] = 1 # avoid infs/nans
1511
- result = numerator / denominator
1548
+ result = xp . asarray ( numerator , dtype = dtype_float ) / denominator
1512
1549
1513
- if not np .any (mask ):
1550
+ if not xp .any (mask ):
1514
1551
return result
1515
1552
1516
1553
# set those with 0 denominator to `zero_division`, and 0 when "warn"
@@ -1559,7 +1596,7 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
1559
1596
y_type , y_true , y_pred = _check_targets (y_true , y_pred )
1560
1597
# Convert to Python primitive type to avoid NumPy type / Python str
1561
1598
# comparison. See https://github.com/numpy/numpy/issues/6784
1562
- present_labels = unique_labels (y_true , y_pred ). tolist ( )
1599
+ present_labels = _tolist ( unique_labels (y_true , y_pred ))
1563
1600
if average == "binary" :
1564
1601
if y_type == "binary" :
1565
1602
if pos_label not in present_labels :
@@ -1774,11 +1811,11 @@ def precision_recall_fscore_support(
1774
1811
>>> y_true = np.array(['cat', 'dog', 'pig', 'cat', 'dog', 'pig'])
1775
1812
>>> y_pred = np.array(['cat', 'pig', 'dog', 'cat', 'cat', 'dog'])
1776
1813
>>> precision_recall_fscore_support(y_true, y_pred, average='macro')
1777
- (np.float64( 0.22...), np.float64( 0.33...), np.float64( 0.26...) , None)
1814
+ (0.22..., 0.33..., 0.26..., None)
1778
1815
>>> precision_recall_fscore_support(y_true, y_pred, average='micro')
1779
- (np.float64( 0.33...), np.float64( 0.33...), np.float64( 0.33...) , None)
1816
+ (0.33..., 0.33..., 0.33..., None)
1780
1817
>>> precision_recall_fscore_support(y_true, y_pred, average='weighted')
1781
- (np.float64( 0.22...), np.float64( 0.33...), np.float64( 0.26...) , None)
1818
+ (0.22..., 0.33..., 0.26..., None)
1782
1819
1783
1820
It is possible to compute per-label precisions, recalls, F1-scores and
1784
1821
supports instead of averaging:
@@ -1805,10 +1842,11 @@ def precision_recall_fscore_support(
1805
1842
pred_sum = tp_sum + MCM [:, 0 , 1 ]
1806
1843
true_sum = tp_sum + MCM [:, 1 , 0 ]
1807
1844
1845
+ xp , _ = get_namespace (y_true , y_pred )
1808
1846
if average == "micro" :
1809
- tp_sum = np . array ([ tp_sum .sum ()] )
1810
- pred_sum = np . array ([ pred_sum .sum ()] )
1811
- true_sum = np . array ([ true_sum .sum ()] )
1847
+ tp_sum = xp . reshape ( xp .sum (tp_sum ), ( 1 ,) )
1848
+ pred_sum = xp . reshape ( xp .sum (pred_sum ), ( 1 ,) )
1849
+ true_sum = xp . reshape ( xp .sum (true_sum ), ( 1 ,) )
1812
1850
1813
1851
# Finally, we have all our sufficient statistics. Divide! #
1814
1852
beta2 = beta ** 2
@@ -1851,10 +1889,10 @@ def precision_recall_fscore_support(
1851
1889
weights = None
1852
1890
1853
1891
if average is not None :
1854
- assert average != "binary" or len ( precision ) == 1
1855
- precision = _nanaverage (precision , weights = weights )
1856
- recall = _nanaverage (recall , weights = weights )
1857
- f_score = _nanaverage (f_score , weights = weights )
1892
+ assert average != "binary" or precision . shape [ 0 ] == 1
1893
+ precision = float ( _nanaverage (precision , weights = weights ) )
1894
+ recall = float ( _nanaverage (recall , weights = weights ) )
1895
+ f_score = float ( _nanaverage (f_score , weights = weights ) )
1858
1896
true_sum = None # return no support
1859
1897
1860
1898
return precision , recall , f_score , true_sum
@@ -2185,11 +2223,11 @@ def precision_score(
2185
2223
>>> y_true = [0, 1, 2, 0, 1, 2]
2186
2224
>>> y_pred = [0, 2, 1, 0, 0, 1]
2187
2225
>>> precision_score(y_true, y_pred, average='macro')
2188
- np.float64( 0.22...)
2226
+ 0.22...
2189
2227
>>> precision_score(y_true, y_pred, average='micro')
2190
- np.float64( 0.33...)
2228
+ 0.33...
2191
2229
>>> precision_score(y_true, y_pred, average='weighted')
2192
- np.float64( 0.22...)
2230
+ 0.22...
2193
2231
>>> precision_score(y_true, y_pred, average=None)
2194
2232
array([0.66..., 0. , 0. ])
2195
2233
>>> y_pred = [0, 0, 0, 0, 0, 0]
@@ -2367,11 +2405,11 @@ def recall_score(
2367
2405
>>> y_true = [0, 1, 2, 0, 1, 2]
2368
2406
>>> y_pred = [0, 2, 1, 0, 0, 1]
2369
2407
>>> recall_score(y_true, y_pred, average='macro')
2370
- np.float64( 0.33...)
2408
+ 0.33...
2371
2409
>>> recall_score(y_true, y_pred, average='micro')
2372
- np.float64( 0.33...)
2410
+ 0.33...
2373
2411
>>> recall_score(y_true, y_pred, average='weighted')
2374
- np.float64( 0.33...)
2412
+ 0.33...
2375
2413
>>> recall_score(y_true, y_pred, average=None)
2376
2414
array([1., 0., 0.])
2377
2415
>>> y_true = [0, 0, 0, 0, 0, 0]
0 commit comments