Skip to content

APIGW NG fix routing with ANY request and remove stage from path #11129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class InvocationRequest(TypedDict, total=False):
http_method: Optional[HTTPMethod]
"""HTTP Method of the incoming request"""
raw_path: Optional[str]
# TODO: verify if raw_path is needed
"""Raw path of the incoming request with no modification, needed to keep double forward slashes"""
path: Optional[str]
"""Path of the request with no URL decoding"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __call__(

def parse_and_enrich(self, context: RestApiInvocationContext):
# first, create the InvocationRequest with the incoming request
context.invocation_request = self.create_invocation_request(context.request)
context.invocation_request = self.create_invocation_request(context)
# then we can create the ContextVariables, used throughout the invocation as payload and to render authorizer
# payload, mapping templates and such.
context.context_variables = self.create_context_variables(context)
Expand All @@ -43,7 +43,8 @@ def parse_and_enrich(self, context: RestApiInvocationContext):
context.stage_variables = self.fetch_stage_variables(context)
LOG.debug("Initializing $stageVariables='%s'", context.stage_variables)

def create_invocation_request(self, request: Request) -> InvocationRequest:
def create_invocation_request(self, context: RestApiInvocationContext) -> InvocationRequest:
request = context.request
params, multi_value_params = self._get_single_and_multi_values_from_multidict(request.args)
headers, multi_value_headers = self._get_single_and_multi_values_from_headers(
request.headers
Expand All @@ -58,12 +59,14 @@ def create_invocation_request(self, request: Request) -> InvocationRequest:
body=restore_payload(request),
)

self._enrich_with_raw_path(request, invocation_request)
self._enrich_with_raw_path(request, invocation_request, stage_name=context.stage)

return invocation_request

@staticmethod
def _enrich_with_raw_path(request: Request, invocation_request: InvocationRequest):
def _enrich_with_raw_path(
request: Request, invocation_request: InvocationRequest, stage_name: str
):
# Base path is not URL-decoded, so we need to get the `RAW_URI` from the request
raw_uri = request.environ.get("RAW_URI") or request.path

Expand All @@ -72,7 +75,11 @@ def _enrich_with_raw_path(request: Request, invocation_request: InvocationReques
if "_user_request_" in raw_uri:
raw_uri = raw_uri.partition("_user_request_")[2]

# remove the stage from the path
raw_uri = raw_uri.replace(f"/{stage_name}", "")

if raw_uri.startswith("//"):
# TODO: AWS validate this assumption
# if the RAW_URI starts with double slashes, `urlparse` will fail to decode it as path only
# it also means that we already only have the path, so we just need to remove the query string
raw_uri = raw_uri.split("?")[0]
Expand Down Expand Up @@ -131,9 +138,8 @@ def create_context_variables(context: RestApiInvocationContext) -> ContextVariab
domainPrefix=domain_prefix,
extendedRequestId=short_uid(), # TODO: use snapshot tests to verify format
httpMethod=invocation_request["http_method"],
path=invocation_request[
"path"
], # TODO: check if we need the raw path? with forward slashes
# TODO: check if we need the raw path? with forward slashes
path=f"/{context.stage}{invocation_request['path']}",
protocol="HTTP/1.1",
requestId=short_uid(), # TODO: use snapshot tests to verify format
requestTime=timestamp(time=now, format=REQUEST_TIME_DATE_FORMAT),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
from functools import cache
from http import HTTPMethod
from typing import Iterable

from werkzeug.exceptions import MethodNotAllowed, NotFound
from werkzeug.routing import Map, MapAdapter
from werkzeug.routing import Map, MapAdapter, Rule

from localstack.aws.api.apigateway import Resource
from localstack.aws.protocol.routing import (
GreedyPathConverter,
StrictMethodRule,
path_param_regex,
post_process_arg_name,
transform_path_params_to_rule_vars,
Expand All @@ -23,6 +23,35 @@
LOG = logging.getLogger(__name__)


class ApiGatewayMethodRule(Rule):
"""
Small extension to Werkzeug's Rule class which reverts unwanted assumptions made by Werkzeug.
Reverted assumptions:
- Werkzeug automatically matches HEAD requests to the corresponding GET request (i.e. Werkzeug's rule
automatically adds the HEAD HTTP method to a rule which should only match GET requests).
Added behavior:
- ANY is equivalent to 7 HTTP methods listed. We manually set them to the rule's methods
"""

def __init__(self, string: str, method: str, **kwargs) -> None:
super().__init__(string=string, methods=[method], **kwargs)

if method == "ANY":
self.methods = {
HTTPMethod.DELETE,
HTTPMethod.GET,
HTTPMethod.HEAD,
HTTPMethod.OPTIONS,
HTTPMethod.PATCH,
HTTPMethod.POST,
HTTPMethod.PUT,
}
else:
# Make sure Werkzeug's Rule does not add any other methods
# (f.e. the HEAD method even though the rule should only match GET)
self.methods = {method.upper()}
Comment on lines +50 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice Catch!



class RestAPIResourceRouter:
"""
A router implementation which abstracts the routing of incoming REST API Context to a specific
Expand Down Expand Up @@ -121,7 +150,7 @@ def get_rule_map_for_resources(resources: Iterable[Resource]) -> Map:
# translate the requestUri to a Werkzeug rule string
rule_string = path_param_regex.sub(transform_path_params_to_rule_vars, path)
rules.append(
StrictMethodRule(string=rule_string, method=method, endpoint=resource["id"])
ApiGatewayMethodRule(string=rule_string, method=method, endpoint=resource["id"])
) # type: ignore

