Skip to content

[Utils] Add a batch policy utility #12430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions localstack-core/localstack/utils/batch_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import copy
import time
from typing import Generic, List, Optional, TypeVar, overload

from pydantic import Field
from pydantic.dataclasses import dataclass

T = TypeVar("T")

# alias to signify whether a batch policy has been triggered
BatchPolicyTriggered = bool


# TODO: Add batching on bytes as well.
@dataclass
class Batcher(Generic[T]):
"""
A utility for collecting items into batches and flushing them when one or more batch policy conditions are met.

The batch policy can be created to trigger on:
- max_count: Maximum number of items added
- max_window: Maximum time window (in seconds)

If no limits are specified, the batcher is always in triggered state.

Example usage:

import time

# Triggers when 2 (or more) items are added
batcher = Batcher(max_count=2)
assert batcher.add(["item1", "item2", "item3"])
assert batcher.flush() == ["item1", "item2", "item3"]

# Triggers partially when 2 (or more) items are added
batcher = Batcher(max_count=2)
assert batcher.add(["item1", "item2", "item3"])
assert batcher.flush(partial=True) == ["item1", "item2"]
assert batcher.add("item4")
assert batcher.flush(partial=True) == ["item3", "item4"]

# Trigger 2 seconds after the first add
batcher = Batcher(max_window=2.0)
assert not batcher.add(["item1", "item2", "item3"])
time.sleep(2.1)
assert not batcher.add(["item4"])
assert batcher.flush() == ["item1", "item2", "item3", "item4"]
"""

max_count: Optional[int] = Field(default=None, description="Maximum number of items", ge=0)
max_window: Optional[float] = Field(
default=None, description="Maximum time window in seconds", ge=0
)

_triggered: bool = Field(default=False, init=False)
_last_batch_time: float = Field(default_factory=time.monotonic, init=False)
_batch: list[T] = Field(default_factory=list, init=False)

@property
def period(self) -> float:
return time.monotonic() - self._last_batch_time

def _check_batch_policy(self) -> bool:
"""Check if any batch policy conditions are met"""
if self.max_count is not None and len(self._batch) >= self.max_count:
self._triggered = True
elif self.max_window is not None and self.period >= self.max_window:
self._triggered = True
elif not self.max_count and not self.max_window:
# always return true
self._triggered = True

return self._triggered

@overload
def add(self, item: T, *, deep_copy: bool = False) -> BatchPolicyTriggered: ...

@overload
def add(self, items: List[T], *, deep_copy: bool = False) -> BatchPolicyTriggered: ...

def add(self, item_or_items: T | list[T], *, deep_copy: bool = False) -> BatchPolicyTriggered:
"""
Add an item or list of items to the collected batch.

Returns:
BatchPolicyTriggered: True if the batch policy was triggered during addition, False otherwise.
"""
if deep_copy:
item_or_items = copy.deepcopy(item_or_items)

if isinstance(item_or_items, list):
self._batch.extend(item_or_items)
else:
self._batch.append(item_or_items)

# Check if the last addition triggered the batch policy
return self.is_triggered()

def flush(self, *, partial=False) -> list[T]:
result = []
if not partial or not self.max_count:
result = self._batch.copy()
self._batch.clear()
else:
batch_size = min(self.max_count, len(self._batch))
result = self._batch[:batch_size].copy()
self._batch = self._batch[batch_size:]

self._last_batch_time = time.monotonic()
self._triggered = False
self._check_batch_policy()

return result

def duration_until_next_batch(self) -> float:
if not self.max_window:
return -1
return max(self.max_window - self.period, -1)

def get_current_size(self) -> int:
return len(self._batch)

def is_triggered(self):
return self._triggered or self._check_batch_policy()
190 changes: 190 additions & 0 deletions tests/unit/utils/test_batch_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import time

import pytest

from localstack.utils.batch_policy import Batcher


class SimpleItem:
def __init__(self, number=10):
self.number = number


class TestBatcher:
def test_add_single_item(self):
batcher = Batcher(max_count=2)

assert not batcher.add("item1")
assert batcher.get_current_size() == 1
assert not batcher.is_triggered()

assert batcher.add("item2")
assert batcher.is_triggered()

result = batcher.flush()
assert result == ["item1", "item2"]
assert batcher.get_current_size() == 0

def test_add_multple_items(self):
batcher = Batcher(max_count=3)

