Skip to content

Commit c6096d8

Browse files
authored
[Utils] Add a batch policy utility (#12430)
1 parent 27fc564 commit c6096d8

File tree

2 files changed

+314
-0
lines changed

2 files changed

+314
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import copy
2+
import time
3+
from typing import Generic, List, Optional, TypeVar, overload
4+
5+
from pydantic import Field
6+
from pydantic.dataclasses import dataclass
7+
8+
T = TypeVar("T")
9+
10+
# alias to signify whether a batch policy has been triggered
11+
BatchPolicyTriggered = bool
12+
13+
14+
# TODO: Add batching on bytes as well.
15+
@dataclass
16+
class Batcher(Generic[T]):
17+
"""
18+
A utility for collecting items into batches and flushing them when one or more batch policy conditions are met.
19+
20+
The batch policy can be created to trigger on:
21+
- max_count: Maximum number of items added
22+
- max_window: Maximum time window (in seconds)
23+
24+
If no limits are specified, the batcher is always in triggered state.
25+
26+
Example usage:
27+
28+
import time
29+
30+
# Triggers when 2 (or more) items are added
31+
batcher = Batcher(max_count=2)
32+
assert batcher.add(["item1", "item2", "item3"])
33+
assert batcher.flush() == ["item1", "item2", "item3"]
34+
35+
# Triggers partially when 2 (or more) items are added
36+
batcher = Batcher(max_count=2)
37+
assert batcher.add(["item1", "item2", "item3"])
38+
assert batcher.flush(partial=True) == ["item1", "item2"]
39+
assert batcher.add("item4")
40+
assert batcher.flush(partial=True) == ["item3", "item4"]
41+
42+
# Trigger 2 seconds after the first add
43+
batcher = Batcher(max_window=2.0)
44+
assert not batcher.add(["item1", "item2", "item3"])
45+
time.sleep(2.1)
46+
assert not batcher.add(["item4"])
47+
assert batcher.flush() == ["item1", "item2", "item3", "item4"]
48+
"""
49+
50+
max_count: Optional[int] = Field(default=None, description="Maximum number of items", ge=0)
51+
max_window: Optional[float] = Field(
52+
default=None, description="Maximum time window in seconds", ge=0
53+
)
54+
55+
_triggered: bool = Field(default=False, init=False)
56+
_last_batch_time: float = Field(default_factory=time.monotonic, init=False)
57+
_batch: list[T] = Field(default_factory=list, init=False)
58+
59+
@property
60+
def period(self) -> float:
61+
return time.monotonic() - self._last_batch_time
62+
63+
def _check_batch_policy(self) -> bool:
64+
"""Check if any batch policy conditions are met"""
65+
if self.max_count is not None and len(self._batch) >= self.max_count:
66+
self._triggered = True
67+
elif self.max_window is not None and self.period >= self.max_window:
68+
self._triggered = True
69+
elif not self.max_count and not self.max_window:
70+
# always return true
71+
self._triggered = True
72+
73+
return self._triggered
74+
75+
@overload
76+
def add(self, item: T, *, deep_copy: bool = False) -> BatchPolicyTriggered: ...
77+
78+
@overload
79+
def add(self, items: List[T], *, deep_copy: bool = False) -> BatchPolicyTriggered: ...
80+
81+
def add(self, item_or_items: T | list[T], *, deep_copy: bool = False) -> BatchPolicyTriggered:
82+
"""
83+
Add an item or list of items to the collected batch.
84+
85+
Returns:
86+
BatchPolicyTriggered: True if the batch policy was triggered during addition, False otherwise.
87+
"""
88+
if deep_copy:
89+
item_or_items = copy.deepcopy(item_or_items)
90+
91+
if isinstance(item_or_items, list):
92+
self._batch.extend(item_or_items)
93+
else:
94+
self._batch.append(item_or_items)
95+
96+
# Check if the last addition triggered the batch policy
97+
return self.is_triggered()
98+
99+
def flush(self, *, partial=False) -> list[T]:
100+
result = []
101+
if not partial or not self.max_count:
102+
result = self._batch.copy()
103+
self._batch.clear()
104+
else:
105+
batch_size = min(self.max_count, len(self._batch))
106+
result = self._batch[:batch_size].copy()
107+
self._batch = self._batch[batch_size:]
108+
109+
self._last_batch_time = time.monotonic()
110+
self._triggered = False
111+
self._check_batch_policy()
112+
113+
return result
114+
115+
def duration_until_next_batch(self) -> float:
116+
if not self.max_window:
117+
return -1
118+
return max(self.max_window - self.period, -1)
119+
120+
def get_current_size(self) -> int:
121+
return len(self._batch)
122+
123+
def is_triggered(self):
124+
return self._triggered or self._check_batch_policy()

tests/unit/utils/test_batch_policy.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import time
2+
3+
import pytest
4+
5+
from localstack.utils.batch_policy import Batcher
6+
7+
8+
class SimpleItem:
9+
def __init__(self, number=10):
10+
self.number = number
11+
12+
13+
class TestBatcher:
14+
def test_add_single_item(self):
15+
batcher = Batcher(max_count=2)
16+
17+
assert not batcher.add("item1")
18+
assert batcher.get_current_size() == 1
19+
assert not batcher.is_triggered()
20+
21+
assert batcher.add("item2")
22+
assert batcher.is_triggered()
23+
24+
result = batcher.flush()
25+
assert result == ["item1", "item2"]
26+
assert batcher.get_current_size() == 0
27+
28+
def test_add_multple_items(self):
29+
batcher = Batcher(max_count=3)
30+
31+
assert not batcher.add(["item1", "item2"])
32+
assert batcher.get_current_size() == 2
33+
assert not batcher.is_triggered()
34+
35+
assert batcher.add(["item3", "item4"]) # exceeds max_count
36+
assert batcher.is_triggered()
37+
assert batcher.get_current_size() == 4
38+
39+
result = batcher.flush()
40+
assert result == ["item1", "item2", "item3", "item4"]
41+
assert batcher.get_current_size() == 0
42+
43+
assert batcher.add(["item1", "item2", "item3", "item4"])
44+
assert batcher.flush() == ["item1", "item2", "item3", "item4"]
45+
assert not batcher.is_triggered()
46+
47+
def test_max_count_limit(self):
48+
batcher = Batcher(max_count=3)
49+
50+
assert not batcher.add("item1")
51+
assert not batcher.add("item2")
52+
assert batcher.add("item3")
53+
54+
assert batcher.is_triggered()
55+
assert batcher.get_current_size() == 3
56+
57+
result = batcher.flush()
58+
assert result == ["item1", "item2", "item3"]
59+
assert batcher.get_current_size() == 0
60+
61+
assert not batcher.add("item4")
62+
assert not batcher.add("item5")
63+
assert batcher.get_current_size() == 2
64+
65+
def test_max_window_limit(self):
66+
max_window = 0.5
67+
batcher = Batcher(max_window=max_window)
68+
69+
assert not batcher.add("item1")
70+
assert batcher.get_current_size() == 1
71+
assert not batcher.is_triggered()
72+
73+
assert not batcher.add("item2")
74+
assert batcher.get_current_size() == 2
75+
assert not batcher.is_triggered()
76+
77+
time.sleep(max_window + 0.1)
78+
79+
assert batcher.add("item3")
80+
assert batcher.is_triggered()
81+
assert batcher.get_current_size() == 3
82+
83+
result = batcher.flush()
84+
assert result == ["item1", "item2", "item3"]
85+
assert batcher.get_current_size() == 0
86+
87+
def test_multiple_policies(self):
88+
batcher = Batcher(max_count=5, max_window=2.0)
89+
90+
item1 = SimpleItem(1)
91+
for _ in range(5):
92+
batcher.add(item1)
93+
assert batcher.is_triggered()
94+
95+
result = batcher.flush()
96+
assert result == [item1, item1, item1, item1, item1]
97+
assert batcher.get_current_size() == 0
98+
99+
batcher.add(item1)
100+
assert not batcher.is_triggered()
101+
102+
item2 = SimpleItem(10)
103+
104+
time.sleep(2.1)
105+
batcher.add(item2)
106+
assert batcher.is_triggered()
107+
108+
result = batcher.flush()
109+
assert result == [item1, item2]
110+
111+
def test_flush(self):
112+
batcher = Batcher(max_count=10)
113+
114+
batcher.add("item1")
115+
batcher.add("item2")
116+
batcher.add("item3")
117+
118+
result = batcher.flush()
119+
assert result == ["item1", "item2", "item3"]
120+
assert batcher.get_current_size() == 0
121+
122+
batcher.add("item4")
123+
result = batcher.flush()
124+
assert result == ["item4"]
125+
assert batcher.get_current_size() == 0
126+
127+
@pytest.mark.parametrize(
128+
"max_count,max_window",
129+
[(0, 10), (10, 0), (None, None)],
130+
)
131+
def test_no_limits(self, max_count, max_window):
132+
if max_count or max_window:
133+
batcher = Batcher(max_count=max_count, max_window=max_window)
134+
else:
135+
batcher = Batcher()
136+
137+
assert batcher.is_triggered() # no limit always returns true
138+
139+
assert batcher.add("item1")
140+
assert batcher.get_current_size() == 1
141+
assert batcher.is_triggered()
142+
143+
assert batcher.add(["item2", "item3"])
144+
assert batcher.get_current_size() == 3
145+
assert batcher.is_triggered()
146+
147+
result = batcher.flush()
148+
assert result == ["item1", "item2", "item3"]
149+
assert batcher.get_current_size() == 0
150+
151+
def test_triggered_state(self):
152+
batcher = Batcher(max_count=2)
153+
154+
assert not batcher.add("item1")
155+
assert not batcher.is_triggered()
156+
157+
assert batcher.add("item2")
158+
assert batcher.is_triggered()
159+
160+
assert batcher.add("item3")
161+
assert batcher.flush() == ["item1", "item2", "item3"]
162+
assert batcher.get_current_size() == 0
163+
assert not batcher.is_triggered()
164+
165+
def test_max_count_partial_flush(self):
166+
batcher = Batcher(max_count=2)
167+
168+
assert batcher.add(["item1", "item2", "item3", "item4"])
169+
assert batcher.is_triggered()
170+
171+
assert batcher.flush(partial=True) == ["item1", "item2"]
172+
assert batcher.get_current_size() == 2
173+
174+
assert batcher.flush(partial=True) == ["item3", "item4"]
175+
assert not batcher.is_triggered() # early flush
176+
177+
assert batcher.flush() == []
178+
assert batcher.get_current_size() == 0
179+
assert not batcher.is_triggered()
180+
181+
def test_deep_copy(self):
182+
original = {"key": "value"}
183+
batcher = Batcher(max_count=2)
184+
185+
batcher.add(original, deep_copy=True)
186+
187+
original["key"] = "modified"
188+
189+
batch = batcher.flush()
190+
assert batch[0]["key"] == "value"

0 commit comments

Comments
 (0)