Skip to content

KAFKA-3949: Avoid race condition when subscription changes during rebalance #1364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions kafka/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,13 @@ def update_metadata(self, metadata):
for listener in self._listeners:
listener(self)

if self.need_all_topic_metadata:
# the listener may change the interested topics,
# which could cause another metadata refresh.
# If we have already fetched all topics, however,
# another fetch should be unnecessary.
self._need_update = False

def add_listener(self, listener):
"""Add a callback function to be called on each metadata update"""
self._listeners.add(listener)
Expand Down
6 changes: 0 additions & 6 deletions kafka/consumer/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,6 @@ def fetched_records(self, max_records=None):
max_records = self.config['max_poll_records']
assert max_records > 0

if self._subscriptions.needs_partition_assignment:
return {}, False

drained = collections.defaultdict(list)
records_remaining = max_records

Expand Down Expand Up @@ -397,9 +394,6 @@ def _append(self, drained, part, max_records):

def _message_generator(self):
"""Iterate over fetched_records"""
if self._subscriptions.needs_partition_assignment:
raise StopIteration('Subscription needs partition assignment')

while self._next_partition_records or self._completed_fetches:

if not self._next_partition_records:
Expand Down
10 changes: 10 additions & 0 deletions kafka/consumer/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,11 @@ def _poll_once(self, timeout_ms, max_records):

timeout_ms = min(timeout_ms, self._coordinator.time_to_next_poll())
self._client.poll(timeout_ms=timeout_ms)
# after the long poll, we should check whether the group needs to rebalance
# prior to returning data so that the group can stabilize faster
if self._coordinator.need_rejoin():
return {}

records, _ = self._fetcher.fetched_records(max_records)
return records

Expand Down Expand Up @@ -1055,6 +1060,11 @@ def _message_generator(self):
poll_ms = 0
self._client.poll(timeout_ms=poll_ms)

# after the long poll, we should check whether the group needs to rebalance
# prior to returning data so that the group can stabilize faster
if self._coordinator.need_rejoin():
continue

# We need to make sure we at least keep up with scheduled tasks,
# like heartbeats, auto-commits, and metadata refreshes
timeout_at = self._next_timeout()
Expand Down
30 changes: 14 additions & 16 deletions kafka/consumer/subscription_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(self, offset_reset_strategy='earliest'):
self._group_subscription = set()
self._user_assignment = set()
self.assignment = dict()
self.needs_partition_assignment = False
self.listener = None

# initialize to true for the consumers to fetch offset upon starting up
Expand Down Expand Up @@ -172,7 +171,6 @@ def change_subscription(self, topics):
log.info('Updating subscribed topics to: %s', topics)
self.subscription = set(topics)
self._group_subscription.update(topics)
self.needs_partition_assignment = True

# Remove any assigned partitions which are no longer subscribed to
for tp in set(self.assignment.keys()):
Expand All @@ -192,12 +190,12 @@ def group_subscribe(self, topics):
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
self._group_subscription.update(topics)

def mark_for_reassignment(self):
def reset_group_subscription(self):
"""Reset the group's subscription to only contain topics subscribed by this consumer."""
if self._user_assignment:
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
assert self.subscription is not None, 'Subscription required'
self._group_subscription.intersection_update(self.subscription)
self.needs_partition_assignment = True

