Skip to content

feat: propagate x-ray trace id to event bridge targets #12481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9d9be32
feat: propagate x-ray trace id
maxhoheiser Apr 4, 2025
fc45d72
feat: use internal trace_context and auto create segment
maxhoheiser Apr 4, 2025
03dff26
feat: add x-ray trace id propagation event api gateway
maxhoheiser Apr 4, 2025
1865747
feat: update test
maxhoheiser Apr 7, 2025
e4e8250
fix: correctly patch botocore
maxhoheiser Apr 7, 2025
21c0720
feat: add test xray events lambda
maxhoheiser Apr 7, 2025
b29ba1d
feat: add x_ray lambda to test lambda
maxhoheiser Apr 7, 2025
707d108
feat: remove xray sdk patch of already patched boto clients
maxhoheiser Apr 7, 2025
fdbbeab
fix: skip for v1
maxhoheiser Apr 7, 2025
4e5a7aa
feat: template replace parent id
maxhoheiser Apr 7, 2025
bc1abe9
fix: register boto hook instead of patching all clients
maxhoheiser Apr 7, 2025
786f024
fix: instrument lambda boto client call with trace header
maxhoheiser Apr 7, 2025
441408a
feat: use to string from TraceHeader
maxhoheiser Apr 7, 2025
921b10d
feat: validate events lambda xray test
maxhoheiser Apr 8, 2025
4a83ca8
feat: validate events api gateway snapshot
maxhoheiser Apr 8, 2025
cfa0336
feat: switch to using x-ray trace header variable
maxhoheiser Apr 9, 2025
b3d5224
feat: use custom boto hook to inject trace header
maxhoheiser Apr 9, 2025
2449ff5
fix: use custom input parameter and transfer to header during call
maxhoheiser Apr 10, 2025
7dfd36d
feat: add test xray propagation event bridge to event bridge
maxhoheiser Apr 10, 2025
00d6da4
feat: skip flake schedule rate test
maxhoheiser Apr 11, 2025
1a9b072
refactor: remove unnecessary parameter
maxhoheiser Apr 11, 2025
6bf8cc1
Update tests/aws/services/events/test_x_ray_trace_propagation.py
maxhoheiser Apr 11, 2025
915c7d6
feat: create new trace for scheduled event
maxhoheiser Apr 14, 2025
0d4c4e0
feat: unskip schedule rate test
maxhoheiser Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions localstack-core/localstack/services/events/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
from localstack.utils.event_matcher import matches_event
from localstack.utils.strings import long_uid
from localstack.utils.time import TIMESTAMP_FORMAT_TZ, timestamp
from localstack.utils.xray.trace_header import TraceHeader

from .analytics import InvocationStatus, rule_invocation

