Skip to content

Commit 012de2c

Browse files
jerryshaozsxwing
authored andcommitted
[SPARK-12002][STREAMING][PYSPARK] Fix python direct stream checkpoint recovery issue
Fixed a minor race condition in apache#10017 Closes apache#10017 Author: jerryshao <sshao@hortonworks.com> Author: Shixiong Zhu <shixiong@databricks.com> Closes apache#10074 from zsxwing/review-pr10017. (cherry picked from commit f292018) Signed-off-by: Shixiong Zhu <shixiong@databricks.com>
1 parent 5647774 commit 012de2c

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

python/pyspark/streaming/tests.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,55 @@ def test_topic_and_partition_equality(self):
11491149
self.assertNotEqual(topic_and_partition_a, topic_and_partition_c)
11501150
self.assertNotEqual(topic_and_partition_a, topic_and_partition_d)
11511151

1152+
@unittest.skipIf(sys.version >= "3", "long type not support")
1153+
def test_kafka_direct_stream_transform_with_checkpoint(self):
1154+
"""Test the Python direct Kafka stream transform with checkpoint correctly recovered."""
1155+
topic = self._randomTopic()
1156+
sendData = {"a": 1, "b": 2, "c": 3}
1157+
kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(),
1158+
"auto.offset.reset": "smallest"}
1159+
1160+
self._kafkaTestUtils.createTopic(topic)
1161+
self._kafkaTestUtils.sendMessages(topic, sendData)
1162+
1163+
offsetRanges = []
1164+
1165+
def transformWithOffsetRanges(rdd):
1166+
for o in rdd.offsetRanges():
1167+
offsetRanges.append(o)
1168+
return rdd
1169+
1170+
self.ssc.stop(False)
1171+
self.ssc = None
1172+
tmpdir = "checkpoint-test-%d" % random.randint(0, 10000)
1173+
1174+
def setup():
1175+
ssc = StreamingContext(self.sc, 0.5)
1176+
ssc.checkpoint(tmpdir)
1177+
stream = KafkaUtils.createDirectStream(ssc, [topic], kafkaParams)
1178+
stream.transform(transformWithOffsetRanges).count().pprint()
1179+
return ssc
1180+
1181+
try:
1182+
ssc1 = StreamingContext.getOrCreate(tmpdir, setup)
1183+
ssc1.start()
1184+
self.wait_for(offsetRanges, 1)
1185+
self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))])
1186+
1187+
# To make sure some checkpoint is written
1188+
time.sleep(3)
1189+
ssc1.stop(False)
1190+
ssc1 = None
1191+
1192+
# Restart again to make sure the checkpoint is recovered correctly
1193+
ssc2 = StreamingContext.getOrCreate(tmpdir, setup)
1194+
ssc2.start()
1195+
ssc2.awaitTermination(3)
1196+
ssc2.stop(stopSparkContext=False, stopGraceFully=True)
1197+
ssc2 = None
1198+
finally:
1199+
shutil.rmtree(tmpdir)
1200+
11521201
@unittest.skipIf(sys.version >= "3", "long type not support")
11531202
def test_kafka_rdd_message_handler(self):
11541203
"""Test Python direct Kafka RDD MessageHandler."""

python/pyspark/streaming/util.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ def __init__(self, ctx, func, *deserializers):
3737
self.ctx = ctx
3838
self.func = func
3939
self.deserializers = deserializers
40-
self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
40+
self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
4141
self.failure = None
4242

4343
def rdd_wrapper(self, func):
44-
self._rdd_wrapper = func
44+
self.rdd_wrap_func = func
4545
return self
4646

4747
def call(self, milliseconds, jrdds):
@@ -59,7 +59,7 @@ def call(self, milliseconds, jrdds):
5959
if len(sers) < len(jrdds):
6060
sers += (sers[0],) * (len(jrdds) - len(sers))
6161

62-
rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
62+
rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None
6363
for jrdd, ser in zip(jrdds, sers)]
6464
t = datetime.fromtimestamp(milliseconds / 1000.0)
6565
r = self.func(t, *rdds)
@@ -101,16 +101,17 @@ def dumps(self, id):
101101
self.failure = None
102102
try:
103103
func = self.gateway.gateway_property.pool[id]
104-
return bytearray(self.serializer.dumps((func.func, func.deserializers)))
104+
return bytearray(self.serializer.dumps((
105+
func.func, func.rdd_wrap_func, func.deserializers)))
105106
except:
106107
self.failure = traceback.format_exc()
107108

108109
def loads(self, data):
109110
# Clear the failure
110111
self.failure = None
111112
try:
112-
f, deserializers = self.serializer.loads(bytes(data))
113-
return TransformFunction(self.ctx, f, *deserializers)
113+
f, wrap_func, deserializers = self.serializer.loads(bytes(data))
114+
return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func)
114115
except:
115116
self.failure = traceback.format_exc()
116117

0 commit comments

Comments
 (0)