diff --git a/localstack-core/localstack/services/sns/executor.py b/localstack-core/localstack/services/sns/executor.py new file mode 100644 index 0000000000000..ce4f8850d6e3e --- /dev/null +++ b/localstack-core/localstack/services/sns/executor.py @@ -0,0 +1,114 @@ +import itertools +import logging +import os +import queue +import threading + +LOG = logging.getLogger(__name__) + + +def _worker(work_queue: queue.Queue): + try: + while True: + work_item = work_queue.get(block=True) + if work_item is None: + return + work_item.run() + # delete reference to the work item to avoid it being in memory until the next blocking `queue.get` call returns + del work_item + + except Exception: + LOG.exception("Exception in worker") + + +class _WorkItem: + def __init__(self, fn, args, kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs + + def run(self): + try: + self.fn(*self.args, **self.kwargs) + except Exception: + LOG.exception("Unhandled Exception in while running %s", self.fn.__name__) + + +class TopicPartitionedThreadPoolExecutor: + """ + This topic partition the work between workers based on Topics. + It guarantees that each Topic only has one worker assigned, and thus that the tasks will be executed sequentially. + + Loosely based on ThreadPoolExecutor for stdlib, but does not return Future as SNS does not need it (fire&forget) + Could be extended if needed to fit other needs. + + Currently, we do not re-balance between workers if some of them have more load. This could be investigated. + """ + + # Used to assign unique thread names when thread_name_prefix is not supplied. + _counter = itertools.count().__next__ + + def __init__(self, max_workers: int = None, thread_name_prefix: str = ""): + if max_workers is None: + max_workers = min(32, (os.cpu_count() or 1) + 4) + if max_workers <= 0: + raise ValueError("max_workers must be greater than 0") + + self._max_workers = max_workers + self._thread_name_prefix = ( + thread_name_prefix or f"TopicThreadPoolExecutor-{self._counter()}" + ) + + # for now, the pool isn't fair and is not redistributed depending on load + self._pool = {} + self._shutdown = False + self._lock = threading.Lock() + self._threads = set() + self._work_queues = [] + self._cycle = itertools.cycle(range(max_workers)) + + def _add_worker(self): + work_queue = queue.SimpleQueue() + self._work_queues.append(work_queue) + thread_name = f"{self._thread_name_prefix}_{len(self._threads)}" + t = threading.Thread(name=thread_name, target=_worker, args=(work_queue,)) + t.daemon = True + t.start() + self._threads.add(t) + + def _get_work_queue(self, topic: str) -> queue.SimpleQueue: + if not (work_queue := self._pool.get(topic)): + if len(self._threads) < self._max_workers: + self._add_worker() + + # we cycle through the possible indexes for a work queue, in order to distribute the load across + # once we get to the max amount of worker, the cycle will start back at 0 + index = next(self._cycle) + work_queue = self._work_queues[index] + + # TODO: the pool is not cleaned up at the moment, think about the clean-up interface + self._pool[topic] = work_queue + return work_queue + + def submit(self, fn, topic, /, *args, **kwargs) -> None: + with self._lock: + work_queue = self._get_work_queue(topic) + + if self._shutdown: + raise RuntimeError("cannot schedule new futures after shutdown") + + w = _WorkItem(fn, args, kwargs) + work_queue.put(w) + + def shutdown(self, wait=True): + with self._lock: + self._shutdown = True + + # Send a wake-up to prevent threads calling + # _work_queue.get(block=True) from permanently blocking. + for work_queue in self._work_queues: + work_queue.put(None) + + if wait: + for t in self._threads: + t.join() diff --git a/localstack-core/localstack/services/sns/publisher.py b/localstack-core/localstack/services/sns/publisher.py index 5569f24a98096..9510885f51431 100644 --- a/localstack-core/localstack/services/sns/publisher.py +++ b/localstack-core/localstack/services/sns/publisher.py @@ -22,6 +22,7 @@ from localstack.config import external_service_url from localstack.services.sns import constants as sns_constants from localstack.services.sns.certificate import SNS_SERVER_PRIVATE_KEY +from localstack.services.sns.executor import TopicPartitionedThreadPoolExecutor from localstack.services.sns.filter import SubscriptionFilter from localstack.services.sns.models import ( SnsApplicationPlatforms, @@ -1176,9 +1177,13 @@ class PublishDispatcher: def __init__(self, num_thread: int = 10): self.executor = ThreadPoolExecutor(num_thread, thread_name_prefix="sns_pub") + self.topic_partitioned_executor = TopicPartitionedThreadPoolExecutor( + max_workers=num_thread, thread_name_prefix="sns_pub_fifo" + ) def shutdown(self): self.executor.shutdown(wait=False) + self.topic_partitioned_executor.shutdown(wait=False) def _should_publish( self, @@ -1295,8 +1300,16 @@ def publish_batch_to_topic(self, ctx: SnsBatchPublishContext, topic_arn: str) -> ) self._submit_notification(notifier, individual_ctx, subscriber) - def _submit_notification(self, notifier, ctx: SnsPublishContext, subscriber: SnsSubscription): - self.executor.submit(notifier.publish, context=ctx, subscriber=subscriber) + def _submit_notification( + self, notifier, ctx: SnsPublishContext | SnsBatchPublishContext, subscriber: SnsSubscription + ): + if (topic_arn := subscriber.get("TopicArn", "")).endswith(".fifo"): + # TODO: we still need to implement Message deduplication on the topic level with `should_publish` for FIFO + self.topic_partitioned_executor.submit( + notifier.publish, topic_arn, context=ctx, subscriber=subscriber + ) + else: + self.executor.submit(notifier.publish, context=ctx, subscriber=subscriber) def publish_to_phone_number(self, ctx: SnsPublishContext, phone_number: str) -> None: LOG.debug( diff --git a/tests/aws/services/sns/test_sns.py b/tests/aws/services/sns/test_sns.py index 6c29891011341..d2aecae511e9b 100644 --- a/tests/aws/services/sns/test_sns.py +++ b/tests/aws/services/sns/test_sns.py @@ -2997,6 +2997,60 @@ def test_publish_to_fifo_with_target_arn(self, sns_create_topic, aws_client): ) assert "MessageId" in response + @markers.aws.validated + def test_message_to_fifo_sqs_ordering( + self, + sns_create_topic, + sqs_create_queue, + sns_create_sqs_subscription, + snapshot, + aws_client, + sqs_collect_messages, + ): + topic_name = f"topic-{short_uid()}.fifo" + topic_attributes = {"FifoTopic": "true", "ContentBasedDeduplication": "true"} + topic_arn = sns_create_topic( + Name=topic_name, + Attributes=topic_attributes, + )["TopicArn"] + + queue_attributes = {"FifoQueue": "true", "ContentBasedDeduplication": "true"} + queues = [] + queue_amount = 5 + message_amount = 10 + + for _ in range(queue_amount): + queue_name = f"queue-{short_uid()}.fifo" + queue_url = sqs_create_queue( + QueueName=queue_name, + Attributes=queue_attributes, + ) + sns_create_sqs_subscription( + topic_arn=topic_arn, queue_url=queue_url, Attributes={"RawMessageDelivery": "true"} + ) + queues.append(queue_url) + + for i in range(message_amount): + aws_client.sns.publish( + TopicArn=topic_arn, Message=str(i), MessageGroupId="message-group-id-1" + ) + + all_messages = [] + for queue_url in queues: + messages = sqs_collect_messages( + queue_url, + expected=message_amount, + timeout=10, + max_number_of_messages=message_amount, + ) + contents = [message["Body"] for message in messages] + all_messages.append(contents) + + # we're expecting the order to be the same across all queues + reference_order = all_messages[0] + for received_content in all_messages[1:]: + assert received_content == reference_order + class TestSNSSubscriptionSES: @markers.aws.only_localstack diff --git a/tests/aws/services/sns/test_sns.snapshot.json b/tests/aws/services/sns/test_sns.snapshot.json index e45d2502cd39a..d48d98adf7ed7 100644 --- a/tests/aws/services/sns/test_sns.snapshot.json +++ b/tests/aws/services/sns/test_sns.snapshot.json @@ -5078,5 +5078,9 @@ ] } } + }, + "tests/aws/services/sns/test_sns.py::TestSNSSubscriptionSQSFifo::test_message_to_fifo_sqs_ordering": { + "recorded-date": "19-02-2025, 01:29:15", + "recorded-content": {} } } diff --git a/tests/aws/services/sns/test_sns.validation.json b/tests/aws/services/sns/test_sns.validation.json index 1d0899ab4e7b0..2897de7db25c1 100644 --- a/tests/aws/services/sns/test_sns.validation.json +++ b/tests/aws/services/sns/test_sns.validation.json @@ -185,6 +185,9 @@ "tests/aws/services/sns/test_sns.py::TestSNSSubscriptionSQSFifo::test_message_to_fifo_sqs[True]": { "last_validated_date": "2023-11-09T20:12:03+00:00" }, + "tests/aws/services/sns/test_sns.py::TestSNSSubscriptionSQSFifo::test_message_to_fifo_sqs_ordering": { + "last_validated_date": "2025-02-19T01:29:14+00:00" + }, "tests/aws/services/sns/test_sns.py::TestSNSSubscriptionSQSFifo::test_publish_batch_messages_from_fifo_topic_to_fifo_queue[False]": { "last_validated_date": "2023-11-09T20:10:33+00:00" },