Skip to content

Commit 9677b44

Browse files
Davies LiuJoshRosen
authored andcommitted
[SPARK-6886] [PySpark] fix big closure with shuffle
Currently, the created broadcast object will have same life cycle as RDD in Python. For multistage jobs, an PythonRDD will be created in JVM and the RDD in Python may be GCed, then the broadcast will be destroyed in JVM before the PythonRDD. This PR change to use PythonRDD to track the lifecycle of the broadcast object. It also have a refactor about getNumPartitions() to avoid unnecessary creation of PythonRDD, which could be heavy. cc JoshRosen Author: Davies Liu <davies@databricks.com> Closes apache#5496 from davies/big_closure and squashes the following commits: 9a0ea4c [Davies Liu] fix big closure with shuffle Conflicts: python/pyspark/rdd.py
1 parent 8e9fc27 commit 9677b44

File tree

2 files changed

+8
-12
lines changed

2 files changed

+8
-12
lines changed

python/pyspark/rdd.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ def take(self, num):
11001100
[91, 92, 93]
11011101
"""
11021102
items = []
1103-
totalParts = self._jrdd.partitions().size()
1103+
totalParts = self.getNumPartitions()
11041104
partsScanned = 0
11051105

11061106
while len(items) < num and partsScanned < totalParts:
@@ -2105,12 +2105,9 @@ def pipeline_func(split, iterator):
21052105
self._jrdd_deserializer = self.ctx.serializer
21062106
self._bypass_serializer = False
21072107
self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
2108-
self._broadcast = None
21092108

2110-
def __del__(self):
2111-
if self._broadcast:
2112-
self._broadcast.unpersist()
2113-
self._broadcast = None
2109+
def getNumPartitions(self):
2110+
return self._prev_jrdd.partitions().size()
21142111

21152112
@property
21162113
def _jrdd(self):
@@ -2126,8 +2123,9 @@ def _jrdd(self):
21262123
ser = CloudPickleSerializer()
21272124
pickled_command = ser.dumps(command)
21282125
if len(pickled_command) > (1 << 20): # 1M
2129-
self._broadcast = self.ctx.broadcast(pickled_command)
2130-
pickled_command = ser.dumps(self._broadcast)
2126+
# The broadcast will have same life cycle as created PythonRDD
2127+
broadcast = self.ctx.broadcast(pickled_command)
2128+
pickled_command = ser.dumps(broadcast)
21312129
broadcast_vars = ListConverter().convert(
21322130
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
21332131
self.ctx._gateway._gateway_client)

python/pyspark/tests.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -521,10 +521,8 @@ def test_large_closure(self):
521521
data = [float(i) for i in xrange(N)]
522522
rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
523523
self.assertEquals(N, rdd.first())
524-
self.assertTrue(rdd._broadcast is not None)
525-
rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1)
526-
self.assertEqual(1, rdd.first())
527-
self.assertTrue(rdd._broadcast is None)
524+
# regression test for SPARK-6886
525+
self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
528526

529527
def test_zip_with_different_serializers(self):
530528
a = self.sc.parallelize(range(5))

0 commit comments

Comments
 (0)