-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
FIX Fixes OneVsOneClassifier.predict for Estimators with only predict_proba #22604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Marking this as quick review. We do a very similar thing here: scikit-learn/sklearn/multiclass.py Lines 429 to 434 in 70eebb9
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart a suggestion, this LGTM. Thank you, @thomasjpfan.
sklearn/multiclass.py
Outdated
if hasattr(self.estimators_[0], "decision_function") and is_classifier( | ||
self.estimators_[0] | ||
): | ||
thresh = 0 | ||
else: | ||
# predict_proba threshold | ||
thresh = 0.5 | ||
return self.classes_[(Y > thresh).astype(int)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth wrapping this and the lines you mentioned, i.e.:
scikit-learn/sklearn/multiclass.py
Lines 429 to 434 in 70eebb9
if hasattr(self.estimators_[0], "decision_function") and is_classifier( | |
self.estimators_[0] | |
): | |
thresh = 0 | |
else: | |
thresh = 0.5 |
in a private _threshold
property?
I think this can make the logic explicit and ease maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OneVsOneClassifier
and OneVsRestClassifier
do not share a common base class where a _threshold
property makes sense. Also the threshold is only used in the binary case.
Inspired by your suggestion, I think a helper function makes the logic better 7372cee
(#22604)
3290270
to
7372cee
Compare
Merging main to get rid of the doc-min-dependencies error about Python 3.7. |
Merging, thanks! |
…_proba (scikit-learn#22604) Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
* Updating to scipy 1.12 * Scikitlearn 1.0 is incompatible with scipy 1.12, so updating to 1.1 * Sometimes only a 1d array is returned. * Removing testing kulsinski metric. * simps is deprecated so switching to simpson * Handle change from scikit-learn/scikit-learn#22199 * Improving test to classify better. * Regold due to scikit change scikit-learn/scikit-learn#22604 * Regold PolyExponential files (had rel err of 1e-03 or less) * Make timestep uniform for scipy update. * Regolding because of changes in scipy 1.12 * Increase limits to improve convergence. * Unpinning xarray and updating numpy * Updating various libraries. * Fix working with newer tensorflow. * Values need to be switched to tuples for hstack in numpy 1.26 * Updating to new ray version. * The deque size can be bigger in python 3.11 * Report difference in row lengths, instead of crashing OrderedCSVDiffer. Also report gold file name. * Remove Fourier__signal_f__period10.0__phase This was either +pi or -pi semirandomly, so nolonger testing it. * Regolding changes to ROM/TimeSeries/DMD/BOPDMD because of library changes. * Support xarray 2024.7 and newer. Pre 2024.7 automatically squeeze()ed groupby results, so now need to explicitly call squeeze(). * Fixing long line. * Increasing zero threshold because of change in libraries. * Remove version from setuptools since ray updated. * Optimizing persistence in BayesianMatyas. * Switch OVO to use estimator that is not constantly zero. * Use keepdims instead of try catch block. * Updating to default using python 3.11
Reference Issues/PRs
Fixes #13617
What does this implement/fix? Explain your changes.
This PR uses the correct threshold if the inner estimator uses
predict_proba
.