From 979d67862b9db324d1435ff7d532972d00873428 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 9 Jul 2025 14:55:04 +1000 Subject: [PATCH] add more sample weight checks --- sklearn/metrics/tests/test_common.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 74bdb46d8258f..5cdc2ead54740 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1614,6 +1614,19 @@ def test_regression_with_invalid_sample_weight(name): with pytest.raises(ValueError, match="Found input variables with inconsistent"): metric(y_true, y_pred, sample_weight=sample_weight) + sample_weight = random_state.random_sample(size=(n_samples,)) + sample_weight[0] = np.inf + with pytest.raises(ValueError, match="Input sample_weight contains infinity"): + metric(y_true, y_pred, sample_weight=sample_weight) + + sample_weight[0] = np.nan + with pytest.raises(ValueError, match="Input sample_weight contains NaN"): + metric(y_true, y_pred, sample_weight=sample_weight) + + sample_weight = np.array([1 + 2j, 3 + 4j, 5 + 7j]) + with pytest.raises(ValueError, match="Complex data not supported"): + metric(y_true[:3], y_pred[:3], sample_weight=sample_weight) + sample_weight = random_state.random_sample(size=(n_samples * 2,)).reshape( (n_samples, 2) )