diff --git a/kafka/conn.py b/kafka/conn.py index d04acce3e..82e6d0cad 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -2,7 +2,6 @@ import copy import errno -import io import logging from random import shuffle, uniform @@ -14,25 +13,26 @@ from kafka.vendor import selectors34 as selectors import socket -import struct import threading import time from kafka.vendor import six +from kafka import sasl import kafka.errors as Errors from kafka.future import Future from kafka.metrics.stats import Avg, Count, Max, Rate -from kafka.oauth.abstract import AbstractTokenProvider -from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest_v2, DescribeClientQuotasRequest +from kafka.protocol.admin import ( + DescribeAclsRequest_v2, + DescribeClientQuotasRequest, + SaslHandShakeRequest, +) from kafka.protocol.commit import OffsetFetchRequest from kafka.protocol.offset import OffsetRequest from kafka.protocol.produce import ProduceRequest from kafka.protocol.metadata import MetadataRequest from kafka.protocol.fetch import FetchRequest from kafka.protocol.parser import KafkaProtocol -from kafka.protocol.types import Int32, Int8 -from kafka.scram import ScramClient from kafka.version import __version__ @@ -83,6 +83,12 @@ class SSLWantWriteError(Exception): gssapi = None GSSError = None +# needed for AWS_MSK_IAM authentication: +try: + from botocore.session import Session as BotoSession +except ImportError: + # no botocore available, will disable AWS_MSK_IAM mechanism + BotoSession = None AFI_NAMES = { socket.AF_UNSPEC: "unspecified", @@ -227,7 +233,6 @@ class BrokerConnection(object): 'sasl_oauth_token_provider': None } SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL') - SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512") def __init__(self, host, port, afi, **configs): self.host = host @@ -260,22 +265,10 @@ def __init__(self, host, port, afi, **configs): assert ssl_available, "Python wasn't built with SSL support" if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): - assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, ( - 'sasl_mechanism must be in ' + ', '.join(self.SASL_MECHANISMS)) - if self.config['sasl_mechanism'] in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'): - assert self.config['sasl_plain_username'] is not None, ( - 'sasl_plain_username required for PLAIN or SCRAM sasl' - ) - assert self.config['sasl_plain_password'] is not None, ( - 'sasl_plain_password required for PLAIN or SCRAM sasl' - ) - if self.config['sasl_mechanism'] == 'GSSAPI': - assert gssapi is not None, 'GSSAPI lib not available' - assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl' - if self.config['sasl_mechanism'] == 'OAUTHBEARER': - token_provider = self.config['sasl_oauth_token_provider'] - assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl' - assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()' + assert self.config['sasl_mechanism'] in sasl.MECHANISMS, ( + 'sasl_mechanism must be one of {}'.format(', '.join(sasl.MECHANISMS.keys())) + ) + sasl.MECHANISMS[self.config['sasl_mechanism']].validate_config(self) # This is not a general lock / this class is not generally thread-safe yet # However, to avoid pushing responsibility for maintaining # per-connection locks to the upstream client, we will use this lock to @@ -553,19 +546,9 @@ def _handle_sasl_handshake_response(self, future, response): Errors.UnsupportedSaslMechanismError( 'Kafka broker does not support %s sasl mechanism. Enabled mechanisms are: %s' % (self.config['sasl_mechanism'], response.enabled_mechanisms))) - elif self.config['sasl_mechanism'] == 'PLAIN': - return self._try_authenticate_plain(future) - elif self.config['sasl_mechanism'] == 'GSSAPI': - return self._try_authenticate_gssapi(future) - elif self.config['sasl_mechanism'] == 'OAUTHBEARER': - return self._try_authenticate_oauth(future) - elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"): - return self._try_authenticate_scram(future) - else: - return future.failure( - Errors.UnsupportedSaslMechanismError( - 'kafka-python does not support SASL mechanism %s' % - self.config['sasl_mechanism'])) + + try_authenticate = sasl.MECHANISMS[self.config['sasl_mechanism']].try_authenticate + return try_authenticate(self, future) def _send_bytes(self, data): """Send some data via non-blocking IO @@ -619,225 +602,6 @@ def _recv_bytes_blocking(self, n): finally: self._sock.settimeout(0.0) - def _try_authenticate_plain(self, future): - if self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.warning('%s: Sending username and password in the clear', self) - - data = b'' - # Send PLAIN credentials per RFC-4616 - msg = bytes('\0'.join([self.config['sasl_plain_username'], - self.config['sasl_plain_username'], - self.config['sasl_plain_password']]).encode('utf-8')) - size = Int32.encode(len(msg)) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - self._send_bytes_blocking(size + msg) - - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - if data != b'\x00\x00\x00\x00': - error = Errors.AuthenticationFailedError('Unrecognized response during authentication') - return future.failure(error) - - log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username']) - return future.success(True) - - def _try_authenticate_scram(self, future): - if self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.warning('%s: Exchanging credentials in the clear', self) - - scram_client = ScramClient( - self.config['sasl_plain_username'], self.config['sasl_plain_password'], self.config['sasl_mechanism'] - ) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - client_first = scram_client.first_message().encode('utf-8') - size = Int32.encode(len(client_first)) - self._send_bytes_blocking(size + client_first) - - (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_first = self._recv_bytes_blocking(data_len).decode('utf-8') - scram_client.process_server_first_message(server_first) - - client_final = scram_client.final_message().encode('utf-8') - size = Int32.encode(len(client_final)) - self._send_bytes_blocking(size + client_final) - - (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_final = self._recv_bytes_blocking(data_len).decode('utf-8') - scram_client.process_server_final_message(server_final) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - log.info( - '%s: Authenticated as %s via %s', self, self.config['sasl_plain_username'], self.config['sasl_mechanism'] - ) - return future.success(True) - - def _try_authenticate_gssapi(self, future): - kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host - auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name - gssapi_name = gssapi.Name( - auth_id, - name_type=gssapi.NameType.hostbased_service - ).canonicalize(gssapi.MechType.kerberos) - log.debug('%s: GSSAPI name: %s', self, gssapi_name) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - # Establish security context and negotiate protection level - # For reference RFC 2222, section 7.2.1 - try: - # Exchange tokens until authentication either succeeds or fails - client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') - received_token = None - while not client_ctx.complete: - # calculate an output token from kafka token (or None if first iteration) - output_token = client_ctx.step(received_token) - - # pass output token to kafka, or send empty response if the security - # context is complete (output token is None in that case) - if output_token is None: - self._send_bytes_blocking(Int32.encode(0)) - else: - msg = output_token - size = Int32.encode(len(msg)) - self._send_bytes_blocking(size + msg) - - # The server will send a token back. Processing of this token either - # establishes a security context, or it needs further token exchange. - # The gssapi will be able to identify the needed next step. - # The connection is closed on failure. - header = self._recv_bytes_blocking(4) - (token_size,) = struct.unpack('>i', header) - received_token = self._recv_bytes_blocking(token_size) - - # Process the security layer negotiation token, sent by the server - # once the security context is established. - - # unwraps message containing supported protection levels and msg size - msg = client_ctx.unwrap(received_token).message - # Kafka currently doesn't support integrity or confidentiality security layers, so we - # simply set QoP to 'auth' only (first octet). We reuse the max message size proposed - # by the server - msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] - # add authorization identity to the response, GSS-wrap and send it - msg = client_ctx.wrap(msg + auth_id.encode(), False).message - size = Int32.encode(len(msg)) - self._send_bytes_blocking(size + msg) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - except Exception as e: - err = e - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name) - return future.success(True) - - def _try_authenticate_oauth(self, future): - data = b'' - - msg = bytes(self._build_oauth_client_request().encode("utf-8")) - size = Int32.encode(len(msg)) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - # Send SASL OAuthBearer request with OAuth token - self._send_bytes_blocking(size + msg) - - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - if data != b'\x00\x00\x00\x00': - error = Errors.AuthenticationFailedError('Unrecognized response during authentication') - return future.failure(error) - - log.info('%s: Authenticated via OAuth', self) - return future.success(True) - - def _build_oauth_client_request(self): - token_provider = self.config['sasl_oauth_token_provider'] - return "n,,\x01auth=Bearer {}{}\x01\x01".format(token_provider.token(), self._token_extensions()) - - def _token_extensions(self): - """ - Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER - initial request. - """ - token_provider = self.config['sasl_oauth_token_provider'] - - # Only run if the #extensions() method is implemented by the clients Token Provider class - # Builds up a string separated by \x01 via a dict of key value pairs - if callable(getattr(token_provider, "extensions", None)) and len(token_provider.extensions()) > 0: - msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()]) - return "\x01" + msg - else: - return "" - def blacked_out(self): """ Return true if we are disconnected from the given node and can't diff --git a/kafka/sasl/__init__.py b/kafka/sasl/__init__.py new file mode 100644 index 000000000..52621830f --- /dev/null +++ b/kafka/sasl/__init__.py @@ -0,0 +1,53 @@ +import logging + +from kafka.sasl import gssapi, oauthbearer, plain, scram + +log = logging.getLogger(__name__) + +MECHANISMS = { + 'GSSAPI': gssapi, + 'OAUTHBEARER': oauthbearer, + 'PLAIN': plain, + 'SCRAM-SHA-256': scram, + 'SCRAM-SHA-512': scram, +} + + +def register_mechanism(key, module): + """ + Registers a custom SASL mechanism that can be used via sasl_mechanism={key}. + + Example: + import kakfa.sasl + from kafka import KafkaProducer + from mymodule import custom_sasl + kafka.sasl.register_mechanism('CUSTOM_SASL', custom_sasl) + + producer = KafkaProducer(sasl_mechanism='CUSTOM_SASL') + + Arguments: + key (str): The name of the mechanism returned by the broker and used + in the sasl_mechanism config value. + module (module): A module that implements the following methods... + + def validate_config(conn: BrokerConnection): -> None: + # Raises an AssertionError for missing or invalid conifg values. + + def try_authenticate(conn: BrokerConncetion, future: -> Future): + # Executes authentication routine and returns a resolved Future. + + Raises: + AssertionError: The registered module does not define a required method. + """ + assert callable(getattr(module, 'validate_config', None)), ( + 'Custom SASL mechanism {} must implement method #validate_config()' + .format(key) + ) + assert callable(getattr(module, 'try_authenticate', None)), ( + 'Custom SASL mechanism {} must implement method #try_authenticate()' + .format(key) + ) + if key in MECHANISMS: + log.warning('Overriding existing SASL mechanism {}'.format(key)) + + MECHANISMS[key] = module diff --git a/kafka/sasl/gssapi.py b/kafka/sasl/gssapi.py new file mode 100644 index 000000000..92a3ed954 --- /dev/null +++ b/kafka/sasl/gssapi.py @@ -0,0 +1,100 @@ +import io +import logging +import struct + +import kafka.errors as Errors +from kafka.protocol.types import Int8, Int32 + +try: + import gssapi + from gssapi.raw.misc import GSSError +except ImportError: + gssapi = None + GSSError = None + +log = logging.getLogger(__name__) + +SASL_QOP_AUTH = 1 + + +def validate_config(conn): + assert gssapi is not None, ( + 'gssapi library required when sasl_mechanism=GSSAPI' + ) + assert conn.config['sasl_kerberos_service_name'] is not None, ( + 'sasl_kerberos_service_name required when sasl_mechanism=GSSAPI' + ) + + +def try_authenticate(conn, future): + kerberos_damin_name = conn.config['sasl_kerberos_domain_name'] or conn.host + auth_id = conn.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name + gssapi_name = gssapi.Name( + auth_id, + name_type=gssapi.NameType.hostbased_service + ).canonicalize(gssapi.MechType.kerberos) + log.debug('%s: GSSAPI name: %s', conn, gssapi_name) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + # Establish security context and negotiate protection level + # For reference RFC 2222, section 7.2.1 + try: + # Exchange tokens until authentication either succeeds or fails + client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') + received_token = None + while not client_ctx.complete: + # calculate an output token from kafka token (or None if first iteration) + output_token = client_ctx.step(received_token) + + # pass output token to kafka, or send empty response if the security + # context is complete (output token is None in that case) + if output_token is None: + conn._send_bytes_blocking(Int32.encode(0)) + else: + msg = output_token + size = Int32.encode(len(msg)) + conn._send_bytes_blocking(size + msg) + + # The server will send a token back. Processing of this token either + # establishes a security context, or it needs further token exchange. + # The gssapi will be able to identify the needed next step. + # The connection is closed on failure. + header = conn._recv_bytes_blocking(4) + (token_size,) = struct.unpack('>i', header) + received_token = conn._recv_bytes_blocking(token_size) + + # Process the security layer negotiation token, sent by the server + # once the security context is established. + + # unwraps message containing supported protection levels and msg size + msg = client_ctx.unwrap(received_token).message + # Kafka currently doesn't support integrity or confidentiality + # security layers, so we simply set QoP to 'auth' only (first octet). + # We reuse the max message size proposed by the server + msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] + # add authorization identity to the response, GSS-wrap and send it + msg = client_ctx.wrap(msg + auth_id.encode(), False).message + size = Int32.encode(len(msg)) + conn._send_bytes_blocking(size + msg) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError("%s: %s" % (conn, e)) + close = True + except Exception as e: + err = e + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + log.info('%s: Authenticated as %s via GSSAPI', conn, gssapi_name) + return future.success(True) diff --git a/kafka/sasl/oauthbearer.py b/kafka/sasl/oauthbearer.py new file mode 100644 index 000000000..f1427af9a --- /dev/null +++ b/kafka/sasl/oauthbearer.py @@ -0,0 +1,80 @@ +import logging + +import kafka.errors as Errors +from kafka.protocol.types import Int32 + +log = logging.getLogger(__name__) + + +def validate_config(conn): + token_provider = conn.config.get('sasl_oauth_token_provider') + assert token_provider is not None, ( + 'sasl_oauth_token_provider required when sasl_mechanism=OAUTHBEARER' + ) + assert callable(getattr(token_provider, 'token', None)), ( + 'sasl_oauth_token_provider must implement method #token()' + ) + + +def try_authenticate(conn, future): + data = b'' + + msg = bytes(_build_oauth_client_request(conn).encode("utf-8")) + size = Int32.encode(len(msg)) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + try: + # Send SASL OAuthBearer request with OAuth token + conn._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = conn._recv_bytes_blocking(4) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError("%s: %s" % (conn, e)) + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated via OAuth', conn) + return future.success(True) + + +def _build_oauth_client_request(conn): + token_provider = conn.config['sasl_oauth_token_provider'] + return "n,,\x01auth=Bearer {}{}\x01\x01".format( + token_provider.token(), + _token_extensions(conn), + ) + + +def _token_extensions(conn): + """ + Return a string representation of the OPTIONAL key-value pairs that can be + sent with an OAUTHBEARER initial request. + """ + token_provider = conn.config['sasl_oauth_token_provider'] + + # Only run if the #extensions() method is implemented by the clients Token Provider class + # Builds up a string separated by \x01 via a dict of key value pairs + if (callable(getattr(token_provider, "extensions", None)) + and len(token_provider.extensions()) > 0): + msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()]) + return "\x01" + msg + else: + return "" diff --git a/kafka/sasl/plain.py b/kafka/sasl/plain.py new file mode 100644 index 000000000..5aedcbbb9 --- /dev/null +++ b/kafka/sasl/plain.py @@ -0,0 +1,58 @@ +import logging + +import kafka.errors as Errors +from kafka.protocol.types import Int32 + +log = logging.getLogger(__name__) + + +def validate_config(conn): + assert conn.config['sasl_plain_username'] is not None, ( + 'sasl_plain_username required when sasl_mechanism=PLAIN' + ) + assert conn.config['sasl_plain_password'] is not None, ( + 'sasl_plain_password required when sasl_mechanism=PLAIN' + ) + + +def try_authenticate(conn, future): + if conn.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Sending username and password in the clear', conn) + + data = b'' + # Send PLAIN credentials per RFC-4616 + msg = bytes('\0'.join([conn.config['sasl_plain_username'], + conn.config['sasl_plain_username'], + conn.config['sasl_plain_password']]).encode('utf-8')) + size = Int32.encode(len(msg)) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + try: + conn._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = conn._recv_bytes_blocking(4) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError("%s: %s" % (conn, e)) + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated as %s via PLAIN', conn, conn.config['sasl_plain_username']) + return future.success(True) diff --git a/kafka/sasl/scram.py b/kafka/sasl/scram.py new file mode 100644 index 000000000..f31c80c1b --- /dev/null +++ b/kafka/sasl/scram.py @@ -0,0 +1,68 @@ +import logging +import struct + +import kafka.errors as Errors +from kafka.protocol.types import Int32 +from kafka.scram import ScramClient + +log = logging.getLogger() + + +def validate_config(conn): + assert conn.config['sasl_plain_username'] is not None, ( + 'sasl_plain_username required when sasl_mechanism=SCRAM-*' + ) + assert conn.config['sasl_plain_password'] is not None, ( + 'sasl_plain_password required when sasl_mechanism=SCRAM-*' + ) + + +def try_authenticate(conn, future): + if conn.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Exchanging credentials in the clear', conn) + + scram_client = ScramClient( + conn.config['sasl_plain_username'], + conn.config['sasl_plain_password'], + conn.config['sasl_mechanism'], + ) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + try: + client_first = scram_client.first_message().encode('utf-8') + size = Int32.encode(len(client_first)) + conn._send_bytes_blocking(size + client_first) + + (data_len,) = struct.unpack('>i', conn._recv_bytes_blocking(4)) + server_first = conn._recv_bytes_blocking(data_len).decode('utf-8') + scram_client.process_server_first_message(server_first) + + client_final = scram_client.final_message().encode('utf-8') + size = Int32.encode(len(client_final)) + conn._send_bytes_blocking(size + client_final) + + (data_len,) = struct.unpack('>i', conn._recv_bytes_blocking(4)) + server_final = conn._recv_bytes_blocking(data_len).decode('utf-8') + scram_client.process_server_final_message(server_final) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError("%s: %s" % (conn, e)) + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + log.info( + '%s: Authenticated as %s via %s', + conn, conn.config['sasl_plain_username'], conn.config['sasl_mechanism'] + ) + return future.success(True) diff --git a/test/test_msk.py b/test/test_msk.py new file mode 100644 index 000000000..4d06d4441 --- /dev/null +++ b/test/test_msk.py @@ -0,0 +1,68 @@ +import datetime +import json + +from kafka.msk import AwsMskIamClient + +try: + from unittest import mock +except ImportError: + import mock + + +def client_factory(token=None): + now = datetime.datetime.utcfromtimestamp(1629321911) + with mock.patch('kafka.msk.datetime') as mock_dt: + mock_dt.datetime.utcnow = mock.Mock(return_value=now) + return AwsMskIamClient( + host='localhost', + access_key='XXXXXXXXXXXXXXXXXXXX', + secret_key='XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX', + region='us-east-1', + token=token, + ) + + +def test_aws_msk_iam_client_permanent_credentials(): + client = client_factory(token=None) + msg = client.first_message() + assert msg + assert isinstance(msg, bytes) + actual = json.loads(msg) + + expected = { + 'version': '2020_10_22', + 'host': 'localhost', + 'user-agent': 'kafka-python', + 'action': 'kafka-cluster:Connect', + 'x-amz-algorithm': 'AWS4-HMAC-SHA256', + 'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request', + 'x-amz-date': '20210818T212511Z', + 'x-amz-signedheaders': 'host', + 'x-amz-expires': '900', + 'x-amz-signature': '0fa42ae3d5693777942a7a4028b564f0b372bafa2f71c1a19ad60680e6cb994b', + } + assert actual == expected + + +def test_aws_msk_iam_client_temporary_credentials(): + client = client_factory(token='XXXXX') + msg = client.first_message() + assert msg + assert isinstance(msg, bytes) + actual = json.loads(msg) + + expected = { + 'version': '2020_10_22', + 'host': 'localhost', + 'user-agent': 'kafka-python', + 'action': 'kafka-cluster:Connect', + 'x-amz-algorithm': 'AWS4-HMAC-SHA256', + 'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request', + 'x-amz-date': '20210818T212511Z', + 'x-amz-signedheaders': 'host', + 'x-amz-expires': '900', + 'x-amz-signature': 'b0619c50b7ecb4a7f6f92bd5f733770df5710e97b25146f97015c0b1db783b05', + 'x-amz-security-token': 'XXXXX', + } + assert actual == expected +