Skip to content

Commit 9fd70a8

Browse files
srivatsan-rameshjnothman
authored andcommitted
[MRG+2] Fix for OvR partial_fit bug (#7786)
* mini-batch can now contain less number of classes than actual data * added tests where mini batches doesn't contain all classes
1 parent 02cc6f5 commit 9fd70a8

File tree

2 files changed

+49
-27
lines changed

2 files changed

+49
-27
lines changed

sklearn/multiclass.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,26 +244,31 @@ def partial_fit(self, X, y, classes=None):
244244
self
245245
"""
246246
if _check_partial_fit_first_call(self, classes):
247-
if (not hasattr(self.estimator, "partial_fit")):
248-
raise ValueError("Base estimator {0}, doesn't have partial_fit"
249-
"method".format(self.estimator))
247+
if not hasattr(self.estimator, "partial_fit"):
248+
raise ValueError(("Base estimator {0}, doesn't have "
249+
"partial_fit method").format(self.estimator))
250250
self.estimators_ = [clone(self.estimator) for _ in range
251251
(self.n_classes_)]
252252

253-
# A sparse LabelBinarizer, with sparse_output=True, has been shown to
254-
# outperform or match a dense label binarizer in all cases and has also
255-
# resulted in less or equal memory consumption in the fit_ovr function
256-
# overall.
257-
self.label_binarizer_ = LabelBinarizer(sparse_output=True)
258-
Y = self.label_binarizer_.fit_transform(y)
253+
# A sparse LabelBinarizer, with sparse_output=True, has been
254+
# shown to outperform or match a dense label binarizer in all
255+
# cases and has also resulted in less or equal memory consumption
256+
# in the fit_ovr function overall.
257+
self.label_binarizer_ = LabelBinarizer(sparse_output=True)
258+
self.label_binarizer_.fit(self.classes_)
259+
260+
if np.setdiff1d(y, self.classes_):
261+
raise ValueError(("Mini-batch contains {0} while classes " +
262+
"must be subset of {1}").format(np.unique(y),
263+
self.classes_))
264+
265+
Y = self.label_binarizer_.transform(y)
259266
Y = Y.tocsc()
260267
columns = (col.toarray().ravel() for col in Y.T)
261268

262-
self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(
263-
_partial_fit_binary)(self.estimators_[i],
264-
X, next(columns) if self.classes_[i] in
265-
self.label_binarizer_.classes_ else
266-
np.zeros((1, len(y))))
269+
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
270+
delayed(_partial_fit_binary)(self.estimators_[i], X,
271+
next(columns))
267272
for i in range(self.n_classes_))
268273

269274
return self

sklearn/tests/test_multiclass.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import scipy.sparse as sp
33

4-
from sklearn.utils.testing import assert_array_equal
4+
from sklearn.utils.testing import assert_array_equal, assert_raises_regex
55
from sklearn.utils.testing import assert_equal
66
from sklearn.utils.testing import assert_almost_equal
77
from sklearn.utils.testing import assert_true
@@ -22,7 +22,8 @@
2222
from sklearn.svm import LinearSVC, SVC
2323
from sklearn.naive_bayes import MultinomialNB
2424
from sklearn.linear_model import (LinearRegression, Lasso, ElasticNet, Ridge,
25-
Perceptron, LogisticRegression)
25+
Perceptron, LogisticRegression,
26+
SGDClassifier)
2627
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2728
from sklearn.model_selection import GridSearchCV, cross_val_score
2829
from sklearn.pipeline import Pipeline
@@ -89,20 +90,37 @@ def test_ovr_partial_fit():
8990
assert_greater(np.mean(y == pred), 0.65)
9091

9192
# Test when mini batches doesn't have all classes
92-
ovr = OneVsRestClassifier(MultinomialNB())
93-
ovr.partial_fit(iris.data[:60], iris.target[:60], np.unique(iris.target))
94-
ovr.partial_fit(iris.data[60:], iris.target[60:])
95-
pred = ovr.predict(iris.data)
96-
ovr2 = OneVsRestClassifier(MultinomialNB())
97-
pred2 = ovr2.fit(iris.data, iris.target).predict(iris.data)
93+
# with SGDClassifier
94+
X = np.abs(np.random.randn(14, 2))
95+
y = [1, 1, 1, 1, 2, 3, 3, 0, 0, 2, 3, 1, 2, 3]
96+
97+
ovr = OneVsRestClassifier(SGDClassifier(n_iter=1, shuffle=False,
98+
random_state=0))
99+
ovr.partial_fit(X[:7], y[:7], np.unique(y))
100+
ovr.partial_fit(X[7:], y[7:])
101+
pred = ovr.predict(X)
102+
ovr1 = OneVsRestClassifier(SGDClassifier(n_iter=1, shuffle=False,
103+
random_state=0))
104+
pred1 = ovr1.fit(X, y).predict(X)
105+
assert_equal(np.mean(pred == y), np.mean(pred1 == y))
98106

99-
assert_almost_equal(pred, pred2)
100-
assert_equal(len(ovr.estimators_), len(np.unique(iris.target)))
101-
assert_greater(np.mean(iris.target == pred), 0.65)
107+
108+
def test_ovr_partial_fit_exceptions():
109+
ovr = OneVsRestClassifier(MultinomialNB())
110+
X = np.abs(np.random.randn(14, 2))
111+
y = [1, 1, 1, 1, 2, 3, 3, 0, 0, 2, 3, 1, 2, 3]
112+
ovr.partial_fit(X[:7], y[:7], np.unique(y))
113+
# A new class value which was not in the first call of partial_fit
114+
# It should raise ValueError
115+
y1 = [5] + y[7:-1]
116+
assert_raises_regex(ValueError, "Mini-batch contains \[.+\] while classes"
117+
" must be subset of \[.+\]",
118+
ovr.partial_fit, X=X[7:], y=y1)
102119

103120

104121
def test_ovr_ovo_regressor():
105-
# test that ovr and ovo work on regressors which don't have a decision_function
122+
# test that ovr and ovo work on regressors which don't have a decision_
123+
# function
106124
ovr = OneVsRestClassifier(DecisionTreeRegressor())
107125
pred = ovr.fit(iris.data, iris.target).predict(iris.data)
108126
assert_equal(len(ovr.estimators_), n_classes)
@@ -204,7 +222,6 @@ def test_ovr_multiclass():
204222
for base_clf in (MultinomialNB(), LinearSVC(random_state=0),
205223
LinearRegression(), Ridge(),
206224
ElasticNet()):
207-
208225
clf = OneVsRestClassifier(base_clf).fit(X, y)
209226
assert_equal(set(clf.classes_), classes)
210227
y_pred = clf.predict(np.array([[0, 0, 4]]))[0]

0 commit comments

Comments
 (0)