return Map(
Expand Down
125 changes: 119 additions & 6 deletions tests/unit/services/apigateway/test_handler_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_parse_request(self, dummy_deployment, parse_handler_chain, get_invocati
body=body,
headers=headers,
query_string="test-param=1&test-param-2=2&test-multi=val1&test-multi=val2",
path="/normal-path",
path=f"/{TEST_API_STAGE}/normal-path",
)
context = get_invocation_context(request)
context.deployment = dummy_deployment
Expand Down Expand Up @@ -125,9 +125,14 @@ def test_parse_request(self, dummy_deployment, parse_handler_chain, get_invocati

assert context.context_variables["domainName"] == host_header
assert context.context_variables["domainPrefix"] == TEST_API_ID
assert context.context_variables["path"] == f"/{TEST_API_STAGE}/normal-path"

def test_parse_raw_path(self, dummy_deployment, parse_handler_chain, get_invocation_context):
request = Request("GET", "/foo/bar/ed", raw_path="//foo%2Fbar/ed")
request = Request(
"GET",
path=f"/{TEST_API_STAGE}/foo/bar/ed",
raw_path=f"/{TEST_API_STAGE}//foo%2Fbar/ed",
)

context = get_invocation_context(request)
context.deployment = dummy_deployment
Expand All @@ -146,8 +151,8 @@ def test_parse_user_request_path(
# simulate a path request
request = Request(
"GET",
path=f"/restapis/{TEST_API_ID}/_user_request_/foo/bar/ed",
raw_path=f"/restapis/{TEST_API_ID}/_user_request_//foo%2Fbar/ed",
path=f"/restapis/{TEST_API_ID}/_user_request_/{TEST_API_STAGE}/foo/bar/ed",
raw_path=f"/restapis/{TEST_API_ID}/_user_request_/{TEST_API_STAGE}//foo%2Fbar/ed",
)

context = get_invocation_context(request)
Expand Down Expand Up @@ -234,12 +239,69 @@ def deployment_with_routes(self, dummy_deployment):
localstack_rest_api=dummy_deployment.rest_api,
)

@pytest.fixture
def deployment_with_any_routes(self, dummy_deployment):
"""
This can be represented by the following routes:
- (No method) - /
- GET - /foo
- ANY - /foo
- PUT - /foo/{param}
- ANY - /foo/{param}
"""
moto_backend: APIGatewayBackend = apigateway_backends[TEST_AWS_ACCOUNT_ID][
TEST_AWS_REGION_NAME
]
moto_rest_api = moto_backend.apis[TEST_API_ID]

# path: /
root_resource = moto_rest_api.default
# path: /foo
hard_coded_resource = moto_rest_api.add_child(path="foo", parent_id=root_resource.id)
# path: /foo/{param}
param_resource = moto_rest_api.add_child(
path="{param}",
parent_id=hard_coded_resource.id,
)

hard_coded_resource.add_method(
method_type="GET",
authorization_type="NONE",
api_key_required=False,
)
hard_coded_resource.add_method(
method_type="ANY",
authorization_type="NONE",
api_key_required=False,
)
# we test different order of setting the Method, to make sure ANY is always matched last
# because this will influence the original order of the Werkzeug Rules in the Map
# Because we only return the `Resource` as the endpoint, we always fetch manually the right
# `resourceMethod` from the request method.
param_resource.add_method(
method_type="ANY",
authorization_type="NONE",
api_key_required=False,
)
param_resource.add_method(
method_type="PUT",
authorization_type="NONE",
api_key_required=False,
)

return freeze_rest_api(
account_id=dummy_deployment.account_id,
region=dummy_deployment.region,
moto_rest_api=moto_rest_api,
localstack_rest_api=dummy_deployment.rest_api,
)

@staticmethod
def get_path_from_addressing(path: str, addressing: str) -> str:
if addressing == "host":
return path
return f"/{TEST_API_STAGE}{path}"
else:
return f"/restapis/{TEST_API_ID}/_user_request_{path}"
return f"/restapis/{TEST_API_ID}/_user_request_/{TEST_API_STAGE}{path}"

@pytest.mark.parametrize("addressing", ["host", "user_request"])
def test_route_request_no_param(
Expand Down Expand Up @@ -405,3 +467,54 @@ def test_route_request_with_double_slash_and_trailing_and_encoded(

assert context.resource["path"] == "/foo/{param}"
assert context.invocation_request["path_parameters"] == {"param": "foo%2Fbar"}

@pytest.mark.parametrize("addressing", ["host", "user_request"])
def test_route_request_any_is_last(
self, deployment_with_any_routes, parse_handler_chain, get_invocation_context, addressing
):
handler = InvocationRequestRouter()

def handle(_request: Request) -> RestApiInvocationContext:
_context = get_invocation_context(_request)
_context.deployment = deployment_with_any_routes
parse_handler_chain.handle(_context, Response())
handler(parse_handler_chain, _context, Response())
return _context

request = Request(
"GET",
path=self.get_path_from_addressing("/foo", addressing),
)
context = handle(request)

assert context.resource["path"] == "/foo"
assert context.resource["resourceMethods"]["GET"]

request = Request(
"DELETE",
path=self.get_path_from_addressing("/foo", addressing),
)
context = handle(request)

assert context.resource["path"] == "/foo"
assert context.resource["resourceMethods"]["ANY"]

request = Request(
"PUT",
path=self.get_path_from_addressing("/foo/random-value", addressing),
)

context = handle(request)

assert context.resource["path"] == "/foo/{param}"
assert context.resource_method == context.resource["resourceMethods"]["PUT"]

request = Request(
"GET",
path=self.get_path_from_addressing("/foo/random-value", addressing),
)

context = handle(request)

assert context.resource["path"] == "/foo/{param}"
assert context.resource_method == context.resource["resourceMethods"]["ANY"]
Loading