Skip to content

Commit d084a2d

Browse files
Davies Liudavies
authored andcommitted
[SPARK-12541] [SQL] support cube/rollup as function
This PR enable cube/rollup as function, so they can be used as this: ``` select a, b, sum(c) from t group by rollup(a, b) ``` Author: Davies Liu <davies@databricks.com> Closes apache#10522 from davies/rollup.
1 parent 93ef9b6 commit d084a2d

File tree

8 files changed

+87
-48
lines changed

8 files changed

+87
-48
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
2525
import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules._
2727
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
28-
import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf}
28+
import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
2929
import org.apache.spark.sql.types._
3030

3131
/**
@@ -208,10 +208,10 @@ class Analyzer(
208208

209209
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
210210
case a if !a.childrenResolved => a // be sure all of the children are resolved.
211-
case a: Cube =>
212-
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
213-
case a: Rollup =>
214-
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
211+
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
212+
GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
213+
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
214+
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
215215
case x: GroupingSets =>
216216
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
217217

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions._
22-
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression}
22+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2323
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.types._
2525

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ object FunctionRegistry {
285285
expression[InputFileName]("input_file_name"),
286286
expression[MonotonicallyIncreasingID]("monotonically_increasing_id"),
287287

288+
// grouping sets
289+
expression[Cube]("cube"),
290+
expression[Rollup]("rollup"),
291+
288292
// window functions
289293
expression[Lead]("lead"),
290294
expression[Lag]("lag"),
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
22+
import org.apache.spark.sql.types._
23+
24+
/**
25+
* A placeholder expression for cube/rollup, which will be replaced by analyzer
26+
*/
27+
trait GroupingSet extends Expression with CodegenFallback {
28+
29+
def groupByExprs: Seq[Expression]
30+
override def children: Seq[Expression] = groupByExprs
31+
32+
// this should be replaced first
33+
override lazy val resolved: Boolean = false
34+
35+
override def dataType: DataType = throw new UnsupportedOperationException
36+
override def foldable: Boolean = false
37+
override def nullable: Boolean = true
38+
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
39+
}
40+
41+
case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {}
42+
43+
case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -397,43 +397,6 @@ case class GroupingSets(
397397
this.copy(aggregations = aggs)
398398
}
399399

400-
/**
401-
* Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
402-
* and eventually will be transformed to Aggregate(.., Expand) in Analyzer
403-
*
404-
* @param groupByExprs The Group By expressions candidates.
405-
* @param child Child operator
406-
* @param aggregations The Aggregation expressions, those non selected group by expressions
407-
* will be considered as constant null if it appears in the expressions
408-
*/
409-
case class Cube(
410-
groupByExprs: Seq[Expression],
411-
child: LogicalPlan,
412-
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
413-
414-
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
415-
this.copy(aggregations = aggs)
416-
}
417-
418-
/**
419-
* Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
420-
* and eventually will be transformed to Aggregate(.., Expand) in Analyzer
421-
*
422-
* @param groupByExprs The Group By expressions candidates, take effective only if the
423-
* associated bit in the bitmask set to 1.
424-
* @param child Child operator
425-
* @param aggregations The Aggregation expressions, those non selected group by expressions
426-
* will be considered as constant null if it appears in the expressions
427-
*/
428-
case class Rollup(
429-
groupByExprs: Seq[Expression],
430-
child: LogicalPlan,
431-
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
432-
433-
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
434-
this.copy(aggregations = aggs)
435-
}
436-
437400
case class Pivot(
438401
groupByExprs: Seq[NamedExpression],
439402
pivotColumn: Expression,

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental
2424
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.aggregate._
27-
import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
27+
import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Aggregate}
2828
import org.apache.spark.sql.types.NumericType
2929

3030

@@ -58,10 +58,10 @@ class GroupedData protected[sql](
5858
df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
5959
case GroupedData.RollupType =>
6060
DataFrame(
61-
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
61+
df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan))
6262
case GroupedData.CubeType =>
6363
DataFrame(
64-
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
64+
df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan))
6565
case GroupedData.PivotType(pivotCol, values) =>
6666
val aliasedGrps = groupingExprs.map(alias)
6767
DataFrame(

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2028,4 +2028,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
20282028
Row(false) :: Row(true) :: Nil)
20292029
}
20302030

2031+
test("rollup") {
2032+
checkAnswer(
2033+
sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" +
2034+
" order by course, year"),
2035+
Row(null, null, 113000.0) ::
2036+
Row("Java", null, 50000.0) ::
2037+
Row("Java", 2012, 20000.0) ::
2038+
Row("Java", 2013, 30000.0) ::
2039+
Row("dotNET", null, 63000.0) ::
2040+
Row("dotNET", 2012, 15000.0) ::
2041+
Row("dotNET", 2013, 48000.0) :: Nil
2042+
)
2043+
}
2044+
2045+
test("cube") {
2046+
checkAnswer(
2047+
sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"),
2048+
Row("Java", 2012, 20000.0) ::
2049+
Row("Java", 2013, 30000.0) ::
2050+
Row("Java", null, 50000.0) ::
2051+
Row("dotNET", 2012, 15000.0) ::
2052+
Row("dotNET", 2013, 48000.0) ::
2053+
Row("dotNET", null, 63000.0) ::
2054+
Row(null, 2012, 35000.0) ::
2055+
Row(null, 2013, 78000.0) ::
2056+
Row(null, null, 113000.0) :: Nil
2057+
)
2058+
}
2059+
20312060
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,12 +1121,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
11211121
}),
11221122
rollupGroupByClause.map(e => e match {
11231123
case Token("TOK_ROLLUP_GROUPBY", children) =>
1124-
Rollup(children.map(nodeToExpr), withLateralView, selectExpressions)
1124+
Aggregate(Seq(Rollup(children.map(nodeToExpr))), selectExpressions, withLateralView)
11251125
case _ => sys.error("Expect WITH ROLLUP")
11261126
}),
11271127
cubeGroupByClause.map(e => e match {
11281128
case Token("TOK_CUBE_GROUPBY", children) =>
1129-
Cube(children.map(nodeToExpr), withLateralView, selectExpressions)
1129+
Aggregate(Seq(Cube(children.map(nodeToExpr))), selectExpressions, withLateralView)
11301130
case _ => sys.error("Expect WITH CUBE")
11311131
}),
11321132
Some(Project(selectExpressions, withLateralView))).flatten.head

0 commit comments

Comments
 (0)