Skip to content

Commit fa53e38

Browse files
handle keyboard interrupt during reconnect (Fixes miguelgrinberg#301)
1 parent 8a4e5ff commit fa53e38

File tree

4 files changed

+107
-34
lines changed

4 files changed

+107
-34
lines changed

socketio/asyncio_client.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ async def _trigger_event(self, event, namespace, *args):
355355
event, *args)
356356

357357
async def _handle_reconnect(self):
358+
self._reconnect_abort.clear()
359+
client.reconnecting_clients.append(self)
358360
attempt_count = 0
359361
current_delay = self.reconnection_delay
360362
while True:
@@ -366,7 +368,12 @@ async def _handle_reconnect(self):
366368
self.logger.info(
367369
'Connection failed, new attempt in {:.02f} seconds'.format(
368370
delay))
369-
await self.sleep(delay)
371+
try:
372+
await asyncio.wait_for(self._reconnect_abort.wait(), delay)
373+
self.logger.info('Reconnect task aborted')
374+
break
375+
except (asyncio.TimeoutError, asyncio.CancelledError):
376+
pass
370377
attempt_count += 1
371378
try:
372379
await self.connect(self.connection_url,
@@ -385,6 +392,7 @@ async def _handle_reconnect(self):
385392
self.logger.info(
386393
'Maximum reconnection attempts reached, giving up')
387394
break
395+
client.reconnecting_clients.remove(self)
388396

389397
def _handle_eio_connect(self):
390398
"""Handle the Engine.IO connection event."""
@@ -422,6 +430,7 @@ async def _handle_eio_message(self, data):
422430
async def _handle_eio_disconnect(self):
423431
"""Handle the Engine.IO disconnection event."""
424432
self.logger.info('Engine.IO connection dropped')
433+
self._reconnect_abort.set()
425434
for n in self.namespaces:
426435
await self._trigger_event('disconnect', namespace=n)
427436
await self._trigger_event('disconnect', namespace='/')

socketio/client.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
import logging
33
import random
4+
import signal
45

56
import engineio
67
import six
@@ -10,6 +11,21 @@
1011
from . import packet
1112

1213
default_logger = logging.getLogger('socketio.client')
14+
reconnecting_clients = []
15+
16+
17+
def signal_handler(sig, frame): # pragma: no cover
18+
"""SIGINT handler.
19+
20+
Notify any clients that are in a reconnect loop to abort. Other
21+
disconnection tasks are handled at the engine.io level.
22+
"""
23+
for client in reconnecting_clients[:]:
24+
client._reconnect_abort.set()
25+
return original_signal_handler(sig, frame)
26+
27+
28+
original_signal_handler = signal.signal(signal.SIGINT, signal_handler)
1329

1430

1531
class Client(object):
@@ -102,6 +118,7 @@ def __init__(self, reconnection=True, reconnection_attempts=0,
102118
self.callbacks = {}
103119
self._binary_packet = None
104120
self._reconnect_task = None
121+
self._reconnect_abort = self.eio.create_event()
105122

106123
def is_asyncio_based(self):
107124
return False
@@ -486,6 +503,8 @@ def _trigger_event(self, event, namespace, *args):
486503
event, *args)
487504

488505
def _handle_reconnect(self):
506+
self._reconnect_abort.clear()
507+
reconnecting_clients.append(self)
489508
attempt_count = 0
490509
current_delay = self.reconnection_delay
491510
while True:
@@ -497,7 +516,10 @@ def _handle_reconnect(self):
497516
self.logger.info(
498517
'Connection failed, new attempt in {:.02f} seconds'.format(
499518
delay))
500-
self.sleep(delay)
519+
print('***', self._reconnect_abort.wait)
520+
if self._reconnect_abort.wait(delay):
521+
self.logger.info('Reconnect task aborted')
522+
break
501523
attempt_count += 1
502524
try:
503525
self.connect(self.connection_url,
@@ -516,6 +538,7 @@ def _handle_reconnect(self):
516538
self.logger.info(
517539
'Maximum reconnection attempts reached, giving up')
518540
break
541+
reconnecting_clients.remove(self)
519542

520543
def _handle_eio_connect(self):
521544
"""Handle the Engine.IO connection event."""

tests/asyncio/test_asyncio_client.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from contextlib import contextmanager
23
import sys
34
import unittest
45

@@ -26,6 +27,19 @@ async def mock_coro(*args, **kwargs):
2627
return mock_coro
2728

2829

30+
@contextmanager
31+
def mock_wait_for():
32+
async def fake_wait_for(coro, timeout):
33+
await coro
34+
await fake_wait_for._mock(timeout)
35+
36+
original_wait_for = asyncio.wait_for
37+
asyncio.wait_for = fake_wait_for
38+
fake_wait_for._mock = AsyncMock()
39+
yield
40+
asyncio.wait_for = original_wait_for
41+
42+
2943
def _run(coro):
3044
"""Run the given coroutine."""
3145
return asyncio.get_event_loop().run_until_complete(coro)
@@ -542,51 +556,64 @@ def on_foo(self, a, b):
542556
_run(c._trigger_event('foo', '/', 1, '2'))
543557
self.assertEqual(result, [1, '2'])
544558

559+
@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
560+
side_effect=asyncio.TimeoutError)
545561
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
546-
def test_handle_reconnect(self, random):
562+
def test_handle_reconnect(self, random, wait_for):
547563
c = asyncio_client.AsyncClient()
548564
c._reconnect_task = 'foo'
549-
c.sleep = AsyncMock()
550565
c.connect = AsyncMock(
551566
side_effect=[ValueError, exceptions.ConnectionError, None])
552567
_run(c._handle_reconnect())
553-
self.assertEqual(c.sleep.mock.call_count, 3)
554-
self.assertEqual(c.sleep.mock.call_args_list, [
555-
mock.call(1.5),
556-
mock.call(1.5),
557-
mock.call(4.0)
558-
])
568+
self.assertEqual(wait_for.mock.call_count, 3)
569+
self.assertEqual(
570+
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
571+
[1.5, 1.5, 4.0])
559572
self.assertEqual(c._reconnect_task, None)
560573

574+
@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
575+
side_effect=asyncio.TimeoutError)
561576
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
562-
def test_handle_reconnect_max_delay(self, random):
577+
def test_handle_reconnect_max_delay(self, random, wait_for):
563578
c = asyncio_client.AsyncClient(reconnection_delay_max=3)
564579
c._reconnect_task = 'foo'
565-
c.sleep = AsyncMock()
566580
c.connect = AsyncMock(
567581
side_effect=[ValueError, exceptions.ConnectionError, None])
568582
_run(c._handle_reconnect())
569-
self.assertEqual(c.sleep.mock.call_count, 3)
570-
self.assertEqual(c.sleep.mock.call_args_list, [
571-
mock.call(1.5),
572-
mock.call(1.5),
573-
mock.call(3.0)
574-
])
583+
self.assertEqual(wait_for.mock.call_count, 3)
584+
self.assertEqual(
585+
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
586+
[1.5, 1.5, 3.0])
575587
self.assertEqual(c._reconnect_task, None)
576588

