Skip to content

Support AWS_MSK_IAM authentication #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import kafka.errors as Errors
from kafka.future import Future
from kafka.metrics.stats import Avg, Count, Max, Rate
from kafka.msk import AwsMskIamClient
from kafka.oauth.abstract import AbstractTokenProvider
from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest_v2, DescribeClientQuotasRequest
from kafka.protocol.commit import OffsetFetchRequest
Expand Down Expand Up @@ -83,6 +84,12 @@ class SSLWantWriteError(Exception):
gssapi = None
GSSError = None

# needed for AWS_MSK_IAM authentication:
try:
from botocore.session import Session as BotoSession
except ImportError:
# no botocore available, will disable AWS_MSK_IAM mechanism
BotoSession = None

AFI_NAMES = {
socket.AF_UNSPEC: "unspecified",
Expand Down Expand Up @@ -227,7 +234,7 @@ class BrokerConnection(object):
'sasl_oauth_token_provider': None
}
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512")
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512", 'AWS_MSK_IAM')

def __init__(self, host, port, afi, **configs):
self.host = host
Expand Down Expand Up @@ -276,6 +283,9 @@ def __init__(self, host, port, afi, **configs):
token_provider = self.config['sasl_oauth_token_provider']
assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl'
assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()'
if self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package'
assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL'
# This is not a general lock / this class is not generally thread-safe yet
# However, to avoid pushing responsibility for maintaining
# per-connection locks to the upstream client, we will use this lock to
Expand Down Expand Up @@ -561,6 +571,8 @@ def _handle_sasl_handshake_response(self, future, response):
return self._try_authenticate_oauth(future)
elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"):
return self._try_authenticate_scram(future)
elif self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
return self._try_authenticate_aws_msk_iam(future)
else:
return future.failure(
Errors.UnsupportedSaslMechanismError(
Expand Down Expand Up @@ -661,6 +673,44 @@ def _try_authenticate_plain(self, future):
log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
return future.success(True)

def _try_authenticate_aws_msk_iam(self, future):
session = BotoSession()
credentials = session.get_credentials().get_frozen_credentials()
client = AwsMskIamClient(
host=self.host,
access_key=credentials.access_key,
secret_key=credentials.secret_key,
region=session.get_config_variable('region'),
token=credentials.token,
)

msg = client.first_message()
size = Int32.encode(len(msg))

err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
try:
self._send_bytes_blocking(size + msg)
data = self._recv_bytes_blocking(4)
data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1])
except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

log.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8'))
return future.success(True)

def _try_authenticate_scram(self, future):
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
log.warning('%s: Exchanging credentials in the clear', self)
Expand Down
184 changes: 184 additions & 0 deletions kafka/msk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import datetime
import hashlib
import hmac
import json
import string

from kafka.vendor.six.moves import urllib


class AwsMskIamClient:
UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~'

def __init__(self, host, access_key, secret_key, region, token=None):
"""
Arguments:
host (str): The hostname of the broker.
access_key (str): An AWS_ACCESS_KEY_ID.
secret_key (str): An AWS_SECRET_ACCESS_KEY.
region (str): An AWS_REGION.
token (Optional[str]): An AWS_SESSION_TOKEN if using temporary
credentials.
"""
self.algorithm = 'AWS4-HMAC-SHA256'
self.expires = '900'
self.hashfunc = hashlib.sha256
self.headers = [
('host', host)
]
self.version = '2020_10_22'

self.service = 'kafka-cluster'
self.action = '{}:Connect'.format(self.service)

now = datetime.datetime.utcnow()
self.datestamp = now.strftime('%Y%m%d')
self.timestamp = now.strftime('%Y%m%dT%H%M%SZ')

self.host = host
self.access_key = access_key
self.secret_key = secret_key
self.region = region
self.token = token

@property
def _credential(self):
return '{0.access_key}/{0._scope}'.format(self)

@property
def _scope(self):
return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self)

@property
def _signed_headers(self):
"""
Returns (str):
An alphabetically sorted, semicolon-delimited list of lowercase
request header names.
"""
return ';'.join(sorted(k.lower() for k, _ in self.headers))

@property
def _canonical_headers(self):
"""
Returns (str):
A newline-delited list of header names and values.
Header names are lowercased.
"""
return '\n'.join(map(':'.join, self.headers)) + '\n'

@property
def _canonical_request(self):
"""
Returns (str):
An AWS Signature Version 4 canonical request in the format:
<Method>\n
<Path>\n
<CanonicalQueryString>\n
<CanonicalHeaders>\n
<SignedHeaders>\n
<HashedPayload>
"""
# The hashed_payload is always an empty string for MSK.
hashed_payload = self.hashfunc(b'').hexdigest()
return '\n'.join((
'GET',
'/',
self._canonical_querystring,
self._canonical_headers,
self._signed_headers,
hashed_payload,
))

