Skip to content

Commit 7564c24

Browse files
committed
[SPARK-10731] [SQL] Delegate to Scala's DataFrame.take implementation in Python DataFrame.
Python DataFrame.head/take now requires scanning all the partitions. This pull request changes them to delegate the actual implementation to Scala DataFrame (by calling DataFrame.take). This is more of a hack for fixing this issue in 1.5.1. A more proper fix is to change executeCollect and executeTake to return InternalRow rather than Row, and thus eliminate the extra round-trip conversion. Author: Reynold Xin <rxin@databricks.com> Closes apache#8876 from rxin/SPARK-10731. (cherry picked from commit 9952217) Signed-off-by: Reynold Xin <rxin@databricks.com>
1 parent 64cc62c commit 7564c24

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ private[spark] object PythonRDD extends Logging {
633633
*
634634
* The thread will terminate after all the data are sent or any exceptions happen.
635635
*/
636-
private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
636+
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
637637
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
638638
// Close the socket if no connection in 3 seconds
639639
serverSocket.setSoTimeout(3000)

python/pyspark/sql/dataframe.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,10 @@ def take(self, num):
300300
>>> df.take(2)
301301
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
302302
"""
303-
return self.limit(num).collect()
303+
with SCCallSiteSync(self._sc) as css:
304+
port = self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe(
305+
self._jdf, num)
306+
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
304307

305308
@ignore_unicode_prefix
306309
@since(1.3)

sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala renamed to sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ import org.apache.spark.annotation.DeveloperApi
2828
import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil}
2929
import org.apache.spark.broadcast.Broadcast
3030
import org.apache.spark.rdd.RDD
31-
import org.apache.spark.sql.catalyst.InternalRow
31+
import org.apache.spark.sql.DataFrame
32+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
3233
import org.apache.spark.sql.catalyst.expressions._
3334
import org.apache.spark.sql.catalyst.plans.logical
3435
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -118,6 +119,17 @@ object EvaluatePython {
118119
def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
119120
new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
120121

122+
def takeAndServe(df: DataFrame, n: Int): Int = {
123+
registerPicklers()
124+
// This is an annoying hack - we should refactor the code so executeCollect and executeTake
125+
// returns InternalRow rather than Row.
126+
val converter = CatalystTypeConverters.createToCatalystConverter(df.schema)
127+
val iter = new SerDeUtil.AutoBatchedPickler(df.take(n).iterator.map { row =>
128+
EvaluatePython.toJava(converter(row).asInstanceOf[InternalRow], df.schema)
129+
})
130+
PythonRDD.serveIterator(iter, s"serve-DataFrame")
131+
}
132+
121133
/**
122134
* Helper for converting from Catalyst type to java type suitable for Pyrolite.
123135
*/

sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,20 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
3939

4040
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
4141

42-
override def serialize(obj: Any): Seq[Double] = {
42+
override def serialize(obj: Any): GenericArrayData = {
4343
obj match {
4444
case p: ExamplePoint =>
45-
Seq(p.x, p.y)
45+
val output = new Array[Any](2)
46+
output(0) = p.x
47+
output(1) = p.y
48+
new GenericArrayData(output)
4649
}
4750
}
4851

4952
override def deserialize(datum: Any): ExamplePoint = {
5053
datum match {
51-
case values: Seq[_] =>
52-
val xy = values.asInstanceOf[Seq[Double]]
53-
assert(xy.length == 2)
54-
new ExamplePoint(xy(0), xy(1))
55-
case values: util.ArrayList[_] =>
56-
val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
57-
new ExamplePoint(xy(0), xy(1))
54+
case values: ArrayData =>
55+
new ExamplePoint(values.getDouble(0), values.getDouble(1))
5856
}
5957
}
6058

0 commit comments

Comments
 (0)