Skip to content

fix SNS FIFO ordering #12285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions localstack-core/localstack/services/sns/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import itertools
import logging
import os
import queue
import threading

LOG = logging.getLogger(__name__)


def _worker(work_queue: queue.Queue):
try:
while True:
work_item = work_queue.get(block=True)
if work_item is None:
return
work_item.run()
# delete reference to the work item to avoid it being in memory until the next blocking `queue.get` call returns
del work_item

except Exception:
LOG.exception("Exception in worker")


class _WorkItem:
def __init__(self, fn, args, kwargs):
self.fn = fn
self.args = args
self.kwargs = kwargs

def run(self):
try:
self.fn(*self.args, **self.kwargs)
except Exception:
LOG.exception("Unhandled Exception in while running %s", self.fn.__name__)


class TopicPartitionedThreadPoolExecutor:
"""
This topic partition the work between workers based on Topics.
It guarantees that each Topic only has one worker assigned, and thus that the tasks will be executed sequentially.

Loosely based on ThreadPoolExecutor for stdlib, but does not return Future as SNS does not need it (fire&forget)
Could be extended if needed to fit other needs.

Currently, we do not re-balance between workers if some of them have more load. This could be investigated.
"""

# Used to assign unique thread names when thread_name_prefix is not supplied.
_counter = itertools.count().__next__

def __init__(self, max_workers: int = None, thread_name_prefix: str = ""):
if max_workers is None:
max_workers = min(32, (os.cpu_count() or 1) + 4)
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")

self._max_workers = max_workers
self._thread_name_prefix = (
thread_name_prefix or f"TopicThreadPoolExecutor-{self._counter()}"
)

# for now, the pool isn't fair and is not redistributed depending on load
self._pool = {}
self._shutdown = False
self._lock = threading.Lock()
self._threads = set()
self._work_queues = []
self._cycle = itertools.cycle(range(max_workers))

def _add_worker(self):
work_queue = queue.SimpleQueue()
self._work_queues.append(work_queue)
thread_name = f"{self._thread_name_prefix}_{len(self._threads)}"
t = threading.Thread(name=thread_name, target=_worker, args=(work_queue,))
t.daemon = True
t.start()
self._threads.add(t)

def _get_work_queue(self, topic: str) -> queue.SimpleQueue:
if not (work_queue := self._pool.get(topic)):
if len(self._threads) < self._max_workers:
self._add_worker()

# we cycle through the possible indexes for a work queue, in order to distribute the load across
# once we get to the max amount of worker, the cycle will start back at 0
index = next(self._cycle)
work_queue = self._work_queues[index]

# TODO: the pool is not cleaned up at the moment, think about the clean-up interface
self._pool[topic] = work_queue
return work_queue

def submit(self, fn, topic, /, *args, **kwargs) -> None:
with self._lock:
work_queue = self._get_work_queue(topic)

if self._shutdown:
raise RuntimeError("cannot schedule new futures after shutdown")

w = _WorkItem(fn, args, kwargs)
work_queue.put(w)

def shutdown(self, wait=True):
with self._lock:
self._shutdown = True

# Send a wake-up to prevent threads calling
# _work_queue.get(block=True) from permanently blocking.
for work_queue in self._work_queues:
work_queue.put(None)

