Skip to content

Commit 7372cee

Browse files
committed
CLN Create helper function
1 parent c80debf commit 7372cee

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

sklearn/multiclass.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ def _predict_binary(estimator, X):
102102
return score
103103

104104

105+
def _threshold_for_binary_predict(estimator):
106+
"""Threshold for predictions from binary estimator."""
107+
if hasattr(estimator, "decision_function") and is_classifier(estimator):
108+
return 0.0
109+
else:
110+
# predict_proba threshold
111+
return 0.5
112+
113+
105114
def _check_estimator(estimator):
106115
"""Make sure that an estimator implements the necessary methods."""
107116
if not hasattr(estimator, "decision_function") and not hasattr(
@@ -426,12 +435,7 @@ def predict(self, X):
426435
argmaxima[maxima == pred] = i
427436
return self.classes_[argmaxima]
428437
else:
429-
if hasattr(self.estimators_[0], "decision_function") and is_classifier(
430-
self.estimators_[0]
431-
):
432-
thresh = 0
433-
else:
434-
thresh = 0.5
438+
thresh = _threshold_for_binary_predict(self.estimators_[0])
435439
indices = array.array("i")
436440
indptr = array.array("i", [0])
437441
for e in self.estimators_:
@@ -770,13 +774,7 @@ def predict(self, X):
770774
"""
771775
Y = self.decision_function(X)
772776
if self.n_classes_ == 2:
773-
if hasattr(self.estimators_[0], "decision_function") and is_classifier(
774-
self.estimators_[0]
775-
):
776-
thresh = 0
777-
else:
778-
# predict_proba threshold
779-
thresh = 0.5
777+
thresh = _threshold_for_binary_predict(self.estimators_[0])
780778
return self.classes_[(Y > thresh).astype(int)]
781779
return self.classes_[Y.argmax(axis=1)]
782780

0 commit comments

Comments
 (0)