Skip to content

Commit a3cd603

Browse files
authored
SNS : fix Message Signature typo and add Lambda URL fixture (#12181)
1 parent 7e58c74 commit a3cd603

File tree

7 files changed

+354
-14
lines changed

7 files changed

+354
-14
lines changed

localstack-core/localstack/services/sns/models.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
import time
33
from dataclasses import dataclass, field
4+
from enum import StrEnum
45
from typing import Dict, List, Literal, Optional, TypedDict, Union
56

67
from localstack.aws.api.sns import (
@@ -37,9 +38,15 @@ def get_next_sequence_number():
3738
return next(global_sns_message_sequence())
3839

3940

41+
class SnsMessageType(StrEnum):
42+
Notification = "Notification"
43+
SubscriptionConfirmation = "SubscriptionConfirmation"
44+
UnsubscribeConfirmation = "UnsubscribeConfirmation"
45+
46+
4047
@dataclass
4148
class SnsMessage:
42-
type: str
49+
type: SnsMessageType
4350
message: Union[
4451
str, Dict
4552
] # can be Dict if after being JSON decoded for validation if structure is `json`
@@ -75,7 +82,7 @@ def message_content(self, protocol: SnsMessageProtocols) -> str:
7582
@classmethod
7683
def from_batch_entry(cls, entry: PublishBatchRequestEntry, is_fifo=False) -> "SnsMessage":
7784
return cls(
78-
type="Notification",
85+
type=SnsMessageType.Notification,
7986
message=entry["Message"],
8087
subject=entry.get("Subject"),
8188
message_structure=entry.get("MessageStructure"),

localstack-core/localstack/services/sns/provider.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,13 @@
6464
from localstack.services.sns import usage
6565
from localstack.services.sns.certificate import SNS_SERVER_CERT
6666
from localstack.services.sns.filter import FilterPolicyValidator
67-
from localstack.services.sns.models import SnsMessage, SnsStore, SnsSubscription, sns_stores
67+
from localstack.services.sns.models import (
68+
SnsMessage,
69+
SnsMessageType,
70+
SnsStore,
71+
SnsSubscription,
72+
sns_stores,
73+
)
6874
from localstack.services.sns.publisher import (
6975
PublishDispatcher,
7076
SnsBatchPublishContext,
@@ -429,9 +435,11 @@ def unsubscribe(
429435
if subscription["Protocol"] in ["http", "https"]:
430436
# TODO: actually validate this (re)subscribe behaviour somehow (localhost.run?)
431437
# we might need to save the sub token in the store
438+
# TODO: AWS only sends the UnsubscribeConfirmation if the call is unauthenticated or the requester is not
439+
# the owner
432440
subscription_token = encode_subscription_token_with_region(region=context.region)
433441
message_ctx = SnsMessage(
434-
type="UnsubscribeConfirmation",
442+
type=SnsMessageType.UnsubscribeConfirmation,
435443
token=subscription_token,
436444
message=f"You have chosen to deactivate subscription {subscription_arn}.\nTo cancel this operation and restore the subscription, visit the SubscribeURL included in this message.",
437445
)
@@ -604,7 +612,7 @@ def publish(
604612
store = self.get_store(account_id=context.account_id, region_name=context.region)
605613

606614
message_ctx = SnsMessage(
607-
type="Notification",
615+
type=SnsMessageType.Notification,
608616
message=message,
609617
message_attributes=message_attributes,
610618
message_deduplication_id=message_deduplication_id,
@@ -763,7 +771,7 @@ def subscribe(
763771
# Send out confirmation message for HTTP(S), fix for https://github.com/localstack/localstack/issues/881
764772
if protocol in ["http", "https"]:
765773
message_ctx = SnsMessage(
766-
type="SubscriptionConfirmation",
774+
type=SnsMessageType.SubscriptionConfirmation,
767775
token=subscription_token,
768776
message=f"You have chosen to subscribe to the topic {topic_arn}.\nTo confirm the subscription, visit the SubscribeURL included in this message.",
769777
)

localstack-core/localstack/services/sns/publisher.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from localstack.services.sns.models import (
2727
SnsApplicationPlatforms,
2828
SnsMessage,
29+
SnsMessageType,
2930
SnsStore,
3031
SnsSubscription,
3132
)
@@ -242,7 +243,7 @@ def prepare_message(
242243
message_attributes = prepare_message_attributes(message_context.message_attributes)
243244

244245
event_payload = {
245-
"Type": message_context.type or "Notification",
246+
"Type": message_context.type or SnsMessageType.Notification,
246247
"MessageId": message_context.message_id,
247248
"Subject": message_context.subject,
248249
"TopicArn": subscriber["TopicArn"],
@@ -482,14 +483,14 @@ def _publish(self, context: SnsPublishContext, subscriber: SnsSubscription):
482483
"x-amz-sns-message-id": message_context.message_id,
483484
"x-amz-sns-topic-arn": subscriber["TopicArn"],
484485
}
485-
if message_context.type != "SubscriptionConfirmation":
486+
if message_context.type != SnsMessageType.SubscriptionConfirmation:
486487
# while testing, never had those from AWS but the docs above states it should be there
487488
message_headers["x-amz-sns-subscription-arn"] = subscriber["SubscriptionArn"]
488489

489490
# When raw message delivery is enabled, x-amz-sns-rawdelivery needs to be set to 'true'
490491
# indicating that the message has been published without JSON formatting.
491492
# https://docs.aws.amazon.com/sns/latest/dg/sns-large-payload-raw-message-delivery.html
492-
if message_context.type == "Notification":
493+
if message_context.type == SnsMessageType.Notification:
493494
if is_raw_message_delivery(subscriber):
494495
message_headers["x-amz-sns-rawdelivery"] = "true"
495496
if content_type := self._get_content_type(subscriber, context.topic_attributes):
@@ -526,7 +527,7 @@ def _publish(self, context: SnsPublishContext, subscriber: SnsSubscription):
526527
topic_attributes=context.topic_attributes,
527528
)
528529
# AWS doesn't send to the DLQ if there's an error trying to deliver a UnsubscribeConfirmation msg
529-
if message_context.type != "UnsubscribeConfirmation":
530+
if message_context.type != SnsMessageType.UnsubscribeConfirmation:
530531
sns_error_to_dead_letter_queue(subscriber, message_body, str(exc))
531532

532533
@staticmethod
@@ -922,9 +923,12 @@ def compute_canonical_string(message: dict, notification_type: str) -> str:
922923
See https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html
923924
"""
924925
# create the canonical string
925-
if notification_type == "Notification":
926+
if notification_type == SnsMessageType.Notification:
926927
fields = ["Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type"]
927-
elif notification_type in ("SubscriptionConfirmation", "UnsubscriptionConfirmation"):
928+
elif notification_type in (
929+
SnsMessageType.SubscriptionConfirmation,
930+
SnsMessageType.UnsubscribeConfirmation,
931+
):
928932
fields = ["Message", "MessageId", "SubscribeURL", "Timestamp", "Token", "TopicArn", "Type"]
929933
else:
930934
return ""
@@ -968,11 +972,14 @@ def create_sns_message_body(
968972
"Timestamp": timestamp_millis(),
969973
}
970974

971-
if message_type == "Notification":
975+
if message_type == SnsMessageType.Notification:
972976
unsubscribe_url = create_unsubscribe_url(external_url, subscriber["SubscriptionArn"])
973977
data["UnsubscribeURL"] = unsubscribe_url
974978

975-
elif message_type in ("UnsubscribeConfirmation", "SubscriptionConfirmation"):
979+
elif message_type in (
980+
SnsMessageType.SubscriptionConfirmation,
981+
SnsMessageType.UnsubscribeConfirmation,
982+
):
976983
data["Token"] = message_context.token
977984
data["SubscribeURL"] = create_subscribe_url(
978985
external_url, subscriber["TopicArn"], message_context.token

tests/aws/services/sns/conftest.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
3+
from localstack.utils.strings import short_uid
4+
5+
LAMBDA_FN_SNS_ENDPOINT = """
6+
import boto3, json, os
7+
def handler(event, *args):
8+
if "AWS_ENDPOINT_URL" in os.environ:
9+
sqs_client = boto3.client("sqs", endpoint_url=os.environ["AWS_ENDPOINT_URL"])
10+
else:
11+
sqs_client = boto3.client("sqs")
12+
13+
queue_url = os.environ.get("SQS_QUEUE_URL")
14+
message = {"event": event}
15+
sqs_client.send_message(QueueUrl=queue_url, MessageBody=json.dumps(message), MessageGroupId="1")
16+
return {"statusCode": 200}
17+
"""
18+
19+
20+
@pytest.fixture
21+
def create_sns_http_endpoint_and_queue(
22+
aws_client, account_id, create_lambda_function, sqs_create_queue
23+
):
24+
lambda_client = aws_client.lambda_
25+
26+
def _create_sns_http_endpoint():
27+
function_name = f"lambda_fn_sns_endpoint-{short_uid()}"
28+
29+
# create SQS queue for results
30+
queue_name = f"{function_name}.fifo"
31+
queue_attrs = {"FifoQueue": "true", "ContentBasedDeduplication": "true"}
32+
queue_url = sqs_create_queue(QueueName=queue_name, Attributes=queue_attrs)
33+
aws_client.sqs.add_permission(
34+
QueueUrl=queue_url,
35+
Label=f"lambda-sqs-{short_uid()}",
36+
AWSAccountIds=[account_id],
37+
Actions=["SendMessage"],
38+
)
39+
40+
create_lambda_function(
41+
func_name=function_name,
42+
handler_file=LAMBDA_FN_SNS_ENDPOINT,
43+
envvars={"SQS_QUEUE_URL": queue_url},
44+
)
45+
create_url_response = lambda_client.create_function_url_config(
46+
FunctionName=function_name, AuthType="NONE", InvokeMode="BUFFERED"
47+
)
48+
aws_client.lambda_.add_permission(
49+
FunctionName=function_name,
50+
StatementId="urlPermission",
51+
Action="lambda:InvokeFunctionUrl",
52+
Principal="*",
53+
FunctionUrlAuthType="NONE",
54+
)
55+
return create_url_response["FunctionUrl"], queue_url
56+
57+
return _create_sns_http_endpoint

tests/aws/services/sns/test_sns.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3921,6 +3921,173 @@ def _clean_headers(response_headers: dict):
39213921
snapshot.match("http-message", payload)
39223922
snapshot.match("http-message-headers", _clean_headers(notification_request.headers))
39233923

3924+
@markers.aws.validated
3925+
def test_subscribe_external_http_endpoint_lambda_url_sig_validation(
3926+
self,
3927+
create_sns_http_endpoint_and_queue,
3928+
sns_create_topic,
3929+
sns_subscription,
3930+
aws_client,
3931+
snapshot,
3932+
sqs_collect_messages,
3933+
):
3934+
def _get_snapshot_from_lambda_url_msg(events: list[dict]) -> dict:
3935+
formatted_events = []
3936+
3937+
def _filter_headers(headers: dict) -> dict:
3938+
filtered_headers = {}
3939+
for key, value in headers.items():
3940+
l_key = key.lower()
3941+
if l_key.startswith("x-amz-sns") or key in (
3942+
"content-type",
3943+
"accept-encoding",
3944+
"user-agent",
3945+
):
3946+
filtered_headers[key] = value
3947+
3948+
return filtered_headers
3949+
3950+
for event in events:
3951+
msg = json.loads(event["Body"])["event"]
3952+
formatted_events.append(
3953+
{"headers": _filter_headers(msg["headers"]), "body": json.loads(msg["body"])}
3954+
)
3955+
3956+
return {"events": formatted_events}
3957+
3958+
def validate_message_signature(msg_event: dict, msg_type: str):
3959+
cert_url = msg_event["SigningCertURL"]
3960+
get_cert_req = requests.get(cert_url)
3961+
assert get_cert_req.ok
3962+
3963+
cert = x509.load_pem_x509_certificate(get_cert_req.content)
3964+
message_signature = msg_event["Signature"]
3965+
# create the canonical string
3966+
if msg_type == "Notification":
3967+
fields = ["Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type"]
3968+
else:
3969+
fields = [
3970+
"Message",
3971+
"MessageId",
3972+
"SubscribeURL",
3973+
"Timestamp",
3974+
"Token",
3975+
"TopicArn",
3976+
"Type",
3977+
]
3978+
3979+
# Build the string to be signed.
3980+
string_to_sign = "".join(
3981+
[f"{field}\n{msg_event[field]}\n" for field in fields if field in msg_event]
3982+
)
3983+
3984+
# decode the signature from base64.
3985+
decoded_signature = base64.b64decode(message_signature)
3986+
3987+
message_sig_version = msg_event["SignatureVersion"]
3988+
# this is a bug on AWS side, assert our behaviour is the same for now, this might get fixed
3989+
assert message_sig_version == "1"
3990+
signature_hash = hashes.SHA1() if message_sig_version == "1" else hashes.SHA256()
3991+
3992+
# calculate signature value with cert
3993+
# if the signature is invalid, this will raise an exception
3994+
cert.public_key().verify(
3995+
decoded_signature,
3996+
to_bytes(string_to_sign),
3997+
padding=padding.PKCS1v15(),
3998+
algorithm=signature_hash,
3999+
)
4000+
4001+
snapshot.add_transformer(
4002+
[
4003+
snapshot.transform.key_value("RequestId"),
4004+
snapshot.transform.key_value("Token"),
4005+
snapshot.transform.key_value("Host"),
4006+
snapshot.transform.regex(
4007+
r"(?i)(?<=SubscribeURL[\"|']:\s[\"|'])(https?.*?)(?=/\?Action=ConfirmSubscription)",
4008+
replacement="<subscribe-domain>",
4009+
),
4010+
]
4011+
)
4012+
http_endpoint_url, queue_url = create_sns_http_endpoint_and_queue()
4013+
topic_arn = sns_create_topic()["TopicArn"]
4014+
sns_protocol = http_endpoint_url.split("://")[0]
4015+
subscription = sns_subscription(
4016+
TopicArn=topic_arn, Protocol=sns_protocol, Endpoint=http_endpoint_url
4017+
)
4018+
subscription_arn = subscription["SubscriptionArn"]
4019+
delivery_policy = {
4020+
"healthyRetryPolicy": {
4021+
"minDelayTarget": 1,
4022+
"maxDelayTarget": 1,
4023+
"numRetries": 0,
4024+
"numNoDelayRetries": 0,
4025+
"numMinDelayRetries": 0,
4026+
"numMaxDelayRetries": 0,
4027+
"backoffFunction": "linear",
4028+
},
4029+
"sicklyRetryPolicy": None,
4030+
"throttlePolicy": {"maxReceivesPerSecond": 1000},
4031+
"guaranteed": False,
4032+
}
4033+
aws_client.sns.set_subscription_attributes(
4034+
SubscriptionArn=subscription_arn,
4035+
AttributeName="DeliveryPolicy",
4036+
AttributeValue=json.dumps(delivery_policy),
4037+
)
4038+
4039+
messages = sqs_collect_messages(queue_url, expected=1, timeout=10)
4040+
subscribe_event = _get_snapshot_from_lambda_url_msg(messages)
4041+
snapshot.match("subscription-confirmation", subscribe_event)
4042+
4043+
subscribe_payload = subscribe_event["events"][0]["body"]
4044+
4045+
validate_message_signature(
4046+
subscribe_payload,
4047+
msg_type=subscribe_event["events"][0]["headers"]["x-amz-sns-message-type"],
4048+
)
4049+
4050+
token = subscribe_payload["Token"]
4051+
subscribe_url = subscribe_payload["SubscribeURL"]
4052+
service_url, subscribe_url_path = subscribe_url.rsplit("/", maxsplit=1)
4053+
# we manually assert here to be sure the format is right, as it hard to verify with snapshots
4054+
assert subscribe_url == (
4055+
f"{service_url}/?Action=ConfirmSubscription&TopicArn={topic_arn}&Token={token}"
4056+
)
4057+
4058+
confirm_subscription = aws_client.sns.confirm_subscription(TopicArn=topic_arn, Token=token)
4059+
snapshot.match("confirm-subscription", confirm_subscription)
4060+
4061+
subscription_attributes = aws_client.sns.get_subscription_attributes(
4062+
SubscriptionArn=subscription_arn
4063+
)
4064+
assert subscription_attributes["Attributes"]["PendingConfirmation"] == "false"
4065+
4066+
message = "test_external_http_endpoint"
4067+
aws_client.sns.publish(TopicArn=topic_arn, Message=message)
4068+
4069+
messages = sqs_collect_messages(queue_url, expected=1, timeout=10)
4070+
publish_event = _get_snapshot_from_lambda_url_msg(messages)
4071+
snapshot.match("publish-event", publish_event)
4072+
publish_payload = publish_event["events"][0]["body"]
4073+
validate_message_signature(
4074+
publish_payload,
4075+
msg_type=publish_event["events"][0]["headers"]["x-amz-sns-message-type"],
4076+
)
4077+
4078+
unsub_request = requests.get(publish_payload["UnsubscribeURL"])
4079+
assert b"UnsubscribeResponse" in unsub_request.content
4080+
4081+
messages = sqs_collect_messages(queue_url, expected=1, timeout=10)
4082+
unsubscribe_event = _get_snapshot_from_lambda_url_msg(messages)
4083+
snapshot.match("unsubscribe-event", unsubscribe_event)
4084+
4085+
unsubscribe_payload = unsubscribe_event["events"][0]["body"]
4086+
validate_message_signature(
4087+
unsubscribe_payload,
4088+
msg_type=unsubscribe_event["events"][0]["headers"]["x-amz-sns-message-type"],
4089+
)
4090+
39244091

39254092
class TestSNSSubscriptionFirehose:
39264093
@markers.aws.validated

0 commit comments

Comments
 (0)