From a6d0579d3cadd3826dd364b01bc12a2173139abc Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Fri, 8 Mar 2024 18:30:02 -0500 Subject: [PATCH 01/17] Update README.rst --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index 78a92a884..64f4fb854 100644 --- a/README.rst +++ b/README.rst @@ -17,6 +17,7 @@ Kafka Python client :target: https://github.com/dpkp/kafka-python/blob/master/setup.py +**DUE TO ISSUES WITH RELEASES, IT IS SUGGESTED TO USE https://github.com/wbarnha/kafka-python-ng FOR THE TIME BEING** Python client for the Apache Kafka distributed stream processing system. kafka-python is designed to function much like the official java client, with a From 2f2ccb135be561501ff02b3f71611583dec9180b Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Tue, 19 Mar 2024 11:05:36 -0400 Subject: [PATCH 02/17] Support Describe log dirs (#145) I implemented API KEY 35 from the official Apache Kafka documentation. This functionality is requested in issue # 2163 and this is an implementation proposal. Co-authored-by: chopatate --- kafka/admin/client.py | 18 +++++++++++++++++- kafka/protocol/admin.py | 42 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/kafka/admin/client.py b/kafka/admin/client.py index d85935f89..5b01f8fe6 100644 --- a/kafka/admin/client.py +++ b/kafka/admin/client.py @@ -17,7 +17,7 @@ from kafka.protocol.admin import ( CreateTopicsRequest, DeleteTopicsRequest, DescribeConfigsRequest, AlterConfigsRequest, CreatePartitionsRequest, ListGroupsRequest, DescribeGroupsRequest, DescribeAclsRequest, CreateAclsRequest, DeleteAclsRequest, - DeleteGroupsRequest + DeleteGroupsRequest, DescribeLogDirsRequest ) from kafka.protocol.commit import GroupCoordinatorRequest, OffsetFetchRequest from kafka.protocol.metadata import MetadataRequest @@ -1342,3 +1342,19 @@ def _wait_for_futures(self, futures): if future.failed(): raise future.exception # pylint: disable-msg=raising-bad-type + + def describe_log_dirs(self): + """Send a DescribeLogDirsRequest request to a broker. + + :return: A message future + """ + version = self._matching_api_version(DescribeLogDirsRequest) + if version <= 1: + request = DescribeLogDirsRequest[version]() + future = self._send_request_to_node(self._client.least_loaded_node(), request) + self._wait_for_futures([future]) + else: + raise NotImplementedError( + "Support for DescribeLogDirsRequest_v{} has not yet been added to KafkaAdminClient." + .format(version)) + return future.value diff --git a/kafka/protocol/admin.py b/kafka/protocol/admin.py index 6109d90f9..bc717fc6b 100644 --- a/kafka/protocol/admin.py +++ b/kafka/protocol/admin.py @@ -788,6 +788,48 @@ class DescribeConfigsRequest_v2(Request): ] +class DescribeLogDirsResponse_v0(Response): + API_KEY = 35 + API_VERSION = 0 + FLEXIBLE_VERSION = True + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('log_dirs', Array( + ('error_code', Int16), + ('log_dir', String('utf-8')), + ('topics', Array( + ('name', String('utf-8')), + ('partitions', Array( + ('partition_index', Int32), + ('partition_size', Int64), + ('offset_lag', Int64), + ('is_future_key', Boolean) + )) + )) + )) + ) + + +class DescribeLogDirsRequest_v0(Request): + API_KEY = 35 + API_VERSION = 0 + RESPONSE_TYPE = DescribeLogDirsResponse_v0 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Int32) + )) + ) + + +DescribeLogDirsResponse = [ + DescribeLogDirsResponse_v0, +] +DescribeLogDirsRequest = [ + DescribeLogDirsRequest_v0, +] + + class SaslAuthenticateResponse_v0(Response): API_KEY = 36 API_VERSION = 0 From 025950277c9fca16b0581ba0a2e65488c3f4b41d Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Tue, 19 Mar 2024 21:35:38 -0400 Subject: [PATCH 03/17] Update conftest.py to use request.node.originalname instead for legal topic naming (#172) * Update conftest.py to use request.node.originalname instead for legal topic naming Otherwise parametrization doesn't work. * Update test/conftest.py Co-authored-by: code-review-doctor[bot] <72320148+code-review-doctor[bot]@users.noreply.github.com> --------- Co-authored-by: code-review-doctor[bot] <72320148+code-review-doctor[bot]@users.noreply.github.com> --- test/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/conftest.py b/test/conftest.py index 3fa0262fd..824c0fa76 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -119,7 +119,7 @@ def factory(**kafka_admin_client_params): @pytest.fixture def topic(kafka_broker, request): """Return a topic fixture""" - topic_name = '%s_%s' % (request.node.name, random_string(10)) + topic_name = f'{request.node.originalname}_{random_string(10)}' kafka_broker.create_topics([topic_name]) return topic_name From 3c124b2da2e99beec08a10dffd51ae3274b84e7e Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Tue, 19 Mar 2024 22:36:02 -0400 Subject: [PATCH 04/17] KIP-345 Static membership implementation (#137) * KIP-345 Add static consumer membership support * KIP-345 Add examples to docs * KIP-345 Add leave_group_on_close flag https://issues.apache.org/jira/browse/KAFKA-6995 * KIP-345 Add tests for static membership * KIP-345 Update docs for leave_group_on_close option * Update changelog.rst * remove six from base.py * Update base.py * Update base.py * Update base.py * Update changelog.rst * Update README.rst --------- Co-authored-by: Denis Kazakov Co-authored-by: Denis Kazakov --- CHANGES.md | 5 ++ README.rst | 6 +- docs/changelog.rst | 7 ++ docs/usage.rst | 12 +++ kafka/consumer/group.py | 12 ++- kafka/coordinator/base.py | 140 +++++++++++++++++++++++++++------- kafka/coordinator/consumer.py | 17 ++++- kafka/protocol/group.py | 119 +++++++++++++++++++++++++++-- test/test_consumer.py | 5 ++ test/test_consumer_group.py | 20 +++++ 10 files changed, 302 insertions(+), 41 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index ccec6b5c3..ba40007f9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,8 @@ +# 2.0.3 (under development) + +Consumer +* KIP-345: Implement static membership support + # 2.0.2 (Sep 29, 2020) Consumer diff --git a/README.rst b/README.rst index b7acfc8a2..ce82c6d3b 100644 --- a/README.rst +++ b/README.rst @@ -64,8 +64,12 @@ that expose basic message attributes: topic, partition, offset, key, and value: .. code-block:: python + # join a consumer group for dynamic partition assignment and offset commits from kafka import KafkaConsumer - consumer = KafkaConsumer('my_favorite_topic') + consumer = KafkaConsumer('my_favorite_topic', group_id='my_favorite_group') + # or as a static member with a fixed group member name + # consumer = KafkaConsumer('my_favorite_topic', group_id='my_favorite_group', + # group_instance_id='consumer-1', leave_group_on_close=False) for msg in consumer: print (msg) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9d3cb6512..67013247b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,13 @@ Changelog ========= +2.2.0 +#################### + +Consumer +-------- +* KIP-345: Implement static membership support + 2.0.2 (Sep 29, 2020) #################### diff --git a/docs/usage.rst b/docs/usage.rst index 047bbad77..dbc8813f0 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -47,6 +47,18 @@ KafkaConsumer group_id='my-group', bootstrap_servers='my.server.com') + # Use multiple static consumers w/ 2.3.0 kafka brokers + consumer1 = KafkaConsumer('my-topic', + group_id='my-group', + group_instance_id='process-1', + leave_group_on_close=False, + bootstrap_servers='my.server.com') + consumer2 = KafkaConsumer('my-topic', + group_id='my-group', + group_instance_id='process-2', + leave_group_on_close=False, + bootstrap_servers='my.server.com') + There are many configuration options for the consumer class. See :class:`~kafka.KafkaConsumer` API documentation for more details. diff --git a/kafka/consumer/group.py b/kafka/consumer/group.py index 0d613e71e..fc04c4bd6 100644 --- a/kafka/consumer/group.py +++ b/kafka/consumer/group.py @@ -52,6 +52,12 @@ class KafkaConsumer: committing offsets. If None, auto-partition assignment (via group coordinator) and offset commits are disabled. Default: None + group_instance_id (str): the unique identifier to distinguish + each client instance. If set and leave_group_on_close is + False consumer group rebalancing won't be triggered until + sessiont_timeout_ms is met. Requires 2.3.0+. + leave_group_on_close (bool or None): whether to leave a consumer + group or not on consumer shutdown. key_deserializer (callable): Any callable that takes a raw message key and returns a deserialized key. value_deserializer (callable): Any callable that takes a @@ -241,6 +247,7 @@ class KafkaConsumer: sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider instance. (See kafka.oauth.abstract). Default: None kafka_client (callable): Custom class / callable for creating KafkaClient instances + coordinator (callable): Custom class / callable for creating ConsumerCoordinator instances Note: Configuration parameters are described in more detail at @@ -250,6 +257,8 @@ class KafkaConsumer: 'bootstrap_servers': 'localhost', 'client_id': 'kafka-python-' + __version__, 'group_id': None, + 'group_instance_id': '', + 'leave_group_on_close': None, 'key_deserializer': None, 'value_deserializer': None, 'fetch_max_wait_ms': 500, @@ -304,6 +313,7 @@ class KafkaConsumer: 'sasl_oauth_token_provider': None, 'legacy_iterator': False, # enable to revert to < 1.4.7 iterator 'kafka_client': KafkaClient, + 'coordinator': ConsumerCoordinator, } DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000 @@ -379,7 +389,7 @@ def __init__(self, *topics, **configs): self._subscription = SubscriptionState(self.config['auto_offset_reset']) self._fetcher = Fetcher( self._client, self._subscription, self._metrics, **self.config) - self._coordinator = ConsumerCoordinator( + self._coordinator = self.config['coordinator']( self._client, self._subscription, self._metrics, assignors=self.config['partition_assignment_strategy'], **self.config) diff --git a/kafka/coordinator/base.py b/kafka/coordinator/base.py index 62773e330..d5ec4c720 100644 --- a/kafka/coordinator/base.py +++ b/kafka/coordinator/base.py @@ -78,6 +78,8 @@ class BaseCoordinator: DEFAULT_CONFIG = { 'group_id': 'kafka-python-default-group', + 'group_instance_id': '', + 'leave_group_on_close': None, 'session_timeout_ms': 10000, 'heartbeat_interval_ms': 3000, 'max_poll_interval_ms': 300000, @@ -92,6 +94,12 @@ def __init__(self, client, metrics, **configs): group_id (str): name of the consumer group to join for dynamic partition assignment (if enabled), and to use for fetching and committing offsets. Default: 'kafka-python-default-group' + group_instance_id (str): the unique identifier to distinguish + each client instance. If set and leave_group_on_close is + False consumer group rebalancing won't be triggered until + sessiont_timeout_ms is met. Requires 2.3.0+. + leave_group_on_close (bool or None): whether to leave a consumer + group or not on consumer shutdown. session_timeout_ms (int): The timeout used to detect failures when using Kafka's group management facilities. Default: 30000 heartbeat_interval_ms (int): The expected time in milliseconds @@ -117,6 +125,11 @@ def __init__(self, client, metrics, **configs): "different values for max_poll_interval_ms " "and session_timeout_ms") + if self.config['group_instance_id'] and self.config['api_version'] < (2, 3, 0): + raise Errors.KafkaConfigurationError( + 'Broker version %s does not support static membership' % (self.config['api_version'],), + ) + self._client = client self.group_id = self.config['group_id'] self.heartbeat = Heartbeat(**self.config) @@ -451,30 +464,48 @@ def _send_join_group_request(self): if self.config['api_version'] < (0, 9): raise Errors.KafkaError('JoinGroupRequest api requires 0.9+ brokers') elif (0, 9) <= self.config['api_version'] < (0, 10, 1): - request = JoinGroupRequest[0]( + version = 0 + args = ( self.group_id, self.config['session_timeout_ms'], self._generation.member_id, self.protocol_type(), - member_metadata) + member_metadata, + ) elif (0, 10, 1) <= self.config['api_version'] < (0, 11, 0): - request = JoinGroupRequest[1]( + version = 1 + args = ( self.group_id, self.config['session_timeout_ms'], self.config['max_poll_interval_ms'], self._generation.member_id, self.protocol_type(), - member_metadata) + member_metadata, + ) + elif self.config['api_version'] >= (2, 3, 0) and self.config['group_instance_id']: + version = 5 + args = ( + self.group_id, + self.config['session_timeout_ms'], + self.config['max_poll_interval_ms'], + self._generation.member_id, + self.config['group_instance_id'], + self.protocol_type(), + member_metadata, + ) else: - request = JoinGroupRequest[2]( + version = 2 + args = ( self.group_id, self.config['session_timeout_ms'], self.config['max_poll_interval_ms'], self._generation.member_id, self.protocol_type(), - member_metadata) + member_metadata, + ) # create the request for the coordinator + request = JoinGroupRequest[version](*args) log.debug("Sending JoinGroup (%s) to coordinator %s", request, self.coordinator_id) future = Future() _f = self._client.send(self.coordinator_id, request) @@ -558,12 +589,25 @@ def _handle_join_group_response(self, future, send_time, response): def _on_join_follower(self): # send follower's sync group with an empty assignment - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 - request = SyncGroupRequest[version]( - self.group_id, - self._generation.generation_id, - self._generation.member_id, - {}) + if self.config['api_version'] >= (2, 3, 0) and self.config['group_instance_id']: + version = 3 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + self.config['group_instance_id'], + {}, + ) + else: + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + {}, + ) + + request = SyncGroupRequest[version](*args) log.debug("Sending follower SyncGroup for group %s to coordinator %s: %s", self.group_id, self.coordinator_id, request) return self._send_sync_group_request(request) @@ -586,15 +630,30 @@ def _on_join_leader(self, response): except Exception as e: return Future().failure(e) - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 - request = SyncGroupRequest[version]( - self.group_id, - self._generation.generation_id, - self._generation.member_id, - [(member_id, - assignment if isinstance(assignment, bytes) else assignment.encode()) - for member_id, assignment in group_assignment.items()]) + group_assignment = [ + (member_id, assignment if isinstance(assignment, bytes) else assignment.encode()) + for member_id, assignment in group_assignment.items() + ] + + if self.config['api_version'] >= (2, 3, 0) and self.config['group_instance_id']: + version = 3 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + self.config['group_instance_id'], + group_assignment, + ) + else: + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + group_assignment, + ) + request = SyncGroupRequest[version](*args) log.debug("Sending leader SyncGroup for group %s to coordinator %s: %s", self.group_id, self.coordinator_id, request) return self._send_sync_group_request(request) @@ -760,15 +819,22 @@ def close(self): def maybe_leave_group(self): """Leave the current group and reset local generation/memberId.""" with self._client._lock, self._lock: - if (not self.coordinator_unknown() + if ( + not self.coordinator_unknown() and self.state is not MemberState.UNJOINED - and self._generation is not Generation.NO_GENERATION): - + and self._generation is not Generation.NO_GENERATION + and self._leave_group_on_close() + ): # this is a minimal effort attempt to leave the group. we do not # attempt any resending if the request fails or times out. log.info('Leaving consumer group (%s).', self.group_id) - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 - request = LeaveGroupRequest[version](self.group_id, self._generation.member_id) + if self.config['api_version'] >= (2, 3, 0) and self.config['group_instance_id']: + version = 3 + args = (self.group_id, [(self._generation.member_id, self.config['group_instance_id'])]) + else: + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + args = self.group_id, self._generation.member_id + request = LeaveGroupRequest[version](*args) future = self._client.send(self.coordinator_id, request) future.add_callback(self._handle_leave_group_response) future.add_errback(log.error, "LeaveGroup request failed: %s") @@ -795,10 +861,23 @@ def _send_heartbeat_request(self): e = Errors.NodeNotReadyError(self.coordinator_id) return Future().failure(e) - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 - request = HeartbeatRequest[version](self.group_id, - self._generation.generation_id, - self._generation.member_id) + if self.config['api_version'] >= (2, 3, 0) and self.config['group_instance_id']: + version = 2 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + self.config['group_instance_id'], + ) + else: + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + ) + + request = HeartbeatRequest[version](*args) log.debug("Heartbeat: %s[%s] %s", request.group, request.generation_id, request.member_id) # pylint: disable-msg=no-member future = Future() _f = self._client.send(self.coordinator_id, request) @@ -845,6 +924,9 @@ def _handle_heartbeat_response(self, future, send_time, response): log.error("Heartbeat failed: Unhandled error: %s", error) future.failure(error) + def _leave_group_on_close(self): + return self.config['leave_group_on_close'] is None or self.config['leave_group_on_close'] + class GroupCoordinatorMetrics: def __init__(self, heartbeat, metrics, prefix, tags=None): diff --git a/kafka/coordinator/consumer.py b/kafka/coordinator/consumer.py index d9a67860b..cf82b69fe 100644 --- a/kafka/coordinator/consumer.py +++ b/kafka/coordinator/consumer.py @@ -25,6 +25,8 @@ class ConsumerCoordinator(BaseCoordinator): """This class manages the coordination process with the consumer coordinator.""" DEFAULT_CONFIG = { 'group_id': 'kafka-python-default-group', + 'group_instance_id': '', + 'leave_group_on_close': None, 'enable_auto_commit': True, 'auto_commit_interval_ms': 5000, 'default_offset_commit_callback': None, @@ -45,6 +47,12 @@ def __init__(self, client, subscription, metrics, **configs): group_id (str): name of the consumer group to join for dynamic partition assignment (if enabled), and to use for fetching and committing offsets. Default: 'kafka-python-default-group' + group_instance_id (str): the unique identifier to distinguish + each client instance. If set and leave_group_on_close is + False consumer group rebalancing won't be triggered until + sessiont_timeout_ms is met. Requires 2.3.0+. + leave_group_on_close (bool or None): whether to leave a consumer + group or not on consumer shutdown. enable_auto_commit (bool): If true the consumer's offset will be periodically committed in the background. Default: True. auto_commit_interval_ms (int): milliseconds between automatic @@ -304,10 +312,15 @@ def _perform_assignment(self, leader_id, assignment_strategy, members): assert assignor, f'Invalid assignment protocol: {assignment_strategy}' member_metadata = {} all_subscribed_topics = set() - for member_id, metadata_bytes in members: + + for member in members: + if len(member) == 3: + member_id, group_instance_id, metadata_bytes = member + else: + member_id, metadata_bytes = member metadata = ConsumerProtocol.METADATA.decode(metadata_bytes) member_metadata[member_id] = metadata - all_subscribed_topics.update(metadata.subscription) # pylint: disable-msg=no-member + all_subscribed_topics.update(metadata.subscription) # pylint: disable-msg=no-member # the leader will begin watching for changes to any of the topics # the group is interested in, which ensures that all metadata changes diff --git a/kafka/protocol/group.py b/kafka/protocol/group.py index 68efdc8f9..9e698c21f 100644 --- a/kafka/protocol/group.py +++ b/kafka/protocol/group.py @@ -40,6 +40,23 @@ class JoinGroupResponse_v2(Response): ) +class JoinGroupResponse_v5(Response): + API_KEY = 11 + API_VERSION = 5 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('generation_id', Int32), + ('group_protocol', String('utf-8')), + ('leader_id', String('utf-8')), + ('member_id', String('utf-8')), + ('members', Array( + ('member_id', String('utf-8')), + ('group_instance_id', String('utf-8')), + ('member_metadata', Bytes))), + ) + + class JoinGroupRequest_v0(Request): API_KEY = 11 API_VERSION = 0 @@ -81,11 +98,30 @@ class JoinGroupRequest_v2(Request): UNKNOWN_MEMBER_ID = '' +class JoinGroupRequest_v5(Request): + API_KEY = 11 + API_VERSION = 5 + RESPONSE_TYPE = JoinGroupResponse_v5 + SCHEMA = Schema( + ('group', String('utf-8')), + ('session_timeout', Int32), + ('rebalance_timeout', Int32), + ('member_id', String('utf-8')), + ('group_instance_id', String('utf-8')), + ('protocol_type', String('utf-8')), + ('group_protocols', Array( + ('protocol_name', String('utf-8')), + ('protocol_metadata', Bytes))), + ) + UNKNOWN_MEMBER_ID = '' + + + JoinGroupRequest = [ - JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2 + JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2, None, None, JoinGroupRequest_v5, ] JoinGroupResponse = [ - JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2 + JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2, None, None, JoinGroupResponse_v5, ] @@ -116,6 +152,16 @@ class SyncGroupResponse_v1(Response): ) +class SyncGroupResponse_v3(Response): + API_KEY = 14 + API_VERSION = 3 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('member_assignment', Bytes) + ) + + class SyncGroupRequest_v0(Request): API_KEY = 14 API_VERSION = 0 @@ -137,8 +183,23 @@ class SyncGroupRequest_v1(Request): SCHEMA = SyncGroupRequest_v0.SCHEMA -SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1] -SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1] +class SyncGroupRequest_v3(Request): + API_KEY = 14 + API_VERSION = 3 + RESPONSE_TYPE = SyncGroupResponse_v3 + SCHEMA = Schema( + ('group', String('utf-8')), + ('generation_id', Int32), + ('member_id', String('utf-8')), + ('group_instance_id', String('utf-8')), + ('group_assignment', Array( + ('member_id', String('utf-8')), + ('member_metadata', Bytes))), + ) + + +SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1, None, SyncGroupRequest_v3] +SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1, None, SyncGroupResponse_v3] class MemberAssignment(Struct): @@ -186,8 +247,29 @@ class HeartbeatRequest_v1(Request): SCHEMA = HeartbeatRequest_v0.SCHEMA -HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1] -HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1] +class HeartbeatResponse_v2(Response): + API_KEY = 12 + API_VERSION = 2 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16) + ) + + +class HeartbeatRequest_v2(Request): + API_KEY = 12 + API_VERSION = 2 + RESPONSE_TYPE = HeartbeatResponse_v2 + SCHEMA = Schema( + ('group', String('utf-8')), + ('generation_id', Int32), + ('member_id', String('utf-8')), + ('group_instance_id', String('utf-8')), + ) + + +HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1, HeartbeatRequest_v2] +HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1, HeartbeatResponse_v2] class LeaveGroupResponse_v0(Response): @@ -207,6 +289,15 @@ class LeaveGroupResponse_v1(Response): ) +class LeaveGroupResponse_v3(Response): + API_KEY = 13 + API_VERSION = 3 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16) + ) + + class LeaveGroupRequest_v0(Request): API_KEY = 13 API_VERSION = 0 @@ -224,5 +315,17 @@ class LeaveGroupRequest_v1(Request): SCHEMA = LeaveGroupRequest_v0.SCHEMA -LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1] -LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1] +class LeaveGroupRequest_v3(Request): + API_KEY = 13 + API_VERSION = 3 + RESPONSE_TYPE = LeaveGroupResponse_v3 + SCHEMA = Schema( + ('group', String('utf-8')), + ('member_identity_list', Array( + ('member_id', String('utf-8')), + ('group_instance_id', String('utf-8')))), + ) + + +LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1, None, LeaveGroupRequest_v3] +LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1, None, LeaveGroupResponse_v3] diff --git a/test/test_consumer.py b/test/test_consumer.py index 436fe55c0..0c6110517 100644 --- a/test/test_consumer.py +++ b/test/test_consumer.py @@ -24,3 +24,8 @@ def test_subscription_copy(self): assert sub == set(['foo']) sub.add('fizz') assert consumer.subscription() == set(['foo']) + + def test_version_for_static_membership(self): + KafkaConsumer(bootstrap_servers='localhost:9092', api_version=(2, 3, 0), group_instance_id='test') + with pytest.raises(KafkaConfigurationError): + KafkaConsumer(bootstrap_servers='localhost:9092', api_version=(2, 2, 0), group_instance_id='test') diff --git a/test/test_consumer_group.py b/test/test_consumer_group.py index 53222b6fc..ed6863fa2 100644 --- a/test/test_consumer_group.py +++ b/test/test_consumer_group.py @@ -180,3 +180,23 @@ def test_heartbeat_thread(kafka_broker, topic): consumer.poll(timeout_ms=100) assert consumer._coordinator.heartbeat.last_poll > last_poll consumer.close() + + +@pytest.mark.skipif(env_kafka_version() < (2, 3, 0), reason="Requires KAFKA_VERSION >= 2.3.0") +@pytest.mark.parametrize('leave, result', [ + (False, True), + (True, False), +]) +def test_kafka_consumer_rebalance_for_static_members(kafka_consumer_factory, leave, result): + GROUP_ID = random_string(10) + + consumer1 = kafka_consumer_factory(group_id=GROUP_ID, group_instance_id=GROUP_ID, leave_group_on_close=leave) + consumer1.poll() + generation1 = consumer1._coordinator.generation().generation_id + consumer1.close() + + consumer2 = kafka_consumer_factory(group_id=GROUP_ID, group_instance_id=GROUP_ID, leave_group_on_close=leave) + consumer2.poll() + generation2 = consumer2._coordinator.generation().generation_id + consumer2.close() + assert (generation1 == generation2) is result From 56065dacaade9c921614791133c19f0e9e1adee9 Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Tue, 26 Mar 2024 09:44:18 -0400 Subject: [PATCH 05/17] Use monkeytype to create some semblance of typing (#173) * Add typing * define types as Struct for simplicity's sake --- kafka/coordinator/assignors/abstract.py | 2 +- .../assignors/sticky/sticky_assignor.py | 1 - kafka/errors.py | 9 +- kafka/protocol/api.py | 14 +- kafka/protocol/struct.py | 14 +- kafka/record/_crc32c.py | 6 +- kafka/record/abc.py | 14 +- kafka/record/default_records.py | 87 ++-- kafka/record/legacy_records.py | 70 +-- kafka/record/memory_records.py | 35 +- kafka/record/util.py | 12 +- kafka/sasl/msk.py | 461 +++++++++--------- kafka/util.py | 11 +- 13 files changed, 373 insertions(+), 363 deletions(-) diff --git a/kafka/coordinator/assignors/abstract.py b/kafka/coordinator/assignors/abstract.py index a1fef3840..7c38907ef 100644 --- a/kafka/coordinator/assignors/abstract.py +++ b/kafka/coordinator/assignors/abstract.py @@ -12,7 +12,7 @@ class AbstractPartitionAssignor(object): partition counts which are always needed in assignors). """ - @abc.abstractproperty + @abc.abstractmethod def name(self): """.name should be a string identifying the assignor""" pass diff --git a/kafka/coordinator/assignors/sticky/sticky_assignor.py b/kafka/coordinator/assignors/sticky/sticky_assignor.py index 033642425..e75dc2561 100644 --- a/kafka/coordinator/assignors/sticky/sticky_assignor.py +++ b/kafka/coordinator/assignors/sticky/sticky_assignor.py @@ -2,7 +2,6 @@ from collections import defaultdict, namedtuple from copy import deepcopy -from kafka.cluster import ClusterMetadata from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements from kafka.coordinator.assignors.sticky.sorted_set import SortedSet diff --git a/kafka/errors.py b/kafka/errors.py index cb3ff285f..d2f313c08 100644 --- a/kafka/errors.py +++ b/kafka/errors.py @@ -1,5 +1,6 @@ import inspect import sys +from typing import Any class KafkaError(RuntimeError): @@ -7,7 +8,7 @@ class KafkaError(RuntimeError): # whether metadata should be refreshed on error invalid_metadata = False - def __str__(self): + def __str__(self) -> str: if not self.args: return self.__class__.__name__ return '{}: {}'.format(self.__class__.__name__, @@ -65,7 +66,7 @@ class IncompatibleBrokerVersion(KafkaError): class CommitFailedError(KafkaError): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__( """Commit cannot be completed since the group has already rebalanced and assigned the partitions to another member. @@ -92,7 +93,7 @@ class BrokerResponseError(KafkaError): message = None description = None - def __str__(self): + def __str__(self) -> str: """Add errno to standard KafkaError str""" return '[Error {}] {}'.format( self.errno, @@ -509,7 +510,7 @@ def _iter_broker_errors(): kafka_errors = {x.errno: x for x in _iter_broker_errors()} -def for_code(error_code): +def for_code(error_code: int) -> Any: return kafka_errors.get(error_code, UnknownError) diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index 24cf61a62..6d6c6edca 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -52,22 +52,22 @@ class Request(Struct): FLEXIBLE_VERSION = False - @abc.abstractproperty + @abc.abstractmethod def API_KEY(self): """Integer identifier for api request""" pass - @abc.abstractproperty + @abc.abstractmethod def API_VERSION(self): """Integer of api request version""" pass - @abc.abstractproperty + @abc.abstractmethod def SCHEMA(self): """An instance of Schema() representing the request structure""" pass - @abc.abstractproperty + @abc.abstractmethod def RESPONSE_TYPE(self): """The Response class associated with the api request""" pass @@ -93,17 +93,17 @@ def parse_response_header(self, read_buffer): class Response(Struct): __metaclass__ = abc.ABCMeta - @abc.abstractproperty + @abc.abstractmethod def API_KEY(self): """Integer identifier for api request/response""" pass - @abc.abstractproperty + @abc.abstractmethod def API_VERSION(self): """Integer of api request/response version""" pass - @abc.abstractproperty + @abc.abstractmethod def SCHEMA(self): """An instance of Schema() representing the response structure""" pass diff --git a/kafka/protocol/struct.py b/kafka/protocol/struct.py index eb08ac8ef..65b3c8c63 100644 --- a/kafka/protocol/struct.py +++ b/kafka/protocol/struct.py @@ -1,15 +1,17 @@ from io import BytesIO +from typing import List, Union from kafka.protocol.abstract import AbstractType from kafka.protocol.types import Schema + from kafka.util import WeakMethod class Struct(AbstractType): SCHEMA = Schema() - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: if len(args) == len(self.SCHEMA.fields): for i, name in enumerate(self.SCHEMA.names): self.__dict__[name] = args[i] @@ -36,23 +38,23 @@ def encode(cls, item): # pylint: disable=E0202 bits.append(field.encode(item[i])) return b''.join(bits) - def _encode_self(self): + def _encode_self(self) -> bytes: return self.SCHEMA.encode( [self.__dict__[name] for name in self.SCHEMA.names] ) @classmethod - def decode(cls, data): + def decode(cls, data: Union[BytesIO, bytes]) -> "Struct": if isinstance(data, bytes): data = BytesIO(data) return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) - def get_item(self, name): + def get_item(self, name: str) -> Union[int, List[List[Union[int, str, bool, List[List[Union[int, List[int]]]]]]], str, List[List[Union[int, str]]]]: if name not in self.SCHEMA.names: raise KeyError("%s is not in the schema" % name) return self.__dict__[name] - def __repr__(self): + def __repr__(self) -> str: key_vals = [] for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields): key_vals.append(f'{name}={field.repr(self.__dict__[name])}') @@ -61,7 +63,7 @@ def __repr__(self): def __hash__(self): return hash(self.encode()) - def __eq__(self, other): + def __eq__(self, other: "Struct") -> bool: if self.SCHEMA != other.SCHEMA: return False for attr in self.SCHEMA.names: diff --git a/kafka/record/_crc32c.py b/kafka/record/_crc32c.py index 6642b5bbe..f7743044c 100644 --- a/kafka/record/_crc32c.py +++ b/kafka/record/_crc32c.py @@ -97,7 +97,7 @@ _MASK = 0xFFFFFFFF -def crc_update(crc, data): +def crc_update(crc: int, data: bytes) -> int: """Update CRC-32C checksum with data. Args: crc: 32-bit checksum to update as long. @@ -116,7 +116,7 @@ def crc_update(crc, data): return crc ^ _MASK -def crc_finalize(crc): +def crc_finalize(crc: int) -> int: """Finalize CRC-32C checksum. This function should be called as last step of crc calculation. Args: @@ -127,7 +127,7 @@ def crc_finalize(crc): return crc & _MASK -def crc(data): +def crc(data: bytes) -> int: """Compute CRC-32C checksum of the data. Args: data: byte array, string or iterable over bytes. diff --git a/kafka/record/abc.py b/kafka/record/abc.py index f45176051..4ce5144d9 100644 --- a/kafka/record/abc.py +++ b/kafka/record/abc.py @@ -5,38 +5,38 @@ class ABCRecord: __metaclass__ = abc.ABCMeta __slots__ = () - @abc.abstractproperty + @abc.abstractmethod def offset(self): """ Absolute offset of record """ - @abc.abstractproperty + @abc.abstractmethod def timestamp(self): """ Epoch milliseconds """ - @abc.abstractproperty + @abc.abstractmethod def timestamp_type(self): """ CREATE_TIME(0) or APPEND_TIME(1) """ - @abc.abstractproperty + @abc.abstractmethod def key(self): """ Bytes key or None """ - @abc.abstractproperty + @abc.abstractmethod def value(self): """ Bytes value or None """ - @abc.abstractproperty + @abc.abstractmethod def checksum(self): """ Prior to v2 format CRC was contained in every message. This will be the checksum for v0 and v1 and None for v2 and above. """ - @abc.abstractproperty + @abc.abstractmethod def headers(self): """ If supported by version list of key-value tuples, or empty list if not supported by format. diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py index 5045f31ee..91eb5c8a0 100644 --- a/kafka/record/default_records.py +++ b/kafka/record/default_records.py @@ -66,6 +66,7 @@ gzip_decode, snappy_decode, lz4_decode, zstd_decode ) import kafka.codec as codecs +from typing import Any, Callable, List, Optional, Tuple, Type, Union class DefaultRecordBase: @@ -105,7 +106,7 @@ class DefaultRecordBase: LOG_APPEND_TIME = 1 CREATE_TIME = 0 - def _assert_has_codec(self, compression_type): + def _assert_has_codec(self, compression_type: int) -> None: if compression_type == self.CODEC_GZIP: checker, name = codecs.has_gzip, "gzip" elif compression_type == self.CODEC_SNAPPY: @@ -124,7 +125,7 @@ class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): __slots__ = ("_buffer", "_header_data", "_pos", "_num_records", "_next_record_index", "_decompressed") - def __init__(self, buffer): + def __init__(self, buffer: Union[memoryview, bytes]) -> None: self._buffer = bytearray(buffer) self._header_data = self.HEADER_STRUCT.unpack_from(self._buffer) self._pos = self.HEADER_STRUCT.size @@ -133,11 +134,11 @@ def __init__(self, buffer): self._decompressed = False @property - def base_offset(self): + def base_offset(self) -> int: return self._header_data[0] @property - def magic(self): + def magic(self) -> int: return self._header_data[3] @property @@ -145,7 +146,7 @@ def crc(self): return self._header_data[4] @property - def attributes(self): + def attributes(self) -> int: return self._header_data[5] @property @@ -153,15 +154,15 @@ def last_offset_delta(self): return self._header_data[6] @property - def compression_type(self): + def compression_type(self) -> int: return self.attributes & self.CODEC_MASK @property - def timestamp_type(self): + def timestamp_type(self) -> int: return int(bool(self.attributes & self.TIMESTAMP_TYPE_MASK)) @property - def is_transactional(self): + def is_transactional(self) -> bool: return bool(self.attributes & self.TRANSACTIONAL_MASK) @property @@ -169,14 +170,14 @@ def is_control_batch(self): return bool(self.attributes & self.CONTROL_MASK) @property - def first_timestamp(self): + def first_timestamp(self) -> int: return self._header_data[7] @property def max_timestamp(self): return self._header_data[8] - def _maybe_uncompress(self): + def _maybe_uncompress(self) -> None: if not self._decompressed: compression_type = self.compression_type if compression_type != self.CODEC_NONE: @@ -196,7 +197,7 @@ def _maybe_uncompress(self): def _read_msg( self, - decode_varint=decode_varint): + decode_varint: Callable=decode_varint) -> "DefaultRecord": # Record => # Length => Varint # Attributes => Int8 @@ -272,11 +273,11 @@ def _read_msg( return DefaultRecord( offset, timestamp, self.timestamp_type, key, value, headers) - def __iter__(self): + def __iter__(self) -> "DefaultRecordBatch": self._maybe_uncompress() return self - def __next__(self): + def __next__(self) -> "DefaultRecord": if self._next_record_index >= self._num_records: if self._pos != len(self._buffer): raise CorruptRecordException( @@ -309,7 +310,7 @@ class DefaultRecord(ABCRecord): __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", "_headers") - def __init__(self, offset, timestamp, timestamp_type, key, value, headers): + def __init__(self, offset: int, timestamp: int, timestamp_type: int, key: Optional[bytes], value: bytes, headers: List[Union[Tuple[str, bytes], Any]]) -> None: self._offset = offset self._timestamp = timestamp self._timestamp_type = timestamp_type @@ -318,39 +319,39 @@ def __init__(self, offset, timestamp, timestamp_type, key, value, headers): self._headers = headers @property - def offset(self): + def offset(self) -> int: return self._offset @property - def timestamp(self): + def timestamp(self) -> int: """ Epoch milliseconds """ return self._timestamp @property - def timestamp_type(self): + def timestamp_type(self) -> int: """ CREATE_TIME(0) or APPEND_TIME(1) """ return self._timestamp_type @property - def key(self): + def key(self) -> Optional[bytes]: """ Bytes key or None """ return self._key @property - def value(self): + def value(self) -> bytes: """ Bytes value or None """ return self._value @property - def headers(self): + def headers(self) -> List[Union[Tuple[str, bytes], Any]]: return self._headers @property - def checksum(self): + def checksum(self) -> None: return None def __repr__(self): @@ -374,8 +375,8 @@ class DefaultRecordBatchBuilder(DefaultRecordBase, ABCRecordBatchBuilder): "_buffer") def __init__( - self, magic, compression_type, is_transactional, - producer_id, producer_epoch, base_sequence, batch_size): + self, magic: int, compression_type: int, is_transactional: Union[int, bool], + producer_id: int, producer_epoch: int, base_sequence: int, batch_size: int) -> None: assert magic >= 2 self._magic = magic self._compression_type = compression_type & self.CODEC_MASK @@ -393,7 +394,7 @@ def __init__( self._buffer = bytearray(self.HEADER_STRUCT.size) - def _get_attributes(self, include_compression_type=True): + def _get_attributes(self, include_compression_type: bool=True) -> int: attrs = 0 if include_compression_type: attrs |= self._compression_type @@ -403,13 +404,13 @@ def _get_attributes(self, include_compression_type=True): # Control batches are only created by Broker return attrs - def append(self, offset, timestamp, key, value, headers, + def append(self, offset: Union[int, str], timestamp: Optional[Union[int, str]], key: Optional[Union[str, bytes]], value: Optional[Union[str, bytes]], headers: List[Union[Tuple[str, bytes], Any, Tuple[str, None]]], # Cache for LOAD_FAST opcodes - encode_varint=encode_varint, size_of_varint=size_of_varint, - get_type=type, type_int=int, time_time=time.time, - byte_like=(bytes, bytearray, memoryview), - bytearray_type=bytearray, len_func=len, zero_len_varint=1 - ): + encode_varint: Callable=encode_varint, size_of_varint: Callable=size_of_varint, + get_type: Type[type]=type, type_int: Type[int]=int, time_time: Callable=time.time, + byte_like: Tuple[Type[bytes], Type[bytearray], Type[memoryview]]=(bytes, bytearray, memoryview), + bytearray_type: Type[bytearray]=bytearray, len_func: Callable=len, zero_len_varint: int=1 + ) -> Optional['DefaultRecordMetadata']: """ Write message to messageset buffer with MsgVersion 2 """ # Check types @@ -490,7 +491,7 @@ def append(self, offset, timestamp, key, value, headers, return DefaultRecordMetadata(offset, required_size, timestamp) - def write_header(self, use_compression_type=True): + def write_header(self, use_compression_type: bool=True) -> None: batch_len = len(self._buffer) self.HEADER_STRUCT.pack_into( self._buffer, 0, @@ -511,7 +512,7 @@ def write_header(self, use_compression_type=True): crc = calc_crc32c(self._buffer[self.ATTRIBUTES_OFFSET:]) struct.pack_into(">I", self._buffer, self.CRC_OFFSET, crc) - def _maybe_compress(self): + def _maybe_compress(self) -> bool: if self._compression_type != self.CODEC_NONE: self._assert_has_codec(self._compression_type) header_size = self.HEADER_STRUCT.size @@ -537,17 +538,17 @@ def _maybe_compress(self): return True return False - def build(self): + def build(self) -> bytearray: send_compressed = self._maybe_compress() self.write_header(send_compressed) return self._buffer - def size(self): + def size(self) -> int: """ Return current size of data written to buffer """ return len(self._buffer) - def size_in_bytes(self, offset, timestamp, key, value, headers): + def size_in_bytes(self, offset: int, timestamp: int, key: bytes, value: bytes, headers: List[Union[Tuple[str, bytes], Tuple[str, None]]]) -> int: if self._first_timestamp is not None: timestamp_delta = timestamp - self._first_timestamp else: @@ -561,7 +562,7 @@ def size_in_bytes(self, offset, timestamp, key, value, headers): return size_of_body + size_of_varint(size_of_body) @classmethod - def size_of(cls, key, value, headers): + def size_of(cls, key: bytes, value: bytes, headers: List[Union[Tuple[str, bytes], Tuple[str, None]]]) -> int: size = 0 # Key size if key is None: @@ -589,7 +590,7 @@ def size_of(cls, key, value, headers): return size @classmethod - def estimate_size_in_bytes(cls, key, value, headers): + def estimate_size_in_bytes(cls, key: bytes, value: bytes, headers: List[Tuple[str, bytes]]) -> int: """ Get the upper bound estimate on the size of record """ return ( @@ -602,28 +603,28 @@ class DefaultRecordMetadata: __slots__ = ("_size", "_timestamp", "_offset") - def __init__(self, offset, size, timestamp): + def __init__(self, offset: int, size: int, timestamp: int) -> None: self._offset = offset self._size = size self._timestamp = timestamp @property - def offset(self): + def offset(self) -> int: return self._offset @property - def crc(self): + def crc(self) -> None: return None @property - def size(self): + def size(self) -> int: return self._size @property - def timestamp(self): + def timestamp(self) -> int: return self._timestamp - def __repr__(self): + def __repr__(self) -> str: return ( "DefaultRecordMetadata(offset={!r}, size={!r}, timestamp={!r})" .format(self._offset, self._size, self._timestamp) diff --git a/kafka/record/legacy_records.py b/kafka/record/legacy_records.py index 9ab8873ca..b77799f4d 100644 --- a/kafka/record/legacy_records.py +++ b/kafka/record/legacy_records.py @@ -44,6 +44,7 @@ import struct import time + from kafka.record.abc import ABCRecord, ABCRecordBatch, ABCRecordBatchBuilder from kafka.record.util import calc_crc32 @@ -53,6 +54,7 @@ ) import kafka.codec as codecs from kafka.errors import CorruptRecordException, UnsupportedCodecError +from typing import Any, Iterator, List, Optional, Tuple, Union class LegacyRecordBase: @@ -115,7 +117,7 @@ class LegacyRecordBase: NO_TIMESTAMP = -1 - def _assert_has_codec(self, compression_type): + def _assert_has_codec(self, compression_type: int) -> None: if compression_type == self.CODEC_GZIP: checker, name = codecs.has_gzip, "gzip" elif compression_type == self.CODEC_SNAPPY: @@ -132,7 +134,7 @@ class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): __slots__ = ("_buffer", "_magic", "_offset", "_crc", "_timestamp", "_attributes", "_decompressed") - def __init__(self, buffer, magic): + def __init__(self, buffer: Union[memoryview, bytes], magic: int) -> None: self._buffer = memoryview(buffer) self._magic = magic @@ -147,7 +149,7 @@ def __init__(self, buffer, magic): self._decompressed = False @property - def timestamp_type(self): + def timestamp_type(self) -> Optional[int]: """0 for CreateTime; 1 for LogAppendTime; None if unsupported. Value is determined by broker; produced messages should always set to 0 @@ -161,14 +163,14 @@ def timestamp_type(self): return 0 @property - def compression_type(self): + def compression_type(self) -> int: return self._attributes & self.CODEC_MASK def validate_crc(self): crc = calc_crc32(self._buffer[self.MAGIC_OFFSET:]) return self._crc == crc - def _decompress(self, key_offset): + def _decompress(self, key_offset: int) -> bytes: # Copy of `_read_key_value`, but uses memoryview pos = key_offset key_size = struct.unpack_from(">i", self._buffer, pos)[0] @@ -195,7 +197,7 @@ def _decompress(self, key_offset): uncompressed = lz4_decode(data.tobytes()) return uncompressed - def _read_header(self, pos): + def _read_header(self, pos: int) -> Union[Tuple[int, int, int, int, int, None], Tuple[int, int, int, int, int, int]]: if self._magic == 0: offset, length, crc, magic_read, attrs = \ self.HEADER_STRUCT_V0.unpack_from(self._buffer, pos) @@ -205,7 +207,7 @@ def _read_header(self, pos): self.HEADER_STRUCT_V1.unpack_from(self._buffer, pos) return offset, length, crc, magic_read, attrs, timestamp - def _read_all_headers(self): + def _read_all_headers(self) -> List[Union[Tuple[Tuple[int, int, int, int, int, int], int], Tuple[Tuple[int, int, int, int, int, None], int]]]: pos = 0 msgs = [] buffer_len = len(self._buffer) @@ -215,7 +217,7 @@ def _read_all_headers(self): pos += self.LOG_OVERHEAD + header[1] # length return msgs - def _read_key_value(self, pos): + def _read_key_value(self, pos: int) -> Union[Tuple[None, bytes], Tuple[bytes, bytes]]: key_size = struct.unpack_from(">i", self._buffer, pos)[0] pos += self.KEY_LENGTH if key_size == -1: @@ -232,7 +234,7 @@ def _read_key_value(self, pos): value = self._buffer[pos:pos + value_size].tobytes() return key, value - def __iter__(self): + def __iter__(self) -> Iterator[LegacyRecordBase]: if self._magic == 1: key_offset = self.KEY_OFFSET_V1 else: @@ -286,7 +288,7 @@ class LegacyRecord(ABCRecord): __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", "_crc") - def __init__(self, offset, timestamp, timestamp_type, key, value, crc): + def __init__(self, offset: int, timestamp: Optional[int], timestamp_type: Optional[int], key: Optional[bytes], value: bytes, crc: int) -> None: self._offset = offset self._timestamp = timestamp self._timestamp_type = timestamp_type @@ -295,39 +297,39 @@ def __init__(self, offset, timestamp, timestamp_type, key, value, crc): self._crc = crc @property - def offset(self): + def offset(self) -> int: return self._offset @property - def timestamp(self): + def timestamp(self) -> Optional[int]: """ Epoch milliseconds """ return self._timestamp @property - def timestamp_type(self): + def timestamp_type(self) -> Optional[int]: """ CREATE_TIME(0) or APPEND_TIME(1) """ return self._timestamp_type @property - def key(self): + def key(self) -> Optional[bytes]: """ Bytes key or None """ return self._key @property - def value(self): + def value(self) -> bytes: """ Bytes value or None """ return self._value @property - def headers(self): + def headers(self) -> List[Any]: return [] @property - def checksum(self): + def checksum(self) -> int: return self._crc def __repr__(self): @@ -343,13 +345,13 @@ class LegacyRecordBatchBuilder(ABCRecordBatchBuilder, LegacyRecordBase): __slots__ = ("_magic", "_compression_type", "_batch_size", "_buffer") - def __init__(self, magic, compression_type, batch_size): + def __init__(self, magic: int, compression_type: int, batch_size: int) -> None: self._magic = magic self._compression_type = compression_type self._batch_size = batch_size self._buffer = bytearray() - def append(self, offset, timestamp, key, value, headers=None): + def append(self, offset: Union[int, str], timestamp: Optional[Union[int, str]], key: Optional[Union[bytes, str]], value: Optional[Union[str, bytes]], headers: None=None) -> Optional['LegacyRecordMetadata']: """ Append message to batch. """ assert not headers, "Headers not supported in v0/v1" @@ -388,8 +390,8 @@ def append(self, offset, timestamp, key, value, headers=None): return LegacyRecordMetadata(offset, crc, size, timestamp) - def _encode_msg(self, start_pos, offset, timestamp, key, value, - attributes=0): + def _encode_msg(self, start_pos: int, offset: int, timestamp: int, key: Optional[bytes], value: Optional[bytes], + attributes: int=0) -> int: """ Encode msg data into the `msg_buffer`, which should be allocated to at least the size of this message. """ @@ -437,7 +439,7 @@ def _encode_msg(self, start_pos, offset, timestamp, key, value, struct.pack_into(">I", buf, start_pos + self.CRC_OFFSET, crc) return crc - def _maybe_compress(self): + def _maybe_compress(self) -> bool: if self._compression_type: self._assert_has_codec(self._compression_type) data = bytes(self._buffer) @@ -464,19 +466,19 @@ def _maybe_compress(self): return True return False - def build(self): + def build(self) -> bytearray: """Compress batch to be ready for send""" self._maybe_compress() return self._buffer - def size(self): + def size(self) -> int: """ Return current size of data written to buffer """ return len(self._buffer) # Size calculations. Just copied Java's implementation - def size_in_bytes(self, offset, timestamp, key, value, headers=None): + def size_in_bytes(self, offset: int, timestamp: int, key: Optional[bytes], value: Optional[bytes], headers: None=None) -> int: """ Actual size of message to add """ assert not headers, "Headers not supported in v0/v1" @@ -484,7 +486,7 @@ def size_in_bytes(self, offset, timestamp, key, value, headers=None): return self.LOG_OVERHEAD + self.record_size(magic, key, value) @classmethod - def record_size(cls, magic, key, value): + def record_size(cls, magic: int, key: Optional[bytes], value: Optional[bytes]) -> int: message_size = cls.record_overhead(magic) if key is not None: message_size += len(key) @@ -493,7 +495,7 @@ def record_size(cls, magic, key, value): return message_size @classmethod - def record_overhead(cls, magic): + def record_overhead(cls, magic: int) -> int: assert magic in [0, 1], "Not supported magic" if magic == 0: return cls.RECORD_OVERHEAD_V0 @@ -501,7 +503,7 @@ def record_overhead(cls, magic): return cls.RECORD_OVERHEAD_V1 @classmethod - def estimate_size_in_bytes(cls, magic, compression_type, key, value): + def estimate_size_in_bytes(cls, magic: int, compression_type: int, key: bytes, value: bytes) -> int: """ Upper bound estimate of record size. """ assert magic in [0, 1], "Not supported magic" @@ -518,29 +520,29 @@ class LegacyRecordMetadata: __slots__ = ("_crc", "_size", "_timestamp", "_offset") - def __init__(self, offset, crc, size, timestamp): + def __init__(self, offset: int, crc: int, size: int, timestamp: int) -> None: self._offset = offset self._crc = crc self._size = size self._timestamp = timestamp @property - def offset(self): + def offset(self) -> int: return self._offset @property - def crc(self): + def crc(self) -> int: return self._crc @property - def size(self): + def size(self) -> int: return self._size @property - def timestamp(self): + def timestamp(self) -> int: return self._timestamp - def __repr__(self): + def __repr__(self) -> str: return ( "LegacyRecordMetadata(offset={!r}, crc={!r}, size={!r}," " timestamp={!r})".format( diff --git a/kafka/record/memory_records.py b/kafka/record/memory_records.py index 7a604887c..a915ed44f 100644 --- a/kafka/record/memory_records.py +++ b/kafka/record/memory_records.py @@ -23,8 +23,9 @@ from kafka.errors import CorruptRecordException from kafka.record.abc import ABCRecords -from kafka.record.legacy_records import LegacyRecordBatch, LegacyRecordBatchBuilder -from kafka.record.default_records import DefaultRecordBatch, DefaultRecordBatchBuilder +from kafka.record.legacy_records import LegacyRecordMetadata, LegacyRecordBatch, LegacyRecordBatchBuilder +from kafka.record.default_records import DefaultRecordMetadata, DefaultRecordBatch, DefaultRecordBatchBuilder +from typing import Any, List, Optional, Union class MemoryRecords(ABCRecords): @@ -38,7 +39,7 @@ class MemoryRecords(ABCRecords): __slots__ = ("_buffer", "_pos", "_next_slice", "_remaining_bytes") - def __init__(self, bytes_data): + def __init__(self, bytes_data: bytes) -> None: self._buffer = bytes_data self._pos = 0 # We keep one slice ahead so `has_next` will return very fast @@ -46,10 +47,10 @@ def __init__(self, bytes_data): self._remaining_bytes = None self._cache_next() - def size_in_bytes(self): + def size_in_bytes(self) -> int: return len(self._buffer) - def valid_bytes(self): + def valid_bytes(self) -> int: # We need to read the whole buffer to get the valid_bytes. # NOTE: in Fetcher we do the call after iteration, so should be fast if self._remaining_bytes is None: @@ -64,7 +65,7 @@ def valid_bytes(self): # NOTE: we cache offsets here as kwargs for a bit more speed, as cPython # will use LOAD_FAST opcode in this case - def _cache_next(self, len_offset=LENGTH_OFFSET, log_overhead=LOG_OVERHEAD): + def _cache_next(self, len_offset: int=LENGTH_OFFSET, log_overhead: int=LOG_OVERHEAD) -> None: buffer = self._buffer buffer_len = len(buffer) pos = self._pos @@ -88,12 +89,12 @@ def _cache_next(self, len_offset=LENGTH_OFFSET, log_overhead=LOG_OVERHEAD): self._next_slice = memoryview(buffer)[pos: slice_end] self._pos = slice_end - def has_next(self): + def has_next(self) -> bool: return self._next_slice is not None # NOTE: same cache for LOAD_FAST as above - def next_batch(self, _min_slice=MIN_SLICE, - _magic_offset=MAGIC_OFFSET): + def next_batch(self, _min_slice: int=MIN_SLICE, + _magic_offset: int=MAGIC_OFFSET) -> Optional[Union[DefaultRecordBatch, LegacyRecordBatch]]: next_slice = self._next_slice if next_slice is None: return None @@ -114,7 +115,7 @@ class MemoryRecordsBuilder: __slots__ = ("_builder", "_batch_size", "_buffer", "_next_offset", "_closed", "_bytes_written") - def __init__(self, magic, compression_type, batch_size): + def __init__(self, magic: int, compression_type: int, batch_size: int) -> None: assert magic in [0, 1, 2], "Not supported magic" assert compression_type in [0, 1, 2, 3, 4], "Not valid compression type" if magic >= 2: @@ -133,7 +134,7 @@ def __init__(self, magic, compression_type, batch_size): self._closed = False self._bytes_written = 0 - def append(self, timestamp, key, value, headers=[]): + def append(self, timestamp: Optional[int], key: Optional[Union[str, bytes]], value: Union[str, bytes], headers: List[Any]=[]) -> Optional[Union[DefaultRecordMetadata, LegacyRecordMetadata]]: """ Append a message to the buffer. Returns: RecordMetadata or None if unable to append @@ -150,7 +151,7 @@ def append(self, timestamp, key, value, headers=[]): self._next_offset += 1 return metadata - def close(self): + def close(self) -> None: # This method may be called multiple times on the same batch # i.e., on retries # we need to make sure we only close it out once @@ -162,25 +163,25 @@ def close(self): self._builder = None self._closed = True - def size_in_bytes(self): + def size_in_bytes(self) -> int: if not self._closed: return self._builder.size() else: return len(self._buffer) - def compression_rate(self): + def compression_rate(self) -> float: assert self._closed return self.size_in_bytes() / self._bytes_written - def is_full(self): + def is_full(self) -> bool: if self._closed: return True else: return self._builder.size() >= self._batch_size - def next_offset(self): + def next_offset(self) -> int: return self._next_offset - def buffer(self): + def buffer(self) -> bytes: assert self._closed return self._buffer diff --git a/kafka/record/util.py b/kafka/record/util.py index 3b712005d..d032151f1 100644 --- a/kafka/record/util.py +++ b/kafka/record/util.py @@ -1,13 +1,15 @@ import binascii from kafka.record._crc32c import crc as crc32c_py +from typing import Callable, Tuple + try: from crc32c import crc32c as crc32c_c except ImportError: crc32c_c = None -def encode_varint(value, write): +def encode_varint(value: int, write: Callable) -> int: """ Encode an integer to a varint presentation. See https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints on how those can be produced. @@ -60,7 +62,7 @@ def encode_varint(value, write): return i -def size_of_varint(value): +def size_of_varint(value: int) -> int: """ Number of bytes needed to encode an integer in variable-length format. """ value = (value << 1) ^ (value >> 63) @@ -85,7 +87,7 @@ def size_of_varint(value): return 10 -def decode_varint(buffer, pos=0): +def decode_varint(buffer: bytearray, pos: int=0) -> Tuple[int, int]: """ Decode an integer from a varint presentation. See https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints on how those can be produced. @@ -122,13 +124,13 @@ def decode_varint(buffer, pos=0): _crc32c = crc32c_c -def calc_crc32c(memview, _crc32c=_crc32c): +def calc_crc32c(memview: bytearray, _crc32c: Callable=_crc32c) -> int: """ Calculate CRC-32C (Castagnoli) checksum over a memoryview of data """ return _crc32c(memview) -def calc_crc32(memview): +def calc_crc32(memview: memoryview) -> int: """ Calculate simple CRC-32 checksum over a memoryview of data """ crc = binascii.crc32(memview) & 0xffffffff diff --git a/kafka/sasl/msk.py b/kafka/sasl/msk.py index 6d1bb74fb..ebea5dc5a 100644 --- a/kafka/sasl/msk.py +++ b/kafka/sasl/msk.py @@ -1,230 +1,231 @@ -import datetime -import hashlib -import hmac -import json -import string -import struct -import logging -import urllib - -from kafka.protocol.types import Int32 -import kafka.errors as Errors - -from botocore.session import Session as BotoSession # importing it in advance is not an option apparently... - - -def try_authenticate(self, future): - - session = BotoSession() - credentials = session.get_credentials().get_frozen_credentials() - client = AwsMskIamClient( - host=self.host, - access_key=credentials.access_key, - secret_key=credentials.secret_key, - region=session.get_config_variable('region'), - token=credentials.token, - ) - - msg = client.first_message() - size = Int32.encode(len(msg)) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - self._send_bytes_blocking(size + msg) - data = self._recv_bytes_blocking(4) - data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1]) - except (ConnectionError, TimeoutError) as e: - logging.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError(f"{self}: {e}") - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - logging.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8')) - return future.success(True) - - -class AwsMskIamClient: - UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~' - - def __init__(self, host, access_key, secret_key, region, token=None): - """ - Arguments: - host (str): The hostname of the broker. - access_key (str): An AWS_ACCESS_KEY_ID. - secret_key (str): An AWS_SECRET_ACCESS_KEY. - region (str): An AWS_REGION. - token (Optional[str]): An AWS_SESSION_TOKEN if using temporary - credentials. - """ - self.algorithm = 'AWS4-HMAC-SHA256' - self.expires = '900' - self.hashfunc = hashlib.sha256 - self.headers = [ - ('host', host) - ] - self.version = '2020_10_22' - - self.service = 'kafka-cluster' - self.action = f'{self.service}:Connect' - - now = datetime.datetime.utcnow() - self.datestamp = now.strftime('%Y%m%d') - self.timestamp = now.strftime('%Y%m%dT%H%M%SZ') - - self.host = host - self.access_key = access_key - self.secret_key = secret_key - self.region = region - self.token = token - - @property - def _credential(self): - return '{0.access_key}/{0._scope}'.format(self) - - @property - def _scope(self): - return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self) - - @property - def _signed_headers(self): - """ - Returns (str): - An alphabetically sorted, semicolon-delimited list of lowercase - request header names. - """ - return ';'.join(sorted(k.lower() for k, _ in self.headers)) - - @property - def _canonical_headers(self): - """ - Returns (str): - A newline-delited list of header names and values. - Header names are lowercased. - """ - return '\n'.join(map(':'.join, self.headers)) + '\n' - - @property - def _canonical_request(self): - """ - Returns (str): - An AWS Signature Version 4 canonical request in the format: - \n - \n - \n - \n - \n - - """ - # The hashed_payload is always an empty string for MSK. - hashed_payload = self.hashfunc(b'').hexdigest() - return '\n'.join(( - 'GET', - '/', - self._canonical_querystring, - self._canonical_headers, - self._signed_headers, - hashed_payload, - )) - - @property - def _canonical_querystring(self): - """ - Returns (str): - A '&'-separated list of URI-encoded key/value pairs. - """ - params = [] - params.append(('Action', self.action)) - params.append(('X-Amz-Algorithm', self.algorithm)) - params.append(('X-Amz-Credential', self._credential)) - params.append(('X-Amz-Date', self.timestamp)) - params.append(('X-Amz-Expires', self.expires)) - if self.token: - params.append(('X-Amz-Security-Token', self.token)) - params.append(('X-Amz-SignedHeaders', self._signed_headers)) - - return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params) - - @property - def _signing_key(self): - """ - Returns (bytes): - An AWS Signature V4 signing key generated from the secret_key, date, - region, service, and request type. - """ - key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp) - key = self._hmac(key, self.region) - key = self._hmac(key, self.service) - key = self._hmac(key, 'aws4_request') - return key - - @property - def _signing_str(self): - """ - Returns (str): - A string used to sign the AWS Signature V4 payload in the format: - \n - \n - \n - - """ - canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest() - return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash)) - - def _uriencode(self, msg): - """ - Arguments: - msg (str): A string to URI-encode. - - Returns (str): - The URI-encoded version of the provided msg, following the encoding - rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode - """ - return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS) - - def _hmac(self, key, msg): - """ - Arguments: - key (bytes): A key to use for the HMAC digest. - msg (str): A value to include in the HMAC digest. - Returns (bytes): - An HMAC digest of the given key and msg. - """ - return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest() - - def first_message(self): - """ - Returns (bytes): - An encoded JSON authentication payload that can be sent to the - broker. - """ - signature = hmac.new( - self._signing_key, - self._signing_str.encode('utf-8'), - digestmod=self.hashfunc, - ).hexdigest() - msg = { - 'version': self.version, - 'host': self.host, - 'user-agent': 'kafka-python', - 'action': self.action, - 'x-amz-algorithm': self.algorithm, - 'x-amz-credential': self._credential, - 'x-amz-date': self.timestamp, - 'x-amz-signedheaders': self._signed_headers, - 'x-amz-expires': self.expires, - 'x-amz-signature': signature, - } - if self.token: - msg['x-amz-security-token'] = self.token - - return json.dumps(msg, separators=(',', ':')).encode('utf-8') +import datetime +import hashlib +import hmac +import json +import string +import struct +import logging +import urllib + +from kafka.protocol.types import Int32 +import kafka.errors as Errors + +from botocore.session import Session as BotoSession # importing it in advance is not an option apparently... +from typing import Optional + + +def try_authenticate(self, future): + + session = BotoSession() + credentials = session.get_credentials().get_frozen_credentials() + client = AwsMskIamClient( + host=self.host, + access_key=credentials.access_key, + secret_key=credentials.secret_key, + region=session.get_config_variable('region'), + token=credentials.token, + ) + + msg = client.first_message() + size = Int32.encode(len(msg)) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + self._send_bytes_blocking(size + msg) + data = self._recv_bytes_blocking(4) + data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1]) + except (ConnectionError, TimeoutError) as e: + logging.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError(f"{self}: {e}") + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + logging.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8')) + return future.success(True) + + +class AwsMskIamClient: + UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~' + + def __init__(self, host: str, access_key: str, secret_key: str, region: str, token: Optional[str]=None) -> None: + """ + Arguments: + host (str): The hostname of the broker. + access_key (str): An AWS_ACCESS_KEY_ID. + secret_key (str): An AWS_SECRET_ACCESS_KEY. + region (str): An AWS_REGION. + token (Optional[str]): An AWS_SESSION_TOKEN if using temporary + credentials. + """ + self.algorithm = 'AWS4-HMAC-SHA256' + self.expires = '900' + self.hashfunc = hashlib.sha256 + self.headers = [ + ('host', host) + ] + self.version = '2020_10_22' + + self.service = 'kafka-cluster' + self.action = f'{self.service}:Connect' + + now = datetime.datetime.utcnow() + self.datestamp = now.strftime('%Y%m%d') + self.timestamp = now.strftime('%Y%m%dT%H%M%SZ') + + self.host = host + self.access_key = access_key + self.secret_key = secret_key + self.region = region + self.token = token + + @property + def _credential(self) -> str: + return '{0.access_key}/{0._scope}'.format(self) + + @property + def _scope(self) -> str: + return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self) + + @property + def _signed_headers(self) -> str: + """ + Returns (str): + An alphabetically sorted, semicolon-delimited list of lowercase + request header names. + """ + return ';'.join(sorted(k.lower() for k, _ in self.headers)) + + @property + def _canonical_headers(self) -> str: + """ + Returns (str): + A newline-delited list of header names and values. + Header names are lowercased. + """ + return '\n'.join(map(':'.join, self.headers)) + '\n' + + @property + def _canonical_request(self) -> str: + """ + Returns (str): + An AWS Signature Version 4 canonical request in the format: + \n + \n + \n + \n + \n + + """ + # The hashed_payload is always an empty string for MSK. + hashed_payload = self.hashfunc(b'').hexdigest() + return '\n'.join(( + 'GET', + '/', + self._canonical_querystring, + self._canonical_headers, + self._signed_headers, + hashed_payload, + )) + + @property + def _canonical_querystring(self) -> str: + """ + Returns (str): + A '&'-separated list of URI-encoded key/value pairs. + """ + params = [] + params.append(('Action', self.action)) + params.append(('X-Amz-Algorithm', self.algorithm)) + params.append(('X-Amz-Credential', self._credential)) + params.append(('X-Amz-Date', self.timestamp)) + params.append(('X-Amz-Expires', self.expires)) + if self.token: + params.append(('X-Amz-Security-Token', self.token)) + params.append(('X-Amz-SignedHeaders', self._signed_headers)) + + return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params) + + @property + def _signing_key(self) -> bytes: + """ + Returns (bytes): + An AWS Signature V4 signing key generated from the secret_key, date, + region, service, and request type. + """ + key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp) + key = self._hmac(key, self.region) + key = self._hmac(key, self.service) + key = self._hmac(key, 'aws4_request') + return key + + @property + def _signing_str(self) -> str: + """ + Returns (str): + A string used to sign the AWS Signature V4 payload in the format: + \n + \n + \n + + """ + canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest() + return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash)) + + def _uriencode(self, msg: str) -> str: + """ + Arguments: + msg (str): A string to URI-encode. + + Returns (str): + The URI-encoded version of the provided msg, following the encoding + rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode + """ + return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS) + + def _hmac(self, key: bytes, msg: str) -> bytes: + """ + Arguments: + key (bytes): A key to use for the HMAC digest. + msg (str): A value to include in the HMAC digest. + Returns (bytes): + An HMAC digest of the given key and msg. + """ + return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest() + + def first_message(self) -> bytes: + """ + Returns (bytes): + An encoded JSON authentication payload that can be sent to the + broker. + """ + signature = hmac.new( + self._signing_key, + self._signing_str.encode('utf-8'), + digestmod=self.hashfunc, + ).hexdigest() + msg = { + 'version': self.version, + 'host': self.host, + 'user-agent': 'kafka-python', + 'action': self.action, + 'x-amz-algorithm': self.algorithm, + 'x-amz-credential': self._credential, + 'x-amz-date': self.timestamp, + 'x-amz-signedheaders': self._signed_headers, + 'x-amz-expires': self.expires, + 'x-amz-signature': signature, + } + if self.token: + msg['x-amz-security-token'] = self.token + + return json.dumps(msg, separators=(',', ':')).encode('utf-8') diff --git a/kafka/util.py b/kafka/util.py index 0c9c5ea62..968787341 100644 --- a/kafka/util.py +++ b/kafka/util.py @@ -1,11 +1,12 @@ import binascii import weakref +from typing import Callable, Optional MAX_INT = 2 ** 31 TO_SIGNED = 2 ** 32 -def crc32(data): +def crc32(data: bytes) -> int: crc = binascii.crc32(data) # py2 and py3 behave a little differently # CRC is encoded as a signed int in kafka protocol @@ -24,7 +25,7 @@ class WeakMethod: object_dot_method: A bound instance method (i.e. 'object.method'). """ - def __init__(self, object_dot_method): + def __init__(self, object_dot_method: Callable) -> None: try: self.target = weakref.ref(object_dot_method.__self__) except AttributeError: @@ -36,16 +37,16 @@ def __init__(self, object_dot_method): self.method = weakref.ref(object_dot_method.im_func) self._method_id = id(self.method()) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Optional[bytes]: """ Calls the method on target with args and kwargs. """ return self.method()(self.target(), *args, **kwargs) - def __hash__(self): + def __hash__(self) -> int: return hash(self.target) ^ hash(self.method) - def __eq__(self, other): + def __eq__(self, other: "WeakMethod") -> bool: if not isinstance(other, WeakMethod): return False return self._target_id == other._target_id and self._method_id == other._method_id From cbf317bd4c42de4c37f675f31824c0728376acc5 Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Tue, 26 Mar 2024 19:26:20 -0400 Subject: [PATCH 06/17] Add zstd support on legacy record and ensure no variable is referred before definition (#138) * fix if statement logic and add zstd check * fix if statement logic and add zstd uncompress * fix imports * avoid variable be used before definition * Remove unused import from legacy_records.py --------- Co-authored-by: Alexandre Souza --- kafka/record/default_records.py | 4 ++++ kafka/record/legacy_records.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py index 91eb5c8a0..8b630cc8b 100644 --- a/kafka/record/default_records.py +++ b/kafka/record/default_records.py @@ -115,6 +115,8 @@ def _assert_has_codec(self, compression_type: int) -> None: checker, name = codecs.has_lz4, "lz4" elif compression_type == self.CODEC_ZSTD: checker, name = codecs.has_zstd, "zstd" + else: + checker, name = lambda: False, "Unknown" if not checker(): raise UnsupportedCodecError( f"Libraries for {name} compression codec not found") @@ -525,6 +527,8 @@ def _maybe_compress(self) -> bool: compressed = lz4_encode(data) elif self._compression_type == self.CODEC_ZSTD: compressed = zstd_encode(data) + else: + compressed = '' # unknown compressed_size = len(compressed) if len(data) <= compressed_size: # We did not get any benefit from compression, lets send diff --git a/kafka/record/legacy_records.py b/kafka/record/legacy_records.py index b77799f4d..4439462f6 100644 --- a/kafka/record/legacy_records.py +++ b/kafka/record/legacy_records.py @@ -49,8 +49,8 @@ from kafka.record.util import calc_crc32 from kafka.codec import ( - gzip_encode, snappy_encode, lz4_encode, lz4_encode_old_kafka, - gzip_decode, snappy_decode, lz4_decode, lz4_decode_old_kafka, + gzip_encode, snappy_encode, lz4_encode, lz4_encode_old_kafka, zstd_encode, + gzip_decode, snappy_decode, lz4_decode, lz4_decode_old_kafka, zstd_decode ) import kafka.codec as codecs from kafka.errors import CorruptRecordException, UnsupportedCodecError @@ -110,6 +110,7 @@ class LegacyRecordBase: CODEC_GZIP = 0x01 CODEC_SNAPPY = 0x02 CODEC_LZ4 = 0x03 + CODEC_ZSTD = 0x04 TIMESTAMP_TYPE_MASK = 0x08 LOG_APPEND_TIME = 1 @@ -124,6 +125,10 @@ def _assert_has_codec(self, compression_type: int) -> None: checker, name = codecs.has_snappy, "snappy" elif compression_type == self.CODEC_LZ4: checker, name = codecs.has_lz4, "lz4" + elif compression_type == self.CODEC_ZSTD: + checker, name = codecs.has_zstd, "zstd" + else: + checker, name = lambda: False, "Unknown" if not checker(): raise UnsupportedCodecError( f"Libraries for {name} compression codec not found") @@ -195,6 +200,10 @@ def _decompress(self, key_offset: int) -> bytes: uncompressed = lz4_decode_old_kafka(data.tobytes()) else: uncompressed = lz4_decode(data.tobytes()) + elif compression_type == self.CODEC_ZSTD: + uncompressed = zstd_decode(data) + else: + raise ValueError("Unknown Compression Type - %s" % compression_type) return uncompressed def _read_header(self, pos: int) -> Union[Tuple[int, int, int, int, int, None], Tuple[int, int, int, int, int, int]]: From af1a5f04971012e85714010ad2e5c8f64291faca Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Tue, 26 Mar 2024 20:47:40 -0400 Subject: [PATCH 07/17] Update __init__.py of SASL to catch ImportErrors in case botocore is not installed (#175) Closes https://github.com/wbarnha/kafka-python-ng/issues/174. --- kafka/sasl/__init__.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/kafka/sasl/__init__.py b/kafka/sasl/__init__.py index 337c90949..dc9456d5a 100644 --- a/kafka/sasl/__init__.py +++ b/kafka/sasl/__init__.py @@ -1,8 +1,6 @@ import logging -from kafka.sasl import gssapi, oauthbearer, plain, scram, msk - -log = logging.getLogger(__name__) +from kafka.sasl import gssapi, oauthbearer, plain, scram MECHANISMS = { 'GSSAPI': gssapi, @@ -10,9 +8,16 @@ 'PLAIN': plain, 'SCRAM-SHA-256': scram, 'SCRAM-SHA-512': scram, - 'AWS_MSK_IAM': msk, } +try: + from kafka.sasl import msk + MECHANISMS['AWS_MSK_IAM'] = msk +except ImportError: + pass + +log = logging.getLogger(__name__) + def register_mechanism(key, module): """ From aba153f95d8029465a7ee694fc046f13473acd52 Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Tue, 26 Mar 2024 20:48:41 -0400 Subject: [PATCH 08/17] Add botocore to extras in setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index dd4e5de90..4398b1ced 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ def run(cls): "lz4": ["lz4"], "snappy": ["python-snappy"], "zstd": ["zstandard"], + "boto": ["botocore"], }, cmdclass={"test": Tox}, packages=find_packages(exclude=['test']), From 6c9eb376ff5652182c0a50285ac5dba58f4f69e9 Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Wed, 3 Apr 2024 10:33:53 -0400 Subject: [PATCH 09/17] Add connection_timeout_ms and reset the timeout counter more often (#132) * Add connection_timeout_ms and reset the timeout counter more often * Refactor last_attempt -> last_activity This semantically reflects the new usage of the variable better * Make tests work again * Add unit tests of new BrokerConnection functionality The test mocks parts of BrokerConnection in order to assert that the connection state machine allows long-lasting connections as long as the state progresses often enough * Re-introduce last_attempt to avoid breakage --------- Co-authored-by: Liam S. Crouch --- kafka/client_async.py | 4 ++ kafka/conn.py | 24 ++++++++--- kafka/producer/kafka.py | 4 ++ test/test_conn.py | 88 +++++++++++++++++++++++++++++++++++++---- 4 files changed, 107 insertions(+), 13 deletions(-) diff --git a/kafka/client_async.py b/kafka/client_async.py index b395dc5da..b46b879f9 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -59,6 +59,9 @@ class KafkaClient: rate. To avoid connection storms, a randomization factor of 0.2 will be applied to the backoff resulting in a random range between 20% below and 20% above the computed value. Default: 1000. + connection_timeout_ms (int): Connection timeout in milliseconds. + Default: None, which defaults it to the same value as + request_timeout_ms. request_timeout_ms (int): Client request timeout in milliseconds. Default: 30000. connections_max_idle_ms: Close idle connections after the number of @@ -145,6 +148,7 @@ class KafkaClient: 'bootstrap_servers': 'localhost', 'bootstrap_topics_filter': set(), 'client_id': 'kafka-python-' + __version__, + 'connection_timeout_ms': None, 'request_timeout_ms': 30000, 'wakeup_timeout_ms': 3000, 'connections_max_idle_ms': 9 * 60 * 1000, diff --git a/kafka/conn.py b/kafka/conn.py index 5a73ba429..ebf314bd5 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -112,6 +112,9 @@ class BrokerConnection: rate. To avoid connection storms, a randomization factor of 0.2 will be applied to the backoff resulting in a random range between 20% below and 20% above the computed value. Default: 1000. + connection_timeout_ms (int): Connection timeout in milliseconds. + Default: None, which defaults it to the same value as + request_timeout_ms. request_timeout_ms (int): Client request timeout in milliseconds. Default: 30000. max_in_flight_requests_per_connection (int): Requests are pipelined @@ -188,6 +191,7 @@ class BrokerConnection: 'client_id': 'kafka-python-' + __version__, 'node_id': 0, 'request_timeout_ms': 30000, + 'connection_timeout_ms': None, 'reconnect_backoff_ms': 50, 'reconnect_backoff_max_ms': 1000, 'max_in_flight_requests_per_connection': 5, @@ -231,6 +235,9 @@ def __init__(self, host, port, afi, **configs): for key in self.config: if key in configs: self.config[key] = configs[key] + + if self.config['connection_timeout_ms'] is None: + self.config['connection_timeout_ms'] = self.config['request_timeout_ms'] self.node_id = self.config.pop('node_id') @@ -284,7 +291,10 @@ def __init__(self, host, port, afi, **configs): if self.config['ssl_context'] is not None: self._ssl_context = self.config['ssl_context'] self._sasl_auth_future = None - self.last_attempt = 0 + self.last_activity = 0 + # This value is not used for internal state, but it is left to allow backwards-compatability + # The variable last_activity is now used instead, but is updated more often may therefore break compatability with some hacks. + self.last_attempt= 0 self._gai = [] self._sensors = None if self.config['metrics']: @@ -362,6 +372,7 @@ def connect(self): self.config['state_change_callback'](self.node_id, self._sock, self) log.info('%s: connecting to %s:%d [%s %s]', self, self.host, self.port, self._sock_addr, AFI_NAMES[self._sock_afi]) + self.last_activity = time.time() if self.state is ConnectionStates.CONNECTING: # in non-blocking mode, use repeated calls to socket.connect_ex @@ -394,6 +405,7 @@ def connect(self): self.state = ConnectionStates.CONNECTED self._reset_reconnect_backoff() self.config['state_change_callback'](self.node_id, self._sock, self) + self.last_activity = time.time() # Connection failed # WSAEINVAL == 10022, but errno.WSAEINVAL is not available on non-win systems @@ -419,6 +431,7 @@ def connect(self): self.state = ConnectionStates.CONNECTED self._reset_reconnect_backoff() self.config['state_change_callback'](self.node_id, self._sock, self) + self.last_activity = time.time() if self.state is ConnectionStates.AUTHENTICATING: assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL') @@ -429,12 +442,13 @@ def connect(self): self.state = ConnectionStates.CONNECTED self._reset_reconnect_backoff() self.config['state_change_callback'](self.node_id, self._sock, self) + self.last_activity = time.time() if self.state not in (ConnectionStates.CONNECTED, ConnectionStates.DISCONNECTED): # Connection timed out - request_timeout = self.config['request_timeout_ms'] / 1000.0 - if time.time() > request_timeout + self.last_attempt: + request_timeout = self.config['connection_timeout_ms'] / 1000.0 + if time.time() > request_timeout + self.last_activity: log.error('Connection attempt to %s timed out', self) self.close(Errors.KafkaConnectionError('timeout')) return self.state @@ -595,7 +609,7 @@ def blacked_out(self): re-establish a connection yet """ if self.state is ConnectionStates.DISCONNECTED: - if time.time() < self.last_attempt + self._reconnect_backoff: + if time.time() < self.last_activity + self._reconnect_backoff: return True return False @@ -606,7 +620,7 @@ def connection_delay(self): the reconnect backoff time. When connecting or connected, returns a very large number to handle slow/stalled connections. """ - time_waited = time.time() - (self.last_attempt or 0) + time_waited = time.time() - (self.last_activity or 0) if self.state is ConnectionStates.DISCONNECTED: return max(self._reconnect_backoff - time_waited, 0) * 1000 else: diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index f58221372..b9b2433d9 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -190,6 +190,9 @@ class KafkaProducer: brokers or partitions. Default: 300000 retry_backoff_ms (int): Milliseconds to backoff when retrying on errors. Default: 100. + connection_timeout_ms (int): Connection timeout in milliseconds. + Default: None, which defaults it to the same value as + request_timeout_ms. request_timeout_ms (int): Client request timeout in milliseconds. Default: 30000. receive_buffer_bytes (int): The size of the TCP receive buffer @@ -300,6 +303,7 @@ class KafkaProducer: 'max_request_size': 1048576, 'metadata_max_age_ms': 300000, 'retry_backoff_ms': 100, + 'connection_timeout_ms': None, 'request_timeout_ms': 30000, 'receive_buffer_bytes': None, 'send_buffer_bytes': None, diff --git a/test/test_conn.py b/test/test_conn.py index d595fac3a..979f25e31 100644 --- a/test/test_conn.py +++ b/test/test_conn.py @@ -6,6 +6,7 @@ import socket import pytest +import time from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts from kafka.protocol.api import RequestHeader @@ -61,28 +62,99 @@ def test_connect_timeout(_socket, conn): # Initial connect returns EINPROGRESS # immediate inline connect returns EALREADY # second explicit connect returns EALREADY - # third explicit connect returns EALREADY and times out via last_attempt + # third explicit connect returns EALREADY and times out via last_activity _socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY] conn.connect() assert conn.state is ConnectionStates.CONNECTING conn.connect() assert conn.state is ConnectionStates.CONNECTING + conn.last_activity = 0 conn.last_attempt = 0 conn.connect() assert conn.state is ConnectionStates.DISCONNECTED +def test_connect_timeout_slowconn(_socket, conn, mocker): + # Same as test_connect_timeout, + # but we make the connection run longer than the timeout in order to test that + # BrokerConnection resets the timer whenever things happen during the connection + # See https://github.com/dpkp/kafka-python/issues/2386 + _socket.connect_ex.side_effect = [EINPROGRESS, EISCONN] + + # 0.8 = we guarantee that when testing with three intervals of this we are past the timeout + time_between_connect = (conn.config['connection_timeout_ms']/1000) * 0.8 + start = time.time() + + # Use plaintext auth for simplicity + last_activity = conn.last_activity + last_attempt = conn.last_attempt + conn.config['security_protocol'] = 'SASL_PLAINTEXT' + conn.connect() + assert conn.state is ConnectionStates.CONNECTING + # Ensure the last_activity counter was updated + # Last_attempt should also be updated + assert conn.last_activity > last_activity + assert conn.last_attempt > last_attempt + last_attempt = conn.last_attempt + last_activity = conn.last_activity + + # Simulate time being passed + # This shouldn't be enough time to time out the connection + conn._try_authenticate = mocker.Mock(side_effect=[False, False, True]) + with mock.patch("time.time", return_value=start+time_between_connect): + # This should trigger authentication + # Note that an authentication attempt isn't actually made until now. + # We simulate that authentication does not succeed at this point + # This is technically incorrect, but it lets us see what happens + # to the state machine when the state doesn't change for two function calls + conn.connect() + assert conn.last_activity > last_activity + # Last attempt is kept as a legacy variable, should not update + assert conn.last_attempt == last_attempt + last_activity = conn.last_activity + + assert conn.state is ConnectionStates.AUTHENTICATING + + + # This time around we should be way past timeout. + # Now we care about connect() not terminating the attempt, + # because connection state was progressed in the meantime. + with mock.patch("time.time", return_value=start+time_between_connect*2): + # Simulate this one not succeeding as well. This is so we can ensure things don't time out + conn.connect() + + # No state change = no activity change + assert conn.last_activity == last_activity + assert conn.last_attempt == last_attempt + + # If last_activity was not reset when the state transitioned to AUTHENTICATING, + # the connection state would be timed out now. + assert conn.state is ConnectionStates.AUTHENTICATING + + + # This time around, the connection should succeed. + with mock.patch("time.time", return_value=start+time_between_connect*3): + # This should finalize the connection + conn.connect() + + assert conn.last_activity > last_activity + assert conn.last_attempt == last_attempt + last_activity = conn.last_activity + + assert conn.state is ConnectionStates.CONNECTED + + def test_blacked_out(conn): with mock.patch("time.time", return_value=1000): - conn.last_attempt = 0 + conn.last_activity = 0 assert conn.blacked_out() is False - conn.last_attempt = 1000 + conn.last_activity = 1000 assert conn.blacked_out() is True def test_connection_delay(conn): with mock.patch("time.time", return_value=1000): - conn.last_attempt = 1000 + conn.last_activity = 1000 assert conn.connection_delay() == conn.config['reconnect_backoff_ms'] conn.state = ConnectionStates.CONNECTING assert conn.connection_delay() == float('inf') @@ -286,7 +358,7 @@ def test_lookup_on_connect(): ] with mock.patch("socket.getaddrinfo", return_value=mock_return2) as m: - conn.last_attempt = 0 + conn.last_activity = 0 conn.connect() m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) assert conn._sock_afi == afi2 @@ -301,11 +373,10 @@ def test_relookup_on_failure(): assert conn.host == hostname mock_return1 = [] with mock.patch("socket.getaddrinfo", return_value=mock_return1) as m: - last_attempt = conn.last_attempt + last_activity = conn.last_activity conn.connect() m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) assert conn.disconnected() - assert conn.last_attempt > last_attempt afi2 = socket.AF_INET sockaddr2 = ('127.0.0.2', 9092) @@ -314,12 +385,13 @@ def test_relookup_on_failure(): ] with mock.patch("socket.getaddrinfo", return_value=mock_return2) as m: - conn.last_attempt = 0 + conn.last_activity = 0 conn.connect() m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) assert conn._sock_afi == afi2 assert conn._sock_addr == sockaddr2 conn.close() + assert conn.last_activity > last_activity def test_requests_timed_out(conn): From 6756974e7d0a86f6aa67a1262feb0f34b346e4b1 Mon Sep 17 00:00:00 2001 From: Sharu Kulam Date: Thu, 4 Apr 2024 17:55:48 +0200 Subject: [PATCH 10/17] add validate_config function for msk module (#176) --- kafka/conn.py | 17 +++-------------- kafka/sasl/msk.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/kafka/conn.py b/kafka/conn.py index ebf314bd5..745e4bca6 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -68,13 +68,6 @@ class SSLWantWriteError(Exception): gssapi = None GSSError = None -# needed for AWS_MSK_IAM authentication: -try: - from botocore.session import Session as BotoSession -except ImportError: - # no botocore available, will disable AWS_MSK_IAM mechanism - BotoSession = None - AFI_NAMES = { socket.AF_UNSPEC: "unspecified", socket.AF_INET: "IPv4", @@ -113,7 +106,7 @@ class BrokerConnection: will be applied to the backoff resulting in a random range between 20% below and 20% above the computed value. Default: 1000. connection_timeout_ms (int): Connection timeout in milliseconds. - Default: None, which defaults it to the same value as + Default: None, which defaults it to the same value as request_timeout_ms. request_timeout_ms (int): Client request timeout in milliseconds. Default: 30000. @@ -235,7 +228,7 @@ def __init__(self, host, port, afi, **configs): for key in self.config: if key in configs: self.config[key] = configs[key] - + if self.config['connection_timeout_ms'] is None: self.config['connection_timeout_ms'] = self.config['request_timeout_ms'] @@ -253,19 +246,15 @@ def __init__(self, host, port, afi, **configs): assert self.config['security_protocol'] in self.SECURITY_PROTOCOLS, ( 'security_protocol must be in ' + ', '.join(self.SECURITY_PROTOCOLS)) - if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): assert ssl_available, "Python wasn't built with SSL support" - if self.config['sasl_mechanism'] == 'AWS_MSK_IAM': - assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package' - assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL' - if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): assert self.config['sasl_mechanism'] in sasl.MECHANISMS, ( 'sasl_mechanism must be one of {}'.format(', '.join(sasl.MECHANISMS.keys())) ) sasl.MECHANISMS[self.config['sasl_mechanism']].validate_config(self) + # This is not a general lock / this class is not generally thread-safe yet # However, to avoid pushing responsibility for maintaining # per-connection locks to the upstream client, we will use this lock to diff --git a/kafka/sasl/msk.py b/kafka/sasl/msk.py index ebea5dc5a..2ae88d326 100644 --- a/kafka/sasl/msk.py +++ b/kafka/sasl/msk.py @@ -10,10 +10,20 @@ from kafka.protocol.types import Int32 import kafka.errors as Errors -from botocore.session import Session as BotoSession # importing it in advance is not an option apparently... +# needed for AWS_MSK_IAM authentication: +try: + from botocore.session import Session as BotoSession +except ImportError: + # no botocore available, will disable AWS_MSK_IAM mechanism + BotoSession = None + from typing import Optional +def validate_config(conn): + assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package' + assert conn.config.get('security_protocol') == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL' + def try_authenticate(self, future): session = BotoSession() @@ -25,7 +35,7 @@ def try_authenticate(self, future): region=session.get_config_variable('region'), token=credentials.token, ) - + msg = client.first_message() size = Int32.encode(len(msg)) From 611471fcfbefe32b610819b33a744da4c1f5849f Mon Sep 17 00:00:00 2001 From: Xiong Ding Date: Wed, 10 Apr 2024 09:57:51 -0700 Subject: [PATCH 11/17] Fix ssl connection (#178) * Fix ssl connection after wrap_ssl * test * refactor * remove global level * test * revert test * address comments --- kafka/client_async.py | 9 ++++- kafka/conn.py | 4 +-- test/fixtures.py | 9 +++-- test/test_ssl_integration.py | 67 ++++++++++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 5 deletions(-) create mode 100644 test/test_ssl_integration.py diff --git a/kafka/client_async.py b/kafka/client_async.py index b46b879f9..984cd81fb 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -266,7 +266,14 @@ def _conn_state_change(self, node_id, sock, conn): try: self._selector.register(sock, selectors.EVENT_WRITE, conn) except KeyError: - self._selector.modify(sock, selectors.EVENT_WRITE, conn) + # SSL detaches the original socket, and transfers the + # underlying file descriptor to a new SSLSocket. We should + # explicitly unregister the original socket. + if conn.state == ConnectionStates.HANDSHAKE: + self._selector.unregister(sock) + self._selector.register(sock, selectors.EVENT_WRITE, conn) + else: + self._selector.modify(sock, selectors.EVENT_WRITE, conn) if self.cluster.is_bootstrap(node_id): self._last_bootstrap = time.time() diff --git a/kafka/conn.py b/kafka/conn.py index 745e4bca6..b9ef0e2d9 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -378,10 +378,10 @@ def connect(self): if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): log.debug('%s: initiating SSL handshake', self) - self.state = ConnectionStates.HANDSHAKE - self.config['state_change_callback'](self.node_id, self._sock, self) # _wrap_ssl can alter the connection state -- disconnects on failure self._wrap_ssl() + self.state = ConnectionStates.HANDSHAKE + self.config['state_change_callback'](self.node_id, self._sock, self) elif self.config['security_protocol'] == 'SASL_PLAINTEXT': log.debug('%s: initiating SASL authentication', self) diff --git a/test/fixtures.py b/test/fixtures.py index 4ed515da3..998dc429f 100644 --- a/test/fixtures.py +++ b/test/fixtures.py @@ -38,7 +38,7 @@ def gen_ssl_resources(directory): # Step 1 keytool -keystore kafka.server.keystore.jks -alias localhost -validity 1 \ - -genkey -storepass foobar -keypass foobar \ + -genkey -keyalg RSA -storepass foobar -keypass foobar \ -dname "CN=localhost, OU=kafka-python, O=kafka-python, L=SF, ST=CA, C=US" \ -ext SAN=dns:localhost @@ -289,7 +289,7 @@ def __init__(self, host, port, broker_id, zookeeper, zk_chroot, self.sasl_mechanism = sasl_mechanism.upper() else: self.sasl_mechanism = None - self.ssl_dir = self.test_resource('ssl') + self.ssl_dir = None # TODO: checking for port connection would be better than scanning logs # until then, we need the pattern to work across all supported broker versions @@ -410,6 +410,8 @@ def start(self): jaas_conf = self.tmp_dir.join("kafka_server_jaas.conf") properties_template = self.test_resource("kafka.properties") jaas_conf_template = self.test_resource("kafka_server_jaas.conf") + self.ssl_dir = self.tmp_dir + gen_ssl_resources(self.ssl_dir.strpath) args = self.kafka_run_class_args("kafka.Kafka", properties.strpath) env = self.kafka_run_class_env() @@ -641,6 +643,9 @@ def _enrich_client_params(self, params, **defaults): if self.sasl_mechanism in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'): params.setdefault('sasl_plain_username', self.broker_user) params.setdefault('sasl_plain_password', self.broker_password) + if self.transport in ["SASL_SSL", "SSL"]: + params.setdefault("ssl_cafile", self.ssl_dir.join('ca-cert').strpath) + params.setdefault("security_protocol", self.transport) return params @staticmethod diff --git a/test/test_ssl_integration.py b/test/test_ssl_integration.py new file mode 100644 index 000000000..8453e7831 --- /dev/null +++ b/test/test_ssl_integration.py @@ -0,0 +1,67 @@ +import logging +import uuid + +import pytest + +from kafka.admin import NewTopic +from kafka.protocol.metadata import MetadataRequest_v1 +from test.testutil import assert_message_count, env_kafka_version, random_string, special_to_underscore + + +@pytest.fixture(scope="module") +def ssl_kafka(request, kafka_broker_factory): + return kafka_broker_factory(transport="SSL")[0] + + +@pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Inter broker SSL was implemented at version 0.9") +def test_admin(request, ssl_kafka): + topic_name = special_to_underscore(request.node.name + random_string(4)) + admin, = ssl_kafka.get_admin_clients(1) + admin.create_topics([NewTopic(topic_name, 1, 1)]) + assert topic_name in ssl_kafka.get_topic_names() + + +@pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Inter broker SSL was implemented at version 0.9") +def test_produce_and_consume(request, ssl_kafka): + topic_name = special_to_underscore(request.node.name + random_string(4)) + ssl_kafka.create_topics([topic_name], num_partitions=2) + producer, = ssl_kafka.get_producers(1) + + messages_and_futures = [] # [(message, produce_future),] + for i in range(100): + encoded_msg = "{}-{}-{}".format(i, request.node.name, uuid.uuid4()).encode("utf-8") + future = producer.send(topic_name, value=encoded_msg, partition=i % 2) + messages_and_futures.append((encoded_msg, future)) + producer.flush() + + for (msg, f) in messages_and_futures: + assert f.succeeded() + + consumer, = ssl_kafka.get_consumers(1, [topic_name]) + messages = {0: [], 1: []} + for i, message in enumerate(consumer, 1): + logging.debug("Consumed message %s", repr(message)) + messages[message.partition].append(message) + if i >= 100: + break + + assert_message_count(messages[0], 50) + assert_message_count(messages[1], 50) + + +@pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Inter broker SSL was implemented at version 0.9") +def test_client(request, ssl_kafka): + topic_name = special_to_underscore(request.node.name + random_string(4)) + ssl_kafka.create_topics([topic_name], num_partitions=1) + + client, = ssl_kafka.get_clients(1) + request = MetadataRequest_v1(None) + client.send(0, request) + for _ in range(10): + result = client.poll(timeout_ms=10000) + if len(result) > 0: + break + else: + raise RuntimeError("Couldn't fetch topic response from Broker.") + result = result[0] + assert topic_name in [t[1] for t in result.topics] From deebd8f06eaf951b8f44628e917262b08c84da39 Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Tue, 23 Apr 2024 09:57:05 -0400 Subject: [PATCH 12/17] Fix badge typo in README.rst --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index ce82c6d3b..794d59a61 100644 --- a/README.rst +++ b/README.rst @@ -11,7 +11,7 @@ Kafka Python client :target: https://github.com/wbarnha/kafka-python-ng/blob/master/LICENSE .. image:: https://img.shields.io/pypi/dw/kafka-python-ng.svg :target: https://pypistats.org/packages/kafka-python-ng -.. image:: https://img.shields.io/pypi/v/kafka-python.svg +.. image:: https://img.shields.io/pypi/v/kafka-python-ng.svg :target: https://pypi.org/project/kafka-python-ng .. image:: https://img.shields.io/pypi/implementation/kafka-python-ng :target: https://github.com/wbarnha/kafka-python-ng/blob/master/setup.py From 5e461a7e017130fb9115add8d64291d6966267e9 Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Fri, 12 Jul 2024 01:37:38 -0400 Subject: [PATCH 13/17] Patch pylint warnings so tests pass again (#184) * stop pylint complaint for uncovered conditional flow * add todo to revisit * formatting makes me happy :) * Fix errors raised by new version of Pylint so tests pass again --- kafka/admin/client.py | 5 +++++ kafka/coordinator/consumer.py | 5 +++++ kafka/record/default_records.py | 8 +++++--- kafka/record/legacy_records.py | 2 ++ 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/kafka/admin/client.py b/kafka/admin/client.py index 5b01f8fe6..f74e09a80 100644 --- a/kafka/admin/client.py +++ b/kafka/admin/client.py @@ -503,6 +503,8 @@ def _get_cluster_metadata(self, topics=None, auto_topic_creation=False): topics=topics, allow_auto_topic_creation=auto_topic_creation ) + else: + raise IncompatibleBrokerVersion(f"MetadataRequest for {version} is not supported") future = self._send_request_to_node( self._client.least_loaded_node(), @@ -1010,6 +1012,7 @@ def _describe_consumer_groups_send_request(self, group_id, group_coordinator_id, def _describe_consumer_groups_process_response(self, response): """Process a DescribeGroupsResponse into a group description.""" if response.API_VERSION <= 3: + group_description = None assert len(response.groups) == 1 for response_field, response_name in zip(response.SCHEMA.fields, response.SCHEMA.names): if isinstance(response_field, Array): @@ -1045,6 +1048,8 @@ def _describe_consumer_groups_process_response(self, response): if response.API_VERSION <=2: described_group_information_list.append(None) group_description = GroupInformation._make(described_group_information_list) + if group_description is None: + raise Errors.BrokerResponseError("No group description received") error_code = group_description.error_code error_type = Errors.for_code(error_code) # Java has the note: KAFKA-6789, we can retry based on the error code diff --git a/kafka/coordinator/consumer.py b/kafka/coordinator/consumer.py index cf82b69fe..351641981 100644 --- a/kafka/coordinator/consumer.py +++ b/kafka/coordinator/consumer.py @@ -628,10 +628,15 @@ def _send_offset_commit_request(self, offsets): ) for partition, offset in partitions.items()] ) for topic, partitions in offset_data.items()] ) + else: + # TODO: We really shouldn't need this here to begin with, but I'd like to get + # pylint to stop complaining. + raise Exception(f"Unsupported Broker API: {self.config['api_version']}") log.debug("Sending offset-commit request with %s for group %s to %s", offsets, self.group_id, node_id) + future = Future() _f = self._client.send(node_id, request) _f.add_callback(self._handle_offset_commit_response, offsets, future, time.time()) diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py index 8b630cc8b..06be57621 100644 --- a/kafka/record/default_records.py +++ b/kafka/record/default_records.py @@ -187,12 +187,14 @@ def _maybe_uncompress(self) -> None: data = memoryview(self._buffer)[self._pos:] if compression_type == self.CODEC_GZIP: uncompressed = gzip_decode(data) - if compression_type == self.CODEC_SNAPPY: + elif compression_type == self.CODEC_SNAPPY: uncompressed = snappy_decode(data.tobytes()) - if compression_type == self.CODEC_LZ4: + elif compression_type == self.CODEC_LZ4: uncompressed = lz4_decode(data.tobytes()) - if compression_type == self.CODEC_ZSTD: + elif compression_type == self.CODEC_ZSTD: uncompressed = zstd_decode(data.tobytes()) + else: + raise NotImplementedError(f"Compression type {compression_type} is not supported") self._buffer = bytearray(uncompressed) self._pos = 0 self._decompressed = True diff --git a/kafka/record/legacy_records.py b/kafka/record/legacy_records.py index 4439462f6..44b365b06 100644 --- a/kafka/record/legacy_records.py +++ b/kafka/record/legacy_records.py @@ -461,6 +461,8 @@ def _maybe_compress(self) -> bool: compressed = lz4_encode_old_kafka(data) else: compressed = lz4_encode(data) + else: + raise NotImplementedError(f"Compression type {self._compression_type} is not supported") size = self.size_in_bytes( 0, timestamp=0, key=None, value=compressed) # We will try to reuse the same buffer if we have enough space From 401896b42a32c356a5453859ae576d166b051afd Mon Sep 17 00:00:00 2001 From: William Barnhart Date: Wed, 17 Jul 2024 11:35:53 -0400 Subject: [PATCH 14/17] Update README.rst to close #179 --- README.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.rst b/README.rst index c2b3d3e67..1a6a8050a 100644 --- a/README.rst +++ b/README.rst @@ -47,6 +47,11 @@ documentation, please see readthedocs and/or python's inline help. $ pip install kafka-python-ng +For those who are concerned regarding the security of this package: +This project uses https://docs.pypi.org/trusted-publishers/ in GitHub +Actions to publish artifacts in https://github.com/wbarnha/kafka-python-ng/deployments/pypi. +This project was forked to keep the project alive for future versions of +Python and Kafka, since `kafka-python` is unable to publish releases in the meantime. KafkaConsumer ************* From 31a6b92e3ff5265dc1f184250115532a30618cc2 Mon Sep 17 00:00:00 2001 From: Orange Kao Date: Sat, 10 Aug 2024 00:00:51 +1000 Subject: [PATCH 15/17] Avoid busy retry (#192) Test test/test_consumer_group.py::test_group and test/test_admin_integration.py::test_describe_consumer_group_exists busy-retry and this might have caused Java not having enough CPU time on GitHub runner, and result in test failure. --- test/test_admin_integration.py | 1 + test/test_consumer_group.py | 1 + 2 files changed, 2 insertions(+) diff --git a/test/test_admin_integration.py b/test/test_admin_integration.py index 283023049..0eb06b18d 100644 --- a/test/test_admin_integration.py +++ b/test/test_admin_integration.py @@ -220,6 +220,7 @@ def consumer_thread(i, group_id): else: sleep(1) assert time() < timeout, "timeout waiting for assignments" + sleep(0.25) info('Group stabilized; verifying assignment') output = kafka_admin_client.describe_consumer_groups(group_id_list) diff --git a/test/test_consumer_group.py b/test/test_consumer_group.py index ed6863fa2..abd0cfe09 100644 --- a/test/test_consumer_group.py +++ b/test/test_consumer_group.py @@ -111,6 +111,7 @@ def consumer_thread(i): logging.info('Rejoining: %s, generations: %s', rejoining, generations) time.sleep(1) assert time.time() < timeout, "timeout waiting for assignments" + time.sleep(0.25) logging.info('Group stabilized; verifying assignment') group_assignment = set() From 9bee9fc599c473437ebec8d90dd22ae7ed7a9bc8 Mon Sep 17 00:00:00 2001 From: debuggings Date: Thu, 15 Aug 2024 10:41:27 +0800 Subject: [PATCH 16/17] fix scram username character escape (#196) According to [rfc5802](https://datatracker.ietf.org/doc/html/rfc5802), username should escape special characters before sending to the server. > The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively. If the server receives a username that contains '=' not followed by either '2C' or '3D', then the server MUST fail the authentication. --- kafka/scram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kafka/scram.py b/kafka/scram.py index 74f4716bd..236ae2149 100644 --- a/kafka/scram.py +++ b/kafka/scram.py @@ -30,7 +30,7 @@ def __init__(self, user, password, mechanism): self.server_signature = None def first_message(self): - client_first_bare = f'n={self.user},r={self.nonce}' + client_first_bare = f'n={self.user.replace("=","=3D").replace(",","=2C")},r={self.nonce}' self.auth_message += client_first_bare return 'n,,' + client_first_bare From 61046232200688ceaba9726ab963b643b223b1d4 Mon Sep 17 00:00:00 2001 From: Orange Kao Date: Thu, 3 Oct 2024 10:51:51 +1000 Subject: [PATCH 17/17] Improve test/test_consumer_integration.py in GitHub runner (#194) test/test_consumer_integration.py::test_kafka_consumer__blocking failed in https://github.com/wbarnha/kafka-python-ng/actions/runs/10361086008/job/28680735389?pr=186 because it took 592ms to finish. Output from the GitHub runner attached This commit increase TIMEOUT_MS so it is less likely to fail on GitHub runner. # Ask for 5 messages, 10 in queue. Get 5 back, no blocking messages = [] with Timer() as t: for i in range(5): msg = next(consumer) messages.append(msg) assert_message_count(messages, 5) > assert t.interval < (TIMEOUT_MS / 1000.0) E assert 0.5929090976715088 < (500 / 1000.0) E + where 0.5929090976715088 = .interval Co-authored-by: William Barnhart --- test/test_consumer_integration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_consumer_integration.py b/test/test_consumer_integration.py index d3165cd63..62aad5f97 100644 --- a/test/test_consumer_integration.py +++ b/test/test_consumer_integration.py @@ -61,7 +61,7 @@ def test_kafka_consumer_unsupported_encoding( @pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages): - TIMEOUT_MS = 500 + TIMEOUT_MS = 1000 consumer = kafka_consumer_factory(auto_offset_reset='earliest', enable_auto_commit=False, consumer_timeout_ms=TIMEOUT_MS) @@ -70,7 +70,7 @@ def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages): consumer.unsubscribe() consumer.assign([TopicPartition(topic, 0)]) - # Ask for 5 messages, nothing in queue, block 500ms + # Ask for 5 messages, nothing in queue, block 1000ms with Timer() as t: with pytest.raises(StopIteration): msg = next(consumer) @@ -87,7 +87,7 @@ def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages): assert_message_count(messages, 5) assert t.interval < (TIMEOUT_MS / 1000.0) - # Ask for 10 messages, get 5 back, block 500ms + # Ask for 10 messages, get 5 back, block 1000ms messages = [] with Timer() as t: with pytest.raises(StopIteration):