Skip to content

Commit 5b393ac

Browse files
committed
Merge pull request dpkp#671 from dpkp/disconnects
Improve socket disconnect handling
2 parents 161fa6d + fa59d4d commit 5b393ac

File tree

3 files changed

+87
-24
lines changed

3 files changed

+87
-24
lines changed

kafka/client_async.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _bootstrap(self, hosts):
142142
# Exponential backoff if bootstrap fails
143143
backoff_ms = self.config['reconnect_backoff_ms'] * 2 ** self._bootstrap_fails
144144
next_at = self._last_bootstrap + backoff_ms / 1000.0
145+
self._refresh_on_disconnects = False
145146
now = time.time()
146147
if next_at > now:
147148
log.debug("Sleeping %0.4f before bootstrapping again", next_at - now)
@@ -180,6 +181,7 @@ def _bootstrap(self, hosts):
180181
log.error('Unable to bootstrap from %s', hosts)
181182
# Max exponential backoff is 2^12, x4000 (50ms -> 200s)
182183
self._bootstrap_fails = min(self._bootstrap_fails + 1, 12)
184+
self._refresh_on_disconnects = True
183185

184186
def _can_connect(self, node_id):
185187
if node_id not in self._conns:
@@ -223,7 +225,7 @@ def _conn_state_change(self, node_id, conn):
223225
except KeyError:
224226
pass
225227
if self._refresh_on_disconnects:
226-
log.warning("Node %s connect failed -- refreshing metadata", node_id)
228+
log.warning("Node %s connection failed -- refreshing metadata", node_id)
227229
self.cluster.request_update()
228230

229231
def _maybe_connect(self, node_id):

kafka/conn.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,17 @@ def recv(self):
381381
# Not receiving is the state of reading the payload header
382382
if not self._receiving:
383383
try:
384-
# An extremely small, but non-zero, probability that there are
385-
# more than 0 but not yet 4 bytes available to read
386-
self._rbuffer.write(self._sock.recv(4 - self._rbuffer.tell()))
384+
bytes_to_read = 4 - self._rbuffer.tell()
385+
data = self._sock.recv(bytes_to_read)
386+
# We expect socket.recv to raise an exception if there is not
387+
# enough data to read the full bytes_to_read
388+
# but if the socket is disconnected, we will get empty data
389+
# without an exception raised
390+
if not data:
391+
log.error('%s: socket disconnected', self)
392+
self.close(error=Errors.ConnectionError('socket disconnected'))
393+
return None
394+
self._rbuffer.write(data)
387395
except ssl.SSLWantReadError:
388396
return None
389397
except ConnectionError as e:
@@ -411,7 +419,17 @@ def recv(self):
411419
if self._receiving:
412420
staged_bytes = self._rbuffer.tell()
413421
try:
414-
self._rbuffer.write(self._sock.recv(self._next_payload_bytes - staged_bytes))
422+
bytes_to_read = self._next_payload_bytes - staged_bytes
423+
data = self._sock.recv(bytes_to_read)
424+
# We expect socket.recv to raise an exception if there is not
425+
# enough data to read the full bytes_to_read
426+
# but if the socket is disconnected, we will get empty data
427+
# without an exception raised
428+
if not data:
429+
log.error('%s: socket disconnected', self)
430+
self.close(error=Errors.ConnectionError('socket disconnected'))
431+
return None
432+
self._rbuffer.write(data)
415433
except ssl.SSLWantReadError:
416434
return None
417435
except ConnectionError as e:

test/test_conn.py

+62-19
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import absolute_import
33

44
from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET
5+
import socket
56
import time
67

78
import pytest
@@ -14,17 +15,16 @@
1415

1516

1617
@pytest.fixture
17-
def socket(mocker):
18+
def _socket(mocker):
1819
socket = mocker.MagicMock()
1920
socket.connect_ex.return_value = 0
2021
mocker.patch('socket.socket', return_value=socket)
2122
return socket
2223

2324

2425
@pytest.fixture
25-
def conn(socket):
26-
from socket import AF_INET
27-
conn = BrokerConnection('localhost', 9092, AF_INET)
26+
def conn(_socket):
27+
conn = BrokerConnection('localhost', 9092, socket.AF_INET)
2828
return conn
2929

3030

@@ -38,23 +38,23 @@ def conn(socket):
3838
([EALREADY], ConnectionStates.CONNECTING),
3939
([EISCONN], ConnectionStates.CONNECTED)),
4040
])
41-
def test_connect(socket, conn, states):
41+
def test_connect(_socket, conn, states):
4242
assert conn.state is ConnectionStates.DISCONNECTED
4343

4444
for errno, state in states:
45-
socket.connect_ex.side_effect = errno
45+
_socket.connect_ex.side_effect = errno
4646
conn.connect()
4747
assert conn.state is state
4848

4949

