-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Fix for OvR partial_fit in various edge cases #6239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The first issue is because of the If you want to make sure in the |
hmm.. Okay. Looks like the issue was with X = np.random.rand(14, 2)
y = [0, 0, 0, 1, 2, 2, 2, 1, 1, 2, 1, 3, 3, 3]
ovr = OneVsRestClassifier(MultinomialNB())
ovr.partial_fit(X[:7], y[:7], np.unique(y))
ovr.partial_fit(X[7:], y[7:])
pred = ovr.predict(X)
ovr1 = OneVsRestClassifier(MultinomialNB())
pred1 = ovr1.fit(X, y).predict(X)
In [96]: np.mean(pred1 == y)
Out[96]: 0.35714285714285715
In [97]: np.mean(pred == y)
Out[97]: 0.35714285714285715 |
db62a75
to
59c873a
Compare
Added tests, raised error when mini-batch classes doesn't contain classes from all_classes.
59c873a
to
cdd8fa5
Compare
What's the status on this? |
@amueller I don't remember, give me 2-3 days to get back on this. |
I have been busy, if anyone wants to give it a try, please! |
I would like to do this work. |
Why |
@srivatsan-ramesh I don't think we can change the API of a class so easily. The first issue wasn't really an issue and second one is solved in this PR, and the 3rd one will work fine as #6202 has been fixed. This just needs rebasing with the master. If you want you can pull my branch and rebase with master and push it in my repo. |
@kaichogami I don't think this fix is correct, or maybe i don't understand it? |
@srivatsan-ramesh #7786 looks clean and simple. Lets go ahead with that! |
There are many corner cases when not all classes are present in
partial_fit
call that needs to be addressed.partial_fit
works perfectly fine in cases wheretarget classes
are present in each mini-batch iteration. Although when someclasses
are missing in some mini batch,ovr.predict(X)
gives unpredictable results.I think they should give the same result, like they do when all
target classes
are present in each iteration of mini-batch.ovr.predict(X)
will give an error asself.label_binarizer_.y_type_
is notmulticlass
.This will be resolved once #6202 is solved.
Solving the first issue here takes priority, and I am unable to figure out where it goes wrong.
@MechCoder