-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG+1] Solves integer overlow in mutual_info_score #10414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] Solves integer overlow in mutual_info_score #10414
Conversation
It don't think changing |
…into mutual_info_classif
@@ -602,6 +602,8 @@ def mutual_info_score(labels_true, labels_pred, contingency=None): | |||
contingency_nm = nz_val / contingency_sum | |||
# Don't need to calculate the full outer product, just for non-zeroes | |||
outer = pi.take(nzx) * pj.take(nzy) | |||
if np.any(outer < 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment that this checks for overflow
@@ -602,6 +602,8 @@ def mutual_info_score(labels_true, labels_pred, contingency=None): | |||
contingency_nm = nz_val / contingency_sum | |||
# Don't need to calculate the full outer product, just for non-zeroes | |||
outer = pi.take(nzx) * pj.take(nzy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worthwhile always casting to int64? Usually there should either be a small contingency matrix, or a sparse one, so I don't think memory is a big issue...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, I don't think that's required. But happy to do that if there's some special advantage over the current proposed method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It merely avoids duplicated code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman do you mean casting the contigency
variable to int64
? IIRC this is what I was thinking as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's certainly an option if contingency is sparse. I just meant removing the if line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want the outer
variable to be always converted to int64?
outer = pi.take(nzx).astype(np.int64) * pi.take(nzy).astype(np.int64)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's my suggestion
…into mutual_info_classif
…into mutual_info_classif
np.repeat(0, 814), np.repeat(1, 39), | ||
np.repeat(0, 316), np.repeat(1, 20))) | ||
|
||
mutual_info_score(x.ravel(), y.ravel()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you forgot to assert the score to be sure that we don't have NaN
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this check is only to be assured there's no overflow. Why do I check for a NaN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, it was already possible to call the function and it was resulting to nan (due to the overflow in the log). You should at least check that the output is finite to be sure that they is no overflow. Otherwise, your current test is also passing in the current master branch,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test does not pass in 32 bit Python in master branch but passes otherwise. I'll soon add the suggested check.
@@ -601,7 +601,7 @@ def mutual_info_score(labels_true, labels_pred, contingency=None): | |||
log_contingency_nm = np.log(nz_val) | |||
contingency_nm = nz_val / contingency_sum | |||
# Don't need to calculate the full outer product, just for non-zeroes | |||
outer = pi.take(nzx) * pj.take(nzy) | |||
outer = pi.take(nzx).astype(np.int64) * pj.take(nzy).astype(np.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would cast before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pi = np.ravel(contingency.sum(axis=1, dtype=np.int64))
pj = np.ravel(contingency.sum(axis=0, dtype=np.int64))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not work for all NumPy versions. And not worth to backport the feature for one line. Hence switched back to casting it later.
def test_int_overflow_mutual_info_score(): | ||
# Test overflow in mutual_info_classif | ||
x = np.concatenate((np.repeat(1, 52632 + 2529), np.repeat(2, 14660+793), | ||
np.repeat(3, 3271+204), np.repeat(4, 814+39), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put space between the arithmetic signs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this is easier to read np.array([0] * 10000 + [1] * 10000 + [2] * 10000)
than the call with np.repeat
.
You can also make the sum directly instead of making the addition.
…into mutual_info_classif
…into mutual_info_classif
It is only missing an entry in the what's new |
…into mutual_info_classif
Thanks @thechargedneutron |
Reference Issues/PRs
Fixes #9772