Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f679bff
raise warning when unique classes > 50% of n_samples
rprkh May 5, 2023
a4b6675
changelog
rprkh May 5, 2023
3513e91
move warning to type_of_target and modify test accordingly
rprkh May 6, 2023
384c816
modify changelog for type_of_target
rprkh May 6, 2023
0d00435
update test for lint
rprkh May 6, 2023
7c69259
reduce failing tests
rprkh May 7, 2023
bcd08b0
explicitly raise warning to prevent errors
rprkh May 8, 2023
b731553
fix lint - y
rprkh May 8, 2023
8cbe6ac
reduce dependence on lineno and use line instead
rprkh May 8, 2023
ca39690
lint again
rprkh May 8, 2023
d36a711
modified tests
rprkh Jun 15, 2023
3988f80
remove import
rprkh Jun 15, 2023
02d820c
Merge branch 'main' into enh_check_classification_targets
rprkh Jun 16, 2023
f70e2b5
update tests
rprkh Jun 17, 2023
743bed0
fix merge conflicts in changelog
rprkh Jun 17, 2023
8dcea63
remove extra line
rprkh Jun 17, 2023
7bdda2f
make seperate try/except blocks
rprkh Jun 18, 2023
89af61d
Merge remote-tracking branch 'origin/main' into pr/rprkh/26335-1
glemaitre Nov 16, 2023
d85131a
avoid raising if the number of samples is low
glemaitre Nov 16, 2023
f7bb57f
revert missing blank line
glemaitre Nov 16, 2023
1881322
Merge remote-tracking branch 'upstream/main' into enh_check_classific…
adrinjalali Apr 30, 2025
6c225f6
No need for raw strings
betatim Apr 30, 2025
a3b21a2
review
adrinjalali Apr 30, 2025
8b3ddd7
Merge branch 'enh_check_classification_targets' of github.com:rprkh/s…
adrinjalali Apr 30, 2025
31415d7
Merge remote-tracking branch 'upstream/main' into enh_check_classific…
adrinjalali Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- |Enhancement| :func:`utils.multiclass.type_of_target` raises a warning when the number
of unique classes is greater than 50% of the number of samples. This warning is raised
only if `y` has more than 20 samples.
By :user:`Rahil Parikh <rprkh>`.
11 changes: 10 additions & 1 deletion sklearn/utils/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,16 @@ def _raise_or_return():
# Check multiclass
if issparse(first_row_or_val):
first_row_or_val = first_row_or_val.data
if cached_unique(y).shape[0] > 2 or (y.ndim == 2 and len(first_row_or_val) > 1):
classes = cached_unique(y)
if y.shape[0] > 20 and classes.shape[0] > round(0.5 * y.shape[0]):
# Only raise the warning when we have at least 20 samples.
warnings.warn(
"The number of unique classes is greater than 50% of the number "
"of samples.",
UserWarning,
stacklevel=2,
)
if classes.shape[0] > 2 or (y.ndim == 2 and len(first_row_or_val) > 1):
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
return "multiclass" + suffix
else:
Expand Down
20 changes: 20 additions & 0 deletions sklearn/utils/tests/test_multiclass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from itertools import product

import numpy as np
Expand Down Expand Up @@ -294,6 +295,25 @@ def test_unique_labels():
assert_array_equal(unique_labels(np.ones((4, 5)), np.ones((5, 5))), np.arange(5))


def test_type_of_target_too_many_unique_classes():
"""Check that we raise a warning when the number of unique classes is greater than
50% of the number of samples.

We need to check that we don't raise if we have less than 20 samples.
"""

y = np.arange(25)
msg = r"The number of unique classes is greater than 50% of the number of samples."
with pytest.warns(UserWarning, match=msg):
type_of_target(y)

# less than 20 samples, no warning should be raised
y = np.arange(10)
with warnings.catch_warnings():
warnings.simplefilter("error")
type_of_target(y)


def test_unique_labels_non_specific():
# Test unique_labels with a variety of collected examples

Expand Down