Skip to content

Commit 64cc62c

Browse files
JoshRosenmarmbrus
authored andcommitted
[SPARK-10403] Allow UnsafeRowSerializer to work with tungsten-sort ShuffleManager
This patch attempts to fix an issue where Spark SQL's UnsafeRowSerializer was incompatible with the `tungsten-sort` ShuffleManager. Author: Josh Rosen <joshrosen@databricks.com> Closes apache#8873 from JoshRosen/SPARK-10403. (cherry picked from commit a182080) Signed-off-by: Michael Armbrust <michael@databricks.com>
1 parent 6c6cadb commit 64cc62c

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,9 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S
4545
}
4646

4747
private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance {
48-
49-
/**
50-
* Marks the end of a stream written with [[serializeStream()]].
51-
*/
52-
private[this] val EOF: Int = -1
53-
5448
/**
5549
* Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
5650
* length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
57-
* The end of the stream is denoted by a record with the special length `EOF` (-1).
5851
*/
5952
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
6053
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
@@ -92,7 +85,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
9285

9386
override def close(): Unit = {
9487
writeBuffer = null
95-
dOut.writeInt(EOF)
9688
dOut.close()
9789
}
9890
}
@@ -104,12 +96,20 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
10496
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
10597
private[this] var row: UnsafeRow = new UnsafeRow()
10698
private[this] var rowTuple: (Int, UnsafeRow) = (0, row)
99+
private[this] val EOF: Int = -1
107100

108101
override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = {
109102
new Iterator[(Int, UnsafeRow)] {
110-
private[this] var rowSize: Int = dIn.readInt()
111-
if (rowSize == EOF) dIn.close()
112103

104+
private[this] def readSize(): Int = try {
105+
dIn.readInt()
106+
} catch {
107+
case e: EOFException =>
108+
dIn.close()
109+
EOF
110+
}
111+
112+
private[this] var rowSize: Int = readSize()
113113
override def hasNext: Boolean = rowSize != EOF
114114

115115
override def next(): (Int, UnsafeRow) = {
@@ -118,7 +118,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
118118
}
119119
ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
120120
row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize)
121-
rowSize = dIn.readInt() // read the next row's size
121+
rowSize = readSize()
122122
if (rowSize == EOF) { // We are returning the last row in this stream
123123
dIn.close()
124124
val _rowTuple = rowTuple

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
20+
import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream}
2121

2222
import org.apache.spark.executor.ShuffleWriteMetrics
23+
import org.apache.spark.rdd.RDD
2324
import org.apache.spark.storage.ShuffleBlockId
2425
import org.apache.spark.util.collection.ExternalSorter
2526
import org.apache.spark.util.Utils
@@ -41,7 +42,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
4142
}
4243
}
4344

44-
class UnsafeRowSerializerSuite extends SparkFunSuite {
45+
class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
4546

4647
private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
4748
val converter = unsafeRowConverter(schema)
@@ -87,11 +88,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
8788
}
8889

8990
test("close empty input stream") {
90-
val baos = new ByteArrayOutputStream()
91-
val dout = new DataOutputStream(baos)
92-
dout.writeInt(-1) // EOF
93-
dout.flush()
94-
val input = new ClosableByteArrayInputStream(baos.toByteArray)
91+
val input = new ClosableByteArrayInputStream(Array.empty)
9592
val serializer = new UnsafeRowSerializer(numFields = 2).newInstance()
9693
val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator
9794
assert(!deserializerIter.hasNext)
@@ -143,4 +140,16 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
143140
}
144141
}
145142
}
143+
144+
test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") {
145+
val conf = new SparkConf()
146+
.set("spark.shuffle.manager", "tungsten-sort")
147+
sc = new SparkContext("local", "test", conf)
148+
val row = Row("Hello", 123)
149+
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
150+
val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)))
151+
.asInstanceOf[RDD[Product2[Int, InternalRow]]]
152+
val shuffled = new ShuffledRowRDD(rowsRDD, new UnsafeRowSerializer(2), 2)
153+
shuffled.count()
154+
}
146155
}

0 commit comments

Comments
 (0)