def assign_from_user(self, partitions):
"""Manually assign a list of TopicPartitions to this consumer.
Expand All @@ -220,18 +218,17 @@ def assign_from_user(self, partitions):
if self.subscription is not None:
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)

self._user_assignment.clear()
self._user_assignment.update(partitions)
if self._user_assignment != set(partitions):
self._user_assignment = set(partitions)

for partition in partitions:
if partition not in self.assignment:
self._add_assigned_partition(partition)
for partition in partitions:
if partition not in self.assignment:
self._add_assigned_partition(partition)

for tp in set(self.assignment.keys()) - self._user_assignment:
del self.assignment[tp]
for tp in set(self.assignment.keys()) - self._user_assignment:
del self.assignment[tp]

self.needs_partition_assignment = False
self.needs_fetch_committed_offsets = True
self.needs_fetch_committed_offsets = True

def assign_from_subscribed(self, assignments):
"""Update the assignment to the specified partitions
Expand All @@ -245,24 +242,25 @@ def assign_from_subscribed(self, assignments):
assignments (list of TopicPartition): partitions to assign to this
consumer instance.
"""
if self.subscription is None:
if not self.partitions_auto_assigned():
raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)

for tp in assignments:
if tp.topic not in self.subscription:
raise ValueError("Assigned partition %s for non-subscribed topic." % str(tp))

# after rebalancing, we always reinitialize the assignment state
self.assignment.clear()
for tp in assignments:
self._add_assigned_partition(tp)
self.needs_partition_assignment = False
self.needs_fetch_committed_offsets = True
log.info("Updated partition assignment: %s", assignments)

def unsubscribe(self):
"""Clear all topic subscriptions and partition assignments"""
self.subscription = None
self._user_assignment.clear()
self.assignment.clear()
self.needs_partition_assignment = True
self.subscribed_pattern = None

def group_subscription(self):
Expand Down
24 changes: 13 additions & 11 deletions kafka/coordinator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,23 +344,25 @@ def _handle_join_failure(self, _):
def ensure_active_group(self):
"""Ensure that the group is active (i.e. joined and synced)"""
with self._lock:
if not self.need_rejoin():
return

# call on_join_prepare if needed. We set a flag to make sure that
# we do not call it a second time if the client is woken up before
# a pending rebalance completes.
if not self.rejoining:
self._on_join_prepare(self._generation.generation_id,
self._generation.member_id)
self.rejoining = True

if self._heartbeat_thread is None:
self._start_heartbeat_thread()

while self.need_rejoin():
self.ensure_coordinator_ready()

# call on_join_prepare if needed. We set a flag
# to make sure that we do not call it a second
# time if the client is woken up before a pending
# rebalance completes. This must be called on each
# iteration of the loop because an event requiring
# a rebalance (such as a metadata refresh which
# changes the matched subscription set) can occur
# while another rebalance is still in progress.
if not self.rejoining:
self._on_join_prepare(self._generation.generation_id,
self._generation.member_id)
self.rejoining = True

# ensure that there are no pending requests to the coordinator.
# This is important in particular to avoid resending a pending
# JoinGroup request.
Expand Down
102 changes: 63 additions & 39 deletions kafka/coordinator/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __init__(self, client, subscription, metrics, **configs):
self.config[key] = configs[key]

self._subscription = subscription
self._is_leader = False
self._joined_subscription = set()
self._metadata_snapshot = self._build_metadata_snapshot(subscription, client.cluster)
self._assignment_snapshot = None
self._cluster = client.cluster
Expand Down Expand Up @@ -132,11 +134,22 @@ def protocol_type(self):

def group_protocols(self):
"""Returns list of preferred (protocols, metadata)"""
topics = self._subscription.subscription
assert topics is not None, 'Consumer has not subscribed to topics'
if self._subscription.subscription is None:
raise Errors.IllegalStateError('Consumer has not subscribed to topics')
# dpkp note: I really dislike this.
# why? because we are using this strange method group_protocols,
# which is seemingly innocuous, to set internal state (_joined_subscription)
# that is later used to check whether metadata has changed since we joined a group
# but there is no guarantee that this method, group_protocols, will get called
# in the correct sequence or that it will only be called when we want it to be.
# So this really should be moved elsewhere, but I don't have the energy to
# work that out right now. If you read this at some later date after the mutable
# state has bitten you... I'm sorry! It mimics the java client, and that's the
# best I've got for now.
self._joined_subscription = set(self._subscription.subscription)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😃 loved this comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I know. Hate it, as the method is basically a getter. They put it there, just cause it should be in ConsumerCoordinator but not in AbstractCoordinator and only this method was overridable... Basically a big hack. BTW the method is called metadata() nowdays. https://github.com/apache/kafka/blob/1.0/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java#L148

metadata_list = []
for assignor in self.config['assignors']:
metadata = assignor.metadata(topics)
metadata = assignor.metadata(self._joined_subscription)
group_protocol = (assignor.name, metadata)
metadata_list.append(group_protocol)
return metadata_list
Expand All @@ -158,21 +171,29 @@ def _handle_metadata_update(self, cluster):

# check if there are any changes to the metadata which should trigger
# a rebalance
if self._subscription_metadata_changed(cluster):

if (self.config['api_version'] >= (0, 9)
and self.config['group_id'] is not None):

self._subscription.mark_for_reassignment()

# If we haven't got group coordinator support,
# just assign all partitions locally
else:
self._subscription.assign_from_subscribed([
TopicPartition(topic, partition)
for topic in self._subscription.subscription
for partition in self._metadata_snapshot[topic]
])
if self._subscription.partitions_auto_assigned():
metadata_snapshot = self._build_metadata_snapshot(self._subscription, cluster)
if self._metadata_snapshot != metadata_snapshot:
self._metadata_snapshot = metadata_snapshot

# If we haven't got group coordinator support,
# just assign all partitions locally
if self._auto_assign_all_partitions():
self._subscription.assign_from_subscribed([
TopicPartition(topic, partition)
for topic in self._subscription.subscription
for partition in self._metadata_snapshot[topic]
])

def _auto_assign_all_partitions(self):
# For users that use "subscribe" without group support,
# we will simply assign all partitions to this consumer
if self.config['api_version'] < (0, 9):
return True
elif self.config['group_id'] is None:
return True
else:
return False

def _build_metadata_snapshot(self, subscription, cluster):
metadata_snapshot = {}
Expand All @@ -181,16 +202,6 @@ def _build_metadata_snapshot(self, subscription, cluster):
metadata_snapshot[topic] = set(partitions)
return metadata_snapshot

def _subscription_metadata_changed(self, cluster):
if not self._subscription.partitions_auto_assigned():
return False

metadata_snapshot = self._build_metadata_snapshot(self._subscription, cluster)
if self._metadata_snapshot != metadata_snapshot:
self._metadata_snapshot = metadata_snapshot
return True
return False

def _lookup_assignor(self, name):
for assignor in self.config['assignors']:
if assignor.name == name:
Expand All @@ -199,12 +210,10 @@ def _lookup_assignor(self, name):

def _on_join_complete(self, generation, member_id, protocol,
member_assignment_bytes):
# if we were the assignor, then we need to make sure that there have
# been no metadata updates since the rebalance begin. Otherwise, we
# won't rebalance again until the next metadata change
if self._assignment_snapshot is not None and self._assignment_snapshot != self._metadata_snapshot:
self._subscription.mark_for_reassignment()
return
# only the leader is responsible for monitoring for metadata changes
# (i.e. partition changes)
if not self._is_leader:
self._assignment_snapshot = None

assignor = self._lookup_assignor(protocol)
assert assignor, 'Coordinator selected invalid assignment protocol: %s' % protocol
Expand Down Expand Up @@ -307,6 +316,7 @@ def _perform_assignment(self, leader_id, assignment_strategy, members):
# keep track of the metadata used for assignment so that we can check
# after rebalance completion whether anything has changed
self._cluster.request_update()
self._is_leader = True
self._assignment_snapshot = self._metadata_snapshot

log.debug("Performing assignment for group %s using strategy %s"
Expand Down Expand Up @@ -338,18 +348,32 @@ def _on_join_prepare(self, generation, member_id):
" for group %s failed on_partitions_revoked",
self._subscription.listener, self.group_id)

self._assignment_snapshot = None
self._subscription.mark_for_reassignment()
self._is_leader = False
self._subscription.reset_group_subscription()

def need_rejoin(self):
"""Check whether the group should be rejoined

Returns:
bool: True if consumer should rejoin group, False otherwise
"""
return (self._subscription.partitions_auto_assigned() and
(super(ConsumerCoordinator, self).need_rejoin() or
self._subscription.needs_partition_assignment))
if not self._subscription.partitions_auto_assigned():
return False

if self._auto_assign_all_partitions():
return False

# we need to rejoin if we performed the assignment and metadata has changed
if (self._assignment_snapshot is not None
and self._assignment_snapshot != self._metadata_snapshot):
return True

# we need to join if our subscription has changed since the last join
if (self._joined_subscription is not None
and self._joined_subscription != self._subscription.subscription):
return True

return super(ConsumerCoordinator, self).need_rejoin()

def refresh_committed_offsets_if_needed(self):
"""Fetch committed offsets for assigned partitions."""
Expand Down
Loading