Skip to content

Commit c536dd2

Browse files
authored
KIP-54: Implement sticky partition assignment strategy (dpkp#2057)
1 parent cb96a1a commit c536dd2

File tree

9 files changed

+1781
-20
lines changed

9 files changed

+1781
-20
lines changed

kafka/coordinator/assignors/sticky/__init__.py

Whitespace-only changes.
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import logging
2+
from collections import defaultdict, namedtuple
3+
from copy import deepcopy
4+
5+
from kafka.vendor import six
6+
7+
log = logging.getLogger(__name__)
8+
9+
10+
ConsumerPair = namedtuple("ConsumerPair", ["src_member_id", "dst_member_id"])
11+
"""
12+
Represents a pair of Kafka consumer ids involved in a partition reassignment.
13+
Each ConsumerPair corresponds to a particular partition or topic, indicates that the particular partition or some
14+
partition of the particular topic was moved from the source consumer to the destination consumer
15+
during the rebalance. This class helps in determining whether a partition reassignment results in cycles among
16+
the generated graph of consumer pairs.
17+
"""
18+
19+
20+
def is_sublist(source, target):
21+
"""Checks if one list is a sublist of another.
22+
23+
Arguments:
24+
source: the list in which to search for the occurrence of target.
25+
target: the list to search for as a sublist of source
26+
27+
Returns:
28+
true if target is in source; false otherwise
29+
"""
30+
for index in (i for i, e in enumerate(source) if e == target[0]):
31+
if tuple(source[index: index + len(target)]) == target:
32+
return True
33+
return False
34+
35+
36+
class PartitionMovements:
37+
"""
38+
This class maintains some data structures to simplify lookup of partition movements among consumers.
39+
At each point of time during a partition rebalance it keeps track of partition movements
40+
corresponding to each topic, and also possible movement (in form a ConsumerPair object) for each partition.
41+
"""
42+
43+
def __init__(self):
44+
self.partition_movements_by_topic = defaultdict(
45+
lambda: defaultdict(set)
46+
)
47+
self.partition_movements = {}
48+
49+
def move_partition(self, partition, old_consumer, new_consumer):
50+
pair = ConsumerPair(src_member_id=old_consumer, dst_member_id=new_consumer)
51+
if partition in self.partition_movements:
52+
# this partition has previously moved
53+
existing_pair = self._remove_movement_record_of_partition(partition)
54+
assert existing_pair.dst_member_id == old_consumer
55+
if existing_pair.src_member_id != new_consumer:
56+
# the partition is not moving back to its previous consumer
57+
self._add_partition_movement_record(
58+
partition, ConsumerPair(src_member_id=existing_pair.src_member_id, dst_member_id=new_consumer)
59+
)
60+
else:
61+
self._add_partition_movement_record(partition, pair)
62+
63+
def get_partition_to_be_moved(self, partition, old_consumer, new_consumer):
64+
if partition.topic not in self.partition_movements_by_topic:
65+
return partition
66+
if partition in self.partition_movements:
67+
# this partition has previously moved
68+
assert old_consumer == self.partition_movements[partition].dst_member_id
69+
old_consumer = self.partition_movements[partition].src_member_id
70+
reverse_pair = ConsumerPair(src_member_id=new_consumer, dst_member_id=old_consumer)
71+
if reverse_pair not in self.partition_movements_by_topic[partition.topic]:
72+
return partition
73+
74+
return next(iter(self.partition_movements_by_topic[partition.topic][reverse_pair]))
75+
76+
def are_sticky(self):
77+
for topic, movements in six.iteritems(self.partition_movements_by_topic):
78+
movement_pairs = set(movements.keys())
79+
if self._has_cycles(movement_pairs):
80+
log.error(
81+
"Stickiness is violated for topic {}\n"
82+
"Partition movements for this topic occurred among the following consumer pairs:\n"
83+
"{}".format(topic, movement_pairs)
84+
)
85+
return False
86+
return True
87+
88+
def _remove_movement_record_of_partition(self, partition):
89+
pair = self.partition_movements[partition]
90+
del self.partition_movements[partition]
91+
92+
self.partition_movements_by_topic[partition.topic][pair].remove(partition)
93+
if not self.partition_movements_by_topic[partition.topic][pair]:
94+
del self.partition_movements_by_topic[partition.topic][pair]
95+
if not self.partition_movements_by_topic[partition.topic]:
96+
del self.partition_movements_by_topic[partition.topic]
97+
98+
return pair
99+
100+
def _add_partition_movement_record(self, partition, pair):
101+
self.partition_movements[partition] = pair
102+
self.partition_movements_by_topic[partition.topic][pair].add(partition)
103+
104+
def _has_cycles(self, consumer_pairs):
105+
cycles = set()
106+
for pair in consumer_pairs:
107+
reduced_pairs = deepcopy(consumer_pairs)
108+
reduced_pairs.remove(pair)
109+
path = [pair.src_member_id]
110+
if self._is_linked(pair.dst_member_id, pair.src_member_id, reduced_pairs, path) and not self._is_subcycle(
111+
path, cycles
112+
):
113+
cycles.add(tuple(path))
114+
log.error("A cycle of length {} was found: {}".format(len(path) - 1, path))
115+
116+
# for now we want to make sure there is no partition movements of the same topic between a pair of consumers.
117+
# the odds of finding a cycle among more than two consumers seem to be very low (according to various randomized
118+
# tests with the given sticky algorithm) that it should not worth the added complexity of handling those cases.
119+
for cycle in cycles:
120+
if len(cycle) == 3: # indicates a cycle of length 2
121+
return True
122+
return False
123+
124+
@staticmethod
125+
def _is_subcycle(cycle, cycles):
126+
super_cycle = deepcopy(cycle)
127+
super_cycle = super_cycle[:-1]
128+
super_cycle.extend(cycle)
129+
for found_cycle in cycles:
130+
if len(found_cycle) == len(cycle) and is_sublist(super_cycle, found_cycle):
131+
return True
132+
return False
133+
134+
def _is_linked(self, src, dst, pairs, current_path):
135+
if src == dst:
136+
return False
137+
if not pairs:
138+
return False
139+
if ConsumerPair(src, dst) in pairs:
140+
current_path.append(src)
141+
current_path.append(dst)
142+
return True
143+
for pair in pairs:
144+
if pair.src_member_id == src:
145+
reduced_set = deepcopy(pairs)
146+
reduced_set.remove(pair)
147+
current_path.append(pair.src_member_id)
148+
return self._is_linked(pair.dst_member_id, dst, reduced_set, current_path)
149+
return False
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
class SortedSet:
2+
def __init__(self, iterable=None, key=None):
3+
self._key = key if key is not None else lambda x: x
4+
self._set = set(iterable) if iterable is not None else set()
5+
6+
self._cached_last = None
7+
self._cached_first = None
8+
9+
def first(self):
10+
if self._cached_first is not None:
11+
return self._cached_first
12+
13+
first = None
14+
for element in self._set:
15+
if first is None or self._key(first) > self._key(element):
16+
first = element
17+
self._cached_first = first
18+
return first
19+
20+
def last(self):
21+
if self._cached_last is not None:
22+
return self._cached_last
23+
24+
last = None
25+
for element in self._set:
26+
if last is None or self._key(last) < self._key(element):
27+
last = element
28+
self._cached_last = last
29+
return last
30+
31+
def pop_last(self):
32+
value = self.last()
33+
self._set.remove(value)
34+
self._cached_last = None
35+
return value
36+
37+
def add(self, value):
38+
if self._cached_last is not None and self._key(value) > self._key(self._cached_last):
39+
self._cached_last = value
40+
if self._cached_first is not None and self._key(value) < self._key(self._cached_first):
41+
self._cached_first = value
42+
43+
return self._set.add(value)
44+
45+
def remove(self, value):
46+
if self._cached_last is not None and self._cached_last == value:
47+
self._cached_last = None
48+
if self._cached_first is not None and self._cached_first == value:
49+
self._cached_first = None
50+
51+
return self._set.remove(value)
52+
53+
def __contains__(self, value):
54+
return value in self._set
55+
56+
def __iter__(self):
57+
return iter(sorted(self._set, key=self._key))
58+
59+
def _bool(self):
60+
return len(self._set) != 0
61+
62+
__nonzero__ = _bool
63+
__bool__ = _bool

0 commit comments

Comments
 (0)