Skip to content

Fix for OvR partial_fit in various edge cases #6239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

kaichogami
Copy link
Contributor

There are many corner cases when not all classes are present in partial_fit call that needs to be addressed.

  • partial_fit works perfectly fine in cases where target classes are present in each mini-batch iteration. Although when some classes are missing in some mini batch, ovr.predict(X) gives unpredictable results.
X = np.random.randn(14, 2)
y = [0, 0, 1, 1, 2, 3, 3, 0, 1, 2, 3, 0, 2, 3]
ovr = OneVsRestClassifier(SGDClassifier())
ovr.partial_fit(X[:7], y[:7], np.unique(y))
ovr.partial_fit(X[7:], y[7:])
pred = ovr.predict(X)
ovr1 = OneVsRestClassifier(SGDClassifier())
pred1 = ovr1.fit(X, y).predict(X)

In [46]: np.mean(pred==y)
Out[46]: 0.35714285714285715

In [47]: np.mean(pred1==y)
Out[47]: 0.5

I think they should give the same result, like they do when all target classes are present in each iteration of mini-batch.

  • When mini-batch contain binary classes but all classes are not binary, ovr.predict(X) will give an error as self.label_binarizer_.y_type_ is not multiclass.
ovr = OneVsRestClassifier(SGDClassifier())
X, y = iris.data, iris.target
ovr.partial_fit(X[:60], y[:60], np.unique(y))
ovr.partial_fit(X[60:], y[60:])
In [54]: pred = ovr.predict(X)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-54-48d087b7f7f3> in <module>()
----> 1 pred = ovr.predict(X)

/home/kaichogami/codes/development_scikit-learn/scikit-learn/sklearn/multiclass.py in predict(self, X)
    307             indicator = sp.csc_matrix((data, indices, indptr),
    308                                       shape=(n_samples, len(self.estimators_)))
--> 309             return self.label_binarizer_.inverse_transform(indicator)
    310 
    311     def predict_proba(self, X):

/home/kaichogami/codes/development_scikit-learn/scikit-learn/sklearn/preprocessing/label.pyc in inverse_transform(self, Y, threshold)
    369         else:
    370             y_inv = _inverse_binarize_thresholding(Y, self.y_type_,
--> 371                                                    self.classes_, threshold)
    372 
    373         if self.sparse_input_:

/home/kaichogami/codes/development_scikit-learn/scikit-learn/sklearn/preprocessing/label.pyc in _inverse_binarize_thresholding(y, output_type, classes, threshold)
    582     if output_type == "binary" and y.ndim == 2 and y.shape[1] > 2:
    583         raise ValueError("output_type='binary', but y.shape = {0}".
--> 584                          format(y.shape))
    585 
    586     if output_type != "binary" and y.shape[1] != len(classes):

ValueError: output_type='binary', but y.shape = (150, 3)
  • When mini-batch contain only one class.
ovr.partial_fit(X[:20], y[:20], np.unique(y))
/home/kaichogami/codes/development_scikit-learn/scikit-learn/sklearn/multiclass.py in partial_fit(self, X, y, classes)
    253         self.label_binarizer_ = LabelBinarizer(sparse_output=True)
    254         Y = self.label_binarizer_.fit_transform(y)
--> 255         Y = Y.tocsc()
    256         Y = Y.T.toarray()
    257         if len(self.label_binarizer_.classes_) == 2:

AttributeError: 'numpy.ndarray' object has no attribute 'toc

This will be resolved once #6202 is solved.

Solving the first issue here takes priority, and I am unable to figure out where it goes wrong.
@MechCoder

@MechCoder
Copy link
Member

The first issue is because of the SGDClassifier. Note that there are no missing classes in the second batch (at least in your example, np.unique(y[7:]) == np.unique(y))

If you want to make sure in the SGDClassifer a fit is equal to doing partial_fit twice, you need to set n_iter=1, shuffle=False and random_state=0 in the constructor of SGD

@kaichogami
Copy link
Contributor Author

hmm.. Okay. Looks like the issue was with SGDClassifier and my in adequate knowledge about the parameters. It works perfectly fine with MultinomialNB

X = np.random.rand(14, 2)
y = [0, 0, 0, 1, 2, 2, 2, 1, 1, 2, 1, 3, 3, 3]
ovr = OneVsRestClassifier(MultinomialNB())
ovr.partial_fit(X[:7], y[:7], np.unique(y))
ovr.partial_fit(X[7:], y[7:])
pred = ovr.predict(X)

ovr1 = OneVsRestClassifier(MultinomialNB())
pred1 = ovr1.fit(X, y).predict(X)

In [96]: np.mean(pred1 == y)
Out[96]: 0.35714285714285715

In [97]: np.mean(pred == y)
Out[97]: 0.35714285714285715

Added tests, raised error when mini-batch classes doesn't contain classes from all_classes.
@amueller
Copy link
Member

amueller commented Oct 7, 2016

What's the status on this?

@kaichogami
Copy link
Contributor Author

@amueller I don't remember, give me 2-3 days to get back on this.

@kaichogami
Copy link
Contributor Author

I have been busy, if anyone wants to give it a try, please!

@srivatsan-ramesh
Copy link
Contributor

I would like to do this work.

@srivatsan-ramesh
Copy link
Contributor

Why LabelBinarizer is not having a parameter for number of classes?.. I think if it has that parameter, the fix for this should be easy?

@kaichogami
Copy link
Contributor Author

kaichogami commented Oct 29, 2016

@srivatsan-ramesh I don't think we can change the API of a class so easily. The first issue wasn't really an issue and second one is solved in this PR, and the 3rd one will work fine as #6202 has been fixed. This just needs rebasing with the master. If you want you can pull my branch and rebase with master and push it in my repo.

@srivatsan-ramesh
Copy link
Contributor

@kaichogami I don't think this fix is correct, or maybe i don't understand it?
Check my PR #7786 , I think that should work.

@kaichogami
Copy link
Contributor Author

@srivatsan-ramesh #7786 looks clean and simple. Lets go ahead with that!

@kaichogami kaichogami closed this Oct 30, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants