Skip to content

APIGW: migrate TestInvokeMethod to NextGen #12514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions localstack-core/localstack/services/apigateway/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import hashlib
import json
import logging
from datetime import datetime
from typing import List, Optional, TypedDict, Union
from urllib import parse as urlparse

Expand Down Expand Up @@ -61,7 +60,6 @@
{formatted_date} : Method completed with status: {status_code}
"""


EMPTY_MODEL = "Empty"
ERROR_MODEL = "Error"

Expand Down Expand Up @@ -984,35 +982,6 @@ def is_variable_path(path_part: str) -> bool:
return path_part.startswith("{") and path_part.endswith("}")


def log_template(
request_id: str,
date: datetime,
http_method: str,
resource_path: str,
request_path: str,
query_string: str,
request_headers: str,
request_body: str,
response_body: str,
response_headers: str,
status_code: str,
):
formatted_date = date.strftime("%a %b %d %H:%M:%S %Z %Y")
return INVOKE_TEST_LOG_TEMPLATE.format(
request_id=request_id,
formatted_date=formatted_date,
http_method=http_method,
resource_path=resource_path,
request_path=request_path,
query_string=query_string,
request_headers=request_headers,
request_body=request_body,
response_body=response_body,
response_headers=response_headers,
status_code=status_code,
)


def get_domain_name_hash(domain_name: str) -> str:
"""
Return a hash of the given domain name, which help construct regional domain names for APIs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from localstack.services.apigateway.helpers import (
EMPTY_MODEL,
ERROR_MODEL,
INVOKE_TEST_LOG_TEMPLATE,
OpenAPIExt,
apply_json_patch_safe,
get_apigateway_store,
Expand All @@ -108,7 +109,6 @@
import_api_from_openapi_spec,
is_greedy_path,
is_variable_path,
log_template,
resolve_references,
)
from localstack.services.apigateway.legacy.helpers import multi_value_dict_for_list
Expand Down Expand Up @@ -217,9 +217,10 @@ def test_invoke_method(

# TODO: add the missing fields to the log. Next iteration will add helpers to extract the missing fields
# from the apicontext
log = log_template(
formatted_date = req_start_time.strftime("%a %b %d %H:%M:%S %Z %Y")
log = INVOKE_TEST_LOG_TEMPLATE.format(
request_id=invocation_context.context["requestId"],
date=req_start_time,
formatted_date=formatted_date,
http_method=invocation_context.method,
resource_path=invocation_context.invocation_path,
request_path="",
Expand All @@ -230,6 +231,7 @@ def test_invoke_method(
response_headers=result.headers,
status_code=result.status_code,
)

return TestInvokeMethodResponse(
status=result.status_code,
headers=dict(result.headers),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import datetime
from urllib.parse import parse_qs

from rolo import Request
from rolo.gateway.chain import HandlerChain
from werkzeug.datastructures import Headers

from localstack.aws.api.apigateway import TestInvokeMethodRequest, TestInvokeMethodResponse
from localstack.constants import APPLICATION_JSON
from localstack.http import Response
from localstack.utils.strings import to_bytes, to_str

from ...models import RestApiDeployment
from . import handlers
from .context import InvocationRequest, RestApiInvocationContext
from .handlers.resource_router import RestAPIResourceRouter
from .header_utils import build_multi_value_headers
from .template_mapping import dict_to_string

# TODO: we probably need to write and populate those logs as part of the handler chain itself
# and store it in the InvocationContext. That way, we could also retrieve in when calling TestInvoke

TEST_INVOKE_TEMPLATE = """Execution log for request {request_id}
{formatted_date} : Starting execution for request: {request_id}
{formatted_date} : HTTP Method: {request_method}, Resource Path: {resource_path}
{formatted_date} : Method request path: {method_request_path_parameters}
{formatted_date} : Method request query string: {method_request_query_string}
{formatted_date} : Method request headers: {method_request_headers}
{formatted_date} : Method request body before transformations: {method_request_body}
{formatted_date} : Endpoint request URI: {endpoint_uri}
{formatted_date} : Endpoint request headers: {endpoint_request_headers}
{formatted_date} : Endpoint request body after transformations: {endpoint_request_body}
{formatted_date} : Sending request to {endpoint_uri}
{formatted_date} : Received response. Status: {endpoint_response_status_code}, Integration latency: {endpoint_response_latency} ms
{formatted_date} : Endpoint response headers: {endpoint_response_headers}
{formatted_date} : Endpoint response body before transformations: {endpoint_response_body}
{formatted_date} : Method response body after transformations: {method_response_body}
{formatted_date} : Method response headers: {method_response_headers}
{formatted_date} : Successfully completed execution
{formatted_date} : Method completed with status: {method_response_status}
"""


def _dump_headers(headers: Headers) -> str:
if not headers:
return "{}"
multi_headers = {key: ",".join(headers.getlist(key)) for key in headers.keys()}
string_headers = dict_to_string(multi_headers)
if len(string_headers) > 998:
return f"{string_headers[:998]} [TRUNCATED]"

return string_headers


def log_template(invocation_context: RestApiInvocationContext, response_headers: Headers) -> str:
# TODO: funny enough, in AWS for the `endpoint_response_headers` in AWS_PROXY, they log the response headers from
# lambda HTTP Invoke call even though we use the headers from the lambda response itself
Comment on lines +56 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😢 Never stop being surprised!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is kinda weird 😄 somewhat leaking implementation details

formatted_date = datetime.datetime.now(tz=datetime.UTC).strftime("%a %b %d %H:%M:%S %Z %Y")
request = invocation_context.invocation_request
context_var = invocation_context.context_variables
integration_req = invocation_context.integration_request
endpoint_resp = invocation_context.endpoint_response
method_resp = invocation_context.invocation_response
# TODO: if endpoint_uri is an ARN, it means it's an AWS_PROXY integration
# this should be transformed to the true URL of a lambda invoke call
endpoint_uri = integration_req.get("uri", "")

return TEST_INVOKE_TEMPLATE.format(
formatted_date=formatted_date,
request_id=context_var["requestId"],
resource_path=request["path"],
request_method=request["http_method"],
method_request_path_parameters=dict_to_string(request["path_parameters"]),
method_request_query_string=dict_to_string(request["query_string_parameters"]),
method_request_headers=_dump_headers(request.get("headers")),
method_request_body=to_str(request.get("body", "")),
endpoint_uri=endpoint_uri,
endpoint_request_headers=_dump_headers(integration_req.get("headers")),
endpoint_request_body=to_str(integration_req.get("body", "")),
# TODO: measure integration latency
endpoint_response_latency=150,
endpoint_response_status_code=endpoint_resp.get("status_code"),
endpoint_response_body=to_str(endpoint_resp.get("body", "")),
endpoint_response_headers=_dump_headers(endpoint_resp.get("headers")),
method_response_status=method_resp.get("status_code"),
method_response_body=to_str(method_resp.get("body", "")),
method_response_headers=_dump_headers(response_headers),
)


def create_test_chain() -> HandlerChain[RestApiInvocationContext]:
return HandlerChain(
request_handlers=[
handlers.method_request_handler,
handlers.integration_request_handler,
handlers.integration_handler,
handlers.integration_response_handler,
handlers.method_response_handler,
],
exception_handlers=[
handlers.gateway_exception_handler,
],
)
Comment on lines +91 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So clean! This really makes that work we did shine bright!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it was so nice to actually be able to only pick a few handlers that we needed, skipping the parsing and routing of the request 👌

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you could add them all, and have a if is_test_invoke in them to produce a different result! 🤣



def create_test_invocation_context(
test_request: TestInvokeMethodRequest,
deployment: RestApiDeployment,
) -> RestApiInvocationContext:
parse_handler = handlers.parse_request
http_method = test_request["httpMethod"]

# we do not need a true HTTP request for the context, as we are skipping all the parsing steps and using the
# provider data
invocation_context = RestApiInvocationContext(
request=Request(method=http_method),
)
path_query = test_request.get("pathWithQueryString", "/").split("?")
path = path_query[0]
multi_query_args: dict[str, list[str]] = {}

if len(path_query) > 1:
multi_query_args = parse_qs(path_query[1])

# for the single value parameters, AWS only keeps the last value of the list
single_query_args = {k: v[-1] for k, v in multi_query_args.items()}

invocation_request = InvocationRequest(
http_method=http_method,
path=path,
raw_path=path,
query_string_parameters=single_query_args,
multi_value_query_string_parameters=multi_query_args,
headers=Headers(test_request.get("headers")),
# TODO: handle multiValueHeaders
body=to_bytes(test_request.get("body") or ""),
)
invocation_context.invocation_request = invocation_request

_, path_parameters = RestAPIResourceRouter(deployment).match(invocation_context)
invocation_request["path_parameters"] = path_parameters

invocation_context.deployment = deployment
invocation_context.api_id = test_request["restApiId"]
invocation_context.stage = None
invocation_context.deployment_id = ""
invocation_context.account_id = deployment.account_id
invocation_context.region = deployment.region
invocation_context.stage_variables = test_request.get("stageVariables", {})
invocation_context.context_variables = parse_handler.create_context_variables(
invocation_context
)
invocation_context.trace_id = parse_handler.populate_trace_id({})

resource = deployment.rest_api.resources[test_request["resourceId"]]
resource_method = resource["resourceMethods"][http_method]
invocation_context.resource = resource
invocation_context.resource_method = resource_method
invocation_context.integration = resource_method["methodIntegration"]
handlers.route_request.update_context_variables_with_resource(
invocation_context.context_variables, resource
)

return invocation_context


def run_test_invocation(
test_request: TestInvokeMethodRequest, deployment: RestApiDeployment
) -> TestInvokeMethodResponse:
# validate resource exists in deployment
invocation_context = create_test_invocation_context(test_request, deployment)

test_chain = create_test_chain()
# header order is important
if invocation_context.integration["type"] == "MOCK":
base_headers = {"Content-Type": APPLICATION_JSON}
else:
# we manually add the trace-id, as it is normally added by handlers.response_enricher which adds to much data
# for the TestInvoke. It needs to be first
base_headers = {
"X-Amzn-Trace-Id": invocation_context.trace_id,
"Content-Type": APPLICATION_JSON,
}

test_response = Response(headers=base_headers)
start_time = datetime.datetime.now()
test_chain.handle(context=invocation_context, response=test_response)
end_time = datetime.datetime.now()

response_headers = test_response.headers.copy()
# AWS does not return the Content-Length for TestInvokeMethod
response_headers.remove("Content-Length")

log = log_template(invocation_context, response_headers)

headers = dict(response_headers)
multi_value_headers = build_multi_value_headers(response_headers)

return TestInvokeMethodResponse(
log=log,
status=test_response.status_code,
body=test_response.get_data(as_text=True),
headers=headers,
multiValueHeaders=multi_value_headers,
latency=int((end_time - start_time).total_seconds()),
)
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from .execute_api.helpers import freeze_rest_api
from .execute_api.router import ApiGatewayEndpoint, ApiGatewayRouter
from .execute_api.test_invoke import run_test_invocation


class ApigatewayNextGenProvider(ApigatewayProvider):
Expand Down Expand Up @@ -242,8 +243,28 @@ def get_gateway_responses(
def test_invoke_method(
self, context: RequestContext, request: TestInvokeMethodRequest
) -> TestInvokeMethodResponse:
# TODO: rewrite and migrate to NextGen
return super().test_invoke_method(context, request)
rest_api_id = request["restApiId"]
moto_rest_api = get_moto_rest_api(context=context, rest_api_id=rest_api_id)
resource = moto_rest_api.resources.get(request["resourceId"])
if not resource:
raise NotFoundException("Invalid Resource identifier specified")

# test httpMethod

rest_api_container = get_rest_api_container(context, rest_api_id=rest_api_id)
frozen_deployment = freeze_rest_api(
account_id=context.account_id,
region=context.region,
moto_rest_api=moto_rest_api,
localstack_rest_api=rest_api_container,
)

response = run_test_invocation(
test_request=request,
deployment=frozen_deployment,
)

return response


def _get_gateway_response_or_default(
Expand Down
Loading
Loading