Skip to content

Commit 11b6f1a

Browse files
Pass auth information sent by client to the connect handler
1 parent 3349b02 commit 11b6f1a

File tree

5 files changed

+89
-11
lines changed

5 files changed

+89
-11
lines changed

docs/server.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ The ``connect`` and ``disconnect`` events are special; they are invoked
182182
automatically when a client connects or disconnects from the server::
183183

184184
@sio.event
185-
def connect(sid, environ):
185+
def connect(sid, environ, auth):
186186
print('connect ', sid)
187187

188188
@sio.event
@@ -193,8 +193,10 @@ The ``connect`` event is an ideal place to perform user authentication, and
193193
any necessary mapping between user entities in the application and the ``sid``
194194
that was assigned to the client. The ``environ`` argument is a dictionary in
195195
standard WSGI format containing the request information, including HTTP
196-
headers. After inspecting the request, the connect event handler can return
197-
``False`` to reject the connection with the client.
196+
headers. The ``auth`` argument contains any authentication details passed by
197+
the client, or ``None`` if the client did not pass anything. After inspecting
198+
the request, the connect event handler can return ``False`` to reject the
199+
connection with the client.
198200

199201
Sometimes it is useful to pass data back to the client being rejected. In that
200202
case instead of returning ``False``

socketio/asyncio_server.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ async def _send_packet(self, eio_sid, pkt):
433433
else:
434434
await self.eio.send(eio_sid, encoded_packet)
435435

436-
async def _handle_connect(self, eio_sid, namespace):
436+
async def _handle_connect(self, eio_sid, namespace, data):
437437
"""Handle a client connection request."""
438438
namespace = namespace or '/'
439439
sid = self.manager.connect(eio_sid, namespace)
@@ -442,8 +442,16 @@ async def _handle_connect(self, eio_sid, namespace):
442442
packet.CONNECT, {'sid': sid}, namespace=namespace))
443443
fail_reason = exceptions.ConnectionRefusedError().error_args
444444
try:
445-
success = await self._trigger_event('connect', namespace, sid,
446-
self.environ[eio_sid])
445+
if data:
446+
success = await self._trigger_event(
447+
'connect', namespace, sid, self.environ[eio_sid], data)
448+
else:
449+
try:
450+
success = await self._trigger_event(
451+
'connect', namespace, sid, self.environ[eio_sid])
452+
except TypeError:
453+
success = await self._trigger_event(
454+
'connect', namespace, sid, self.environ[eio_sid], None)
447455
except exceptions.ConnectionRefusedError as exc:
448456
fail_reason = exc.error_args
449457
success = False
@@ -552,7 +560,7 @@ async def _handle_eio_message(self, eio_sid, data):
552560
else:
553561
pkt = packet.Packet(encoded_packet=data)
554562
if pkt.packet_type == packet.CONNECT:
555-
await self._handle_connect(eio_sid, pkt.namespace)
563+
await self._handle_connect(eio_sid, pkt.namespace, pkt.data)
556564
elif pkt.packet_type == packet.DISCONNECT:
557565
await self._handle_disconnect(eio_sid, pkt.namespace)
558566
elif pkt.packet_type == packet.EVENT:

socketio/server.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def _send_packet(self, eio_sid, pkt):
619619
else:
620620
self.eio.send(eio_sid, encoded_packet)
621621

622-
def _handle_connect(self, eio_sid, namespace):
622+
def _handle_connect(self, eio_sid, namespace, data):
623623
"""Handle a client connection request."""
624624
namespace = namespace or '/'
625625
sid = self.manager.connect(eio_sid, namespace)
@@ -628,8 +628,16 @@ def _handle_connect(self, eio_sid, namespace):
628628
packet.CONNECT, {'sid': sid}, namespace=namespace))
629629
fail_reason = exceptions.ConnectionRefusedError().error_args
630630
try:
631-
success = self._trigger_event('connect', namespace, sid,
632-
self.environ[eio_sid])
631+
if data:
632+
success = self._trigger_event(
633+
'connect', namespace, sid, self.environ[eio_sid], data)
634+
else:
635+
try:
636+
success = self._trigger_event(
637+
'connect', namespace, sid, self.environ[eio_sid])
638+
except TypeError:
639+
success = self._trigger_event(
640+
'connect', namespace, sid, self.environ[eio_sid], None)
633641
except exceptions.ConnectionRefusedError as exc:
634642
fail_reason = exc.error_args
635643
success = False
@@ -729,7 +737,7 @@ def _handle_eio_message(self, eio_sid, data):
729737
else:
730738
pkt = packet.Packet(encoded_packet=data)
731739
if pkt.packet_type == packet.CONNECT:
732-
self._handle_connect(eio_sid, pkt.namespace)
740+
self._handle_connect(eio_sid, pkt.namespace, pkt.data)
733741
elif pkt.packet_type == packet.DISCONNECT:
734742
self._handle_disconnect(eio_sid, pkt.namespace)
735743
elif pkt.packet_type == packet.EVENT:

