Skip to content

Commit 730c677

Browse files
committed
add tag-based branching to thresholding scorer
1 parent 5695350 commit 730c677

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

sklearn/metrics/scorer.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .cluster import adjusted_rand_score
3131
from ..utils.multiclass import type_of_target
3232
from ..externals import six
33+
from ..base import is_regressor
3334

3435

3536
class _BaseScorer(six.with_metaclass(ABCMeta, object)):
@@ -157,20 +158,23 @@ def __call__(self, clf, X, y, sample_weight=None):
157158
if y_type not in ("binary", "multilabel-indicator"):
158159
raise ValueError("{0} format is not supported".format(y_type))
159160

160-
try:
161-
y_pred = clf.decision_function(X)
161+
if is_regressor(clf):
162+
y_pred = clf.predict(X)
163+
else:
164+
try:
165+
y_pred = clf.decision_function(X)
162166

163-
# For multi-output multi-class estimator
164-
if isinstance(y_pred, list):
165-
y_pred = np.vstack(p for p in y_pred).T
167+
# For multi-output multi-class estimator
168+
if isinstance(y_pred, list):
169+
y_pred = np.vstack(p for p in y_pred).T
166170

167-
except (NotImplementedError, AttributeError):
168-
y_pred = clf.predict_proba(X)
171+
except (NotImplementedError, AttributeError):
172+
y_pred = clf.predict_proba(X)
169173

170-
if y_type == "binary":
171-
y_pred = y_pred[:, 1]
172-
elif isinstance(y_pred, list):
173-
y_pred = np.vstack([p[:, -1] for p in y_pred]).T
174+
if y_type == "binary":
175+
y_pred = y_pred[:, 1]
176+
elif isinstance(y_pred, list):
177+
y_pred = np.vstack([p[:, -1] for p in y_pred]).T
174178

175179
if sample_weight is not None:
176180
return self._sign * self._score_func(y, y_pred,

sklearn/metrics/tests/test_score_objects.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sklearn.cluster import KMeans
2323
from sklearn.dummy import DummyRegressor
2424
from sklearn.linear_model import Ridge, LogisticRegression
25-
from sklearn.tree import DecisionTreeClassifier
25+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2626
from sklearn.datasets import make_blobs
2727
from sklearn.datasets import make_classification
2828
from sklearn.datasets import make_multilabel_classification
@@ -219,6 +219,13 @@ def test_thresholded_scorers():
219219
score2 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1])
220220
assert_almost_equal(score1, score2)
221221

222+
# test with a regressor (no decision_function)
223+
reg = DecisionTreeRegressor()
224+
reg.fit(X_train, y_train)
225+
score1 = get_scorer('roc_auc')(reg, X_test, y_test)
226+
score2 = roc_auc_score(y_test, reg.predict(X_test))
227+
assert_almost_equal(score1, score2)
228+
222229
# Test that an exception is raised on more than two classes
223230
X, y = make_blobs(random_state=0, centers=3)
224231
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

0 commit comments

Comments
 (0)