Skip to content

Commit 594fafc

Browse files
committed
[SPARK-12250][SQL] Allow users to define a UDAF without providing details of its inputSchema
https://issues.apache.org/jira/browse/SPARK-12250 Author: Yin Huai <yhuai@databricks.com> Closes apache#10236 from yhuai/SPARK-12250. (cherry picked from commit bc5f56a) Signed-off-by: Yin Huai <yhuai@databricks.com>
1 parent e541f70 commit 594fafc

File tree

2 files changed

+64
-5
lines changed

2 files changed

+64
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,6 @@ private[sql] case class ScalaUDAF(
332332
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
333333
copy(inputAggBufferOffset = newInputAggBufferOffset)
334334

335-
require(
336-
children.length == udaf.inputSchema.length,
337-
s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
338-
s"but ${children.length} are provided.")
339-
340335
override def nullable: Boolean = true
341336

342337
override def dataType: DataType = udaf.dataType

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,33 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
6666
}
6767
}
6868

69+
class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction {
70+
71+
def inputSchema: StructType = StructType(Nil)
72+
73+
def bufferSchema: StructType = StructType(StructField("value", LongType) :: Nil)
74+
75+
def dataType: DataType = LongType
76+
77+
def deterministic: Boolean = true
78+
79+
def initialize(buffer: MutableAggregationBuffer): Unit = {
80+
buffer.update(0, 0L)
81+
}
82+
83+
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
84+
buffer.update(0, input.getAs[Seq[Row]](0).map(_.getAs[Int]("v")).sum + buffer.getLong(0))
85+
}
86+
87+
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
88+
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
89+
}
90+
91+
def evaluate(buffer: Row): Any = {
92+
buffer.getLong(0)
93+
}
94+
}
95+
6996
class LongProductSum extends UserDefinedAggregateFunction {
7097
def inputSchema: StructType = new StructType()
7198
.add("a", LongType)
@@ -858,6 +885,43 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
858885
)
859886
}
860887
}
888+
889+
test("udaf without specifying inputSchema") {
890+
withTempTable("noInputSchemaUDAF") {
891+
sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema)
892+
893+
val data =
894+
Row(1, Seq(Row(1), Row(2), Row(3))) ::
895+
Row(1, Seq(Row(4), Row(5), Row(6))) ::
896+
Row(2, Seq(Row(-10))) :: Nil
897+
val schema =
898+
StructType(
899+
StructField("key", IntegerType) ::
900+
StructField("myArray",
901+
ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil)
902+
sqlContext.createDataFrame(
903+
sparkContext.parallelize(data, 2),
904+
schema)
905+
.registerTempTable("noInputSchemaUDAF")
906+
907+
checkAnswer(
908+
sqlContext.sql(
909+
"""
910+
|SELECT key, noInputSchema(myArray)
911+
|FROM noInputSchemaUDAF
912+
|GROUP BY key
913+
""".stripMargin),
914+
Row(1, 21) :: Row(2, -10) :: Nil)
915+
916+
checkAnswer(
917+
sqlContext.sql(
918+
"""
919+
|SELECT noInputSchema(myArray)
920+
|FROM noInputSchemaUDAF
921+
""".stripMargin),
922+
Row(11) :: Nil)
923+
}
924+
}
861925
}
862926

863927

0 commit comments

Comments
 (0)