|
1 | 1 | import numpy as np
|
2 | 2 | import scipy.sparse as sp
|
3 | 3 |
|
4 |
| -from sklearn.utils.testing import assert_array_equal |
| 4 | +from sklearn.utils.testing import assert_array_equal, assert_raises_regex |
5 | 5 | from sklearn.utils.testing import assert_equal
|
6 | 6 | from sklearn.utils.testing import assert_almost_equal
|
7 | 7 | from sklearn.utils.testing import assert_true
|
|
22 | 22 | from sklearn.svm import LinearSVC, SVC
|
23 | 23 | from sklearn.naive_bayes import MultinomialNB
|
24 | 24 | from sklearn.linear_model import (LinearRegression, Lasso, ElasticNet, Ridge,
|
25 |
| - Perceptron, LogisticRegression) |
| 25 | + Perceptron, LogisticRegression, |
| 26 | + SGDClassifier) |
26 | 27 | from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
27 | 28 | from sklearn.model_selection import GridSearchCV, cross_val_score
|
28 | 29 | from sklearn.pipeline import Pipeline
|
@@ -89,20 +90,37 @@ def test_ovr_partial_fit():
|
89 | 90 | assert_greater(np.mean(y == pred), 0.65)
|
90 | 91 |
|
91 | 92 | # Test when mini batches doesn't have all classes
|
92 |
| - ovr = OneVsRestClassifier(MultinomialNB()) |
93 |
| - ovr.partial_fit(iris.data[:60], iris.target[:60], np.unique(iris.target)) |
94 |
| - ovr.partial_fit(iris.data[60:], iris.target[60:]) |
95 |
| - pred = ovr.predict(iris.data) |
96 |
| - ovr2 = OneVsRestClassifier(MultinomialNB()) |
97 |
| - pred2 = ovr2.fit(iris.data, iris.target).predict(iris.data) |
| 93 | + # with SGDClassifier |
| 94 | + X = np.abs(np.random.randn(14, 2)) |
| 95 | + y = [1, 1, 1, 1, 2, 3, 3, 0, 0, 2, 3, 1, 2, 3] |
| 96 | + |
| 97 | + ovr = OneVsRestClassifier(SGDClassifier(n_iter=1, shuffle=False, |
| 98 | + random_state=0)) |
| 99 | + ovr.partial_fit(X[:7], y[:7], np.unique(y)) |
| 100 | + ovr.partial_fit(X[7:], y[7:]) |
| 101 | + pred = ovr.predict(X) |
| 102 | + ovr1 = OneVsRestClassifier(SGDClassifier(n_iter=1, shuffle=False, |
| 103 | + random_state=0)) |
| 104 | + pred1 = ovr1.fit(X, y).predict(X) |
| 105 | + assert_equal(np.mean(pred == y), np.mean(pred1 == y)) |
98 | 106 |
|
99 |
| - assert_almost_equal(pred, pred2) |
100 |
| - assert_equal(len(ovr.estimators_), len(np.unique(iris.target))) |
101 |
| - assert_greater(np.mean(iris.target == pred), 0.65) |
| 107 | + |
| 108 | +def test_ovr_partial_fit_exceptions(): |
| 109 | + ovr = OneVsRestClassifier(MultinomialNB()) |
| 110 | + X = np.abs(np.random.randn(14, 2)) |
| 111 | + y = [1, 1, 1, 1, 2, 3, 3, 0, 0, 2, 3, 1, 2, 3] |
| 112 | + ovr.partial_fit(X[:7], y[:7], np.unique(y)) |
| 113 | + # A new class value which was not in the first call of partial_fit |
| 114 | + # It should raise ValueError |
| 115 | + y1 = [5] + y[7:-1] |
| 116 | + assert_raises_regex(ValueError, "Mini-batch contains \[.+\] while classes" |
| 117 | + " must be subset of \[.+\]", |
| 118 | + ovr.partial_fit, X=X[7:], y=y1) |
102 | 119 |
|
103 | 120 |
|
104 | 121 | def test_ovr_ovo_regressor():
|
105 |
| - # test that ovr and ovo work on regressors which don't have a decision_function |
| 122 | + # test that ovr and ovo work on regressors which don't have a decision_ |
| 123 | + # function |
106 | 124 | ovr = OneVsRestClassifier(DecisionTreeRegressor())
|
107 | 125 | pred = ovr.fit(iris.data, iris.target).predict(iris.data)
|
108 | 126 | assert_equal(len(ovr.estimators_), n_classes)
|
@@ -204,7 +222,6 @@ def test_ovr_multiclass():
|
204 | 222 | for base_clf in (MultinomialNB(), LinearSVC(random_state=0),
|
205 | 223 | LinearRegression(), Ridge(),
|
206 | 224 | ElasticNet()):
|
207 |
| - |
208 | 225 | clf = OneVsRestClassifier(base_clf).fit(X, y)
|
209 | 226 | assert_equal(set(clf.classes_), classes)
|
210 | 227 | y_pred = clf.predict(np.array([[0, 0, 4]]))[0]
|
|
0 commit comments