Skip to content

Commit 6e3e3c6

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-12068][SQL] use a single column in Dataset.groupBy and count will fail
The reason is that, for a single culumn `RowEncoder`(or a single field product encoder), when we use it as the encoder for grouping key, we should also combine the grouping attributes, although there is only one grouping attribute. Author: Wenchen Fan <wenchen@databricks.com> Closes apache#10059 from cloud-fan/bug. (cherry picked from commit 8ddc55f) Signed-off-by: Michael Armbrust <michael@databricks.com>
1 parent 9b99b2b commit 6e3e3c6

File tree

4 files changed

+27
-7
lines changed

4 files changed

+27
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class Dataset[T] private[sql](
7070
* implicit so that we can use it when constructing new [[Dataset]] objects that have the same
7171
* object type (that will be possibly resolved to a different schema).
7272
*/
73-
private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
73+
private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
7474

7575
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
7676
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =

sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,11 @@ class GroupedDataset[K, V] private[sql](
228228
val namedColumns =
229229
columns.map(
230230
_.withInputType(resolvedVEncoder, dataAttributes).named)
231-
val keyColumn = if (groupingAttributes.length > 1) {
232-
Alias(CreateStruct(groupingAttributes), "key")()
233-
} else {
231+
val keyColumn = if (resolvedKEncoder.flat) {
232+
assert(groupingAttributes.length == 1)
234233
groupingAttributes.head
234+
} else {
235+
Alias(CreateStruct(groupingAttributes), "key")()
235236
}
236237
val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan)
237238
val execution = new QueryExecution(sqlContext, aggregate)

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
272272
3 -> "abcxyz", 5 -> "hello")
273273
}
274274

275+
test("groupBy single field class, count") {
276+
val ds = Seq("abc", "xyz", "hello").toDS()
277+
val count = ds.groupBy(s => Tuple1(s.length)).count()
278+
279+
checkAnswer(
280+
count,
281+
(Tuple1(3), 2L), (Tuple1(5), 1L)
282+
)
283+
}
284+
275285
test("groupBy columns, map") {
276286
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
277287
val grouped = ds.groupBy($"_1")
@@ -282,6 +292,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
282292
("a", 30), ("b", 3), ("c", 1))
283293
}
284294

295+
test("groupBy columns, count") {
296+
val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS()
297+
val count = ds.groupBy($"_1").count()
298+
299+
checkAnswer(
300+
count,
301+
(Row("a"), 2L), (Row("b"), 1L))
302+
}
303+
285304
test("groupBy columns asKey, map") {
286305
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
287306
val grouped = ds.groupBy($"_1").keyAs[String]

sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ abstract class QueryTest extends PlanTest {
6464
* for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
6565
* which performs a subset of the checks done by this function.
6666
*/
67-
protected def checkAnswer[T : Encoder](
68-
ds: => Dataset[T],
67+
protected def checkAnswer[T](
68+
ds: Dataset[T],
6969
expectedAnswer: T*): Unit = {
7070
checkAnswer(
7171
ds.toDF(),
72-
sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
72+
sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq)
7373

7474
checkDecoding(ds, expectedAnswer: _*)
7575
}

0 commit comments

Comments
 (0)