Skip to content

Commit 99473c6

Browse files
committed
Convert SQS test util into re-usable fixture
1 parent 0338932 commit 99473c6

File tree

4 files changed

+78
-50
lines changed

4 files changed

+78
-50
lines changed

localstack-core/localstack/testing/pytest/fixtures.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
import textwrap
88
import time
9-
from typing import Any, Callable, Dict, List, Optional, Tuple
9+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
1010

1111
import botocore.auth
1212
import botocore.config
@@ -62,6 +62,11 @@
6262
WAITER_STACK_DELETE_COMPLETE = "stack_delete_complete"
6363

6464

65+
if TYPE_CHECKING:
66+
from mypy_boto3_sqs import SQSClient
67+
from mypy_boto3_sqs.type_defs import MessageTypeDef
68+
69+
6570
@pytest.fixture(scope="class")
6671
def aws_http_client_factory(aws_session):
6772
"""
@@ -365,6 +370,65 @@ def factory(queue_url: str, expected_messages: int, max_iterations: int = 3):
365370
return factory
366371

367372

373+
@pytest.fixture
374+
def sqs_collect_messages(aws_client):
375+
"""Collects SQS messages from a given queue_url and deletes them by default.
376+
Example usage:
377+
messages = sqs_collect_messages(
378+
my_queue_url,
379+
expected=2,
380+
timeout=10,
381+
attribute_names=["All"],
382+
message_attribute_names=["All"],
383+
)
384+
"""
385+
386+
def factory(
387+
queue_url: str,
388+
expected: int,
389+
timeout: int,
390+
delete: bool = True,
391+
attribute_names: list[str] = None,
392+
message_attribute_names: list[str] = None,
393+
max_number_of_messages: int = 1,
394+
wait_time_seconds: int = 5,
395+
sqs_client: "SQSClient | None" = None,
396+
) -> list["MessageTypeDef"]:
397+
sqs_client = sqs_client or aws_client.sqs
398+
collected = []
399+
400+
def _receive():
401+
response = sqs_client.receive_message(
402+
QueueUrl=queue_url,
403+
# Maximum is 20 seconds. Performs long polling.
404+
WaitTimeSeconds=wait_time_seconds,
405+
# Maximum 10 messages
406+
MaxNumberOfMessages=max_number_of_messages,
407+
AttributeNames=attribute_names or [],
408+
MessageAttributeNames=message_attribute_names or [],
409+
)
410+
411+
if messages := response.get("Messages"):
412+
collected.extend(messages)
413+
414+
if delete:
415+
for m in messages:
416+
sqs_client.delete_message(
417+
QueueUrl=queue_url, ReceiptHandle=m["ReceiptHandle"]
418+
)
419+
420+
return len(collected) >= expected
421+
422+
if not poll_condition(_receive, timeout=timeout):
423+
raise TimeoutError(
424+
f"gave up waiting for messages (expected={expected}, actual={len(collected)}"
425+
)
426+
427+
return collected
428+
429+
yield factory
430+
431+
368432
@pytest.fixture
369433
def sqs_queue(sqs_create_queue):
370434
return sqs_create_queue()

tests/aws/services/sqs/test_sqs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
from tests.aws.services.lambda_.functions import lambda_integration
3434
from tests.aws.services.lambda_.test_lambda import TEST_LAMBDA_PYTHON
3535

36-
from .utils import sqs_collect_messages
37-
3836
if TYPE_CHECKING:
3937
from mypy_boto3_sqs import SQSClient
4038

@@ -2936,6 +2934,7 @@ def test_dead_letter_queue_message_attributes(
29362934
sqs_create_queue,
29372935
sqs_get_queue_arn,
29382936
snapshot,
2937+
sqs_collect_messages,
29392938
):
29402939
sqs = aws_client.sqs
29412940

@@ -2990,7 +2989,6 @@ def test_dead_letter_queue_message_attributes(
29902989
snapshot.match("rec-pre-dlq", messages)
29912990

29922991
messages = sqs_collect_messages(
2993-
sqs,
29942992
dl_queue_url,
29952993
expected=2,
29962994
timeout=10,

tests/aws/services/sqs/test_sqs_move_task.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from localstack.utils.aws import arns
1111
from localstack.utils.sync import retry
1212

13-
from .utils import sqs_collect_messages, sqs_wait_queue_size
13+
from .utils import sqs_wait_queue_size
1414

1515
QueueUrl = str
1616

@@ -125,6 +125,7 @@ def test_basic_move_task_workflow(
125125
sqs_create_queue,
126126
sqs_create_dlq_pipe,
127127
sqs_get_queue_arn,
128+
sqs_collect_messages,
128129
aws_client,
129130
snapshot,
130131
):
@@ -161,7 +162,7 @@ def test_basic_move_task_workflow(
161162
assert decoded_source_arn == source_arn
162163

163164
# check that messages arrived in destination queue correctly
164-
messages = sqs_collect_messages(sqs, destination_queue, expected=2, timeout=10)
165+
messages = sqs_collect_messages(destination_queue, expected=2, timeout=10)
165166
assert {message["Body"] for message in messages} == {"message-1", "message-2"}
166167

167168
# check move task completion (in AWS, approximate number of messages may take a while to update)
@@ -184,6 +185,7 @@ def test_move_task_workflow_with_default_destination(
184185
sqs_create_queue,
185186
sqs_create_dlq_pipe,
186187
sqs_get_queue_arn,
188+
sqs_collect_messages,
187189
aws_client,
188190
snapshot,
189191
):
@@ -221,7 +223,7 @@ def test_move_task_workflow_with_default_destination(
221223
assert decoded_source_arn == source_arn
222224

223225
# check that messages arrived in destination queue correctly
224-
messages = sqs_collect_messages(sqs, queue_url, expected=2, timeout=10)
226+
messages = sqs_collect_messages(queue_url, expected=2, timeout=10)
225227
assert {message["Body"] for message in messages} == {"message-1", "message-2"}
226228

227229
# check move task completion (in AWS, approximate number of messages may take a while to update)
@@ -244,6 +246,7 @@ def test_move_task_workflow_with_multiple_sources_as_default_destination(
244246
sqs_create_queue,
245247
sqs_create_dlq_pipe,
246248
sqs_get_queue_arn,
249+
sqs_collect_messages,
247250
aws_client,
248251
snapshot,
249252
):
@@ -295,10 +298,10 @@ def test_move_task_workflow_with_multiple_sources_as_default_destination(
295298
snapshot.match("start-message-move-task-response", response)
296299

297300
# check that messages arrived in destination queue correctly
298-
messages = sqs_collect_messages(sqs, queue1_url, expected=2, timeout=10)
301+
messages = sqs_collect_messages(queue1_url, expected=2, timeout=10)
299302
assert {message["Body"] for message in messages} == {"message-1-1", "message-1-2"}
300303

301-
messages = sqs_collect_messages(sqs, queue2_url, expected=2, timeout=10)
304+
messages = sqs_collect_messages(queue2_url, expected=2, timeout=10)
302305
assert {message["Body"] for message in messages} == {"message-2-1", "message-2-2"}
303306

304307
# check move task completion (in AWS, approximate number of messages may take a while to update)
@@ -321,6 +324,7 @@ def test_move_task_with_throughput_limit(
321324
sqs_create_queue,
322325
sqs_create_dlq_pipe,
323326
sqs_get_queue_arn,
327+
sqs_collect_messages,
324328
aws_client,
325329
snapshot,
326330
):
@@ -353,7 +357,7 @@ def test_move_task_with_throughput_limit(
353357
)
354358
snapshot.match("start-message-move-task-response", response)
355359
started = time.time()
356-
messages = sqs_collect_messages(sqs, destination_queue, n, 60)
360+
messages = sqs_collect_messages(destination_queue, n, 60)
357361
assert {message["Body"] for message in messages} == {
358362
"message-0",
359363
"message-1",
@@ -378,6 +382,7 @@ def test_move_task_cancel(
378382
sqs_create_queue,
379383
sqs_create_dlq_pipe,
380384
sqs_get_queue_arn,
385+
sqs_collect_messages,
381386
aws_client,
382387
snapshot,
383388
):
@@ -411,7 +416,7 @@ def test_move_task_cancel(
411416
task_handle = response["TaskHandle"]
412417

413418
# wait for two messages to arrive, then cancel the task
414-
messages = sqs_collect_messages(sqs, destination_queue, 2, 60)
419+
messages = sqs_collect_messages(destination_queue, 2, 60)
415420
assert len(messages) == 2
416421

417422
response = sqs.list_message_move_tasks(SourceArn=source_arn)

tests/aws/services/sqs/utils.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,6 @@
44

55
if TYPE_CHECKING:
66
from mypy_boto3_sqs import SQSClient
7-
from mypy_boto3_sqs.type_defs import MessageTypeDef
8-
9-
10-
def sqs_collect_messages(
11-
sqs_client: "SQSClient",
12-
queue_url: str,
13-
expected: int,
14-
timeout: int,
15-
delete: bool = True,
16-
attribute_names: list[str] = None,
17-
message_attribute_names: list[str] = None,
18-
) -> list["MessageTypeDef"]:
19-
collected = []
20-
21-
def _receive():
22-
response = sqs_client.receive_message(
23-
QueueUrl=queue_url,
24-
# try not to wait too long, but also not poll too often
25-
WaitTimeSeconds=min(max(1, timeout), 5),
26-
MaxNumberOfMessages=1,
27-
AttributeNames=attribute_names or [],
28-
MessageAttributeNames=message_attribute_names or [],
29-
)
30-
31-
if messages := response.get("Messages"):
32-
collected.extend(messages)
33-
34-
if delete:
35-
for m in messages:
36-
sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=m["ReceiptHandle"])
37-
38-
return len(collected) >= expected
39-
40-
if not poll_condition(_receive, timeout=timeout):
41-
raise TimeoutError(
42-
f"gave up waiting for messages (expected={expected}, actual={len(collected)}"
43-
)
44-
45-
return collected
467

478

489
def get_approx_number_of_messages(

0 commit comments

Comments
 (0)