Skip to content

ENH check_classification_targets raises a warning when unique classes > 50% of n_samples #26335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 30, 2025

Conversation

rprkh
Copy link
Contributor

@rprkh rprkh commented May 5, 2023

Reference Issues/PRs

Towards #16399

What does this implement/fix? Explain your changes.

Addresses the first point of #16399 (comment)

Any other comments?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am -1 with this change in check_classification_targets. If we want a quick fix, I'll include the warning in type_of_target.

diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py
index 24e528d10a..35e2752212 100644
--- a/sklearn/utils/multiclass.py
+++ b/sklearn/utils/multiclass.py
@@ -387,7 +387,13 @@ def type_of_target(y, input_name=""):
 
     # Check multiclass
     first_row = y[0] if not issparse(y) else y.getrow(0).data
-    if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1):
+    classes = xp.unique_values(y)
+    if classes.shape[0] > round(0.5 * y.shape[0]):
+        warnings.warn(
+            r"The number of unique classes is greater than 50% of the samples."
+        )
+
+    if classes.shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1):
         # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
         return "multiclass" + suffix
     else:

This way the classes are reused for the check and do not need to be recomputed.

@@ -221,6 +224,13 @@ def check_classification_targets(y):
"regression target with continuous values."
)

if n_samples is not None:
if len(np.unique(y)) > round(0.5 * n_samples):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs counter to the second point in #16399:

check_classification_targets which calls type_of_target triggers a call to _assert_all_finite and np.unique(y) which are redundant with checks done elsewhere

In general, including another np.unique(y) adds another n*log(n) operation. (np.unique sorts the data)

@rprkh
Copy link
Contributor Author

rprkh commented May 8, 2023

Hey, sorry for the late reply. I made the changes. Using warnings.formatwarning was the only way I could find to prevent test failures in sklearn/metrics/tests/test_classification.py, while at the same time successfully executing the tests in sklearn/utils/tests/test_multiclass.py.

@rprkh
Copy link
Contributor Author

rprkh commented May 15, 2023

Apologies once again for the late reply, my college exams are going on and I might not get a lot of time this month.

Could you explain a bit why warnings.formatwarning is needed instead of warnings.warn?

Yes, sure. I had tried the following code earlier:

classes = xp.unique_values(y)
if classes.shape[0] > round(0.5 * y.shape[0]) and classes.shape[0] > 1:
    warnings.warn(
        r"The number of unique classes is greater than 50% of the samples.",
        UserWarning,
        stacklevel=2,
    )

along with this as the test:

def test_unique_labels_and_samples():
    """Raises a warning when the number of unique classes is
    larger than 50% of `n_samples`"""

    y = np.array([0, 1, 1, 2, 3, 4, 5])

    warning_message = r"The number of unique classes is greater than 50% of the samples."

    with pytest.warns(UserWarning, match=warning_message):
        type_of_target(y)

pytest sklearn/utils/tests/test_multiclass.py will pass the checks however, when running
pytest sklearn/metrics/tests/test_classification.py some tests would fail due to the warning (The number of unique classes is greater than 50% of the samples) being raised while checking for zero division:

Details
================================================= test session starts =================================================
platform win32 -- Python 3.9.13, pytest-7.1.2, pluggy-1.0.0
rootdir: C:\Users\rahil\Rahil\KJ Somaiya\scikit-learn, configfile: setup.cfg
plugins: cov-3.0.0
collected 199 items

sklearn\metrics\tests\test_classification.py ..FFFF............................................................. [ 33%]
.....................................F...............................FFF.......................s................ [ 89%]
....................                                                                                             [100%]

====================================================== FAILURES =======================================================
_______________________________ test_classification_report_zero_division_warning[warn] ________________________________

zero_division = 'warn'

    @pytest.mark.parametrize("zero_division", ["warn", 0, 1, np.nan])
    def test_classification_report_zero_division_warning(zero_division):
        y_true, y_pred = ["a", "b", "c"], ["a", "b", "d"]
        with warnings.catch_warnings(record=True) as record:
            classification_report(
                y_true, y_pred, zero_division=zero_division, output_dict=True
            )
            if zero_division == "warn":
                assert len(record) > 1
                for item in record:
                    msg = "Use `zero_division` parameter to control this behavior."
