Skip to content

SASL SCRAM support #1920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 118 additions & 8 deletions kafka/conn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import absolute_import, division

import base64
import collections
import copy
import errno
import hashlib
import hmac
import io
import logging
from random import shuffle, uniform
from random import shuffle, uniform, SystemRandom
import re

# selectors in stdlib as of py3.4
try:
Expand All @@ -15,6 +19,7 @@
from kafka.vendor import selectors34 as selectors

import socket
import string
import struct
import sys
import threading
Expand Down Expand Up @@ -177,11 +182,11 @@ class BrokerConnection(object):
metric_group_prefix (str): Prefix for metric names. Default: ''
sasl_mechanism (str): Authentication mechanism when security_protocol
is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are:
PLAIN, GSSAPI, OAUTHBEARER.
PLAIN, SCRAM-SHA-256, SCRAM-SHA-512, GSSAPI, OAUTHBEARER.
sasl_plain_username (str): username for sasl PLAIN authentication.
Required if sasl_mechanism is PLAIN.
Required if sasl_mechanism is PLAIN, SCRAM-SHA-256, or SCRAM-SHA-512.
sasl_plain_password (str): password for sasl PLAIN authentication.
Required if sasl_mechanism is PLAIN.
Required if sasl_mechanism is PLAIN, SCRAM-SHA-256, or SCRAM-SHA-512.
sasl_kerberos_service_name (str): Service name to include in GSSAPI
sasl mechanism handshake. Default: 'kafka'
sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI
Expand Down Expand Up @@ -224,7 +229,7 @@ class BrokerConnection(object):
'sasl_oauth_token_provider': None
}
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER')
SASL_MECHANISMS = ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512', 'GSSAPI', 'OAUTHBEARER')

def __init__(self, host, port, afi, **configs):
self.host = host
Expand Down Expand Up @@ -259,9 +264,11 @@ def __init__(self, host, port, afi, **configs):
if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'):
assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, (
'sasl_mechanism must be in ' + ', '.join(self.SASL_MECHANISMS))
if self.config['sasl_mechanism'] == 'PLAIN':
assert self.config['sasl_plain_username'] is not None, 'sasl_plain_username required for PLAIN sasl'
assert self.config['sasl_plain_password'] is not None, 'sasl_plain_password required for PLAIN sasl'
if self.config['sasl_mechanism'] == 'PLAIN' or self.config['sasl_mechanism'].startswith('SCRAM'):
assert self.config['sasl_plain_username'] is not None, \
'sasl_plain_username required for {} sasl'.format(self.config['sasl_mechanism'])
assert self.config['sasl_plain_password'] is not None, \
'sasl_plain_password required for {} sasl'.format(self.config['sasl_mechanism'])
if self.config['sasl_mechanism'] == 'GSSAPI':
assert gssapi is not None, 'GSSAPI lib not available'
assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl'
Expand Down Expand Up @@ -548,6 +555,8 @@ def _handle_sasl_handshake_response(self, future, response):
% (self.config['sasl_mechanism'], response.enabled_mechanisms)))
elif self.config['sasl_mechanism'] == 'PLAIN':
return self._try_authenticate_plain(future)
elif self.config['sasl_mechanism'].startswith('SCRAM'):
return self._try_authenticate_scram(future)
elif self.config['sasl_mechanism'] == 'GSSAPI':
return self._try_authenticate_gssapi(future)
elif self.config['sasl_mechanism'] == 'OAUTHBEARER':
Expand Down Expand Up @@ -652,6 +661,107 @@ def _try_authenticate_plain(self, future):
log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
return future.success(True)

def _try_authenticate_scram(self, future):
err = None
close = False
with self._lock:
try:
if not self._can_send_recv():
raise Errors.NodeNotReadyError(str(self))

