57
57
]
58
58
59
59
60
- def _check_reg_targets (y_true , y_pred , multioutput , dtype = "numeric" , xp = None ):
61
- """Check that y_true and y_pred belong to the same regression task.
60
+ def _check_reg_targets (
61
+ y_true , y_pred , sample_weight , multioutput , dtype = "numeric" , xp = None
62
+ ):
63
+ """Check that y_true, y_pred and sample_weight belong to the same regression task.
62
64
63
65
To reduce redundancy when calling `_find_matching_floating_dtype`,
64
66
please use `_check_reg_targets_with_floating_dtype` instead.
@@ -71,6 +73,9 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
71
73
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
72
74
Estimated target values.
73
75
76
+ sample_weight : array-like of shape (n_samples,) or None
77
+ Sample weights.
78
+
74
79
multioutput : array-like or string in ['raw_values', uniform_average',
75
80
'variance_weighted'] or None
76
81
None is accepted due to backward compatibility of r2_score().
@@ -95,6 +100,9 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
95
100
y_pred : array-like of shape (n_samples, n_outputs)
96
101
Estimated target values.
97
102
103
+ sample_weight : array-like of shape (n_samples,) or None
104
+ Sample weights.
105
+
98
106
multioutput : array-like of shape (n_outputs) or string in ['raw_values',
99
107
uniform_average', 'variance_weighted'] or None
100
108
Custom output weights if ``multioutput`` is array-like or
@@ -103,9 +111,11 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
103
111
"""
104
112
xp , _ = get_namespace (y_true , y_pred , multioutput , xp = xp )
105
113
106
- check_consistent_length (y_true , y_pred )
114
+ check_consistent_length (y_true , y_pred , sample_weight )
107
115
y_true = check_array (y_true , ensure_2d = False , dtype = dtype )
108
116
y_pred = check_array (y_pred , ensure_2d = False , dtype = dtype )
117
+ if sample_weight is not None :
118
+ sample_weight = _check_sample_weight (sample_weight , y_true , dtype = dtype )
109
119
110
120
if y_true .ndim == 1 :
111
121
y_true = xp .reshape (y_true , (- 1 , 1 ))
@@ -141,14 +151,13 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
141
151
)
142
152
y_type = "continuous" if n_outputs == 1 else "continuous-multioutput"
143
153
144
- return y_type , y_true , y_pred , multioutput
154
+ return y_type , y_true , y_pred , sample_weight , multioutput
145
155
146
156
147
157
def _check_reg_targets_with_floating_dtype (
148
158
y_true , y_pred , sample_weight , multioutput , xp = None
149
159
):
150
- """Ensures that y_true, y_pred, and sample_weight correspond to the same
151
- regression task.
160
+ """Ensures y_true, y_pred, and sample_weight correspond to same regression task.
152
161
153
162
Extends `_check_reg_targets` by automatically selecting a suitable floating-point
154
163
data type for inputs using `_find_matching_floating_dtype`.
@@ -197,15 +206,10 @@ def _check_reg_targets_with_floating_dtype(
197
206
"""
198
207
dtype_name = _find_matching_floating_dtype (y_true , y_pred , sample_weight , xp = xp )
199
208
200
- y_type , y_true , y_pred , multioutput = _check_reg_targets (
201
- y_true , y_pred , multioutput , dtype = dtype_name , xp = xp
209
+ y_type , y_true , y_pred , sample_weight , multioutput = _check_reg_targets (
210
+ y_true , y_pred , sample_weight , multioutput , dtype = dtype_name , xp = xp
202
211
)
203
212
204
- # _check_reg_targets does not accept sample_weight as input.
205
- # Convert sample_weight's data type separately to match dtype_name.
206
- if sample_weight is not None :
207
- sample_weight = xp .asarray (sample_weight , dtype = dtype_name )
208
-
209
213
return y_type , y_true , y_pred , sample_weight , multioutput
210
214
211
215
@@ -282,8 +286,6 @@ def mean_absolute_error(
282
286
)
283
287
)
284
288
285
- check_consistent_length (y_true , y_pred , sample_weight )
286
-
287
289
output_errors = _average (
288
290
xp .abs (y_pred - y_true ), weights = sample_weight , axis = 0 , xp = xp
289
291
)
@@ -383,7 +385,6 @@ def mean_pinball_loss(
383
385
)
384
386
)
385
387
386
- check_consistent_length (y_true , y_pred , sample_weight )
387
388
diff = y_true - y_pred
388
389
sign = xp .astype (diff >= 0 , diff .dtype )
389
390
loss = alpha * sign * diff - (1 - alpha ) * (1 - sign ) * diff
@@ -489,7 +490,6 @@ def mean_absolute_percentage_error(
489
490
y_true , y_pred , sample_weight , multioutput , xp = xp
490
491
)
491
492
)
492
- check_consistent_length (y_true , y_pred , sample_weight )
493
493
epsilon = xp .asarray (xp .finfo (xp .float64 ).eps , dtype = y_true .dtype , device = device_ )
494
494
y_true_abs = xp .abs (y_true )
495
495
mape = xp .abs (y_pred - y_true ) / xp .maximum (y_true_abs , epsilon )
@@ -581,7 +581,6 @@ def mean_squared_error(
581
581
y_true , y_pred , sample_weight , multioutput , xp = xp
582
582
)
583
583
)
584
- check_consistent_length (y_true , y_pred , sample_weight )
585
584
output_errors = _average ((y_true - y_pred ) ** 2 , axis = 0 , weights = sample_weight )
586
585
587
586
if isinstance (multioutput , str ):
@@ -753,8 +752,10 @@ def mean_squared_log_error(
753
752
"""
754
753
xp , _ = get_namespace (y_true , y_pred )
755
754
756
- _ , y_true , y_pred , _ , _ = _check_reg_targets_with_floating_dtype (
757
- y_true , y_pred , sample_weight , multioutput , xp = xp
755
+ _ , y_true , y_pred , sample_weight , multioutput = (
756
+ _check_reg_targets_with_floating_dtype (
757
+ y_true , y_pred , sample_weight , multioutput , xp = xp
758
+ )
758
759
)
759
760
760
761
if xp .any (y_true <= - 1 ) or xp .any (y_pred <= - 1 ):
@@ -829,8 +830,10 @@ def root_mean_squared_log_error(
829
830
"""
830
831
xp , _ = get_namespace (y_true , y_pred )
831
832
832
- _ , y_true , y_pred , _ , _ = _check_reg_targets_with_floating_dtype (
833
- y_true , y_pred , sample_weight , multioutput , xp = xp
833
+ _ , y_true , y_pred , sample_weight , multioutput = (
834
+ _check_reg_targets_with_floating_dtype (
835
+ y_true , y_pred , sample_weight , multioutput , xp = xp
836
+ )
834
837
)
835
838
836
839
if xp .any (y_true <= - 1 ) or xp .any (y_pred <= - 1 ):
@@ -912,13 +915,12 @@ def median_absolute_error(
912
915
>>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
913
916
0.85
914
917
"""
915
- y_type , y_true , y_pred , multioutput = _check_reg_targets (
916
- y_true , y_pred , multioutput
918
+ _ , y_true , y_pred , sample_weight , multioutput = _check_reg_targets (
919
+ y_true , y_pred , sample_weight , multioutput
917
920
)
918
921
if sample_weight is None :
919
922
output_errors = np .median (np .abs (y_pred - y_true ), axis = 0 )
920
923
else :
921
- sample_weight = _check_sample_weight (sample_weight , y_pred )
922
924
output_errors = _weighted_percentile (
923
925
np .abs (y_pred - y_true ), sample_weight = sample_weight
924
926
)
@@ -1106,8 +1108,6 @@ def explained_variance_score(
1106
1108
)
1107
1109
)
1108
1110
1109
- check_consistent_length (y_true , y_pred , sample_weight )
1110
-
1111
1111
y_diff_avg = _average (y_true - y_pred , weights = sample_weight , axis = 0 )
1112
1112
numerator = _average (
1113
1113
(y_true - y_pred - y_diff_avg ) ** 2 , weights = sample_weight , axis = 0
@@ -1278,8 +1278,6 @@ def r2_score(
1278
1278
)
1279
1279
)
1280
1280
1281
- check_consistent_length (y_true , y_pred , sample_weight )
1282
-
1283
1281
if _num_samples (y_pred ) < 2 :
1284
1282
msg = "R^2 score is not well-defined with less than two samples."
1285
1283
warnings .warn (msg , UndefinedMetricWarning )
@@ -1343,7 +1341,9 @@ def max_error(y_true, y_pred):
1343
1341
1.0
1344
1342
"""
1345
1343
xp , _ = get_namespace (y_true , y_pred )
1346
- y_type , y_true , y_pred , _ = _check_reg_targets (y_true , y_pred , None , xp = xp )
1344
+ y_type , y_true , y_pred , _ , _ = _check_reg_targets (
1345
+ y_true , y_pred , sample_weight = None , multioutput = None , xp = xp
1346
+ )
1347
1347
if y_type == "continuous-multioutput" :
1348
1348
raise ValueError ("Multioutput not supported in max_error" )
1349
1349
return float (xp .max (xp .abs (y_true - y_pred )))
@@ -1448,7 +1448,6 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
1448
1448
)
1449
1449
if y_type == "continuous-multioutput" :
1450
1450
raise ValueError ("Multioutput not supported in mean_tweedie_deviance" )
1451
- check_consistent_length (y_true , y_pred , sample_weight )
1452
1451
1453
1452
if sample_weight is not None :
1454
1453
sample_weight = column_or_1d (sample_weight )
@@ -1773,10 +1772,9 @@ def d2_pinball_score(
1773
1772
>>> d2_pinball_score(y_true, y_true, alpha=0.1)
1774
1773
1.0
1775
1774
"""
1776
- y_type , y_true , y_pred , multioutput = _check_reg_targets (
1777
- y_true , y_pred , multioutput
1775
+ _ , y_true , y_pred , sample_weight , multioutput = _check_reg_targets (
1776
+ y_true , y_pred , sample_weight , multioutput
1778
1777
)
1779
- check_consistent_length (y_true , y_pred , sample_weight )
1780
1778
1781
1779
if _num_samples (y_pred ) < 2 :
1782
1780
msg = "D^2 score is not well-defined with less than two samples."
@@ -1796,7 +1794,6 @@ def d2_pinball_score(
1796
1794
np .percentile (y_true , q = alpha * 100 , axis = 0 ), (len (y_true ), 1 )
1797
1795
)
1798
1796
else :
1799
- sample_weight = _check_sample_weight (sample_weight , y_true )
1800
1797
y_quantile = np .tile (
1801
1798
_weighted_percentile (
1802
1799
y_true , sample_weight = sample_weight , percentile_rank = alpha * 100
0 commit comments