Skip to content

Commit 6fd23fc

Browse files
ENH/DOC clearer sample weight validation error msg (#31873)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent 8525ba5 commit 6fd23fc

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
``sklearn.utils._check_sample_weight`` now raises a clearer error message when the
2+
provided weights are neither a scalar nor a 1-D array-like of the same size as the
3+
input data.
4+
:issue:`31712` by :user:`Kapil Parekh <kapslock123>`.

sklearn/utils/tests/test_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1608,7 +1608,7 @@ def _check_sample_weight_common(xp):
16081608
assert_allclose(_convert_to_numpy(sample_weight, xp), 2 * np.ones(5))
16091609

16101610
# check wrong number of dimensions
1611-
with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"):
1611+
with pytest.raises(ValueError, match=r"Sample weights must be 1D array or scalar"):
16121612
_check_sample_weight(xp.ones((2, 4)), X=xp.ones((2, 2)))
16131613

16141614
# check incorrect n_samples

sklearn/utils/validation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2222,7 +2222,11 @@ def _check_sample_weight(
22222222
input_name="sample_weight",
22232223
)
22242224
if sample_weight.ndim != 1:
2225-
raise ValueError("Sample weights must be 1D array or scalar")
2225+
raise ValueError(
2226+
f"Sample weights must be 1D array or scalar, got "
2227+
f"{sample_weight.ndim}D array. Expected either a scalar value "
2228+
f"or a 1D array of length {n_samples}."
2229+
)
22262230

22272231
if sample_weight.shape != (n_samples,):
22282232
raise ValueError(

0 commit comments

Comments
 (0)