-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH check_classification_targets
raises a warning when unique classes > 50% of n_samples
#26335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH check_classification_targets
raises a warning when unique classes > 50% of n_samples
#26335
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am -1 with this change in check_classification_targets
. If we want a quick fix, I'll include the warning in type_of_target
.
diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py
index 24e528d10a..35e2752212 100644
--- a/sklearn/utils/multiclass.py
+++ b/sklearn/utils/multiclass.py
@@ -387,7 +387,13 @@ def type_of_target(y, input_name=""):
# Check multiclass
first_row = y[0] if not issparse(y) else y.getrow(0).data
- if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1):
+ classes = xp.unique_values(y)
+ if classes.shape[0] > round(0.5 * y.shape[0]):
+ warnings.warn(
+ r"The number of unique classes is greater than 50% of the samples."
+ )
+
+ if classes.shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1):
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
return "multiclass" + suffix
else:
This way the classes are reused for the check and do not need to be recomputed.
sklearn/utils/multiclass.py
Outdated
@@ -221,6 +224,13 @@ def check_classification_targets(y): | |||
"regression target with continuous values." | |||
) | |||
|
|||
if n_samples is not None: | |||
if len(np.unique(y)) > round(0.5 * n_samples): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This runs counter to the second point in #16399:
check_classification_targets which calls type_of_target triggers a call to _assert_all_finite and np.unique(y) which are redundant with checks done elsewhere
In general, including another np.unique(y)
adds another n*log(n)
operation. (np.unique
sorts the data)
Hey, sorry for the late reply. I made the changes. Using |
Apologies once again for the late reply, my college exams are going on and I might not get a lot of time this month.
Yes, sure. I had tried the following code earlier:
along with this as the test:
Details
I was taking a look at the documentation to see if there was another way to raise warnings and I came across warnings.formatwarning, which allowed me to execute all the tests without any errors. I do realize the current implementation is a bit of a hacky workaround, with |
I think this is because |
14aa5a8
to
3988f80
Compare
Hey @betatim, I have removed |
Update: I have removed the merge conflicts and modified the tests as per your suggestions. I have used a try/except block to check for any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed the conflict and modified the logic at the same time. Now, we raise only if we have more than 20 samples in y
. For instance, it avoids to raise the warnings for toy dataset where we test with a couple of samples.
It would avoid spurious warning in the documentation.
@betatim do you want to have an additional look at this PR? |
Btw this PR would be great since it addresses the issue of someone using a classifier for a regression task. |
# less than 20 samples, no warning should be raised | ||
y = np.arange(10) | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("error", UserWarning) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does anyone have wisdom on whether we should use warnings.simplefilter("error", UserWarning)
or warnings.simplefilter("error")
?
As far as I can tell the test passes with both options, but I dont know which is the "better" way of testing that no warnings were raised. Maybe it doesn't matter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing UserWarning
to make sure nothing's raised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I think @thomasjpfan's concerns were addressed
…cikit-learn into enh_check_classification_targets
Thanks everyone for updating and merging the PR! |
Reference Issues/PRs
Towards #16399
What does this implement/fix? Explain your changes.
Addresses the first point of #16399 (comment)
Any other comments?