-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Fix predict method for multiclass multioutput ensemble models #12834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix predict method for multiclass multioutput ensemble models #12834
Conversation
sklearn/ensemble/forest.py
Outdated
@@ -547,7 +547,8 @@ def predict(self, X): | |||
|
|||
else: | |||
n_samples = proba[0].shape[0] | |||
predictions = np.zeros((n_samples, self.n_outputs_)) | |||
predictions = np.empty((n_samples, self.n_outputs_), | |||
dtype='object') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't it be better to have dtype=self.classes_.dtype
or something?
|
||
with np.errstate(divide="ignore"): | ||
proba = est.predict_proba(X_test) | ||
assert_equal(len(proba), 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the adoption of pytest, we are phasing out use of test helpers assert_equal
, assert_true
, etc. Please use bare assert
statements, e.g. assert x == y
, assert not x
, etc.
Sorry for the delay! I committed a couple of changes to address the code review comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @elsander , LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Please add an entry to the change log at doc/whats_new/v0.21.rst
. Like the other entries there, please reference this pull request with :issue:
and credit yourself (and other contributors if applicable) with :user:
Reference Issues/PRs
Fixes #12831.
What does this implement/fix? Explain your changes.
This PR fixes a bug where the
predict
method would fail for multiclass multioutput ensemble models, if any of the dependent variables were strings. The underlying issue was preallocating thepredict
output usingnp.zeros
, which would then error when string predictions were inserted. I replaced the function call with a more dtype-agnostic call tonp.empty
.Any other comments?