Skip to content

Commit 66e5362

Browse files
committed
Merge pull request apache#175 from squito/collection_accumulators
add accumulators for mutable collections, with correct typing!
2 parents 607b8ff + 1490d09 commit 66e5362

File tree

3 files changed

+51
-5
lines changed

3 files changed

+51
-5
lines changed

core/src/main/scala/spark/Accumulators.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package spark
33
import java.io._
44

55
import scala.collection.mutable.Map
6+
import collection.generic.Growable
67

78
class Accumulable[T,R] (
89
@transient initialValue: T,
@@ -108,6 +109,18 @@ trait AccumulableParam[R,T] extends Serializable {
108109
def zero(initialValue: R): R
109110
}
110111

112+
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] extends AccumulableParam[R,T] {
113+
def addAccumulator(growable: R, elem: T) : R = {
114+
growable += elem
115+
growable
116+
}
117+
def addInPlace(t1: R, t2: R) : R = {
118+
t1 ++= t2
119+
t1
120+
}
121+
def zero(initialValue: R) = initialValue
122+
}
123+
111124
// TODO: The multi-thread support in accumulators is kind of lame; check
112125
// if there's a more intuitive way of doing it right
113126
private object Accumulators {

core/src/main/scala/spark/SparkContext.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import java.io._
44
import java.util.concurrent.atomic.AtomicInteger
55

66
import scala.actors.remote.RemoteActor
7-
import scala.collection.mutable.ArrayBuffer
87

98
import org.apache.hadoop.fs.Path
109
import org.apache.hadoop.conf.Configuration
@@ -31,6 +30,7 @@ import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
3130
import org.apache.mesos.MesosNativeLibrary
3231

3332
import spark.broadcast._
33+
import collection.generic.Growable
3434

3535
class SparkContext(
3636
master: String,
@@ -253,6 +253,16 @@ class SparkContext(
253253
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
254254
new Accumulable(initialValue, param)
255255

256+
/**
257+
* create an accumulator from a "mutable collection" type.
258+
*
259+
* Growable and TraversableOnce are the standard apis that guarantee += and ++=, implemented by
260+
* standard mutable collections. So you can use this with mutable Map, Set, etc.
261+
*/
262+
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
263+
val param = new GrowableAccumulableParam[R,T]
264+
new Accumulable(initialValue, param)
265+
}
256266

257267
// Keep around a weak hash map of values to Cached versions?
258268
def broadcast[T](value: T) = Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal)

core/src/test/scala/spark/AccumulatorSuite.scala

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@ package spark
33
import org.scalatest.FunSuite
44
import org.scalatest.matchers.ShouldMatchers
55
import collection.mutable
6-
import java.util.Random
7-
import scala.math.exp
8-
import scala.math.signum
96
import spark.SparkContext._
107

118
class AccumulatorSuite extends FunSuite with ShouldMatchers {
@@ -79,6 +76,32 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
7976
}
8077
}
8178

79+
test ("collection accumulators") {
80+
val maxI = 1000
81+
for (nThreads <- List(1, 10)) {
82+
//test single & multi-threaded
83+
val sc = new SparkContext("local[" + nThreads + "]", "test")
84+
val setAcc = sc.accumulableCollection(mutable.HashSet[Int]())
85+
val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]())
86+
val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]())
87+
val d = sc.parallelize( (1 to maxI) ++ (1 to maxI))
88+
d.foreach {
89+
x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)}
90+
}
91+
92+
//NOTE that this is typed correctly -- no casts necessary
93+
setAcc.value.size should be (maxI)
94+
bufferAcc.value.size should be (2 * maxI)
95+
mapAcc.value.size should be (maxI)
96+
for (i <- 1 to maxI) {
97+
setAcc.value should contain(i)
98+
bufferAcc.value should contain(i)
99+
mapAcc.value should contain (i -> i.toString)
100+
}
101+
sc.stop()
102+
}
103+
}
104+
82105
test ("localValue readable in tasks") {
83106
import SetAccum._
84107
val maxI = 1000
@@ -94,4 +117,4 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
94117
}
95118
}
96119

97-
}
120+
}

0 commit comments

Comments
 (0)