589+
@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
590+
side_effect=asyncio.TimeoutError)
577591
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
578-
def test_handle_reconnect_max_attempts(self, random):
592+
def test_handle_reconnect_max_attempts(self, random, wait_for):
579593
c = asyncio_client.AsyncClient(reconnection_attempts=2)
580594
c._reconnect_task = 'foo'
581-
c.sleep = AsyncMock()
582595
c.connect = AsyncMock(
583596
side_effect=[ValueError, exceptions.ConnectionError, None])
584597
_run(c._handle_reconnect())
585-
self.assertEqual(c.sleep.mock.call_count, 2)
586-
self.assertEqual(c.sleep.mock.call_args_list, [
587-
mock.call(1.5),
588-
mock.call(1.5)
589-
])
598+
self.assertEqual(wait_for.mock.call_count, 2)
599+
self.assertEqual(
600+
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
601+
[1.5, 1.5])
602+
self.assertEqual(c._reconnect_task, 'foo')
603+
604+
@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
605+
side_effect=[asyncio.TimeoutError, None])
606+
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
607+
def test_handle_reconnect_aborted(self, random, wait_for):
608+
c = asyncio_client.AsyncClient()
609+
c._reconnect_task = 'foo'
610+
c.connect = AsyncMock(
611+
side_effect=[ValueError, exceptions.ConnectionError, None])
612+
_run(c._handle_reconnect())
613+
self.assertEqual(wait_for.mock.call_count, 2)
614+
self.assertEqual(
615+
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
616+
[1.5, 1.5])
590617
self.assertEqual(c._reconnect_task, 'foo')
591618

