Skip to content

Commit 306d788

Browse files
committed
add tag-based branching to thresholding scorer
1 parent 5695350 commit 306d788

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

sklearn/metrics/scorer.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,23 @@ def __call__(self, clf, X, y, sample_weight=None):
157157
if y_type not in ("binary", "multilabel-indicator"):
158158
raise ValueError("{0} format is not supported".format(y_type))
159159

160-
try:
161-
y_pred = clf.decision_function(X)
160+
if getattr(clf, "estimator_type", None) == "regressor":
161+
y_pred = clf.predict(X)
162+
else:
163+
try:
164+
y_pred = clf.decision_function(X)
162165

163-
# For multi-output multi-class estimator
164-
if isinstance(y_pred, list):
165-
y_pred = np.vstack(p for p in y_pred).T
166+
# For multi-output multi-class estimator
167+
if isinstance(y_pred, list):
168+
y_pred = np.vstack(p for p in y_pred).T
166169

167-
except (NotImplementedError, AttributeError):
168-
y_pred = clf.predict_proba(X)
170+
except (NotImplementedError, AttributeError):
171+
y_pred = clf.predict_proba(X)
169172

170-
if y_type == "binary":
171-
y_pred = y_pred[:, 1]
172-
elif isinstance(y_pred, list):
173-
y_pred = np.vstack([p[:, -1] for p in y_pred]).T
173+
if y_type == "binary":
174+
y_pred = y_pred[:, 1]
175+
elif isinstance(y_pred, list):
176+
y_pred = np.vstack([p[:, -1] for p in y_pred]).T
174177

175178
if sample_weight is not None:
176179
return self._sign * self._score_func(y, y_pred,

sklearn/metrics/tests/test_score_objects.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sklearn.cluster import KMeans
2323
from sklearn.dummy import DummyRegressor
2424
from sklearn.linear_model import Ridge, LogisticRegression
25-
from sklearn.tree import DecisionTreeClassifier
25+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2626
from sklearn.datasets import make_blobs
2727
from sklearn.datasets import make_classification
2828
from sklearn.datasets import make_multilabel_classification
@@ -219,6 +219,13 @@ def test_thresholded_scorers():
219219
score2 = roc_auc_score(y_test, clf.predict_proba(X_test)[:, 1])
220220
assert_almost_equal(score1, score2)
221221

222+
# test with a regressor (no decision_function)
223+
reg = DecisionTreeRegressor()
224+
reg.fit(X_train, y_train)
225+
score1 = get_scorer('roc_auc')(reg, X_test, y_test)
226+
score2 = roc_auc_score(y_test, reg.predict(X_test))
227+
assert_almost_equal(score1, score2)
228+
222229
# Test that an exception is raised on more than two classes
223230
X, y = make_blobs(random_state=0, centers=3)
224231
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

0 commit comments

Comments
 (0)