# SCRAM authentication as defined in RFC 5802
rand = SystemRandom()
client_nonce = "".join([rand.choice(string.ascii_letters + string.digits) for i in range(32)])
gs2_header = "n,,"
client_first_msg_bare = "n={},r={}".format(self.config["sasl_plain_username"], client_nonce)
client_first_msg = (gs2_header + client_first_msg_bare).encode("utf-8")
size = Int32.encode(len(client_first_msg))
self._send_bytes_blocking(size + client_first_msg)

size = Int32.decode(io.BytesIO(self._recv_bytes_blocking(4)))
server_first_msg = self._recv_bytes_blocking(size).decode("utf-8")

# ignore extensions
m = re.match("r=(?P<r>[^,]+),s=(?P<s>[^,]+),i=(?P<i>\d+)", server_first_msg)
if m is None:
err_msg = "failed to parse server-first-message: {}".format(server_first_msg)
raise Errors.AuthenticationFailedError(err_msg)

server_nonce = m.group("r")
if server_nonce[0:len(client_nonce)] != client_nonce:
err_msg = "nonce verification failed"
raise Errors.AuthenticationFailedError(err_msg)
salt = base64.b64decode(m.group("s"))
iterations = int(m.group("i"))

if self.config["sasl_mechanism"] == "SCRAM-SHA-256":
hash_name = "sha256"
hash_func = hashlib.sha256
elif self.config["sasl_mechanism"] == "SCRAM-SHA-512":
hash_name = "sha512"
hash_func = hashlib.sha512
else:
err_msg = "unknown sasl_mechanism: {}".format(self.config["sasl_mechanism"])
raise Errors.AuthenticationFailedError(err_msg)

c = base64.standard_b64encode(gs2_header.encode("utf-8")).decode("utf-8")
client_final_msg_without_proof = "c={},r={}".format(c, server_nonce)

salted_password = hashlib.pbkdf2_hmac(hash_name,
self.config["sasl_plain_password"].encode("utf-8"),
salt,
iterations)
client_key = hmac.new(salted_password, "Client Key".encode("utf-8"), hash_func).digest()
stored_key = hash_func(client_key).digest()
auth_msg = (client_first_msg_bare + "," + server_first_msg + "," + \
client_final_msg_without_proof).encode("utf-8")
client_sig = hmac.new(stored_key, auth_msg, hash_func).digest()
if sys.version_info[0] < 3:
client_proof = bytearray([ord(k) ^ ord(s) for (k,s) in zip(client_key, client_sig)])
else:
client_proof = bytes([k ^ s for (k,s) in zip(client_key, client_sig)])
server_key = hmac.new(salted_password, "Server Key".encode("utf-8"), hash_func).digest()
server_sig = hmac.new(server_key, auth_msg, hash_func).digest()

p = base64.standard_b64encode(client_proof).decode("utf-8")
client_final_msg = (client_final_msg_without_proof + ",p={}".format(p)).encode("utf-8")
size = Int32.encode(len(client_final_msg))
self._send_bytes_blocking(size + client_final_msg)

size = Int32.decode(io.BytesIO(self._recv_bytes_blocking(4)))
server_final_msg = self._recv_bytes_blocking(size).decode("utf-8")

m = re.match("v=(?P<v>[^,]+)", server_final_msg)
if m is None:
err_msg = "failed to parse server-final-message: {}".format(server_final_msg)
raise Errors.AuthenticationFailedError(err_msg)

if base64.standard_b64decode(m.group("v")) != server_sig:
err_msg = "failed to validate server signature"
raise Errors.AuthenticationFailedError(err_msg)
except Errors.NodeNotReadyError as e:
err = e
close = False
except Errors.AuthenticationFailedError as e:
err = e
close = True
except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True
except Exception as e:
err = e
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

log.info('%s: Authenticated as %s via %s', self, self.config['sasl_plain_username'],
self.config['sasl_mechanism'])
return future.success(True)

def _try_authenticate_gssapi(self, future):
kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host
auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name
Expand Down