Skip to content

Commit 54334d3

Browse files
cloud-fanyhuai
authored andcommitted
[SPARK-10740] [SQL] handle nondeterministic expressions correctly for set operations
https://issues.apache.org/jira/browse/SPARK-10740 Author: Wenchen Fan <cloud0fan@163.com> Closes apache#8858 from cloud-fan/non-deter. (cherry picked from commit 5017c68) Signed-off-by: Yin Huai <yhuai@databricks.com>
1 parent c3112a9 commit 54334d3

File tree

3 files changed

+93
-20
lines changed

3 files changed

+93
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] {
9595
* Intersect:
9696
* It is not safe to pushdown Projections through it because we need to get the
9797
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
98-
* because we will not have non-deterministic expressions.
98+
* with deterministic condition.
9999
*
100100
* Except:
101101
* It is not safe to pushdown Projections through it because we need to get the
102102
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
103-
* because we will not have non-deterministic expressions.
103+
* with deterministic condition.
104104
*/
105-
object SetOperationPushDown extends Rule[LogicalPlan] {
105+
object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
106106

107107
/**
108108
* Maps Attributes from the left side to the corresponding Attribute on the right side.
@@ -129,34 +129,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] {
129129
result.asInstanceOf[A]
130130
}
131131

132+
/**
133+
* Splits the condition expression into small conditions by `And`, and partition them by
134+
* deterministic, and finally recombine them by `And`. It returns an expression containing
135+
* all deterministic expressions (the first field of the returned Tuple2) and an expression
136+
* containing all non-deterministic expressions (the second field of the returned Tuple2).
137+
*/
138+
private def partitionByDeterministic(condition: Expression): (Expression, Expression) = {
139+
val andConditions = splitConjunctivePredicates(condition)
140+
andConditions.partition(_.deterministic) match {
141+
case (deterministic, nondeterministic) =>
142+
deterministic.reduceOption(And).getOrElse(Literal(true)) ->
143+
nondeterministic.reduceOption(And).getOrElse(Literal(true))
144+
}
145+
}
146+
132147
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
133148
// Push down filter into union
134149
case Filter(condition, u @ Union(left, right)) =>
150+
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
135151
val rewrites = buildRewrites(u)
136-
Union(
137-
Filter(condition, left),
138-
Filter(pushToRight(condition, rewrites), right))
139-
140-
// Push down projection through UNION ALL
141-
case Project(projectList, u @ Union(left, right)) =>
142-
val rewrites = buildRewrites(u)
143-
Union(
144-
Project(projectList, left),
145-
Project(projectList.map(pushToRight(_, rewrites)), right))
152+
Filter(nondeterministic,
153+
Union(
154+
Filter(deterministic, left),
155+
Filter(pushToRight(deterministic, rewrites), right)
156+
)
157+
)
158+
159+
// Push down deterministic projection through UNION ALL
160+
case p @ Project(projectList, u @ Union(left, right)) =>
161+
if (projectList.forall(_.deterministic)) {
162+
val rewrites = buildRewrites(u)
163+
Union(
164+
Project(projectList, left),
165+
Project(projectList.map(pushToRight(_, rewrites)), right))
166+
} else {
167+
p
168+
}
146169

147170
// Push down filter through INTERSECT
148171
case Filter(condition, i @ Intersect(left, right)) =>
172+
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
149173
val rewrites = buildRewrites(i)
150-
Intersect(
151-
Filter(condition, left),
152-
Filter(pushToRight(condition, rewrites), right))
174+
Filter(nondeterministic,
175+
Intersect(
176+
Filter(deterministic, left),
177+
Filter(pushToRight(deterministic, rewrites), right)
178+
)
179+
)
153180

154181
// Push down filter through EXCEPT
155182
case Filter(condition, e @ Except(left, right)) =>
183+
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
156184
val rewrites = buildRewrites(e)
157-
Except(
158-
Filter(condition, left),
159-
Filter(pushToRight(condition, rewrites), right))
185+
Filter(nondeterministic,
186+
Except(
187+
Filter(deterministic, left),
188+
Filter(pushToRight(deterministic, rewrites), right)
189+
)
190+
)
160191
}
161192
}
162193

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest {
3030
Batch("Subqueries", Once,
3131
EliminateSubQueries) ::
3232
Batch("Union Pushdown", Once,
33-
SetOperationPushDown) :: Nil
33+
SetOperationPushDown,
34+
SimplifyFilters) :: Nil
3435
}
3536

3637
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,4 +896,45 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
896896
assert(intersect.count() === 30)
897897
assert(except.count() === 70)
898898
}
899+
900+
test("SPARK-10740: handle nondeterministic expressions correctly for set operations") {
901+
val df1 = (1 to 20).map(Tuple1.apply).toDF("i")
902+
val df2 = (1 to 10).map(Tuple1.apply).toDF("i")
903+
904+
// When generating expected results at here, we need to follow the implementation of
905+
// Rand expression.
906+
def expected(df: DataFrame): Seq[Row] = {
907+
df.rdd.collectPartitions().zipWithIndex.flatMap {
908+
case (data, index) =>
909+
val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
910+
data.filter(_.getInt(0) < rng.nextDouble() * 10)
911+
}
912+
}
913+
914+
val union = df1.unionAll(df2)
915+
checkAnswer(
916+
union.filter('i < rand(7) * 10),
917+
expected(union)
918+
)
919+
checkAnswer(
920+
union.select(rand(7)),
921+
union.rdd.collectPartitions().zipWithIndex.flatMap {
922+
case (data, index) =>
923+
val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
924+
data.map(_ => rng.nextDouble()).map(i => Row(i))
925+
}
926+
)
927+
928+
val intersect = df1.intersect(df2)
929+
checkAnswer(
930+
intersect.filter('i < rand(7) * 10),
931+
expected(intersect)
932+
)
933+
934+
val except = df1.except(df2)
935+
checkAnswer(
936+
except.filter('i < rand(7) * 10),
937+
expected(except)
938+
)
939+
}
899940
}

0 commit comments

Comments
 (0)