-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH fast path for binary confusion matrix #15403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -879,12 +879,6 @@ def test_confusion_matrix_dtype(): | |||
assert cm[0, 0] == 4294967295 | |||
assert cm[1, 1] == 8589934590 | |||
|
|||
# np.iinfo(np.int64).max should cause an overflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be a neater solution than removing this, but it turns out that different implementations will give different results in the case of overflow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it just test for the binary type and exact selection conditions as in the optimized code before making the assertions to exclude them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for the PR. Helpful question: what about handling of the normalization case, as I do not see anything which deals with it? Update: nevermind, it appears I must have been looking at another function or version previously as I realize there is currently no normalize argument for the confusion matrix though theoretically one could be added, its easy enough to divide by np.sum.
Looks like CI is unhappy anyway |
Perhaps instead of |
Nice job, it works now, only the doc system has some sort of issue. I would consider putting |
Please feel free to run some benchmarks, @GregoryMorse |
Would using timeit between the old and new version be sufficient? |
Yes. Maybe checking a few different affected cases.
|
sample_weight = np.asarray(sample_weight) | ||
check_consistent_length(y_true, y_pred, sample_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample_weight = np.asarray(sample_weight) | |
check_consistent_length(y_true, y_pred, sample_weight) | |
sample_weight = _check_sample_weight(sample_weight, X) |
This PR needs a simple benchmark (as comment/post here in github), e.g. with |
Some profiling done in #28578 showed that it's actually all the checks that dominate in |
A patch for @GregoryMorse to benchmark.
Fixes #15388.