Skip to content

Commit f3094eb

Browse files
lucyleeowjeremiedbb
authored andcommitted
MNT Make sample_weight checking more consistent in regression metrics (#30886)
1 parent 3994e12 commit f3094eb

File tree

5 files changed

+87
-43
lines changed

5 files changed

+87
-43
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
- Additional `sample_weight` checking has been added to
2+
:func:`metrics.mean_absolute_error`,
3+
:func:`metrics.mean_pinball_loss`,
4+
:func:`metrics.mean_absolute_percentage_error`,
5+
:func:`metrics.mean_squared_error`,
6+
:func:`metrics.root_mean_squared_error`,
7+
:func:`metrics.mean_squared_log_error`,
8+
:func:`metrics.root_mean_squared_log_error`,
9+
:func:`metrics.explained_variance_score`,
10+
:func:`metrics.r2_score`,
11+
:func:`metrics.mean_tweedie_deviance`,
12+
:func:`metrics.mean_poisson_deviance`,
13+
:func:`metrics.mean_gamma_deviance` and
14+
:func:`metrics.d2_tweedie_score`.
15+
`sample_weight` can only be 1D, consistent to `y_true` and `y_pred` in length
16+
or a scalar.
17+
By :user:`Lucy Liu <lucyleeow>`.

sklearn/metrics/_regression.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@
5757
]
5858

5959

60-
def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
61-
"""Check that y_true and y_pred belong to the same regression task.
60+
def _check_reg_targets(
61+
y_true, y_pred, sample_weight, multioutput, dtype="numeric", xp=None
62+
):
63+
"""Check that y_true, y_pred and sample_weight belong to the same regression task.
6264
6365
To reduce redundancy when calling `_find_matching_floating_dtype`,
6466
please use `_check_reg_targets_with_floating_dtype` instead.
@@ -71,6 +73,9 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
7173
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
7274
Estimated target values.
7375
76+
sample_weight : array-like of shape (n_samples,) or None
77+
Sample weights.
78+
7479
multioutput : array-like or string in ['raw_values', uniform_average',
7580
'variance_weighted'] or None
7681
None is accepted due to backward compatibility of r2_score().
@@ -95,6 +100,9 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
95100
y_pred : array-like of shape (n_samples, n_outputs)
96101
Estimated target values.
97102
103+
sample_weight : array-like of shape (n_samples,) or None
104+
Sample weights.
105+
98106
multioutput : array-like of shape (n_outputs) or string in ['raw_values',
99107
uniform_average', 'variance_weighted'] or None
100108
Custom output weights if ``multioutput`` is array-like or
@@ -103,9 +111,11 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
103111
"""
104112
xp, _ = get_namespace(y_true, y_pred, multioutput, xp=xp)
105113

106-
check_consistent_length(y_true, y_pred)
114+
check_consistent_length(y_true, y_pred, sample_weight)
107115
y_true = check_array(y_true, ensure_2d=False, dtype=dtype)
108116
y_pred = check_array(y_pred, ensure_2d=False, dtype=dtype)
117+
if sample_weight is not None:
118+
sample_weight = _check_sample_weight(sample_weight, y_true, dtype=dtype)
109119

110120
if y_true.ndim == 1:
111121
y_true = xp.reshape(y_true, (-1, 1))
@@ -141,14 +151,13 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
141151
)
142152
y_type = "continuous" if n_outputs == 1 else "continuous-multioutput"
143153

144-
return y_type, y_true, y_pred, multioutput
154+
return y_type, y_true, y_pred, sample_weight, multioutput
145155

146156

147157
def _check_reg_targets_with_floating_dtype(
148158
y_true, y_pred, sample_weight, multioutput, xp=None
149159
):
150-
"""Ensures that y_true, y_pred, and sample_weight correspond to the same
151-
regression task.
160+
"""Ensures y_true, y_pred, and sample_weight correspond to same regression task.
152161
153162
Extends `_check_reg_targets` by automatically selecting a suitable floating-point
154163
data type for inputs using `_find_matching_floating_dtype`.
@@ -197,15 +206,10 @@ def _check_reg_targets_with_floating_dtype(
197206
"""
198207
dtype_name = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
199208

200-
y_type, y_true, y_pred, multioutput = _check_reg_targets(
201-
y_true, y_pred, multioutput, dtype=dtype_name, xp=xp
209+
y_type, y_true, y_pred, sample_weight, multioutput = _check_reg_targets(
210+
y_true, y_pred, sample_weight, multioutput, dtype=dtype_name, xp=xp
202211
)
203212

204-
# _check_reg_targets does not accept sample_weight as input.
205-
# Convert sample_weight's data type separately to match dtype_name.
206-
if sample_weight is not None:
207-
sample_weight = xp.asarray(sample_weight, dtype=dtype_name)
208-
209213
return y_type, y_true, y_pred, sample_weight, multioutput
210214

211215

@@ -282,8 +286,6 @@ def mean_absolute_error(
282286
)
283287
)
284288

285-
check_consistent_length(y_true, y_pred, sample_weight)
286-
287289
output_errors = _average(
288290
xp.abs(y_pred - y_true), weights=sample_weight, axis=0, xp=xp
289291
)
@@ -383,7 +385,6 @@ def mean_pinball_loss(
383385
)
384386
)
385387