>                   assert msg in str(item.message)
E                   AssertionError: assert 'Use `zero_division` parameter to control this behavior.' in 'The number of unique classes is greater than 50% of the samples.'
E                    +  where 'The number of unique classes is greater than 50% of the samples.' = str(UserWarning('The number of unique classes is greater than 50% of the samples.'))
E                    +    where UserWarning('The number of unique classes is greater than 50% of the samples.') = <warnings.WarningMessage object at 0x0000012CE4709910>.message

sklearn\metrics\tests\test_classification.py:212: AssertionError
_________________________________ test_classification_report_zero_division_warning[0] _________________________________

zero_division = 0

    @pytest.mark.parametrize("zero_division", ["warn", 0, 1, np.nan])
    def test_classification_report_zero_division_warning(zero_division):
        y_true, y_pred = ["a", "b", "c"], ["a", "b", "d"]
        with warnings.catch_warnings(record=True) as record:
            classification_report(
                y_true, y_pred, zero_division=zero_division, output_dict=True
            )
            if zero_division == "warn":
                assert len(record) > 1
                for item in record:
                    msg = "Use `zero_division` parameter to control this behavior."
                    assert msg in str(item.message)
            else:
>               assert not record
E               assert not [<warnings.WarningMessage object at 0x0000012CE4855A30>, <warnings.WarningMessage object at 0x0000012CE4855C10>, <warn...>, <warnings.WarningMessage object at 0x0000012CE4855550>, <warnings.WarningMessage object at 0x0000012CE48555E0>, ...]

sklearn\metrics\tests\test_classification.py:214: AssertionError
_________________________________ test_classification_report_zero_division_warning[1] _________________________________

zero_division = 1

    @pytest.mark.parametrize("zero_division", ["warn", 0, 1, np.nan])
    def test_classification_report_zero_division_warning(zero_division):
        y_true, y_pred = ["a", "b", "c"], ["a", "b", "d"]
        with warnings.catch_warnings(record=True) as record:
            classification_report(
                y_true, y_pred, zero_division=zero_division, output_dict=True
            )
            if zero_division == "warn":
                assert len(record) > 1
                for item in record:
                    msg = "Use `zero_division` parameter to control this behavior."
                    assert msg in str(item.message)
            else:
>               assert not record
E               assert not [<warnings.WarningMessage object at 0x0000012CE49438B0>, <warnings.WarningMessage object at 0x0000012CE4943760>, <warn...>, <warnings.WarningMessage object at 0x0000012CE4943520>, <warnings.WarningMessage object at 0x0000012CE49437C0>, ...]

sklearn\metrics\tests\test_classification.py:214: AssertionError
________________________________ test_classification_report_zero_division_warning[nan] ________________________________

zero_division = nan

    @pytest.mark.parametrize("zero_division", ["warn", 0, 1, np.nan])
    def test_classification_report_zero_division_warning(zero_division):
        y_true, y_pred = ["a", "b", "c"], ["a", "b", "d"]
        with warnings.catch_warnings(record=True) as record:
            classification_report(
                y_true, y_pred, zero_division=zero_division, output_dict=True
            )
            if zero_division == "warn":
                assert len(record) > 1
                for item in record:
                    msg = "Use `zero_division` parameter to control this behavior."
                    assert msg in str(item.message)
            else:
>               assert not record
E               assert not [<warnings.WarningMessage object at 0x0000012CE4856490>, <warnings.WarningMessage object at 0x0000012CE4856130>, <warn...>, <warnings.WarningMessage object at 0x0000012CE4856C70>, <warnings.WarningMessage object at 0x0000012CE4856220>, ...]

sklearn\metrics\tests\test_classification.py:214: AssertionError
____________________________________________ test_multiclass_jaccard_score ____________________________________________

