-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FIX Fixes issue with exatly_zero_info_score #19179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
240de07
637681c
d6fa33a
98e9480
8178a8b
4065ad4
e4d17f1
e311d89
6600968
d261ff1
f67c5cf
ea3094a
d1e3a49
4572831
0e705f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ def expected_mutual_information(contingency, int n_samples): | |
cdef DOUBLE N, gln_N, emi, term2, term3, gln | ||
cdef np.ndarray[DOUBLE] gln_a, gln_b, gln_Na, gln_Nb, gln_nij, log_Nnij | ||
cdef np.ndarray[DOUBLE] nijs, term1 | ||
cdef np.ndarray[DOUBLE, ndim=2] log_ab_outer | ||
cdef np.ndarray[DOUBLE] log_a, log_b | ||
cdef np.ndarray[np.int32_t] a, b | ||
#cdef np.ndarray[int, ndim=2] start, end | ||
R, C = contingency.shape | ||
|
@@ -37,10 +37,10 @@ def expected_mutual_information(contingency, int n_samples): | |
# term1 is nij / N | ||
term1 = nijs / N | ||
# term2 is log((N*nij) / (a * b)) == log(N * nij) - log(a * b) | ||
# term2 uses the outer product | ||
log_ab_outer = np.log(a)[:, np.newaxis] + np.log(b) | ||
# term2 uses N * nij | ||
log_Nnij = np.log(N * nijs) | ||
log_a = np.log(a) | ||
log_b = np.log(b) | ||
# term2 uses log(N * nij) = log(N) + log(nij) | ||
log_Nnij = np.log(N) + np.log(nijs) | ||
# term3 is large, and involved many factorials. Calculate these in log | ||
# space to stop overflows. | ||
gln_a = gammaln(a + 1) | ||
|
@@ -54,12 +54,12 @@ def expected_mutual_information(contingency, int n_samples): | |
start = np.maximum(start, 1) | ||
end = np.minimum(np.resize(a, (C, R)).T, np.resize(b, (R, C))) + 1 | ||
# emi itself is a summation over the various values. | ||
emi = 0 | ||
emi = 0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. emi is defined as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is more explicit if you did not read the definition :) |
||
cdef Py_ssize_t i, j, nij | ||
for i in range(R): | ||
for j in range(C): | ||
for nij in range(start[i,j], end[i,j]): | ||
term2 = log_Nnij[nij] - log_ab_outer[i,j] | ||
term2 = log_Nnij[nij] - log_a[i] - log_b[j] | ||
# Numerators are positive, denominators are negative. | ||
gln = (gln_a[i] + gln_b[j] + gln_Na[i] + gln_Nb[j] | ||
- gln_N - gln_nij[nij] - lgamma(a[i] - nij + 1) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -795,6 +795,7 @@ def mutual_info_score(labels_true, labels_pred, *, contingency=None): | |
log_outer = -np.log(outer) + log(pi.sum()) + log(pj.sum()) | ||
mi = (contingency_nm * (log_contingency_nm - log(contingency_sum)) + | ||
contingency_nm * log_outer) | ||
mi = np.where(np.abs(mi) < np.finfo(mi.dtype).eps, 0.0, mi) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fair enough, but I wonder what other places we need to be doing it! |
||
return np.clip(mi.sum(), 0.0, None) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly more memory efficient because we would not need to create the 2d array anymore.