Skip to content

Commit 74a2306

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-11856][SQL] add type cast if the real type is different but compatible with encoder schema
When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff and lost the required data type, which may lead to runtime error if the real type doesn't match the encoder's schema. For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type is `[a: int, b: long]`, then we will hit runtime error and say that we can't construct class `Data` with int and long, because we lost the information that `b` should be a string. Author: Wenchen Fan <wenchen@databricks.com> Closes apache#9840 from cloud-fan/err-msg. (cherry picked from commit 9df2462) Signed-off-by: Michael Armbrust <michael@databricks.com>
1 parent 6e3e3c6 commit 74a2306

File tree

10 files changed

+335
-32
lines changed

10 files changed

+335
-32
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
package org.apache.spark.sql.catalyst
1919

2020
import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
21-
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils}
21+
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2423
import org.apache.spark.sql.types._
2524
import org.apache.spark.unsafe.types.UTF8String
2625
import org.apache.spark.util.Utils
@@ -117,31 +116,75 @@ object ScalaReflection extends ScalaReflection {
117116
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
118117
* calling resolve/bind with a new schema.
119118
*/
120-
def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None)
119+
def constructorFor[T : TypeTag]: Expression = {
120+
val tpe = localTypeOf[T]
121+
val clsName = getClassNameFromType(tpe)
122+
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
123+
constructorFor(tpe, None, walkedTypePath)
124+
}
121125

122126
private def constructorFor(
123127
tpe: `Type`,
124-
path: Option[Expression]): Expression = ScalaReflectionLock.synchronized {
128+
path: Option[Expression],
129+
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
125130

126131
/** Returns the current path with a sub-field extracted. */
127-
def addToPath(part: String): Expression = path
128-
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
129-
.getOrElse(UnresolvedAttribute(part))
132+
def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
133+
val newPath = path
134+
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
135+
.getOrElse(UnresolvedAttribute(part))
136+
upCastToExpectedType(newPath, dataType, walkedTypePath)
137+
}
130138

131139
/** Returns the current path with a field at ordinal extracted. */
132-
def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
133-
.map(p => GetStructField(p, ordinal))
134-
.getOrElse(BoundReference(ordinal, dataType, false))
140+
def addToPathOrdinal(
141+
ordinal: Int,
142+
dataType: DataType,
143+
walkedTypePath: Seq[String]): Expression = {
144+
val newPath = path
145+
.map(p => GetStructField(p, ordinal))
146+
.getOrElse(BoundReference(ordinal, dataType, false))
147+
upCastToExpectedType(newPath, dataType, walkedTypePath)
148+
}
135149

136150
/** Returns the current path or `BoundReference`. */
137-
def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
151+
def getPath: Expression = {
152+
val dataType = schemaFor(tpe).dataType
153+
if (path.isDefined) {
154+
path.get
155+
} else {
156+
upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath)
157+
}
158+
}
159+
160+
/**
161+
* When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff
162+
* and lost the required data type, which may lead to runtime error if the real type doesn't
163+
* match the encoder's schema.
164+
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
165+
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
166+
* `Data` with int and long, because we lost the information that `b` should be a string.
167+
*
168+
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
169+
* don't need to cast struct type because there must be `UnresolvedExtractValue` or
170+
* `GetStructField` wrapping it, thus we only need to handle leaf type.
171+
*/
172+
def upCastToExpectedType(
173+
expr: Expression,
174+
expected: DataType,
175+
walkedTypePath: Seq[String]): Expression = expected match {
176+
case _: StructType => expr
177+
case _ => UpCast(expr, expected, walkedTypePath)
178+
}
138179

