Skip to content

Commit 05becff

Browse files
committed
[ESM] Add backoff between stream poller retries
1 parent 6a8b01f commit 05becff

File tree

1 file changed

+47
-7
lines changed
  • localstack-core/localstack/services/lambda_/event_source_mapping/pollers

1 file changed

+47
-7
lines changed

localstack-core/localstack/services/lambda_/event_source_mapping/pollers/stream_poller.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
import threading
34
from abc import abstractmethod
45
from datetime import datetime
56
from typing import Iterator
@@ -28,6 +29,7 @@
2829
)
2930
from localstack.services.lambda_.event_source_mapping.pollers.sqs_poller import get_queue_url
3031
from localstack.utils.aws.arns import parse_arn, s3_bucket_name
32+
from localstack.utils.backoff import ExponentialBackoff
3133
from localstack.utils.strings import long_uid
3234

3335
LOG = logging.getLogger(__name__)
@@ -47,6 +49,9 @@ class StreamPoller(Poller):
4749
# The ARN of the processor (e.g., Pipe ARN)
4850
partner_resource_arn: str | None
4951

52+
# Used for backing-off between retries and breaking the retry loop
53+
_is_shutdown: threading.Event
54+
5055
def __init__(
5156
self,
5257
source_arn: str,
@@ -62,6 +67,8 @@ def __init__(
6267
self.shards = {}
6368
self.iterator_over_shards = None
6469

70+
self._is_shutdown = threading.Event()
71+
6572
@abstractmethod
6673
def transform_into_events(self, records: list[dict], shard_id) -> list[dict]:
6774
pass
@@ -104,12 +111,29 @@ def format_datetime(self, time: datetime) -> str:
104111
def get_sequence_number(self, record: dict) -> str:
105112
pass
106113

114+
def close(self):
115+
self._is_shutdown.set()
116+
107117
def pre_filter(self, events: list[dict]) -> list[dict]:
108118
return events
109119

110120
def post_filter(self, events: list[dict]) -> list[dict]:
111121
return events
112122

123+
def has_record_expired(self, event: dict):
124+
# Check MaximumRecordAgeInSeconds
125+
if maximum_record_age_in_seconds := self.stream_parameters.get("MaximumRecordAgeInSeconds"):
126+
arrival_timestamp_of_last_event = event.get("approximateArrivalTimestamp")
127+
if not arrival_timestamp_of_last_event:
128+
return False
129+
130+
now = get_current_time().timestamp()
131+
record_age_in_seconds = now - arrival_timestamp_of_last_event
132+
if record_age_in_seconds > maximum_record_age_in_seconds:
133+
return True
134+
135+
return False
136+
113137
def poll_events(self):
114138
"""Generalized poller for streams such as Kinesis or DynamoDB
115139
Examples of Kinesis consumers:
@@ -146,14 +170,13 @@ def poll_events_from_shard(self, shard_id: str, shard_iterator: str):
146170
abort_condition = None
147171
get_records_response = self.get_records(shard_iterator)
148172
records = get_records_response["Records"]
173+
if not records:
174+
return
175+
149176
polled_events = self.transform_into_events(records, shard_id)
150177
# Check MaximumRecordAgeInSeconds
151-
if maximum_record_age_in_seconds := self.stream_parameters.get("MaximumRecordAgeInSeconds"):
152-
arrival_timestamp_of_last_event = polled_events[-1]["approximateArrivalTimestamp"]
153-
now = get_current_time().timestamp()
154-
record_age_in_seconds = now - arrival_timestamp_of_last_event
155-
if record_age_in_seconds > maximum_record_age_in_seconds:
156-
abort_condition = "RecordAgeExpired"
178+
if self.has_record_expired(polled_events[-1]):
179+
abort_condition = "RecordAgeExpired"
157180

158181
# TODO: implement format detection behavior (e.g., for JSON body):
159182
# https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-pipes-event-filtering.html
@@ -187,9 +210,26 @@ def poll_events_from_shard(self, shard_id: str, shard_iterator: str):
187210
# TODO: think about how to avoid starvation of other shards if one shard runs into infinite retries
188211
attempts = 0
189212
error_payload = {}
190-
while not abort_condition and not self.max_retries_exceeded(attempts):
213+
214+
boff = ExponentialBackoff(max_retries=attempts)
215+
while (
216+
not abort_condition
217+
and not self.max_retries_exceeded(attempts)
218+
and not self._is_shutdown.is_set()
219+
):
191220
try:
221+
if self.has_record_expired(polled_events[-1]):
222+
abort_condition = "RecordAgeExpired"
223+
224+
if attempts > 0:
225+
# TODO: Should we always backoff (with jitter) before processing since we may not want multiple pollers
226+
# all starting up and polling simultaneously
227+
# For example: 500 persisted ESMs starting up and requesting concurrently could flood gateway
228+
self._is_shutdown.wait(boff.next_backoff())
229+
192230
self.processor.process_events_batch(events)
231+
boff.reset()
232+
193233
# Update shard iterator if execution is successful
194234
self.shards[shard_id] = get_records_response["NextShardIterator"]
195235
return

0 commit comments

Comments
 (0)