assert not batcher.add(["item1", "item2"])
assert batcher.get_current_size() == 2
assert not batcher.is_triggered()

assert batcher.add(["item3", "item4"]) # exceeds max_count
assert batcher.is_triggered()
assert batcher.get_current_size() == 4

result = batcher.flush()
assert result == ["item1", "item2", "item3", "item4"]
assert batcher.get_current_size() == 0

assert batcher.add(["item1", "item2", "item3", "item4"])
assert batcher.flush() == ["item1", "item2", "item3", "item4"]
assert not batcher.is_triggered()

def test_max_count_limit(self):
batcher = Batcher(max_count=3)

assert not batcher.add("item1")
assert not batcher.add("item2")
assert batcher.add("item3")

assert batcher.is_triggered()
assert batcher.get_current_size() == 3

result = batcher.flush()
assert result == ["item1", "item2", "item3"]
assert batcher.get_current_size() == 0

assert not batcher.add("item4")
assert not batcher.add("item5")
assert batcher.get_current_size() == 2

def test_max_window_limit(self):
max_window = 0.5
batcher = Batcher(max_window=max_window)

assert not batcher.add("item1")
assert batcher.get_current_size() == 1
assert not batcher.is_triggered()

assert not batcher.add("item2")
assert batcher.get_current_size() == 2
assert not batcher.is_triggered()

time.sleep(max_window + 0.1)

assert batcher.add("item3")
assert batcher.is_triggered()
assert batcher.get_current_size() == 3

result = batcher.flush()
assert result == ["item1", "item2", "item3"]
assert batcher.get_current_size() == 0

def test_multiple_policies(self):
batcher = Batcher(max_count=5, max_window=2.0)

item1 = SimpleItem(1)
for _ in range(5):
batcher.add(item1)
assert batcher.is_triggered()

result = batcher.flush()
assert result == [item1, item1, item1, item1, item1]
assert batcher.get_current_size() == 0

batcher.add(item1)
assert not batcher.is_triggered()

item2 = SimpleItem(10)

time.sleep(2.1)
batcher.add(item2)
assert batcher.is_triggered()

result = batcher.flush()
assert result == [item1, item2]

def test_flush(self):
batcher = Batcher(max_count=10)

batcher.add("item1")
batcher.add("item2")
batcher.add("item3")

result = batcher.flush()
assert result == ["item1", "item2", "item3"]
assert batcher.get_current_size() == 0

batcher.add("item4")
result = batcher.flush()
assert result == ["item4"]
assert batcher.get_current_size() == 0

@pytest.mark.parametrize(
"max_count,max_window",
[(0, 10), (10, 0), (None, None)],
)
def test_no_limits(self, max_count, max_window):
if max_count or max_window:
batcher = Batcher(max_count=max_count, max_window=max_window)
else:
batcher = Batcher()

assert batcher.is_triggered() # no limit always returns true

assert batcher.add("item1")
assert batcher.get_current_size() == 1
assert batcher.is_triggered()

assert batcher.add(["item2", "item3"])
assert batcher.get_current_size() == 3
assert batcher.is_triggered()

result = batcher.flush()
assert result == ["item1", "item2", "item3"]
assert batcher.get_current_size() == 0

def test_triggered_state(self):
batcher = Batcher(max_count=2)

assert not batcher.add("item1")
assert not batcher.is_triggered()

assert batcher.add("item2")
assert batcher.is_triggered()

assert batcher.add("item3")
assert batcher.flush() == ["item1", "item2", "item3"]
assert batcher.get_current_size() == 0
assert not batcher.is_triggered()

def test_max_count_partial_flush(self):
batcher = Batcher(max_count=2)

assert batcher.add(["item1", "item2", "item3", "item4"])
assert batcher.is_triggered()

assert batcher.flush(partial=True) == ["item1", "item2"]
assert batcher.get_current_size() == 2

assert batcher.flush(partial=True) == ["item3", "item4"]
assert not batcher.is_triggered() # early flush

assert batcher.flush() == []
assert batcher.get_current_size() == 0
assert not batcher.is_triggered()

def test_deep_copy(self):
original = {"key": "value"}
batcher = Batcher(max_count=2)

batcher.add(original, deep_copy=True)

original["key"] = "modified"

batch = batcher.flush()
assert batch[0]["key"] == "value"
Loading