Skip to content

Commit 66d7066

Browse files
committed
Let the reducer retry if a fetch fails before reading all records
1 parent 588120c commit 66d7066

File tree

3 files changed

+43
-22
lines changed

3 files changed

+43
-22
lines changed

core/src/main/scala/spark/HttpServer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class HttpServer(resourceBase: File) extends Logging {
3030
server = new Server(0)
3131
val threadPool = new QueuedThreadPool
3232
threadPool.setDaemon(true)
33+
threadPool.setMinThreads(System.getProperty("spark.http.minThreads", "8").toInt)
3334
server.setThreadPool(threadPool)
3435
val resHandler = new ResourceHandler
3536
resHandler.setResourceBase(resourceBase.getAbsolutePath)

core/src/main/scala/spark/ShuffleMapTask.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class ShuffleMapTask(
3939
for (i <- 0 until numOutputSplits) {
4040
val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i)
4141
val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file)))
42+
out.writeObject(buckets(i).size)
4243
val iter = buckets(i).entrySet().iterator()
4344
while (iter.hasNext()) {
4445
val entry = iter.next()

core/src/main/scala/spark/SimpleShuffleFetcher.scala

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,51 @@ class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
1919
}
2020
for ((serverUri, inputIds) <- Utils.randomize(splitsByUri)) {
2121
for (i <- inputIds) {
22-
var numRecords = 0
23-
try {
24-
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
25-
// TODO: multithreaded fetch
26-
// TODO: would be nice to retry multiple times
27-
val inputStream = ser.inputStream(
28-
new FastBufferedInputStream(new URL(url).openStream()))
22+
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
23+
var totalRecords = -1
24+
var recordsProcessed = 0
25+
var tries = 0
26+
while (totalRecords == -1 || recordsProcessed < totalRecords) {
27+
tries += 1
28+
if (tries > 4) {
29+
// We've tried four times to get this data but we've had trouble; let's just declare
30+
// a failed fetch
31+
logError("Failed to fetch " + url + " four times; giving up")
32+
throw new FetchFailedException(serverUri, shuffleId, i, reduceId, null)
33+
}
34+
var recordsRead = 0
2935
try {
30-
while (true) {
31-
val pair = inputStream.readObject().asInstanceOf[(K, V)]
32-
func(pair._1, pair._2)
33-
numRecords += 1
36+
val inputStream = ser.inputStream(
37+
new FastBufferedInputStream(new URL(url).openStream()))
38+
try {
39+
totalRecords = inputStream.readObject().asInstanceOf[Int]
40+
logInfo("Total records to read from " + url + ": " + totalRecords)
41+
while (true) {
42+
val pair = inputStream.readObject().asInstanceOf[(K, V)]
43+
if (recordsRead <= recordsProcessed) {
44+
func(pair._1, pair._2)
45+
recordsProcessed += 1
46+
}
47+
recordsRead += 1
48+
}
49+
} finally {
50+
inputStream.close()
51+
}
52+
} catch {
53+
case e: EOFException => {
54+
logInfo("Reduce %s got %s records from map %s before EOF".format(
55+
reduceId, recordsRead, i))
56+
if (recordsRead < totalRecords) {
57+
logInfo("Retrying because we needed " + totalRecords + " in total!")
58+
}
59+
}
60+
case other: Exception => {
61+
logError("Fetch failed", other)
62+
throw new FetchFailedException(serverUri, shuffleId, i, reduceId, other)
3463
}
35-
} finally {
36-
inputStream.close()
37-
}
38-
} catch {
39-
case e: EOFException => {
40-
// We currently assume EOF means we read the whole thing
41-
logInfo("Reduce %s got %s records from map %s".format(reduceId, numRecords, i))
42-
}
43-
case other: Exception => {
44-
logError("Fetch failed", other)
45-
throw new FetchFailedException(serverUri, shuffleId, i, reduceId, other)
4664
}
4765
}
66+
logInfo("Fetched all " + totalRecords + " records successfully")
4867
}
4968
}
5069
}

0 commit comments

Comments
 (0)