Skip to content

Commit 2462b40

Browse files
Christopher Nguyenmateiz
authored andcommitted
In the current code, when both partitions happen to have zero-length, the return mean will be NaN.
Consequently, the result of mean after reducing over all partitions will also be NaN, which is not correct if there are partitions with non-zero length. This patch fixes this issue.
1 parent 5539549 commit 2462b40

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

core/src/main/scala/spark/util/StatCounter.scala

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
3737
if (other == this) {
3838
merge(other.copy()) // Avoid overwriting fields in a weird order
3939
} else {
40-
val delta = other.mu - mu
41-
if (other.n * 10 < n) {
42-
mu = mu + (delta * other.n) / (n + other.n)
43-
} else if (n * 10 < other.n) {
44-
mu = other.mu - (delta * n) / (n + other.n)
45-
} else {
46-
mu = (mu * n + other.mu * other.n) / (n + other.n)
40+
if (n == 0) {
41+
mu = other.mu
42+
m2 = other.m2
43+
n = other.n
44+
} else if (other.n != 0) {
45+
val delta = other.mu - mu
46+
if (other.n * 10 < n) {
47+
mu = mu + (delta * other.n) / (n + other.n)
48+
} else if (n * 10 < other.n) {
49+
mu = other.mu - (delta * n) / (n + other.n)
50+
} else {
51+
mu = (mu * n + other.mu * other.n) / (n + other.n)
52+
}
53+
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
54+
n += other.n
4755
}
48-
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
49-
n += other.n
50-
this
56+
this
5157
}
5258
}
5359

0 commit comments

Comments
 (0)