Skip to content

Commit 27e5256

Browse files
authored
MNT Add _check_sample_weights to classification metrics (#31701)
1 parent ed5f530 commit 27e5256

File tree

5 files changed

+109
-29
lines changed

5 files changed

+109
-29
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
- Additional `sample_weight` checking has been added to
3+
:func:`metrics.accuracy_score`,
4+
:func:`metrics.balanced_accuracy_score`,
5+
:func:`metrics.brier_score_loss`,
6+
:func:`metrics.class_likelihood_ratios`,
7+
:func:`metrics.classification_report`,
8+
:func:`metrics.cohen_kappa_score`,
9+
:func:`metrics.confusion_matrix`,
10+
:func:`metrics.f1_score`,
11+
:func:`metrics.fbeta_score`,
12+
:func:`metrics.hamming_loss`,
13+
:func:`metrics.jaccard_score`,
14+
:func:`metrics.matthews_corrcoef`,
15+
:func:`metrics.multilabel_confusion_matrix`,
16+
:func:`metrics.precision_recall_fscore_support`,
17+
:func:`metrics.precision_score`,
18+
:func:`metrics.recall_score` and
19+
:func:`metrics.zero_one_loss`.
20+
`sample_weight` can only be 1D, consistent to `y_true` and `y_pred` in length,and
21+
all values must be finite and not complex.
22+
By :user:`Lucy Liu <lucyleeow>`.

sklearn/metrics/_classification.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _check_zero_division(zero_division):
6666
return np.nan
6767

6868

69-
def _check_targets(y_true, y_pred):
69+
def _check_targets(y_true, y_pred, sample_weight=None):
7070
"""Check that y_true and y_pred belong to the same classification task.
7171
7272
This converts multiclass or binary types to a common shape, and raises a
@@ -83,6 +83,8 @@ def _check_targets(y_true, y_pred):
8383
8484
y_pred : array-like
8585
86+
sample_weight : array-like, default=None
87+
8688
Returns
8789
-------
8890
type_true : one of {'multilabel-indicator', 'multiclass', 'binary'}
@@ -92,11 +94,17 @@ def _check_targets(y_true, y_pred):
9294
y_true : array or indicator matrix
9395
9496
y_pred : array or indicator matrix
97+
98+
sample_weight : array or None
9599
"""
96-
xp, _ = get_namespace(y_true, y_pred)
97-
check_consistent_length(y_true, y_pred)
100+
xp, _ = get_namespace(y_true, y_pred, sample_weight)
101+
check_consistent_length(y_true, y_pred, sample_weight)
98102
type_true = type_of_target(y_true, input_name="y_true")
99103
type_pred = type_of_target(y_pred, input_name="y_pred")
104+
if sample_weight is not None:
105+
sample_weight = _check_sample_weight(
106+
sample_weight, y_true, force_float_dtype=False
107+
)
100108

101109
y_type = {type_true, type_pred}
102110
if y_type == {"binary", "multiclass"}:
@@ -148,7 +156,7 @@ def _check_targets(y_true, y_pred):
148156
y_pred = csr_matrix(y_pred)
149157
y_type = "multilabel-indicator"
150158

151-
return y_type, y_true, y_pred
159+
return y_type, y_true, y_pred, sample_weight
152160

153161

154162
def _validate_multiclass_probabilistic_prediction(
@@ -200,6 +208,9 @@ def _validate_multiclass_probabilistic_prediction(
200208
raise ValueError(f"y_prob contains values lower than 0: {y_prob.min()}")
201209

202210
check_consistent_length(y_prob, y_true, sample_weight)
211+
if sample_weight is not None:
212+
_check_sample_weight(sample_weight, y_true, force_float_dtype=False)
213+
203214
lb = LabelBinarizer()
204215

205216
if labels is not None:
@@ -356,8 +367,9 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
356367
xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)
357368
# Compute accuracy for each possible representation
358369
y_true, y_pred = attach_unique(y_true, y_pred)
359-
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
360-
check_consistent_length(y_true, y_pred, sample_weight)
370+
y_type, y_true, y_pred, sample_weight = _check_targets(
371+
y_true, y_pred, sample_weight
372+
)
361373

362374
if y_type.startswith("multilabel"):
363375
differing_labels = _count_nonzero(y_true - y_pred, xp=xp, device=device, axis=1)
@@ -464,7 +476,9 @@ def confusion_matrix(
464476
(0, 2, 1, 1)
465477
"""
466478
y_true, y_pred = attach_unique(y_true, y_pred)
467-
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
479+
y_type, y_true, y_pred, sample_weight = _check_targets(
480+
y_true, y_pred, sample_weight
481+
)
468482
if y_type not in ("binary", "multiclass"):
469483
raise ValueError("%s is not supported" % y_type)
470484

@@ -482,10 +496,6 @@ def confusion_matrix(
482496

483497
if sample_weight is None:
484498
sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
485-
else:
486-
sample_weight = np.asarray(sample_weight)
487-
488-
check_consistent_length(y_true, y_pred, sample_weight)
489499

490500
n_labels = labels.size
491501
# If labels are not consecutive integers starting from zero, then
@@ -654,11 +664,10 @@ def multilabel_confusion_matrix(
654664
[1, 2]]])
655665
"""
656666
y_true, y_pred = attach_unique(y_true, y_pred)
657-
xp, _, device_ = get_namespace_and_device(y_true, y_pred)
658-
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
659-
if sample_weight is not None:
660-
sample_weight = column_or_1d(sample_weight, device=device_)
661-
check_consistent_length(y_true, y_pred, sample_weight)
667+
xp, _, device_ = get_namespace_and_device(y_true, y_pred, sample_weight)
668+
y_type, y_true, y_pred, sample_weight = _check_targets(
669+
y_true, y_pred, sample_weight
670+
)
662671

663672
if y_type not in ("binary", "multiclass", "multilabel-indicator"):
664673
raise ValueError("%s is not supported" % y_type)
@@ -1171,8 +1180,9 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None):
11711180
-0.33
11721181
"""
11731182
y_true, y_pred = attach_unique(y_true, y_pred)
1174-
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
1175-
check_consistent_length(y_true, y_pred, sample_weight)
1183+
y_type, y_true, y_pred, sample_weight = _check_targets(
1184+
y_true, y_pred, sample_weight
1185+
)
11761186
if y_type not in {"binary", "multiclass"}:
11771187
raise ValueError("%s is not supported" % y_type)
11781188

@@ -1759,7 +1769,7 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
17591769
raise ValueError("average has to be one of " + str(average_options))
17601770

17611771
y_true, y_pred = attach_unique(y_true, y_pred)
1762-
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
1772+
y_type, y_true, y_pred, _ = _check_targets(y_true, y_pred)
17631773
# Convert to Python primitive type to avoid NumPy type / Python str
17641774
# comparison. See https://github.com/numpy/numpy/issues/6784
17651775
present_labels = _tolist(unique_labels(y_true, y_pred))
@@ -2227,7 +2237,9 @@ class are present in `y_true`): both likelihood ratios are undefined.
22272237
# remove `FutureWarning`, and the Warns section in the docstring should not mention
22282238
# `raise_warning` anymore.
22292239
y_true, y_pred = attach_unique(y_true, y_pred)
2230-
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
2240+
y_type, y_true, y_pred, sample_weight = _check_targets(
2241+
y_true, y_pred, sample_weight
2242+
)
22312243
if y_type != "binary":
22322244
raise ValueError(
22332245
"class_likelihood_ratios only supports binary classification "
@@ -2945,7 +2957,9 @@ class 2 1.00 0.67 0.80 3
29452957
"""
29462958

29472959
y_true, y_pred = attach_unique(y_true, y_pred)
2948-
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
2960+
y_type, y_true, y_pred, sample_weight = _check_targets(
2961+
y_true, y_pred, sample_weight
2962+
)
29492963

29502964
if labels is None:
29512965
labels = unique_labels(y_true, y_pred)
@@ -3134,15 +3148,15 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None):
31343148
0.75
31353149
"""
31363150
y_true, y_pred = attach_unique(y_true, y_pred)
3137-
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
3138-
check_consistent_length(y_true, y_pred, sample_weight)
3151+
y_type, y_true, y_pred, sample_weight = _check_targets(
3152+
y_true, y_pred, sample_weight
3153+
)
31393154

31403155
xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)
31413156

31423157
if sample_weight is None:
31433158
weight_average = 1.0
31443159
else:
3145-
sample_weight = xp.asarray(sample_weight, device=device)
31463160
weight_average = _average(sample_weight, xp=xp)
31473161

31483162
if y_type.startswith("multilabel"):
@@ -3440,6 +3454,8 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos
34403454
assert_all_finite(y_prob)
34413455

34423456
check_consistent_length(y_prob, y_true, sample_weight)
3457+
if sample_weight is not None:
3458+
_check_sample_weight(sample_weight, y_true, force_float_dtype=False)
34433459

34443460
y_type = type_of_target(y_true, input_name="y_true")
34453461
if y_type != "binary":

sklearn/metrics/tests/test_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def test_multilabel_confusion_matrix_errors():
596596
# Bad sample_weight
597597
with pytest.raises(ValueError, match="inconsistent numbers of samples"):
598598
multilabel_confusion_matrix(y_true, y_pred, sample_weight=[1, 2])
599-
with pytest.raises(ValueError, match="should be a 1d array"):
599+
with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"):
600600
multilabel_confusion_matrix(
601601
y_true, y_pred, sample_weight=[[1, 2, 3], [2, 3, 4], [3, 4, 5]]
602602
)
@@ -2541,7 +2541,7 @@ def test__check_targets():
25412541
_check_targets(y1, y2)
25422542

25432543
else:
2544-
merged_type, y1out, y2out = _check_targets(y1, y2)
2544+
merged_type, y1out, y2out, _ = _check_targets(y1, y2)
25452545
assert merged_type == expected
25462546
if merged_type.startswith("multilabel"):
25472547
assert y1out.format == "csr"

sklearn/metrics/tests/test_common.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,38 @@ def test_format_invariance_with_1d_vectors(name):
881881
metric(y1_row, y2_row)
882882

883883

884+
@pytest.mark.parametrize("metric", CLASSIFICATION_METRICS.values())
885+
def test_classification_with_invalid_sample_weight(metric):
886+
# Check invalid `sample_weight` raises correct error
887+
random_state = check_random_state(0)
888+
n_samples = 20
889+
y1 = random_state.randint(0, 2, size=(n_samples,))
890+
y2 = random_state.randint(0, 2, size=(n_samples,))
891+
892+
sample_weight = random_state.random_sample(size=(n_samples - 1,))
893+
with pytest.raises(ValueError, match="Found input variables with inconsistent"):
894+
metric(y1, y2, sample_weight=sample_weight)
895+
896+
sample_weight = random_state.random_sample(size=(n_samples,))
897+
sample_weight[0] = np.inf
898+
with pytest.raises(ValueError, match="Input sample_weight contains infinity"):
899+
metric(y1, y2, sample_weight=sample_weight)
900+
901+
sample_weight[0] = np.nan
902+
with pytest.raises(ValueError, match="Input sample_weight contains NaN"):
903+
metric(y1, y2, sample_weight=sample_weight)
904+
905+
sample_weight = np.array([1 + 2j, 3 + 4j, 5 + 7j])
906+
with pytest.raises(ValueError, match="Complex data not supported"):
907+
metric(y1[:3], y2[:3], sample_weight=sample_weight)
908+
909+
sample_weight = random_state.random_sample(size=(n_samples * 2,)).reshape(
910+
(n_samples, 2)
911+
)
912+
with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"):
913+
metric(y1, y2, sample_weight=sample_weight)
914+
915+
884916
@pytest.mark.parametrize(
885917
"name", sorted(set(CLASSIFICATION_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS)
886918
)

sklearn/utils/validation.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,7 +2134,13 @@ def _check_psd_eigenvalues(lambdas, enable_warnings=False):
21342134

21352135

21362136
def _check_sample_weight(
2137-
sample_weight, X, *, dtype=None, ensure_non_negative=False, copy=False
2137+
sample_weight,
2138+
X,
2139+
*,
2140+
dtype=None,
2141+
force_float_dtype=True,
2142+
ensure_non_negative=False,
2143+
copy=False,
21382144
):
21392145
"""Validate sample weights.
21402146
@@ -2162,6 +2168,10 @@ def _check_sample_weight(
21622168
If `dtype` is not `{np.float32, np.float64, None}`, then output will
21632169
be `np.float64`.
21642170
2171+
force_float_dtype : bool, default=True
2172+
Whether `X` should be forced to be float dtype, when `dtype` is a non-float
2173+
dtype or None.
2174+
21652175
ensure_non_negative : bool, default=False,
21662176
Whether or not the weights are expected to be non-negative.
21672177
@@ -2185,15 +2195,15 @@ def _check_sample_weight(
21852195
float_dtypes = (
21862196
[xp.float32] if max_float_type == xp.float32 else [xp.float64, xp.float32]
21872197
)
2188-
if dtype is not None and dtype not in float_dtypes:
2198+
if force_float_dtype and dtype is not None and dtype not in float_dtypes:
21892199
dtype = max_float_type
21902200

21912201
if sample_weight is None:
21922202
sample_weight = xp.ones(n_samples, dtype=dtype, device=device)
21932203
elif isinstance(sample_weight, numbers.Number):
21942204
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype, device=device)
21952205
else:
2196-
if dtype is None:
2206+
if force_float_dtype and dtype is None:
21972207
dtype = float_dtypes
21982208
sample_weight = check_array(
21992209
sample_weight,

0 commit comments

Comments
 (0)