Skip to content

Commit ff156a3

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-11819][SQL] nice error message for missing encoder
before this PR, when users try to get an encoder for an un-supported class, they will only get a very simple error message like `Encoder for type xxx is not supported`. After this PR, the error message become more friendly, for example: ``` No Encoder found for abc.xyz.NonEncodable - array element class: "abc.xyz.NonEncodable" - field (class: "scala.Array", name: "arrayField") - root class: "abc.xyz.AnotherClass" ``` Author: Wenchen Fan <wenchen@databricks.com> Closes apache#9810 from cloud-fan/error-message. (cherry picked from commit 3b9d2a3) Signed-off-by: Michael Armbrust <michael@databricks.com>
1 parent 119f92b commit ff156a3

File tree

2 files changed

+129
-23
lines changed

2 files changed

+129
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection {
6363
case t if t <:< definitions.BooleanTpe => BooleanType
6464
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
6565
case _ =>
66-
val className: String = tpe.erasure.typeSymbol.asClass.fullName
66+
val className = getClassNameFromType(tpe)
6767
className match {
6868
case "scala.Array" =>
6969
val TypeRef(_, _, Seq(elementType)) = tpe
@@ -320,9 +320,23 @@ object ScalaReflection extends ScalaReflection {
320320
}
321321
}
322322

323-
/** Returns expressions for extracting all the fields from the given type. */
323+
/**
324+
* Returns expressions for extracting all the fields from the given type.
325+
*
326+
* If the given type is not supported, i.e. there is no encoder can be built for this type,
327+
* an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
328+
* the type path walked so far and which class we are not supporting.
329+
* There are 4 kinds of type path:
330+
* * the root type: `root class: "abc.xyz.MyClass"`
331+
* * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"`
332+
* * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
333+
* * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
334+
*/
324335
def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
325-
extractorFor(inputObject, localTypeOf[T]) match {
336+
val tpe = localTypeOf[T]
337+
val clsName = getClassNameFromType(tpe)
338+
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
339+
extractorFor(inputObject, tpe, walkedTypePath) match {
326340
case s: CreateNamedStruct => s
327341
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
328342
}
@@ -331,7 +345,28 @@ object ScalaReflection extends ScalaReflection {
331345
/** Helper for extracting internal fields from a case class. */
332346
private def extractorFor(
333347
inputObject: Expression,
334-
tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
348+
tpe: `Type`,
349+
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
350+
351+
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
352+
val externalDataType = dataTypeFor(elementType)
353+
val Schema(catalystType, nullable) = silentSchemaFor(elementType)
354+
if (isNativeType(catalystType)) {
355+
NewInstance(
356+
classOf[GenericArrayData],
357+
input :: Nil,
358+
dataType = ArrayType(catalystType, nullable))
359+
} else {
360+
val clsName = getClassNameFromType(elementType)
361+
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
362+
// `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here
363+
// to trigger the type check.
364+
extractorFor(inputObject, elementType, newPath)
365+
366+
MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
367+
}
368+
}
369+
335370
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
336371
inputObject
337372
} else {
@@ -378,15 +413,16 @@ object ScalaReflection extends ScalaReflection {
378413

379414
// For non-primitives, we can just extract the object from the Option and then recurse.
380415
case other =>
381-
val className: String = optType.erasure.typeSymbol.asClass.fullName
416+
val className = getClassNameFromType(optType)
382417
val classObj = Utils.classForName(className)
383418
val optionObjectType = ObjectType(classObj)
419+
val newPath = s"""- option value class: "$className"""" +: walkedTypePath
384420

385421
val unwrapped = UnwrapOption(optionObjectType, inputObject)
386422
expressions.If(
387423
IsNull(unwrapped),
388-
expressions.Literal.create(null, schemaFor(optType).dataType),
389-
extractorFor(unwrapped, optType))
424+
expressions.Literal.create(null, silentSchemaFor(optType).dataType),
425+
extractorFor(unwrapped, optType, newPath))
390426
}
391427

392428
case t if t <:< localTypeOf[Product] =>
@@ -412,7 +448,10 @@ object ScalaReflection extends ScalaReflection {
412448
val fieldName = p.name.toString
413449
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
414450
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
415-
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
451+
val clsName = getClassNameFromType(fieldType)
452+
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
453+
454+
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
416455
})
417456

418457
case t if t <:< localTypeOf[Array[_]] =>
@@ -500,23 +539,11 @@ object ScalaReflection extends ScalaReflection {
500539
Invoke(inputObject, "booleanValue", BooleanType)
501540

502541
case other =>
503-
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
542+
throw new UnsupportedOperationException(
543+
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
504544
}
505545
}
506546
}
507-
508-
private def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
509-
val externalDataType = dataTypeFor(elementType)
510-
val Schema(catalystType, nullable) = schemaFor(elementType)
511-
if (isNativeType(catalystType)) {
512-
NewInstance(
513-
classOf[GenericArrayData],
514-
input :: Nil,
515-
dataType = ArrayType(catalystType, nullable))
516-
} else {
517-
MapObjects(extractorFor(_, elementType), input, externalDataType)
518-
}
519-
}
520547
}
521548

