Skip to content

Commit add4e63

Browse files
viiryayhuai
authored andcommitted
[SPARK-11949][SQL] Set field nullable property for GroupingSets to get correct results for null values
JIRA: https://issues.apache.org/jira/browse/SPARK-11949 The result of cube plan uses incorrect schema. The schema of cube result should set nullable property to true because the grouping expressions will have null values. Author: Liang-Chi Hsieh <viirya@appier.com> Closes apache#10038 from viirya/fix-cube. (cherry picked from commit c87531b) Signed-off-by: Yin Huai <yhuai@databricks.com>
1 parent 1aa39bd commit add4e63

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,11 @@ class Analyzer(
223223
case other => Alias(other, other.toString)()
224224
}
225225

226+
// TODO: We need to use bitmasks to determine which grouping expressions need to be
227+
// set as nullable. For example, if we have GROUPING SETS ((a,b), a), we do not need
228+
// to change the nullability of a.
229+
val attributeMap = groupByAliases.map(a => (a -> a.toAttribute.withNullability(true))).toMap
230+
226231
val aggregations: Seq[NamedExpression] = x.aggregations.map {
227232
// If an expression is an aggregate (contains a AggregateExpression) then we dont change
228233
// it so that the aggregation is computed on the unmodified value of its argument
@@ -231,12 +236,13 @@ class Analyzer(
231236
// If not then its a grouping expression and we need to use the modified (with nulls from
232237
// Expand) value of the expression.
233238
case expr => expr.transformDown {
234-
case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e)
239+
case e =>
240+
groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e)
235241
}.asInstanceOf[NamedExpression]
236242
}
237243

238244
val child = Project(x.child.output ++ groupByAliases, x.child)
239-
val groupByAttributes = groupByAliases.map(_.toAttribute)
245+
val groupByAttributes = groupByAliases.map(attributeMap(_))
240246

241247
Aggregate(
242248
groupByAttributes :+ VirtualColumn.groupingIdAttribute,

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.functions._
2121
import org.apache.spark.sql.test.SharedSQLContext
2222
import org.apache.spark.sql.types.DecimalType
2323

24+
case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)
2425

2526
class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
2627
import testImplicits._
@@ -86,6 +87,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
8687
Row(null, 2013, 78000.0) ::
8788
Row(null, null, 113000.0) :: Nil
8889
)
90+
91+
val df0 = sqlContext.sparkContext.parallelize(Seq(
92+
Fact(20151123, 18, 35, "room1", 18.6),
93+
Fact(20151123, 18, 35, "room2", 22.4),
94+
Fact(20151123, 18, 36, "room1", 17.4),
95+
Fact(20151123, 18, 36, "room2", 25.6))).toDF()
96+
97+
val cube0 = df0.cube("date", "hour", "minute", "room_name").agg(Map("temp" -> "avg"))
98+
assert(cube0.where("date IS NULL").count > 0)
8999
}
90100

91101
test("rollup overlapping columns") {

0 commit comments

Comments
 (0)