-
-
Notifications
You must be signed in to change notification settings - Fork 26k
FIX discrete Naive Bayes model fitting for degenerate single-class case #18925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX discrete Naive Bayes model fitting for degenerate single-class case #18925
Conversation
d3793a0
to
b3d62fa
Compare
@@ -552,7 +552,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): | |||
if _check_partial_fit_first_call(self, classes): | |||
# This is the first call to partial_fit: | |||
# initialize various cumulative counters | |||
n_effective_classes = len(classes) if len(classes) > 1 else 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a small behavior change here when len(classes) == 1
.
However, the docstring implies that len(classes)
should be 2 for binary problems:
classes : array-like of shape (n_classes,), default=None
List of all the classes that can possibly appear in the y vector.
Assuming this is true, then I believe the only behavior change is for the degenerate single-class case, as intended.
131c483
to
8f3c089
Compare
8f3c089
to
1d10e36
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A first pass. We will need to add an entry in the what's new v1.0.rst
as a bug fix
Cool, I've drafted this in 6e51074. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart of some style changes, LGTM.
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
@thomasjpfan @lorentzenchr Would you like to have a second look |
…list-classes Centralize lists of naive Bayes classes to streamline test parameterization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dpoznik Thanks for fixing this. You need to resolve merge conflicts. Beware that we've changed the default branch from master to main. Let me/us know if you need help.
As per @lorentzenchr in code review: "Pickability is tested in `check_estimators_pickle` (see `utils/estimator_checks.py`) in `tests/test_common.py`. I verified it by running `pytest sklearn/tests/test_common.py -v -k BernoulliNB` and so on."
No problem; it was fun :)
This should be all set with 9bbfb22.
Cool; thanks for the heads-up! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@glemaitre Before merging, could you shortly skim over the last changes - just to be sure:smirk: |
Thanks @dpoznik Everything seems good. Merging. |
closes #18974
When training multinomial naive Bayes models as part of a larger pipeline, it is possible for a degenerate case to arise, wherein there is just one class label. Prior to this change, it was possible to fit a single-class
MultinomialNB
model. One would expect that using such a model for prediction would deterministically yield the one class label. However, sometimes anIndexError
would arise.Ultimately, this traced to the fact that
LabelBinarizer.transform
returns an array of shape(n_samples, 1)
when there are one or two classes, butnaive_bayes._BaseDiscreteNB.fit
andnaive_bayes._BaseDiscreteNB.partial_fit
(reasonably) assumed that if the return value had one column, that meant there were exactly two classes.With this fix, the single-class case is handled differently from the two-class case, thereby ensuring the expected behavior and eliminating the source of the
IndexError
.On this branch, all tests run via these three commands pass:
The last one includes a new regression test that fails on
master
.