tests/asyncio/test_asyncio_server.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,38 @@ def test_handle_connect(self, eio):
377377
_run(s._handle_eio_message('456', '0'))
378378
assert s.manager.initialize.call_count == 1
379379

380+
def test_handle_connect_with_auth(self, eio):
381+
eio.return_value.send = AsyncMock()
382+
s = asyncio_server.AsyncServer()
383+
s.manager.initialize = mock.MagicMock()
384+
handler = mock.MagicMock()
385+
s.on('connect', handler)
386+
_run(s._handle_eio_connect('123', 'environ'))
387+
_run(s._handle_eio_message('123', '0{"token":"abc"}'))
388+
assert s.manager.is_connected('1', '/')
389+
handler.assert_called_once_with('1', 'environ', {'token': 'abc'})
390+
s.eio.send.mock.assert_called_once_with('123', '0{"sid":"1"}')
391+
assert s.manager.initialize.call_count == 1
392+
_run(s._handle_eio_connect('456', 'environ'))
393+
_run(s._handle_eio_message('456', '0'))
394+
assert s.manager.initialize.call_count == 1
395+
396+
def test_handle_connect_with_auth_none(self, eio):
397+
eio.return_value.send = AsyncMock()
398+
s = asyncio_server.AsyncServer()
399+
s.manager.initialize = mock.MagicMock()
400+
handler = mock.MagicMock(side_effect=[TypeError, None, None])
401+
s.on('connect', handler)
402+
_run(s._handle_eio_connect('123', 'environ'))
403+
_run(s._handle_eio_message('123', '0'))
404+
assert s.manager.is_connected('1', '/')
405+
handler.assert_called_with('1', 'environ', None)
406+
s.eio.send.mock.assert_called_once_with('123', '0{"sid":"1"}')
407+
assert s.manager.initialize.call_count == 1
408+
_run(s._handle_eio_connect('456', 'environ'))
409+
_run(s._handle_eio_message('456', '0'))
410+
assert s.manager.initialize.call_count == 1
411+
380412
def test_handle_connect_async(self, eio):
381413
eio.return_value.send = AsyncMock()
382414
s = asyncio_server.AsyncServer()

tests/common/test_server.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,34 @@ def test_handle_connect(self, eio):
326326
s._handle_eio_connect('456', 'environ')
327327
assert s.manager.initialize.call_count == 1
328328

329+
def test_handle_connect_with_auth(self, eio):
330+
s = server.Server()
331+
s.manager.initialize = mock.MagicMock()
332+
handler = mock.MagicMock()
333+
s.on('connect', handler)
334+
s._handle_eio_connect('123', 'environ')
335+
s._handle_eio_message('123', '0{"token":"abc"}')
336+
assert s.manager.is_connected('1', '/')
337+
handler.assert_called_with('1', 'environ', {'token': 'abc'})
338+
s.eio.send.assert_called_once_with('123', '0{"sid":"1"}')
339+
assert s.manager.initialize.call_count == 1
340+
s._handle_eio_connect('456', 'environ')
341+
assert s.manager.initialize.call_count == 1
342+
343+
def test_handle_connect_with_auth_none(self, eio):
344+
s = server.Server()
345+
s.manager.initialize = mock.MagicMock()
346+
handler = mock.MagicMock(side_effect=[TypeError, None])
347+
s.on('connect', handler)
348+
s._handle_eio_connect('123', 'environ')
349+
s._handle_eio_message('123', '0')
350+
assert s.manager.is_connected('1', '/')
351+
handler.assert_called_with('1', 'environ', None)
352+
s.eio.send.assert_called_once_with('123', '0{"sid":"1"}')
353+
assert s.manager.initialize.call_count == 1
354+
s._handle_eio_connect('456', 'environ')
355+
assert s.manager.initialize.call_count == 1
356+
329357
def test_handle_connect_namespace(self, eio):
330358
s = server.Server()
331359
handler = mock.MagicMock()

0 commit comments

Comments
 (0)