Skip to content

Commit 0665fb5

Browse files
committed
[SPARK-11636][SQL] Support classes defined in the REPL with Encoders
#theScaryParts (i.e. changes to the repl, executor classloaders and codegen)... Author: Michael Armbrust <michael@databricks.com> Author: Yin Huai <yhuai@databricks.com> Closes apache#9825 from marmbrus/dataset-replClasses2. (cherry picked from commit 4b84c72) Signed-off-by: Michael Armbrust <michael@databricks.com>
1 parent 11a11f0 commit 0665fb5

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,10 +1221,16 @@ import org.apache.spark.annotation.DeveloperApi
12211221
)
12221222
}
12231223

1224-
val preamble = """
1225-
|class %s extends Serializable {
1226-
| %s%s%s
1227-
""".stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute))
1224+
val preamble = s"""
1225+
|class ${lineRep.readName} extends Serializable {
1226+
| ${envLines.map(" " + _ + ";\n").mkString}
1227+
| $importsPreamble
1228+
|
1229+
| // If we need to construct any objects defined in the REPL on an executor we will need
1230+
| // to pass the outer scope to the appropriate encoder.
1231+
| org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this)
1232+
| ${indentCode(toCompute)}
1233+
""".stripMargin
12281234
val postamble = importsTrailer + "\n}" + "\n" +
12291235
"object " + lineRep.readName + " {\n" +
12301236
" val INSTANCE = new " + lineRep.readName + "();\n" +

repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ class ReplSuite extends SparkFunSuite {
262262
|import sqlContext.implicits._
263263
|case class TestCaseClass(value: Int)
264264
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect()
265+
|
266+
|// Test Dataset Serialization in the REPL
267+
|Seq(TestCaseClass(1)).toDS().collect()
265268
""".stripMargin)
266269
assertDoesNotContain("error:", output)
267270
assertDoesNotContain("Exception", output)
@@ -278,6 +281,27 @@ class ReplSuite extends SparkFunSuite {
278281
assertDoesNotContain("java.lang.ClassNotFoundException", output)
279282
}
280283

284+
test("Datasets and encoders") {
285+
val output = runInterpreter("local",
286+
"""
287+
|import org.apache.spark.sql.functions._
288+
|import org.apache.spark.sql.Encoder
289+
|import org.apache.spark.sql.expressions.Aggregator
290+
|import org.apache.spark.sql.TypedColumn
291+
|val simpleSum = new Aggregator[Int, Int, Int] with Serializable {
292+
| def zero: Int = 0 // The initial value.
293+
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
294+
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
295+
| def finish(b: Int) = b // Return the final result.
296+
|}.toColumn
297+
|
298+
|val ds = Seq(1, 2, 3, 4).toDS()
299+
|ds.select(simpleSum).collect
300+
""".stripMargin)
301+
assertDoesNotContain("error:", output)
302+
assertDoesNotContain("Exception", output)
303+
}
304+
281305
test("SPARK-2632 importing a method from non serializable class and not using it.") {
282306
val output = runInterpreter("local",
283307
"""

repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,13 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
6565
case e: ClassNotFoundException => {
6666
val classOption = findClassLocally(name)
6767
classOption match {
68-
case None => throw new ClassNotFoundException(name, e)
68+
case None =>
69+
// If this class has a cause, it will break the internal assumption of Janino
70+
// (the compiler used for Spark SQL code-gen).
71+
// See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see
72+
// its behavior will be changed if there is a cause and the compilation
73+
// of generated class will fail.
74+
throw new ClassNotFoundException(name)
6975
case Some(a) => a
7076
}
7177
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, ArrayData}
3131
import org.apache.spark.sql.types._
3232
import org.apache.spark.unsafe.Platform
3333
import org.apache.spark.unsafe.types._
34-
34+
import org.apache.spark.util.Utils
3535

3636
/**
3737
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
@@ -536,7 +536,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
536536
*/
537537
private[this] def doCompile(code: String): GeneratedClass = {
538538
val evaluator = new ClassBodyEvaluator()
539-
evaluator.setParentClassLoader(getClass.getClassLoader)
539+
evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader)
540540
// Cannot be under package codegen, or fail with java.lang.InstantiationException
541541
evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass")
542542
evaluator.setDefaultImports(Array(

0 commit comments

Comments
 (0)