if wait:
for t in self._threads:
t.join()
17 changes: 15 additions & 2 deletions localstack-core/localstack/services/sns/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from localstack.config import external_service_url
from localstack.services.sns import constants as sns_constants
from localstack.services.sns.certificate import SNS_SERVER_PRIVATE_KEY
from localstack.services.sns.executor import TopicPartitionedThreadPoolExecutor
from localstack.services.sns.filter import SubscriptionFilter
from localstack.services.sns.models import (
SnsApplicationPlatforms,
Expand Down Expand Up @@ -1176,9 +1177,13 @@ class PublishDispatcher:

def __init__(self, num_thread: int = 10):
self.executor = ThreadPoolExecutor(num_thread, thread_name_prefix="sns_pub")
self.topic_partitioned_executor = TopicPartitionedThreadPoolExecutor(
max_workers=num_thread, thread_name_prefix="sns_pub_fifo"
)

def shutdown(self):
self.executor.shutdown(wait=False)
self.topic_partitioned_executor.shutdown(wait=False)

def _should_publish(
self,
Expand Down Expand Up @@ -1295,8 +1300,16 @@ def publish_batch_to_topic(self, ctx: SnsBatchPublishContext, topic_arn: str) ->
)
self._submit_notification(notifier, individual_ctx, subscriber)

def _submit_notification(self, notifier, ctx: SnsPublishContext, subscriber: SnsSubscription):
self.executor.submit(notifier.publish, context=ctx, subscriber=subscriber)
def _submit_notification(
self, notifier, ctx: SnsPublishContext | SnsBatchPublishContext, subscriber: SnsSubscription
):
if (topic_arn := subscriber.get("TopicArn", "")).endswith(".fifo"):
# TODO: we still need to implement Message deduplication on the topic level with `should_publish` for FIFO
self.topic_partitioned_executor.submit(
notifier.publish, topic_arn, context=ctx, subscriber=subscriber
)
else:
self.executor.submit(notifier.publish, context=ctx, subscriber=subscriber)

def publish_to_phone_number(self, ctx: SnsPublishContext, phone_number: str) -> None:
LOG.debug(
Expand Down
54 changes: 54 additions & 0 deletions tests/aws/services/sns/test_sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2997,6 +2997,60 @@ def test_publish_to_fifo_with_target_arn(self, sns_create_topic, aws_client):
)
assert "MessageId" in response

@markers.aws.validated
def test_message_to_fifo_sqs_ordering(
self,
sns_create_topic,
sqs_create_queue,
sns_create_sqs_subscription,
snapshot,
aws_client,
sqs_collect_messages,
):
topic_name = f"topic-{short_uid()}.fifo"
topic_attributes = {"FifoTopic": "true", "ContentBasedDeduplication": "true"}
topic_arn = sns_create_topic(
Name=topic_name,
Attributes=topic_attributes,
)["TopicArn"]

queue_attributes = {"FifoQueue": "true", "ContentBasedDeduplication": "true"}
queues = []
queue_amount = 5
message_amount = 10

for _ in range(queue_amount):
queue_name = f"queue-{short_uid()}.fifo"
queue_url = sqs_create_queue(
QueueName=queue_name,
Attributes=queue_attributes,
)
sns_create_sqs_subscription(
topic_arn=topic_arn, queue_url=queue_url, Attributes={"RawMessageDelivery": "true"}
)
queues.append(queue_url)

for i in range(message_amount):
aws_client.sns.publish(
TopicArn=topic_arn, Message=str(i), MessageGroupId="message-group-id-1"
)

all_messages = []
for queue_url in queues:
messages = sqs_collect_messages(
queue_url,
expected=message_amount,
timeout=10,
max_number_of_messages=message_amount,
)
contents = [message["Body"] for message in messages]
all_messages.append(contents)

# we're expecting the order to be the same across all queues
reference_order = all_messages[0]
for received_content in all_messages[1:]:
assert received_content == reference_order


class TestSNSSubscriptionSES:
@markers.aws.only_localstack
Expand Down
4 changes: 4 additions & 0 deletions tests/aws/services/sns/test_sns.snapshot.json
Original file line number Diff line number Diff line change
Expand Up @@ -5078,5 +5078,9 @@
]
}
}
},
"tests/aws/services/sns/test_sns.py::TestSNSSubscriptionSQSFifo::test_message_to_fifo_sqs_ordering": {
"recorded-date": "19-02-2025, 01:29:15",
"recorded-content": {}
Comment on lines +5082 to +5084
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty snapshot is because we're using a transformer fixture with autouse, so it creates an entry even if no snapshot are recorded 😅

}
}
3 changes: 3 additions & 0 deletions tests/aws/services/sns/test_sns.validation.json
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@
"tests/aws/services/sns/test_sns.py::TestSNSSubscriptionSQSFifo::test_message_to_fifo_sqs[True]": {
"last_validated_date": "2023-11-09T20:12:03+00:00"
},
"tests/aws/services/sns/test_sns.py::TestSNSSubscriptionSQSFifo::test_message_to_fifo_sqs_ordering": {
"last_validated_date": "2025-02-19T01:29:14+00:00"
},
"tests/aws/services/sns/test_sns.py::TestSNSSubscriptionSQSFifo::test_publish_batch_messages_from_fifo_topic_to_fifo_queue[False]": {
"last_validated_date": "2023-11-09T20:10:33+00:00"
},
Expand Down
Loading