Skip to content

Commit ed08d38

Browse files
mini-batch can now contain less number of classes than actual data
1 parent 10323f5 commit ed08d38

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

sklearn/multiclass.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -244,26 +244,31 @@ def partial_fit(self, X, y, classes=None):
244244
self
245245
"""
246246
if _check_partial_fit_first_call(self, classes):
247-
if (not hasattr(self.estimator, "partial_fit")):
247+
if not hasattr(self.estimator, "partial_fit"):
248248
raise ValueError("Base estimator {0}, doesn't have partial_fit"
249249
"method".format(self.estimator))
250250
self.estimators_ = [clone(self.estimator) for _ in range
251251
(self.n_classes_)]
252252

253-
# A sparse LabelBinarizer, with sparse_output=True, has been shown to
254-
# outperform or match a dense label binarizer in all cases and has also
255-
# resulted in less or equal memory consumption in the fit_ovr function
256-
# overall.
257-
self.label_binarizer_ = LabelBinarizer(sparse_output=True)
258-
Y = self.label_binarizer_.fit_transform(y)
253+
# A sparse LabelBinarizer, with sparse_output=True, has been shown to
254+
# outperform or match a dense label binarizer in all cases and has also
255+
# resulted in less or equal memory consumption in the fit_ovr function
256+
# overall.
257+
self.label_binarizer_ = LabelBinarizer(sparse_output=True)
258+
self.label_binarizer_.fit(self.classes_)
259+
260+
if not set(self.classes_).issuperset(y):
261+
raise ValueError("Mini-batch contains {0} while classes " +
262+
"must be subset of {1}".format(np.unique(y),
263+
self.classes_))
264+
265+
Y = self.label_binarizer_.transform(y)
259266
Y = Y.tocsc()
260267
columns = (col.toarray().ravel() for col in Y.T)
261268

262-
self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(
263-
_partial_fit_binary)(self.estimators_[i],
264-
X, next(columns) if self.classes_[i] in
265-
self.label_binarizer_.classes_ else
266-
np.zeros((1, len(y))))
269+
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
270+
delayed(_partial_fit_binary)(self.estimators_[i], X,
271+
next(columns))
267272
for i in range(self.n_classes_))
268273

269274
return self

0 commit comments

Comments
 (0)