@property
def _canonical_querystring(self):
"""
Returns (str):
A '&'-separated list of URI-encoded key/value pairs.
"""
params = []
params.append(('Action', self.action))
params.append(('X-Amz-Algorithm', self.algorithm))
params.append(('X-Amz-Credential', self._credential))
params.append(('X-Amz-Date', self.timestamp))
params.append(('X-Amz-Expires', self.expires))
if self.token:
params.append(('X-Amz-Security-Token', self.token))
params.append(('X-Amz-SignedHeaders', self._signed_headers))

return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params)

@property
def _signing_key(self):
"""
Returns (bytes):
An AWS Signature V4 signing key generated from the secret_key, date,
region, service, and request type.
"""
key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp)
key = self._hmac(key, self.region)
key = self._hmac(key, self.service)
key = self._hmac(key, 'aws4_request')
return key

@property
def _signing_str(self):
"""
Returns (str):
A string used to sign the AWS Signature V4 payload in the format:
<Algorithm>\n
<Timestamp>\n
<Scope>\n
<CanonicalRequestHash>
"""
canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest()
return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash))

def _uriencode(self, msg):
"""
Arguments:
msg (str): A string to URI-encode.

Returns (str):
The URI-encoded version of the provided msg, following the encoding
rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode
"""
return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS)

def _hmac(self, key, msg):
"""
Arguments:
key (bytes): A key to use for the HMAC digest.
msg (str): A value to include in the HMAC digest.
Returns (bytes):
An HMAC digest of the given key and msg.
"""
return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest()

def first_message(self):
"""
Returns (bytes):
An encoded JSON authentication payload that can be sent to the
broker.
"""
signature = hmac.new(
self._signing_key,
self._signing_str.encode('utf-8'),
digestmod=self.hashfunc,
).hexdigest()
msg = {
'version': self.version,
'host': self.host,
'user-agent': 'kafka-python',
'action': self.action,
'x-amz-algorithm': self.algorithm,
'x-amz-credential': self._credential,
'x-amz-date': self.timestamp,
'x-amz-signedheaders': self._signed_headers,
'x-amz-expires': self.expires,
'x-amz-signature': signature,
}
if self.token:
msg['x-amz-security-token'] = self.token

return json.dumps(msg, separators=(',', ':')).encode('utf-8')
67 changes: 67 additions & 0 deletions test/test_msk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import datetime
import json

from kafka.msk import AwsMskIamClient

try:
from unittest import mock
except ImportError:
import mock


def client_factory(token=None):
now = datetime.datetime.utcfromtimestamp(1629321911)
with mock.patch('kafka.msk.datetime') as mock_dt:
mock_dt.datetime.utcnow = mock.Mock(return_value=now)
return AwsMskIamClient(
host='localhost',
access_key='XXXXXXXXXXXXXXXXXXXX',
secret_key='XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX',
region='us-east-1',
token=token,
)


def test_aws_msk_iam_client_permanent_credentials():
client = client_factory(token=None)
msg = client.first_message()
assert msg
assert isinstance(msg, bytes)
actual = json.loads(msg)

expected = {
'version': '2020_10_22',
'host': 'localhost',
'user-agent': 'kafka-python',
'action': 'kafka-cluster:Connect',
'x-amz-algorithm': 'AWS4-HMAC-SHA256',
'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request',
'x-amz-date': '20210818T212511Z',
'x-amz-signedheaders': 'host',
'x-amz-expires': '900',
'x-amz-signature': '0fa42ae3d5693777942a7a4028b564f0b372bafa2f71c1a19ad60680e6cb994b',
}
assert actual == expected


def test_aws_msk_iam_client_temporary_credentials():
client = client_factory(token='XXXXX')
msg = client.first_message()
assert msg
assert isinstance(msg, bytes)
actual = json.loads(msg)

expected = {
'version': '2020_10_22',
'host': 'localhost',
'user-agent': 'kafka-python',
'action': 'kafka-cluster:Connect',
'x-amz-algorithm': 'AWS4-HMAC-SHA256',
'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request',
'x-amz-date': '20210818T212511Z',
'x-amz-signedheaders': 'host',
'x-amz-expires': '900',
'x-amz-signature': 'b0619c50b7ecb4a7f6f92bd5f733770df5710e97b25146f97015c0b1db783b05',
'x-amz-security-token': 'XXXXX',
}
assert actual == expected