17
17
18
18
package org .apache .spark .sql .catalyst
19
19
20
- import java .beans .Introspector
20
+ import java .beans .{ PropertyDescriptor , Introspector }
21
21
import java .lang .{Iterable => JIterable }
22
- import java .util .{Iterator => JIterator , Map => JMap }
22
+ import java .util .{Iterator => JIterator , Map => JMap , List => JList }
23
23
24
24
import scala .language .existentials
25
25
26
26
import com .google .common .reflect .TypeToken
27
+
27
28
import org .apache .spark .sql .types ._
29
+ import org .apache .spark .sql .catalyst .expressions ._
30
+ import org .apache .spark .sql .catalyst .analysis .{UnresolvedAttribute , UnresolvedExtractValue }
31
+ import org .apache .spark .sql .catalyst .util .{GenericArrayData , ArrayBasedMapData , DateTimeUtils }
32
+ import org .apache .spark .unsafe .types .UTF8String
33
+
28
34
29
35
/**
30
36
* Type-inference utilities for POJOs and Java collections.
@@ -33,13 +39,14 @@ object JavaTypeInference {
33
39
34
40
private val iterableType = TypeToken .of(classOf [JIterable [_]])
35
41
private val mapType = TypeToken .of(classOf [JMap [_, _]])
42
+ private val listType = TypeToken .of(classOf [JList [_]])
36
43
private val iteratorReturnType = classOf [JIterable [_]].getMethod(" iterator" ).getGenericReturnType
37
44
private val nextReturnType = classOf [JIterator [_]].getMethod(" next" ).getGenericReturnType
38
45
private val keySetReturnType = classOf [JMap [_, _]].getMethod(" keySet" ).getGenericReturnType
39
46
private val valuesReturnType = classOf [JMap [_, _]].getMethod(" values" ).getGenericReturnType
40
47
41
48
/**
42
- * Infers the corresponding SQL data type of a JavaClean class.
49
+ * Infers the corresponding SQL data type of a JavaBean class.
43
50
* @param beanClass Java type
44
51
* @return (SQL data type, nullable)
45
52
*/
@@ -58,6 +65,8 @@ object JavaTypeInference {
58
65
(c.getAnnotation(classOf [SQLUserDefinedType ]).udt().newInstance(), true )
59
66
60
67
case c : Class [_] if c == classOf [java.lang.String ] => (StringType , true )
68
+ case c : Class [_] if c == classOf [Array [Byte ]] => (BinaryType , true )
69
+
61
70
case c : Class [_] if c == java.lang.Short .TYPE => (ShortType , false )
62
71
case c : Class [_] if c == java.lang.Integer .TYPE => (IntegerType , false )
63
72
case c : Class [_] if c == java.lang.Long .TYPE => (LongType , false )
@@ -87,15 +96,14 @@ object JavaTypeInference {
87
96
(ArrayType (dataType, nullable), true )
88
97
89
98
case _ if mapType.isAssignableFrom(typeToken) =>
90
- val typeToken2 = typeToken.asInstanceOf [TypeToken [_ <: JMap [_, _]]]
91
- val mapSupertype = typeToken2.getSupertype(classOf [JMap [_, _]])
92
- val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
93
- val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
99
+ val (keyType, valueType) = mapKeyValueType(typeToken)
94
100
val (keyDataType, _) = inferDataType(keyType)
95
101
val (valueDataType, nullable) = inferDataType(valueType)
96
102
(MapType (keyDataType, valueDataType, nullable), true )
97
103
98
104
case _ =>
105
+ // TODO: we should only collect properties that have getter and setter. However, some tests
106
+ // pass in scala case class as java bean class which doesn't have getter and setter.
99
107
val beanInfo = Introspector .getBeanInfo(typeToken.getRawType)
100
108
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == " class" )
101
109
val fields = properties.map { property =>
@@ -107,11 +115,294 @@ object JavaTypeInference {
107
115
}
108
116
}
109
117
118
+ private def getJavaBeanProperties (beanClass : Class [_]): Array [PropertyDescriptor ] = {
119
+ val beanInfo = Introspector .getBeanInfo(beanClass)
120
+ beanInfo.getPropertyDescriptors
121
+ .filter(p => p.getReadMethod != null && p.getWriteMethod != null )
122
+ }
123
+
110
124
private def elementType (typeToken : TypeToken [_]): TypeToken [_] = {
111
125
val typeToken2 = typeToken.asInstanceOf [TypeToken [_ <: JIterable [_]]]
112
- val iterableSupertype = typeToken2.getSupertype(classOf [JIterable [_]])
113
- val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
114
- val itemType = iteratorType.resolveType(nextReturnType)
115
- itemType
126
+ val iterableSuperType = typeToken2.getSupertype(classOf [JIterable [_]])
127
+ val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
128
+ iteratorType.resolveType(nextReturnType)
129
+ }
130
+
131
+ private def mapKeyValueType (typeToken : TypeToken [_]): (TypeToken [_], TypeToken [_]) = {
132
+ val typeToken2 = typeToken.asInstanceOf [TypeToken [_ <: JMap [_, _]]]
133
+ val mapSuperType = typeToken2.getSupertype(classOf [JMap [_, _]])
134
+ val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
135
+ val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
136
+ keyType -> valueType
137
+ }
138
+
139
+ /**
140
+ * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
141
+ * to a native type, an ObjectType is returned.
142
+ *
143
+ * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
144
+ * system. As a result, ObjectType will be returned for things like boxed Integers.
145
+ */
146
+ private def inferExternalType (cls : Class [_]): DataType = cls match {
147
+ case c if c == java.lang.Boolean .TYPE => BooleanType
148
+ case c if c == java.lang.Byte .TYPE => ByteType
149
+ case c if c == java.lang.Short .TYPE => ShortType
150
+ case c if c == java.lang.Integer .TYPE => IntegerType
151
+ case c if c == java.lang.Long .TYPE => LongType
152
+ case c if c == java.lang.Float .TYPE => FloatType
153
+ case c if c == java.lang.Double .TYPE => DoubleType
154
+ case c if c == classOf [Array [Byte ]] => BinaryType
155
+ case _ => ObjectType (cls)
156
+ }
157
+
158
+ /**
159
+ * Returns an expression that can be used to construct an object of java bean `T` given an input
160
+ * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
161
+ * of the same name as the constructor arguments. Nested classes will have their fields accessed
162
+ * using UnresolvedExtractValue.
163
+ */
164
+ def constructorFor (beanClass : Class [_]): Expression = {
165
+ constructorFor(TypeToken .of(beanClass), None )
166
+ }
167
+
168
+ private def constructorFor (typeToken : TypeToken [_], path : Option [Expression ]): Expression = {
169
+ /** Returns the current path with a sub-field extracted. */
170
+ def addToPath (part : String ): Expression = path
171
+ .map(p => UnresolvedExtractValue (p, expressions.Literal (part)))
172
+ .getOrElse(UnresolvedAttribute (part))
173
+
174
+ /** Returns the current path or `BoundReference`. */
175
+ def getPath : Expression = path.getOrElse(BoundReference (0 , inferDataType(typeToken)._1, true ))
176
+
177
+ typeToken.getRawType match {
178
+ case c if ! inferExternalType(c).isInstanceOf [ObjectType ] => getPath
179
+
180
+ case c if c == classOf [java.lang.Short ] =>
181
+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
182
+ case c if c == classOf [java.lang.Integer ] =>
183
+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
184
+ case c if c == classOf [java.lang.Long ] =>
185
+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
186
+ case c if c == classOf [java.lang.Double ] =>
187
+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
188
+ case c if c == classOf [java.lang.Byte ] =>
189
+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
190
+ case c if c == classOf [java.lang.Float ] =>
191
+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
192
+ case c if c == classOf [java.lang.Boolean ] =>
193
+ NewInstance (c, getPath :: Nil , propagateNull = true , ObjectType (c))
194
+
195
+ case c if c == classOf [java.sql.Date ] =>
196
+ StaticInvoke (
197
+ DateTimeUtils ,
198
+ ObjectType (c),
199
+ " toJavaDate" ,
200
+ getPath :: Nil ,
201
+ propagateNull = true )
202
+
203
+ case c if c == classOf [java.sql.Timestamp ] =>
204
+ StaticInvoke (
205
+ DateTimeUtils ,
206
+ ObjectType (c),
207
+ " toJavaTimestamp" ,
208
+ getPath :: Nil ,
209
+ propagateNull = true )
210
+
211
+ case c if c == classOf [java.lang.String ] =>
212
+ Invoke (getPath, " toString" , ObjectType (classOf [String ]))
213
+
214
+ case c if c == classOf [java.math.BigDecimal ] =>
215
+ Invoke (getPath, " toJavaBigDecimal" , ObjectType (classOf [java.math.BigDecimal ]))
216
+
217
+ case c if c.isArray =>
218
+ val elementType = c.getComponentType
219
+ val primitiveMethod = elementType match {
220
+ case c if c == java.lang.Boolean .TYPE => Some (" toBooleanArray" )
221
+ case c if c == java.lang.Byte .TYPE => Some (" toByteArray" )
222
+ case c if c == java.lang.Short .TYPE => Some (" toShortArray" )
223
+ case c if c == java.lang.Integer .TYPE => Some (" toIntArray" )
224
+ case c if c == java.lang.Long .TYPE => Some (" toLongArray" )
225
+ case c if c == java.lang.Float .TYPE => Some (" toFloatArray" )
226
+ case c if c == java.lang.Double .TYPE => Some (" toDoubleArray" )
227
+ case _ => None
228
+ }
229
+
230
+ primitiveMethod.map { method =>
231
+ Invoke (getPath, method, ObjectType (c))
232
+ }.getOrElse {
233
+ Invoke (
234
+ MapObjects (
235
+ p => constructorFor(typeToken.getComponentType, Some (p)),
236
+ getPath,
237
+ inferDataType(elementType)._1),
238
+ " array" ,
239
+ ObjectType (c))
240
+ }
241
+
242
+ case c if listType.isAssignableFrom(typeToken) =>
243
+ val et = elementType(typeToken)
244
+ val array =
245
+ Invoke (
246
+ MapObjects (
247
+ p => constructorFor(et, Some (p)),
248
+ getPath,
249
+ inferDataType(et)._1),
250
+ " array" ,
251
+ ObjectType (classOf [Array [Any ]]))
252
+
253
+ StaticInvoke (classOf [java.util.Arrays ], ObjectType (c), " asList" , array :: Nil )
254
+
255
+ case _ if mapType.isAssignableFrom(typeToken) =>
256
+ val (keyType, valueType) = mapKeyValueType(typeToken)
257
+ val keyDataType = inferDataType(keyType)._1
258
+ val valueDataType = inferDataType(valueType)._1
259
+
260
+ val keyData =
261
+ Invoke (
262
+ MapObjects (
263
+ p => constructorFor(keyType, Some (p)),
264
+ Invoke (getPath, " keyArray" , ArrayType (keyDataType)),
265
+ keyDataType),
266
+ " array" ,
267
+ ObjectType (classOf [Array [Any ]]))
268
+
269
+ val valueData =
270
+ Invoke (
271
+ MapObjects (
272
+ p => constructorFor(valueType, Some (p)),
273
+ Invoke (getPath, " valueArray" , ArrayType (valueDataType)),
274
+ valueDataType),
275
+ " array" ,
276
+ ObjectType (classOf [Array [Any ]]))
277
+
278
+ StaticInvoke (
279
+ ArrayBasedMapData ,
280
+ ObjectType (classOf [JMap [_, _]]),
281
+ " toJavaMap" ,
282
+ keyData :: valueData :: Nil )
283
+
284
+ case other =>
285
+ val properties = getJavaBeanProperties(other)
286
+ assert(properties.length > 0 )
287
+
288
+ val setters = properties.map { p =>
289
+ val fieldName = p.getName
290
+ val fieldType = typeToken.method(p.getReadMethod).getReturnType
291
+ p.getWriteMethod.getName -> constructorFor(fieldType, Some (addToPath(fieldName)))
292
+ }.toMap
293
+
294
+ val newInstance = NewInstance (other, Nil , propagateNull = false , ObjectType (other))
295
+ val result = InitializeJavaBean (newInstance, setters)
296
+
297
+ if (path.nonEmpty) {
298
+ expressions.If (
299
+ IsNull (getPath),
300
+ expressions.Literal .create(null , ObjectType (other)),
301
+ result
302
+ )
303
+ } else {
304
+ result
305
+ }
306
+ }
307
+ }
308
+
309
+ /**
310
+ * Returns expressions for extracting all the fields from the given type.
311
+ */
312
+ def extractorsFor (beanClass : Class [_]): CreateNamedStruct = {
313
+ val inputObject = BoundReference (0 , ObjectType (beanClass), nullable = true )
314
+ extractorFor(inputObject, TypeToken .of(beanClass)).asInstanceOf [CreateNamedStruct ]
315
+ }
316
+
317
+ private def extractorFor (inputObject : Expression , typeToken : TypeToken [_]): Expression = {
318
+
319
+ def toCatalystArray (input : Expression , elementType : TypeToken [_]): Expression = {
320
+ val (dataType, nullable) = inferDataType(elementType)
321
+ if (ScalaReflection .isNativeType(dataType)) {
322
+ NewInstance (
323
+ classOf [GenericArrayData ],
324
+ input :: Nil ,
325
+ dataType = ArrayType (dataType, nullable))
326
+ } else {
327
+ MapObjects (extractorFor(_, elementType), input, ObjectType (elementType.getRawType))
328
+ }
329
+ }
330
+
331
+ if (! inputObject.dataType.isInstanceOf [ObjectType ]) {
332
+ inputObject
333
+ } else {
334
+ typeToken.getRawType match {
335
+ case c if c == classOf [String ] =>
336
+ StaticInvoke (
337
+ classOf [UTF8String ],
338
+ StringType ,
339
+ " fromString" ,
340
+ inputObject :: Nil )
341
+
342
+ case c if c == classOf [java.sql.Timestamp ] =>
343
+ StaticInvoke (
344
+ DateTimeUtils ,
345
+ TimestampType ,
346
+ " fromJavaTimestamp" ,
347
+ inputObject :: Nil )
348
+
349
+ case c if c == classOf [java.sql.Date ] =>
350
+ StaticInvoke (
351
+ DateTimeUtils ,
352
+ DateType ,
353
+ " fromJavaDate" ,
354
+ inputObject :: Nil )
355
+
356
+ case c if c == classOf [java.math.BigDecimal ] =>
357
+ StaticInvoke (
358
+ Decimal ,
359
+ DecimalType .SYSTEM_DEFAULT ,
360
+ " apply" ,
361
+ inputObject :: Nil )
362
+
363
+ case c if c == classOf [java.lang.Boolean ] =>
364
+ Invoke (inputObject, " booleanValue" , BooleanType )
365
+ case c if c == classOf [java.lang.Byte ] =>
366
+ Invoke (inputObject, " byteValue" , ByteType )
367
+ case c if c == classOf [java.lang.Short ] =>
368
+ Invoke (inputObject, " shortValue" , ShortType )
369
+ case c if c == classOf [java.lang.Integer ] =>
370
+ Invoke (inputObject, " intValue" , IntegerType )
371
+ case c if c == classOf [java.lang.Long ] =>
372
+ Invoke (inputObject, " longValue" , LongType )
373
+ case c if c == classOf [java.lang.Float ] =>
374
+ Invoke (inputObject, " floatValue" , FloatType )
375
+ case c if c == classOf [java.lang.Double ] =>
376
+ Invoke (inputObject, " doubleValue" , DoubleType )
377
+
378
+ case _ if typeToken.isArray =>
379
+ toCatalystArray(inputObject, typeToken.getComponentType)
380
+
381
+ case _ if listType.isAssignableFrom(typeToken) =>
382
+ toCatalystArray(inputObject, elementType(typeToken))
383
+
384
+ case _ if mapType.isAssignableFrom(typeToken) =>
385
+ // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can
386
+ // not guarantee they have same iteration order(which is different from scala map).
387
+ // A possible solution is creating a new `MapObjects` that can iterate a map directly.
388
+ throw new UnsupportedOperationException (" map type is not supported currently" )
389
+
390
+ case other =>
391
+ val properties = getJavaBeanProperties(other)
392
+ if (properties.length > 0 ) {
393
+ CreateNamedStruct (properties.flatMap { p =>
394
+ val fieldName = p.getName
395
+ val fieldType = typeToken.method(p.getReadMethod).getReturnType
396
+ val fieldValue = Invoke (
397
+ inputObject,
398
+ p.getReadMethod.getName,
399
+ inferExternalType(fieldType.getRawType))
400
+ expressions.Literal (fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
401
+ })
402
+ } else {
403
+ throw new UnsupportedOperationException (s " no encoder found for ${other.getName}" )
404
+ }
405
+ }
406
+ }
116
407
}
117
408
}
0 commit comments