recwarn = WarningsRecorder(record=True)

    def test_multiclass_jaccard_score(recwarn):
        y_true = ["ant", "ant", "cat", "cat", "ant", "cat", "bird", "bird"]
        y_pred = ["cat", "ant", "cat", "cat", "ant", "bird", "bird", "cat"]
        labels = ["ant", "bird", "cat"]
        lb = LabelBinarizer()
        lb.fit(labels)
        y_true_bin = lb.transform(y_true)
        y_pred_bin = lb.transform(y_pred)
        multi_jaccard_score = partial(jaccard_score, y_true, y_pred)
        bin_jaccard_score = partial(jaccard_score, y_true_bin, y_pred_bin)
        multi_labels_list = [
            ["ant", "bird"],
            ["ant", "cat"],
            ["cat", "bird"],
            ["ant"],
            ["bird"],
            ["cat"],
            None,
        ]
        bin_labels_list = [[0, 1], [0, 2], [2, 1], [0], [1], [2], None]

        # other than average='samples'/'none-samples', test everything else here
        for average in ("macro", "weighted", "micro", None):
            for m_label, b_label in zip(multi_labels_list, bin_labels_list):
                assert_almost_equal(
                    multi_jaccard_score(average=average, labels=m_label),
                    bin_jaccard_score(average=average, labels=b_label),
                )

        y_true = np.array([[0, 0], [0, 0], [0, 0]])
        y_pred = np.array([[0, 0], [0, 0], [0, 0]])
        with ignore_warnings():
            assert jaccard_score(y_true, y_pred, average="weighted") == 0

>       assert not list(recwarn)
E       assert not [<warnings.WarningMessage object at 0x0000012CE6ADA760>, <warnings.WarningMessage object at 0x0000012CE6ADAD30>]
E        +  where [<warnings.WarningMessage object at 0x0000012CE6ADA760>, <warnings.WarningMessage object at 0x0000012CE6ADAD30>] = list(WarningsRecorder(record=True))

sklearn\metrics\tests\test_classification.py:1541: AssertionError
____________________________________ test_prf_no_warnings_if_zero_division_set[0] _____________________________________

zero_division = 0

    @pytest.mark.parametrize("zero_division", [0, 1, np.nan])
    def test_prf_no_warnings_if_zero_division_set(zero_division):
        # average of per-label scores
        f = precision_recall_fscore_support
        for average in [None, "weighted", "macro"]:
>           assert_no_warnings(
                f, [0, 1, 2], [1, 1, 2], average=average, zero_division=zero_division
            )

sklearn\metrics\tests\test_classification.py:2049:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

func = <function precision_recall_fscore_support at 0x0000012CE428C4C0>, args = ([0, 1, 2], [1, 1, 2])
kw = {'average': None, 'zero_division': 0}
w = [<warnings.WarningMessage object at 0x0000012CE6B201C0>, <warnings.WarningMessage object at 0x0000012CE6B20040>, <warnings.WarningMessage object at 0x0000012CE6B202B0>, <warnings.WarningMessage object at 0x0000012CE6B20130>]
result = (array([0. , 0.5, 1. ]), array([0., 1., 1.]), array([0.        , 0.66666667, 1.        ]), array([1, 1, 1], dtype=int64))

    def assert_no_warnings(func, *args, **kw):
        """
        Parameters
        ----------
        func
        *args
        **kw
        """
        # very important to avoid uncontrolled state propagation
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")

            result = func(*args, **kw)
            if hasattr(np, "FutureWarning"):
                # Filter out numpy-specific warnings in numpy >= 1.9
                w = [e for e in w if e.category is not np.VisibleDeprecationWarning]

            if len(w) > 0:
>               raise AssertionError(
                    "Got warnings when calling %s: [%s]"
                    % (func.__name__, ", ".join(str(warning) for warning in w))
                )
E               AssertionError: Got warnings when calling precision_recall_fscore_support: [{message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\metrics\\_classification.py', lineno : 86, line : None}, {message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\utils\\multiclass.py', lineno : 81, line : None}, {message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\metrics\\_classification.py', lineno : 86, line : None}, {message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\utils\\multiclass.py', lineno : 81, line : None}]

