18
18
package org .apache .spark .sql .catalyst
19
19
20
20
import org .apache .spark .sql .catalyst .analysis .{UnresolvedExtractValue , UnresolvedAttribute }
21
- import org .apache .spark .sql .catalyst .util .{GenericArrayData , ArrayBasedMapData , ArrayData , DateTimeUtils }
21
+ import org .apache .spark .sql .catalyst .util .{GenericArrayData , ArrayBasedMapData , DateTimeUtils }
22
22
import org .apache .spark .sql .catalyst .expressions ._
23
- import org .apache .spark .sql .catalyst .plans .logical .LocalRelation
24
23
import org .apache .spark .sql .types ._
25
24
import org .apache .spark .unsafe .types .UTF8String
26
25
import org .apache .spark .util .Utils
@@ -117,31 +116,75 @@ object ScalaReflection extends ScalaReflection {
117
116
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
118
117
* calling resolve/bind with a new schema.
119
118
*/
120
- def constructorFor [T : TypeTag ]: Expression = constructorFor(localTypeOf[T ], None )
119
+ def constructorFor [T : TypeTag ]: Expression = {
120
+ val tpe = localTypeOf[T ]
121
+ val clsName = getClassNameFromType(tpe)
122
+ val walkedTypePath = s """ - root class: " ${clsName}" """ :: Nil
123
+ constructorFor(tpe, None , walkedTypePath)
124
+ }
121
125
122
126
private def constructorFor (
123
127
tpe : `Type`,
124
- path : Option [Expression ]): Expression = ScalaReflectionLock .synchronized {
128
+ path : Option [Expression ],
129
+ walkedTypePath : Seq [String ]): Expression = ScalaReflectionLock .synchronized {
125
130
126
131
/** Returns the current path with a sub-field extracted. */
127
- def addToPath (part : String ): Expression = path
128
- .map(p => UnresolvedExtractValue (p, expressions.Literal (part)))
129
- .getOrElse(UnresolvedAttribute (part))
132
+ def addToPath (part : String , dataType : DataType , walkedTypePath : Seq [String ]): Expression = {
133
+ val newPath = path
134
+ .map(p => UnresolvedExtractValue (p, expressions.Literal (part)))
135
+ .getOrElse(UnresolvedAttribute (part))
136
+ upCastToExpectedType(newPath, dataType, walkedTypePath)
137
+ }
130
138
131
139
/** Returns the current path with a field at ordinal extracted. */
132
- def addToPathOrdinal (ordinal : Int , dataType : DataType ): Expression = path
133
- .map(p => GetStructField (p, ordinal))
134
- .getOrElse(BoundReference (ordinal, dataType, false ))
140
+ def addToPathOrdinal (
141
+ ordinal : Int ,
142
+ dataType : DataType ,
143
+ walkedTypePath : Seq [String ]): Expression = {
144
+ val newPath = path
145
+ .map(p => GetStructField (p, ordinal))
146
+ .getOrElse(BoundReference (ordinal, dataType, false ))
147
+ upCastToExpectedType(newPath, dataType, walkedTypePath)
148
+ }
135
149
136
150
/** Returns the current path or `BoundReference`. */
137
- def getPath : Expression = path.getOrElse(BoundReference (0 , schemaFor(tpe).dataType, true ))
151
+ def getPath : Expression = {
152
+ val dataType = schemaFor(tpe).dataType
153
+ if (path.isDefined) {
154
+ path.get
155
+ } else {
156
+ upCastToExpectedType(BoundReference (0 , dataType, true ), dataType, walkedTypePath)
157
+ }
158
+ }
159
+
160
+ /**
161
+ * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff
162
+ * and lost the required data type, which may lead to runtime error if the real type doesn't
163
+ * match the encoder's schema.
164
+ * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
165
+ * is [a: int, b: long], then we will hit runtime error and say that we can't construct class
166
+ * `Data` with int and long, because we lost the information that `b` should be a string.
167
+ *
168
+ * This method help us "remember" the required data type by adding a `UpCast`. Note that we
169
+ * don't need to cast struct type because there must be `UnresolvedExtractValue` or
170
+ * `GetStructField` wrapping it, thus we only need to handle leaf type.
171
+ */
172
+ def upCastToExpectedType (
173
+ expr : Expression ,
174
+ expected : DataType ,
175
+ walkedTypePath : Seq [String ]): Expression = expected match {
176
+ case _ : StructType => expr
177
+ case _ => UpCast (expr, expected, walkedTypePath)
178
+ }
138
179
139
180
tpe match {
140
181
case t if ! dataTypeFor(t).isInstanceOf [ObjectType ] => getPath
141
182
142
183
case t if t <:< localTypeOf[Option [_]] =>
143
184
val TypeRef (_, _, Seq (optType)) = t
144
- WrapOption (constructorFor(optType, path))
185
+ val className = getClassNameFromType(optType)
186
+ val newTypePath = s """ - option value class: " $className" """ +: walkedTypePath
187
+ WrapOption (constructorFor(optType, path, newTypePath))
145
188
146
189
case t if t <:< localTypeOf[java.lang.Integer ] =>
147
190
val boxedType = classOf [java.lang.Integer ]
@@ -219,9 +262,11 @@ object ScalaReflection extends ScalaReflection {
219
262
primitiveMethod.map { method =>
220
263
Invoke (getPath, method, arrayClassFor(elementType))
221
264
}.getOrElse {
265
+ val className = getClassNameFromType(elementType)
266
+ val newTypePath = s """ - array element class: " $className" """ +: walkedTypePath
222
267
Invoke (
223
268
MapObjects (
224
- p => constructorFor(elementType, Some (p)),
269
+ p => constructorFor(elementType, Some (p), newTypePath ),
225
270
getPath,
226
271
schemaFor(elementType).dataType),
227
272
" array" ,
@@ -230,10 +275,12 @@ object ScalaReflection extends ScalaReflection {
230
275
231
276
case t if t <:< localTypeOf[Seq [_]] =>
232
277
val TypeRef (_, _, Seq (elementType)) = t
278
+ val className = getClassNameFromType(elementType)
279
+ val newTypePath = s """ - array element class: " $className" """ +: walkedTypePath
233
280
val arrayData =
234
281
Invoke (
235
282
MapObjects (
236
- p => constructorFor(elementType, Some (p)),
283
+ p => constructorFor(elementType, Some (p), newTypePath ),
237
284
getPath,
238
285
schemaFor(elementType).dataType),
239
286
" array" ,
@@ -246,12 +293,13 @@ object ScalaReflection extends ScalaReflection {
246
293
arrayData :: Nil )
247
294
248
295
case t if t <:< localTypeOf[Map [_, _]] =>
296
+ // TODO: add walked type path for map
249
297
val TypeRef (_, _, Seq (keyType, valueType)) = t
250
298
251
299
val keyData =
252
300
Invoke (
253
301
MapObjects (
254
- p => constructorFor(keyType, Some (p)),
302
+ p => constructorFor(keyType, Some (p), walkedTypePath ),
255
303
Invoke (getPath, " keyArray" , ArrayType (schemaFor(keyType).dataType)),
256
304
schemaFor(keyType).dataType),
257
305
" array" ,
@@ -260,7 +308,7 @@ object ScalaReflection extends ScalaReflection {
260
308
val valueData =
261
309
Invoke (
262
310
MapObjects (
263
- p => constructorFor(valueType, Some (p)),
311
+ p => constructorFor(valueType, Some (p), walkedTypePath ),
264
312
Invoke (getPath, " valueArray" , ArrayType (schemaFor(valueType).dataType)),
265
313
schemaFor(valueType).dataType),
266
314
" array" ,
@@ -297,12 +345,19 @@ object ScalaReflection extends ScalaReflection {
297
345
val fieldName = p.name.toString
298
346
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
299
347
val dataType = schemaFor(fieldType).dataType
300
-
348
+ val clsName = getClassNameFromType(fieldType)
349
+ val newTypePath = s """ - field (class: " $clsName", name: " $fieldName") """ +: walkedTypePath
301
350
// For tuples, we based grab the inner fields by ordinal instead of name.
302
351
if (cls.getName startsWith " scala.Tuple" ) {
303
- constructorFor(fieldType, Some (addToPathOrdinal(i, dataType)))
352
+ constructorFor(
353
+ fieldType,
354
+ Some (addToPathOrdinal(i, dataType, newTypePath)),
355
+ newTypePath)
304
356
} else {
305
- constructorFor(fieldType, Some (addToPath(fieldName)))
357
+ constructorFor(
358
+ fieldType,
359
+ Some (addToPath(fieldName, dataType, newTypePath)),
360
+ newTypePath)
306
361
}
307
362
}
308
363
0 commit comments