Skip to content

Commit 59c873a

Browse files
committed
Fixed OvR partial_fit method in various edge cases.
Added tests, raised error when mini-batch classes doesn't contain classes from all_classes.
1 parent 21d9ccc commit 59c873a

File tree

2 files changed

+55
-12
lines changed

2 files changed

+55
-12
lines changed

sklearn/multiclass.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747
from .utils.validation import check_consistent_length
4848
from .utils.validation import check_is_fitted
4949
from .utils.multiclass import (_check_partial_fit_first_call,
50-
check_classification_targets)
50+
check_classification_targets,
51+
type_of_target)
5152
from .externals.joblib import Parallel
5253
from .externals.joblib import delayed
5354
from .externals.six.moves import zip as izip
@@ -250,16 +251,25 @@ def partial_fit(self, X, y, classes=None):
250251
# outperform or match a dense label binarizer in all cases and has also
251252
# resulted in less or equal memory consumption in the fit_ovr function
252253
# overall.
254+
if not set(self.classes_).issuperset(y):
255+
raise ValueError("Mini-batch contains {0} while classes "
256+
"must be subset of {1}".format(np.unique(y),
257+
self.classes_))
253258
self.label_binarizer_ = LabelBinarizer(sparse_output=True)
254259
Y = self.label_binarizer_.fit_transform(y)
255260
Y = Y.tocsc()
256-
columns = (col.toarray().ravel() for col in Y.T)
257-
261+
Y = Y.T.toarray()
262+
if len(self.label_binarizer_.classes_) == 2:
263+
# The sparse output gives a single array, but we require
264+
# two as we have to train two estimator
265+
Y = np.array(((Y[0]-1)*-1, Y[0]))
266+
columns = (col.ravel() for col in Y)
267+
258268
self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(
259269
_partial_fit_binary)(self.estimators_[i],
260270
X, next(columns) if self.classes_[i] in
261271
self.label_binarizer_.classes_ else
262-
np.zeros((1, len(y))))
272+
np.zeros(len(y), dtype=np.int))
263273
for i in range(self.n_classes_))
264274

265275
return self
@@ -285,7 +295,8 @@ def predict(self, X):
285295
thresh = .5
286296

287297
n_samples = _num_samples(X)
288-
if self.label_binarizer_.y_type_ == "multiclass":
298+
if ((self.label_binarizer_.y_type_ == 'multiclass' or self.label_binarizer_.y_type_ == 'binary') and
299+
type_of_target(self.classes_) != 'binary'):
289300
maxima = np.empty(n_samples, dtype=float)
290301
maxima.fill(-np.inf)
291302
argmaxima = np.zeros(n_samples, dtype=int)

sklearn/tests/test_multiclass.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,49 @@ def test_ovr_partial_fit():
8989
assert_equal(len(ovr.estimators_), len(np.unique(y)))
9090
assert_greater(np.mean(y == pred), 0.65)
9191

92-
# Test when mini batches doesn't have all classes
92+
# Test when classes are more than 2 in each pass
93+
X = np.random.rand(14, 2)
94+
y = [0, 0, 1, 1, 2, 2, 0, 0, 1, 2, 2, 3, 3, 3]
9395
ovr = OneVsRestClassifier(MultinomialNB())
94-
ovr.partial_fit(iris.data[:60], iris.target[:60], np.unique(iris.target))
95-
ovr.partial_fit(iris.data[60:], iris.target[60:])
96-
pred = ovr.predict(iris.data)
96+
ovr.partial_fit(X[:7], y[:7], np.unique(y))
97+
ovr.partial_fit(X[7:], y[7:])
98+
pred = ovr.predict(X)
99+
100+
ovr1 = OneVsRestClassifier(MultinomialNB())
101+
ovr1.fit(X, y)
102+
pred1 = ovr1.predict(X)
103+
assert_almost_equal(np.mean(y == pred), np.mean(pred1 == y))
104+
105+
# Test when mini batches have 2 classes in each
106+
# pass.
107+
temp = datasets.load_iris()
108+
X, y= temp.data, temp.target
109+
ovr = OneVsRestClassifier(MultinomialNB())
110+
ovr.partial_fit(X[:60], y[:60], np.unique(y))
111+
ovr.partial_fit(X[60:], y[60:])
112+
pred = ovr.predict(X)
97113
ovr2 = OneVsRestClassifier(MultinomialNB())
98-
pred2 = ovr2.fit(iris.data, iris.target).predict(iris.data)
114+
pred2 = ovr2.fit(X, y).predict(X)
99115

100116
assert_almost_equal(pred, pred2)
101-
assert_equal(len(ovr.estimators_), len(np.unique(iris.target)))
102-
assert_greater(np.mean(iris.target == pred), 0.65)
117+
assert_equal(len(ovr.estimators_), len(np.unique(y)))
118+
assert_greater(np.mean(y == pred), 0.65)
119+
120+
# Check when mini batch classes doesn't conain classes from all_classes
121+
rnd = np.random.rand(10, 2)
122+
ovr = OneVsRestClassifier(MultinomialNB())
123+
assert_raises(ValueError, ovr.partial_fit, rnd[:5], [0, 1, 2, 3, 4],
124+
[0, 1, 2, 3])
125+
126+
# Test when mini-batches have one class target
127+
ovr = OneVsRestClassifier(MultinomialNB())
128+
ovr.partial_fit(X[:125], y[:125], np.unique(y))
129+
ovr.partial_fit(X[125:], y[125:])
130+
pred = ovr.predict(X)
131+
132+
assert_almost_equal(pred, pred2)
133+
assert_equal(len(ovr.estimators_), len(np.unique(y)))
134+
assert_greater(np.mean(y == pred), 0.65)
103135

104136

105137
def test_ovr_ovo_regressor():

0 commit comments

Comments
 (0)