@@ -66,6 +66,33 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
66
66
}
67
67
}
68
68
69
+ class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction {
70
+
71
+ def inputSchema : StructType = StructType (Nil )
72
+
73
+ def bufferSchema : StructType = StructType (StructField (" value" , LongType ) :: Nil )
74
+
75
+ def dataType : DataType = LongType
76
+
77
+ def deterministic : Boolean = true
78
+
79
+ def initialize (buffer : MutableAggregationBuffer ): Unit = {
80
+ buffer.update(0 , 0L )
81
+ }
82
+
83
+ def update (buffer : MutableAggregationBuffer , input : Row ): Unit = {
84
+ buffer.update(0 , input.getAs[Seq [Row ]](0 ).map(_.getAs[Int ](" v" )).sum + buffer.getLong(0 ))
85
+ }
86
+
87
+ def merge (buffer1 : MutableAggregationBuffer , buffer2 : Row ): Unit = {
88
+ buffer1.update(0 , buffer1.getLong(0 ) + buffer2.getLong(0 ))
89
+ }
90
+
91
+ def evaluate (buffer : Row ): Any = {
92
+ buffer.getLong(0 )
93
+ }
94
+ }
95
+
69
96
class LongProductSum extends UserDefinedAggregateFunction {
70
97
def inputSchema : StructType = new StructType ()
71
98
.add(" a" , LongType )
@@ -858,6 +885,43 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
858
885
)
859
886
}
860
887
}
888
+
889
+ test(" udaf without specifying inputSchema" ) {
890
+ withTempTable(" noInputSchemaUDAF" ) {
891
+ sqlContext.udf.register(" noInputSchema" , new ScalaAggregateFunctionWithoutInputSchema )
892
+
893
+ val data =
894
+ Row (1 , Seq (Row (1 ), Row (2 ), Row (3 ))) ::
895
+ Row (1 , Seq (Row (4 ), Row (5 ), Row (6 ))) ::
896
+ Row (2 , Seq (Row (- 10 ))) :: Nil
897
+ val schema =
898
+ StructType (
899
+ StructField (" key" , IntegerType ) ::
900
+ StructField (" myArray" ,
901
+ ArrayType (StructType (StructField (" v" , IntegerType ) :: Nil ))) :: Nil )
902
+ sqlContext.createDataFrame(
903
+ sparkContext.parallelize(data, 2 ),
904
+ schema)
905
+ .registerTempTable(" noInputSchemaUDAF" )
906
+
907
+ checkAnswer(
908
+ sqlContext.sql(
909
+ """
910
+ |SELECT key, noInputSchema(myArray)
911
+ |FROM noInputSchemaUDAF
912
+ |GROUP BY key
913
+ """ .stripMargin),
914
+ Row (1 , 21 ) :: Row (2 , - 10 ) :: Nil )
915
+
916
+ checkAnswer(
917
+ sqlContext.sql(
918
+ """
919
+ |SELECT noInputSchema(myArray)
920
+ |FROM noInputSchemaUDAF
921
+ """ .stripMargin),
922
+ Row (11 ) :: Nil )
923
+ }
924
+ }
861
925
}
862
926
863
927
0 commit comments