5
5
import io
6
6
from random import shuffle
7
7
import socket
8
+ import ssl
8
9
import struct
9
10
from threading import local
10
11
import time
29
30
DEFAULT_SOCKET_TIMEOUT_SECONDS = 120
30
31
DEFAULT_KAFKA_PORT = 9092
31
32
33
+ # support older ssl libraries
34
+ try :
35
+ assert ssl .SSLWantReadError
36
+ assert ssl .SSLWantWriteError
37
+ assert ssl .SSLZeroReturnError
38
+ except :
39
+ log .warning ('old ssl module detected.'
40
+ ' ssl error handling may not operate cleanly.'
41
+ ' Consider upgrading to python 3.5 or 2.7' )
42
+ ssl .SSLWantReadError = ssl .SSLError
43
+ ssl .SSLWantWriteError = ssl .SSLError
44
+ ssl .SSLZeroReturnError = ssl .SSLError
45
+
32
46
33
47
class ConnectionStates (object ):
34
48
DISCONNECTING = '<disconnecting>'
35
49
DISCONNECTED = '<disconnected>'
36
50
CONNECTING = '<connecting>'
51
+ HANDSHAKE = '<handshake>'
37
52
CONNECTED = '<connected>'
38
53
39
54
@@ -49,6 +64,12 @@ class BrokerConnection(object):
49
64
'max_in_flight_requests_per_connection' : 5 ,
50
65
'receive_buffer_bytes' : None ,
51
66
'send_buffer_bytes' : None ,
67
+ 'security_protocol' : 'PLAINTEXT' ,
68
+ 'ssl_context' : None ,
69
+ 'ssl_check_hostname' : True ,
70
+ 'ssl_cafile' : None ,
71
+ 'ssl_certfile' : None ,
72
+ 'ssl_keyfile' : None ,
52
73
'api_version' : (0 , 8 , 2 ), # default to most restrictive
53
74
'state_change_callback' : lambda conn : True ,
54
75
}
@@ -66,6 +87,9 @@ def __init__(self, host, port, afi, **configs):
66
87
67
88
self .state = ConnectionStates .DISCONNECTED
68
89
self ._sock = None
90
+ self ._ssl_context = None
91
+ if self .config ['ssl_context' ] is not None :
92
+ self ._ssl_context = self .config ['ssl_context' ]
69
93
self ._rbuffer = io .BytesIO ()
70
94
self ._receiving = False
71
95
self ._next_payload_bytes = 0
@@ -87,6 +111,8 @@ def connect(self):
87
111
self ._sock .setsockopt (socket .SOL_SOCKET , socket .SO_SNDBUF ,
88
112
self .config ['send_buffer_bytes' ])
89
113
self ._sock .setblocking (False )
114
+ if self .config ['security_protocol' ] in ('SSL' , 'SASL_SSL' ):
115
+ self ._wrap_ssl ()
90
116
self .state = ConnectionStates .CONNECTING
91
117
self .last_attempt = time .time ()
92
118
self .config ['state_change_callback' ](self )
@@ -103,7 +129,11 @@ def connect(self):
103
129
# Connection succeeded
104
130
if not ret or ret == errno .EISCONN :
105
131
log .debug ('%s: established TCP connection' , str (self ))
106
- self .state = ConnectionStates .CONNECTED
132
+ if self .config ['security_protocol' ] in ('SSL' , 'SASL_SSL' ):
133
+ log .debug ('%s: initiating SSL handshake' , str (self ))
134
+ self .state = ConnectionStates .HANDSHAKE
135
+ else :
136
+ self .state = ConnectionStates .CONNECTED
107
137
self .config ['state_change_callback' ](self )
108
138
109
139
# Connection failed
@@ -122,8 +152,60 @@ def connect(self):
122
152
else :
123
153
pass
124
154
155
+ if self .state is ConnectionStates .HANDSHAKE :
156
+ if self ._try_handshake ():
157
+ log .debug ('%s: completed SSL handshake.' , str (self ))
158
+ self .state = ConnectionStates .CONNECTED
159
+ self .config ['state_change_callback' ](self )
160
+
125
161
return self .state
126
162
163
+ def _wrap_ssl (self ):
164
+ assert self .config ['security_protocol' ] in ('SSL' , 'SASL_SSL' )
165
+ if self ._ssl_context is None :
166
+ log .debug ('%s: configuring default SSL Context' , str (self ))
167
+ self ._ssl_context = ssl .SSLContext (ssl .PROTOCOL_SSLv23 ) # pylint: disable=no-member
168
+ self ._ssl_context .options |= ssl .OP_NO_SSLv2 # pylint: disable=no-member
169
+ self ._ssl_context .options |= ssl .OP_NO_SSLv3 # pylint: disable=no-member
170
+ self ._ssl_context .verify_mode = ssl .CERT_OPTIONAL
171
+ if self .config ['ssl_check_hostname' ]:
172
+ self ._ssl_context .check_hostname = True
173
+ if self .config ['ssl_cafile' ]:
174
+ log .info ('%s: Loading SSL CA from %s' , str (self ), self .config ['ssl_cafile' ])
175
+ self ._ssl_context .load_verify_locations (self .config ['ssl_cafile' ])
176
+ self ._ssl_context .verify_mode = ssl .CERT_REQUIRED
177
+ if self .config ['ssl_certfile' ] and self .config ['ssl_keyfile' ]:
178
+ log .info ('%s: Loading SSL Cert from %s' , str (self ), self .config ['ssl_certfile' ])
179
+ log .info ('%s: Loading SSL Key from %s' , str (self ), self .config ['ssl_keyfile' ])
180
+ self ._ssl_context .load_cert_chain (
181
+ certfile = self .config ['ssl_certfile' ],
182
+ keyfile = self .config ['ssl_keyfile' ])
183
+ log .debug ('%s: wrapping socket in ssl context' , str (self ))
184
+ try :
185
+ self ._sock = self ._ssl_context .wrap_socket (
186
+ self ._sock ,
187
+ server_hostname = self .host ,
188
+ do_handshake_on_connect = False )
189
+ except ssl .SSLError :
190
+ log .exception ('%s: Failed to wrap socket in SSLContext!' , str (self ))
191
+ self .close ()
192
+ self .last_failure = time .time ()
193
+
194
+ def _try_handshake (self ):
195
+ assert self .config ['security_protocol' ] in ('SSL' , 'SASL_SSL' )
196
+ try :
197
+ self ._sock .do_handshake ()
198
+ return True
199
+ # old ssl in python2.6 will swallow all SSLErrors here...
200
+ except (ssl .SSLWantReadError , ssl .SSLWantWriteError ):
201
+ pass
202
+ except ssl .SSLZeroReturnError :
203
+ log .warning ('SSL connection closed by server during handshake.' )
204
+ self .close ()
205
+ # Other SSLErrors will be raised to user
206
+
207
+ return False
208
+
127
209
def blacked_out (self ):
128
210
"""
129
211
Return true if we are disconnected from the given node and can't
@@ -140,8 +222,10 @@ def connected(self):
140
222
return self .state is ConnectionStates .CONNECTED
141
223
142
224
def connecting (self ):
143
- """Return True iff socket is in intermediate connecting state."""
144
- return self .state is ConnectionStates .CONNECTING
225
+ """Returns True if still connecting (this may encompass several
226
+ different states, such as SSL handshake, authorization, etc)."""
227
+ return self .state in (ConnectionStates .CONNECTING ,
228
+ ConnectionStates .HANDSHAKE )
145
229
146
230
def disconnected (self ):
147
231
"""Return True iff socket is closed"""
@@ -260,6 +344,8 @@ def recv(self):
260
344
# An extremely small, but non-zero, probability that there are
261
345
# more than 0 but not yet 4 bytes available to read
262
346
self ._rbuffer .write (self ._sock .recv (4 - self ._rbuffer .tell ()))
347
+ except ssl .SSLWantReadError :
348
+ return None
263
349
except ConnectionError as e :
264
350
if six .PY2 and e .errno == errno .EWOULDBLOCK :
265
351
return None
@@ -286,6 +372,8 @@ def recv(self):
286
372
staged_bytes = self ._rbuffer .tell ()
287
373
try :
288
374
self ._rbuffer .write (self ._sock .recv (self ._next_payload_bytes - staged_bytes ))
375
+ except ssl .SSLWantReadError :
376
+ return None
289
377
except ConnectionError as e :
290
378
# Extremely small chance that we have exactly 4 bytes for a
291
379
# header, but nothing to read in the body yet
0 commit comments