|
| 1 | +import time |
| 2 | +from threading import Timer |
| 3 | + |
1 | 4 | import pytest
|
2 | 5 | import requests
|
3 | 6 | import xmltodict
|
@@ -361,3 +364,128 @@ def test_list_messages_with_queue_url_in_path(
|
361 | 364 | assert response.status_code == 400
|
362 | 365 | doc = response.json()
|
363 | 366 | assert doc["ErrorResponse"]["Error"]["Code"] == "AWS.SimpleQueueService.NonExistentQueue"
|
| 367 | + |
| 368 | + |
| 369 | +class TestSqsOverrideHeaders: |
| 370 | + @markers.aws.only_localstack |
| 371 | + def test_receive_message_override_max_number_of_messages( |
| 372 | + self, sqs_create_queue, aws_client_factory |
| 373 | + ): |
| 374 | + # Create standalone boto3 client since registering hooks to the session-wide |
| 375 | + # aws_client (from the fixture) will have side-effects. |
| 376 | + aws_client = aws_client_factory().sqs |
| 377 | + |
| 378 | + override_max_number_of_messages = 20 |
| 379 | + |
| 380 | + from localstack.services.sqs.constants import HEADER_LOCALSTACK_SQS_OVERRIDE_MESSAGE_COUNT |
| 381 | + from localstack.services.sqs.provider import MAX_NUMBER_OF_MESSAGES |
| 382 | + |
| 383 | + queue_url = sqs_create_queue() |
| 384 | + |
| 385 | + for i in range(override_max_number_of_messages): |
| 386 | + aws_client.send_message(QueueUrl=queue_url, MessageBody=f"message-{i}") |
| 387 | + |
| 388 | + with pytest.raises(ClientError): |
| 389 | + aws_client.receive_message( |
| 390 | + QueueUrl=queue_url, |
| 391 | + VisibilityTimeout=0, |
| 392 | + MaxNumberOfMessages=override_max_number_of_messages, |
| 393 | + AttributeNames=["All"], |
| 394 | + ) |
| 395 | + |
| 396 | + def _handle_receive_message_override(params, context, **kwargs): |
| 397 | + if not (requested_count := params.get("MaxNumberOfMessages")): |
| 398 | + return |
| 399 | + context[HEADER_LOCALSTACK_SQS_OVERRIDE_MESSAGE_COUNT] = str(requested_count) |
| 400 | + params["MaxNumberOfMessages"] = min(MAX_NUMBER_OF_MESSAGES, requested_count) |
| 401 | + |
| 402 | + def _handler_inject_header(params, context, **kwargs): |
| 403 | + if override_message_count := context.pop( |
| 404 | + HEADER_LOCALSTACK_SQS_OVERRIDE_MESSAGE_COUNT, None |
| 405 | + ): |
| 406 | + params["headers"][HEADER_LOCALSTACK_SQS_OVERRIDE_MESSAGE_COUNT] = ( |
| 407 | + override_message_count |
| 408 | + ) |
| 409 | + |
| 410 | + aws_client.meta.events.register( |
| 411 | + "provide-client-params.sqs.ReceiveMessage", _handle_receive_message_override |
| 412 | + ) |
| 413 | + |
| 414 | + aws_client.meta.events.register("before-call.sqs.ReceiveMessage", _handler_inject_header) |
| 415 | + |
| 416 | + response = aws_client.receive_message( |
| 417 | + QueueUrl=queue_url, |
| 418 | + VisibilityTimeout=30, |
| 419 | + MaxNumberOfMessages=override_max_number_of_messages, |
| 420 | + AttributeNames=["All"], |
| 421 | + ) |
| 422 | + |
| 423 | + messages = response.get("Messages", []) |
| 424 | + assert len(messages) == 20 |
| 425 | + |
| 426 | + @markers.aws.only_localstack |
| 427 | + def test_receive_message_override_message_wait_time_seconds( |
| 428 | + self, sqs_create_queue, aws_client_factory |
| 429 | + ): |
| 430 | + aws_client = aws_client_factory().sqs |
| 431 | + |
| 432 | + override_message_wait_time_seconds = 30 |
| 433 | + |
| 434 | + from localstack.services.sqs.constants import ( |
| 435 | + HEADER_LOCALSTACK_SQS_OVERRIDE_WAIT_TIME_SECONDS, |
| 436 | + ) |
| 437 | + from localstack.services.sqs.provider import MAX_NUMBER_OF_MESSAGES |
| 438 | + |
| 439 | + queue_url = sqs_create_queue() |
| 440 | + |
| 441 | + with pytest.raises(ClientError): |
| 442 | + aws_client.receive_message( |
| 443 | + QueueUrl=queue_url, |
| 444 | + VisibilityTimeout=0, |
| 445 | + MaxNumberOfMessages=MAX_NUMBER_OF_MESSAGES, |
| 446 | + WaitTimeSeconds=override_message_wait_time_seconds, |
| 447 | + AttributeNames=["All"], |
| 448 | + ) |
| 449 | + |
| 450 | + def _handle_receive_message_override(params, context, **kwargs): |
| 451 | + if not (requested_wait_time := params.get("WaitTimeSeconds")): |
| 452 | + return |
| 453 | + context[HEADER_LOCALSTACK_SQS_OVERRIDE_WAIT_TIME_SECONDS] = str(requested_wait_time) |
| 454 | + params["WaitTimeSeconds"] = min(20, requested_wait_time) |
| 455 | + |
| 456 | + def _handler_inject_header(params, context, **kwargs): |
| 457 | + if override_wait_time := context.pop( |
| 458 | + HEADER_LOCALSTACK_SQS_OVERRIDE_WAIT_TIME_SECONDS, None |
| 459 | + ): |
| 460 | + params["headers"][HEADER_LOCALSTACK_SQS_OVERRIDE_WAIT_TIME_SECONDS] = ( |
| 461 | + override_wait_time |
| 462 | + ) |
| 463 | + |
| 464 | + aws_client.meta.events.register( |
| 465 | + "provide-client-params.sqs.ReceiveMessage", _handle_receive_message_override |
| 466 | + ) |
| 467 | + |
| 468 | + aws_client.meta.events.register("before-call.sqs.ReceiveMessage", _handler_inject_header) |
| 469 | + |
| 470 | + def _send_message(): |
| 471 | + aws_client.send_message(QueueUrl=queue_url, MessageBody=f"message-{short_uid()}") |
| 472 | + |
| 473 | + # Populate with 9 messages (1 below the MaxNumberOfMessages threshold). |
| 474 | + # This should cause long-polling to exit since MaxNumberOfMessages is met. |
| 475 | + for _ in range(9): |
| 476 | + _send_message() |
| 477 | + |
| 478 | + Timer(25, _send_message).start() # send message asynchronously after 1 second |
| 479 | + |
| 480 | + start_t = time.time() |
| 481 | + response = aws_client.receive_message( |
| 482 | + QueueUrl=queue_url, |
| 483 | + VisibilityTimeout=30, |
| 484 | + MaxNumberOfMessages=MAX_NUMBER_OF_MESSAGES, |
| 485 | + WaitTimeSeconds=override_message_wait_time_seconds, |
| 486 | + AttributeNames=["All"], |
| 487 | + ) |
| 488 | + assert time.time() - start_t >= 25 |
| 489 | + |
| 490 | + messages = response.get("Messages", []) |
| 491 | + assert len(messages) == 10 |
0 commit comments