|
30 | 30 | from .cluster import adjusted_rand_score
|
31 | 31 | from ..utils.multiclass import type_of_target
|
32 | 32 | from ..externals import six
|
| 33 | +from ..base import is_regressor |
33 | 34 |
|
34 | 35 |
|
35 | 36 | class _BaseScorer(six.with_metaclass(ABCMeta, object)):
|
@@ -157,20 +158,23 @@ def __call__(self, clf, X, y, sample_weight=None):
|
157 | 158 | if y_type not in ("binary", "multilabel-indicator"):
|
158 | 159 | raise ValueError("{0} format is not supported".format(y_type))
|
159 | 160 |
|
160 |
| - try: |
161 |
| - y_pred = clf.decision_function(X) |
| 161 | + if is_regressor(clf): |
| 162 | + y_pred = clf.predict(X) |
| 163 | + else: |
| 164 | + try: |
| 165 | + y_pred = clf.decision_function(X) |
162 | 166 |
|
163 |
| - # For multi-output multi-class estimator |
164 |
| - if isinstance(y_pred, list): |
165 |
| - y_pred = np.vstack(p for p in y_pred).T |
| 167 | + # For multi-output multi-class estimator |
| 168 | + if isinstance(y_pred, list): |
| 169 | + y_pred = np.vstack(p for p in y_pred).T |
166 | 170 |
|
167 |
| - except (NotImplementedError, AttributeError): |
168 |
| - y_pred = clf.predict_proba(X) |
| 171 | + except (NotImplementedError, AttributeError): |
| 172 | + y_pred = clf.predict_proba(X) |
169 | 173 |
|
170 |
| - if y_type == "binary": |
171 |
| - y_pred = y_pred[:, 1] |
172 |
| - elif isinstance(y_pred, list): |
173 |
| - y_pred = np.vstack([p[:, -1] for p in y_pred]).T |
| 174 | + if y_type == "binary": |
| 175 | + y_pred = y_pred[:, 1] |
| 176 | + elif isinstance(y_pred, list): |
| 177 | + y_pred = np.vstack([p[:, -1] for p in y_pred]).T |
174 | 178 |
|
175 | 179 | if sample_weight is not None:
|
176 | 180 | return self._sign * self._score_func(y, y_pred,
|
|
0 commit comments