Skip to content

Commit 88bbce0

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-11954][SQL] Encoder for JavaBeans
create java version of `constructorFor` and `extractorFor` in `JavaTypeInference` Author: Wenchen Fan <wenchen@databricks.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes apache#9937 from cloud-fan/pojo. (cherry picked from commit fd95eea) Signed-off-by: Michael Armbrust <michael@databricks.com>
1 parent 74a2306 commit 88bbce0

File tree

9 files changed

+608
-20
lines changed

9 files changed

+608
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,24 @@ object Encoders {
9797
*/
9898
def STRING: Encoder[java.lang.String] = ExpressionEncoder()
9999

100+
/**
101+
* Creates an encoder for Java Bean of type T.
102+
*
103+
* T must be publicly accessible.
104+
*
105+
* supported types for java bean field:
106+
* - primitive types: boolean, int, double, etc.
107+
* - boxed types: Boolean, Integer, Double, etc.
108+
* - String
109+
* - java.math.BigDecimal
110+
* - time related: java.sql.Date, java.sql.Timestamp
111+
* - collection types: only array and java.util.List currently, map support is in progress
112+
* - nested java bean.
113+
*
114+
* @since 1.6.0
115+
*/
116+
def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
117+
100118
/**
101119
* (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
102120
* This encoder maps T into a single byte array (binary) field.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 302 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,20 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20-
import java.beans.Introspector
20+
import java.beans.{PropertyDescriptor, Introspector}
2121
import java.lang.{Iterable => JIterable}
22-
import java.util.{Iterator => JIterator, Map => JMap}
22+
import java.util.{Iterator => JIterator, Map => JMap, List => JList}
2323

2424
import scala.language.existentials
2525

2626
import com.google.common.reflect.TypeToken
27+
2728
import org.apache.spark.sql.types._
29+
import org.apache.spark.sql.catalyst.expressions._
30+
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
31+
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
32+
import org.apache.spark.unsafe.types.UTF8String
33+
2834

2935
/**
3036
* Type-inference utilities for POJOs and Java collections.
@@ -33,13 +39,14 @@ object JavaTypeInference {
3339

3440
private val iterableType = TypeToken.of(classOf[JIterable[_]])
3541
private val mapType = TypeToken.of(classOf[JMap[_, _]])
42+
private val listType = TypeToken.of(classOf[JList[_]])
3643
private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
3744
private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
3845
private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
3946
private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
4047

4148
/**
42-
* Infers the corresponding SQL data type of a JavaClean class.
49+
* Infers the corresponding SQL data type of a JavaBean class.
4350
* @param beanClass Java type
4451
* @return (SQL data type, nullable)
4552
*/
@@ -58,6 +65,8 @@ object JavaTypeInference {
5865
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
5966

6067
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
68+
case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true)
69+
6170
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
6271
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
6372
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
@@ -87,15 +96,14 @@ object JavaTypeInference {
8796
(ArrayType(dataType, nullable), true)
8897

8998
case _ if mapType.isAssignableFrom(typeToken) =>
90-
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
91-
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
92-
val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
93-
val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
99+
val (keyType, valueType) = mapKeyValueType(typeToken)
94100
val (keyDataType, _) = inferDataType(keyType)
95101
val (valueDataType, nullable) = inferDataType(valueType)
96102
(MapType(keyDataType, valueDataType, nullable), true)
97103

98104
case _ =>
105+
// TODO: we should only collect properties that have getter and setter. However, some tests
106+
// pass in scala case class as java bean class which doesn't have getter and setter.
99107
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
100108
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
101109
val fields = properties.map { property =>
@@ -107,11 +115,294 @@ object JavaTypeInference {
107115
}
108116
}
109117

118+
private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
119+
val beanInfo = Introspector.getBeanInfo(beanClass)
120+
beanInfo.getPropertyDescriptors
121+
.filter(p => p.getReadMethod != null && p.getWriteMethod != null)
122+
}
123+
110124
private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
111125
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
112-
val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
113-
val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
114-
val itemType = iteratorType.resolveType(nextReturnType)
115-
itemType
126+
val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]])
127+
val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
128+
iteratorType.resolveType(nextReturnType)
129+
}
130+
131+
private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = {
132+
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
133+
val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]])
134+
val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
135+
val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
136+
keyType -> valueType
137+
}
138+
139+
/**
140+
* Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
141+
* to a native type, an ObjectType is returned.
142+
*
143+
* Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
144+
* system. As a result, ObjectType will be returned for things like boxed Integers.
145+
*/
146+
private def inferExternalType(cls: Class[_]): DataType = cls match {
147+
case c if c == java.lang.Boolean.TYPE => BooleanType
148+
case c if c == java.lang.Byte.TYPE => ByteType
149+
case c if c == java.lang.Short.TYPE => ShortType
150+
case c if c == java.lang.Integer.TYPE => IntegerType
151+
case c if c == java.lang.Long.TYPE => LongType
152+
case c if c == java.lang.Float.TYPE => FloatType
153+
case c if c == java.lang.Double.TYPE => DoubleType
154+
case c if c == classOf[Array[Byte]] => BinaryType
155+
case _ => ObjectType(cls)
156+
}
157+
158+
/**
159+
* Returns an expression that can be used to construct an object of java bean `T` given an input
160+
* row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
161+
* of the same name as the constructor arguments. Nested classes will have their fields accessed
162+
* using UnresolvedExtractValue.
163+
*/
164+
def constructorFor(beanClass: Class[_]): Expression = {
165+
constructorFor(TypeToken.of(beanClass), None)
166+
}
167+
168+
private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
169+
/** Returns the current path with a sub-field extracted. */
170+
def addToPath(part: String): Expression = path
171+
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
172+
.getOrElse(UnresolvedAttribute(part))
173+
174+
/** Returns the current path or `BoundReference`. */
175+
def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true))
176+
177+
typeToken.getRawType match {
178+
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
179+
180+
case c if c == classOf[java.lang.Short] =>
181+
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
182+
case c if c == classOf[java.lang.Integer] =>
183+
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
184+
case c if c == classOf[java.lang.Long] =>
185+
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
186+
case c if c == classOf[java.lang.Double] =>
187+
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
188+
case c if c == classOf[java.lang.Byte] =>
189+
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
190+
case c if c == classOf[java.lang.Float] =>
191+
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
192+
case c if c == classOf[java.lang.Boolean] =>
193+
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
194+
195+
case c if c == classOf[java.sql.Date] =>
196+
StaticInvoke(
197+
DateTimeUtils,
198+
ObjectType(c),
199+
"toJavaDate",
200+
getPath :: Nil,
201+
propagateNull = true)
202+
203+
case c if c == classOf[java.sql.Timestamp] =>
204+
StaticInvoke(
205+
DateTimeUtils,
206+
ObjectType(c),
207+
"toJavaTimestamp",
208+
getPath :: Nil,
209+
propagateNull = true)
210+
211+
case c if c == classOf[java.lang.String] =>
212+
Invoke(getPath, "toString", ObjectType(classOf[String]))
213+
214+
case c if c == classOf[java.math.BigDecimal] =>
215+
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
216+
217+
case c if c.isArray =>
218+
val elementType = c.getComponentType
219+
val primitiveMethod = elementType match {
220+
case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray")
221+
case c if c == java.lang.Byte.TYPE => Some("toByteArray")
222+
case c if c == java.lang.Short.TYPE => Some("toShortArray")
223+
case c if c == java.lang.Integer.TYPE => Some("toIntArray")
224+
case c if c == java.lang.Long.TYPE => Some("toLongArray")
225+
case c if c == java.lang.Float.TYPE => Some("toFloatArray")
226+
case c if c == java.lang.Double.TYPE => Some("toDoubleArray")
227+
case _ => None
228+
}
229+
230+
primitiveMethod.map { method =>
231+
Invoke(getPath, method, ObjectType(c))
232+
}.getOrElse {
233+
Invoke(
234+
MapObjects(
235+
p => constructorFor(typeToken.getComponentType, Some(p)),
236+
getPath,
237+
inferDataType(elementType)._1),
238+
"array",
239+
ObjectType(c))
240+
}
241+
242+
case c if listType.isAssignableFrom(typeToken) =>
243+
val et = elementType(typeToken)
244+
val array =
245+
Invoke(
246+
MapObjects(
247+
p => constructorFor(et, Some(p)),
248+
getPath,
249+
inferDataType(et)._1),
250+
"array",
251+
ObjectType(classOf[Array[Any]]))
252+
253+
StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil)
254+
255+
case _ if mapType.isAssignableFrom(typeToken) =>
256+
val (keyType, valueType) = mapKeyValueType(typeToken)
257+
val keyDataType = inferDataType(keyType)._1
258+
val valueDataType = inferDataType(valueType)._1
259+
260+
val keyData =
261+
Invoke(
262+
MapObjects(
263+
p => constructorFor(keyType, Some(p)),
264+
Invoke(getPath, "keyArray", ArrayType(keyDataType)),
265+
keyDataType),
266+
"array",
267+
ObjectType(classOf[Array[Any]]))
268+
269+
val valueData =
270+
Invoke(
271+
MapObjects(
272+
p => constructorFor(valueType, Some(p)),
273+
Invoke(getPath, "valueArray", ArrayType(valueDataType)),
274+
valueDataType),
275+
"array",
276+
ObjectType(classOf[Array[Any]]))
277+
278+
StaticInvoke(
279+
ArrayBasedMapData,
280+
ObjectType(classOf[JMap[_, _]]),
281+
"toJavaMap",
282+
keyData :: valueData :: Nil)
283+
284+
case other =>
285+
val properties = getJavaBeanProperties(other)
286+
assert(properties.length > 0)
287+
288+
val setters = properties.map { p =>
289+
val fieldName = p.getName
290+
val fieldType = typeToken.method(p.getReadMethod).getReturnType
291+
p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName)))
292+
}.toMap
293+
294+
val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
295+
val result = InitializeJavaBean(newInstance, setters)
296+
297+
if (path.nonEmpty) {
298+
expressions.If(
299+
IsNull(getPath),
300+
expressions.Literal.create(null, ObjectType(other)),
301+
result
302+
)
303+
} else {
304+
result
305+
}
306+
}
307+
}
308+
309+
/**
310+
* Returns expressions for extracting all the fields from the given type.
311+
*/
312+
def extractorsFor(beanClass: Class[_]): CreateNamedStruct = {
313+
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
314+
extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
315+
}
316+
317+
private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
318+
319+
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
320+
val (dataType, nullable) = inferDataType(elementType)
321+
if (ScalaReflection.isNativeType(dataType)) {
322+
NewInstance(
323+
classOf[GenericArrayData],
324+
input :: Nil,
325+
dataType = ArrayType(dataType, nullable))
326+
} else {
327+
MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType))
328+
}
329+
}
330+
331+
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
332+
inputObject
333+
} else {
334+
typeToken.getRawType match {
335+
case c if c == classOf[String] =>
336+
StaticInvoke(
337+
classOf[UTF8String],
338+
StringType,
339+
"fromString",
340+
inputObject :: Nil)
341+
342+
case c if c == classOf[java.sql.Timestamp] =>
343+
StaticInvoke(
344+
DateTimeUtils,
345+
TimestampType,
346+
"fromJavaTimestamp",
347+
inputObject :: Nil)
348+
349+
case c if c == classOf[java.sql.Date] =>
350+
StaticInvoke(
351+
DateTimeUtils,
352+
DateType,
353+
"fromJavaDate",
354+
inputObject :: Nil)
355+
356+
case c if c == classOf[java.math.BigDecimal] =>
357+
StaticInvoke(
358+
Decimal,
359+
DecimalType.SYSTEM_DEFAULT,
360+
"apply",
361+
inputObject :: Nil)
362+
363+
case c if c == classOf[java.lang.Boolean] =>
364+
Invoke(inputObject, "booleanValue", BooleanType)
365+
case c if c == classOf[java.lang.Byte] =>
366+
Invoke(inputObject, "byteValue", ByteType)
367+
case c if c == classOf[java.lang.Short] =>
368+
Invoke(inputObject, "shortValue", ShortType)
369+
case c if c == classOf[java.lang.Integer] =>
370+
Invoke(inputObject, "intValue", IntegerType)
371+
case c if c == classOf[java.lang.Long] =>
372+
Invoke(inputObject, "longValue", LongType)
373+
case c if c == classOf[java.lang.Float] =>
374+
Invoke(inputObject, "floatValue", FloatType)
375+
case c if c == classOf[java.lang.Double] =>
376+
Invoke(inputObject, "doubleValue", DoubleType)
377+
378+
case _ if typeToken.isArray =>
379+
toCatalystArray(inputObject, typeToken.getComponentType)
380+
381+
case _ if listType.isAssignableFrom(typeToken) =>
382+
toCatalystArray(inputObject, elementType(typeToken))
383+
384+
case _ if mapType.isAssignableFrom(typeToken) =>
385+
// TODO: for java map, if we get the keys and values by `keySet` and `values`, we can
386+
// not guarantee they have same iteration order(which is different from scala map).
387+
// A possible solution is creating a new `MapObjects` that can iterate a map directly.
388+
throw new UnsupportedOperationException("map type is not supported currently")
389+
390+
case other =>
391+
val properties = getJavaBeanProperties(other)
392+
if (properties.length > 0) {
393+
CreateNamedStruct(properties.flatMap { p =>
394+
val fieldName = p.getName
395+
val fieldType = typeToken.method(p.getReadMethod).getReturnType
396+
val fieldValue = Invoke(
397+
inputObject,
398+
p.getReadMethod.getName,
399+
inferExternalType(fieldType.getRawType))
400+
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
401+
})
402+
} else {
403+
throw new UnsupportedOperationException(s"no encoder found for ${other.getName}")
404+
}
405+
}
406+
}
116407
}
117408
}

0 commit comments

Comments
 (0)