50-
def test_connect_timeout(socket, conn):
50+
def test_connect_timeout(_socket, conn):
5151
assert conn.state is ConnectionStates.DISCONNECTED
5252

5353
# Initial connect returns EINPROGRESS
5454
# immediate inline connect returns EALREADY
5555
# second explicit connect returns EALREADY
5656
# third explicit connect returns EALREADY and times out via last_attempt
57-
socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY]
57+
_socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY]
5858
conn.connect()
5959
assert conn.state is ConnectionStates.CONNECTING
6060
conn.connect()
@@ -108,15 +108,15 @@ def test_send_max_ifr(conn):
108108
assert isinstance(f.exception, Errors.TooManyInFlightRequests)
109109

110110

111-
def test_send_no_response(socket, conn):
111+
def test_send_no_response(_socket, conn):
112112
conn.connect()
113113
assert conn.state is ConnectionStates.CONNECTED
114114
req = MetadataRequest[0]([])
115115
header = RequestHeader(req, client_id=conn.config['client_id'])
116116
payload_bytes = len(header.encode()) + len(req.encode())
117117
third = payload_bytes // 3
118118
remainder = payload_bytes % 3
119-
socket.send.side_effect = [4, third, third, third, remainder]
119+
_socket.send.side_effect = [4, third, third, third, remainder]
120120

121121
assert len(conn.in_flight_requests) == 0
122122
f = conn.send(req, expect_response=False)
@@ -125,36 +125,34 @@ def test_send_no_response(socket, conn):
125125
assert len(conn.in_flight_requests) == 0
126126

127127

128-
def test_send_response(socket, conn):
128+
def test_send_response(_socket, conn):
129129
conn.connect()
130130
assert conn.state is ConnectionStates.CONNECTED
131131
req = MetadataRequest[0]([])
132132
header = RequestHeader(req, client_id=conn.config['client_id'])
133133
payload_bytes = len(header.encode()) + len(req.encode())
134134
third = payload_bytes // 3
135135
remainder = payload_bytes % 3
136-
socket.send.side_effect = [4, third, third, third, remainder]
136+
_socket.send.side_effect = [4, third, third, third, remainder]
137137

138138
assert len(conn.in_flight_requests) == 0
139139
f = conn.send(req)
140140
assert f.is_done is False
141141
assert len(conn.in_flight_requests) == 1
142142

143143

144-
def test_send_error(socket, conn):
144+
def test_send_error(_socket, conn):
145145
conn.connect()
146146
assert conn.state is ConnectionStates.CONNECTED
147147
req = MetadataRequest[0]([])
148-
header = RequestHeader(req, client_id=conn.config['client_id'])
149148
try:
150-
error = ConnectionError
149+
_socket.send.side_effect = ConnectionError
151150
except NameError:
152-
from socket import error
153-
socket.send.side_effect = error
151+
_socket.send.side_effect = socket.error
154152
f = conn.send(req)
155153
assert f.failed() is True
156154
assert isinstance(f.exception, Errors.ConnectionError)
157-
assert socket.close.call_count == 1
155+
assert _socket.close.call_count == 1
158156
assert conn.state is ConnectionStates.DISCONNECTED
159157

160158

@@ -167,7 +165,52 @@ def test_can_send_more(conn):
167165
assert conn.can_send_more() is False
168166

169167

170-
def test_recv(socket, conn):
168+
def test_recv_disconnected():
169+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
170+
sock.bind(('127.0.0.1', 0))
171+
port = sock.getsockname()[1]
172+
sock.listen(5)
173+
174+
conn = BrokerConnection('127.0.0.1', port, socket.AF_INET)
175+
timeout = time.time() + 1
176+
while time.time() < timeout:
177+
conn.connect()
178+
if conn.connected():
179+
break
180+
else:
181+
assert False, 'Connection attempt to local socket timed-out ?'
182+
183+
conn.send(MetadataRequest[0]([]))
184+
185+
# Disconnect server socket
186+
sock.close()
187+
188+
# Attempt to receive should mark connection as disconnected
189+
assert conn.connected()
190+
conn.recv()
191+
assert conn.disconnected()
192+
193+
194+
def test_recv_disconnected_too(_socket, conn):
195+
conn.connect()
196+
assert conn.connected()
197+
198+
req = MetadataRequest[0]([])
199+
header = RequestHeader(req, client_id=conn.config['client_id'])
200+
payload_bytes = len(header.encode()) + len(req.encode())
201+
_socket.send.side_effect = [4, payload_bytes]
202+
conn.send(req)
203+
204+
# Empty data on recv means the socket is disconnected
205+
_socket.recv.return_value = b''
206+
207+
# Attempt to receive should mark connection as disconnected
208+
assert conn.connected()
209+
conn.recv()
210+
assert conn.disconnected()
211+
212+
213+
def test_recv(_socket, conn):
171214
pass # TODO
172215

173216

0 commit comments

Comments
 (0)