diff --git a/doc/whats_new/upcoming_changes/sklearn.utils/26335.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.utils/26335.enhancement.rst new file mode 100644 index 0000000000000..e5bf047cd5db9 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.utils/26335.enhancement.rst @@ -0,0 +1,4 @@ +- |Enhancement| :func:`utils.multiclass.type_of_target` raises a warning when the number + of unique classes is greater than 50% of the number of samples. This warning is raised + only if `y` has more than 20 samples. + By :user:`Rahil Parikh `. diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index 15d1428ce2ad7..3a81e2b9eb6fe 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -413,7 +413,16 @@ def _raise_or_return(): # Check multiclass if issparse(first_row_or_val): first_row_or_val = first_row_or_val.data - if cached_unique(y).shape[0] > 2 or (y.ndim == 2 and len(first_row_or_val) > 1): + classes = cached_unique(y) + if y.shape[0] > 20 and classes.shape[0] > round(0.5 * y.shape[0]): + # Only raise the warning when we have at least 20 samples. + warnings.warn( + "The number of unique classes is greater than 50% of the number " + "of samples.", + UserWarning, + stacklevel=2, + ) + if classes.shape[0] > 2 or (y.ndim == 2 and len(first_row_or_val) > 1): # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] return "multiclass" + suffix else: diff --git a/sklearn/utils/tests/test_multiclass.py b/sklearn/utils/tests/test_multiclass.py index b400d675e5687..433e8118923fb 100644 --- a/sklearn/utils/tests/test_multiclass.py +++ b/sklearn/utils/tests/test_multiclass.py @@ -1,3 +1,4 @@ +import warnings from itertools import product import numpy as np @@ -294,6 +295,25 @@ def test_unique_labels(): assert_array_equal(unique_labels(np.ones((4, 5)), np.ones((5, 5))), np.arange(5)) +def test_type_of_target_too_many_unique_classes(): + """Check that we raise a warning when the number of unique classes is greater than + 50% of the number of samples. + + We need to check that we don't raise if we have less than 20 samples. + """ + + y = np.arange(25) + msg = r"The number of unique classes is greater than 50% of the number of samples." + with pytest.warns(UserWarning, match=msg): + type_of_target(y) + + # less than 20 samples, no warning should be raised + y = np.arange(10) + with warnings.catch_warnings(): + warnings.simplefilter("error") + type_of_target(y) + + def test_unique_labels_non_specific(): # Test unique_labels with a variety of collected examples