Skip to content

Commit 3d24281

Browse files
committed
Backport sampling fixes from dev (suggested by Henry Milner)
1 parent 915ab97 commit 3d24281

File tree

5 files changed

+54
-45
lines changed

5 files changed

+54
-45
lines changed

core/src/main/scala/spark/Partitioner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
4141
Array()
4242
} else {
4343
val rddSize = rdd.count()
44-
val maxSampleSize = partitions * 10.0
44+
val maxSampleSize = partitions * 20.0
4545
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
4646
val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _)
4747
if (rddSample.length == 0) {

core/src/main/scala/spark/RDD.scala

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,32 +97,31 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
9797
var multiplier = 3.0
9898
var initialCount = count()
9999
var maxSelected = 0
100-
101-
if (initialCount > Integer.MAX_VALUE) {
102-
maxSelected = Integer.MAX_VALUE
100+
101+
if (initialCount > Integer.MAX_VALUE - 1) {
102+
maxSelected = Integer.MAX_VALUE - 1
103103
} else {
104104
maxSelected = initialCount.toInt
105105
}
106-
106+
107107
if (num > initialCount) {
108108
total = maxSelected
109-
fraction = Math.min(multiplier * (maxSelected + 1) / initialCount, 1.0)
109+
fraction = math.min(multiplier * (maxSelected + 1) / initialCount, 1.0)
110110
} else if (num < 0) {
111111
throw(new IllegalArgumentException("Negative number of elements requested"))
112112
} else {
113-
fraction = Math.min(multiplier * (num + 1) / initialCount, 1.0)
114-
total = num.toInt
113+
fraction = math.min(multiplier * (num + 1) / initialCount, 1.0)
114+
total = num
115115
}
116-
117-
var samples = this.sample(withReplacement, fraction, seed).collect()
118-
116+
117+
val rand = new Random(seed)
118+
var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
119+
119120
while (samples.length < total) {
120-
samples = this.sample(withReplacement, fraction, seed).collect()
121+
samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
121122
}
122-
123-
val arr = samples.take(total)
124-
125-
return arr
123+
124+
Utils.randomizeInPlace(samples, rand).take(total)
126125
}
127126

128127
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package spark
22

33
import java.util.Random
4+
import cern.jet.random.Poisson
5+
import cern.jet.random.engine.DRand
46

57
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
6-
override val index = prev.index
8+
override val index: Int = prev.index
79
}
810

911
class SampledRDD[T: ClassManifest](
@@ -15,7 +17,7 @@ class SampledRDD[T: ClassManifest](
1517

1618
@transient
1719
val splits_ = {
18-
val rg = new Random(seed);
20+
val rg = new Random(seed)
1921
prev.splits.map(x => new SampledRDDSplit(x, rg.nextInt))
2022
}
2123

@@ -28,19 +30,21 @@ class SampledRDD[T: ClassManifest](
2830

2931
override def compute(splitIn: Split) = {
3032
val split = splitIn.asInstanceOf[SampledRDDSplit]
31-
val rg = new Random(split.seed);
32-
// Sampling with replacement (TODO: use reservoir sampling to make this more efficient?)
3333
if (withReplacement) {
34-
val oldData = prev.iterator(split.prev).toArray
35-
val sampleSize = (oldData.size * frac).ceil.toInt
36-
val sampledData = {
37-
// all of oldData's indices are candidates, even if sampleSize < oldData.size
38-
for (i <- 1 to sampleSize)
39-
yield oldData(rg.nextInt(oldData.size))
34+
// For large datasets, the expected number of occurrences of each element in a sample with
35+
// replacement is Poisson(frac). We use that to get a count for each element.
36+
val poisson = new Poisson(frac, new DRand(split.seed))
37+
prev.iterator(split.prev).flatMap { element =>
38+
val count = poisson.nextInt()
39+
if (count == 0) {
40+
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
41+
} else {
42+
Iterator.fill(count)(element)
43+
}
4044
}
41-
sampledData.iterator
4245
} else { // Sampling without replacement
43-
prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac))
46+
val rand = new Random(split.seed)
47+
prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac))
4448
}
4549
}
4650
}

core/src/main/scala/spark/Utils.scala

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ import java.net.InetAddress
55
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
66

77
import scala.collection.mutable.ArrayBuffer
8-
import scala.util.Random
9-
import java.util.{Locale, UUID}
8+
import java.util.{Locale, UUID, Random}
109

1110
/**
1211
* Various utility methods used by Spark.
@@ -104,20 +103,27 @@ object Utils {
104103
}
105104
}
106105

107-
// Shuffle the elements of a collection into a random order, returning the
108-
// result in a new collection. Unlike scala.util.Random.shuffle, this method
109-
// uses a local random number generator, avoiding inter-thread contention.
110-
def randomize[T](seq: TraversableOnce[T]): Seq[T] = {
111-
val buf = new ArrayBuffer[T]()
112-
buf ++= seq
113-
val rand = new Random()
114-
for (i <- (buf.size - 1) to 1 by -1) {
106+
/**
107+
* Shuffle the elements of a collection into a random order, returning the
108+
* result in a new collection. Unlike scala.util.Random.shuffle, this method
109+
* uses a local random number generator, avoiding inter-thread contention.
110+
*/
111+
def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
112+
randomizeInPlace(seq.toArray)
113+
}
114+
115+
/**
116+
* Shuffle the elements of an array into a random order, modifying the
117+
* original array. Returns the original array.
118+
*/
119+
def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
120+
for (i <- (arr.length - 1) to 1 by -1) {
115121
val j = rand.nextInt(i)
116-
val tmp = buf(j)
117-
buf(j) = buf(i)
118-
buf(i) = tmp
122+
val tmp = arr(j)
123+
arr(j) = arr(i)
124+
arr(i) = tmp
119125
}
120-
buf
126+
arr
121127
}
122128

123129
/**

project/SparkBuild.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ object SparkBuild extends Build {
5858
"com.google.protobuf" % "protobuf-java" % "2.4.1",
5959
"de.javakaffee" % "kryo-serializers" % "0.9",
6060
"org.jboss.netty" % "netty" % "3.2.6.Final",
61-
"it.unimi.dsi" % "fastutil" % "6.4.2"
61+
"it.unimi.dsi" % "fastutil" % "6.4.2",
62+
"colt" % "colt" % "1.2.0"
6263
)
6364
) ++ assemblySettings ++ Seq(test in assembly := {})
6465

@@ -68,8 +69,7 @@ object SparkBuild extends Build {
6869
) ++ assemblySettings ++ Seq(test in assembly := {})
6970

7071
def examplesSettings = sharedSettings ++ Seq(
71-
name := "spark-examples",
72-
libraryDependencies += "colt" % "colt" % "1.2.0"
72+
name := "spark-examples"
7373
)
7474

7575
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")

0 commit comments

Comments
 (0)