Skip to content

Commit 811e044

Browse files
New shutdown() method added to the client (Fixes miguelgrinberg#1333)
1 parent 82ceaf7 commit 811e044

File tree

4 files changed

+153
-2
lines changed

4 files changed

+153
-2
lines changed

src/socketio/async_client.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,21 @@ async def disconnect(self):
318318
namespace=n))
319319
await self.eio.disconnect(abort=True)
320320

321+
async def shutdown(self):
322+
"""Stop the client.
323+
324+
If the client is connected to a server, it is disconnected. If the
325+
client is attempting to reconnect to server, the reconnection attempts
326+
are stopped. If the client is not connected to a server and is not
327+
attempting to reconnect, then this function does nothing.
328+
"""
329+
if self.connected:
330+
await self.disconnect()
331+
elif self._reconnect_task: # pragma: no branch
332+
self._reconnect_abort.set()
333+
print(self._reconnect_task)
334+
await self._reconnect_task
335+
321336
def start_background_task(self, target, *args, **kwargs):
322337
"""Start a background task using the appropriate async model.
323338
@@ -467,15 +482,20 @@ async def _handle_reconnect(self):
467482
self.logger.info(
468483
'Connection failed, new attempt in {:.02f} seconds'.format(
469484
delay))
485+
abort = False
470486
try:
471487
await asyncio.wait_for(self._reconnect_abort.wait(), delay)
488+
abort = True
489+
except asyncio.TimeoutError:
490+
pass
491+
except asyncio.CancelledError: # pragma: no cover
492+
abort = True
493+
if abort:
472494
self.logger.info('Reconnect task aborted')
473495
for n in self.connection_namespaces:
474496
await self._trigger_event('__disconnect_final',
475497
namespace=n)
476498
break
477-
except (asyncio.TimeoutError, asyncio.CancelledError):
478-
pass
479499
attempt_count += 1
480500
try:
481501
await self.connect(self.connection_url,

src/socketio/client.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,20 @@ def disconnect(self):
298298
packet.DISCONNECT, namespace=n))
299299
self.eio.disconnect(abort=True)
300300

301+
def shutdown(self):
302+
"""Stop the client.
303+
304+
If the client is connected to a server, it is disconnected. If the
305+
client is attempting to reconnect to server, the reconnection attempts
306+
are stopped. If the client is not connected to a server and is not
307+
attempting to reconnect, then this function does nothing.
308+
"""
309+
if self.connected:
310+
self.disconnect()
311+
elif self._reconnect_task: # pragma: no branch
312+
self._reconnect_abort.set()
313+
self._reconnect_task.join()
314+
301315
def start_background_task(self, target, *args, **kwargs):
302316
"""Start a background task using the appropriate async model.
303317

tests/async/test_client.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,66 @@ def test_handle_reconnect_aborted(self, random, wait_for):
990990
c._trigger_event.mock.assert_called_once_with('__disconnect_final',
991991
namespace='/')
992992

993+
def test_shutdown_disconnect(self):
994+
c = async_client.AsyncClient()
995+
c.connected = True
996+
c.namespaces = {'/': '1'}
997+
c._trigger_event = AsyncMock()
998+
c._send_packet = AsyncMock()
999+
c.eio = mock.MagicMock()
1000+
c.eio.disconnect = AsyncMock()
1001+
c.eio.state = 'connected'
1002+
_run(c.shutdown())
1003+
assert c._trigger_event.mock.call_count == 0
1004+
assert c._send_packet.mock.call_count == 1
1005+
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/')
1006+
assert (
1007+
c._send_packet.mock.call_args_list[0][0][0].encode()
1008+
== expected_packet.encode()
1009+
)
1010+
c.eio.disconnect.mock.assert_called_once_with(abort=True)
1011+
1012+
def test_shutdown_disconnect_namespaces(self):
1013+
c = async_client.AsyncClient()
1014+
c.connected = True
1015+
c.namespaces = {'/foo': '1', '/bar': '2'}
1016+
c._trigger_event = AsyncMock()
1017+
c._send_packet = AsyncMock()
1018+
c.eio = mock.MagicMock()
1019+
c.eio.disconnect = AsyncMock()
1020+
c.eio.state = 'connected'
1021+
_run(c.shutdown())
1022+
assert c._trigger_event.mock.call_count == 0
1023+
assert c._send_packet.mock.call_count == 2
1024+
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/foo')
1025+
assert (
1026+
c._send_packet.mock.call_args_list[0][0][0].encode()
1027+
== expected_packet.encode()
1028+
)
1029+
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/bar')
1030+
assert (
1031+
c._send_packet.mock.call_args_list[1][0][0].encode()
1032+
== expected_packet.encode()
1033+
)
1034+
1035+
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
1036+
def test_shutdown_reconnect(self, random):
1037+
c = async_client.AsyncClient()
1038+
c.connection_namespaces = ['/']
1039+
c._reconnect_task = AsyncMock()()
1040+
c._trigger_event = AsyncMock()
1041+
c.connect = AsyncMock(side_effect=exceptions.ConnectionError)
1042+
1043+
async def r():
1044+
task = c.start_background_task(c._handle_reconnect)
1045+
await asyncio.sleep(0.1)
1046+
await c.shutdown()
1047+
await task
1048+
1049+
_run(r())
1050+
c._trigger_event.mock.assert_called_once_with('__disconnect_final',
1051+
namespace='/')
1052+
9931053
def test_handle_eio_connect(self):
9941054
c = async_client.AsyncClient()
9951055
c.connection_namespaces = ['/', '/foo']

tests/common/test_client.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import time
23
import unittest
34
from unittest import mock
45

@@ -636,6 +637,7 @@ def test_disconnect(self):
636637

637638
def test_disconnect_namespaces(self):
638639
c = client.Client()
640+
c.connected = True
639641
c.namespaces = {'/foo': '1', '/bar': '2'}
640642
c._trigger_event = mock.MagicMock()
641643
c._send_packet = mock.MagicMock()
@@ -1128,6 +1130,61 @@ def test_handle_reconnect_aborted(self, random):
11281130
c._trigger_event.assert_called_once_with('__disconnect_final',
11291131
namespace='/')
11301132

1133+
def test_shutdown_disconnect(self):
1134+
c = client.Client()
1135+
c.connected = True
1136+
c.namespaces = {'/': '1'}
1137+
c._trigger_event = mock.MagicMock()
1138+
c._send_packet = mock.MagicMock()
1139+
c.eio = mock.MagicMock()
1140+
c.eio.state = 'connected'
1141+
c.shutdown()
1142+
assert c._trigger_event.call_count == 0
1143+
assert c._send_packet.call_count == 1
1144+
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/')
1145+
assert (
1146+
c._send_packet.call_args_list[0][0][0].encode()
1147+
== expected_packet.encode()
1148+
)
1149+
c.eio.disconnect.assert_called_once_with(abort=True)
1150+
1151+
def test_shutdown_disconnect_namespaces(self):
1152+
c = client.Client()
1153+
c.connected = True
1154+
c.namespaces = {'/foo': '1', '/bar': '2'}
1155+
c._trigger_event = mock.MagicMock()
1156+
c._send_packet = mock.MagicMock()
1157+
c.eio = mock.MagicMock()
1158+
c.eio.state = 'connected'
1159+
c.shutdown()
1160+
assert c._trigger_event.call_count == 0
1161+
assert c._send_packet.call_count == 2
1162+
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/foo')
1163+
assert (
1164+
c._send_packet.call_args_list[0][0][0].encode()
1165+
== expected_packet.encode()
1166+
)
1167+
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/bar')
1168+
assert (
1169+
c._send_packet.call_args_list[1][0][0].encode()
1170+
== expected_packet.encode()
1171+
)
1172+
1173+
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
1174+
def test_shutdown_reconnect(self, random):
1175+
c = client.Client()
1176+
c.connection_namespaces = ['/']
1177+
c._reconnect_task = mock.MagicMock()
1178+
c._trigger_event = mock.MagicMock()
1179+
c.connect = mock.MagicMock(side_effect=exceptions.ConnectionError)
1180+
task = c.start_background_task(c._handle_reconnect)
1181+
time.sleep(0.1)
1182+
c.shutdown()
1183+
task.join()
1184+
c._trigger_event.assert_called_once_with('__disconnect_final',
1185+
namespace='/')
1186+
assert c._reconnect_task.join.called_once_with()
1187+
11311188
def test_handle_eio_connect(self):
11321189
c = client.Client()
11331190
c.connection_namespaces = ['/', '/foo']

0 commit comments

Comments
 (0)