Skip to content

Fix for OvR partial_fit in various edge cases #6239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions sklearn/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
from .utils.validation import check_consistent_length
from .utils.validation import check_is_fitted
from .utils.multiclass import (_check_partial_fit_first_call,
check_classification_targets)
check_classification_targets,
type_of_target)
from .externals.joblib import Parallel
from .externals.joblib import delayed
from .externals.six.moves import zip as izip
Expand Down Expand Up @@ -250,16 +251,25 @@ def partial_fit(self, X, y, classes=None):
# outperform or match a dense label binarizer in all cases and has also
# resulted in less or equal memory consumption in the fit_ovr function
# overall.
if not set(self.classes_).issuperset(y):
raise ValueError("Mini-batch contains {0} while classes "
"must be subset of {1}".format(np.unique(y),
self.classes_))
self.label_binarizer_ = LabelBinarizer(sparse_output=True)
Y = self.label_binarizer_.fit_transform(y)
Y = Y.tocsc()
columns = (col.toarray().ravel() for col in Y.T)

Y = Y.T.toarray()
if len(self.label_binarizer_.classes_) == 2:
# The sparse output gives a single array, but we require
# two as we have to train two estimator
Y = np.array(((Y[0]-1)*-1, Y[0]))
columns = (col.ravel() for col in Y)

self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(
_partial_fit_binary)(self.estimators_[i],
X, next(columns) if self.classes_[i] in
self.label_binarizer_.classes_ else
np.zeros((1, len(y))))
np.zeros(len(y), dtype=np.int))
for i in range(self.n_classes_))

return self
Expand All @@ -285,7 +295,12 @@ def predict(self, X):
thresh = .5

n_samples = _num_samples(X)
if self.label_binarizer_.y_type_ == "multiclass":
# In case mini-batches from partial_fit contains binary classes,
# but `type_of_target(y) == 'multiclass'`, it won't go to the
# else part.
if ((self.label_binarizer_.y_type_ == 'multiclass' or
self.label_binarizer_.y_type_ == 'binary') and
type_of_target(self.classes_) != 'binary'):
maxima = np.empty(n_samples, dtype=float)
maxima.fill(-np.inf)
argmaxima = np.zeros(n_samples, dtype=int)
Expand Down
46 changes: 39 additions & 7 deletions sklearn/tests/test_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,49 @@ def test_ovr_partial_fit():
assert_equal(len(ovr.estimators_), len(np.unique(y)))
assert_greater(np.mean(y == pred), 0.65)

# Test when mini batches doesn't have all classes
# Test when classes are more than 2 in each pass
X = np.random.rand(14, 2)
y = [0, 0, 1, 1, 2, 2, 0, 0, 1, 2, 2, 3, 3, 3]
ovr = OneVsRestClassifier(MultinomialNB())
ovr.partial_fit(iris.data[:60], iris.target[:60], np.unique(iris.target))
ovr.partial_fit(iris.data[60:], iris.target[60:])
pred = ovr.predict(iris.data)
ovr.partial_fit(X[:7], y[:7], np.unique(y))
ovr.partial_fit(X[7:], y[7:])
pred = ovr.predict(X)

ovr1 = OneVsRestClassifier(MultinomialNB())
ovr1.fit(X, y)
pred1 = ovr1.predict(X)
assert_almost_equal(np.mean(y == pred), np.mean(pred1 == y))

# Test when mini batches have 2 classes in each
# pass.
temp = datasets.load_iris()
X, y= temp.data, temp.target
ovr = OneVsRestClassifier(MultinomialNB())
ovr.partial_fit(X[:60], y[:60], np.unique(y))
ovr.partial_fit(X[60:], y[60:])
pred = ovr.predict(X)
ovr2 = OneVsRestClassifier(MultinomialNB())
pred2 = ovr2.fit(iris.data, iris.target).predict(iris.data)
pred2 = ovr2.fit(X, y).predict(X)

assert_almost_equal(pred, pred2)
assert_equal(len(ovr.estimators_), len(np.unique(iris.target)))
assert_greater(np.mean(iris.target == pred), 0.65)
assert_equal(len(ovr.estimators_), len(np.unique(y)))
assert_greater(np.mean(y == pred), 0.65)

# Check when mini batch classes doesn't conain classes from all_classes
rnd = np.random.rand(10, 2)
ovr = OneVsRestClassifier(MultinomialNB())
assert_raises(ValueError, ovr.partial_fit, rnd[:5], [0, 1, 2, 3, 4],
[0, 1, 2, 3])

# Test when mini-batches have one class target
ovr = OneVsRestClassifier(MultinomialNB())
ovr.partial_fit(X[:125], y[:125], np.unique(y))
ovr.partial_fit(X[125:], y[125:])
pred = ovr.predict(X)

assert_almost_equal(pred, pred2)
assert_equal(len(ovr.estimators_), len(np.unique(y)))
assert_greater(np.mean(y == pred), 0.65)


def test_ovr_ovo_regressor():
Expand Down