sklearn\utils\_testing.py:111: AssertionError
____________________________________ test_prf_no_warnings_if_zero_division_set[1] _____________________________________

zero_division = 1

    @pytest.mark.parametrize("zero_division", [0, 1, np.nan])
    def test_prf_no_warnings_if_zero_division_set(zero_division):
        # average of per-label scores
        f = precision_recall_fscore_support
        for average in [None, "weighted", "macro"]:
>           assert_no_warnings(
                f, [0, 1, 2], [1, 1, 2], average=average, zero_division=zero_division
            )

sklearn\metrics\tests\test_classification.py:2049:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

func = <function precision_recall_fscore_support at 0x0000012CE428C4C0>, args = ([0, 1, 2], [1, 1, 2])
kw = {'average': None, 'zero_division': 1}
w = [<warnings.WarningMessage object at 0x0000012CE6BFD580>, <warnings.WarningMessage object at 0x0000012CE6BFD700>, <warnings.WarningMessage object at 0x0000012CE6BFD610>, <warnings.WarningMessage object at 0x0000012CE6BFD760>]
result = (array([1. , 0.5, 1. ]), array([0., 1., 1.]), array([0.        , 0.66666667, 1.        ]), array([1, 1, 1], dtype=int64))

    def assert_no_warnings(func, *args, **kw):
        """
        Parameters
        ----------
        func
        *args
        **kw
        """
        # very important to avoid uncontrolled state propagation
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")

            result = func(*args, **kw)
            if hasattr(np, "FutureWarning"):
                # Filter out numpy-specific warnings in numpy >= 1.9
                w = [e for e in w if e.category is not np.VisibleDeprecationWarning]

            if len(w) > 0:
>               raise AssertionError(
                    "Got warnings when calling %s: [%s]"
                    % (func.__name__, ", ".join(str(warning) for warning in w))
                )
E               AssertionError: Got warnings when calling precision_recall_fscore_support: [{message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\metrics\\_classification.py', lineno : 86, line : None}, {message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\utils\\multiclass.py', lineno : 81, line : None}, {message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\metrics\\_classification.py', lineno : 86, line : None}, {message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\utils\\multiclass.py', lineno : 81, line : None}]

sklearn\utils\_testing.py:111: AssertionError
___________________________________ test_prf_no_warnings_if_zero_division_set[nan] ____________________________________

zero_division = nan

    @pytest.mark.parametrize("zero_division", [0, 1, np.nan])
    def test_prf_no_warnings_if_zero_division_set(zero_division):
        # average of per-label scores
        f = precision_recall_fscore_support
        for average in [None, "weighted", "macro"]:
>           assert_no_warnings(
                f, [0, 1, 2], [1, 1, 2], average=average, zero_division=zero_division
            )

sklearn\metrics\tests\test_classification.py:2049:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

func = <function precision_recall_fscore_support at 0x0000012CE428C4C0>, args = ([0, 1, 2], [1, 1, 2])
kw = {'average': None, 'zero_division': nan}
w = [<warnings.WarningMessage object at 0x0000012CE6BE6880>, <warnings.WarningMessage object at 0x0000012CE6BE6730>, <warnings.WarningMessage object at 0x0000012CE6BE6760>, <warnings.WarningMessage object at 0x0000012CE6BE6B50>]
result = (array([nan, 0.5, 1. ]), array([0., 1., 1.]), array([       nan, 0.66666667, 1.        ]), array([1, 1, 1], dtype=int64))

    def assert_no_warnings(func, *args, **kw):
        """
        Parameters
        ----------
        func
        *args
        **kw
        """
        # very important to avoid uncontrolled state propagation
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")

            result = func(*args, **kw)
            if hasattr(np, "FutureWarning"):
                # Filter out numpy-specific warnings in numpy >= 1.9
                w = [e for e in w if e.category is not np.VisibleDeprecationWarning]

            if len(w) > 0:
>               raise AssertionError(
                    "Got warnings when calling %s: [%s]"
                    % (func.__name__, ", ".join(str(warning) for warning in w))
                )
