@@ -66,7 +66,7 @@ def _check_zero_division(zero_division):
66
66
return np .nan
67
67
68
68
69
- def _check_targets (y_true , y_pred ):
69
+ def _check_targets (y_true , y_pred , sample_weight = None ):
70
70
"""Check that y_true and y_pred belong to the same classification task.
71
71
72
72
This converts multiclass or binary types to a common shape, and raises a
@@ -83,6 +83,8 @@ def _check_targets(y_true, y_pred):
83
83
84
84
y_pred : array-like
85
85
86
+ sample_weight : array-like, default=None
87
+
86
88
Returns
87
89
-------
88
90
type_true : one of {'multilabel-indicator', 'multiclass', 'binary'}
@@ -92,11 +94,17 @@ def _check_targets(y_true, y_pred):
92
94
y_true : array or indicator matrix
93
95
94
96
y_pred : array or indicator matrix
97
+
98
+ sample_weight : array or None
95
99
"""
96
- xp , _ = get_namespace (y_true , y_pred )
97
- check_consistent_length (y_true , y_pred )
100
+ xp , _ = get_namespace (y_true , y_pred , sample_weight )
101
+ check_consistent_length (y_true , y_pred , sample_weight )
98
102
type_true = type_of_target (y_true , input_name = "y_true" )
99
103
type_pred = type_of_target (y_pred , input_name = "y_pred" )
104
+ if sample_weight is not None :
105
+ sample_weight = _check_sample_weight (
106
+ sample_weight , y_true , force_float_dtype = False
107
+ )
100
108
101
109
y_type = {type_true , type_pred }
102
110
if y_type == {"binary" , "multiclass" }:
@@ -148,7 +156,7 @@ def _check_targets(y_true, y_pred):
148
156
y_pred = csr_matrix (y_pred )
149
157
y_type = "multilabel-indicator"
150
158
151
- return y_type , y_true , y_pred
159
+ return y_type , y_true , y_pred , sample_weight
152
160
153
161
154
162
def _validate_multiclass_probabilistic_prediction (
@@ -200,6 +208,9 @@ def _validate_multiclass_probabilistic_prediction(
200
208
raise ValueError (f"y_prob contains values lower than 0: { y_prob .min ()} " )
201
209
202
210
check_consistent_length (y_prob , y_true , sample_weight )
211
+ if sample_weight is not None :
212
+ _check_sample_weight (sample_weight , y_true , force_float_dtype = False )
213
+
203
214
lb = LabelBinarizer ()
204
215
205
216
if labels is not None :
@@ -356,8 +367,9 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
356
367
xp , _ , device = get_namespace_and_device (y_true , y_pred , sample_weight )
357
368
# Compute accuracy for each possible representation
358
369
y_true , y_pred = attach_unique (y_true , y_pred )
359
- y_type , y_true , y_pred = _check_targets (y_true , y_pred )
360
- check_consistent_length (y_true , y_pred , sample_weight )
370
+ y_type , y_true , y_pred , sample_weight = _check_targets (
371
+ y_true , y_pred , sample_weight
372
+ )
361
373
362
374
if y_type .startswith ("multilabel" ):
363
375
differing_labels = _count_nonzero (y_true - y_pred , xp = xp , device = device , axis = 1 )
@@ -464,7 +476,9 @@ def confusion_matrix(
464
476
(0, 2, 1, 1)
465
477
"""
466
478
y_true , y_pred = attach_unique (y_true , y_pred )
467
- y_type , y_true , y_pred = _check_targets (y_true , y_pred )
479
+ y_type , y_true , y_pred , sample_weight = _check_targets (
480
+ y_true , y_pred , sample_weight
481
+ )
468
482
if y_type not in ("binary" , "multiclass" ):
469
483
raise ValueError ("%s is not supported" % y_type )
470
484
@@ -482,10 +496,6 @@ def confusion_matrix(
482
496
483
497
if sample_weight is None :
484
498
sample_weight = np .ones (y_true .shape [0 ], dtype = np .int64 )
485
- else :
486
- sample_weight = np .asarray (sample_weight )
487
-
488
- check_consistent_length (y_true , y_pred , sample_weight )
489
499
490
500
n_labels = labels .size
491
501
# If labels are not consecutive integers starting from zero, then
@@ -654,11 +664,10 @@ def multilabel_confusion_matrix(
654
664
[1, 2]]])
655
665
"""
656
666
y_true , y_pred = attach_unique (y_true , y_pred )
657
- xp , _ , device_ = get_namespace_and_device (y_true , y_pred )
658
- y_type , y_true , y_pred = _check_targets (y_true , y_pred )
659
- if sample_weight is not None :
660
- sample_weight = column_or_1d (sample_weight , device = device_ )
661
- check_consistent_length (y_true , y_pred , sample_weight )
667
+ xp , _ , device_ = get_namespace_and_device (y_true , y_pred , sample_weight )
668
+ y_type , y_true , y_pred , sample_weight = _check_targets (
669
+ y_true , y_pred , sample_weight
670
+ )
662
671
663
672
if y_type not in ("binary" , "multiclass" , "multilabel-indicator" ):
664
673
raise ValueError ("%s is not supported" % y_type )
@@ -1171,8 +1180,9 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None):
1171
1180
-0.33
1172
1181
"""
1173
1182
y_true , y_pred = attach_unique (y_true , y_pred )
1174
- y_type , y_true , y_pred = _check_targets (y_true , y_pred )
1175
- check_consistent_length (y_true , y_pred , sample_weight )
1183
+ y_type , y_true , y_pred , sample_weight = _check_targets (
1184
+ y_true , y_pred , sample_weight
1185
+ )
1176
1186
if y_type not in {"binary" , "multiclass" }:
1177
1187
raise ValueError ("%s is not supported" % y_type )
1178
1188
@@ -1759,7 +1769,7 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
1759
1769
raise ValueError ("average has to be one of " + str (average_options ))
1760
1770
1761
1771
y_true , y_pred = attach_unique (y_true , y_pred )
1762
- y_type , y_true , y_pred = _check_targets (y_true , y_pred )
1772
+ y_type , y_true , y_pred , _ = _check_targets (y_true , y_pred )
1763
1773
# Convert to Python primitive type to avoid NumPy type / Python str
1764
1774
# comparison. See https://github.com/numpy/numpy/issues/6784
1765
1775
present_labels = _tolist (unique_labels (y_true , y_pred ))
@@ -2227,7 +2237,9 @@ class are present in `y_true`): both likelihood ratios are undefined.
2227
2237
# remove `FutureWarning`, and the Warns section in the docstring should not mention
2228
2238
# `raise_warning` anymore.
2229
2239
y_true , y_pred = attach_unique (y_true , y_pred )
2230
- y_type , y_true , y_pred = _check_targets (y_true , y_pred )
2240
+ y_type , y_true , y_pred , sample_weight = _check_targets (
2241
+ y_true , y_pred , sample_weight
2242
+ )
2231
2243
if y_type != "binary" :
2232
2244
raise ValueError (
2233
2245
"class_likelihood_ratios only supports binary classification "
@@ -2945,7 +2957,9 @@ class 2 1.00 0.67 0.80 3
2945
2957
"""
2946
2958
2947
2959
y_true , y_pred = attach_unique (y_true , y_pred )
2948
- y_type , y_true , y_pred = _check_targets (y_true , y_pred )
2960
+ y_type , y_true , y_pred , sample_weight = _check_targets (
2961
+ y_true , y_pred , sample_weight
2962
+ )
2949
2963
2950
2964
if labels is None :
2951
2965
labels = unique_labels (y_true , y_pred )
@@ -3134,15 +3148,15 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None):
3134
3148
0.75
3135
3149
"""
3136
3150
y_true , y_pred = attach_unique (y_true , y_pred )
3137
- y_type , y_true , y_pred = _check_targets (y_true , y_pred )
3138
- check_consistent_length (y_true , y_pred , sample_weight )
3151
+ y_type , y_true , y_pred , sample_weight = _check_targets (
3152
+ y_true , y_pred , sample_weight
3153
+ )
3139
3154
3140
3155
xp , _ , device = get_namespace_and_device (y_true , y_pred , sample_weight )
3141
3156
3142
3157
if sample_weight is None :
3143
3158
weight_average = 1.0
3144
3159
else :
3145
- sample_weight = xp .asarray (sample_weight , device = device )
3146
3160
weight_average = _average (sample_weight , xp = xp )
3147
3161
3148
3162
if y_type .startswith ("multilabel" ):
@@ -3440,6 +3454,8 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos
3440
3454
assert_all_finite (y_prob )
3441
3455
3442
3456
check_consistent_length (y_prob , y_true , sample_weight )
3457
+ if sample_weight is not None :
3458
+ _check_sample_weight (sample_weight , y_true , force_float_dtype = False )
3443
3459
3444
3460
y_type = type_of_target (y_true , input_name = "y_true" )
3445
3461
if y_type != "binary" :
0 commit comments