Expand Down Expand Up @@ -1541,8 +1542,11 @@ def func(*args, **kwargs):
}
target_unique_id = f"{rule.arn}-{target['Id']}"
target_sender = self._target_sender_store[target_unique_id]
new_trace_header = (
TraceHeader().ensure_root_exists()
) # scheduled events will always start a new trace
try:
target_sender.process_event(event.copy())
target_sender.process_event(event.copy(), trace_header=new_trace_header)
except Exception as e:
LOG.info(
"Unable to send event notification %s to target %s: %s",
Expand Down Expand Up @@ -1814,6 +1818,8 @@ def _process_entry(
return

region, account_id = extract_region_and_account_id(event_bus_name_or_arn, context)

# TODO check interference with x-ray trace header
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: how could those clash? could you quickly explain to me, I'm not familiar with this 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am piggybacking on the trace header for intermediary storing region and account for cross-region cross-account events

if encoded_trace_header := get_trace_header_encoded_region_account(
entry, context.region, context.account_id, region, account_id
):
Expand All @@ -1837,14 +1843,16 @@ def _process_entry(
)
return

self._proxy_capture_input_event(event_formatted)
trace_header = context.trace_context["aws_trace_header"]

self._proxy_capture_input_event(event_formatted, trace_header)

# Always add the successful EventId entry, even if target processing might fail
processed_entries.append({"EventId": event_formatted["id"]})

if configured_rules := list(event_bus.rules.values()):
for rule in configured_rules:
self._process_rules(rule, region, account_id, event_formatted)
self._process_rules(rule, region, account_id, event_formatted, trace_header)
else:
LOG.info(
json.dumps(
Expand All @@ -1855,7 +1863,7 @@ def _process_entry(
)
)

def _proxy_capture_input_event(self, event: FormattedEvent) -> None:
def _proxy_capture_input_event(self, event: FormattedEvent, trace_header: TraceHeader) -> None:
# only required for eventstudio to capture input event if no rule is configured
pass

Expand All @@ -1865,6 +1873,7 @@ def _process_rules(
region: str,
account_id: str,
event_formatted: FormattedEvent,
trace_header: TraceHeader,
) -> None:
"""Process rules for an event. Note that we no longer handle entries here as AWS returns success regardless of target failures."""
event_pattern = rule.event_pattern
Expand Down Expand Up @@ -1894,7 +1903,7 @@ def _process_rules(
target_unique_id = f"{rule.arn}-{target_id}"
target_sender = self._target_sender_store[target_unique_id]
try:
target_sender.process_event(event_formatted.copy())
target_sender.process_event(event_formatted.copy(), trace_header)
rule_invocation.labels(
status=InvocationStatus.success,
service=target_sender.service,
Expand Down
74 changes: 54 additions & 20 deletions localstack-core/localstack/services/events/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from localstack.utils.json import extract_jsonpath
from localstack.utils.strings import to_bytes
from localstack.utils.time import now_utc
from localstack.utils.xray.trace_header import TraceHeader

LOG = logging.getLogger(__name__)

Expand All @@ -63,6 +64,7 @@
)

TRANSFORMER_PLACEHOLDER_PATTERN = re.compile(r"<(.*?)>")
TRACE_HEADER_KEY = "X-Amzn-Trace-Id"


def transform_event_with_target_input_path(
Expand Down Expand Up @@ -193,10 +195,10 @@ def client(self):
return self._client

@abstractmethod
def send_event(self, event: FormattedEvent | TransformedEvent):
def send_event(self, event: FormattedEvent | TransformedEvent, trace_header: TraceHeader):
pass

def process_event(self, event: FormattedEvent):
def process_event(self, event: FormattedEvent, trace_header: TraceHeader):
"""Processes the event and send it to the target."""
if input_ := self.target.get("Input"):
event = json.loads(input_)
Expand All @@ -208,7 +210,7 @@ def process_event(self, event: FormattedEvent):
if input_transformer := self.target.get("InputTransformer"):
event = self.transform_event_with_target_input_transformer(input_transformer, event)
if event:
self.send_event(event)
self.send_event(event, trace_header)
else:
LOG.info("No event to send to target %s", self.target.get("Id"))

Expand Down Expand Up @@ -257,6 +259,7 @@ def _initialize_client(self) -> BaseClient:
client = client.request_metadata(
service_principal=service_principal, source_arn=self.rule_arn
)
self._register_client_hooks(client)
return client

def _validate_input_transformer(self, input_transformer: InputTransformer):
Expand Down Expand Up @@ -287,6 +290,24 @@ def _get_predefined_template_replacements(self, event: FormattedEvent) -> dict[s

return predefined_template_replacements

def _register_client_hooks(self, client: BaseClient):
"""Register client hooks to inject trace header into requests."""

def handle_extract_params(params, context, **kwargs):
trace_header = params.pop("TraceHeader", None)
if trace_header is None:
return
context[TRACE_HEADER_KEY] = trace_header.to_header_str()

def handle_inject_headers(params, context, **kwargs):
if trace_header_str := context.pop(TRACE_HEADER_KEY, None):
params["headers"][TRACE_HEADER_KEY] = trace_header_str

client.meta.events.register(
f"provide-client-params.{self.service}.*", handle_extract_params
)
client.meta.events.register(f"before-call.{self.service}.*", handle_inject_headers)


TargetSenderDict = dict[str, TargetSender] # rule_arn-target_id as global unique id

Expand Down Expand Up @@ -316,7 +337,7 @@ class ApiGatewayTargetSender(TargetSender):

ALLOWED_HTTP_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}

def send_event(self, event):
def send_event(self, event, trace_header):
# Parse the ARN to extract api_id, stage_name, http_method, and resource path
# Example ARN: arn:{partition}:execute-api:{region}:{account_id}:{api_id}/{stage_name}/{method}/{resource_path}
arn_parts = parse_arn(self.target["Arn"])
Expand Down Expand Up @@ -383,6 +404,9 @@ def send_event(self, event):
# Serialize the event, converting datetime objects to strings
event_json = json.dumps(event, default=str)

# Add trace header
headers[TRACE_HEADER_KEY] = trace_header.to_header_str()

# Send the HTTP request
response = requests.request(
method=http_method, url=url, headers=headers, data=event_json, timeout=5
Expand Down Expand Up @@ -415,12 +439,12 @@ def _get_predefined_template_replacements(self, event: Dict[str, Any]) -> Dict[s


class AppSyncTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
raise NotImplementedError("AppSync target is not yet implemented")


class BatchTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
raise NotImplementedError("Batch target is not yet implemented")

def _validate_input(self, target: Target):
Expand All @@ -433,7 +457,7 @@ def _validate_input(self, target: Target):


class ECSTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
raise NotImplementedError("ECS target is a pro feature, please use LocalStack Pro")

def _validate_input(self, target: Target):
Expand All @@ -444,7 +468,7 @@ def _validate_input(self, target: Target):


class EventsTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
# TODO add validation and tests for eventbridge to eventbridge requires Detail, DetailType, and Source
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/events/client/put_events.html
source = self._get_source(event)
Expand All @@ -464,7 +488,8 @@ def send_event(self, event):
event, self.region, self.account_id, self.target_region, self.target_account_id
):
entries[0]["TraceHeader"] = encoded_original_id
self.client.put_events(Entries=entries)

self.client.put_events(Entries=entries, TraceHeader=trace_header)

def _get_source(self, event: FormattedEvent | TransformedEvent) -> str:
if isinstance(event, dict) and (source := event.get("source")):
Expand All @@ -486,7 +511,7 @@ def _get_resources(self, event: FormattedEvent | TransformedEvent) -> list[str]:


class EventsApiDestinationTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
"""Send an event to an EventBridge API destination
See https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-api-destinations.html"""
target_arn = self.target["Arn"]
Expand Down Expand Up @@ -520,6 +545,9 @@ def send_event(self, event):
if http_parameters := self.target.get("HttpParameters"):
endpoint = add_target_http_parameters(http_parameters, endpoint, headers, event)

# add trace header
headers[TRACE_HEADER_KEY] = trace_header.to_header_str()

result = requests.request(
method=method, url=endpoint, data=json.dumps(event or {}), headers=headers
)
Expand All @@ -532,23 +560,25 @@ def send_event(self, event):


class FirehoseTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
delivery_stream_name = firehose_name(self.target["Arn"])

self.client.put_record(
DeliveryStreamName=delivery_stream_name,
Record={"Data": to_bytes(to_json_str(event))},
)


class KinesisTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
partition_key_path = collections.get_safe(
self.target,
"$.KinesisParameters.PartitionKeyPath",
default_value="$.id",
)
stream_name = self.target["Arn"].split("/")[-1]
partition_key = collections.get_safe(event, partition_key_path, event["id"])

self.client.put_record(
StreamName=stream_name,
Data=to_bytes(to_json_str(event)),
Expand All @@ -565,18 +595,20 @@ def _validate_input(self, target: Target):


class LambdaTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
self.client.invoke(
FunctionName=self.target["Arn"],
Payload=to_bytes(to_json_str(event)),
InvocationType="Event",
TraceHeader=trace_header,
)


class LogsTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
log_group_name = self.target["Arn"].split(":")[6]
log_stream_name = str(uuid.uuid4()) # Unique log stream name

self.client.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name)
self.client.put_log_events(
logGroupName=log_group_name,
Expand All @@ -591,7 +623,7 @@ def send_event(self, event):


class RedshiftTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
raise NotImplementedError("Redshift target is not yet implemented")

def _validate_input(self, target: Target):
Expand All @@ -602,20 +634,21 @@ def _validate_input(self, target: Target):


class SagemakerTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
raise NotImplementedError("Sagemaker target is not yet implemented")


class SnsTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
self.client.publish(TopicArn=self.target["Arn"], Message=to_json_str(event))


class SqsTargetSender(TargetSender):
def send_event(self, event):
def send_event(self, event, trace_header):
queue_url = sqs_queue_url_for_arn(self.target["Arn"])
msg_group_id = self.target.get("SqsParameters", {}).get("MessageGroupId", None)
kwargs = {"MessageGroupId": msg_group_id} if msg_group_id else {}

self.client.send_message(
QueueUrl=queue_url,
MessageBody=to_json_str(event),
Expand All @@ -626,8 +659,9 @@ def send_event(self, event):
class StatesTargetSender(TargetSender):
"""Step Functions Target Sender"""

def send_event(self, event):
def send_event(self, event, trace_header):
self.service = "stepfunctions"

self.client.start_execution(
stateMachineArn=self.target["Arn"], name=event["id"], input=to_json_str(event)
)
Expand All @@ -642,7 +676,7 @@ def _validate_input(self, target: Target):
class SystemsManagerSender(TargetSender):
"""EC2 Run Command Target Sender"""

def send_event(self, event):
def send_event(self, event, trace_header):
raise NotImplementedError("Systems Manager target is not yet implemented")

def _validate_input(self, target: Target):
Expand Down
Loading
Loading