@@ -244,26 +244,31 @@ def partial_fit(self, X, y, classes=None):
244
244
self
245
245
"""
246
246
if _check_partial_fit_first_call (self , classes ):
247
- if ( not hasattr (self .estimator , "partial_fit" ) ):
247
+ if not hasattr (self .estimator , "partial_fit" ):
248
248
raise ValueError ("Base estimator {0}, doesn't have partial_fit"
249
249
"method" .format (self .estimator ))
250
250
self .estimators_ = [clone (self .estimator ) for _ in range
251
251
(self .n_classes_ )]
252
252
253
- # A sparse LabelBinarizer, with sparse_output=True, has been shown to
254
- # outperform or match a dense label binarizer in all cases and has also
255
- # resulted in less or equal memory consumption in the fit_ovr function
256
- # overall.
257
- self .label_binarizer_ = LabelBinarizer (sparse_output = True )
258
- Y = self .label_binarizer_ .fit_transform (y )
253
+ # A sparse LabelBinarizer, with sparse_output=True, has been shown to
254
+ # outperform or match a dense label binarizer in all cases and has also
255
+ # resulted in less or equal memory consumption in the fit_ovr function
256
+ # overall.
257
+ self .label_binarizer_ = LabelBinarizer (sparse_output = True )
258
+ self .label_binarizer_ .fit (self .classes_ )
259
+
260
+ if not set (self .classes_ ).issuperset (y ):
261
+ raise ValueError ("Mini-batch contains {0} while classes " +
262
+ "must be subset of {1}" .format (np .unique (y ),
263
+ self .classes_ ))
264
+
265
+ Y = self .label_binarizer_ .transform (y )
259
266
Y = Y .tocsc ()
260
267
columns = (col .toarray ().ravel () for col in Y .T )
261
268
262
- self .estimators_ = Parallel (n_jobs = self .n_jobs )(delayed (
263
- _partial_fit_binary )(self .estimators_ [i ],
264
- X , next (columns ) if self .classes_ [i ] in
265
- self .label_binarizer_ .classes_ else
266
- np .zeros ((1 , len (y ))))
269
+ self .estimators_ = Parallel (n_jobs = self .n_jobs )(
270
+ delayed (_partial_fit_binary )(self .estimators_ [i ], X ,
271
+ next (columns ))
267
272
for i in range (self .n_classes_ ))
268
273
269
274
return self
0 commit comments