34
34
from scipy .special import xlogy
35
35
36
36
from ..exceptions import UndefinedMetricWarning
37
+ from ..utils ._array_api import (
38
+ _average ,
39
+ _find_matching_floating_dtype ,
40
+ device ,
41
+ get_namespace ,
42
+ )
37
43
from ..utils ._param_validation import Hidden , Interval , StrOptions , validate_params
38
44
from ..utils .stats import _weighted_percentile
39
45
from ..utils .validation import (
65
71
]
66
72
67
73
68
- def _check_reg_targets (y_true , y_pred , multioutput , dtype = "numeric" ):
74
+ def _check_reg_targets (y_true , y_pred , multioutput , dtype = "numeric" , xp = None ):
69
75
"""Check that y_true and y_pred belong to the same regression task.
70
76
71
77
Parameters
@@ -99,15 +105,17 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric"):
99
105
just the corresponding argument if ``multioutput`` is a
100
106
correct keyword.
101
107
"""
108
+ xp , _ = get_namespace (y_true , y_pred , multioutput , xp = xp )
109
+
102
110
check_consistent_length (y_true , y_pred )
103
111
y_true = check_array (y_true , ensure_2d = False , dtype = dtype )
104
112
y_pred = check_array (y_pred , ensure_2d = False , dtype = dtype )
105
113
106
114
if y_true .ndim == 1 :
107
- y_true = y_true .reshape ((- 1 , 1 ))
115
+ y_true = xp .reshape (y_true , (- 1 , 1 ))
108
116
109
117
if y_pred .ndim == 1 :
110
- y_pred = y_pred .reshape ((- 1 , 1 ))
118
+ y_pred = xp .reshape (y_pred , (- 1 , 1 ))
111
119
112
120
if y_true .shape [1 ] != y_pred .shape [1 ]:
113
121
raise ValueError (
@@ -855,9 +863,10 @@ def median_absolute_error(
855
863
856
864
857
865
def _assemble_r2_explained_variance (
858
- numerator , denominator , n_outputs , multioutput , force_finite
866
+ numerator , denominator , n_outputs , multioutput , force_finite , xp , device
859
867
):
860
868
"""Common part used by explained variance score and :math:`R^2` score."""
869
+ dtype = numerator .dtype
861
870
862
871
nonzero_denominator = denominator != 0
863
872
@@ -868,12 +877,14 @@ def _assemble_r2_explained_variance(
868
877
nonzero_numerator = numerator != 0
869
878
# Default = Zero Numerator = perfect predictions. Set to 1.0
870
879
# (note: even if denominator is zero, thus avoiding NaN scores)
871
- output_scores = np .ones ([n_outputs ])
880
+ output_scores = xp .ones ([n_outputs ], device = device , dtype = dtype )
872
881
# Non-zero Numerator and Non-zero Denominator: use the formula
873
882
valid_score = nonzero_denominator & nonzero_numerator
883
+
874
884
output_scores [valid_score ] = 1 - (
875
885
numerator [valid_score ] / denominator [valid_score ]
876
886
)
887
+
877
888
# Non-zero Numerator and Zero Denominator:
878
889
# arbitrary set to 0.0 to avoid -inf scores
879
890
output_scores [nonzero_numerator & ~ nonzero_denominator ] = 0.0
@@ -887,15 +898,18 @@ def _assemble_r2_explained_variance(
887
898
avg_weights = None
888
899
elif multioutput == "variance_weighted" :
889
900
avg_weights = denominator
890
- if not np .any (nonzero_denominator ):
901
+ if not xp .any (nonzero_denominator ):
891
902
# All weights are zero, np.average would raise a ZeroDiv error.
892
903
# This only happens when all y are constant (or 1-element long)
893
904
# Since weights are all equal, fall back to uniform weights.
894
905
avg_weights = None
895
906
else :
896
907
avg_weights = multioutput
897
908
898
- return np .average (output_scores , weights = avg_weights )
909
+ result = _average (output_scores , weights = avg_weights )
910
+ if result .size == 1 :
911
+ return float (result )
912
+ return result
899
913
900
914
901
915
@validate_params (
@@ -1033,6 +1047,9 @@ def explained_variance_score(
1033
1047
n_outputs = y_true .shape [1 ],
1034
1048
multioutput = multioutput ,
1035
1049
force_finite = force_finite ,
1050
+ xp = get_namespace (y_true )[0 ],
1051
+ # TODO: update once Array API support is added to explained_variance_score.
1052
+ device = None ,
1036
1053
)
1037
1054
1038
1055
@@ -1177,8 +1194,14 @@ def r2_score(
1177
1194
>>> r2_score(y_true, y_pred, force_finite=False)
1178
1195
-inf
1179
1196
"""
1180
- y_type , y_true , y_pred , multioutput = _check_reg_targets (
1181
- y_true , y_pred , multioutput
1197
+ input_arrays = [y_true , y_pred , sample_weight , multioutput ]
1198
+ xp , _ = get_namespace (* input_arrays )
1199
+ device_ = device (* input_arrays )
1200
+
1201
+ dtype = _find_matching_floating_dtype (y_true , y_pred , sample_weight , xp = xp )
1202
+
1203
+ _ , y_true , y_pred , multioutput = _check_reg_targets (
1204
+ y_true , y_pred , multioutput , dtype = dtype , xp = xp
1182
1205
)
1183
1206
check_consistent_length (y_true , y_pred , sample_weight )
1184
1207
@@ -1188,22 +1211,25 @@ def r2_score(
1188
1211
return float ("nan" )
1189
1212
1190
1213
if sample_weight is not None :
1191
- sample_weight = column_or_1d (sample_weight )
1192
- weight = sample_weight [:, np . newaxis ]
1214
+ sample_weight = column_or_1d (sample_weight , dtype = dtype )
1215
+ weight = sample_weight [:, None ]
1193
1216
else :
1194
1217
weight = 1.0
1195
1218
1196
- numerator = (weight * (y_true - y_pred ) ** 2 ).sum (axis = 0 , dtype = np .float64 )
1197
- denominator = (
1198
- weight * (y_true - np .average (y_true , axis = 0 , weights = sample_weight )) ** 2
1199
- ).sum (axis = 0 , dtype = np .float64 )
1219
+ numerator = xp .sum (weight * (y_true - y_pred ) ** 2 , axis = 0 )
1220
+ denominator = xp .sum (
1221
+ weight * (y_true - _average (y_true , axis = 0 , weights = sample_weight , xp = xp )) ** 2 ,
1222
+ axis = 0 ,
1223
+ )
1200
1224
1201
1225
return _assemble_r2_explained_variance (
1202
1226
numerator = numerator ,
1203
1227
denominator = denominator ,
1204
1228
n_outputs = y_true .shape [1 ],
1205
1229
multioutput = multioutput ,
1206
1230
force_finite = force_finite ,
1231
+ xp = xp ,
1232
+ device = device_ ,
1207
1233
)
1208
1234
1209
1235
0 commit comments