522549
/**
@@ -561,7 +588,7 @@ trait ScalaReflection {
561588

562589
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
563590
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
564-
val className: String = tpe.erasure.typeSymbol.asClass.fullName
591+
val className = getClassNameFromType(tpe)
565592
tpe match {
566593
case t if Utils.classIsLoadable(className) &&
567594
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
@@ -637,6 +664,23 @@ trait ScalaReflection {
637664
}
638665
}
639666

667+
/**
668+
* Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
669+
*
670+
* Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return
671+
* `NullType` silently instead.
672+
*/
673+
private def silentSchemaFor(tpe: `Type`): Schema = try {
674+
schemaFor(tpe)
675+
} catch {
676+
case _: UnsupportedOperationException => Schema(NullType, nullable = true)
677+
}
678+
679+
/** Returns the full class name for a type. */
680+
private def getClassNameFromType(tpe: `Type`): String = {
681+
tpe.erasure.typeSymbol.asClass.fullName
682+
}
683+
640684
/**
641685
* Returns classes of input parameters of scala function object.
642686
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,22 @@
1717

1818
package org.apache.spark.sql.catalyst.encoders
1919

20+
import scala.reflect.ClassTag
21+
2022
import org.apache.spark.SparkFunSuite
2123
import org.apache.spark.sql.Encoders
2224

25+
class NonEncodable(i: Int)
26+
27+
case class ComplexNonEncodable1(name1: NonEncodable)
28+
29+
case class ComplexNonEncodable2(name2: ComplexNonEncodable1)
30+
31+
case class ComplexNonEncodable3(name3: Option[NonEncodable])
32+
33+
case class ComplexNonEncodable4(name4: Array[NonEncodable])
34+
35+
case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]])
2336

2437
class EncoderErrorMessageSuite extends SparkFunSuite {
2538

@@ -37,4 +50,53 @@ class EncoderErrorMessageSuite extends SparkFunSuite {
3750
intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] }
3851
intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] }
3952
}
53+
54+
test("nice error message for missing encoder") {
55+
val errorMsg1 =
56+
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage
57+
assert(errorMsg1.contains(
58+
s"""root class: "${clsName[ComplexNonEncodable1]}""""))
59+
assert(errorMsg1.contains(
60+
s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))
61+
62+
val errorMsg2 =
63+
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage
64+
assert(errorMsg2.contains(
65+
s"""root class: "${clsName[ComplexNonEncodable2]}""""))
66+
assert(errorMsg2.contains(
67+
s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")"""))
68+
assert(errorMsg1.contains(
69+
s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))
70+
71+
val errorMsg3 =
72+
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage
73+
assert(errorMsg3.contains(
74+
s"""root class: "${clsName[ComplexNonEncodable3]}""""))
75+
assert(errorMsg3.contains(
76+
s"""field (class: "scala.Option", name: "name3")"""))
77+
assert(errorMsg3.contains(
78+
s"""option value class: "${clsName[NonEncodable]}""""))
79+
80+
val errorMsg4 =
81+
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage
82+
assert(errorMsg4.contains(
83+
s"""root class: "${clsName[ComplexNonEncodable4]}""""))
84+
assert(errorMsg4.contains(
85+
s"""field (class: "scala.Array", name: "name4")"""))
86+
assert(errorMsg4.contains(
87+
s"""array element class: "${clsName[NonEncodable]}""""))
88+
89+
val errorMsg5 =
90+
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage
91+
assert(errorMsg5.contains(
92+
s"""root class: "${clsName[ComplexNonEncodable5]}""""))
93+
assert(errorMsg5.contains(
94+
s"""field (class: "scala.Option", name: "name5")"""))
95+
assert(errorMsg5.contains(
96+
s"""option value class: "scala.Array""""))
97+
assert(errorMsg5.contains(
98+
s"""array element class: "${clsName[NonEncodable]}""""))
99+
}
100+
101+
private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName
40102
}

0 commit comments

Comments
 (0)