15
15
import kafka .errors as Errors
16
16
from kafka .future import Future
17
17
from kafka .protocol .api import RequestHeader
18
+ from kafka .protocol .admin import SaslHandShakeRequest , SaslHandShakeResponse
18
19
from kafka .protocol .commit import GroupCoordinatorResponse
19
20
from kafka .protocol .types import Int32
20
21
from kafka .version import __version__
@@ -48,7 +49,7 @@ class ConnectionStates(object):
48
49
CONNECTING = '<connecting>'
49
50
HANDSHAKE = '<handshake>'
50
51
CONNECTED = '<connected>'
51
-
52
+ AUTHENTICATING = '<authenticating>'
52
53
53
54
InFlightRequest = collections .namedtuple ('InFlightRequest' ,
54
55
['request' , 'response_type' , 'correlation_id' , 'future' , 'timestamp' ])
@@ -73,6 +74,9 @@ class BrokerConnection(object):
73
74
'ssl_password' : None ,
74
75
'api_version' : (0 , 8 , 2 ), # default to most restrictive
75
76
'state_change_callback' : lambda conn : True ,
77
+ 'sasl_mechanism' : None ,
78
+ 'sasl_plain_username' : None ,
79
+ 'sasl_plain_password' : None
76
80
}
77
81
78
82
def __init__ (self , host , port , afi , ** configs ):
@@ -188,6 +192,8 @@ def connect(self):
188
192
if self .config ['security_protocol' ] in ('SSL' , 'SASL_SSL' ):
189
193
log .debug ('%s: initiating SSL handshake' , str (self ))
190
194
self .state = ConnectionStates .HANDSHAKE
195
+ elif self .config ['security_protocol' ] == 'SASL_PLAINTEXT' :
196
+ self .state = ConnectionStates .AUTHENTICATING
191
197
else :
192
198
self .state = ConnectionStates .CONNECTED
193
199
self .config ['state_change_callback' ](self )
@@ -211,6 +217,15 @@ def connect(self):
211
217
if self .state is ConnectionStates .HANDSHAKE :
212
218
if self ._try_handshake ():
213
219
log .debug ('%s: completed SSL handshake.' , str (self ))
220
+ if self .config ['security_protocol' ] == 'SASL_SSL' :
221
+ self .state = ConnectionStates .AUTHENTICATING
222
+ else :
223
+ self .state = ConnectionStates .CONNECTED
224
+ self .config ['state_change_callback' ](self )
225
+
226
+ if self .state is ConnectionStates .AUTHENTICATING :
227
+ if self ._try_authenticate ():
228
+ log .debug ('%s: Authenticated as %s' , str (self ), self .config ['sasl_plain_username' ])
214
229
self .state = ConnectionStates .CONNECTED
215
230
self .config ['state_change_callback' ](self )
216
231
@@ -273,6 +288,90 @@ def _try_handshake(self):
273
288
274
289
return False
275
290
291
+ def _try_authenticate (self ):
292
+ assert self .config ['security_protocol' ] in ('SASL_PLAINTEXT' , 'SASL_SSL' )
293
+
294
+ if self .config ['security_protocol' ] == 'SASL_PLAINTEXT' :
295
+ log .warning ('%s: Sending username and password in the clear' , str (self ))
296
+
297
+ # Build a SaslHandShakeRequest message
298
+ correlation_id = self ._next_correlation_id ()
299
+ request = SaslHandShakeRequest [0 ](self .config ['sasl_mechanism' ])
300
+ header = RequestHeader (request ,
301
+ correlation_id = correlation_id ,
302
+ client_id = self .config ['client_id' ])
303
+
304
+ message = b'' .join ([header .encode (), request .encode ()])
305
+ size = Int32 .encode (len (message ))
306
+
307
+ # Attempt to send it over our socket
308
+ try :
309
+ self ._sock .setblocking (True )
310
+ self ._sock .sendall (size + message )
311
+ self ._sock .setblocking (False )
312
+ except (AssertionError , ConnectionError ) as e :
313
+ log .exception ("Error sending %s to %s" , request , self )
314
+ error = Errors .ConnectionError ("%s: %s" % (str (self ), e ))
315
+ self .close (error = error )
316
+ return False
317
+
318
+ future = Future ()
319
+ ifr = InFlightRequest (request = request ,
320
+ correlation_id = correlation_id ,
321
+ response_type = request .RESPONSE_TYPE ,
322
+ future = future ,
323
+ timestamp = time .time ())
324
+ self .in_flight_requests .append (ifr )
325
+
326
+ # Listen for a reply and check that the server supports the PLAIN mechanism
327
+ response = None
328
+ while not response :
329
+ response = self .recv ()
330
+
331
+ if not response .error_code is 0 :
332
+ raise Errors .for_code (response .error_code )
333
+
334
+ if not self .config ['sasl_mechanism' ] in response .enabled_mechanisms :
335
+ raise Errors .AuthenticationMethodNotSupported (self .config ['sasl_mechanism' ] + " is not supported by broker" )
336
+
337
+ return self ._try_authenticate_plain ()
338
+
339
+ def _try_authenticate_plain (self ):
340
+ data = b''
341
+ try :
342
+ self ._sock .setblocking (True )
343
+ # Send our credentials
344
+ msg = bytes ('\0 ' .join ([self .config ['sasl_plain_username' ],
345
+ self .config ['sasl_plain_username' ],
346
+ self .config ['sasl_plain_password' ]]).encode ('utf-8' ))
347
+ size = Int32 .encode (len (msg ))
348
+ self ._sock .sendall (size + msg )
349
+
350
+ # The server will send a zero sized message (that is Int32(0)) on success.
351
+ # The connection is closed on failure
352
+ received_bytes = 0
353
+ while received_bytes < 4 :
354
+ data = data + self ._sock .recv (4 - received_bytes )
355
+ received_bytes = received_bytes + len (data )
356
+ if not data :
357
+ log .error ('%s: Authentication failed for user %s' , self , self .config ['sasl_plain_username' ])
358
+ self .close (error = Errors .ConnectionError ('Authentication failed' ))
359
+ raise Errors .AuthenticationFailedError ('Authentication failed for user {}' .format (self .config ['sasl_plain_username' ]))
360
+ self ._sock .setblocking (False )
361
+ except (AssertionError , ConnectionError ) as e :
362
+ log .exception ("%s: Error receiving reply from server" , self )
363
+ error = Errors .ConnectionError ("%s: %s" % (str (self ), e ))
364
+ self .close (error = error )
365
+ return False
366
+
367
+ with io .BytesIO () as buffer :
368
+ buffer .write (data )
369
+ buffer .seek (0 )
370
+ if not Int32 .decode (buffer ) == 0 :
371
+ raise Errors .KafkaError ('Expected a zero sized reply after sending credentials' )
372
+
373
+ return True
374
+
276
375
def blacked_out (self ):
277
376
"""
278
377
Return true if we are disconnected from the given node and can't
@@ -292,7 +391,8 @@ def connecting(self):
292
391
"""Returns True if still connecting (this may encompass several
293
392
different states, such as SSL handshake, authorization, etc)."""
294
393
return self .state in (ConnectionStates .CONNECTING ,
295
- ConnectionStates .HANDSHAKE )
394
+ ConnectionStates .HANDSHAKE ,
395
+ ConnectionStates .AUTHENTICATING )
296
396
297
397
def disconnected (self ):
298
398
"""Return True iff socket is closed"""
@@ -385,7 +485,7 @@ def recv(self):
385
485
Return response if available
386
486
"""
387
487
assert not self ._processing , 'Recursion not supported'
388
- if not self .connected ():
488
+ if not self .connected () and not self . state is ConnectionStates . AUTHENTICATING :
389
489
log .warning ('%s cannot recv: socket not connected' , self )
390
490
# If requests are pending, we should close the socket and
391
491
# fail all the pending request futures
0 commit comments