Skip to content

Commit 6b1e5c2

Browse files
committed
[SPARK-10737] [SQL] When using UnsafeRows, SortMergeJoin may return wrong results
https://issues.apache.org/jira/browse/SPARK-10737 Author: Yin Huai <yhuai@databricks.com> Closes apache#8854 from yhuai/SMJBug. (cherry picked from commit 5aea987) Signed-off-by: Yin Huai <yhuai@databricks.com>
1 parent d83dcc9 commit 6b1e5c2

File tree

4 files changed

+59
-5
lines changed

4 files changed

+59
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
171171

172172
@Override
173173
public Object apply(Object r) {
174+
// GenerateProjection does not work with UnsafeRows.
175+
assert(!(r instanceof ${classOf[UnsafeRow].getName}));
174176
return new SpecificRow((InternalRow) r);
175177
}
176178

sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,11 @@ case class Window(
253253

254254
// Get all relevant projections.
255255
val result = createResultProjection(unboundExpressions)
256-
val grouping = newProjection(partitionSpec, child.output)
256+
val grouping = if (child.outputsUnsafeRows) {
257+
UnsafeProjection.create(partitionSpec, child.output)
258+
} else {
259+
newProjection(partitionSpec, child.output)
260+
}
257261

258262
// Manage the stream and the grouping.
259263
var nextRow: InternalRow = EmptyRow
@@ -277,7 +281,8 @@ case class Window(
277281
val numFrames = frames.length
278282
private[this] def fetchNextPartition() {
279283
// Collect all the rows in the current partition.
280-
val currentGroup = nextGroup
284+
// Before we start to fetch new input rows, make a copy of nextGroup.
285+
val currentGroup = nextGroup.copy()
281286
rows = new CompactBuffer
282287
while (nextRowAvailable && nextGroup == currentGroup) {
283288
rows += nextRow.copy()

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ case class SortMergeJoin(
5656
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
5757
requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
5858

59-
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
60-
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
61-
6259
protected[this] def isUnsafeMode: Boolean = {
6360
(codegenEnabled && unsafeEnabled
6461
&& UnsafeProjection.canSupport(leftKeys)
@@ -82,6 +79,28 @@ case class SortMergeJoin(
8279

8380
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
8481
new RowIterator {
82+
// The projection used to extract keys from input rows of the left child.
83+
private[this] val leftKeyGenerator = {
84+
if (isUnsafeMode) {
85+
// It is very important to use UnsafeProjection if input rows are UnsafeRows.
86+
// Otherwise, GenerateProjection will cause wrong results.
87+
UnsafeProjection.create(leftKeys, left.output)
88+
} else {
89+
newProjection(leftKeys, left.output)
90+
}
91+
}
92+
93+
// The projection used to extract keys from input rows of the right child.
94+
private[this] val rightKeyGenerator = {
95+
if (isUnsafeMode) {
96+
// It is very important to use UnsafeProjection if input rows are UnsafeRows.
97+
// Otherwise, GenerateProjection will cause wrong results.
98+
UnsafeProjection.create(rightKeys, right.output)
99+
} else {
100+
newProjection(rightKeys, right.output)
101+
}
102+
}
103+
85104
// An ordering that can be used to compare keys from both sides.
86105
private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
87106
private[this] var currentLeftRow: InternalRow = _

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,4 +1717,32 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
17171717
checkAnswer(
17181718
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
17191719
}
1720+
1721+
test("SortMergeJoin returns wrong results when using UnsafeRows") {
1722+
// This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737.
1723+
// This bug will be triggered when Tungsten is enabled and there are multiple
1724+
// SortMergeJoin operators executed in the same task.
1725+
val confs =
1726+
SQLConf.SORTMERGE_JOIN.key -> "true" ::
1727+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" ::
1728+
SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil
1729+
withSQLConf(confs: _*) {
1730+
val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j")
1731+
val df2 =
1732+
df1
1733+
.join(df1.select(df1("i")), "i")
1734+
.select(df1("i"), df1("j"))
1735+
1736+
val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1")
1737+
val df4 =
1738+
df2
1739+
.join(df3, df2("i") === df3("i1"))
1740+
.withColumn("diff", $"j" - $"j1")
1741+
.select(df2("i"), df2("j"), $"diff")
1742+
1743+
checkAnswer(
1744+
df4,
1745+
df1.withColumn("diff", lit(0)))
1746+
}
1747+
}
17201748
}

0 commit comments

Comments
 (0)