@@ -102,6 +102,15 @@ def _predict_binary(estimator, X):
102
102
return score
103
103
104
104
105
+ def _threshold_for_binary_predict (estimator ):
106
+ """Threshold for predictions from binary estimator."""
107
+ if hasattr (estimator , "decision_function" ) and is_classifier (estimator ):
108
+ return 0.0
109
+ else :
110
+ # predict_proba threshold
111
+ return 0.5
112
+
113
+
105
114
def _check_estimator (estimator ):
106
115
"""Make sure that an estimator implements the necessary methods."""
107
116
if not hasattr (estimator , "decision_function" ) and not hasattr (
@@ -426,12 +435,7 @@ def predict(self, X):
426
435
argmaxima [maxima == pred ] = i
427
436
return self .classes_ [argmaxima ]
428
437
else :
429
- if hasattr (self .estimators_ [0 ], "decision_function" ) and is_classifier (
430
- self .estimators_ [0 ]
431
- ):
432
- thresh = 0
433
- else :
434
- thresh = 0.5
438
+ thresh = _threshold_for_binary_predict (self .estimators_ [0 ])
435
439
indices = array .array ("i" )
436
440
indptr = array .array ("i" , [0 ])
437
441
for e in self .estimators_ :
@@ -770,13 +774,7 @@ def predict(self, X):
770
774
"""
771
775
Y = self .decision_function (X )
772
776
if self .n_classes_ == 2 :
773
- if hasattr (self .estimators_ [0 ], "decision_function" ) and is_classifier (
774
- self .estimators_ [0 ]
775
- ):
776
- thresh = 0
777
- else :
778
- # predict_proba threshold
779
- thresh = 0.5
777
+ thresh = _threshold_for_binary_predict (self .estimators_ [0 ])
780
778
return self .classes_ [(Y > thresh ).astype (int )]
781
779
return self .classes_ [Y .argmax (axis = 1 )]
782
780
0 commit comments