592619
def test_eio_connect(self):

tests/common/test_client.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -671,12 +671,12 @@ def on_foo(self, a, b):
671671
def test_handle_reconnect(self, random):
672672
c = client.Client()
673673
c._reconnect_task = 'foo'
674-
c.sleep = mock.MagicMock()
674+
c._reconnect_abort.wait = mock.MagicMock(return_value=False)
675675
c.connect = mock.MagicMock(
676676
side_effect=[ValueError, exceptions.ConnectionError, None])
677677
c._handle_reconnect()
678-
self.assertEqual(c.sleep.call_count, 3)
679-
self.assertEqual(c.sleep.call_args_list, [
678+
self.assertEqual(c._reconnect_abort.wait.call_count, 3)
679+
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
680680
mock.call(1.5),
681681
mock.call(1.5),
682682
mock.call(4.0)
@@ -687,12 +687,12 @@ def test_handle_reconnect(self, random):
687687
def test_handle_reconnect_max_delay(self, random):
688688
c = client.Client(reconnection_delay_max=3)
689689
c._reconnect_task = 'foo'
690-
c.sleep = mock.MagicMock()
690+
c._reconnect_abort.wait = mock.MagicMock(return_value=False)
691691
c.connect = mock.MagicMock(
692692
side_effect=[ValueError, exceptions.ConnectionError, None])
693693
c._handle_reconnect()
694-
self.assertEqual(c.sleep.call_count, 3)
695-
self.assertEqual(c.sleep.call_args_list, [
694+
self.assertEqual(c._reconnect_abort.wait.call_count, 3)
695+
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
696696
mock.call(1.5),
697697
mock.call(1.5),
698698
mock.call(3.0)
@@ -703,12 +703,26 @@ def test_handle_reconnect_max_delay(self, random):
703703
def test_handle_reconnect_max_attempts(self, random):
704704
c = client.Client(reconnection_attempts=2)
705705
c._reconnect_task = 'foo'
706-
c.sleep = mock.MagicMock()
706+
c._reconnect_abort.wait = mock.MagicMock(return_value=False)
707707
c.connect = mock.MagicMock(
708708
side_effect=[ValueError, exceptions.ConnectionError, None])
709709
c._handle_reconnect()
710-
self.assertEqual(c.sleep.call_count, 2)
711-
self.assertEqual(c.sleep.call_args_list, [
710+
self.assertEqual(c._reconnect_abort.wait.call_count, 2)
711+
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
712+
mock.call(1.5),
713+
mock.call(1.5)
714+
])
715+
self.assertEqual(c._reconnect_task, 'foo')
716+
717+
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
718+
def test_handle_reconnect_aborted(self, random):
719+
c = client.Client()
720+
c._reconnect_task = 'foo'
721+
c._reconnect_abort.wait = mock.MagicMock(side_effect=[False, True])
722+
c.connect = mock.MagicMock(side_effect=exceptions.ConnectionError)
723+
c._handle_reconnect()
724+
self.assertEqual(c._reconnect_abort.wait.call_count, 2)
725+
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
712726
mock.call(1.5),
713727
mock.call(1.5)
714728
])

0 commit comments

Comments
 (0)