Skip to content

Commit 612d93d

Browse files
fcharraselindgrenogriselbetatimlesteve
authored
ENH Use Array API in r2_score (#27904)
Signed-off-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Eric Lindgren <ericlin@chalmers.se> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Tim Head <betatim@gmail.com> Co-authored-by: Loïc Estève <loic.esteve@ymail.com> Co-authored-by: Oleksii Kachaiev <kachayev@gmail.com> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 9d85052 commit 612d93d

File tree

7 files changed

+492
-98
lines changed

7 files changed

+492
-98
lines changed

doc/modules/array_api.rst

+17
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Metrics
105105
-------
106106

107107
- :func:`sklearn.metrics.accuracy_score`
108+
- :func:`sklearn.metrics.r2_score`
108109
- :func:`sklearn.metrics.zero_one_loss`
109110

110111
Tools
@@ -115,6 +116,22 @@ Tools
115116
Coverage is expected to grow over time. Please follow the dedicated `meta-issue on GitHub
116117
<https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress.
117118

119+
Type of return values and fitted attributes
120+
-------------------------------------------
121+
122+
When calling functions or methods with Array API compatible inputs, the
123+
convention is to return array values of the same array container type and
124+
device as the input data.
125+
126+
Similarly, when an estimator is fitted with Array API compatible inputs, the
127+
fitted attributes will be arrays from the same library as the input and stored
128+
on the same device. The `predict` and `transform` method subsequently expect
129+
inputs from the same array library and device as the data passed to the `fit`
130+
method.
131+
132+
Note however that scoring functions that return scalar values return Python
133+
scalars (typically a `float` instance) instead of an array scalar value.
134+
118135
Common estimator checks
119136
=======================
120137

doc/whats_new/v1.5.rst

+16
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,22 @@ Version 1.5.0
2222

2323
**In Development**
2424

25+
Support for Array API
26+
---------------------
27+
28+
Additional estimators and functions have been updated to include support for all
29+
`Array API <https://data-apis.org/array-api/latest/>`_ compliant inputs.
30+
31+
See :ref:`array_api` for more details.
32+
33+
**Functions:**
34+
35+
- :func:`sklearn.metrics.r2_score` now supports Array API compliant inputs.
36+
:pr:`27904` by :user:`Eric Lindgren <elindgren>`, `Franck Charras <fcharras>`,
37+
`Olivier Grisel <ogrisel>` and `Tim Head <betatim>`.
38+
39+
**Classes:**
40+
2541
Support for building with Meson
2642
-------------------------------
2743

sklearn/metrics/_classification.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@
3838
check_consistent_length,
3939
column_or_1d,
4040
)
41-
from ..utils._array_api import _union1d, _weighted_sum, get_namespace
41+
from ..utils._array_api import (
42+
_average,
43+
_union1d,
44+
get_namespace,
45+
)
4246
from ..utils._param_validation import (
4347
Hidden,
4448
Interval,
@@ -224,7 +228,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
224228
else:
225229
score = y_true == y_pred
226230

227-
return _weighted_sum(score, sample_weight, normalize)
231+
return float(_average(score, weights=sample_weight, normalize=normalize))
228232

229233

230234
@validate_params(
@@ -2809,7 +2813,7 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None):
28092813
return n_differences / (y_true.shape[0] * y_true.shape[1] * weight_average)
28102814

28112815
elif y_type in ["binary", "multiclass"]:
2812-
return _weighted_sum(y_true != y_pred, sample_weight, normalize=True)
2816+
return float(_average(y_true != y_pred, weights=sample_weight, normalize=True))
28132817
else:
28142818
raise ValueError("{0} is not supported".format(y_type))
28152819

@@ -2994,7 +2998,7 @@ def log_loss(
29942998
y_pred = y_pred / y_pred_sum[:, np.newaxis]
29952999
loss = -xlogy(transformed_labels, y_pred).sum(axis=1)
29963000

2997-
return _weighted_sum(loss, sample_weight, normalize)
3001+
return float(_average(loss, weights=sample_weight, normalize=normalize))
29983002

29993003

30003004
@validate_params(

sklearn/metrics/_regression.py

+41-15
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434
from scipy.special import xlogy
3535

3636
from ..exceptions import UndefinedMetricWarning
37+
from ..utils._array_api import (
38+
_average,
39+
_find_matching_floating_dtype,
40+
device,
41+
get_namespace,
42+
)
3743
from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params
3844
from ..utils.stats import _weighted_percentile
3945
from ..utils.validation import (
@@ -65,7 +71,7 @@
6571
]
6672

6773

68-
def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric"):
74+
def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
6975
"""Check that y_true and y_pred belong to the same regression task.
7076
7177
Parameters
@@ -99,15 +105,17 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric"):
99105
just the corresponding argument if ``multioutput`` is a
100106
correct keyword.
101107
"""
108+
xp, _ = get_namespace(y_true, y_pred, multioutput, xp=xp)
109+
102110
check_consistent_length(y_true, y_pred)
103111
y_true = check_array(y_true, ensure_2d=False, dtype=dtype)
104112
y_pred = check_array(y_pred, ensure_2d=False, dtype=dtype)
105113

106114
if y_true.ndim == 1:
107-
y_true = y_true.reshape((-1, 1))
115+
y_true = xp.reshape(y_true, (-1, 1))
108116

109117
if y_pred.ndim == 1:
110-
y_pred = y_pred.reshape((-1, 1))
118+
y_pred = xp.reshape(y_pred, (-1, 1))
111119

112120
if y_true.shape[1] != y_pred.shape[1]:
113121
raise ValueError(
@@ -855,9 +863,10 @@ def median_absolute_error(
855863

856864

857865
def _assemble_r2_explained_variance(
858-
numerator, denominator, n_outputs, multioutput, force_finite
866+
numerator, denominator, n_outputs, multioutput, force_finite, xp, device
859867
):
860868
"""Common part used by explained variance score and :math:`R^2` score."""
869+
dtype = numerator.dtype
861870

862871
nonzero_denominator = denominator != 0
863872

@@ -868,12 +877,14 @@ def _assemble_r2_explained_variance(
868877
nonzero_numerator = numerator != 0
869878
# Default = Zero Numerator = perfect predictions. Set to 1.0
870879
# (note: even if denominator is zero, thus avoiding NaN scores)
871-
output_scores = np.ones([n_outputs])
880+
output_scores = xp.ones([n_outputs], device=device, dtype=dtype)
872881
# Non-zero Numerator and Non-zero Denominator: use the formula
873882
valid_score = nonzero_denominator & nonzero_numerator
883+
874884
output_scores[valid_score] = 1 - (
875885
numerator[valid_score] / denominator[valid_score]
876886
)
887+
877888
# Non-zero Numerator and Zero Denominator:
878889
# arbitrary set to 0.0 to avoid -inf scores
879890
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0
@@ -887,15 +898,18 @@ def _assemble_r2_explained_variance(
887898
avg_weights = None
888899
elif multioutput == "variance_weighted":
889900
avg_weights = denominator
890-
if not np.any(nonzero_denominator):
901+
if not xp.any(nonzero_denominator):
891902
# All weights are zero, np.average would raise a ZeroDiv error.
892903
# This only happens when all y are constant (or 1-element long)
893904
# Since weights are all equal, fall back to uniform weights.
894905
avg_weights = None
895906
else:
896907
avg_weights = multioutput
897908

898-
return np.average(output_scores, weights=avg_weights)
909+
result = _average(output_scores, weights=avg_weights)
910+
if result.size == 1:
911+
return float(result)
912+
return result
899913

900914

901915
@validate_params(
@@ -1033,6 +1047,9 @@ def explained_variance_score(
10331047
n_outputs=y_true.shape[1],
10341048
multioutput=multioutput,
10351049
force_finite=force_finite,
1050+
xp=get_namespace(y_true)[0],
1051+
# TODO: update once Array API support is added to explained_variance_score.
1052+
device=None,
10361053
)
10371054

10381055

@@ -1177,8 +1194,14 @@ def r2_score(
11771194
>>> r2_score(y_true, y_pred, force_finite=False)
11781195
-inf
11791196
"""
1180-
y_type, y_true, y_pred, multioutput = _check_reg_targets(
1181-
y_true, y_pred, multioutput
1197+
input_arrays = [y_true, y_pred, sample_weight, multioutput]
1198+
xp, _ = get_namespace(*input_arrays)
1199+
device_ = device(*input_arrays)
1200+
1201+
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
1202+
1203+
_, y_true, y_pred, multioutput = _check_reg_targets(
1204+
y_true, y_pred, multioutput, dtype=dtype, xp=xp
11821205
)
11831206
check_consistent_length(y_true, y_pred, sample_weight)
11841207

@@ -1188,22 +1211,25 @@ def r2_score(
11881211
return float("nan")
11891212

11901213
if sample_weight is not None:
1191-
sample_weight = column_or_1d(sample_weight)
1192-
weight = sample_weight[:, np.newaxis]
1214+
sample_weight = column_or_1d(sample_weight, dtype=dtype)
1215+
weight = sample_weight[:, None]
11931216
else:
11941217
weight = 1.0
11951218

1196-
numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype=np.float64)
1197-
denominator = (
1198-
weight * (y_true - np.average(y_true, axis=0, weights=sample_weight)) ** 2
1199-
).sum(axis=0, dtype=np.float64)
1219+
numerator = xp.sum(weight * (y_true - y_pred) ** 2, axis=0)
1220+
denominator = xp.sum(
1221+
weight * (y_true - _average(y_true, axis=0, weights=sample_weight, xp=xp)) ** 2,
1222+
axis=0,
1223+
)
12001224

12011225
return _assemble_r2_explained_variance(
12021226
numerator=numerator,
12031227
denominator=denominator,
12041228
n_outputs=y_true.shape[1],
12051229
multioutput=multioutput,
12061230
force_finite=force_finite,
1231+
xp=xp,
1232+
device=device_,
12071233
)
12081234

12091235

sklearn/metrics/tests/test_common.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from sklearn.utils import shuffle
5656
from sklearn.utils._array_api import (
5757
_atol_for_type,
58+
_convert_to_numpy,
5859
yield_namespace_device_dtype_combinations,
5960
)
6061
from sklearn.utils._testing import (
@@ -1749,7 +1750,7 @@ def check_array_api_metric(
17491750
metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)
17501751

17511752
assert_allclose(
1752-
metric_xp,
1753+
_convert_to_numpy(xp.asarray(metric_xp), xp),
17531754
metric_np,
17541755
atol=_atol_for_type(dtype_name),
17551756
)
@@ -1813,6 +1814,33 @@ def check_array_api_multiclass_classification_metric(
18131814
)
18141815

18151816

1817+
def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
1818+
y_true_np = np.array([[1, 3], [1, 2]], dtype=dtype_name)
1819+
y_pred_np = np.array([[1, 4], [1, 1]], dtype=dtype_name)
1820+
1821+
check_array_api_metric(
1822+
metric,
1823+
array_namespace,
1824+
device,
1825+
dtype_name,
1826+
y_true_np=y_true_np,
1827+
y_pred_np=y_pred_np,
1828+
sample_weight=None,
1829+
)
1830+
1831+
sample_weight = np.array([0.1, 2.0], dtype=dtype_name)
1832+
1833+
check_array_api_metric(
1834+
metric,
1835+
array_namespace,
1836+
device,
1837+
dtype_name,
1838+
y_true_np=y_true_np,
1839+
y_pred_np=y_pred_np,
1840+
sample_weight=sample_weight,
1841+
)
1842+
1843+
18161844
array_api_metric_checkers = {
18171845
accuracy_score: [
18181846
check_array_api_binary_classification_metric,
@@ -1822,6 +1850,7 @@ def check_array_api_multiclass_classification_metric(
18221850
check_array_api_binary_classification_metric,
18231851
check_array_api_multiclass_classification_metric,
18241852
],
1853+
r2_score: [check_array_api_regression_metric],
18251854
}
18261855

18271856

0 commit comments

Comments
 (0)