29
29
from sklearn .utils ._array_api import (
30
30
_average ,
31
31
_bincount ,
32
+ _convert_to_numpy ,
32
33
_count_nonzero ,
33
34
_find_matching_floating_dtype ,
34
35
_is_numpy_namespace ,
@@ -413,7 +414,7 @@ def confusion_matrix(
413
414
y_pred : array-like of shape (n_samples,)
414
415
Estimated targets as returned by a classifier.
415
416
416
- labels : array-like of shape (n_classes), default=None
417
+ labels : array-like of shape (n_classes, ), default=None
417
418
List of labels to index the matrix. This may be used to reorder
418
419
or select a subset of labels.
419
420
If ``None`` is given, those that appear at least once
@@ -475,28 +476,61 @@ def confusion_matrix(
475
476
>>> (tn, fp, fn, tp)
476
477
(0, 2, 1, 1)
477
478
"""
478
- y_true , y_pred = attach_unique (y_true , y_pred )
479
- y_type , y_true , y_pred , sample_weight = _check_targets (
480
- y_true , y_pred , sample_weight
479
+ xp , _ , device_ = get_namespace_and_device (y_true , y_pred , labels , sample_weight )
480
+ y_true = check_array (
481
+ y_true ,
482
+ dtype = None ,
483
+ ensure_2d = False ,
484
+ ensure_all_finite = False ,
485
+ ensure_min_samples = 0 ,
481
486
)
487
+ y_pred = check_array (
488
+ y_pred ,
489
+ dtype = None ,
490
+ ensure_2d = False ,
491
+ ensure_all_finite = False ,
492
+ ensure_min_samples = 0 ,
493
+ )
494
+ # Convert the input arrays to NumPy (on CPU) irrespective of the original
495
+ # namespace and device so as to be able to leverage the the efficient
496
+ # counting operations implemented by SciPy in the coo_matrix constructor.
497
+ # The final results will be converted back to the input namespace and device
498
+ # for the sake of consistency with other metric functions with array API support.
499
+ y_true = _convert_to_numpy (y_true , xp )
500
+ y_pred = _convert_to_numpy (y_pred , xp )
501
+ if sample_weight is None :
502
+ sample_weight = np .ones (y_true .shape [0 ], dtype = np .int64 )
503
+ else :
504
+ sample_weight = _convert_to_numpy (sample_weight , xp )
505
+
506
+ if len (sample_weight ) > 0 :
507
+ y_type , y_true , y_pred , sample_weight = _check_targets (
508
+ y_true , y_pred , sample_weight
509
+ )
510
+ else :
511
+ # This is needed to handle the special case where y_true, y_pred and
512
+ # sample_weight are all empty.
513
+ # In this case we don't pass sample_weight to _check_targets that would
514
+ # check that sample_weight is not empty and we don't reuse the returned
515
+ # sample_weight
516
+ y_type , y_true , y_pred , _ = _check_targets (y_true , y_pred )
517
+
518
+ y_true , y_pred = attach_unique (y_true , y_pred )
482
519
if y_type not in ("binary" , "multiclass" ):
483
520
raise ValueError ("%s is not supported" % y_type )
484
521
485
522
if labels is None :
486
523
labels = unique_labels (y_true , y_pred )
487
524
else :
488
- labels = np . asarray (labels )
525
+ labels = _convert_to_numpy (labels , xp )
489
526
n_labels = labels .size
490
527
if n_labels == 0 :
491
- raise ValueError ("'labels' should contains at least one label." )
528
+ raise ValueError ("'labels' should contain at least one label." )
492
529
elif y_true .size == 0 :
493
530
return np .zeros ((n_labels , n_labels ), dtype = int )
494
531
elif len (np .intersect1d (y_true , labels )) == 0 :
495
532
raise ValueError ("At least one label specified must be in y_true" )
496
533
497
- if sample_weight is None :
498
- sample_weight = np .ones (y_true .shape [0 ], dtype = np .int64 )
499
-
500
534
n_labels = labels .size
501
535
# If labels are not consecutive integers starting from zero, then
502
536
# y_true and y_pred must be converted into index form
@@ -507,9 +541,9 @@ def confusion_matrix(
507
541
and y_pred .min () >= 0
508
542
)
509
543
if need_index_conversion :
510
- label_to_ind = {y : x for x , y in enumerate (labels )}
511
- y_pred = np .array ([label_to_ind .get (x , n_labels + 1 ) for x in y_pred ])
512
- y_true = np .array ([label_to_ind .get (x , n_labels + 1 ) for x in y_true ])
544
+ label_to_ind = {label : index for index , label in enumerate (labels )}
545
+ y_pred = np .array ([label_to_ind .get (label , n_labels + 1 ) for label in y_pred ])
546
+ y_true = np .array ([label_to_ind .get (label , n_labels + 1 ) for label in y_true ])
513
547
514
548
# intersect y_pred, y_true with labels, eliminate items not in labels
515
549
ind = np .logical_and (y_pred < n_labels , y_true < n_labels )
@@ -550,7 +584,7 @@ def confusion_matrix(
550
584
UserWarning ,
551
585
)
552
586
553
- return cm
587
+ return xp . asarray ( cm , device = device_ )
554
588
555
589
556
590
@validate_params (
0 commit comments