E               AssertionError: Got warnings when calling precision_recall_fscore_support: [{message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\metrics\\_classification.py', lineno : 86, line : None}, {message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\utils\\multiclass.py', lineno : 81, line : None}, {message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\metrics\\_classification.py', lineno : 86, line : None}, {message : UserWarning('The number of unique classes is greater than 50% of the samples.'), category : 'UserWarning', filename : 'c:\\users\\rahil\\rahil\\kj somaiya\\scikit-learn\\sklearn\\utils\\multiclass.py', lineno : 81, line : None}]

sklearn\utils\_testing.py:111: AssertionError
=============================== 8 failed, 190 passed, 1 skipped, 194 warnings in 3.54s ================================

I was taking a look at the documentation to see if there was another way to raise warnings and I came across warnings.formatwarning, which allowed me to execute all the tests without any errors.

I do realize the current implementation is a bit of a hacky workaround, with filename and line having to be explicitly passed to raise the warning. I would be open to suggestions to make the implementation more cleaner and less hard-coded.

@betatim
Copy link
Member

betatim commented May 24, 2023

I was taking a look at the documentation to see if there was another way to raise warnings and I came across warnings.formatwarning, which allowed me to execute all the tests without any errors.

I think this is because formatwarning does not actually raise a warning, it just prints out text to the console. The right way to deal with the fact that some of the tests start failing when you introduce this warning is to see if we can modify the data used in those tests to not raise a warning or to explicitly filter out this warning in the tests.

@rprkh rprkh force-pushed the enh_check_classification_targets branch from 14aa5a8 to 3988f80 Compare June 16, 2023 04:00
@rprkh
Copy link
Contributor Author

rprkh commented Jun 16, 2023

I think this is because formatwarning does not actually raise a warning, it just prints out text to the console.

Hey @betatim, I have removed formatwarning and made use of warnings.warn as originally suggested.

@rprkh
Copy link
Contributor Author

rprkh commented Jun 18, 2023

The right way to deal with the fact that some of the tests start failing when you introduce this warning is to see if we can modify the data used in those tests to not raise a warning or to explicitly filter out this warning in the tests.

Update: I have removed the merge conflicts and modified the tests as per your suggestions. I have used a try/except block to check for any AssertionError in the tests. In the except block we check to see if the message ("The number of unique classes is greater than 50% of the samples.") is present or not. I have also filtered the warnings where it was possible such as in the case of the recwarn list.

@glemaitre glemaitre self-requested a review November 16, 2023 14:27
Copy link

github-actions bot commented Nov 16, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 31415d7. Link to the linter CI: here

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the conflict and modified the logic at the same time. Now, we raise only if we have more than 20 samples in y. For instance, it avoids to raise the warnings for toy dataset where we test with a couple of samples.

It would avoid spurious warning in the documentation.

@glemaitre glemaitre added the Waiting for Second Reviewer First reviewer is done, need a second one! label Nov 16, 2023
@glemaitre
Copy link
Member

@betatim do you want to have an additional look at this PR?

@amueller
Copy link
Member

Btw this PR would be great since it addresses the issue of someone using a classifier for a regression task.

# less than 20 samples, no warning should be raised
y = np.arange(10)
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does anyone have wisdom on whether we should use warnings.simplefilter("error", UserWarning) or warnings.simplefilter("error")?

As far as I can tell the test passes with both options, but I dont know which is the "better" way of testing that no warnings were raised. Maybe it doesn't matter

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing UserWarning to make sure nothing's raised.

Copy link
Member

@betatim betatim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I think @thomasjpfan's concerns were addressed

@adrinjalali adrinjalali enabled auto-merge (squash) April 30, 2025 14:38
@adrinjalali adrinjalali merged commit 4985e69 into scikit-learn:main Apr 30, 2025
36 checks passed
@rprkh
Copy link
Contributor Author

rprkh commented May 2, 2025

Thanks everyone for updating and merging the PR!

@rprkh rprkh deleted the enh_check_classification_targets branch May 2, 2025 11:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:utils Waiting for Second Reviewer First reviewer is done, need a second one!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants