Skip to content

Commit ad12b83

Browse files
Move ack functionality into BaseManager class
1 parent ebea5aa commit ad12b83

File tree

5 files changed

+89
-61
lines changed

5 files changed

+89
-61
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
],
3030
tests_require=[
3131
'mock',
32+
'pbr<1.7.0', # temporary, to workaround bug in 1.7.0
3233
],
3334
test_suite='tests',
3435
classifiers=[

socketio/base_manager.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
import six
24

35

@@ -14,6 +16,7 @@ def __init__(self, server):
1416
self.server = server
1517
self.rooms = {}
1618
self.pending_removals = []
19+
self.callbacks = {}
1720

1821
def get_namespaces(self):
1922
"""Return an iterable with the active namespace names."""
@@ -43,6 +46,10 @@ def disconnect(self, sid, namespace):
4346
rooms.append(room_name)
4447
for room in rooms:
4548
self.leave_room(sid, namespace, room)
49+
if sid in self.callbacks and namespace in self.callbacks[sid]:
50+
del self.callbacks[sid][namespace]
51+
if len(self.callbacks[sid]) == 0:
52+
del self.callbacks[sid]
4653

4754
def enter_room(self, sid, namespace, room):
4855
"""Add a client to a room."""
@@ -86,8 +93,31 @@ def emit(self, event, data, namespace, room=None, skip_sid=None,
8693
return
8794
for sid in self.get_participants(namespace, room):
8895
if sid != skip_sid:
89-
self.server._emit_internal(sid, event, data, namespace,
90-
callback)
96+
if callback is not None:
97+
id = self.server._generate_ack_id(sid, namespace, callback)
98+
else:
99+
id = None
100+
self.server._emit_internal(sid, event, data, namespace, id)
101+
102+
def trigger_callback(self, sid, namespace, id, data):
103+
"""Invoke an application callback."""
104+
try:
105+
callback = self.callbacks[sid][namespace][id]
106+
except KeyError:
107+
raise ValueError('Unknown callback')
108+
del self.callbacks[sid][namespace][id]
109+
callback(*data)
110+
111+
def _generate_ack_id(self, sid, namespace, callback):
112+
"""Generate a unique identifier for an ACK packet."""
113+
namespace = namespace or '/'
114+
if sid not in self.callbacks:
115+
self.callbacks[sid] = {}
116+
if namespace not in self.callbacks[sid]:
117+
self.callbacks[sid][namespace] = {0: itertools.count(1)}
118+
id = six.next(self.callbacks[sid][namespace][0])
119+
self.callbacks[sid][namespace][id] = callback
120+
return id
91121

92122
def _clean_rooms(self):
93123
"""Remove all the inactive room participants."""

socketio/server.py

100755100644
Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
import logging
32

43
import engineio
@@ -83,7 +82,6 @@ def __init__(self, client_manager_class=None, logger=False, binary=False,
8382

8483
self.environ = {}
8584
self.handlers = {}
86-
self.callbacks = {}
8785

8886
self._binary_packet = None
8987
self._attachment_count = 0
@@ -304,12 +302,8 @@ def handle_request(self, environ, start_response):
304302
"""
305303
return self.eio.handle_request(environ, start_response)
306304

307-
def _emit_internal(self, sid, event, data, namespace=None, callback=None):
305+
def _emit_internal(self, sid, event, data, namespace=None, id=None):
308306
"""Send a message to a client."""
309-
if callback is not None:
310-
id = self._generate_ack_id(sid, namespace, callback)
311-
else:
312-
id = None
313307
if six.PY2 and not self.binary:
314308
binary = False # pragma: nocover
315309
else:
@@ -353,13 +347,9 @@ def _handle_disconnect(self, sid, namespace):
353347
if n != '/' and self.manager.is_connected(sid, n):
354348
self._trigger_event('disconnect', n, sid)
355349
self.manager.disconnect(sid, n)
356-
if sid in self.callbacks and n in self.callbacks[sid]:
357-
del self.callbacks[sid][n]
358350
if namespace == '/' and self.manager.is_connected(sid, namespace):
359351
self._trigger_event('disconnect', '/', sid)
360352
self.manager.disconnect(sid, '/')
361-
if sid in self.callbacks:
362-
del self.callbacks[sid]
363353
if sid in self.environ:
364354
del self.environ[sid]
365355

@@ -390,34 +380,13 @@ def _handle_ack(self, sid, namespace, id, data):
390380
"""Handle ACK packets from the client."""
391381
namespace = namespace or '/'
392382
self.logger.info('received ack from %s [%s]', sid, namespace)
393-
self._trigger_callback(sid, namespace, id, data)
383+
self.manager.trigger_callback(sid, namespace, id, data)
394384

395385
def _trigger_event(self, event, namespace, *args):
396386
"""Invoke an application event handler."""
397387
if namespace in self.handlers and event in self.handlers[namespace]:
398388
return self.handlers[namespace][event](*args)
399389

400-
def _generate_ack_id(self, sid, namespace, callback):
401-
"""Generate a unique identifier for an ACK packet."""
402-
namespace = namespace or '/'
403-
if sid not in self.callbacks:
404-
self.callbacks[sid] = {}
405-
if namespace not in self.callbacks[sid]:
406-
self.callbacks[sid][namespace] = {0: itertools.count(1)}
407-
id = six.next(self.callbacks[sid][namespace][0])
408-
self.callbacks[sid][namespace][id] = callback
409-
return id
410-
411-
def _trigger_callback(self, sid, namespace, id, data):
412-
"""Invoke an application callback."""
413-
namespace = namespace or '/'
414-
try:
415-
callback = self.callbacks[sid][namespace][id]
416-
except KeyError:
417-
raise ValueError('Unknown callback')
418-
del self.callbacks[sid][namespace][id]
419-
callback(*data)
420-
421390
def _handle_eio_connect(self, sid, environ):
422391
"""Handle the Engine.IO connection event."""
423392
self.environ[sid] = environ

tests/test_base_manager.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
class TestBaseManager(unittest.TestCase):
1313
def setUp(self):
1414
mock_server = mock.MagicMock()
15-
mock_server.rooms = {}
1615
self.bm = base_manager.BaseManager(mock_server)
1716

1817
def test_connect(self):
@@ -78,6 +77,40 @@ def test_disconnect_all(self):
7877
self.bm._clean_rooms()
7978
self.assertEqual(self.bm.rooms, {})
8079

80+
def test_disconnect_with_callbacks(self):
81+
self.bm.connect('123', '/')
82+
self.bm.connect('123', '/foo')
83+
self.bm._generate_ack_id('123', '/', 'f')
84+
self.bm._generate_ack_id('123', '/foo', 'g')
85+
self.bm.disconnect('123', '/foo')
86+
self.assertNotIn('/foo', self.bm.callbacks['123'])
87+
self.bm.disconnect('123', '/')
88+
self.assertNotIn('123', self.bm.callbacks)
89+
90+
def test_trigger_callback(self):
91+
self.bm.connect('123', '/')
92+
self.bm.connect('123', '/foo')
93+
cb = mock.MagicMock()
94+
id1 = self.bm._generate_ack_id('123', '/', cb)
95+
id2 = self.bm._generate_ack_id('123', '/foo', cb)
96+
self.bm.trigger_callback('123', '/', id1, ['foo'])
97+
self.bm.trigger_callback('123', '/foo', id2, ['bar', 'baz'])
98+
self.assertEqual(cb.call_count, 2)
99+
cb.assert_any_call('foo')
100+
cb.assert_any_call('bar', 'baz')
101+
102+
def test_invalid_callback(self):
103+
self.bm.connect('123', '/')
104+
cb = mock.MagicMock()
105+
id = self.bm._generate_ack_id('123', '/', cb)
106+
self.assertRaises(ValueError, self.bm.trigger_callback,
107+
'124', '/', id, ['foo'])
108+
self.assertRaises(ValueError, self.bm.trigger_callback,
109+
'123', '/foo', id, ['foo'])
110+
self.assertRaises(ValueError, self.bm.trigger_callback,
111+
'123', '/', id + 1, ['foo'])
112+
self.assertEqual(cb.call_count, 0)
113+
81114
def test_get_namespaces(self):
82115
self.assertEqual(list(self.bm.get_namespaces()), [])
83116
self.bm.connect('123', '/')
@@ -185,6 +218,16 @@ def test_emit_to_all_skip_one(self):
185218
{'foo': 'bar'}, '/foo',
186219
None)
187220

221+
def test_emit_with_callback(self):
222+
self.bm.connect('123', '/foo')
223+
self.bm.server._generate_ack_id.return_value = 11
224+
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo',
225+
callback='cb')
226+
self.bm.server._emit_internal.assert_called_once_with('123',
227+
'my event',
228+
{'foo': 'bar'},
229+
'/foo', 11)
230+
188231
def test_emit_to_invalid_room(self):
189232
self.bm.emit('my event', {'foo': 'bar'}, namespace='/', room='123')
190233

tests/test_server.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def test_emit_internal(self, eio):
127127

128128
def test_emit_internal_with_callback(self, eio):
129129
s = server.Server()
130-
s._emit_internal('123', 'my event', 'my data', namespace='/foo',
131-
callback='cb')
130+
id = s.manager._generate_ack_id('123', '/foo', 'cb')
131+
s._emit_internal('123', 'my event', 'my data', namespace='/foo', id=id)
132132
s.eio.send.assert_called_once_with('123',
133133
'2/foo,1["my event","my data"]',
134134
binary=False)
@@ -323,40 +323,25 @@ def test_handle_invalid_packet(self, eio):
323323

324324
def test_send_with_ack(self, eio):
325325
s = server.Server()
326-
cb = mock.MagicMock()
327326
s._handle_eio_connect('123', 'environ')
328-
s._emit_internal('123', 'my event', ['foo'], callback=cb)
329-
s._emit_internal('123', 'my event', ['bar'], callback=cb)
327+
cb = mock.MagicMock()
328+
id1 = s.manager._generate_ack_id('123', '/', cb)
329+
id2 = s.manager._generate_ack_id('123', '/', cb)
330+
s._emit_internal('123', 'my event', ['foo'], id=id1)
331+
s._emit_internal('123', 'my event', ['bar'], id=id2)
330332
s._handle_eio_message('123', '31["foo",2]')
331333
cb.assert_called_once_with('foo', 2)
332-
self.assertIn('123', s.callbacks)
333-
s._handle_disconnect('123', '/')
334-
self.assertNotIn('123', s.callbacks)
335334

336335
def test_send_with_ack_namespace(self, eio):
337336
s = server.Server()
338-
cb = mock.MagicMock()
339337
s._handle_eio_connect('123', 'environ')
340338
s._handle_eio_message('123', '0/foo')
339+
cb = mock.MagicMock()
340+
id = s.manager._generate_ack_id('123', '/foo', cb)
341341
s._emit_internal('123', 'my event', ['foo'], namespace='/foo',
342-
callback=cb)
342+
id=id)
343343
s._handle_eio_message('123', '3/foo,1["foo",2]')
344344
cb.assert_called_once_with('foo', 2)
345-
self.assertIn('/foo', s.callbacks['123'])
346-
s._handle_eio_disconnect('123')
347-
self.assertNotIn('123', s.callbacks)
348-
349-
def test_invalid_callback(self, eio):
350-
s = server.Server()
351-
cb = mock.MagicMock()
352-
s._handle_eio_connect('123', 'environ')
353-
s._emit_internal('123', 'my event', ['foo'], callback=cb)
354-
self.assertRaises(ValueError, s._handle_eio_message, '124',
355-
'31["foo",2]')
356-
self.assertRaises(ValueError, s._handle_eio_message, '123',
357-
'3/foo,1["foo",2]')
358-
self.assertRaises(ValueError, s._handle_eio_message, '123',
359-
'32["foo",2]')
360345

361346
def test_disconnect(self, eio):
362347
s = server.Server()

0 commit comments

Comments
 (0)