139180
tpe match {
140181
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
141182

142183
case t if t <:< localTypeOf[Option[_]] =>
143184
val TypeRef(_, _, Seq(optType)) = t
144-
WrapOption(constructorFor(optType, path))
185+
val className = getClassNameFromType(optType)
186+
val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath
187+
WrapOption(constructorFor(optType, path, newTypePath))
145188

146189
case t if t <:< localTypeOf[java.lang.Integer] =>
147190
val boxedType = classOf[java.lang.Integer]
@@ -219,9 +262,11 @@ object ScalaReflection extends ScalaReflection {
219262
primitiveMethod.map { method =>
220263
Invoke(getPath, method, arrayClassFor(elementType))
221264
}.getOrElse {
265+
val className = getClassNameFromType(elementType)
266+
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
222267
Invoke(
223268
MapObjects(
224-
p => constructorFor(elementType, Some(p)),
269+
p => constructorFor(elementType, Some(p), newTypePath),
225270
getPath,
226271
schemaFor(elementType).dataType),
227272
"array",
@@ -230,10 +275,12 @@ object ScalaReflection extends ScalaReflection {
230275

231276
case t if t <:< localTypeOf[Seq[_]] =>
232277
val TypeRef(_, _, Seq(elementType)) = t
278+
val className = getClassNameFromType(elementType)
279+
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
233280
val arrayData =
234281
Invoke(
235282
MapObjects(
236-
p => constructorFor(elementType, Some(p)),
283+
p => constructorFor(elementType, Some(p), newTypePath),
237284
getPath,
238285
schemaFor(elementType).dataType),
239286
"array",
@@ -246,12 +293,13 @@ object ScalaReflection extends ScalaReflection {
246293
arrayData :: Nil)
247294

248295
case t if t <:< localTypeOf[Map[_, _]] =>
296+
// TODO: add walked type path for map
249297
val TypeRef(_, _, Seq(keyType, valueType)) = t
250298

251299
val keyData =
252300
Invoke(
253301
MapObjects(
254-
p => constructorFor(keyType, Some(p)),
302+
p => constructorFor(keyType, Some(p), walkedTypePath),
255303
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
256304
schemaFor(keyType).dataType),
257305
"array",
@@ -260,7 +308,7 @@ object ScalaReflection extends ScalaReflection {
260308
val valueData =
261309
Invoke(
262310
MapObjects(
263-
p => constructorFor(valueType, Some(p)),
311+
p => constructorFor(valueType, Some(p), walkedTypePath),
264312
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
265313
schemaFor(valueType).dataType),
266314
"array",
@@ -297,12 +345,19 @@ object ScalaReflection extends ScalaReflection {
297345
val fieldName = p.name.toString
298346
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
299347
val dataType = schemaFor(fieldType).dataType
300-
348+
val clsName = getClassNameFromType(fieldType)
349+
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
301350
// For tuples, we based grab the inner fields by ordinal instead of name.
302351
if (cls.getName startsWith "scala.Tuple") {
303-
constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
352+
constructorFor(
353+
fieldType,
354+
Some(addToPathOrdinal(i, dataType, newTypePath)),
355+
newTypePath)
304356
} else {
305-
constructorFor(fieldType, Some(addToPath(fieldName)))
357+
constructorFor(
358+
fieldType,
359+
Some(addToPath(fieldName, dataType, newTypePath)),
360+
newTypePath)
306361
}
307362
}
308363

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class Analyzer(
7272
ResolveReferences ::
7373
ResolveGroupingAnalytics ::
7474
ResolvePivot ::
75+
ResolveUpCast ::
7576
ResolveSortReferences ::
7677
ResolveGenerate ::
7778
ResolveFunctions ::
@@ -1182,3 +1183,42 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
11821183
}
11831184
}
11841185
}
1186+
1187+
/**
1188+
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
1189+
*/
1190+
object ResolveUpCast extends Rule[LogicalPlan] {
1191+
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
1192+
throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " +
1193+
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
1194+
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
1195+
"You can either add an explicit cast to the input data or choose a higher precision " +
1196+
"type of the field in the target object")
1197+
}
1198+
1199+
private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
1200+
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
1201+
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
1202+
toPrecedence > 0 && fromPrecedence > toPrecedence
1203+
}
1204+
1205+
def apply(plan: LogicalPlan): LogicalPlan = {
1206+
plan transformAllExpressions {
1207+
case u @ UpCast(child, _, _) if !child.resolved => u
1208+
1209+
case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
1210+
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
1211+
fail(child, to, walkedTypePath)
1212+
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
1213+
fail(child, to, walkedTypePath)
1214+
case (from, to) if illegalNumericPrecedence(from, to) =>
1215+
fail(child, to, walkedTypePath)
1216+
case (TimestampType, DateType) =>
1217+
fail(child, DateType, walkedTypePath)
1218+
case (StringType, to: NumericType) =>
1219+
fail(child, to, walkedTypePath)
1220+
case _ => Cast(child, dataType)
1221+
}
1222+
}
1223+
}
1224+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ object HiveTypeCoercion {
5353

5454
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
5555
// The conversion for integral and floating point types have a linear widening hierarchy:
56-
private val numericPrecedence =
56+
private[sql] val numericPrecedence =
5757
IndexedSeq(
5858
ByteType,
5959
ShortType,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract
2828
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
2929
import org.apache.spark.sql.catalyst.expressions._
3030
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
31+
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
3132
import org.apache.spark.sql.catalyst.InternalRow
3233
import org.apache.spark.sql.catalyst.ScalaReflection
3334
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
@@ -235,12 +236,13 @@ case class ExpressionEncoder[T](
235236

236237
val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
237238
val analyzedPlan = SimpleAnalyzer.execute(plan)
239+
val optimizedPlan = SimplifyCasts(analyzedPlan)
238240

239241
// In order to construct instances of inner classes (for example those declared in a REPL cell),
240242
// we need an instance of the outer scope. This rule substitues those outer objects into
241243
// expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
242244
// registry.
243-
copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform {
245+
copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform {
244246
case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
245247
val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
246248
if (outer == null) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,3 +915,12 @@ case class Cast(child: Expression, dataType: DataType)
915915
"""
916916
}
917917
}
918+
919+
/**
920+
* Cast the child expression to the target data type, but will throw error if the cast might
921+
* truncate, e.g. long -> int, timestamp -> data.
922+
*/
923+
case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String])
924+
extends UnaryExpression with Unevaluable {
925+
override lazy val resolved = false
926+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
126126
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
127127

128128
/**
129-
* Returns Aliased [[Expressions]] that could be used to construct a flattened version of this
129+
* Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
130130
* StructType.
131131
*/
132132
def flatten: Seq[NamedExpression] = valExprs.zip(names).map {

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
9090
case _ => false
9191
}
9292

93+
/**
94+
* Returns whether this DecimalType is tighter than `other`. If yes, it means `this`
95+
* can be casted into `other` safely without losing any precision or range.
96+
*/
97+
private[sql] def isTighterThan(other: DataType): Boolean = other match {
98+
case dt: DecimalType =>
99+
(precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale
100+
case dt: IntegralType =>
101+
isTighterThan(DecimalType.forType(dt))
102+
case _ => false
103+
}
104+
93105
/**
94106
* The default size of a value of the DecimalType is 4096 bytes.
95107
*/

0 commit comments

Comments
 (0)