Skip to content

Commit 8b34fb0

Browse files
Davies Liurxin
authored andcommitted
[SPARK-11864][SQL] Improve performance of max/min
This PR has the following optimization: 1) The greatest/least already does the null-check, so the `If` and `IsNull` are not necessary. 2) In greatest/least, it should initialize the result using the first child (removing one block). 3) For primitive types, the generated greater expression is too complicated (`a > b ? 1 : (a < b) ? -1 : 0) > 0`), should be as simple as `a > b` Combine these optimization, this could improve the performance of `ss_max` query by 30%. Author: Davies Liu <davies@databricks.com> Closes apache#9846 from davies/improve_max. (cherry picked from commit ee21407) Signed-off-by: Reynold Xin <rxin@databricks.com>
1 parent 19ea30d commit 8b34fb0

File tree

5 files changed

+45
-25
lines changed

5 files changed

+45
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,12 @@ case class Max(child: Expression) extends DeclarativeAggregate {
4646
)
4747

4848
override lazy val updateExpressions: Seq[Expression] = Seq(
49-
/* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child))))
49+
/* max = */ Greatest(Seq(max, child))
5050
)
5151

5252
override lazy val mergeExpressions: Seq[Expression] = {
53-
val greatest = Greatest(Seq(max.left, max.right))
5453
Seq(
55-
/* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest))
54+
/* max = */ Greatest(Seq(max.left, max.right))
5655
)
5756
}
5857

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,12 @@ case class Min(child: Expression) extends DeclarativeAggregate {
4747
)
4848

4949
override lazy val updateExpressions: Seq[Expression] = Seq(
50-
/* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child))))
50+
/* min = */ Least(Seq(min, child))
5151
)
5252

5353
override lazy val mergeExpressions: Seq[Expression] = {
54-
val least = Least(Seq(min.left, min.right))
5554
Seq(
56-
/* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least))
55+
/* min = */ Least(Seq(min.left, min.right))
5756
)
5857
}
5958

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,18 @@ class CodeGenContext {
329329
throw new IllegalArgumentException("cannot generate compare code for un-comparable type")
330330
}
331331

332+
/**
333+
* Generates code for greater of two expressions.
334+
*
335+
* @param dataType data type of the expressions
336+
* @param c1 name of the variable of expression 1's output
337+
* @param c2 name of the variable of expression 2's output
338+
*/
339+
def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match {
340+
case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2"
341+
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
342+
}
343+
332344
/**
333345
* List of java data types that have special accessors and setters in [[InternalRow]].
334346
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -348,19 +348,22 @@ case class Least(children: Seq[Expression]) extends Expression {
348348

349349
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
350350
val evalChildren = children.map(_.gen(ctx))
351-
def updateEval(i: Int): String =
351+
val first = evalChildren(0)
352+
val rest = evalChildren.drop(1)
353+
def updateEval(eval: GeneratedExpressionCode): String =
352354
s"""
353-
if (!${evalChildren(i).isNull} && (${ev.isNull} ||
354-
${ctx.genComp(dataType, evalChildren(i).value, ev.value)} < 0)) {
355+
${eval.code}
356+
if (!${eval.isNull} && (${ev.isNull} ||
357+
${ctx.genGreater(dataType, ev.value, eval.value)})) {
355358
${ev.isNull} = false;
356-
${ev.value} = ${evalChildren(i).value};
359+
${ev.value} = ${eval.value};
357360
}
358361
"""
359362
s"""
360-
${evalChildren.map(_.code).mkString("\n")}
361-
boolean ${ev.isNull} = true;
362-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
363-
${children.indices.map(updateEval).mkString("\n")}
363+
${first.code}
364+
boolean ${ev.isNull} = ${first.isNull};
365+
${ctx.javaType(dataType)} ${ev.value} = ${first.value};
366+
${rest.map(updateEval).mkString("\n")}
364367
"""
365368
}
366369
}
@@ -403,19 +406,22 @@ case class Greatest(children: Seq[Expression]) extends Expression {
403406

404407
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
405408
val evalChildren = children.map(_.gen(ctx))
406-
def updateEval(i: Int): String =
409+
val first = evalChildren(0)
410+
val rest = evalChildren.drop(1)
411+
def updateEval(eval: GeneratedExpressionCode): String =
407412
s"""
408-
if (!${evalChildren(i).isNull} && (${ev.isNull} ||
409-
${ctx.genComp(dataType, evalChildren(i).value, ev.value)} > 0)) {
413+
${eval.code}
414+
if (!${eval.isNull} && (${ev.isNull} ||
415+
${ctx.genGreater(dataType, eval.value, ev.value)})) {
410416
${ev.isNull} = false;
411-
${ev.value} = ${evalChildren(i).value};
417+
${ev.value} = ${eval.value};
412418
}
413419
"""
414420
s"""
415-
${evalChildren.map(_.code).mkString("\n")}
416-
boolean ${ev.isNull} = true;
417-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
418-
${children.indices.map(updateEval).mkString("\n")}
421+
${first.code}
422+
boolean ${ev.isNull} = ${first.isNull};
423+
${ctx.javaType(dataType)} ${ev.value} = ${first.value};
424+
${rest.map(updateEval).mkString("\n")}
419425
"""
420426
}
421427
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,15 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
6262
}
6363

6464
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
65+
val first = children(0)
66+
val rest = children.drop(1)
67+
val firstEval = first.gen(ctx)
6568
s"""
66-
boolean ${ev.isNull} = true;
67-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
69+
${firstEval.code}
70+
boolean ${ev.isNull} = ${firstEval.isNull};
71+
${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value};
6872
""" +
69-
children.map { e =>
73+
rest.map { e =>
7074
val eval = e.gen(ctx)
7175
s"""
7276
if (${ev.isNull}) {

0 commit comments

Comments
 (0)