386-
check_consistent_length(y_true, y_pred, sample_weight)
387388
diff = y_true - y_pred
388389
sign = xp.astype(diff >= 0, diff.dtype)
389390
loss = alpha * sign * diff - (1 - alpha) * (1 - sign) * diff
@@ -489,7 +490,6 @@ def mean_absolute_percentage_error(
489490
y_true, y_pred, sample_weight, multioutput, xp=xp
490491
)
491492
)
492-
check_consistent_length(y_true, y_pred, sample_weight)
493493
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=y_true.dtype, device=device_)
494494
y_true_abs = xp.abs(y_true)
495495
mape = xp.abs(y_pred - y_true) / xp.maximum(y_true_abs, epsilon)
@@ -581,7 +581,6 @@ def mean_squared_error(
581581
y_true, y_pred, sample_weight, multioutput, xp=xp
582582
)
583583
)
584-
check_consistent_length(y_true, y_pred, sample_weight)
585584
output_errors = _average((y_true - y_pred) ** 2, axis=0, weights=sample_weight)
586585

587586
if isinstance(multioutput, str):
@@ -753,8 +752,10 @@ def mean_squared_log_error(
753752
"""
754753
xp, _ = get_namespace(y_true, y_pred)
755754

756-
_, y_true, y_pred, _, _ = _check_reg_targets_with_floating_dtype(
757-
y_true, y_pred, sample_weight, multioutput, xp=xp
755+
_, y_true, y_pred, sample_weight, multioutput = (
756+
_check_reg_targets_with_floating_dtype(
757+
y_true, y_pred, sample_weight, multioutput, xp=xp
758+
)
758759
)
759760

760761
if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
@@ -829,8 +830,10 @@ def root_mean_squared_log_error(
829830
"""
830831
xp, _ = get_namespace(y_true, y_pred)
831832

832-
_, y_true, y_pred, _, _ = _check_reg_targets_with_floating_dtype(
833-
y_true, y_pred, sample_weight, multioutput, xp=xp
833+
_, y_true, y_pred, sample_weight, multioutput = (
834+
_check_reg_targets_with_floating_dtype(
835+
y_true, y_pred, sample_weight, multioutput, xp=xp
836+
)
834837
)
835838

836839
if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
@@ -912,13 +915,12 @@ def median_absolute_error(
912915
>>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
913916
0.85
914917
"""
915-
y_type, y_true, y_pred, multioutput = _check_reg_targets(
916-
y_true, y_pred, multioutput
918+
_, y_true, y_pred, sample_weight, multioutput = _check_reg_targets(
919+
y_true, y_pred, sample_weight, multioutput
917920
)
918921
if sample_weight is None:
919922
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
920923
else:
921-
sample_weight = _check_sample_weight(sample_weight, y_pred)
922924
output_errors = _weighted_percentile(
923925
np.abs(y_pred - y_true), sample_weight=sample_weight
924926
)
@@ -1106,8 +1108,6 @@ def explained_variance_score(
11061108
)
11071109
)
11081110

1109-
check_consistent_length(y_true, y_pred, sample_weight)
1110-
11111111
y_diff_avg = _average(y_true - y_pred, weights=sample_weight, axis=0)
11121112
numerator = _average(
11131113
(y_true - y_pred - y_diff_avg) ** 2, weights=sample_weight, axis=0
@@ -1278,8 +1278,6 @@ def r2_score(
12781278
)
12791279
)
12801280

1281-
check_consistent_length(y_true, y_pred, sample_weight)
1282-
12831281
if _num_samples(y_pred) < 2:
12841282
msg = "R^2 score is not well-defined with less than two samples."
12851283
warnings.warn(msg, UndefinedMetricWarning)
@@ -1343,7 +1341,9 @@ def max_error(y_true, y_pred):
13431341
1.0
13441342
"""
13451343
xp, _ = get_namespace(y_true, y_pred)
1346-
y_type, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, None, xp=xp)
1344+
y_type, y_true, y_pred, _, _ = _check_reg_targets(
1345+
y_true, y_pred, sample_weight=None, multioutput=None, xp=xp
1346+
)
13471347
if y_type == "continuous-multioutput":
13481348
raise ValueError("Multioutput not supported in max_error")
13491349
return float(xp.max(xp.abs(y_true - y_pred)))
@@ -1448,7 +1448,6 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
14481448
)
14491449
if y_type == "continuous-multioutput":
14501450
raise ValueError("Multioutput not supported in mean_tweedie_deviance")
1451-
check_consistent_length(y_true, y_pred, sample_weight)
14521451

14531452
if sample_weight is not None:
14541453
sample_weight = column_or_1d(sample_weight)
@@ -1773,10 +1772,9 @@ def d2_pinball_score(
17731772
>>> d2_pinball_score(y_true, y_true, alpha=0.1)
17741773
1.0
17751774
"""
1776-
y_type, y_true, y_pred, multioutput = _check_reg_targets(
1777-
y_true, y_pred, multioutput
1775+
_, y_true, y_pred, sample_weight, multioutput = _check_reg_targets(
1776+
y_true, y_pred, sample_weight, multioutput
17781777
)
1779-
check_consistent_length(y_true, y_pred, sample_weight)
17801778

17811779
if _num_samples(y_pred) < 2:
17821780
msg = "D^2 score is not well-defined with less than two samples."
@@ -1796,7 +1794,6 @@ def d2_pinball_score(
17961794
np.percentile(y_true, q=alpha * 100, axis=0), (len(y_true), 1)
17971795
)
17981796
else:
1799-
sample_weight = _check_sample_weight(sample_weight, y_true)
18001797
y_quantile = np.tile(
18011798
_weighted_percentile(
18021799
y_true, sample_weight=sample_weight, percentile_rank=alpha * 100

sklearn/metrics/tests/test_common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,6 +1588,32 @@ def test_regression_sample_weight_invariance(name):
15881588
check_sample_weight_invariance(name, metric, y_true, y_pred)
15891589

15901590

1591+
@pytest.mark.parametrize(
1592+
"name",
1593+
sorted(
1594+
set(ALL_METRICS).intersection(set(REGRESSION_METRICS))
1595+
- METRICS_WITHOUT_SAMPLE_WEIGHT
1596+
),
1597+
)
1598+
def test_regression_with_invalid_sample_weight(name):
1599+
# Check that `sample_weight` with incorrect length raises error
1600+
n_samples = 50
1601+
random_state = check_random_state(0)
1602+
y_true = random_state.random_sample(size=(n_samples,))
1603+
y_pred = random_state.random_sample(size=(n_samples,))
1604+
metric = ALL_METRICS[name]
1605+
1606+
sample_weight = random_state.random_sample(size=(n_samples - 1,))
1607+
with pytest.raises(ValueError, match="Found input variables with inconsistent"):
1608+
metric(y_true, y_pred, sample_weight=sample_weight)
1609+
1610+
sample_weight = random_state.random_sample(size=(n_samples * 2,)).reshape(
1611+
(n_samples, 2)
1612+
)
1613+
with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"):
1614+
metric(y_true, y_pred, sample_weight=sample_weight)
1615+
1616+
15911617
@pytest.mark.parametrize(
15921618
"name",
15931619
sorted(

sklearn/metrics/tests/test_regression.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,9 @@ def test__check_reg_targets():
330330

331331
for (type1, y1, n_out1), (type2, y2, n_out2) in product(EXAMPLES, repeat=2):
332332
if type1 == type2 and n_out1 == n_out2:
333-
y_type, y_check1, y_check2, multioutput = _check_reg_targets(y1, y2, None)
333+
y_type, y_check1, y_check2, _, _ = _check_reg_targets(
334+
y1, y2, sample_weight=None, multioutput=None
335+
)
334336
assert type1 == y_type
335337
if type1 == "continuous":
336338
assert_array_equal(y_check1, np.reshape(y1, (-1, 1)))
@@ -340,7 +342,7 @@ def test__check_reg_targets():
340342
assert_array_equal(y_check2, y2)
341343
else:
342344
with pytest.raises(ValueError):
343-
_check_reg_targets(y1, y2, None)
345+
_check_reg_targets(y1, y2, sample_weight=None, multioutput=None)
344346

345347

346348
def test__check_reg_targets_exception():
@@ -351,7 +353,7 @@ def test__check_reg_targets_exception():
351353
)
352354
)
353355
with pytest.raises(ValueError, match=expected_message):
354-
_check_reg_targets([1, 2, 3], [[1], [2], [3]], invalid_multioutput)
356+
_check_reg_targets([1, 2, 3], [[1], [2], [3]], None, invalid_multioutput)
355357

356358

357359
def test_regression_multioutput_array():

sklearn/utils/validation.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2169,16 +2169,18 @@ def _check_sample_weight(
21692169
"""
21702170
n_samples = _num_samples(X)
21712171

2172-
if dtype is not None and dtype not in [np.float32, np.float64]:
2173-
dtype = np.float64
2172+
xp, _ = get_namespace(X)
2173+
2174+
if dtype is not None and dtype not in [xp.float32, xp.float64]:
2175+
dtype = xp.float64
21742176

21752177
if sample_weight is None:
2176-
sample_weight = np.ones(n_samples, dtype=dtype)
2178+
sample_weight = xp.ones(n_samples, dtype=dtype)
21772179
elif isinstance(sample_weight, numbers.Number):
2178-
sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
2180+
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype)
21792181
else:
21802182
if dtype is None:
2181-
dtype = [np.float64, np.float32]
2183+
dtype = [xp.float64, xp.float32]
21822184
sample_weight = check_array(
21832185
sample_weight,
21842186
accept_sparse=False,

0 commit comments

Comments
 (0)