Skip to content

Commit bc10d8f

Browse files
committed
ENH use / test multi-class with regressors
1 parent 730c677 commit bc10d8f

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

sklearn/multiclass.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import scipy.sparse as sp
4040

4141
from .base import BaseEstimator, ClassifierMixin, clone, is_classifier
42-
from .base import MetaEstimatorMixin
42+
from .base import MetaEstimatorMixin, is_regressor
4343
from .preprocessing import LabelBinarizer
4444
from .metrics.pairwise import euclidean_distances
4545
from .utils import check_random_state
@@ -77,6 +77,8 @@ def _fit_binary(estimator, X, y, classes=None):
7777

7878
def _predict_binary(estimator, X):
7979
"""Make predictions using a single binary estimator."""
80+
if is_regressor(estimator):
81+
return estimator.predict(X)
8082
try:
8183
score = np.ravel(estimator.decision_function(X))
8284
except (AttributeError, NotImplementedError):
@@ -276,11 +278,11 @@ def fit(self, X, y):
276278
# In cases where individual estimators are very fast to train setting
277279
# n_jobs > 1 in can results in slower performance due to the overhead
278280
# of spawning threads. See joblib issue #112.
279-
self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(_fit_binary)
280-
(self.estimator, X, column,
281-
classes=["not %s" % self.label_binarizer_.classes_[i],
282-
self.label_binarizer_.classes_[i]])
283-
for i, column in enumerate(columns))
281+
self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(_fit_binary)(
282+
self.estimator, X, column, classes=[
283+
"not %s" % self.label_binarizer_.classes_[i],
284+
self.label_binarizer_.classes_[i]])
285+
for i, column in enumerate(columns))
284286

285287
return self
286288

sklearn/tests/test_multiclass.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from sklearn.naive_bayes import MultinomialNB
3232
from sklearn.linear_model import (LinearRegression, Lasso, ElasticNet, Ridge,
3333
Perceptron, LogisticRegression)
34-
from sklearn.tree import DecisionTreeClassifier
34+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
3535
from sklearn.grid_search import GridSearchCV
3636
from sklearn.pipeline import Pipeline
3737
from sklearn import svm
@@ -79,6 +79,23 @@ def test_ovr_fit_predict():
7979
assert_greater(np.mean(iris.target == pred), 0.65)
8080

8181

82+
def test_ovr_ovo_regressor():
83+
# test that ovr and ovo work on regressors which don't have a decision_function
84+
ovr = OneVsRestClassifier(DecisionTreeRegressor())
85+
pred = ovr.fit(iris.data, iris.target).predict(iris.data)
86+
assert_equal(len(ovr.estimators_), n_classes)
87+
assert_array_equal(np.unique(pred), [0, 1, 2])
88+
# we are doing something sensible
89+
assert_greater(np.mean(pred == iris.target), .9)
90+
91+
ovr = OneVsOneClassifier(DecisionTreeRegressor())
92+
pred = ovr.fit(iris.data, iris.target).predict(iris.data)
93+
assert_equal(len(ovr.estimators_), n_classes * (n_classes - 1) / 2)
94+
assert_array_equal(np.unique(pred), [0, 1, 2])
95+
# we are doing something sensible
96+
assert_greater(np.mean(pred == iris.target), .9)
97+
98+
8299
def test_ovr_fit_predict_sparse():
83100
for sparse in [sp.csr_matrix, sp.csc_matrix, sp.coo_matrix, sp.dok_matrix,
84101
sp.lil_matrix]:

0 commit comments

Comments
 (0)