From 2aea6aa5c2d8eefe8f09fb858b9f29fe6beabc18 Mon Sep 17 00:00:00 2001 From: Callum Bodels <111014555+bodelsc@users.noreply.github.com> Date: Wed, 4 Jan 2023 18:19:07 +0000 Subject: [PATCH 01/41] docs: removing trailing whitespace (#76) Removing the trailing whitespace in the README.md to improve the formatting and readability of the file. --- README.md | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 0eec533..45674c9 100644 --- a/README.md +++ b/README.md @@ -26,26 +26,26 @@ Instructions for installing AWS Lambda Runtime Interface Emulator for your platf ## Getting started -There are a few ways you use the Runtime Interface Emulator (RIE) to locally test your function depending on the base image used. +There are a few ways you use the Runtime Interface Emulator (RIE) to locally test your function depending on the base image used. ### Test an image with RIE included in the image -The AWS base images for Lambda include the runtime interface emulator. You can also follow these steps if you built the RIE into your alternative base image. +The AWS base images for Lambda include the runtime interface emulator. You can also follow these steps if you built the RIE into your alternative base image. #### To test your Lambda function with the emulator -1. Build your image locally using the docker build command. +1. Build your image locally using the docker build command. `docker build -t myfunction:latest .` -2. Run your container image locally using the docker run command. +2. Run your container image locally using the docker run command. `docker run -p 9000:8080 myfunction:latest` - This command runs the image as a container and starts up an endpoint locally at `localhost:9000/2015-03-31/functions/function/invocations`. + This command runs the image as a container and starts up an endpoint locally at `localhost:9000/2015-03-31/functions/function/invocations`. -3. Post an event to the following endpoint using a curl command: +3. Post an event to the following endpoint using a curl command: `curl -XPOST "http://localhost:9000/2015-03-31/functions/function/invocations" -d '{}'` @@ -71,7 +71,7 @@ The following example shows a typical script for a Node.js function. fi ``` -2. Download the [runtime interface emulator](https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest) for your target architecture (`aws-lambda-rie` for x86\_64 or `aws-lambda-rie-arm64` for arm64) from GitHub into your project directory. +2. Download the [runtime interface emulator](https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest) for your target architecture (`aws-lambda-rie` for x86\_64 or `aws-lambda-rie-arm64` for arm64) from GitHub into your project directory. 3. Install the emulator package and change `ENTRYPOINT` to run the new script by adding the following lines to your Dockerfile: @@ -108,30 +108,30 @@ You install the runtime interface emulator to your local machine. When you run t mkdir -p ~/.aws-lambda-rie && curl -Lo ~/.aws-lambda-rie/aws-lambda-rie \ https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie \ && chmod +x ~/.aws-lambda-rie/aws-lambda-rie - ``` + ``` To download the RIE for arm64 architecture, use the previous command with a different GitHub download url. ``` https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie-arm64 \ ``` -2. Run your Lambda image function using the docker run command. +2. Run your Lambda image function using the docker run command. ``` - docker run -d -v ~/.aws-lambda-rie:/aws-lambda -p 9000:8080 myfunction:latest + docker run -d -v ~/.aws-lambda-rie:/aws-lambda -p 9000:8080 myfunction:latest --entrypoint /aws-lambda/aws-lambda-rie <(optional) image command>` ``` - This runs the image as a container and starts up an endpoint locally at `localhost:9000/2015-03-31/functions/function/invocations`. + This runs the image as a container and starts up an endpoint locally at `localhost:9000/2015-03-31/functions/function/invocations`. -3. Post an event to the following endpoint using a curl command: +3. Post an event to the following endpoint using a curl command: `curl -XPOST "http://localhost:9000/2015-03-31/functions/function/invocations" -d '{}'` This command invokes the function running in the container image and returns a response. -## How to configure +## How to configure -`aws-lambda-rie` can be configured through Environment Variables within the local running Image. +`aws-lambda-rie` can be configured through Environment Variables within the local running Image. You can configure your credentials by setting: * `AWS_ACCESS_KEY_ID` * `AWS_SECRET_ACCESS_KEY` @@ -147,17 +147,17 @@ The rest of these Environment Variables can be set to match AWS Lambda's environ ## Level of support -You can use the emulator to test if your function code is compatible with the Lambda environment, executes successfully -and provides the expected output. For example, you can mock test events from different event sources. You can also use -it to test extensions and agents built into the container image against the Lambda Extensions API. This component -does *not *emulate* *the orchestration behavior of AWS Lambda. For example, Lambda has a network and security -configurations that will not be emulated by this component. +You can use the emulator to test if your function code is compatible with the Lambda environment, executes successfully +and provides the expected output. For example, you can mock test events from different event sources. You can also use +it to test extensions and agents built into the container image against the Lambda Extensions API. This component +does *not *emulate* *the orchestration behavior of AWS Lambda. For example, Lambda has a network and security +configurations that will not be emulated by this component. * You can use the emulator to test if your function code is compatible with the Lambda environment, runs successfully and provides the expected output. * You can also use it to test extensions and agents built into the container image against the Lambda Extensions API. -* This component does _not_ emulate Lambda’s orchestration, or security and authentication configurations. -* The component does _not_ support X-ray and other Lambda integrations locally. +* This component does _not_ emulate Lambda’s orchestration, or security and authentication configurations. +* The component does _not_ support X-ray and other Lambda integrations locally. * The component supports only Linux, for x86-64 and arm64 architectures. ## Security From de2c8504eb0cec626d0d04b84e0706f925fb5154 Mon Sep 17 00:00:00 2001 From: Callum Bodels <111014555+bodelsc@users.noreply.github.com> Date: Thu, 5 Jan 2023 17:48:01 +0000 Subject: [PATCH 02/41] docs: adding Go version/release badges (#77) Add Shield.io badges for Go version and GitHub release. Useful for displaying important information in a visually appealing and easily accessible way at a glance. Updated the license badge to standard Shield.io format. As it appears it was incorrect pointing to the badge for 'aws-sam-local'. --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 45674c9..0e5a737 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ ## AWS Lambda Runtime Interface Emulator +![GitHub release (latest by date)](https://img.shields.io/github/v/release/aws/aws-lambda-runtime-interface-emulator) +![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/aws/aws-lambda-runtime-interface-emulator) +![GitHub](https://img.shields.io/github/license/aws/aws-lambda-runtime-interface-emulator) -![Apache-2.0](https://img.shields.io/npm/l/aws-sam-local.svg) The Lambda Runtime Interface Emulator is a proxy for Lambda’s Runtime and Extensions APIs, which allows customers to locally test their Lambda function packaged as a container image. It is a lightweight web-server that converts From 4c56a131d9d883c80c74ac893509c38a95042ded Mon Sep 17 00:00:00 2001 From: Callum Bodels <111014555+bodelsc@users.noreply.github.com> Date: Thu, 5 Jan 2023 17:54:39 +0000 Subject: [PATCH 03/41] docs: adding content table (#78) Add content table to README.md, making it more convenient for users to find the information they need and enabling them to quickly jump to specific sections. --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index 0e5a737..daec1b8 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,20 @@ requests instead of the JSON events required for deployment to Lambda. This comp Lambda’s orchestrator, or security and authentication configurations. You can get started by downloading and installing it on your local machine. When the Lambda Runtime API emulator is executed, a `/2015-03-31/functions/function/invocations` endpoint will be stood up within the container that you post data to it in order to invoke your function for testing. +## Content +* [Installing](#installing) +* [Getting started](#getting-started) + * [Test an image with RIE included in the image](#test-an-image-with-rie-included-in-the-image) + * [To test your Lambda function with the emulator](#to-test-your-lambda-function-with-the-emulator) + * [Build RIE into your base image](#build-rie-into-your-base-image) + * [To build the emulator into your image](#to-build-the-emulator-into-your-image) + * [Test an image without adding RIE to the image](#test-an-image-without-adding-rie-to-the-image) +* [How to configure](#how-to-configure) +* [Level of support](#level-of-support) +* [Security](#security) +* [License](#license) + + ## Installing Instructions for installing AWS Lambda Runtime Interface Emulator for your platform From ee52c16e23d6a9c6c201495fb76ada13aeafd9cb Mon Sep 17 00:00:00 2001 From: Yeongrok Gim <35251295+yeongrokgim@users.noreply.github.com> Date: Fri, 6 Jan 2023 02:56:37 +0900 Subject: [PATCH 04/41] docs: Display environment variable as code snippet Add markup to environment variable for better browsing --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index daec1b8..9d97794 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ You can configure your credentials by setting: * `AWS_SESSION_TOKEN` * `AWS_REGION` -You can configure timeout by setting AWS_LAMBDA_FUNCTION_TIMEOUT to the number of seconds you want your function to timeout in. +You can configure timeout by setting `AWS_LAMBDA_FUNCTION_TIMEOUT` to the number of seconds you want your function to timeout in. The rest of these Environment Variables can be set to match AWS Lambda's environment but are not required. * `AWS_LAMBDA_FUNCTION_VERSION` From 6c827eac2139cbdfdd63fbac40a0b5a6808f6695 Mon Sep 17 00:00:00 2001 From: Callum Bodels <111014555+bodelsc@users.noreply.github.com> Date: Thu, 5 Jan 2023 17:58:20 +0000 Subject: [PATCH 05/41] docs: corrected bullet point's content spacing/indentation (#81) The content for some bullet points are not correctly indented or spaced. This causes text and code blocks to not be displayed properly, causing confusion for customers. This change indents the content of the bullet points so it is now correctly displayed increasing readability. --- README.md | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 9d97794..4782a65 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,10 @@ You can build RIE into a base image. Download the RIE from GitHub to your local 1. Create a script and save it in your project directory. Set execution permissions for the script file. -The script checks for the presence of the `AWS_LAMBDA_RUNTIME_API` environment variable, which indicates the presence of the runtime API. If the runtime API is present, the script runs [the runtime interface client](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-images.html#runtimes-api-client). Otherwise, the script runs the runtime interface emulator. + The script checks for the presence of the `AWS_LAMBDA_RUNTIME_API` environment variable, which indicates the presence of the runtime API. If the runtime API is present, the script runs [the runtime interface client](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-images.html#runtimes-api-client). Otherwise, the script runs the runtime interface emulator. + + The following example shows a typical script for a Node.js function. -The following example shows a typical script for a Node.js function. ``` #!/bin/sh if [ -z "${AWS_LAMBDA_RUNTIME_API}" ]; then @@ -91,24 +92,28 @@ The following example shows a typical script for a Node.js function. 3. Install the emulator package and change `ENTRYPOINT` to run the new script by adding the following lines to your Dockerfile: -To use the default x86\_64 architecture + To use the default x86\_64 architecture + ``` ADD aws-lambda-rie /usr/local/bin/aws-lambda-rie ENTRYPOINT [ "/entry_script.sh" ] ``` -To use the arm64 architecture: + To use the arm64 architecture: + ``` ADD aws-lambda-rie-arm64 /usr/local/bin/aws-lambda-rie ENTRYPOINT [ "/entry_script.sh" ] ``` 4. Build your image locally using the docker build command. + ``` docker build -t myfunction:latest . ``` 5. Run your image locally using the docker run command. + ``` docker run -p 9000:8080 myfunction:latest ``` @@ -126,12 +131,14 @@ You install the runtime interface emulator to your local machine. When you run t && chmod +x ~/.aws-lambda-rie/aws-lambda-rie ``` -To download the RIE for arm64 architecture, use the previous command with a different GitHub download url. + To download the RIE for arm64 architecture, use the previous command with a different GitHub download url. + ``` https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie-arm64 \ ``` 2. Run your Lambda image function using the docker run command. + ``` docker run -d -v ~/.aws-lambda-rie:/aws-lambda -p 9000:8080 myfunction:latest --entrypoint /aws-lambda/aws-lambda-rie <(optional) image command>` From 559b1159e7b1cb0c771b5bb54ac86e08f8b5eef6 Mon Sep 17 00:00:00 2001 From: Jordan Brough Date: Tue, 17 Jan 2023 09:51:29 -0700 Subject: [PATCH 06/41] docs: Add syntax highlighting, fix typos, other cleanups (#82) * Add missing backslash in docker command * Remove stray "`" from command * Covert code snippet to code block for consistency * Remove extraneous newline * Fix invalid italics and use "_" for consistency * Change (invalid) italics to header Make it consistent with the preceding section that says: "To build the emulator into your image" Also, the stray whitespace before the final asterisk was incorrect. * Add syntax highlighting to code blocks * Add header to content table --- README.md | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 4782a65..3599cc0 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ Lambda’s orchestrator, or security and authentication configurations. You can * [Build RIE into your base image](#build-rie-into-your-base-image) * [To build the emulator into your image](#to-build-the-emulator-into-your-image) * [Test an image without adding RIE to the image](#test-an-image-without-adding-rie-to-the-image) + * [To test an image without adding RIE to the image](#to-test-an-image-without-adding-rie-to-the-image) * [How to configure](#how-to-configure) * [Level of support](#level-of-support) * [Security](#security) @@ -79,7 +80,7 @@ You can build RIE into a base image. Download the RIE from GitHub to your local The following example shows a typical script for a Node.js function. - ``` + ```sh #!/bin/sh if [ -z "${AWS_LAMBDA_RUNTIME_API}" ]; then exec /usr/local/bin/aws-lambda-rie /usr/bin/npx aws-lambda-ric @@ -94,38 +95,39 @@ You can build RIE into a base image. Download the RIE from GitHub to your local To use the default x86\_64 architecture - ``` + ```dockerfile ADD aws-lambda-rie /usr/local/bin/aws-lambda-rie ENTRYPOINT [ "/entry_script.sh" ] ``` To use the arm64 architecture: - ``` + ```dockerfile ADD aws-lambda-rie-arm64 /usr/local/bin/aws-lambda-rie ENTRYPOINT [ "/entry_script.sh" ] ``` 4. Build your image locally using the docker build command. - ``` + ```sh docker build -t myfunction:latest . ``` 5. Run your image locally using the docker run command. - ``` + ```sh docker run -p 9000:8080 myfunction:latest ``` ### Test an image without adding RIE to the image You install the runtime interface emulator to your local machine. When you run the container image, you set the entry point to be the emulator. -*To test an image without adding RIE to the image * + +#### To test an image without adding RIE to the image 1. From your project directory, run the following command to download the RIE (x86-64 architecture) from GitHub and install it on your local machine. - ``` + ```sh mkdir -p ~/.aws-lambda-rie && curl -Lo ~/.aws-lambda-rie/aws-lambda-rie \ https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie \ && chmod +x ~/.aws-lambda-rie/aws-lambda-rie @@ -139,16 +141,18 @@ You install the runtime interface emulator to your local machine. When you run t 2. Run your Lambda image function using the docker run command. - ``` - docker run -d -v ~/.aws-lambda-rie:/aws-lambda -p 9000:8080 myfunction:latest - --entrypoint /aws-lambda/aws-lambda-rie <(optional) image command>` + ```sh + docker run -d -v ~/.aws-lambda-rie:/aws-lambda -p 9000:8080 myfunction:latest \ + --entrypoint /aws-lambda/aws-lambda-rie <(optional) image command> ``` This runs the image as a container and starts up an endpoint locally at `localhost:9000/2015-03-31/functions/function/invocations`. 3. Post an event to the following endpoint using a curl command: - `curl -XPOST "http://localhost:9000/2015-03-31/functions/function/invocations" -d '{}'` + ```sh + curl -XPOST "http://localhost:9000/2015-03-31/functions/function/invocations" -d '{}' + ``` This command invokes the function running in the container image and returns a response. @@ -173,10 +177,9 @@ The rest of these Environment Variables can be set to match AWS Lambda's environ You can use the emulator to test if your function code is compatible with the Lambda environment, executes successfully and provides the expected output. For example, you can mock test events from different event sources. You can also use it to test extensions and agents built into the container image against the Lambda Extensions API. This component -does *not *emulate* *the orchestration behavior of AWS Lambda. For example, Lambda has a network and security +does _not_ emulate the orchestration behavior of AWS Lambda. For example, Lambda has a network and security configurations that will not be emulated by this component. - * You can use the emulator to test if your function code is compatible with the Lambda environment, runs successfully and provides the expected output. * You can also use it to test extensions and agents built into the container image against the Lambda Extensions API. * This component does _not_ emulate Lambda’s orchestration, or security and authentication configurations. From a08886c27be3214fe7a1b72bd419e51eca01afa2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 3 Mar 2023 14:01:54 -0800 Subject: [PATCH 07/41] chore(deps): bump golang.org/x/net from 0.1.0 to 0.7.0 (#85) Bumps [golang.org/x/net](https://github.com/golang/net) from 0.1.0 to 0.7.0. - [Release notes](https://github.com/golang/net/releases) - [Commits](https://github.com/golang/net/compare/v0.1.0...v0.7.0) --- updated-dependencies: - dependency-name: golang.org/x/net dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 871b812..278c63a 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/konsorten/go-windows-terminal-sequences v1.0.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.1.0 // indirect - golang.org/x/net v0.1.0 // indirect - golang.org/x/sys v0.1.0 // indirect + golang.org/x/net v0.7.0 // indirect + golang.org/x/sys v0.5.0 // indirect gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect ) diff --git a/go.sum b/go.sum index daa8fe3..905e315 100644 --- a/go.sum +++ b/go.sum @@ -28,14 +28,14 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/urfave/cli/v2 v2.2.0/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= -golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From c1cf0c518004962baa79bd82d5c4c2776e15073f Mon Sep 17 00:00:00 2001 From: Renato Valenzuela <37676028+valerena@users.noreply.github.com> Date: Thu, 27 Apr 2023 10:46:51 -0700 Subject: [PATCH 08/41] feat: Pull upstream changes 2023/04 (#87) --- cmd/aws-lambda-rie/handlers.go | 31 +- cmd/aws-lambda-rie/http.go | 6 +- cmd/aws-lambda-rie/main.go | 51 +- lambda/agents/agent.go | 75 +- lambda/agents/agent_test.go | 221 ++--- lambda/agents/log_line_splitter.go | 40 - lambda/appctx/appctx.go | 14 + lambda/appctx/appctxutil.go | 26 +- lambda/appctx/appctxutil_test.go | 25 + .../core/bandwidthlimiter/bandwidthlimiter.go | 61 ++ .../bandwidthlimiter/bandwidthlimiter_test.go | 106 +++ lambda/core/bandwidthlimiter/throttler.go | 154 ++++ .../core/bandwidthlimiter/throttler_test.go | 215 +++++ lambda/core/bandwidthlimiter/util.go | 46 ++ lambda/core/bandwidthlimiter/util_test.go | 45 + lambda/core/credentials.go | 42 +- lambda/core/credentials_test.go | 54 +- lambda/core/directinvoke/directinvoke.go | 288 ++++++- lambda/core/directinvoke/directinvoke_test.go | 358 ++++++++ lambda/core/directinvoke/util.go | 84 ++ lambda/core/doc.go | 38 +- lambda/core/externalagent.go | 1 - lambda/core/flow.go | 17 + lambda/core/registrations.go | 21 + lambda/core/registrations_test.go | 2 +- lambda/core/runtime_state_names.go | 12 +- lambda/core/states.go | 78 +- lambda/core/states_test.go | 134 ++- lambda/core/watchdog.go | 102 --- lambda/core/watchdog_test.go | 50 -- lambda/fatalerror/fatalerror.go | 5 +- lambda/interop/bootstrap.go | 18 + lambda/interop/cancellable_request.go | 27 + lambda/interop/environment_variables.go | 14 + lambda/interop/model.go | 331 +++++--- lambda/interop/model_test.go | 27 + lambda/interop/sandbox_model.go | 171 +++- lambda/logging/doc.go | 23 +- lambda/logging/internal_log_test.go | 26 +- lambda/logging/platform_log.go | 65 -- lambda/logging/platform_log_test.go | 42 - lambda/logging/taillog.go | 52 -- lambda/logging/taillog_test.go | 29 - lambda/rapi/handler/agentiniterror_test.go | 10 +- lambda/rapi/handler/agentnext_test.go | 16 +- lambda/rapi/handler/agentregister.go | 11 +- lambda/rapi/handler/agentregister_test.go | 9 +- lambda/rapi/handler/constants.go | 1 - lambda/rapi/handler/credentials_test.go | 37 +- lambda/rapi/handler/initerror.go | 42 +- lambda/rapi/handler/initerror_test.go | 4 +- lambda/rapi/handler/invocationerror.go | 25 +- lambda/rapi/handler/invocationerror_test.go | 6 +- lambda/rapi/handler/invocationnext_test.go | 2 +- lambda/rapi/handler/invocationresponse.go | 36 +- .../rapi/handler/invocationresponse_test.go | 95 ++- lambda/rapi/handler/restorenext.go | 40 + lambda/rapi/handler/restorenext_test.go | 87 ++ lambda/rapi/handler/runtimelogs.go | 36 +- lambda/rapi/handler/runtimelogs_stub.go | 45 +- lambda/rapi/handler/runtimelogs_stub_test.go | 16 +- lambda/rapi/handler/runtimelogs_test.go | 130 ++- lambda/rapi/middleware/middleware_test.go | 6 +- lambda/rapi/model/tracing.go | 11 +- lambda/rapi/rendering/doc.go | 2 - lambda/rapi/rendering/render_json.go | 33 + lambda/rapi/rendering/rendering.go | 84 +- lambda/rapi/router.go | 39 +- lambda/rapi/router_test.go | 47 +- lambda/rapi/security_test.go | 6 +- lambda/rapi/server.go | 16 +- lambda/rapi/server_test.go | 3 +- lambda/rapid/bootstrap.go | 18 - lambda/rapid/exit.go | 147 ++-- lambda/rapid/graceful_shutdown.go | 198 ----- lambda/rapid/sandbox.go | 165 ++-- lambda/rapid/shutdown.go | 366 +++++++++ lambda/rapid/start.go | 776 +++++++++++------- lambda/rapid/start_test.go | 34 +- lambda/rapidcore/bootstrap.go | 79 +- lambda/rapidcore/bootstrap_test.go | 115 ++- lambda/rapidcore/env/environment.go | 18 +- lambda/rapidcore/env/environment_test.go | 41 +- lambda/rapidcore/errors.go | 6 +- lambda/rapidcore/sandbox.go | 259 ------ lambda/rapidcore/sandbox_api.go | 147 ++++ lambda/rapidcore/sandbox_builder.go | 217 +++++ lambda/rapidcore/sandbox_emulator_api.go | 52 ++ lambda/rapidcore/server.go | 621 +++++++++----- lambda/rapidcore/server_test.go | 370 ++++++--- .../standalone/directInvokeHandler.go | 16 +- lambda/rapidcore/standalone/executeHandler.go | 17 +- lambda/rapidcore/standalone/initHandler.go | 70 +- .../standalone/internalStateHandler.go | 4 +- lambda/rapidcore/standalone/invokeHandler.go | 34 +- lambda/rapidcore/standalone/pingHandler.go | 12 + lambda/rapidcore/standalone/reserveHandler.go | 12 +- lambda/rapidcore/standalone/resetHandler.go | 4 +- lambda/rapidcore/standalone/restoreHandler.go | 41 + lambda/rapidcore/standalone/router.go | 27 +- .../rapidcore/standalone/shutdownHandler.go | 6 +- lambda/rapidcore/standalone/util.go | 4 +- .../standalone/waitUntilInitializedHandler.go | 23 + .../standalone/waitUntilReleaseHandler.go | 4 +- lambda/rapidcore/telemetry/eventLog.go | 25 +- lambda/rapidcore/telemetry/events_api.go | 97 +++ lambda/runtimecmd/runtime_command.go | 57 -- lambda/runtimecmd/runtime_command_test.go | 51 -- lambda/supervisor/local_supervisor.go | 302 +++++++ lambda/supervisor/local_supervisor_test.go | 215 +++++ lambda/supervisor/model/model.go | 269 ++++++ lambda/telemetry/events_api.go | 128 ++- lambda/telemetry/events_api_test.go | 139 ++++ lambda/telemetry/logs_egress_api.go | 7 +- lambda/telemetry/logs_subscription_api.go | 29 +- lambda/telemetry/tracer.go | 22 + lambda/telemetry/tracer_test.go | 46 ++ lambda/testdata/agents/bash_stderr.sh | 5 - lambda/testdata/agents/bash_stdout.sh | 5 - .../testdata/agents/bash_stdout_and_stderr.sh | 8 - lambda/testdata/flowtesting.go | 123 ++- .../local_lambda/test_end_to_end.py | 18 + 122 files changed, 6845 insertions(+), 2726 deletions(-) delete mode 100644 lambda/agents/log_line_splitter.go create mode 100644 lambda/core/bandwidthlimiter/bandwidthlimiter.go create mode 100644 lambda/core/bandwidthlimiter/bandwidthlimiter_test.go create mode 100644 lambda/core/bandwidthlimiter/throttler.go create mode 100644 lambda/core/bandwidthlimiter/throttler_test.go create mode 100644 lambda/core/bandwidthlimiter/util.go create mode 100644 lambda/core/bandwidthlimiter/util_test.go create mode 100644 lambda/core/directinvoke/directinvoke_test.go create mode 100644 lambda/core/directinvoke/util.go delete mode 100644 lambda/core/watchdog.go delete mode 100644 lambda/core/watchdog_test.go create mode 100644 lambda/interop/bootstrap.go create mode 100644 lambda/interop/cancellable_request.go create mode 100644 lambda/interop/environment_variables.go create mode 100644 lambda/interop/model_test.go delete mode 100644 lambda/logging/platform_log.go delete mode 100644 lambda/logging/platform_log_test.go delete mode 100644 lambda/logging/taillog.go delete mode 100644 lambda/logging/taillog_test.go create mode 100644 lambda/rapi/handler/restorenext.go create mode 100644 lambda/rapi/handler/restorenext_test.go create mode 100644 lambda/rapi/rendering/render_json.go delete mode 100644 lambda/rapid/bootstrap.go delete mode 100644 lambda/rapid/graceful_shutdown.go create mode 100644 lambda/rapid/shutdown.go delete mode 100644 lambda/rapidcore/sandbox.go create mode 100644 lambda/rapidcore/sandbox_api.go create mode 100644 lambda/rapidcore/sandbox_builder.go create mode 100644 lambda/rapidcore/sandbox_emulator_api.go create mode 100644 lambda/rapidcore/standalone/pingHandler.go create mode 100644 lambda/rapidcore/standalone/restoreHandler.go create mode 100644 lambda/rapidcore/standalone/waitUntilInitializedHandler.go create mode 100644 lambda/rapidcore/telemetry/events_api.go delete mode 100644 lambda/runtimecmd/runtime_command.go delete mode 100644 lambda/runtimecmd/runtime_command_test.go create mode 100644 lambda/supervisor/local_supervisor.go create mode 100644 lambda/supervisor/local_supervisor_test.go create mode 100644 lambda/supervisor/model/model.go create mode 100644 lambda/telemetry/events_api_test.go delete mode 100755 lambda/testdata/agents/bash_stderr.sh delete mode 100755 lambda/testdata/agents/bash_stdout.sh delete mode 100755 lambda/testdata/agents/bash_stdout_and_stderr.sh diff --git a/cmd/aws-lambda-rie/handlers.go b/cmd/aws-lambda-rie/handlers.go index 39097fc..42032cf 100644 --- a/cmd/aws-lambda-rie/handlers.go +++ b/cmd/aws-lambda-rie/handlers.go @@ -14,8 +14,10 @@ import ( "strings" "time" + "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" + "go.amzn.com/lambda/rapidcore/env" "github.com/google/uuid" @@ -27,6 +29,19 @@ type Sandbox interface { Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error } +type InteropServer interface { + Init(i *interop.Init, invokeTimeoutMs int64) error + AwaitInitialized() error + FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error + Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) + Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) + AwaitRelease() (*statejson.InternalStateDescription, error) + Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription + InternalState() (*statejson.InternalStateDescription, error) + CurrentToken() *interop.Token + Restore(restore *interop.Restore) error +} + var initDone bool func GetenvWithDefault(key string, defaultValue string) string { @@ -57,7 +72,7 @@ func printEndReports(invokeId string, initDuration string, memorySize string, in invokeId, invokeDuration, math.Ceil(invokeDuration), memorySize, memorySize) } -func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { +func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs interop.Bootstrap) { log.Debugf("invoke: -> %s %s %v", r.Method, r.URL, r.Header) bodyBytes, err := ioutil.ReadAll(r.Body) if err != nil { @@ -80,7 +95,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { if !initDone { - initStart, initEnd := InitHandler(sandbox, functionVersion, timeout) + initStart, initEnd := InitHandler(sandbox, functionVersion, timeout, bs) // Calculate InitDuration initTimeMS := math.Min(float64(initEnd.Sub(initStart).Nanoseconds()), @@ -99,7 +114,6 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { TraceID: r.Header.Get("X-Amzn-Trace-Id"), LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), Payload: bytes.NewReader(bodyBytes), - CorrelationID: "invokeCorrelationID", } fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion) @@ -166,7 +180,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox) { w.Write(invokeResp.Body) } -func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.Time, time.Time) { +func InitHandler(sandbox Sandbox, functionVersion string, timeout int64, bs interop.Bootstrap) (time.Time, time.Time) { additionalFunctionEnvironmentVariables := map[string]string{} // Add default Env Vars if they were not defined. This is a required otherwise 1p Python2.7, Python3.6, and @@ -189,15 +203,20 @@ func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.T // pass to rapid sandbox.Init(&interop.Init{ Handler: GetenvWithDefault("AWS_LAMBDA_FUNCTION_HANDLER", os.Getenv("_HANDLER")), - CorrelationID: "initCorrelationID", AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), AwsSession: os.Getenv("AWS_SESSION_TOKEN"), XRayDaemonAddress: "0.0.0.0:0", // TODO FunctionName: GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function"), FunctionVersion: functionVersion, - + RuntimeInfo: interop.RuntimeInfo{ + ImageJSON: "{}", + Arn: "", + Version: ""}, CustomerEnvironmentVariables: additionalFunctionEnvironmentVariables, + SandboxType: interop.SandboxClassic, + Bootstrap: bs, + EnvironmentVariables: env.NewEnvironment(), }, timeout*1000) initEnd := time.Now() return initStart, initEnd diff --git a/cmd/aws-lambda-rie/http.go b/cmd/aws-lambda-rie/http.go index be4002d..88bd39b 100644 --- a/cmd/aws-lambda-rie/http.go +++ b/cmd/aws-lambda-rie/http.go @@ -7,16 +7,18 @@ import ( "net/http" log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore" ) -func startHTTPServer(ipport string, sandbox Sandbox) { +func startHTTPServer(ipport string, sandbox *rapidcore.SandboxBuilder, bs interop.Bootstrap) { srv := &http.Server{ Addr: ipport, } // Pass a channel http.HandleFunc("/2015-03-31/functions/function/invocations", func(w http.ResponseWriter, r *http.Request) { - InvokeHandler(w, r, sandbox) + InvokeHandler(w, r, sandbox.LambdaInvokeAPI(), bs) }) // go routine (main thread waits) diff --git a/cmd/aws-lambda-rie/main.go b/cmd/aws-lambda-rie/main.go index 3a87e46..65879c0 100644 --- a/cmd/aws-lambda-rie/main.go +++ b/cmd/aws-lambda-rie/main.go @@ -6,6 +6,7 @@ package main import ( "context" "fmt" + "net" "os" "runtime/debug" @@ -21,8 +22,11 @@ const ( ) type options struct { - LogLevel string `long:"log-level" default:"info" description:"log level"` + LogLevel string `long:"log-level" description:"The level of AWS Lambda Runtime Interface Emulator logs to display. Can also be set by the environment variable 'LOG_LEVEL'. Defaults to the value 'info'."` InitCachingEnabled bool `long:"enable-init-caching" description:"Enable support for Init Caching"` + // Do not have a default value so we do not need to keep it in sync with the default value in lambda/rapidcore/sandbox_builder.go + RuntimeAPIAddress string `long:"runtime-api-address" description:"The address of the AWS Lambda Runtime API to communicate with the Lambda execution environment."` + RuntimeInterfaceEmulatorAddress string `long:"runtime-interface-emulator-address" default:"0.0.0.0:8080" description:"The address for the AWS Lambda Runtime Interface Emulator to accept HTTP request upon."` } func main() { @@ -30,11 +34,37 @@ func main() { debug.SetGCPercent(33) opts, args := getCLIArgs() - rapidcore.SetLogLevel(opts.LogLevel) + + logLevel := "info" + + // If you specify an option by using a parameter on the CLI command line, it overrides any value from either the corresponding environment variable. + // + // https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html + if opts.LogLevel != "" { + logLevel = opts.LogLevel + } else if envLogLevel, envLogLevelSet := os.LookupEnv("LOG_LEVEL"); envLogLevelSet { + logLevel = envLogLevel + } + + rapidcore.SetLogLevel(logLevel) + + if opts.RuntimeAPIAddress != "" { + _, _, err := net.SplitHostPort(opts.RuntimeAPIAddress) + + if err != nil { + log.WithError(err).Fatalf("The command line value for \"--runtime-api-address\" is not a valid network address %q.", opts.RuntimeAPIAddress) + } + } + + _, _, err := net.SplitHostPort(opts.RuntimeInterfaceEmulatorAddress) + + if err != nil { + log.WithError(err).Fatalf("The command line value for \"--runtime-interface-emulator-address\" is not a valid network address %q.", opts.RuntimeInterfaceEmulatorAddress) + } bootstrap, handler := getBootstrap(args, opts) sandbox := rapidcore. - NewSandboxBuilder(bootstrap). + NewSandboxBuilder(). AddShutdownFunc(context.CancelFunc(func() { os.Exit(0) })). SetExtensionsFlag(true). SetInitCachingFlag(opts.InitCachingEnabled) @@ -43,10 +73,17 @@ func main() { sandbox.SetHandler(handler) } - go sandbox.Create() + if opts.RuntimeAPIAddress != "" { + sandbox.SetRuntimeAPIAddress(opts.RuntimeAPIAddress) + } + + sandboxContext, internalStateFn := sandbox.Create() + // Since we have not specified a custom interop server for standalone, we can + // directly reference the default interop server, which is a concrete type + sandbox.DefaultInteropServer().SetSandboxContext(sandboxContext) + sandbox.DefaultInteropServer().SetInternalStateGetter(internalStateFn) - testAPIipport := "0.0.0.0:8080" - startHTTPServer(testAPIipport, sandbox) + startHTTPServer(opts.RuntimeInterfaceEmulatorAddress, sandbox, bootstrap) } func getCLIArgs() (options, []string) { @@ -112,5 +149,5 @@ func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) { log.Panic("insufficient arguments: bootstrap not provided") } - return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir), handler + return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir, ""), handler } diff --git a/lambda/agents/agent.go b/lambda/agents/agent.go index 16625c2..b1f8563 100644 --- a/lambda/agents/agent.go +++ b/lambda/agents/agent.go @@ -4,77 +4,38 @@ package agents import ( - "fmt" - "io" - "io/ioutil" - "os/exec" + "os" "path" - "syscall" + "path/filepath" log "github.com/sirupsen/logrus" ) -// AgentProcess is the common interface exposed by both internal and external agent processes -type AgentProcess interface { - Name() string -} - -// ExternalAgentProcess represents an external agent process -type ExternalAgentProcess struct { - cmd *exec.Cmd -} - -// NewExternalAgentProcess returns a new external agent process -func NewExternalAgentProcess(path string, env []string, stdoutWriter io.Writer, stderrWriter io.Writer) ExternalAgentProcess { - command := exec.Command(path) - command.Env = env - - command.Stdout = NewNewlineSplitWriter(stdoutWriter) - command.Stderr = NewNewlineSplitWriter(stderrWriter) - command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - - return ExternalAgentProcess{ - cmd: command, - } -} - -// Name returns the name of the agent -// For external agents is the executable name -func (a *ExternalAgentProcess) Name() string { - return path.Base(a.cmd.Path) -} - -func (a *ExternalAgentProcess) Pid() int { - return a.cmd.Process.Pid -} - -// Start starts an external agent process -func (a *ExternalAgentProcess) Start() error { - return a.cmd.Start() -} - -// Wait waits for the external agent process to exit -func (a *ExternalAgentProcess) Wait() error { - return a.cmd.Wait() -} - -// String is used to print values passed as an operand to any format that accepts a string or to an unformatted printer such as Print. -func (a *ExternalAgentProcess) String() string { - return fmt.Sprintf("%s (%s)", a.Name(), a.cmd.Path) -} - // ListExternalAgentPaths return a list of external agents found in a given directory -func ListExternalAgentPaths(root string) []string { +func ListExternalAgentPaths(dir string, root string) []string { var agentPaths []string - files, err := ioutil.ReadDir(root) + if !isCanonical(dir) || !isCanonical(root) { + log.Warningf("Agents base paths are not absolute and in canonical form: %s, %s", dir, root) + return agentPaths + } + fullDir := path.Join(root, dir) + files, err := os.ReadDir(fullDir) if err != nil { log.WithError(err).Warning("Cannot list external agents") return agentPaths } for _, file := range files { if !file.IsDir() { - agentPaths = append(agentPaths, path.Join(root, file.Name())) + // The returned path is absolute wrt to `root`. This allows + // to exec the agents in their own mount namespace + p := path.Join("/", dir, file.Name()) + agentPaths = append(agentPaths, p) } } return agentPaths } + +func isCanonical(path string) bool { + absPath, err := filepath.Abs(path) + return err == nil && absPath == path +} diff --git a/lambda/agents/agent_test.go b/lambda/agents/agent_test.go index d314a76..e6732ff 100644 --- a/lambda/agents/agent_test.go +++ b/lambda/agents/agent_test.go @@ -4,13 +4,12 @@ package agents import ( - "bytes" - "io/ioutil" "os" "path" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // - Test utilities @@ -50,14 +49,8 @@ func mkLink(name, target string) fileInfo { } } -// populate a temporary directory with a list of files and directories -// returns the name of the temporary root directory -func createFileTree(fs []fileInfo) (string, error) { - - root, err := ioutil.TempDir(os.TempDir(), "tmp-") - if err != nil { - return "", err - } +// populate a directory with a list of files and directories +func createFileTree(root string, fs []fileInfo) error { for _, info := range fs { filename := info.name @@ -65,67 +58,40 @@ func createFileTree(fs []fileInfo) (string, error) { name := path.Base(filename) err := os.MkdirAll(dir, 0775) if err != nil && !os.IsExist(err) { - return "", err + return err } if os.ModeDir == info.mode&os.ModeDir { err := os.Mkdir(path.Join(dir, name), info.mode&os.ModePerm) if err != nil { - return "", err + return err } } else if os.ModeSymlink == info.mode&os.ModeSymlink { target := path.Join(root, info.target) _, err = os.Stat(target) if err != nil { - return "", err + return err } err := os.Symlink(target, path.Join(dir, name)) if err != nil { - return "", err + return err } } else { file, err := os.OpenFile(path.Join(dir, name), os.O_RDWR|os.O_CREATE, info.mode&os.ModePerm) if err != nil { - return "", err + return err } file.Truncate(info.size) file.Close() } } - return root, nil -} - -// executes a given closure inside a temporary directory populated with the given FS tree -func within(fs []fileInfo, closure func()) error { - - var root string - var cwd string - var err error - - if root, err = createFileTree(fs); err != nil { - return err - } - - defer os.RemoveAll(root) - - if cwd, err = os.Getwd(); err != nil { - return err - } - - if err = os.Chdir(root); err != nil { - return err - } - - defer os.Chdir(cwd) - - closure() return nil } // - Actual tests // If the agents folder is empty it is not an error -func TestRootEmpty(t *testing.T) { +func TestBaseEmpty(t *testing.T) { assert := assert.New(t) @@ -133,34 +99,51 @@ func TestRootEmpty(t *testing.T) { mkDir("/opt/extensions", 0777), } - assert.NoError(within(fs, func() { - agents := ListExternalAgentPaths("opt/extensions") - assert.Equal(0, len(agents)) - })) + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) + + createFileTree(tmpDir, fs) + defer os.RemoveAll(tmpDir) + + agents := ListExternalAgentPaths(path.Join(tmpDir, "/opt/extensions"), "/") + assert.Equal(0, len(agents)) } // Test that non-existant /opt/extensions is treated as if no agents were found -func TestRootNotExist(t *testing.T) { +func TestBaseNotExist(t *testing.T) { assert := assert.New(t) - agents := ListExternalAgentPaths("/path/which/does/not/exist") + agents := ListExternalAgentPaths("/path/which/does/not/exist", "/") + assert.Equal(0, len(agents)) +} + +// Test that non-existant root dir is teaded as if no agents were found +func TestChrootNotExist(t *testing.T) { + + assert := assert.New(t) + + agents := ListExternalAgentPaths("/bin", "/does/not/exist") assert.Equal(0, len(agents)) } // Test that non-directory /opt/extensions is treated as if no agents were found -func TestRootNotDir(t *testing.T) { +func TestBaseNotDir(t *testing.T) { assert := assert.New(t) fs := []fileInfo{ mkFile("/opt/extensions", 1, 0777), } + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) + + createFileTree(tmpDir, fs) + defer os.RemoveAll(tmpDir) - assert.NoError(within(fs, func() { - agents := ListExternalAgentPaths("opt/extensions") - assert.Equal(0, len(agents)) - })) + path := path.Join(tmpDir, "/opt/extensions") + agents := ListExternalAgentPaths(path, "/") + assert.Equal(0, len(agents)) } // Test our ability to find agent bootstraps in the FS and return them sorted. @@ -188,99 +171,63 @@ func TestFindAgentMixed(t *testing.T) { fs := append([]fileInfo{}, listed...) fs = append(fs, unlisted...) - assert.NoError(within(fs, func() { - agentPaths := ListExternalAgentPaths("opt/extensions") - assert.Equal(len(listed), len(agentPaths)) - last := "" - for index := range listed { - if len(last) > 0 { - assert.GreaterOrEqual(agentPaths[index], last) - } - last = agentPaths[index] - } - })) -} - -// Test our ability to start agents -func TestAgentStart(t *testing.T) { - assert := assert.New(t) - agent := NewExternalAgentProcess("../testdata/agents/bash_true.sh", []string{}, &mockWriter{}, &mockWriter{}) - assert.Nil(agent.Start()) - assert.Nil(agent.Wait()) -} + tmpDir, err := os.MkdirTemp("", "ext-") + require.NoError(t, err) -// Test that execution of invalid agents is correctly reported -func TestInvalidAgentStart(t *testing.T) { - assert := assert.New(t) - agent := NewExternalAgentProcess("/bin/none", []string{}, &mockWriter{}, &mockWriter{}) - assert.True(os.IsNotExist(agent.Start())) -} + createFileTree(tmpDir, fs) + defer os.RemoveAll(tmpDir) -func TestAgentStdoutWriter(t *testing.T) { - // Given - assert := assert.New(t) - - stdout := &mockWriter{} - stderr := &mockWriter{} - expectedStdout := "stdout line 1\nstdout line 2\nstdout line 3\n" - expectedStderr := "" - - agent := NewExternalAgentProcess("../testdata/agents/bash_stdout.sh", []string{}, stdout, stderr) - - // When - assert.NoError(agent.Start()) - assert.NoError(agent.Wait()) - - // Then - assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) - assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) + path := path.Join(tmpDir, "/opt/extensions") + agentPaths := ListExternalAgentPaths(path, "/") + assert.Equal(len(listed), len(agentPaths)) + last := "" + for index := range listed { + if len(last) > 0 { + assert.GreaterOrEqual(agentPaths[index], last) + } + last = agentPaths[index] + } } -func TestAgentStderrWriter(t *testing.T) { - // Given - assert := assert.New(t) - - stdout := &mockWriter{} - stderr := &mockWriter{} - expectedStdout := "" - expectedStderr := "stderr line 1\nstderr line 2\nstderr line 3\n" - - agent := NewExternalAgentProcess("../testdata/agents/bash_stderr.sh", []string{}, stdout, stderr) - - // When - assert.NoError(agent.Start()) - assert.NoError(agent.Wait()) - - // Then - assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) - assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) -} +// Test our ability to find agent bootstraps in the FS and return them sorted, +// when using a different mount namespace root for the extensiosn domain. +// Even if not all files are valid as executable agents, +// ListExternalAgentPaths() is expected to return all of them. +func TestFindAgentMixedInChroot(t *testing.T) { -func TestAgentStdoutAndStderrSeperateWriters(t *testing.T) { - // Given assert := assert.New(t) - stdout := &mockWriter{} - stderr := &mockWriter{} - expectedStdout := "stdout line 1\nstdout line 2\nstdout line 3\n" - expectedStderr := "stderr line 1\nstderr line 2\nstderr line 3\n" + listed := []fileInfo{ + mkFile("/opt/extensions/ok2", 1, 0777), // this is ok + mkFile("/opt/extensions/ok1", 1, 0777), // this is ok as well + mkFile("/opt/extensions/not_exec", 1, 0666), // this is not executable + mkFile("/opt/extensions/not_read", 1, 0333), // this is not readable + mkFile("/opt/extensions/empty_file", 0, 0777), // this is empty + mkLink("/opt/extensions/link", "/opt/extensions/ok1"), // symlink must be ignored + } - agent := NewExternalAgentProcess("../testdata/agents/bash_stdout_and_stderr.sh", []string{}, stdout, stderr) + unlisted := []fileInfo{ + mkDir("/opt/extensions/empty_dir", 0777), // this is an empty directory + mkDir("/opt/extensions/nonempty_dir", 0777), // subdirs should not be listed + mkFile("/opt/extensions/nonempty_dir/notok", 1, 0777), // files in subdirs should not be listed + } - // When - assert.NoError(agent.Start()) - assert.NoError(agent.Wait()) + fs := append([]fileInfo{}, listed...) + fs = append(fs, unlisted...) - // Then - assert.Equal(expectedStdout, string(bytes.Join(stdout.bytesReceived, []byte("")))) - assert.Equal(expectedStderr, string(bytes.Join(stderr.bytesReceived, []byte("")))) -} + rootDir, err := os.MkdirTemp("", "rootfs") + require.NoError(t, err) -type mockWriter struct { - bytesReceived [][]byte -} + createFileTree(rootDir, fs) + defer os.RemoveAll(rootDir) -func (m *mockWriter) Write(bytes []byte) (int, error) { - m.bytesReceived = append(m.bytesReceived, bytes) - return len(bytes), nil + agentPaths := ListExternalAgentPaths("/opt/extensions", rootDir) + assert.Equal(len(listed), len(agentPaths)) + last := "" + for index := range listed { + if len(last) > 0 { + assert.GreaterOrEqual(agentPaths[index], last) + } + last = agentPaths[index] + } } diff --git a/lambda/agents/log_line_splitter.go b/lambda/agents/log_line_splitter.go deleted file mode 100644 index ac2c134..0000000 --- a/lambda/agents/log_line_splitter.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package agents - -import ( - "bytes" - "io" -) - -// NewlineSplitWriter wraps an io.Writer and calls the underlying writer for each newline separated line -type NewlineSplitWriter struct { - writer io.Writer -} - -// NewNewlineSplitWriter returns an instance of NewlineSplitWriter -func NewNewlineSplitWriter(w io.Writer) *NewlineSplitWriter { - return &NewlineSplitWriter{ - writer: w, - } -} - -// Write splits the byte buffer by newline and calls the underlying writer for each line -func (nsw *NewlineSplitWriter) Write(buf []byte) (int, error) { - newBuf := make([]byte, len(buf)) - copy(newBuf, buf) - lines := bytes.SplitAfter(newBuf, []byte("\n")) - var bytesWritten int - for _, line := range lines { - if len(line) > 0 { - n, err := nsw.writer.Write(line) - bytesWritten += n - if err != nil { - return bytesWritten, err - } - } - } - - return bytesWritten, nil -} diff --git a/lambda/appctx/appctx.go b/lambda/appctx/appctx.go index 44776ab..6c81653 100644 --- a/lambda/appctx/appctx.go +++ b/lambda/appctx/appctx.go @@ -10,6 +10,8 @@ import ( // A Key type is used as a key for storing values in the application context. type Key int +type InitType int + const ( // AppCtxInvokeErrorResponseKey is used for storing deferred invoke error response. // Only used by xray. TODO refactor xray interface so it doesn't use appctx @@ -23,6 +25,18 @@ const ( // AppCtxFirstFatalErrorKey is used to store first unrecoverable error message encountered to propagate it to slicer with DONE(errortype) or DONEFAIL(errortype) AppCtxFirstFatalErrorKey + + // AppCtxInitType is used to store the init type (init caching or plain INIT) + AppCtxInitType + + // AppCtxSandbox type is used to store the sandbox type (SandboxClassic or SandboxPreWarmed) + AppCtxSandboxType +) + +// Possible values for AppCtxInitType key +const ( + Init InitType = iota + InitCaching ) // ApplicationContext is an application scope context. diff --git a/lambda/appctx/appctxutil.go b/lambda/appctx/appctxutil.go index a3e652f..a30677f 100644 --- a/lambda/appctx/appctxutil.go +++ b/lambda/appctx/appctxutil.go @@ -5,11 +5,12 @@ package appctx import ( "context" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" "net/http" "strings" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + log "github.com/sirupsen/logrus" ) @@ -164,3 +165,24 @@ func LoadFirstFatalError(appCtx ApplicationContext) (errorType fatalerror.ErrorT } return v.(fatalerror.ErrorType), true } + +func StoreInitType(appCtx ApplicationContext, initCachingEnabled bool) { + if initCachingEnabled { + appCtx.Store(AppCtxInitType, InitCaching) + } else { + appCtx.Store(AppCtxInitType, Init) + } +} + +// Default Init Type is Init unless it's explicitly stored in ApplicationContext +func LoadInitType(appCtx ApplicationContext) InitType { + return appCtx.GetOrDefault(AppCtxInitType, Init).(InitType) +} + +func StoreSandboxType(appCtx ApplicationContext, sandboxType interop.SandboxType) { + appCtx.Store(AppCtxSandboxType, sandboxType) +} + +func LoadSandboxType(appCtx ApplicationContext) interop.SandboxType { + return appCtx.GetOrDefault(AppCtxSandboxType, interop.SandboxClassic).(interop.SandboxType) +} diff --git a/lambda/appctx/appctxutil_test.go b/lambda/appctx/appctxutil_test.go index a8a4761..b6df9aa 100644 --- a/lambda/appctx/appctxutil_test.go +++ b/lambda/appctx/appctxutil_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.amzn.com/lambda/fatalerror" + + "go.amzn.com/lambda/interop" ) func runTestRequestWithUserAgent(t *testing.T, userAgent string, expectedRuntimeRelease string) { @@ -200,3 +202,26 @@ func TestFirstFatalError(t *testing.T) { require.True(t, found) require.Equal(t, fatalerror.AgentCrash, v) } + +func TestStoreLoadInitType(t *testing.T) { + appCtx := NewApplicationContext() + + initType := LoadInitType(appCtx) + assert.Equal(t, Init, initType) + + StoreInitType(appCtx, true) + initType = LoadInitType(appCtx) + assert.Equal(t, InitCaching, initType) +} + +func TestStoreLoadSandboxType(t *testing.T) { + appCtx := NewApplicationContext() + + sandboxType := LoadSandboxType(appCtx) + assert.Equal(t, interop.SandboxClassic, sandboxType) + + StoreSandboxType(appCtx, interop.SandboxPreWarmed) + + sandboxType = LoadSandboxType(appCtx) + assert.Equal(t, interop.SandboxPreWarmed, sandboxType) +} diff --git a/lambda/core/bandwidthlimiter/bandwidthlimiter.go b/lambda/core/bandwidthlimiter/bandwidthlimiter.go new file mode 100644 index 0000000..05c600a --- /dev/null +++ b/lambda/core/bandwidthlimiter/bandwidthlimiter.go @@ -0,0 +1,61 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "io" + + "go.amzn.com/lambda/interop" +) + +func BandwidthLimitingCopy(dst *BandwidthLimitingWriter, src io.Reader) (written int64, err error) { + written, err = io.Copy(dst, src) + _ = dst.Close() + return +} + +func NewBandwidthLimitingWriter(w io.Writer, bucket *Bucket) (*BandwidthLimitingWriter, error) { + throttler, err := NewThrottler(bucket) + if err != nil { + return nil, err + } + return &BandwidthLimitingWriter{w: w, th: throttler}, nil +} + +type BandwidthLimitingWriter struct { + w io.Writer + th *Throttler +} + +func (w *BandwidthLimitingWriter) ChunkedWrite(p []byte) (n int, err error) { + i := NewChunkIterator(p, int(w.th.b.capacity)) + for { + buf := i.Next() + if buf == nil { + return + } + written, writeErr := w.th.bandwidthLimitingWrite(w.w, buf) + n += written + if writeErr != nil { + return n, writeErr + } + } +} + +func (w *BandwidthLimitingWriter) Write(p []byte) (n int, err error) { + w.th.start() + if int64(len(p)) > w.th.b.capacity { + return w.ChunkedWrite(p) + } + return w.th.bandwidthLimitingWrite(w.w, p) +} + +func (w *BandwidthLimitingWriter) Close() (err error) { + w.th.stop() + return +} + +func (w *BandwidthLimitingWriter) GetMetrics() (metrics *interop.InvokeResponseMetrics) { + return w.th.metrics +} diff --git a/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go b/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go new file mode 100644 index 0000000..7ede24b --- /dev/null +++ b/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go @@ -0,0 +1,106 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "bytes" + "io" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBandwidthLimitingCopy(t *testing.T) { + var size10mb int64 = 10 * 1024 * 1024 + + inputBuffer := []byte(strings.Repeat("a", int(size10mb))) + reader := bytes.NewReader(inputBuffer) + + bucket, err := NewBucket(size10mb/2, size10mb/4, size10mb/2, time.Millisecond/2) + assert.NoError(t, err) + + internalWriter := bytes.NewBuffer(make([]byte, 0, size10mb)) + writer, err := NewBandwidthLimitingWriter(internalWriter, bucket) + assert.NoError(t, err) + + n, err := BandwidthLimitingCopy(writer, reader) + assert.Equal(t, size10mb, n) + assert.Equal(t, nil, err) + assert.Equal(t, inputBuffer, internalWriter.Bytes()) +} + +type ErrorBufferWriter struct { + w ByteBufferWriter + failAfter int +} + +func (w *ErrorBufferWriter) Write(p []byte) (n int, err error) { + if w.failAfter >= 1 { + w.failAfter-- + } + n, err = w.w.Write(p) + if w.failAfter == 0 { + return n, io.ErrUnexpectedEOF + } + return n, err +} + +func (w *ErrorBufferWriter) Bytes() []byte { + return w.w.Bytes() +} + +func TestNewBandwidthLimitingWriter(t *testing.T) { + type testCase struct { + refillNumber int64 + internalWriter ByteBufferWriter + inputBuffer []byte + expectedN int + expectedError error + } + testCases := []testCase{ + { + refillNumber: 2, + internalWriter: bytes.NewBuffer(make([]byte, 0, 36)), // buffer size greater than bucket size + inputBuffer: []byte(strings.Repeat("a", 36)), + expectedN: 36, + expectedError: nil, + }, + { + refillNumber: 2, + internalWriter: bytes.NewBuffer(make([]byte, 0, 12)), // buffer size lesser than bucket size + inputBuffer: []byte(strings.Repeat("a", 12)), + expectedN: 12, + expectedError: nil, + }, + { + // buffer size greater than bucket size and error after two Write() invocations + refillNumber: 2, + internalWriter: &ErrorBufferWriter{w: bytes.NewBuffer(make([]byte, 0, 36)), failAfter: 2}, + inputBuffer: []byte(strings.Repeat("a", 36)), + expectedN: 32, + expectedError: io.ErrUnexpectedEOF, + }, + } + + for _, test := range testCases { + bucket, err := NewBucket(16, 8, test.refillNumber, 100*time.Millisecond) + assert.NoError(t, err) + + writer, err := NewBandwidthLimitingWriter(test.internalWriter, bucket) + assert.NoError(t, err) + assert.False(t, writer.th.running) + + n, err := writer.Write(test.inputBuffer) + assert.True(t, writer.th.running) + assert.Equal(t, test.expectedN, n) + assert.Equal(t, test.expectedError, err) + assert.Equal(t, test.inputBuffer[:n], test.internalWriter.Bytes()) + + err = writer.Close() + assert.Nil(t, err) + assert.False(t, writer.th.running) + } +} diff --git a/lambda/core/bandwidthlimiter/throttler.go b/lambda/core/bandwidthlimiter/throttler.go new file mode 100644 index 0000000..b3b57dd --- /dev/null +++ b/lambda/core/bandwidthlimiter/throttler.go @@ -0,0 +1,154 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "errors" + "fmt" + "io" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" +) + +var ErrBufferSizeTooLarge = errors.New("buffer size cannot be greater than bucket size") + +func NewBucket(capacity int64, initialTokenCount int64, refillNumber int64, refillInterval time.Duration) (*Bucket, error) { + if capacity <= 0 || initialTokenCount < 0 || refillNumber <= 0 || refillInterval <= 0 || + capacity < initialTokenCount { + errorMsg := fmt.Sprintf("invalid bucket parameters (capacity: %d, initialTokenCount: %d, refillNumber: %d,"+ + "refillInterval: %d)", capacity, initialTokenCount, refillInterval, refillInterval) + log.Error(errorMsg) + return nil, errors.New(errorMsg) + } + return &Bucket{ + capacity: capacity, + tokenCount: initialTokenCount, + refillNumber: refillNumber, + refillInterval: refillInterval, + mutex: sync.Mutex{}, + }, nil +} + +type Bucket struct { + capacity int64 + tokenCount int64 + refillNumber int64 + refillInterval time.Duration + mutex sync.Mutex +} + +func (b *Bucket) produceTokens() { + b.mutex.Lock() + defer b.mutex.Unlock() + if b.tokenCount < b.capacity { + b.tokenCount = min64(b.tokenCount+b.refillNumber, b.capacity) + } +} + +func (b *Bucket) consumeTokens(n int64) bool { + b.mutex.Lock() + defer b.mutex.Unlock() + if n <= b.tokenCount { + b.tokenCount -= n + return true + } + return false +} + +func (b *Bucket) getTokenCount() int64 { + b.mutex.Lock() + defer b.mutex.Unlock() + return b.tokenCount +} + +func NewThrottler(bucket *Bucket) (*Throttler, error) { + if bucket == nil { + errorMsg := "cannot create a throttler with nil bucket" + log.Error(errorMsg) + return nil, errors.New(errorMsg) + } + return &Throttler{ + b: bucket, + running: false, + produced: make(chan int64), + done: make(chan struct{}), + // FIXME: + // The runtime tells whether the function response mode is streaming or not. + // Ideally, we would want to use that value here. Since I'm just rebasing, I will leave + // as-is, but we should use that instead of relying on our memory to set this here + // because we "know" it's a streaming code path. + metrics: &interop.InvokeResponseMetrics{FunctionResponseMode: interop.FunctionResponseModeStreaming}, + }, nil +} + +type Throttler struct { + b *Bucket + running bool + produced chan int64 + done chan struct{} + metrics *interop.InvokeResponseMetrics +} + +func (th *Throttler) start() { + if th.running { + return + } + th.running = true + th.metrics.StartReadingResponseMonoTimeMs = metering.Monotime() + go func() { + ticker := time.NewTicker(th.b.refillInterval) + for { + select { + case <-ticker.C: + th.b.produceTokens() + select { + case th.produced <- metering.Monotime(): + default: + } + case <-th.done: + ticker.Stop() + return + } + } + }() +} + +func (th *Throttler) stop() { + if !th.running { + return + } + th.running = false + th.metrics.FinishReadingResponseMonoTimeMs = metering.Monotime() + durationMs := (th.metrics.FinishReadingResponseMonoTimeMs - th.metrics.StartReadingResponseMonoTimeMs) / int64(time.Millisecond) + if durationMs > 0 { + th.metrics.OutboundThroughputBps = (th.metrics.ProducedBytes / durationMs) * int64(time.Second/time.Millisecond) + } else { + th.metrics.OutboundThroughputBps = -1 + } + th.done <- struct{}{} +} + +func (th *Throttler) bandwidthLimitingWrite(w io.Writer, p []byte) (written int, err error) { + n := int64(len(p)) + if n > th.b.capacity { + return 0, ErrBufferSizeTooLarge + } + for { + if th.b.consumeTokens(n) { + written, err = w.Write(p) + th.metrics.ProducedBytes += int64(written) + return + } + waitStart := metering.Monotime() + elapsed := <-th.produced - waitStart + if elapsed > 0 { + th.metrics.TimeShapedNs += elapsed + } + } +} diff --git a/lambda/core/bandwidthlimiter/throttler_test.go b/lambda/core/bandwidthlimiter/throttler_test.go new file mode 100644 index 0000000..a88a14d --- /dev/null +++ b/lambda/core/bandwidthlimiter/throttler_test.go @@ -0,0 +1,215 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewBucket(t *testing.T) { + type testCase struct { + capacity int64 + initialTokenCount int64 + refillNumber int64 + refillInterval time.Duration + bucketCreated bool + } + testCases := []testCase{ + {capacity: 8, initialTokenCount: 6, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: true}, + {capacity: 8, initialTokenCount: 6, refillNumber: 2, refillInterval: -100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: 6, refillNumber: -5, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: -2, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: -2, initialTokenCount: 6, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + {capacity: 8, initialTokenCount: 10, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, + } + + for _, test := range testCases { + bucket, err := NewBucket(test.capacity, test.initialTokenCount, test.refillNumber, test.refillInterval) + if test.bucketCreated { + assert.NoError(t, err) + assert.NotNil(t, bucket) + } else { + assert.Error(t, err) + assert.Nil(t, bucket) + } + } +} + +func TestBucket_produceTokens_consumeTokens(t *testing.T) { + var consumed bool + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + assert.Equal(t, int64(8), bucket.getTokenCount()) + + consumed = bucket.consumeTokens(5) + assert.Equal(t, int64(3), bucket.getTokenCount()) + assert.True(t, consumed) + + bucket.produceTokens() + assert.Equal(t, int64(9), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(15), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(16), bucket.getTokenCount()) + + bucket.produceTokens() + assert.Equal(t, int64(16), bucket.getTokenCount()) + + consumed = bucket.consumeTokens(18) + assert.Equal(t, int64(16), bucket.getTokenCount()) + assert.False(t, consumed) + + consumed = bucket.consumeTokens(16) + assert.Equal(t, int64(0), bucket.getTokenCount()) + assert.True(t, consumed) +} + +func TestNewThrottler(t *testing.T) { + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + assert.NotNil(t, throttler) + + throttler, err = NewThrottler(nil) + assert.Error(t, err) + assert.Nil(t, throttler) +} + +func TestNewThrottler_start_stop(t *testing.T) { + bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + + assert.False(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + + <-time.Tick(2 * throttler.b.refillInterval) + assert.LessOrEqual(t, int64(14), throttler.b.getTokenCount()) + assert.True(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + <-time.Tick(2 * throttler.b.refillInterval) + assert.Equal(t, int64(16), throttler.b.getTokenCount()) + assert.True(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) + + throttler.start() + assert.True(t, throttler.running) + + throttler.stop() + assert.False(t, throttler.running) +} + +type ByteBufferWriter interface { + Write(p []byte) (n int, err error) + Bytes() []byte +} + +type FixedSizeBufferWriter struct { + buf []byte +} + +func (w *FixedSizeBufferWriter) Write(p []byte) (n int, err error) { + n = copy(w.buf, p) + return +} + +func (w *FixedSizeBufferWriter) Bytes() []byte { + return w.buf +} + +func TestNewThrottler_bandwidthLimitingWrite(t *testing.T) { + var size10mb int64 = 10 * 1024 * 1024 + + type testCase struct { + capacity int64 + initialTokenCount int64 + writer ByteBufferWriter + inputBuffer []byte + expectedN int + expectedError error + } + testCases := []testCase{ + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 14)), + inputBuffer: []byte(strings.Repeat("a", 12)), + expectedN: 12, + expectedError: nil, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 12)), + inputBuffer: []byte(strings.Repeat("a", 14)), + expectedN: 14, + expectedError: nil, + }, + { + capacity: size10mb, + initialTokenCount: size10mb, + writer: bytes.NewBuffer(make([]byte, 0, size10mb)), + inputBuffer: []byte(strings.Repeat("a", int(size10mb))), + expectedN: int(size10mb), + expectedError: nil, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: bytes.NewBuffer(make([]byte, 0, 18)), + inputBuffer: []byte(strings.Repeat("a", 18)), + expectedN: 0, + expectedError: ErrBufferSizeTooLarge, + }, + { + capacity: 16, + initialTokenCount: 8, + writer: &FixedSizeBufferWriter{buf: make([]byte, 12)}, + inputBuffer: []byte(strings.Repeat("a", 14)), + expectedN: 12, + expectedError: nil, + }, + } + + for _, test := range testCases { + bucket, err := NewBucket(test.capacity, test.initialTokenCount, 2, 100*time.Millisecond) + assert.NoError(t, err) + + throttler, err := NewThrottler(bucket) + assert.NoError(t, err) + + writer := test.writer + throttler.start() + n, err := throttler.bandwidthLimitingWrite(writer, test.inputBuffer) + assert.Equal(t, test.expectedN, n) + assert.Equal(t, test.expectedError, err) + + if test.expectedError == nil { + assert.Equal(t, test.inputBuffer[:n], test.writer.Bytes()) + } else { + assert.Equal(t, []byte{}, test.writer.Bytes()) + } + throttler.stop() + } +} diff --git a/lambda/core/bandwidthlimiter/util.go b/lambda/core/bandwidthlimiter/util.go new file mode 100644 index 0000000..7078d5d --- /dev/null +++ b/lambda/core/bandwidthlimiter/util.go @@ -0,0 +1,46 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func min64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +func NewChunkIterator(buf []byte, chunkSize int) *ChunkIterator { + if buf == nil { + return nil + } + return &ChunkIterator{ + buf: buf, + chunkSize: chunkSize, + offset: 0, + } +} + +type ChunkIterator struct { + buf []byte + chunkSize int + offset int +} + +func (i *ChunkIterator) Next() []byte { + begin := i.offset + end := min(i.offset+i.chunkSize, len(i.buf)) + i.offset = end + + if begin == end { + return nil + } + return i.buf[begin:end] +} diff --git a/lambda/core/bandwidthlimiter/util_test.go b/lambda/core/bandwidthlimiter/util_test.go new file mode 100644 index 0000000..ed93c77 --- /dev/null +++ b/lambda/core/bandwidthlimiter/util_test.go @@ -0,0 +1,45 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package bandwidthlimiter + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewChunkIterator(t *testing.T) { + buf := []byte("abcdefghijk") + + type testCase struct { + buf []byte + chunkSize int + expectedResult [][]byte + } + testCases := []testCase{ + {buf: nil, chunkSize: 0, expectedResult: [][]byte{}}, + {buf: nil, chunkSize: 1, expectedResult: [][]byte{}}, + {buf: buf, chunkSize: 0, expectedResult: [][]byte{}}, + {buf: buf, chunkSize: 1, expectedResult: [][]byte{ + []byte("a"), []byte("b"), []byte("c"), []byte("d"), []byte("e"), []byte("f"), []byte("g"), []byte("h"), + []byte("i"), []byte("j"), []byte("k"), + }}, + {buf: buf, chunkSize: 4, expectedResult: [][]byte{[]byte("abcd"), []byte("efgh"), []byte("ijk")}}, + {buf: buf, chunkSize: 5, expectedResult: [][]byte{[]byte("abcde"), []byte("fghij"), []byte("k")}}, + {buf: buf, chunkSize: 11, expectedResult: [][]byte{[]byte("abcdefghijk")}}, + {buf: buf, chunkSize: 12, expectedResult: [][]byte{[]byte("abcdefghijk")}}, + } + + for _, test := range testCases { + iterator := NewChunkIterator(test.buf, test.chunkSize) + if test.buf == nil { + assert.Nil(t, iterator) + } else { + for _, expectedChunk := range test.expectedResult { + assert.Equal(t, expectedChunk, iterator.Next()) + } + assert.Nil(t, iterator.Next()) + } + } +} diff --git a/lambda/core/credentials.go b/lambda/core/credentials.go index 7b1bf14..ad152d0 100644 --- a/lambda/core/credentials.go +++ b/lambda/core/credentials.go @@ -7,8 +7,6 @@ import ( "fmt" "sync" "time" - - log "github.com/sirupsen/logrus" ) const ( @@ -26,11 +24,9 @@ type Credentials struct { } type CredentialsService interface { - SetCredentials(token, awsKey, awsSecret, awsSession string) + SetCredentials(token, awsKey, awsSecret, awsSession string, expiration time.Time) GetCredentials(token string) (*Credentials, error) - UpdateCredentials(awsKey, awsSecret, awsSession string) error - BlockService() - UnblockService() + UpdateCredentials(awsKey, awsSecret, awsSession string, expiration time.Time) error } type credentialsServiceImpl struct { @@ -51,7 +47,7 @@ func NewCredentialsService() CredentialsService { return credentialsService } -func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSession string) { +func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSession string, expiration time.Time) { c.contentMutex.Lock() defer c.contentMutex.Unlock() @@ -59,7 +55,7 @@ func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSes AwsKey: awsKey, AwsSecret: awsSecret, AwsSession: awsSession, - Expiration: time.Now().Add(16 * time.Minute), + Expiration: expiration, } } @@ -77,33 +73,7 @@ func (c *credentialsServiceImpl) GetCredentials(token string) (*Credentials, err return nil, ErrCredentialsNotFound } -func (c *credentialsServiceImpl) BlockService() { - if c.currentState == BLOCKED { - return - } - log.Info("blocking the credentials service") - c.serviceMutex.Lock() - - c.contentMutex.Lock() - defer c.contentMutex.Unlock() - - c.currentState = BLOCKED -} - -func (c *credentialsServiceImpl) UnblockService() { - if c.currentState == UNBLOCKED { - return - } - log.Info("unblocking the credentials service") - - c.contentMutex.Lock() - defer c.contentMutex.Unlock() - - c.currentState = UNBLOCKED - c.serviceMutex.Unlock() -} - -func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession string) error { +func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession string, expiration time.Time) error { mapSize := len(c.credentials) if mapSize != 1 { return fmt.Errorf("there are %d set of credentials", mapSize) @@ -114,6 +84,6 @@ func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession token = key } - c.SetCredentials(token, awsKey, awsSecret, awsSession) + c.SetCredentials(token, awsKey, awsSecret, awsSession, expiration) return nil } diff --git a/lambda/core/credentials_test.go b/lambda/core/credentials_test.go index ab0b247..625ab8e 100644 --- a/lambda/core/credentials_test.go +++ b/lambda/core/credentials_test.go @@ -19,7 +19,8 @@ const ( func TestGetSetCredentialsHappy(t *testing.T) { credentialsService := NewCredentialsService() - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) + credentialsExpiration := time.Now().Add(15 * time.Minute) + credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession, credentialsExpiration) credentials, err := credentialsService.GetCredentials(Token) @@ -40,8 +41,12 @@ func TestGetCredentialsFail(t *testing.T) { func TestUpdateCredentialsHappy(t *testing.T) { credentialsService := NewCredentialsService() - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) - err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1") + credentialsExpiration := time.Now().Add(15 * time.Minute) + credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession, credentialsExpiration) + + restoreCredentialsExpiration := time.Now().Add(10 * time.Hour) + + err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1", restoreCredentialsExpiration) assert.NoError(t, err) credentials, err := credentialsService.GetCredentials(Token) @@ -50,49 +55,16 @@ func TestUpdateCredentialsHappy(t *testing.T) { assert.Equal(t, "sampleKey1", credentials.AwsKey) assert.Equal(t, "sampleSecret1", credentials.AwsSecret) assert.Equal(t, "sampleSession1", credentials.AwsSession) -} - -func TestUpdateCredentialsFail(t *testing.T) { - credentialsService := NewCredentialsService() - err := credentialsService.UpdateCredentials("unknownKey", "unknownSecret", "unknownSession") - - assert.Error(t, err) -} + nineHoursLater := time.Now().Add(9 * time.Hour) -func TestUpdateCredentialsOfBlockedService(t *testing.T) { - credentialsService := NewCredentialsService() - credentialsService.BlockService() - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession) - err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1") - assert.NoError(t, err) + assert.True(t, nineHoursLater.Before(credentials.Expiration)) } -func TestConsecutiveBlockService(t *testing.T) { +func TestUpdateCredentialsFail(t *testing.T) { credentialsService := NewCredentialsService() - timeout := time.After(1 * time.Second) - done := make(chan bool) - - go func() { - for i := 0; i < 10; i++ { - credentialsService.BlockService() - } - done <- true - }() - - select { - case <-timeout: - t.Fatal("BlockService should not block the calling thread.") - case <-done: - } -} - -// unlocking a mutex twice causes panic -// the assertion here is basically not having panic -func TestConsecutiveUnblockService(t *testing.T) { - credentialsService := NewCredentialsService() + err := credentialsService.UpdateCredentials("unknownKey", "unknownSecret", "unknownSession", time.Now()) - credentialsService.UnblockService() - credentialsService.UnblockService() + assert.Error(t, err) } diff --git a/lambda/core/directinvoke/directinvoke.go b/lambda/core/directinvoke/directinvoke.go index 1699121..8ef59ae 100644 --- a/lambda/core/directinvoke/directinvoke.go +++ b/lambda/core/directinvoke/directinvoke.go @@ -4,28 +4,38 @@ package directinvoke import ( + "context" "fmt" "io" "net/http" + "strconv" "github.com/go-chi/chi" + "go.amzn.com/lambda/core/bandwidthlimiter" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" + + log "github.com/sirupsen/logrus" ) const ( - InvokeIDHeader = "Invoke-Id" - InvokedFunctionArnHeader = "Invoked-Function-Arn" - VersionIDHeader = "Invoked-Function-Version" - ReservationTokenHeader = "Reservation-Token" - CustomerHeadersHeader = "Customer-Headers" - ContentTypeHeader = "Content-Type" + InvokeIDHeader = "Invoke-Id" + InvokedFunctionArnHeader = "Invoked-Function-Arn" + VersionIDHeader = "Invoked-Function-Version" + ReservationTokenHeader = "Reservation-Token" + CustomerHeadersHeader = "Customer-Headers" + ContentTypeHeader = "Content-Type" + MaxPayloadSizeHeader = "MaxPayloadSize" + ResponseBandwidthRateHeader = "ResponseBandwidthRate" + ResponseBandwidthBurstSizeHeader = "ResponseBandwidthBurstSize" + FunctionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" ErrorTypeHeader = "Error-Type" - EndOfResponseTrailer = "End-Of-Response" - - SandboxErrorType = "Error.Sandbox" + EndOfResponseTrailer = "End-Of-Response" + FunctionErrorTypeTrailer = "Lambda-Runtime-Function-Error-Type" + FunctionErrorBodyTrailer = "Lambda-Runtime-Function-Error-Body" ) const ( @@ -34,7 +44,14 @@ const ( EndOfResponseOversized = "Oversized" ) +var ResetReasonMap = map[string]fatalerror.ErrorType{ + "failure": fatalerror.SandboxFailure, + "timeout": fatalerror.SandboxTimeout, +} + var MaxDirectResponseSize int64 = interop.MaxPayloadSize // this is intentionally not a constant so we can configure it via CLI +var ResponseBandwidthRate int64 = interop.ResponseBandwidthRate +var ResponseBandwidthBurstSize int64 = interop.ResponseBandwidthBurstSize func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) { w.Header().Set(ErrorTypeHeader, errorType) @@ -42,6 +59,12 @@ func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } +func renderInternalServerError(w http.ResponseWriter, errorType string) { + w.Header().Set(ErrorTypeHeader, errorType) + w.WriteHeader(http.StatusInternalServerError) + w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) +} + // ReceiveDirectInvoke parses invoke and verifies it against Token message. Uses deadline provided by Token // Renders BadRequest in case of error func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.Token) (*interop.Invoke, error) { @@ -54,6 +77,47 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T } now := metering.Monotime() + + MaxDirectResponseSize = interop.MaxPayloadSize + if maxPayloadSize := r.Header.Get(MaxPayloadSizeHeader); maxPayloadSize != "" { + if n, err := strconv.ParseInt(maxPayloadSize, 10, 64); err == nil && n >= -1 { + MaxDirectResponseSize = n + } else { + log.Error("MaxPayloadSize header is not a valid number") + renderBadRequest(w, r, interop.ErrInvalidMaxPayloadSize.Error()) + return nil, interop.ErrInvalidMaxPayloadSize + } + } + + if MaxDirectResponseSize == -1 { + w.Header().Add("Trailer", FunctionErrorTypeTrailer) + w.Header().Add("Trailer", FunctionErrorBodyTrailer) + + ResponseBandwidthRate = interop.ResponseBandwidthRate + if responseBandwidthRate := r.Header.Get(ResponseBandwidthRateHeader); responseBandwidthRate != "" { + if n, err := strconv.ParseInt(responseBandwidthRate, 10, 64); err == nil && + interop.MinResponseBandwidthRate <= n && n <= interop.MaxResponseBandwidthRate { + ResponseBandwidthRate = n + } else { + log.Error("ResponseBandwidthRate header is not a valid number or is out of the allowed range") + renderBadRequest(w, r, interop.ErrInvalidResponseBandwidthRate.Error()) + return nil, interop.ErrInvalidResponseBandwidthRate + } + } + + ResponseBandwidthBurstSize = interop.ResponseBandwidthBurstSize + if responseBandwidthBurstSize := r.Header.Get(ResponseBandwidthBurstSizeHeader); responseBandwidthBurstSize != "" { + if n, err := strconv.ParseInt(responseBandwidthBurstSize, 10, 64); err == nil && + interop.MinResponseBandwidthBurstSize <= n && n <= interop.MaxResponseBandwidthBurstSize { + ResponseBandwidthBurstSize = n + } else { + log.Error("ResponseBandwidthBurstSize header is not a valid number or is out of the allowed range") + renderBadRequest(w, r, interop.ErrInvalidResponseBandwidthBurstSize.Error()) + return nil, interop.ErrInvalidResponseBandwidthBurstSize + } + } + } + inv := &interop.Invoke{ ID: r.Header.Get(InvokeIDHeader), ReservationToken: chi.URLParam(r, "reservationtoken"), @@ -66,7 +130,6 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T LambdaSegmentID: token.LambdaSegmentID, ClientContext: custHeaders.ClientContext, Payload: r.Body, - CorrelationID: "invokeCorrelationID", DeadlineNs: fmt.Sprintf("%d", now+token.FunctionTimeout.Nanoseconds()), NeedDebugLogs: token.NeedDebugLogs, InvokeReceivedTime: now, @@ -99,24 +162,215 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T return inv, nil } -func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, w http.ResponseWriter) error { - for k, v := range additionalHeaders { - w.Header().Add(k, v) +type CopyDoneResult struct { + Metrics *interop.InvokeResponseMetrics + Error error +} + +func getErrorTypeFromResetReason(resetReason string) fatalerror.ErrorType { + errorTypeTrailer, ok := ResetReasonMap[resetReason] + if !ok { + errorTypeTrailer = fatalerror.Unknown + } + return errorTypeTrailer +} + +func isErrorResponse(additionalHeaders map[string]string) (isErrorResponse bool) { + _, isErrorResponse = additionalHeaders[ErrorTypeHeader] + return +} + +func isStreamingInvoke() bool { + return MaxDirectResponseSize == -1 +} + +func asyncPayloadCopy(w http.ResponseWriter, payload io.Reader) (copyDone chan CopyDoneResult, cancel context.CancelFunc, err error) { + copyDone = make(chan CopyDoneResult) + streamedResponseWriter, cancel, err := NewStreamedResponseWriter(w) + if err != nil { + return nil, nil, &interop.ErrInternalPlatformError{} + } + go func() { // copy payload in a separate go routine + _, copyError := bandwidthlimiter.BandwidthLimitingCopy(streamedResponseWriter, payload) + if copyError != nil { + w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) + } else { + w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) + } + copyDoneResult := CopyDoneResult{ + Metrics: streamedResponseWriter.GetMetrics(), + Error: copyError, + } + copyDone <- copyDoneResult + cancel() // free resources + }() + return +} + +func sendStreamingInvokeResponse(payload io.Reader, trailers http.Header, w http.ResponseWriter, + interruptedResponseChan chan *interop.Reset, sendResponseChan chan *interop.InvokeResponseMetrics, + request *interop.CancellableRequest, runtimeCalledResponse bool) (err error) { + /* In case of /response, we copy the payload and, once copied, we attach: + * 1) 'Lambda-Runtime-Function-Error-Type' + * 2) 'Lambda-Runtime-Function-Error-Body' + * trailers. */ + copyDone, cancel, err := asyncPayloadCopy(w, payload) + if err != nil { + renderInternalServerError(w, err.Error()) + return err + } + + var errorTypeTrailer string + var errorBodyTrailer string + var copyDoneResult CopyDoneResult + select { + case copyDoneResult = <-copyDone: // copy finished + errorTypeTrailer = trailers.Get(FunctionErrorTypeTrailer) + errorBodyTrailer = trailers.Get(FunctionErrorBodyTrailer) + if copyDoneResult.Error != nil && errorTypeTrailer == "" { // truncated payload, error type not known + errorTypeTrailer = string(fatalerror.TruncatedResponse) + } + case reset := <-interruptedResponseChan: // reset initiated + cancel() + if request != nil { + // In case of reset: + // * to interrupt copying when runtime called /response (a potential stuck on Body.Read() operation), + // we close the underlying connection using .Close() method on the request object + // * for /error case, the whole body is already read in /error handler, so we don't need additional handling + // when sending streaming invoke error response + connErr := request.Cancel() + if connErr != nil { + log.Warnf("Failed to close underlying connection: %s", connErr) + } + } else { + log.Warn("Cannot close underlying connection. Request object is nil") + } + copyDoneResult = <-copyDone + reset.InvokeResponseMetrics = copyDoneResult.Metrics + interruptedResponseChan <- nil + errorTypeTrailer = string(getErrorTypeFromResetReason(reset.Reason)) + } + w.Header().Set(FunctionErrorTypeTrailer, errorTypeTrailer) + w.Header().Set(FunctionErrorBodyTrailer, errorBodyTrailer) + + copyDoneResult.Metrics.RuntimeCalledResponse = runtimeCalledResponse + sendResponseChan <- copyDoneResult.Metrics + + if copyDoneResult.Error != nil { + log.Errorf("Error while streaming response payload: %s", copyDoneResult.Error) + err = &interop.ErrTruncatedResponse{} + } + return +} + +func sendStreamingInvokeErrorResponse(payload io.Reader, w http.ResponseWriter, + interruptedResponseChan chan *interop.Reset, sendResponseChan chan *interop.InvokeResponseMetrics, runtimeCalledResponse bool) (err error) { + + copyDone, cancel, err := asyncPayloadCopy(w, payload) + if err != nil { + renderInternalServerError(w, err.Error()) + return err + } + + var copyDoneResult CopyDoneResult + select { + case copyDoneResult = <-copyDone: // copy finished + case reset := <-interruptedResponseChan: // reset initiated + cancel() + copyDoneResult = <-copyDone + reset.InvokeResponseMetrics = copyDoneResult.Metrics + interruptedResponseChan <- nil + } + + copyDoneResult.Metrics.RuntimeCalledResponse = runtimeCalledResponse + sendResponseChan <- copyDoneResult.Metrics + + if copyDoneResult.Error != nil { + log.Errorf("Error while streaming error response payload: %s", copyDoneResult.Error) + err = &interop.ErrTruncatedResponse{} + } + return +} + +// parseFunctionResponseMode fetches the mode from the header +// If the header is absent, it returns buffered mode. +func parseFunctionResponseMode(w http.ResponseWriter) (interop.FunctionResponseMode, error) { + headerValue := w.Header().Get(FunctionResponseModeHeader) + // the header is optional, so it's ok to be absent + if headerValue == "" { + return interop.FunctionResponseModeBuffered, nil + } + + return interop.ConvertToFunctionResponseMode(headerValue) +} + +func sendPayloadLimitedResponse(payload io.Reader, trailers http.Header, w http.ResponseWriter, sendResponseChan chan *interop.InvokeResponseMetrics, runtimeCalledResponse bool) (err error) { + functionResponseMode, err := parseFunctionResponseMode(w) + if err != nil { + return err + } + + // non-streaming invoke request but runtime is streaming: predefine Trailer headers + if functionResponseMode == interop.FunctionResponseModeStreaming { + w.Header().Add("Trailer", FunctionErrorTypeTrailer) + w.Header().Add("Trailer", FunctionErrorBodyTrailer) + } + + startReadingResponseMonoTimeMs := metering.Monotime() + written, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // +1 because we do allow 10MB but not 10MB + 1 byte + + // non-streaming invoke request but runtime is streaming: set response trailers + if functionResponseMode == interop.FunctionResponseModeStreaming { + w.Header().Set(FunctionErrorTypeTrailer, trailers.Get(FunctionErrorTypeTrailer)) + w.Header().Set(FunctionErrorBodyTrailer, trailers.Get(FunctionErrorBodyTrailer)) } - n, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // +1 because we do allow 10MB but not 10MB + 1 byte if err != nil { w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) - } else if n == MaxDirectResponseSize+1 { + err = &interop.ErrTruncatedResponse{} + } else if MaxDirectResponseSize != -1 && written == MaxDirectResponseSize+1 { w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) err = &interop.ErrorResponseTooLargeDI{ ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ - ResponseSize: int(n), + ResponseSize: int(written), MaxResponseSize: int(MaxDirectResponseSize), }, } } else { w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } - return err + + sendResponseChan <- &interop.InvokeResponseMetrics{ + ProducedBytes: int64(written), + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: metering.Monotime(), + TimeShapedNs: int64(-1), + OutboundThroughputBps: int64(-1), + // FIXME: + // We should use InvokeResponseMode here, because only when it's streaming we're interested + // on it. If the invoke is buffered, we don't generate streaming metrics, even if the + // function response mode is streaming. + FunctionResponseMode: interop.FunctionResponseModeBuffered, + RuntimeCalledResponse: runtimeCalledResponse, + } + return +} + +func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, trailers http.Header, + w http.ResponseWriter, interruptedResponseChan chan *interop.Reset, + sendResponseChan chan *interop.InvokeResponseMetrics, request *interop.CancellableRequest, runtimeCalledResponse bool) error { + + for k, v := range additionalHeaders { + w.Header().Add(k, v) + } + + if isStreamingInvoke() { // unlimited payload; response streaming mode + if isErrorResponse(additionalHeaders) { // send streamed error response when runtime called /error + return sendStreamingInvokeErrorResponse(payload, w, interruptedResponseChan, sendResponseChan, runtimeCalledResponse) + } + // send streamed response when runtime called /response + return sendStreamingInvokeResponse(payload, trailers, w, interruptedResponseChan, sendResponseChan, request, runtimeCalledResponse) + } + + return sendPayloadLimitedResponse(payload, trailers, w, sendResponseChan, runtimeCalledResponse) } diff --git a/lambda/core/directinvoke/directinvoke_test.go b/lambda/core/directinvoke/directinvoke_test.go new file mode 100644 index 0000000..4e26161 --- /dev/null +++ b/lambda/core/directinvoke/directinvoke_test.go @@ -0,0 +1,358 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "bytes" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/interop" +) + +func NewResponseWriterWithoutFlushMethod() *ResponseWriterWithoutFlushMethod { + return &ResponseWriterWithoutFlushMethod{} +} + +type ResponseWriterWithoutFlushMethod struct{} + +func (*ResponseWriterWithoutFlushMethod) Header() http.Header { return http.Header{} } +func (*ResponseWriterWithoutFlushMethod) Write([]byte) (n int, err error) { return } +func (*ResponseWriterWithoutFlushMethod) WriteHeader(_ int) {} + +func NewSimpleResponseWriter() *SimpleResponseWriter { + return &SimpleResponseWriter{ + buffer: bytes.NewBuffer(nil), + trailers: make(http.Header), + } +} + +type SimpleResponseWriter struct { + buffer *bytes.Buffer + trailers http.Header +} + +func (w *SimpleResponseWriter) Header() http.Header { return w.trailers } +func (w *SimpleResponseWriter) Write(p []byte) (n int, err error) { return w.buffer.Write(p) } +func (*SimpleResponseWriter) WriteHeader(_ int) {} +func (*SimpleResponseWriter) Flush() {} + +func NewInterruptableResponseWriter(interruptAfter int) (*InterruptableResponseWriter, chan struct{}) { + interruptedTestWriterChan := make(chan struct{}) + return &InterruptableResponseWriter{ + buffer: bytes.NewBuffer(nil), + trailers: make(http.Header), + interruptAfter: interruptAfter, + interruptedTestWriterChan: interruptedTestWriterChan, + }, interruptedTestWriterChan +} + +type InterruptableResponseWriter struct { + buffer *bytes.Buffer + trailers http.Header + interruptAfter int // expect Writer to be interrupted after 'interruptAfter' number of writes + interruptedTestWriterChan chan struct{} +} + +func (w *InterruptableResponseWriter) Header() http.Header { return w.trailers } +func (w *InterruptableResponseWriter) Write(p []byte) (n int, err error) { + if w.interruptAfter >= 1 { + w.interruptAfter-- + } else if w.interruptAfter == 0 { + w.interruptedTestWriterChan <- struct{}{} // ready to be interrupted + <-w.interruptedTestWriterChan // wait until interrupted + } + n, err = w.buffer.Write(p) + return +} +func (*InterruptableResponseWriter) WriteHeader(_ int) {} +func (*InterruptableResponseWriter) Flush() {} + +// This is a simple reader implementing io.Reader interface. It's based on strings.Reader, but it doesn't have extra +// methods that allow faster copying such as .WriteTo() method. +func NewReader(s string) *Reader { return &Reader{s, 0, -1} } + +type Reader struct { + s string + i int64 // current reading index + prevRune int // index of previous rune; or < 0 +} + +func (r *Reader) Read(b []byte) (n int, err error) { + if r.i >= int64(len(r.s)) { + return 0, io.EOF + } + r.prevRune = -1 + n = copy(b, r.s[r.i:]) + r.i += int64(n) + return +} + +func TestSendDirectInvokeWithIncompatibleResponseWriter(t *testing.T) { + MaxDirectResponseSize = -1 + err := SendDirectInvokeResponse(nil, nil, nil, NewResponseWriterWithoutFlushMethod(), nil, nil, nil, false) + require.Error(t, err) + require.Equal(t, "ErrInternalPlatformError", err.Error()) +} + +func TestAsyncPayloadCopySuccess(t *testing.T) { + payloadString := strings.Repeat("a", 10*1024*1024) + writer := NewSimpleResponseWriter() + + expectedPayloadString := payloadString + + copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + <-copyDone + require.Equal(t, expectedPayloadString, writer.buffer.String()) +} + +// We use an interruptable response writer which informs on a channel that it's ready to be interrupted after +// 'interruptAfter' number of writes, then it waits for interruption completion to resume the current write operation. +// For this test, after initiating copying, we wait for one chunk of 32 KiB to be copied. Then, we use cancel() to +// interrupt copying. At this point, only ongoing .Write() operations can be performed. We inform the writer about +// interruption completion, and the writer resumes the current .Write() operation, which gives us another 32 KiB chunk +// that is copied. After that, copying returns, and we receive a signal on <-copyDone channel. +func TestAsyncPayloadCopySuccessAfterCancel(t *testing.T) { + payloadString := strings.Repeat("a", 10*1024*1024) // 10 MiB + writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) + + expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB (2 chunks) + + copyDone, cancel, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks + cancel() // interrupt copying + interruptedTestWriterChan <- struct{}{} // inform test writer about interruption + + <-copyDone + require.Equal(t, expectedPayloadString, writer.buffer.String()) +} + +func TestAsyncPayloadCopyWithIncompatibleResponseWriter(t *testing.T) { + copyDone, cancel, err := asyncPayloadCopy(&ResponseWriterWithoutFlushMethod{}, nil) + require.Nil(t, copyDone) + require.Nil(t, cancel) + require.Error(t, err) + require.Equal(t, "ErrInternalPlatformError", err.Error()) +} + +func TestSendStreamingInvokeResponseSuccess(t *testing.T) { + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + trailers := http.Header{} + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := payloadString + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendPayloadLimitedResponseWithinThresholdWithStreamingFunction(t *testing.T) { + payloadSize := 1 + payloadString := strings.Repeat("a", payloadSize) + payload := NewReader(payloadString) + trailers := http.Header{} + writer := NewSimpleResponseWriter() + writer.Header().Set("Lambda-Runtime-Function-Response-Mode", "streaming") + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + MaxDirectResponseSize = int64(payloadSize + 1) + + go func() { + err := sendPayloadLimitedResponse(payload, trailers, writer, sendResponseChan, true) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + metrics := <-sendResponseChan + require.Equal(t, interop.FunctionResponseModeBuffered, metrics.FunctionResponseMode) + require.Equal(t, len(payloadString), len(writer.buffer.String())) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished + + // Reset to its default value, just in case other tests use them + MaxDirectResponseSize = interop.MaxPayloadSize +} + +func TestSendPayloadLimitedResponseAboveThresholdWithStreamingFunction(t *testing.T) { + payloadSize := 2 + payloadString := strings.Repeat("a", payloadSize) + payload := NewReader(payloadString) + trailers := http.Header{} + writer := NewSimpleResponseWriter() + writer.Header().Set("Lambda-Runtime-Function-Response-Mode", "streaming") + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + MaxDirectResponseSize = int64(payloadSize - 1) + expectedError := &interop.ErrorResponseTooLargeDI{ + ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ + MaxResponseSize: int(MaxDirectResponseSize), + ResponseSize: payloadSize, + }, + } + + go func() { + err := sendPayloadLimitedResponse(payload, trailers, writer, sendResponseChan, true) + require.Equal(t, expectedError, err) + testFinished <- struct{}{} + }() + + metrics := <-sendResponseChan + require.Equal(t, interop.FunctionResponseModeBuffered, metrics.FunctionResponseMode) + require.Equal(t, len(payloadString), len(writer.buffer.String())) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Oversized", writer.Header().Get("End-Of-Response")) + <-testFinished + + // Reset to its default value, just in case other tests use them + MaxDirectResponseSize = interop.MaxPayloadSize +} + +func TestSendStreamingInvokeResponseSuccessWithTrailers(t *testing.T) { + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + trailers := http.Header{ + "Lambda-Runtime-Function-Error-Type": []string{"ErrorType"}, + "Lambda-Runtime-Function-Error-Body": []string{"ErrorBody"}, + } + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := payloadString + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "ErrorType", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "ErrorBody", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendStreamingInvokeResponseReset(t *testing.T) { // Reset initiated after writing two chunks of 32 KiB + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + trailers := http.Header{} + writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, true) + require.Error(t, err) + require.Equal(t, "ErrTruncatedResponse", err.Error()) + testFinished <- struct{}{} + }() + + reset := &interop.Reset{Reason: "timeout"} + require.Nil(t, reset.InvokeResponseMetrics) + + <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks + interruptedResponseChan <- reset // send reset + time.Sleep(10 * time.Millisecond) // wait for cancel() being called (first instruction after getting reset) + interruptedTestWriterChan <- struct{}{} // inform test writer about interruption + <-interruptedResponseChan // wait for copy done after interruption + require.NotNil(t, reset.InvokeResponseMetrics) + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "Sandbox.Timeout", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Truncated", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendStreamingInvokeErrorResponseSuccess(t *testing.T) { + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := payloadString + + go func() { + err := sendStreamingInvokeErrorResponse(payload, writer, interruptedResponseChan, sendResponseChan, false) + require.Nil(t, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) + <-testFinished +} + +func TestSendStreamingInvokeErrorResponseReset(t *testing.T) { // Reset initiated after writing two chunks of 32 KiB + payloadString := strings.Repeat("a", 128*1024) // 128 KiB + payload := NewReader(payloadString) + writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB + + go func() { + err := sendStreamingInvokeErrorResponse(payload, writer, interruptedResponseChan, sendResponseChan, true) + require.Error(t, err) + require.Equal(t, "ErrTruncatedResponse", err.Error()) + testFinished <- struct{}{} + }() + + reset := &interop.Reset{Reason: "timeout"} + require.Nil(t, reset.InvokeResponseMetrics) + + <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks + interruptedResponseChan <- reset // send reset + time.Sleep(10 * time.Millisecond) // wait for cancel() being called (first instruction after getting reset) + interruptedTestWriterChan <- struct{}{} // inform test writer about interruption + <-interruptedResponseChan // wait for copy done after interruption + require.NotNil(t, reset.InvokeResponseMetrics) + + <-sendResponseChan + require.Equal(t, expectedPayloadString, writer.buffer.String()) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) + require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) + require.Equal(t, "Truncated", writer.Header().Get("End-Of-Response")) + <-testFinished +} diff --git a/lambda/core/directinvoke/util.go b/lambda/core/directinvoke/util.go new file mode 100644 index 0000000..511d656 --- /dev/null +++ b/lambda/core/directinvoke/util.go @@ -0,0 +1,84 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package directinvoke + +import ( + "context" + "errors" + "go.amzn.com/lambda/core/bandwidthlimiter" + "io" + "net/http" + "time" + + log "github.com/sirupsen/logrus" +) + +const DefaultRefillIntervalMs = 125 // default refill interval in milliseconds + +func NewStreamedResponseWriter(w http.ResponseWriter) (*bandwidthlimiter.BandwidthLimitingWriter, context.CancelFunc, error) { + flushingWriter, err := NewFlushingWriter(w) // after writing a chunk we have to flush it to avoid additional buffering by ResponseWriter + if err != nil { + return nil, nil, err + } + cancellableWriter, cancel := NewCancellableWriter(flushingWriter) // cancelling prevents next calls to Write() from happening + + refillNumber := ResponseBandwidthRate * DefaultRefillIntervalMs / 1000 // refillNumber is calculated based on 'ResponseBandwidthRate' and bucket refill interval + refillInterval := DefaultRefillIntervalMs * time.Millisecond + + // Initial bucket for token bucket algorithm allows for a burst of up to 6 MiB, and an average transmission rate of 2 MiB/s + bucket, err := bandwidthlimiter.NewBucket(ResponseBandwidthBurstSize, ResponseBandwidthBurstSize, refillNumber, refillInterval) + if err != nil { + cancel() // free resources + return nil, nil, err + } + + bandwidthLimitingWriter, err := bandwidthlimiter.NewBandwidthLimitingWriter(cancellableWriter, bucket) + if err != nil { + cancel() // free resources + return nil, nil, err + } + + return bandwidthLimitingWriter, cancel, nil +} + +func NewFlushingWriter(w io.Writer) (*FlushingWriter, error) { + flusher, ok := w.(http.Flusher) + if !ok { + errorMsg := "expected http.ResponseWriter to be an http.Flusher" + log.Error(errorMsg) + return nil, errors.New(errorMsg) + } + return &FlushingWriter{ + w: w, + flusher: flusher, + }, nil +} + +type FlushingWriter struct { + w io.Writer + flusher http.Flusher +} + +func (w *FlushingWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + w.flusher.Flush() + return +} + +func NewCancellableWriter(w io.Writer) (*CancellableWriter, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return &CancellableWriter{w: w, ctx: ctx}, cancel +} + +type CancellableWriter struct { + w io.Writer + ctx context.Context +} + +func (w *CancellableWriter) Write(p []byte) (int, error) { + if err := w.ctx.Err(); err != nil { + return 0, err + } + return w.w.Write(p) +} diff --git a/lambda/core/doc.go b/lambda/core/doc.go index 23c1539..4a7157f 100644 --- a/lambda/core/doc.go +++ b/lambda/core/doc.go @@ -2,26 +2,23 @@ // SPDX-License-Identifier: Apache-2.0 /* - Package core provides state objects and synchronization primitives for managing data flow in the system. - -States +# States Runtime and Agent implement state object design pattern. Runtime state interface: -type RuntimeState interface { - InitError() error - Ready() error - InvocationResponse() error - InvocationErrorResponse() error -} + type RuntimeState interface { + InitError() error + Ready() error + InvocationResponse() error + InvocationErrorResponse() error + } - -Gates +# Gates Gates provide synchornization primitives for managing data flow in the system. @@ -31,8 +28,9 @@ set of operations being performed in other threads completes. To better understand gates, consider two examples below: Example 1: main thread is awaiting registered threads to walk through the gate, - and after the last registered thread walked through the gate, gate - condition will be satisfied and main thread will proceed: + + and after the last registered thread walked through the gate, gate + condition will be satisfied and main thread will proceed: [main] // register threads with the gate and start threads ... [main] g.AwaitGateCondition() @@ -42,27 +40,25 @@ Example 1: main thread is awaiting registered threads to walk through the gate, [thread] // not blocked Example 2: main thread is awaiting registered threads to arrive at the gate, - and after the last registered thread arrives at the gate, gate - condition will be satisfied and main thread, along with registered - threads will proceed: + + and after the last registered thread arrives at the gate, gate + condition will be satisfied and main thread, along with registered + threads will proceed: [main] // register threads with the gate and start threads ... [main] g.AwaitGateCondition() [main] // blocked until gate condition is satisfied - -Flow +# Flow Flow wraps a set of specific gates required to implement specific data flow in the system. Example flows would be INIT, INVOKE and RESET. - -Registrations +# Registrations Registration service manages registrations, it maintains the mapping between registered parties are events they are registered. Parties not registered in the system will not be issued events. - */ package core diff --git a/lambda/core/externalagent.go b/lambda/core/externalagent.go index 792f356..cd367d2 100644 --- a/lambda/core/externalagent.go +++ b/lambda/core/externalagent.go @@ -22,7 +22,6 @@ type ExternalAgent struct { currentState ExternalAgentState stateLastModified time.Time - Pid int StartedState ExternalAgentState RegisteredState ExternalAgentState diff --git a/lambda/core/flow.go b/lambda/core/flow.go index 3c22b84..b2cb538 100644 --- a/lambda/core/flow.go +++ b/lambda/core/flow.go @@ -19,6 +19,9 @@ type InitFlowSynchronization interface { CancelWithError(error) + RuntimeRestoreReady() error + AwaitRuntimeRestoreReady() error + Clear() } @@ -26,6 +29,7 @@ type initFlowSynchronizationImpl struct { externalAgentsRegisteredGate Gate runtimeReadyGate Gate agentReadyGate Gate + runtimeRestoreReadyGate Gate } // SetExternalAgentsRegisterCount notifies init flow that N /extension/register calls should be done in future by external agents @@ -43,6 +47,11 @@ func (s *initFlowSynchronizationImpl) AwaitRuntimeReady() error { return s.runtimeReadyGate.AwaitGateCondition() } +// AwaitRuntimeRestoreReady awaits runtime restore ready state (/restore/next is called by runtime) +func (s *initFlowSynchronizationImpl) AwaitRuntimeRestoreReady() error { + return s.runtimeRestoreReadyGate.AwaitGateCondition() +} + // AwaitExternalAgentsRegistered awaits for all subscribed agents to report registered func (s *initFlowSynchronizationImpl) AwaitExternalAgentsRegistered() error { return s.externalAgentsRegisteredGate.AwaitGateCondition() @@ -58,6 +67,11 @@ func (s *initFlowSynchronizationImpl) RuntimeReady() error { return s.runtimeReadyGate.WalkThrough() } +// Ready called by runtime when restore is completed (i.e. /next is called after /restore/next) +func (s *initFlowSynchronizationImpl) RuntimeRestoreReady() error { + return s.runtimeRestoreReadyGate.WalkThrough() +} + // Ready called by agent when initialized func (s *initFlowSynchronizationImpl) AgentReady() error { return s.agentReadyGate.WalkThrough() @@ -73,6 +87,7 @@ func (s *initFlowSynchronizationImpl) CancelWithError(err error) { s.externalAgentsRegisteredGate.CancelWithError(err) s.runtimeReadyGate.CancelWithError(err) s.agentReadyGate.CancelWithError(err) + s.runtimeRestoreReadyGate.CancelWithError(err) } // Clear gates state @@ -80,6 +95,7 @@ func (s *initFlowSynchronizationImpl) Clear() { s.externalAgentsRegisteredGate.Clear() s.runtimeReadyGate.Clear() s.agentReadyGate.Clear() + s.runtimeRestoreReadyGate.Clear() } // NewInitFlowSynchronization returns new InitFlowSynchronization instance. @@ -88,6 +104,7 @@ func NewInitFlowSynchronization() InitFlowSynchronization { runtimeReadyGate: NewGate(1), externalAgentsRegisteredGate: NewGate(0), agentReadyGate: NewGate(maxAgentsLimit), + runtimeRestoreReadyGate: NewGate(1), } return initFlow } diff --git a/lambda/core/registrations.go b/lambda/core/registrations.go index dca9d90..f68612c 100644 --- a/lambda/core/registrations.go +++ b/lambda/core/registrations.go @@ -10,8 +10,11 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/interop" "github.com/google/uuid" + + log "github.com/sirupsen/logrus" ) type registrationServiceState int @@ -70,6 +73,7 @@ type FunctionMetadata struct { FunctionName string FunctionVersion string Handler string + RuntimeInfo interop.RuntimeInfo } // RegistrationService keeps track of registered parties, including external agents, threads, and runtime. @@ -94,6 +98,7 @@ type RegistrationService interface { CountAgents() int Clear() AgentsInfo() []AgentInfo + CancelFlows(err error) } type registrationServiceImpl struct { @@ -105,6 +110,7 @@ type registrationServiceImpl struct { initFlow InitFlowSynchronization invokeFlow InvokeFlowSynchronization functionMetadata FunctionMetadata + cancelOnce sync.Once } func (s *registrationServiceImpl) Clear() { @@ -115,6 +121,7 @@ func (s *registrationServiceImpl) Clear() { s.internalAgents.Clear() s.externalAgents.Clear() s.state = registrationServiceOn + s.cancelOnce = sync.Once{} } func (s *registrationServiceImpl) InitFlow() InitFlowSynchronization { @@ -373,6 +380,19 @@ func (s *registrationServiceImpl) TurnOff() { s.state = registrationServiceOff } +// CancelFlows cancels init and invoke flows with error. +func (s *registrationServiceImpl) CancelFlows(err error) { + s.mutex.Lock() + defer s.mutex.Unlock() + // The following block protects us from overwriting the error + // which was first used to cancel flows. + s.cancelOnce.Do(func() { + log.Debugf("Canceling flows: %s", err) + s.initFlow.CancelWithError(err) + s.invokeFlow.CancelWithError(err) + }) +} + // NewRegistrationService returns new RegistrationService instance. func NewRegistrationService(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchronization) RegistrationService { return ®istrationServiceImpl{ @@ -382,5 +402,6 @@ func NewRegistrationService(initFlow InitFlowSynchronization, invokeFlow InvokeF externalAgents: NewExternalAgentsMap(), initFlow: initFlow, invokeFlow: invokeFlow, + cancelOnce: sync.Once{}, } } diff --git a/lambda/core/registrations_test.go b/lambda/core/registrations_test.go index d8857a6..5956ac3 100644 --- a/lambda/core/registrations_test.go +++ b/lambda/core/registrations_test.go @@ -63,7 +63,7 @@ func TestRegistrationServiceHappyPathDuringInit(t *testing.T) { assert.NoError(t, runtime.Ready()) }() - assert.NoError(t, initFlow.AwaitRuntimeReady()) + assert.NoError(t, initFlow.AwaitRuntimeRestoreReady()) registrationService.TurnOff() // Agents Ready diff --git a/lambda/core/runtime_state_names.go b/lambda/core/runtime_state_names.go index b20b9f8..b04ba5d 100644 --- a/lambda/core/runtime_state_names.go +++ b/lambda/core/runtime_state_names.go @@ -5,10 +5,14 @@ package core // String values of possibles runtime states const ( - RuntimeStartedStateName = "Started" - RuntimeInitErrorStateName = "InitError" - RuntimeReadyStateName = "Ready" - RuntimeRunningStateName = "Running" + RuntimeStartedStateName = "Started" + RuntimeInitErrorStateName = "InitError" + RuntimeReadyStateName = "Ready" + RuntimeRunningStateName = "Running" + // RuntimeStartedState -> RuntimeRestoreReadyState + RuntimeRestoreReadyStateName = "RestoreReady" + // RuntimeRestoreReadyState -> RuntimeRestoringState + RuntimeRestoringStateName = "Restoring" RuntimeInvocationResponseStateName = "InvocationResponse" RuntimeInvocationErrorResponseStateName = "InvocationErrorResponse" RuntimeResponseSentStateName = "RuntimeResponseSentState" diff --git a/lambda/core/states.go b/lambda/core/states.go index bc7359d..a5e2010 100644 --- a/lambda/core/states.go +++ b/lambda/core/states.go @@ -72,6 +72,7 @@ var ErrConcurrentStateModification = errors.New("Concurrent state modification") type RuntimeState interface { InitError() error Ready() error + RestoreReady() error InvocationResponse() error InvocationErrorResponse() error ResponseSent() error @@ -82,6 +83,7 @@ type disallowEveryTransitionByDefault struct{} func (s *disallowEveryTransitionByDefault) InitError() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) Ready() error { return ErrNotAllowed } +func (s *disallowEveryTransitionByDefault) RestoreReady() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) InvocationResponse() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) InvocationErrorResponse() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) ResponseSent() error { return ErrNotAllowed } @@ -92,13 +94,14 @@ type Runtime struct { currentState RuntimeState stateLastModified time.Time - Pid int responseTime time.Time RuntimeStartedState RuntimeState RuntimeInitErrorState RuntimeState RuntimeReadyState RuntimeState RuntimeRunningState RuntimeState + RuntimeRestoreReadyState RuntimeState + RuntimeRestoringState RuntimeState RuntimeInvocationResponseState RuntimeState RuntimeInvocationErrorResponseState RuntimeState RuntimeResponseSentState RuntimeState @@ -135,6 +138,12 @@ func (s *Runtime) Ready() error { return s.currentState.Ready() } +func (s *Runtime) RestoreReady() error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.RestoreReady() +} + // InvocationResponse delegates to state implementation. func (s *Runtime) InvocationResponse() error { s.ManagedThread.Lock() @@ -196,6 +205,8 @@ func NewRuntime(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchroni runtime.RuntimeInvocationResponseState = &RuntimeInvocationResponseState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeInvocationErrorResponseState = &RuntimeInvocationErrorResponseState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeResponseSentState = &RuntimeResponseSentState{runtime: runtime, invokeFlow: invokeFlow} + runtime.RuntimeRestoreReadyState = &RuntimeRestoreReadyState{} + runtime.RuntimeRestoringState = &RuntimeRestoringState{runtime: runtime, initFlow: initFlow} runtime.setStateUnsafe(runtime.RuntimeStartedState) return runtime @@ -211,7 +222,14 @@ type RuntimeStartedState struct { // Ready call when runtime init done. func (s *RuntimeStartedState) Ready() error { s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) - err := s.initFlow.RuntimeReady() + // runtime called /next without calling /restore/next + // that means it's not interested in restore phase + err := s.initFlow.RuntimeRestoreReady() + if err != nil { + return err + } + + err = s.initFlow.RuntimeReady() if err != nil { return err } @@ -225,6 +243,22 @@ func (s *RuntimeStartedState) Ready() error { return nil } +func (s *RuntimeStartedState) RestoreReady() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeRestoreReadyState) + err := s.initFlow.RuntimeRestoreReady() + if err != nil { + return err + } + + s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeRestoreReadyState && s.runtime.currentState != s.runtime.RuntimeRestoringState { + return ErrConcurrentStateModification + } + + s.runtime.setStateUnsafe(s.runtime.RuntimeRestoringState) + return nil +} + // InitError move runtime to init error state. func (s *RuntimeStartedState) InitError() error { s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) @@ -236,6 +270,38 @@ func (s *RuntimeStartedState) Name() string { return RuntimeStartedStateName } +type RuntimeRestoringState struct { + disallowEveryTransitionByDefault + runtime *Runtime + initFlow InitFlowSynchronization +} + +// Runtime is healthy after restore and called /next +func (s *RuntimeRestoringState) Ready() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) + err := s.initFlow.RuntimeReady() + if err != nil { + return err + } + s.runtime.ManagedThread.SuspendUnsafe() + if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { + return ErrConcurrentStateModification + } + + s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) + return nil +} + +// Runtime has thrown an exception when executing restore hooks and called /init/error +func (s *RuntimeRestoringState) InitError() error { + s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) + return nil +} + +func (s *RuntimeRestoringState) Name() string { + return RuntimeRestoringStateName +} + // RuntimeInitErrorState runtime started state. type RuntimeInitErrorState struct { disallowEveryTransitionByDefault @@ -297,6 +363,14 @@ func (s *RuntimeRunningState) Name() string { return RuntimeRunningStateName } +type RuntimeRestoreReadyState struct { + disallowEveryTransitionByDefault +} + +func (s *RuntimeRestoreReadyState) Name() string { + return RuntimeRestoreReadyStateName +} + // RuntimeInvocationResponseState runtime response is available. // Start state for runtime response submission. type RuntimeInvocationResponseState struct { diff --git a/lambda/core/states_test.go b/lambda/core/states_test.go index 4b01838..37f38e2 100644 --- a/lambda/core/states_test.go +++ b/lambda/core/states_test.go @@ -39,10 +39,7 @@ func TestRuntimeInitErrorAfterReady(t *testing.T) { } func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // Started assert.Equal(t, runtime.RuntimeStartedState, runtime.GetState()) // Started -> InitError @@ -53,6 +50,10 @@ func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { runtime.SetState(runtime.RuntimeStartedState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Started -> RestoreReady + runtime.SetState(runtime.RuntimeStartedState) + assert.NoError(t, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) // Started -> ResponseSent runtime.SetState(runtime.RuntimeStartedState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -68,10 +69,7 @@ func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { } func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // InitError -> InitError runtime.SetState(runtime.RuntimeInitErrorState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -80,6 +78,10 @@ func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { runtime.SetState(runtime.RuntimeInitErrorState) assert.Equal(t, ErrNotAllowed, runtime.Ready()) assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) + // InitError -> RestoreReady + runtime.SetState(runtime.RuntimeInitErrorState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) // InitError -> ResponseSent runtime.SetState(runtime.RuntimeInitErrorState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -95,10 +97,7 @@ func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { } func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // Ready -> InitError runtime.SetState(runtime.RuntimeReadyState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -107,6 +106,10 @@ func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { runtime.SetState(runtime.RuntimeReadyState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Ready -> RestoreReady + runtime.SetState(runtime.RuntimeReadyState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) // Ready -> ResponseSent runtime.SetState(runtime.RuntimeReadyState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -122,10 +125,7 @@ func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { } func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // Running -> InitError runtime.SetState(runtime.RuntimeRunningState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -134,6 +134,10 @@ func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { runtime.SetState(runtime.RuntimeRunningState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // Running -> RestoreReady + runtime.SetState(runtime.RuntimeRunningState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) // Running -> ResponseSent runtime.SetState(runtime.RuntimeRunningState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -149,10 +153,7 @@ func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { } func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // InvocationResponse -> InitError runtime.SetState(runtime.RuntimeInvocationResponseState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -161,6 +162,10 @@ func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { runtime.SetState(runtime.RuntimeInvocationResponseState) assert.Equal(t, ErrNotAllowed, runtime.Ready()) assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) + // InvocationResponse -> RestoreReady + runtime.SetState(runtime.RuntimeInvocationResponseState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) // InvocationResponse -> ResponseSent runtime.SetState(runtime.RuntimeInvocationResponseState) assert.NoError(t, runtime.ResponseSent()) @@ -177,10 +182,7 @@ func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { } func TestRuntimeStateTransitionsFromInvocationErrorResponseState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // InvocationErrorResponse -> InitError runtime.SetState(runtime.RuntimeInvocationErrorResponseState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -189,6 +191,10 @@ func TestRuntimeStateTransitionsFromInvocationErrorResponseState(t *testing.T) { runtime.SetState(runtime.RuntimeInvocationErrorResponseState) assert.Equal(t, ErrNotAllowed, runtime.Ready()) assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) + // InvocationErrorResponse -> RestoreReady + runtime.SetState(runtime.RuntimeInvocationErrorResponseState) + assert.Equal(t, ErrNotAllowed, runtime.Ready()) + assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) // InvocationErrorResponse -> ResponseSent runtime.SetState(runtime.RuntimeInvocationErrorResponseState) assert.NoError(t, runtime.ResponseSent()) @@ -204,10 +210,7 @@ func TestRuntimeStateTransitionsFromInvocationErrorResponseState(t *testing.T) { } func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} + runtime := newRuntime() // ResponseSent -> InitError runtime.SetState(runtime.RuntimeResponseSentState) assert.Equal(t, ErrNotAllowed, runtime.InitError()) @@ -216,6 +219,10 @@ func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { runtime.SetState(runtime.RuntimeResponseSentState) assert.NoError(t, runtime.Ready()) assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // ResponseSent -> RestoreReady + runtime.SetState(runtime.RuntimeResponseSentState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) // ResponseSent -> ResponseSent runtime.SetState(runtime.RuntimeResponseSentState) assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) @@ -230,6 +237,71 @@ func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) } +func TestRuntimeStateTransitionsFromRestoreReadyState(t *testing.T) { + runtime := newRuntime() + // RestoreReady -> InitError + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> Ready + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> RestoreReady() + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> ResponseSent + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> InvocationResponse + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) + // RestoreReady -> InvocationErrorResponse + runtime.SetState(runtime.RuntimeRestoreReadyState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) + assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) +} + +func TestRuntimeStateTransitionsFromRestoringState(t *testing.T) { + runtime := newRuntime() + // RestoreRunning -> InitError + runtime.SetState(runtime.RuntimeRestoringState) + assert.NoError(t, runtime.InitError()) + assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) + // RestoreRunning -> Ready + runtime.SetState(runtime.RuntimeRestoringState) + assert.NoError(t, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) + // RestoreRunning -> RestoreReady + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> ResponseSent + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> InvocationResponse + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> InvocationErrorResponse + runtime.SetState(runtime.RuntimeRestoringState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) + assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) +} + +func newRuntime() *Runtime { + initFlow := &mockInitFlowSynchronization{} + invokeFlow := &mockInvokeFlowSynchronization{} + runtime := NewRuntime(initFlow, invokeFlow) + runtime.ManagedThread = &mockthread.MockManagedThread{} + + return runtime +} + type mockInitFlowSynchronization struct { mock.Mock ReadyCond *sync.Cond @@ -272,6 +344,12 @@ func (s *mockInitFlowSynchronization) CancelWithError(err error) { s.Called(err) } func (s *mockInitFlowSynchronization) Clear() {} +func (s *mockInitFlowSynchronization) RuntimeRestoreReady() error { + return nil +} +func (s *mockInitFlowSynchronization) AwaitRuntimeRestoreReady() error { + return nil +} type mockInvokeFlowSynchronization struct{ mock.Mock } diff --git a/lambda/core/watchdog.go b/lambda/core/watchdog.go deleted file mode 100644 index bf57d01..0000000 --- a/lambda/core/watchdog.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "fmt" - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/fatalerror" - "sync" -) - -type WaitableProcess interface { - // Wait blocks until process exits and returns error in case of non-zero exit code - Wait() error - // Pid returnes process ID - Pid() int - // Name returnes process executable name (for logging) - Name() string -} - -// Watchdog watches started goroutines. -type Watchdog struct { - cancelOnce sync.Once - initFlow InitFlowSynchronization - invokeFlow InvokeFlowSynchronization - exitPidChan chan<- int - appCtx appctx.ApplicationContext - mutedMutex sync.Mutex - muted bool -} - -func (w *Watchdog) Mute() { - w.mutedMutex.Lock() - defer w.mutedMutex.Unlock() - w.muted = true -} - -func (w *Watchdog) Unmute() { - w.mutedMutex.Lock() - defer w.mutedMutex.Unlock() - w.muted = false -} - -func (w *Watchdog) Muted() bool { - w.mutedMutex.Lock() - defer w.mutedMutex.Unlock() - return w.muted -} - -// GoWait waits for process to complete in separate goroutine and handles the process termination -// Returns PID of the process -func (w *Watchdog) GoWait(p WaitableProcess, errorType fatalerror.ErrorType) int { - pid := p.Pid() - name := p.Name() - appCtx := w.appCtx - go func() { - err := p.Wait() - - if !w.Muted() { - appctx.StoreFirstFatalError(appCtx, errorType) - - if err == nil { - err = fmt.Errorf("exit code 0") - } - log.Warnf("Process %d(%s) exited: %s", pid, name, err) - } - - w.CancelFlows(err) - w.exitPidChan <- pid - }() - - return pid -} - -// CancelFlows cancels init and invoke flows with error. -func (w *Watchdog) CancelFlows(err error) { - // The following block protects us from overwriting the error - // which was first used to cancel flows. - w.cancelOnce.Do(func() { - log.Debugf("Canceling flows: %s", err) - w.initFlow.CancelWithError(err) - w.invokeFlow.CancelWithError(err) - }) -} - -// Clear watchdog state -func (w *Watchdog) Clear() { - w.cancelOnce = sync.Once{} -} - -// NewWatchdog returns new instance of a Watchdog. -func NewWatchdog(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchronization, exitPidChan chan<- int, appCtx appctx.ApplicationContext) *Watchdog { - return &Watchdog{ - initFlow: initFlow, - invokeFlow: invokeFlow, - exitPidChan: exitPidChan, - appCtx: appCtx, - mutedMutex: sync.Mutex{}, - } -} diff --git a/lambda/core/watchdog_test.go b/lambda/core/watchdog_test.go deleted file mode 100644 index 84f8342..0000000 --- a/lambda/core/watchdog_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "errors" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/fatalerror" - "testing" -) - -var errTest = errors.New("ErrTest") - -type MockProcess struct { -} - -func (s *MockProcess) Wait() error { return errTest } -func (s *MockProcess) Pid() int { return 0 } -func (s *MockProcess) Name() string { return "" } - -func TestWatchdogCallback(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - initFlow.On("CancelWithError", mock.Anything) - invokeFlow.On("CancelWithError", mock.Anything) - - pidChan := make(chan int) - appCtx := appctx.NewApplicationContext() - w := NewWatchdog(initFlow, invokeFlow, pidChan, appCtx) - - w.GoWait(&MockProcess{}, fatalerror.AgentExitError) - w.GoWait(&MockProcess{}, fatalerror.AgentExitError) - - <-pidChan - initFlow.AssertCalled(t, "CancelWithError", errTest) - initFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - invokeFlow.AssertCalled(t, "CancelWithError", errTest) - invokeFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - - <-pidChan - initFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - invokeFlow.AssertNumberOfCalls(t, "CancelWithError", 1) - - err, found := appctx.LoadFirstFatalError(appCtx) - require.True(t, found) - require.Equal(t, err, fatalerror.AgentExitError) -} diff --git a/lambda/fatalerror/fatalerror.go b/lambda/fatalerror/fatalerror.go index 7292baf..bb8a86a 100644 --- a/lambda/fatalerror/fatalerror.go +++ b/lambda/fatalerror/fatalerror.go @@ -3,7 +3,7 @@ package fatalerror -// This package defines constant error types returned to slicer with DONE(failure) +// This package defines constant error types returned to slicer with DONE(failure), and also sandbox errors // Separate package for namespacing // ErrorType is returned to slicer inside DONE @@ -18,5 +18,8 @@ const ( InvalidEntrypoint ErrorType = "Runtime.InvalidEntrypoint" InvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir" InvalidTaskConfig ErrorType = "Runtime.InvalidTaskConfig" + TruncatedResponse ErrorType = "Runtime.TruncatedResponse" + SandboxFailure ErrorType = "Sandbox.Failure" + SandboxTimeout ErrorType = "Sandbox.Timeout" Unknown ErrorType = "Unknown" ) diff --git a/lambda/interop/bootstrap.go b/lambda/interop/bootstrap.go new file mode 100644 index 0000000..4a9b6af --- /dev/null +++ b/lambda/interop/bootstrap.go @@ -0,0 +1,18 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "os" + + "go.amzn.com/lambda/fatalerror" +) + +type Bootstrap interface { + Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable + Env(e EnvironmentVariables) map[string]string // returns the environment variables to be passed to the bootstrapped process + Cwd() (string, error) // returns the working directory of the bootstrap process + ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime + CachedFatalError(err error) (fatalerror.ErrorType, string, bool) +} diff --git a/lambda/interop/cancellable_request.go b/lambda/interop/cancellable_request.go new file mode 100644 index 0000000..7e8fca5 --- /dev/null +++ b/lambda/interop/cancellable_request.go @@ -0,0 +1,27 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "net" + "net/http" +) + +type key int + +const ( + HTTPConnKey key = iota +) + +func GetConn(r *http.Request) net.Conn { + return r.Context().Value(HTTPConnKey).(net.Conn) +} + +type CancellableRequest struct { + Request *http.Request +} + +func (c *CancellableRequest) Cancel() error { + return GetConn(c.Request).Close() +} diff --git a/lambda/interop/environment_variables.go b/lambda/interop/environment_variables.go new file mode 100644 index 0000000..46bdf8b --- /dev/null +++ b/lambda/interop/environment_variables.go @@ -0,0 +1,14 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +type EnvironmentVariables interface { + AgentExecEnv() map[string]string + RuntimeExecEnv() map[string]string + SetHandler(handler string) + StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) + StoreEnvironmentVariablesFromInit(customerEnv map[string]string, + handler, awsKey, awsSecret, awsSession, funcName, funcVer string) + StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) +} diff --git a/lambda/interop/model.go b/lambda/interop/model.go index 5cdf63f..cc9c7d0 100644 --- a/lambda/interop/model.go +++ b/lambda/interop/model.go @@ -8,17 +8,99 @@ import ( "fmt" "io" "net/http" + "strings" "time" "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/supervisor/model" + + log "github.com/sirupsen/logrus" ) // MaxPayloadSize max event body size declared as LAMBDA_EVENT_BODY_SIZE -const MaxPayloadSize = 6*1024*1024 + 100 // 6 MiB + 100 bytes +const ( + MaxPayloadSize = 6*1024*1024 + 100 // 6 MiB + 100 bytes + + ResponseBandwidthRate = 2 * 1024 * 1024 // default average rate of 2 MiB/s + ResponseBandwidthBurstSize = 6 * 1024 * 1024 // default burst size of 6 MiB + + MinResponseBandwidthRate = 32 * 1024 // 32 KiB/s + MaxResponseBandwidthRate = 64 * 1024 * 1024 // 64 MiB/s + + MinResponseBandwidthBurstSize = 32 * 1024 // 32 KiB + MaxResponseBandwidthBurstSize = 64 * 1024 * 1024 // 64 MiB +) const functionResponseSizeTooLargeType = "Function.ResponseSizeTooLarge" +// ResponseMode are top-level constants used in combination with the various types of +// modes we have for responses, such as invoke's response mode and function's response mode. +// In the future we might have invoke's request mode or similar, so these help set the ground +// for consistency. +type ResponseMode string + +const ResponseModeBuffered = "Buffered" +const ResponseModeStreaming = "Streaming" + +type InvokeResponseMode string + +const InvokeResponseModeBuffered InvokeResponseMode = ResponseModeBuffered +const InvokeResponseModeStreaming InvokeResponseMode = ResponseModeStreaming + +var AllInvokeResponseModes = []string{ + string(InvokeResponseModeBuffered), string(InvokeResponseModeStreaming), +} + +// ConvertToInvokeResponseMode converts the given string to a InvokeResponseMode +// It is case insensitive and if there is no match, an error is thrown. +func ConvertToInvokeResponseMode(value string) (InvokeResponseMode, error) { + // buffered + if strings.EqualFold(value, string(InvokeResponseModeBuffered)) { + return InvokeResponseModeBuffered, nil + } + + // streaming + if strings.EqualFold(value, string(InvokeResponseModeStreaming)) { + return InvokeResponseModeStreaming, nil + } + + // unknown + allowedValues := strings.Join(AllInvokeResponseModes, ", ") + log.Errorf("Unlable to map %s to %s.", value, allowedValues) + return "", ErrInvalidInvokeResponseMode +} + +// FunctionResponseMode is passed by Runtime to tell whether the response should be +// streamed or not. +type FunctionResponseMode string + +const FunctionResponseModeBuffered FunctionResponseMode = ResponseModeBuffered +const FunctionResponseModeStreaming FunctionResponseMode = ResponseModeStreaming + +var AllFunctionResponseModes = []string{ + string(FunctionResponseModeBuffered), string(FunctionResponseModeStreaming), +} + +// ConvertToFunctionResponseMode converts the given string to a FunctionResponseMode +// It is case insensitive and if there is no match, an error is thrown. +func ConvertToFunctionResponseMode(value string) (FunctionResponseMode, error) { + // buffered + if strings.EqualFold(value, string(FunctionResponseModeBuffered)) { + return FunctionResponseModeBuffered, nil + } + + // streaming + if strings.EqualFold(value, string(FunctionResponseModeStreaming)) { + return FunctionResponseModeStreaming, nil + } + + // unknown + allowedValues := strings.Join(AllFunctionResponseModes, ", ") + log.Errorf("Unlable to map %s to %s.", value, allowedValues) + return "", ErrInvalidFunctionResponseMode +} + // Message is a generic interop message. type Message interface{} @@ -37,11 +119,10 @@ type Invoke struct { ContentType string Payload io.Reader NeedDebugLogs bool - CorrelationID string // internal use only ReservationToken string VersionID string InvokeReceivedTime int64 - ResyncState Resync + InvokeResponseMetrics *InvokeResponseMetrics } type Token struct { @@ -54,21 +135,13 @@ type Token struct { LambdaSegmentID string InvokeMetadata string NeedDebugLogs bool - ResyncState Resync -} - -type Resync struct { - IsResyncReceived bool - AwsKey string - AwsSecret string - AwsSession string - ReceivedTime time.Time } type ErrorResponse struct { // Payload sent via shared memory. - Payload []byte `json:"Payload,omitempty"` - ContentType string `json:"-"` + Payload []byte `json:"Payload,omitempty"` + ContentType string `json:"-"` + FunctionResponseMode string `json:"-"` // When error response body (Payload) is not provided, e.g. // not retrievable, error type and error message will be @@ -92,48 +165,80 @@ type SandboxType string const SandboxPreWarmed SandboxType = "PreWarmed" const SandboxClassic SandboxType = "Classic" -// Start message received from the slicer, part of the protocol. -type Start struct { - InvokeID string - Handler string - AwsKey string - AwsSecret string - AwsSession string - SuppressInit bool - XRayDaemonAddress string // only in standalone - FunctionName string // only in standalone - FunctionVersion string // only in standalone - CorrelationID string // internal use only - // TODO: define new Init type that has the Start fields as well as env vars below. - // In standalone mode, these env vars come from test/init but from environment otherwise. - CustomerEnvironmentVariables map[string]string - SandboxType SandboxType +// RuntimeInfo contains metadata about the runtime used by the Sandbox +type RuntimeInfo struct { + ImageJSON string // image config, e.g {\"layers\":[]} + Arn string // runtime ARN, e.g. arn:awstest:lambda:us-west-2::runtime:python3.8::alpha + Version string // human-readable runtime arn equivalent, e.g. python3.8.v999 } -// Running message is sent to the slicer, part of the protocol. -type Running struct { - WaitStartTimeNs int64 - WaitEndTimeNs int64 - PreLoadTimeNs int64 - PostLoadTimeNs int64 - ExtensionsEnabled bool +// Captures configuration of the operator and runtime domain +// that are only known after INIT is received +type DynamicDomainConfig struct { + // extra hooks to execute at domain start. Currently used for filesystem and network hooks. + // It can be empty. + AdditionalStartHooks []model.Hook + Mounts []model.DriveMount + //TODO: other dynamic configurations for the domain go here } // Reset message is sent to rapid to initiate reset sequence type Reset struct { - Reason string - DeadlineNs int64 - CorrelationID string // internal use only + Reason string + DeadlineNs int64 + InvokeResponseMetrics *InvokeResponseMetrics + TraceID string + LambdaSegmentID string +} + +// Restore message is sent to rapid to restore runtime to make it ready for consecutive invokes +type Restore struct { + AwsKey string + AwsSecret string + AwsSession string + CredentialsExpiry time.Time +} + +type Resync struct { } // Shutdown message is sent to rapid to initiate graceful shutdown type Shutdown struct { - DeadlineNs int64 - CorrelationID string // internal use only + DeadlineNs int64 +} + +// Metrics for response status of LogsAPI/TelemetryAPI `/subscribe` calls +type TelemetrySubscriptionMetrics map[string]int + +func MergeSubscriptionMetrics(logsAPIMetrics TelemetrySubscriptionMetrics, telemetryAPIMetrics TelemetrySubscriptionMetrics) TelemetrySubscriptionMetrics { + metrics := make(map[string]int) + for metric, value := range logsAPIMetrics { + metrics[metric] = value + } + + for metric, value := range telemetryAPIMetrics { + metrics[metric] += value + } + return metrics +} + +// InvokeResponseMetrics are produced while sending streaming invoke response to WP +type InvokeResponseMetrics struct { + StartReadingResponseMonoTimeMs int64 + FinishReadingResponseMonoTimeMs int64 + TimeShapedNs int64 + ProducedBytes int64 + OutboundThroughputBps int64 // in bytes per second + FunctionResponseMode FunctionResponseMode + RuntimeCalledResponse bool } -// Metrics for response status of LogsAPI `/subscribe` calls -type LogsAPIMetrics map[string]int +func IsResponseStreamingMetrics(metrics *InvokeResponseMetrics) bool { + if metrics == nil { + return false + } + return metrics.FunctionResponseMode == FunctionResponseModeStreaming +} type DoneMetadata struct { NumActiveExtensions int @@ -141,25 +246,26 @@ type DoneMetadata struct { ExtensionNames string RuntimeRelease string // Metrics for response status of LogsAPI `/subscribe` calls - LogsAPIMetrics LogsAPIMetrics - InvokeRequestReadTimeNs int64 - InvokeRequestSizeBytes int64 - InvokeCompletionTimeNs int64 - InvokeReceivedTime int64 - RuntimeReadyTime int64 + LogsAPIMetrics TelemetrySubscriptionMetrics + InvokeRequestReadTimeNs int64 + InvokeRequestSizeBytes int64 + InvokeCompletionTimeNs int64 + InvokeReceivedTime int64 + RuntimeReadyTime int64 + RuntimeTimeThrottledMs int64 + RuntimeProducedBytes int64 + RuntimeOutboundThroughputBps int64 } type Done struct { - WaitForExit bool - ErrorType fatalerror.ErrorType - CorrelationID string // internal use only - Meta DoneMetadata + WaitForExit bool + ErrorType fatalerror.ErrorType + Meta DoneMetadata } type DoneFail struct { - ErrorType fatalerror.ErrorType - CorrelationID string // internal use only - Meta DoneMetadata + ErrorType fatalerror.ErrorType + Meta DoneMetadata } // ErrInvalidInvokeID is returned when invokeID provided in Invoke2 does not match one provided in Token @@ -171,6 +277,22 @@ var ErrInvalidReservationToken = fmt.Errorf("ErrInvalidReservationToken") // ErrInvalidFunctionVersion is returned when functionVersion provided in Invoke2 does not match one provided in Token var ErrInvalidFunctionVersion = fmt.Errorf("ErrInvalidFunctionVersion") +// ErrInvalidFunctionResponseMode is returned when the value sent by runtime during Invoke2 +// is not a constant of type interop.FunctionResponseMode +var ErrInvalidFunctionResponseMode = fmt.Errorf("ErrInvalidFunctionResponseMode") + +// ErrInvalidInvokeResponseMode is returned when optional InvokeResponseMode header provided in Invoke2 is not a constant of type interop.InvokeResponseMode +var ErrInvalidInvokeResponseMode = fmt.Errorf("ErrInvalidInvokeResponseMode") + +// ErrInvalidMaxPayloadSize is returned when optional MaxPayloadSize header provided in Invoke2 is invalid +var ErrInvalidMaxPayloadSize = fmt.Errorf("ErrInvalidMaxPayloadSize") + +// ErrInvalidResponseBandwidthRate is returned when optional ResponseBandwidthRate header provided in Invoke2 is invalid +var ErrInvalidResponseBandwidthRate = fmt.Errorf("ErrInvalidResponseBandwidthRate") + +// ErrInvalidResponseBandwidthBurstSize is returned when optional ResponseBandwidthBurstSize header provided in Invoke2 is invalid +var ErrInvalidResponseBandwidthBurstSize = fmt.Errorf("ErrInvalidResponseBandwidthBurstSize") + // ErrMalformedCustomerHeaders is returned when customer headers format is invalid var ErrMalformedCustomerHeaders = fmt.Errorf("ErrMalformedCustomerHeaders") @@ -180,6 +302,20 @@ var ErrResponseSent = fmt.Errorf("ErrResponseSent") // ErrReservationExpired is returned when invoke arrived after InvackDeadline var ErrReservationExpired = fmt.Errorf("ErrReservationExpired") +// ErrInternalPlatformError is returned when internal platform error occurred +type ErrInternalPlatformError struct{} + +func (s *ErrInternalPlatformError) Error() string { + return "ErrInternalPlatformError" +} + +// ErrTruncatedResponse is returned when response is truncated +type ErrTruncatedResponse struct{} + +func (s *ErrTruncatedResponse) Error() string { + return "ErrTruncatedResponse" +} + // ErrorResponseTooLarge is returned when response Payload exceeds shared memory buffer size type ErrorResponseTooLarge struct { MaxResponseSize int @@ -211,17 +347,21 @@ func (s *ErrorResponseTooLarge) AsInteropError() *ErrorResponse { return &resp } -// Server implements Slicer communication protocol. +// Server used for sending messages and sharing data between the Runtime API handlers and the +// internal platform facing servers. For example, +// +// responseCtx.SendResponse(...) +// +// will send the response payload and metadata provided by the runtime to the platform, through the internal +// protocol used by the specific implementation +// TODO: rename this to InvokeResponseContext, used to send responses from handlers to platform-facing server type Server interface { - // StartAcceptingDirectInvokes starts accepting on direct invoke socket (if one is available) - StartAcceptingDirectInvokes() error - - // SendErrorResponse sends response. + // SendResponse sends response. // Errors returned: // ErrInvalidInvokeID - validation error indicating that provided invokeID doesn't match current invokeID // ErrResponseSent - validation error indicating that response with given invokeID was already sent // Non-nil error - non-nil error indicating transport failure - SendResponse(invokeID string, contentType string, response io.Reader) error + SendResponse(invokeID string, headers map[string]string, response io.Reader, trailers http.Header, request *CancellableRequest) error // SendErrorResponse sends error response. // Errors returned: @@ -229,61 +369,36 @@ type Server interface { // ErrResponseSent - validation error indicating that response with given invokeID was already sent // Non-nil error - non-nil error indicating transport failure SendErrorResponse(invokeID string, response *ErrorResponse) error + SendInitErrorResponse(invokeID string, response *ErrorResponse) error // GetCurrentInvokeID returns current invokeID. // NOTE, in case of INIT, when invokeID is not known in advance (e.g. provisioned concurrency), // returned invokeID will contain empty value. GetCurrentInvokeID() string - // CommitMessage confirms that the message written through SendResponse and SendErrorResponse is complete. - CommitResponse() error - - // SendRunning sends GIRD RUNNING. - // Returns error on transport failure. - SendRunning(*Running) error - - // SendRuntimeReady sends GIRD RTREADY + // SendRuntimeReady sends a message indicating the runtime has called /invocation/next. + // The checkpoint allows us to compute the overhead due to Extensions by substracting it + // from the time when all extensions have called /next. + // TODO: this method is a lifecycle event used only for metrics, and doesn't belong here SendRuntimeReady() error +} - // SendDone sends GIRD DONE. - // Returns error on transport failure. - SendDone(*Done) error - - // SendDone sends GIRD DONEFAIL. - // Returns error on transport failure. - SendDoneFail(*DoneFail) error - - // StartChan returns Start emitter - StartChan() <-chan *Start - - // InvokeChan returns Invoke emitter - InvokeChan() <-chan *Invoke - - // ResetChan returns Reset emitter - ResetChan() <-chan *Reset - - // ShutdownChan returns Shutdown emitter - ShutdownChan() <-chan *Shutdown - - // TransportErrorChan emits errors if there was parsing/connection issue - TransportErrorChan() <-chan error - - // Clear is called on rapid reset. It should leave server prepared for new invocations - Clear() - - // IsResponseSent exposes is response sent flag - IsResponseSent() bool - - // The following are used by standalone rapid only - // TODO refactor to decouple the interfaces +type InternalStateGetter func() statejson.InternalStateDescription - SetInternalStateGetter(cb InternalStateGetter) +const OnDemandInitTelemetrySource string = "on-demand" +const ProvisionedConcurrencyInitTelemetrySource string = "provisioned-concurrency" +const InitCachingInitTelemetrySource string = "snap-start" - Init(i *Start, invokeTimeoutMs int64) +func InferTelemetryInitSource(initCachingEnabled bool, sandboxType SandboxType) string { + initSource := OnDemandInitTelemetrySource - Invoke(responseWriter http.ResponseWriter, invoke *Invoke) error + // ToDo: Unify this selection of SandboxType by using the START message + // after having a roadmap on the combination of INIT modes + if initCachingEnabled { + initSource = InitCachingInitTelemetrySource + } else if sandboxType == SandboxPreWarmed { + initSource = ProvisionedConcurrencyInitTelemetrySource + } - Shutdown(shutdown *Shutdown) *statejson.InternalStateDescription + return initSource } - -type InternalStateGetter func() statejson.InternalStateDescription diff --git a/lambda/interop/model_test.go b/lambda/interop/model_test.go new file mode 100644 index 0000000..9ad4d17 --- /dev/null +++ b/lambda/interop/model_test.go @@ -0,0 +1,27 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMergeSubscriptionMetrics(t *testing.T) { + logsAPIMetrics := map[string]int{ + "server_error": 1, + "client_error": 2, + } + + telemetryAPIMetrics := map[string]int{ + "server_error": 1, + "success": 5, + } + + metrics := MergeSubscriptionMetrics(logsAPIMetrics, telemetryAPIMetrics) + assert.Equal(t, 5, metrics["success"]) + assert.Equal(t, 2, metrics["server_error"]) + assert.Equal(t, 2, metrics["client_error"]) +} diff --git a/lambda/interop/sandbox_model.go b/lambda/interop/sandbox_model.go index dddfcf2..b5d15b0 100644 --- a/lambda/interop/sandbox_model.go +++ b/lambda/interop/sandbox_model.go @@ -3,20 +3,183 @@ package interop -// Init represents an init message and is currently only used in standalone +import ( + "time" + + "go.amzn.com/lambda/fatalerror" +) + +// Init represents an init message +// In Rapid Shim, this is a START GirD message +// In Rapid Daemon, this is an INIT GirP message type Init struct { InvokeID string Handler string AwsKey string AwsSecret string AwsSession string + CredentialsExpiry time.Time SuppressInit bool + InvokeTimeoutMs int64 // timeout duration of whole invoke + InitTimeoutMs int64 // timeout duration for init only XRayDaemonAddress string // only in standalone FunctionName string // only in standalone FunctionVersion string // only in standalone - CorrelationID string // internal use only - // TODO: define new Init type that has the Start fields as well as env vars below. // In standalone mode, these env vars come from test/init but from environment otherwise. CustomerEnvironmentVariables map[string]string - SandboxType + SandboxType SandboxType + // there is no dynamic config at the moment for the runtime domain + OperatorDomainExtraConfig DynamicDomainConfig + RuntimeInfo RuntimeInfo + Bootstrap Bootstrap + EnvironmentVariables EnvironmentVariables // contains env vars for agents and runtime procs +} + +// InitStarted contains metadata about the initialized sandbox +// In Rapid Shim, this translates to a RUNNING GirD message to Slicer +// In Rapid Daemon, this is followed by a SANDBOX GirP message to MM +type InitStarted struct { + WaitStartTimeNs int64 + WaitEndTimeNs int64 + PreLoadTimeNs int64 + PostLoadTimeNs int64 + ExtensionsEnabled bool + Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent +} + +// InitSuccess indicates that runtime/extensions initialization completed successfully +// In Rapid Shim, this translates to a DONE GirD message to Slicer +// In Rapid Daemon, this is followed by a DONEDONE GirP message to MM +type InitSuccess struct { + NumActiveExtensions int // indicates number of active extensions + ExtensionNames string // file names of extensions in /opt/extensions + RuntimeRelease string + LogsAPIMetrics TelemetrySubscriptionMetrics // used if telemetry API enabled + Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent +} + +// InitFailure indicates that runtime/extensions initialization failed due to process exit or /error calls +// In Rapid Shim, this translates to either a DONE or a DONEFAIL GirD message to Slicer (depending on extensions mode) +// However, even on failure, the next invoke is expected to work with a suppressed init - i.e. we init again as aprt of the invoke +type InitFailure struct { + ResetReceived bool // indicates if failure happened due to a reset received + RequestReset bool // Indicates whether reset should be requested on init failure + ErrorType fatalerror.ErrorType + ErrorMessage error + NumActiveExtensions int + RuntimeRelease string // value of the User Agent HTTP header provided by runtime + LogsAPIMetrics TelemetrySubscriptionMetrics + Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent +} + +// ResponseMetrics groups metrics related to the response stream +type ResponseMetrics struct { + RuntimeTimeThrottledMs int64 + RuntimeProducedBytes int64 + RuntimeOutboundThroughputBps int64 +} + +// InvokeMetrics groups metrics related to the invoke phase +type InvokeMetrics struct { + InvokeRequestReadTimeNs int64 + InvokeRequestSizeBytes int64 + RuntimeReadyTime int64 +} + +// InvokeSuccess is the success response to invoke phase end +type InvokeSuccess struct { + RuntimeRelease string // value of the User Agent HTTP header provided by runtime + NumActiveExtensions int + ExtensionNames string + InvokeCompletionTimeNs int64 + InvokeReceivedTime int64 + LogsAPIMetrics TelemetrySubscriptionMetrics + ResponseMetrics ResponseMetrics + InvokeMetrics InvokeMetrics +} + +// InvokeFailure is the failure response to invoke phase end +type InvokeFailure struct { + ResetReceived bool // indicates if failure happened due to a reset received + RequestReset bool // indicates if reset must be requested after the failure + ErrorType fatalerror.ErrorType + ErrorMessage error + RuntimeRelease string // value of the User Agent HTTP header provided by runtime + NumActiveExtensions int + InvokeReceivedTime int64 + LogsAPIMetrics TelemetrySubscriptionMetrics + ResponseMetrics ResponseMetrics + InvokeMetrics InvokeMetrics + ExtensionNames string + DefaultErrorResponse *ErrorResponse // error resp constructed by platform during fn errors +} + +// ResetSuccess is the success response to reset request +type ResetSuccess struct { + ExtensionsResetMs int64 + ErrorType fatalerror.ErrorType + ResponseMetrics ResponseMetrics +} + +// ResetFailure is the failure response to reset request +type ResetFailure struct { + ExtensionsResetMs int64 + ErrorType fatalerror.ErrorType + ResponseMetrics ResponseMetrics +} + +// ShutdownSuccess is the response to a shutdown request +type ShutdownSuccess struct { + ErrorType fatalerror.ErrorType +} + +// SandboxInfoFromInit captures data from init request that +// is required during invoke (e.g. for suppressed init) +type SandboxInfoFromInit struct { + EnvironmentVariables EnvironmentVariables // contains agent env vars (creds, customer, platform) + SandboxType SandboxType // indicating Pre-Warmed, On-Demand etc + RuntimeBootstrap Bootstrap // contains the runtime bootstrap binary path, Cwd, Args, Env, Cmd +} + +// RapidContext expose methods for functionality of the Rapid Core library +type RapidContext interface { + HandleInit(i *Init, started chan<- InitStarted, success chan<- InitSuccess, failure chan<- InitFailure) + HandleInvoke(i *Invoke, sbMetadata SandboxInfoFromInit) (InvokeSuccess, *InvokeFailure) + HandleReset(reset *Reset, invokeReceivedTime int64, InvokeResponseMetrics *InvokeResponseMetrics) (ResetSuccess, *ResetFailure) + HandleShutdown(shutdown *Shutdown) ShutdownSuccess + HandleRestore(restore *Restore) error + Clear() +} + +// SandboxContext represents the sandbox lifecycle context +type SandboxContext interface { + Init(i *Init, timeoutMs int64) (InitStarted, InitContext) + Reset(reset *Reset) (ResetSuccess, *ResetFailure) + Shutdown(shutdown *Shutdown) ShutdownSuccess + Restore(restore *Restore) error + + // TODO: refactor this + // invokeReceivedTime and InvokeResponseMetrics are needed to compute the runtimeDone metrics + // in case of a Reset during an invoke (reset.reason=failure or reset.reason=timeout). + // Ideally: + // - the InvokeContext will have a Reset method to deal with Reset during an invoke and will hold invokeReceivedTime and InvokeResponseMetrics + // - the SandboxContext will have its own Reset/Spindown method + SetInvokeReceivedTime(invokeReceivedTime int64) + SetInvokeResponseMetrics(metrics *InvokeResponseMetrics) +} + +// InitContext represents the lifecycle of a sandbox initialization +type InitContext interface { + Wait() (InitSuccess, *InitFailure) + Reserve() InvokeContext +} + +// InvokeContext represents the lifecycle of a sandbox reservation +type InvokeContext interface { + SendRequest(i *Invoke) + Wait() (InvokeSuccess, *InvokeFailure) +} + +// Restored message is sent to Slicer to inform Runtime Restore Hook execution was successful +type Restored struct { } diff --git a/lambda/logging/doc.go b/lambda/logging/doc.go index 92637a1..a1f7e95 100644 --- a/lambda/logging/doc.go +++ b/lambda/logging/doc.go @@ -2,24 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 /* - RAPID emits or proxies the following sources of logging: -1. Internal logs: RAPID's own application logs into stderr for operational use, visible only internally -2. Function stream-based logs: Runtime's stdout and stderr, read as newline separated lines -3. Function message-based logs: Stock runtimes communicate using a custom TLV protocol over a Unix pipe -4. Extension stream-based logs: Extension's stdout and stderr, read as newline separated lines -5. Platform logs: Logs that RAPID generates, but is visible in customer's logs. - - -It has the following log sinks, which may further be egressed to other sinks (e.g. CloudWatch) by external telemetry agents: - -1. Internal Log File (stderr): stderr is redirected to a file specified by Sandbox Factory via env-vars, and accessible via StreamQuery -2. Stdout: stream-based function logs are output to RAPID's stdout process, and read by a telemetry agent -3. Telemetry API MSG-verb events: function messages-based logs are written using GirP protocol into the console socket specified by Sandbox Factory env-vars -4. Telemetry API LOGX-verb events: extension stream-based logs are written using GirP protocol into the console socket specified by Sandbox Factory env-vars -5. Telemetry API LOGP-verb events: platform logs are written using GirP protocol into the console socket specified by Sandbox Factory env-vars -6. Tail logs: a truncated version of function stream-based and message-based logs are written along with the invocation response to the frontend when 'debug logging' is enabled - + 1. Internal logs: RAPID's own application logs into stderr for operational use, visible only internally + 2. Function stream-based logs: Runtime's stdout and stderr, read as newline separated lines + 3. Function message-based logs: Stock runtimes communicate using a custom TLV protocol over a Unix pipe + 4. Extension stream-based logs: Extension's stdout and stderr, read as newline separated lines + 5. Platform logs: Logs that RAPID generates, but is visible either in customer's logs or via Logs API + (e.g. EXTENSION, RUNTIME, RUNTIMEDONE, IMAGE) */ package logging diff --git a/lambda/logging/internal_log_test.go b/lambda/logging/internal_log_test.go index b94ac88..3ec537f 100644 --- a/lambda/logging/internal_log_test.go +++ b/lambda/logging/internal_log_test.go @@ -8,7 +8,7 @@ import ( "fmt" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "io/ioutil" + "io" "log" "testing" ) @@ -67,14 +67,14 @@ func TestInternalFormatter(t *testing.T) { } func BenchmarkLogPrint(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { log.Print(1, "two", true) } } func BenchmarkLogrusPrint(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { logrus.Print(1, "two", true) } @@ -83,21 +83,21 @@ func BenchmarkLogrusPrint(b *testing.B) { func BenchmarkLogrusPrintInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) for n := 0; n < b.N; n++ { l.Print(1, "two", true) } } func BenchmarkLogPrintf(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { log.Printf("field:%v,field:%v,field:%v", 1, "two", true) } } func BenchmarkLogrusPrintf(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) for n := 0; n < b.N; n++ { logrus.Printf("field:%v,field:%v,field:%v", 1, "two", true) } @@ -106,14 +106,14 @@ func BenchmarkLogrusPrintf(b *testing.B) { func BenchmarkLogrusPrintfInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) for n := 0; n < b.N; n++ { l.Printf("field:%v,field:%v,field:%v", 1, "two", true) } } func BenchmarkLogrusDebugLogLevelDisabled(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) logrus.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { logrus.Debug(1, "two", true) @@ -122,7 +122,7 @@ func BenchmarkLogrusDebugLogLevelDisabled(b *testing.B) { func BenchmarkLogrusDebugLogLevelDisabledInternalFormatter(b *testing.B) { var l = logrus.New() - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) l.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { l.Debug(1, "two", true) @@ -130,7 +130,7 @@ func BenchmarkLogrusDebugLogLevelDisabledInternalFormatter(b *testing.B) { } func BenchmarkLogrusDebugLogLevelEnabled(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) logrus.SetLevel(logrus.DebugLevel) for n := 0; n < b.N; n++ { logrus.Debug(1, "two", true) @@ -140,7 +140,7 @@ func BenchmarkLogrusDebugLogLevelEnabled(b *testing.B) { func BenchmarkLogrusDebugLogLevelEnabledInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) l.SetLevel(logrus.DebugLevel) for n := 0; n < b.N; n++ { l.Debug(1, "two", true) @@ -148,7 +148,7 @@ func BenchmarkLogrusDebugLogLevelEnabledInternalFormatter(b *testing.B) { } func BenchmarkLogrusDebugWithFieldLogLevelDisabled(b *testing.B) { - SetOutput(ioutil.Discard) + SetOutput(io.Discard) logrus.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { logrus.WithField("field", "value").Debug(1, "two", true) @@ -158,7 +158,7 @@ func BenchmarkLogrusDebugWithFieldLogLevelDisabled(b *testing.B) { func BenchmarkLogrusDebugWithFieldLogLevelDisabledInternalFormatter(b *testing.B) { var l = logrus.New() l.SetFormatter(&InternalFormatter{}) - l.SetOutput(ioutil.Discard) + l.SetOutput(io.Discard) l.SetLevel(logrus.InfoLevel) for n := 0; n < b.N; n++ { l.WithField("field", "value").Debug(1, "two", true) diff --git a/lambda/logging/platform_log.go b/lambda/logging/platform_log.go deleted file mode 100644 index 5154f93..0000000 --- a/lambda/logging/platform_log.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "fmt" - "io" - "log" - "strings" -) - -// TODO PlatformLogger interface has this LogExtensionInitEvent() method so it's easier to assert against it in standalone tests; -// TODO However, this makes interface harder to maintain (you are supposed to add new method to PlatformLogger for each event type) -// TODO We need to remove those methods and make PlatformLogger just a log.Logger interface - -// PlatformLogger is a logger that logs platform lines to customers' logs -type PlatformLogger interface { - Printf(fmt string, args ...interface{}) - LogExtensionInitEvent(agentName, state, errorType string, subscriptions []string) -} - -// FormattedPlatformLogger formats and logs platform lines to customers' logs via Telemetry API -type FormattedPlatformLogger struct { - logger *log.Logger -} - -// NewPlatformLogger is a logger for logging Platform log lines into customers' logs -func NewPlatformLogger(output, tailLogWriter io.Writer) *FormattedPlatformLogger { - prefix, flags := "", 0 - return &FormattedPlatformLogger{ - logger: log.New(io.MultiWriter(output, tailLogWriter), prefix, flags), - } -} - -// LogExtensionInitEvent formats and logs a line containing agent info -func (l *FormattedPlatformLogger) LogExtensionInitEvent(agentName, state, errorType string, subscriptions []string) { - format := "EXTENSION\tName: %s\tState: %s\tEvents: [%s]" - line := fmt.Sprintf(format, agentName, state, strings.Join(subscriptions, ",")) - if len(errorType) > 0 { - line += fmt.Sprintf("\tError Type: %s", errorType) - } - l.logger.Println(line) -} - -func (l *FormattedPlatformLogger) Printf(fmt string, args ...interface{}) { - fmt += "\n" // we append newline to the logline because that's how they are separated on recepient - l.logger.Printf(fmt, args...) -} - -func SupernovaInvalidTaskConfigRepr(err error) func(error) string { - return func(unused error) string { - return fmt.Sprintf("IMAGE\tInvalid task config: %s", err) - } -} - -func SupernovaLaunchErrorRepr(entrypoint []string, cmd []string, workingDir string) func(error) string { - return func(err error) string { - return fmt.Sprintf("IMAGE\tLaunch error: %s\tEntrypoint: [%s]\tCmd: [%s]\tWorkingDir: [%s]", - err, - strings.Join(entrypoint, ","), - strings.Join(cmd, ","), - workingDir) - } -} diff --git a/lambda/logging/platform_log_test.go b/lambda/logging/platform_log_test.go deleted file mode 100644 index 8b01778..0000000 --- a/lambda/logging/platform_log_test.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestPlatformLogExtensionLine(t *testing.T) { - var buf bytes.Buffer - var tailLogBuf bytes.Buffer - logger := NewPlatformLogger(&buf, &tailLogBuf) - - logger.LogExtensionInitEvent("agentName", "Registered", "", []string{"INVOKE", "SHUTDOWN"}) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\n", buf.String()) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\n", tailLogBuf.String()) -} - -func TestPlatformLogExtensionLineWithError(t *testing.T) { - var buf bytes.Buffer - var tailLogBuf bytes.Buffer - logger := NewPlatformLogger(&buf, &tailLogBuf) - - errorType := "Extension.FooBar" - logger.LogExtensionInitEvent("agentName", "Registered", errorType, []string{"INVOKE", "SHUTDOWN"}) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\tError Type: "+errorType+"\n", buf.String()) - require.Equal(t, "EXTENSION\tName: agentName\tState: Registered\tEvents: [INVOKE,SHUTDOWN]\tError Type: "+errorType+"\n", tailLogBuf.String()) -} - -func TestPlatformLogPrintf(t *testing.T) { - var buf bytes.Buffer - var tailLogBuf bytes.Buffer - logger := NewPlatformLogger(&buf, &tailLogBuf) - - logger.Printf("bebe %s %d", "as", 12) - require.Equal(t, "bebe as 12\n", buf.String()) - require.Equal(t, "bebe as 12\n", tailLogBuf.String()) -} diff --git a/lambda/logging/taillog.go b/lambda/logging/taillog.go deleted file mode 100644 index 9fe5352..0000000 --- a/lambda/logging/taillog.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "io" - "sync" -) - -// TailLogWriter writes tail/debug log to provided io.Writer -type TailLogWriter struct { - out io.Writer - enabled bool - mutex sync.Mutex -} - -// Enable enables log writer. -func (lw *TailLogWriter) Enable() { - lw.mutex.Lock() - defer lw.mutex.Unlock() - - lw.enabled = true -} - -// Disable disables log writer. -func (lw *TailLogWriter) Disable() { - lw.mutex.Lock() - defer lw.mutex.Unlock() - - lw.enabled = false -} - -// Writer wraps the basic io.Write method -func (lw *TailLogWriter) Write(p []byte) (n int, err error) { - lw.mutex.Lock() - defer lw.mutex.Unlock() - - if lw.enabled { - return lw.out.Write(p) - } - // Else returns a successful write so that MultiWriter won't stop - return len(p), nil -} - -// NewTailLogWriter returns a new invoke tail log writer, default output is discarded until output is configured. -func NewTailLogWriter(w io.Writer) *TailLogWriter { - return &TailLogWriter{ - out: w, - enabled: false, - } -} diff --git a/lambda/logging/taillog_test.go b/lambda/logging/taillog_test.go deleted file mode 100644 index 2bc444c..0000000 --- a/lambda/logging/taillog_test.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDisableDebugLog(t *testing.T) { - buf := new(bytes.Buffer) - tailLogWriter := NewTailLogWriter(buf) - tailLogWriter.Disable() - - tailLogWriter.Write([]byte("hello_world")) - assert.Len(t, buf.String(), 0) -} - -func TestEnableDebugLog(t *testing.T) { - buf := new(bytes.Buffer) - tailLogWriter := NewTailLogWriter(buf) - tailLogWriter.Enable() - - tailLogWriter.Write([]byte("hello_world")) - assert.Equal(t, "hello_world", buf.String()) -} diff --git a/lambda/rapi/handler/agentiniterror_test.go b/lambda/rapi/handler/agentiniterror_test.go index 571031e..50b9143 100644 --- a/lambda/rapi/handler/agentiniterror_test.go +++ b/lambda/rapi/handler/agentiniterror_test.go @@ -6,7 +6,7 @@ package handler import ( "context" "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -60,7 +60,7 @@ func TestAgentInitErrorMissingErrorHeader(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, errAgentMissingHeader, errorResponse.ErrorType) } @@ -77,7 +77,7 @@ func TestAgentInitErrorUnknownAgent(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, errAgentIdentifierUnknown, errorResponse.ErrorType) } @@ -97,7 +97,7 @@ func TestAgentInitErrorAgentInvalidState(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, errAgentInvalidState, errorResponse.ErrorType) } @@ -118,7 +118,7 @@ func TestAgentInitErrorRequestAccepted(t *testing.T) { assert.Equal(t, http.StatusAccepted, responseRecorder.Code) var response model.StatusResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, "OK", response.Status) diff --git a/lambda/rapi/handler/agentnext_test.go b/lambda/rapi/handler/agentnext_test.go index ef14e49..003c4b6 100644 --- a/lambda/rapi/handler/agentnext_test.go +++ b/lambda/rapi/handler/agentnext_test.go @@ -7,7 +7,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -52,7 +52,7 @@ func TestRenderAgentInvokeUnknownAgent(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, errAgentIdentifierUnknown, errorResponse.ErrorType) @@ -75,7 +75,7 @@ func TestRenderAgentInvokeInvalidAgentState(t *testing.T) { assert.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, errAgentInvalidState, errorResponse.ErrorType) @@ -118,7 +118,7 @@ func TestRenderAgentInvokeNextHappy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentInvokeEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -167,7 +167,7 @@ func TestRenderAgentInternalInvokeNextHappy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentInvokeEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -212,7 +212,7 @@ func TestRenderAgentInternalShutdownEvent(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentShutdownEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -254,7 +254,7 @@ func TestRenderAgentExternalShutdownEvent(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentShutdownEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Equal(t, agent.RunningState, agent.GetState()) @@ -297,7 +297,7 @@ func TestRenderAgentInvokeNextHappyEmptyTraceID(t *testing.T) { handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusOK, responseRecorder.Code) var response model.AgentInvokeEvent - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &response) assert.Nil(t, response.Tracing) diff --git a/lambda/rapi/handler/agentregister.go b/lambda/rapi/handler/agentregister.go index 776ac28..8882965 100644 --- a/lambda/rapi/handler/agentregister.go +++ b/lambda/rapi/handler/agentregister.go @@ -6,10 +6,9 @@ package handler import ( "encoding/json" "errors" - "io/ioutil" + "io" "net/http" - "github.com/go-chi/render" log "github.com/sirupsen/logrus" "go.amzn.com/lambda/core" "go.amzn.com/lambda/rapi/model" @@ -26,7 +25,7 @@ type RegisterRequest struct { } func parseRegister(request *http.Request) (*RegisterRequest, error) { - body, err := ioutil.ReadAll(request.Body) + body, err := io.ReadAll(request.Body) if err != nil { return nil, err } @@ -70,7 +69,6 @@ func (h *agentRegisterHandler) ServeHTTP(writer http.ResponseWriter, request *ht } func (h *agentRegisterHandler) renderResponse(agentID string, writer http.ResponseWriter, request *http.Request) { - render.Status(request, http.StatusOK) writer.Header().Set(LambdaAgentIdentifier, agentID) metadata := h.registrationService.GetFunctionMetadata() @@ -81,7 +79,10 @@ func (h *agentRegisterHandler) renderResponse(agentID string, writer http.Respon Handler: metadata.Handler, } - render.JSON(writer, request, resp) + if err := rendering.RenderJSON(http.StatusOK, writer, request, resp); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(writer, err.Error(), http.StatusInternalServerError) + } } func (h *agentRegisterHandler) registerExternalAgent(agent *core.ExternalAgent, registerRequest *RegisterRequest, writer http.ResponseWriter, request *http.Request) { diff --git a/lambda/rapi/handler/agentregister_test.go b/lambda/rapi/handler/agentregister_test.go index 185f249..35456ee 100644 --- a/lambda/rapi/handler/agentregister_test.go +++ b/lambda/rapi/handler/agentregister_test.go @@ -7,7 +7,6 @@ import ( "bytes" "encoding/json" "io" - "io/ioutil" "net/http" "net/http/httptest" "testing" @@ -41,7 +40,7 @@ func TestRenderAgentRegisterInvalidAgentName(t *testing.T) { require.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) require.Equal(t, http.StatusForbidden, responseRecorder.Code) require.Equal(t, errAgentNameInvalid, errorResponse.ErrorType) @@ -63,7 +62,7 @@ func TestRenderAgentRegisterRegistrationClosed(t *testing.T) { require.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) require.Equal(t, http.StatusForbidden, responseRecorder.Code) require.Equal(t, errAgentRegistrationClosed, errorResponse.ErrorType) @@ -88,7 +87,7 @@ func TestRenderAgentRegisterInvalidAgentState(t *testing.T) { require.Equal(t, http.StatusForbidden, responseRecorder.Code) var errorResponse model.ErrorResponse - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) require.Equal(t, http.StatusForbidden, responseRecorder.Code) require.Equal(t, errAgentInvalidState, errorResponse.ErrorType) @@ -311,7 +310,7 @@ func TestRenderAgentResponse(t *testing.T) { require.Equal(t, http.StatusOK, responseRecorder.Code) registerResponse := ExtensionRegisterResponseWithConfig{} - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, ®isterResponse) assert.Equal(t, tt.expectedRegistrationResponse.FunctionName, registerResponse.FunctionName) assert.Equal(t, tt.expectedRegistrationResponse.FunctionVersion, registerResponse.FunctionVersion) diff --git a/lambda/rapi/handler/constants.go b/lambda/rapi/handler/constants.go index 01553f3..5912d71 100644 --- a/lambda/rapi/handler/constants.go +++ b/lambda/rapi/handler/constants.go @@ -20,7 +20,6 @@ const ( errAgentMissingHeader string = "Extension.MissingHeader" errTooManyExtensions string = "Extension.TooManyExtensions" errInvalidEventType string = "Extension.InvalidEventType" - errLogsSubscriptionClosed string = "Logs.SubscriptionClosed" errInvalidRequestFormat string = "InvalidRequestFormat" StateTransitionFailedForExtensionMessageFormat string = "State transition from %s to %s failed for extension %s. Error: %s" diff --git a/lambda/rapi/handler/credentials_test.go b/lambda/rapi/handler/credentials_test.go index fa4a2bd..d5a1090 100644 --- a/lambda/rapi/handler/credentials_test.go +++ b/lambda/rapi/handler/credentials_test.go @@ -21,13 +21,11 @@ const InitCachingAwsKey = "sampleAwsKey" const InitCachingAwsSecret = "sampleAwsSecret" const InitCachingAwsSessionToken = "sampleAwsSessionToken" -func getRequestContext(isServiceBlocked bool) (http.Handler, *http.Request, *httptest.ResponseRecorder) { +func getRequestContext() (http.Handler, *http.Request, *httptest.ResponseRecorder) { flowTest := testdata.NewFlowTest() - if isServiceBlocked { - flowTest.ConfigureForBlockedInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) - } else { - flowTest.ConfigureForInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) - } + + flowTest.ConfigureForInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) + handler := NewCredentialsHandler(flowTest.CredentialsService) responseRecorder := httptest.NewRecorder() appCtx := flowTest.AppCtx @@ -38,14 +36,14 @@ func getRequestContext(isServiceBlocked bool) (http.Handler, *http.Request, *htt } func TestEmptyAuthorizationHeader(t *testing.T) { - handler, request, responseRecorder := getRequestContext(false) + handler, request, responseRecorder := getRequestContext() handler.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusNotFound, responseRecorder.Code) } func TestArbitraryAuthorizationHeader(t *testing.T) { - handler, request, responseRecorder := getRequestContext(false) + handler, request, responseRecorder := getRequestContext() request.Header.Set("Authorization", "randomAuthToken") handler.ServeHTTP(responseRecorder, request) @@ -53,7 +51,7 @@ func TestArbitraryAuthorizationHeader(t *testing.T) { } func TestSuccessfulGet(t *testing.T) { - handler, request, responseRecorder := getRequestContext(false) + handler, request, responseRecorder := getRequestContext() request.Header.Set("Authorization", InitCachingToken) handler.ServeHTTP(responseRecorder, request) @@ -67,25 +65,6 @@ func TestSuccessfulGet(t *testing.T) { expirationTime, err := time.Parse(time.RFC3339, responseMap["Expiration"]) assert.NoError(t, err) durationUntilExpiration := time.Until(expirationTime) - assert.True(t, durationUntilExpiration.Minutes() <= 16 && durationUntilExpiration.Minutes() > 15 && durationUntilExpiration.Hours() < 1) + assert.True(t, durationUntilExpiration.Minutes() <= 30 && durationUntilExpiration.Minutes() > 29 && durationUntilExpiration.Hours() < 1) log.Println(responseRecorder.Body.String()) } - -func TestBlockedGet(t *testing.T) { - handler, request, responseRecorder := getRequestContext(true) - request.Header.Set("Authorization", InitCachingToken) - - timeout := time.After(1 * time.Second) - done := make(chan bool) - - go func() { - handler.ServeHTTP(responseRecorder, request) - done <- true - }() - - select { - case <-done: - t.Fatal("Endpoint should be blocked!") - case <-timeout: - } -} diff --git a/lambda/rapi/handler/initerror.go b/lambda/rapi/handler/initerror.go index 4015a11..d28e2d4 100644 --- a/lambda/rapi/handler/initerror.go +++ b/lambda/rapi/handler/initerror.go @@ -5,11 +5,12 @@ package handler import ( "encoding/json" - "io/ioutil" + "io" "net/http" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" "go.amzn.com/lambda/core" "go.amzn.com/lambda/rapi/rendering" @@ -19,6 +20,7 @@ import ( type initErrorHandler struct { registrationService core.RegistrationService + eventsAPI telemetry.EventsAPI } func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -30,6 +32,10 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R } runtime := h.registrationService.GetRuntime() + + // the previousStateName is needed to define if the init/error is called for INIT or RESTORE + previousStateName := runtime.GetState().Name() + if err := runtime.InitError(); err != nil { log.Warn(err) rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, @@ -39,18 +45,24 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R errorType := request.Header.Get("Lambda-Runtime-Function-Error-Type") - errorBody, err := ioutil.ReadAll(request.Body) + errorBody, err := io.ReadAll(request.Body) if err != nil { log.WithError(err).Warn("Failed to read error body") } + if previousStateName == core.RuntimeRestoringStateName { + h.sendRestoreRuntimeDoneLogEvent() + } else { + h.sendInitRuntimeDoneLogEvent(appCtx) + } + response := &interop.ErrorResponse{ ErrorType: errorType, Payload: errorBody, ContentType: determineJSONContentType(errorBody), } - if err := server.SendErrorResponse(server.GetCurrentInvokeID(), response); err != nil { + if err := server.SendInitErrorResponse(server.GetCurrentInvokeID(), response); err != nil { rendering.RenderInteropError(writer, request, err) return } @@ -62,9 +74,10 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R // NewInitErrorHandler returns a new instance of http handler // for serving /runtime/init/error. -func NewInitErrorHandler(registrationService core.RegistrationService) http.Handler { +func NewInitErrorHandler(registrationService core.RegistrationService, eventsAPI telemetry.EventsAPI) http.Handler { return &initErrorHandler{ registrationService: registrationService, + eventsAPI: eventsAPI, } } @@ -74,3 +87,24 @@ func determineJSONContentType(body []byte) string { } return "application/octet-stream" } + +func (h *initErrorHandler) sendInitRuntimeDoneLogEvent(appCtx appctx.ApplicationContext) { + // ToDo: Convert this to an enum for the whole package to increase readability. + initCachingEnabled := appctx.LoadInitType(appCtx) == appctx.InitCaching + + initSource := interop.InferTelemetryInitSource(initCachingEnabled, appctx.LoadSandboxType(appCtx)) + runtimeDoneData := &telemetry.InitRuntimeDoneData{ + InitSource: initSource, + Status: telemetry.RuntimeDoneFailure, + } + + if err := h.eventsAPI.SendInitRuntimeDone(runtimeDoneData); err != nil { + log.Errorf("Failed to send INITRD: %s", err) + } +} + +func (h *initErrorHandler) sendRestoreRuntimeDoneLogEvent() { + if err := h.eventsAPI.SendRestoreRuntimeDone(telemetry.RuntimeDoneFailure); err != nil { + log.Errorf("Failed to send RESTRD: %s", err) + } +} diff --git a/lambda/rapi/handler/initerror_test.go b/lambda/rapi/handler/initerror_test.go index c2d3d89..c9a5a83 100644 --- a/lambda/rapi/handler/initerror_test.go +++ b/lambda/rapi/handler/initerror_test.go @@ -27,7 +27,7 @@ func runTestInitErrorHandler(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - handler := NewInitErrorHandler(flowTest.RegistrationService) + handler := NewInitErrorHandler(flowTest.RegistrationService, flowTest.EventsAPI) responseRecorder := httptest.NewRecorder() appCtx := flowTest.AppCtx @@ -49,7 +49,7 @@ func runTestInitErrorHandler(t *testing.T) { require.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", responseRecorder.Code, http.StatusAccepted) require.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) - require.Equal(t, "application/json; charset=utf-8", responseRecorder.Header().Get("Content-Type")) + require.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) // Validate init error persisted in the application context. errorResponse := flowTest.InteropServer.ErrorResponse diff --git a/lambda/rapi/handler/invocationerror.go b/lambda/rapi/handler/invocationerror.go index d60b5d6..170c0cb 100644 --- a/lambda/rapi/handler/invocationerror.go +++ b/lambda/rapi/handler/invocationerror.go @@ -6,7 +6,7 @@ package handler import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "go.amzn.com/lambda/interop" @@ -25,6 +25,11 @@ const errorWithCauseContentType = "application/vnd.aws.lambda.error.cause+json" const xrayErrorCauseHeaderName = "Lambda-Runtime-Function-XRay-Error-Cause" const invalidErrorBodyMessage = "Invalid error body" +const ( + contentTypeHeader = "Content-Type" + functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" +) + type invocationErrorHandler struct { registrationService core.RegistrationService } @@ -52,7 +57,7 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * var contentType string var err error - switch request.Header.Get("Content-Type") { + switch request.Header.Get(contentTypeHeader) { case errorWithCauseContentType: errorBody, errorCause, err = h.getErrorBodyForErrorCauseContentType(request) contentType = "application/json" @@ -62,18 +67,20 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * default: errorBody, err = h.getErrorBody(request) errorCause = h.getValidatedErrorCause(request.Header) - contentType = request.Header.Get("Content-Type") + contentType = request.Header.Get(contentTypeHeader) } + functionResponseMode := request.Header.Get(functionResponseModeHeader) if err != nil { log.WithError(err).Warn("Failed to parse error body") } response := &interop.ErrorResponse{ - ErrorType: errorType, - Payload: errorBody, - ErrorCause: errorCause, - ContentType: contentType, + ErrorType: errorType, + Payload: errorBody, + ErrorCause: errorCause, + ContentType: contentType, + FunctionResponseMode: functionResponseMode, } if err := server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), response); err != nil { @@ -95,7 +102,7 @@ func (h *invocationErrorHandler) getErrorType(headers http.Header) string { } func (h *invocationErrorHandler) getErrorBody(request *http.Request) ([]byte, error) { - errorBody, err := ioutil.ReadAll(request.Body) + errorBody, err := io.ReadAll(request.Body) if err != nil { return nil, fmt.Errorf("error reading request body: %s", err) } @@ -120,7 +127,7 @@ func (h *invocationErrorHandler) getValidatedErrorCause(headers http.Header) jso } func (h *invocationErrorHandler) getErrorBodyForErrorCauseContentType(request *http.Request) ([]byte, json.RawMessage, error) { - errorBody, err := ioutil.ReadAll(request.Body) + errorBody, err := io.ReadAll(request.Body) if err != nil { return nil, nil, fmt.Errorf("error reading request body: %s", err) } diff --git a/lambda/rapi/handler/invocationerror_test.go b/lambda/rapi/handler/invocationerror_test.go index 6defa14..2f177fe 100644 --- a/lambda/rapi/handler/invocationerror_test.go +++ b/lambda/rapi/handler/invocationerror_test.go @@ -77,7 +77,7 @@ func runTestInvocationErrorHandler(t *testing.T) { assert.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", responseRecorder.Code, http.StatusAccepted) assert.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) - assert.Equal(t, "application/json; charset=utf-8", responseRecorder.Header().Get("Content-Type")) + assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) errorResponse := flowTest.InteropServer.ErrorResponse assert.NotNil(t, errorResponse) @@ -268,7 +268,8 @@ func TestInvocationResponsePayloadIsDefaultErrorMessageWhenRequestParsingFailsFo invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} request := httptest.NewRequest("POST", "/", bytes.NewReader(invalidRequestBody)) request = addInvocationID(request, invoke.ID) - request.Header.Set("Content-Type", errorWithCauseContentType) + request.Header.Set(contentTypeHeader, errorWithCauseContentType) + request.Header.Set(functionResponseModeHeader, "function-response-mode") // Corresponding invoke must be placed into appCtx. flowTest.ConfigureForInvoke(context.Background(), invoke) @@ -280,6 +281,7 @@ func TestInvocationResponsePayloadIsDefaultErrorMessageWhenRequestParsingFailsFo assert.NotNil(t, errorResponse) assert.Nil(t, flowTest.InteropServer.Response) assert.Equal(t, "application/octet-stream", flowTest.InteropServer.ResponseContentType) + assert.Equal(t, "function-response-mode", flowTest.InteropServer.FunctionResponseMode) invokeResponsePayload := errorResponse.Payload diff --git a/lambda/rapi/handler/invocationnext_test.go b/lambda/rapi/handler/invocationnext_test.go index beebd97..5bddb86 100644 --- a/lambda/rapi/handler/invocationnext_test.go +++ b/lambda/rapi/handler/invocationnext_test.go @@ -92,7 +92,7 @@ func TestRenderInvoke(t *testing.T) { assert.Equal(t, invokePayload, responseRecorder.Body.String()) } -//Cgo calls removed due to crashes while spawning threads under memory pressure. +// Cgo calls removed due to crashes while spawning threads under memory pressure. func TestRenderInvokeDoesNotCallCgo(t *testing.T) { cgoCallsBefore := runtime.NumCgoCall() TestRenderInvoke(t) diff --git a/lambda/rapi/handler/invocationresponse.go b/lambda/rapi/handler/invocationresponse.go index 7c15342..7e47d2e 100644 --- a/lambda/rapi/handler/invocationresponse.go +++ b/lambda/rapi/handler/invocationresponse.go @@ -15,7 +15,10 @@ import ( log "github.com/sirupsen/logrus" ) -const contentTypeOverrideHeaderName = "Content-Type" +const ( + StreamingFunctionResponseMode = "streaming" + ErrInvalidResponseModeHeader = "Runtime.InvalidResponseModeHeader" +) type invocationResponseHandler struct { registrationService core.RegistrationService @@ -39,9 +42,23 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques invokeID := chi.URLParam(request, "awsrequestid") - responseContentType := request.Header.Get(contentTypeOverrideHeaderName) + headers := map[string]string{contentTypeHeader: request.Header.Get(contentTypeHeader)} + if functionResponseMode := request.Header.Get(functionResponseModeHeader); functionResponseMode != "" { + switch functionResponseMode { + case StreamingFunctionResponseMode: + headers[functionResponseModeHeader] = functionResponseMode + default: + errorResponse := &interop.ErrorResponse{ + ErrorType: ErrInvalidResponseModeHeader, + ContentType: request.Header.Get(contentTypeHeader), + } + _ = server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), errorResponse) + rendering.RenderInvalidFunctionResponseMode(writer, request) + return + } + } - if err := server.SendResponse(invokeID, responseContentType, request.Body); err != nil { + if err := server.SendResponse(invokeID, headers, request.Body, request.Trailer, &interop.CancellableRequest{Request: request}); err != nil { switch err := err.(type) { case *interop.ErrorResponseTooLarge: if server.SendErrorResponse(invokeID, err.AsInteropError()) != nil { @@ -66,6 +83,19 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques rendering.RenderRequestEntityTooLarge(writer, request) return + + case *interop.ErrTruncatedResponse: + if err := runtime.ResponseSent(); err != nil { + log.Panic(err) + } + + rendering.RenderTruncatedHTTPRequestError(writer, request) + return + + case *interop.ErrInternalPlatformError: + rendering.RenderInternalServerError(writer, request) + return + default: rendering.RenderInteropError(writer, request, err) return diff --git a/lambda/rapi/handler/invocationresponse_test.go b/lambda/rapi/handler/invocationresponse_test.go index e40a5bf..7c0b220 100644 --- a/lambda/rapi/handler/invocationresponse_test.go +++ b/lambda/rapi/handler/invocationresponse_test.go @@ -8,7 +8,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -55,7 +55,7 @@ func TestResponseTooLarge(t *testing.T) { responseRecorder.Code, http.StatusRequestEntityTooLarge) expectedAPIResponse := fmt.Sprintf("{\"errorMessage\":\"Exceeded maximum allowed payload size (%d bytes).\",\"errorType\":\"RequestEntityTooLarge\"}\n", interop.MaxPayloadSize) - body, err := ioutil.ReadAll(responseRecorder.Body) + body, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) test.AssertJsonsEqual(t, []byte(expectedAPIResponse), body) @@ -98,7 +98,8 @@ func TestResponseAccepted(t *testing.T) { request := httptest.NewRequest("", "/", bytes.NewReader(responseBody)) request = addInvocationID(request, invoke.ID) - request.Header.Set(contentTypeOverrideHeaderName, "application/json") + request.Header.Set(contentTypeHeader, "application/json") + request.Header.Set(functionResponseModeHeader, "streaming") handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) // Assertions @@ -106,7 +107,7 @@ func TestResponseAccepted(t *testing.T) { responseRecorder.Code, http.StatusAccepted) expectedAPIResponse := "{\"status\":\"OK\"}\n" - body, err := ioutil.ReadAll(responseRecorder.Body) + body, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) test.AssertJsonsEqual(t, []byte(expectedAPIResponse), body) @@ -114,6 +115,92 @@ func TestResponseAccepted(t *testing.T) { assert.NotNil(t, response) assert.Nil(t, flowTest.InteropServer.ErrorResponse) assert.Equal(t, "application/json", flowTest.InteropServer.ResponseContentType) + assert.Equal(t, "streaming", flowTest.InteropServer.FunctionResponseMode) assert.Equal(t, responseBody, response, "Persisted response data in app context must match the submitted.") } + +func TestResponseWithDifferentFunctionResponseModes(t *testing.T) { + type testCase struct { + providedFunctionResponseMode string + expectedFunctionResponseMode string + expectedAPIResponse string + expectedStatusCode int + expectedErrorResponse bool + } + testCases := []testCase{ + { + providedFunctionResponseMode: "", + expectedFunctionResponseMode: "", + expectedAPIResponse: "{\"status\":\"OK\"}\n", + expectedStatusCode: http.StatusAccepted, + expectedErrorResponse: false, + }, + { + providedFunctionResponseMode: "streaming", + expectedFunctionResponseMode: "streaming", + expectedAPIResponse: "{\"status\":\"OK\"}\n", + expectedStatusCode: http.StatusAccepted, + expectedErrorResponse: false, + }, + { + providedFunctionResponseMode: "invalid-mode", + expectedFunctionResponseMode: "", + expectedAPIResponse: "{\"errorMessage\":\"Invalid function response mode\", \"errorType\":\"InvalidFunctionResponseMode\"}\n", + expectedStatusCode: http.StatusBadRequest, + expectedErrorResponse: true, + }, + } + + for _, testCase := range testCases { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + flowTest.Runtime.Ready() + handler := NewInvocationResponseHandler(flowTest.RegistrationService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + // Invoke that we are sending response for must be placed into appCtx. + invoke := &interop.Invoke{ + ID: "InvocationID1", + InvokedFunctionArn: "arn::dummy1", + CognitoIdentityID: "CognitoidentityID1", + CognitoIdentityPoolID: "CognitoidentityPollID1", + DeadlineNs: "deadlinens1", + ClientContext: "clientcontext1", + ContentType: "application/json", + Payload: strings.NewReader(`{"message": "hello"}`), + } + + flowTest.ConfigureForInvoke(context.Background(), invoke) + + // Invocation response submitted by runtime. + var responseBody = []byte("{'foo': 'bar'}") + + request := httptest.NewRequest("", "/", bytes.NewReader(responseBody)) + request = addInvocationID(request, invoke.ID) + request.Header.Set(functionResponseModeHeader, testCase.providedFunctionResponseMode) + handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) + + // Assertions + assert.Equal(t, testCase.expectedStatusCode, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", + responseRecorder.Code, testCase.expectedStatusCode) + + body, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + test.AssertJsonsEqual(t, []byte(testCase.expectedAPIResponse), body) + + if testCase.expectedErrorResponse { + assert.NotNil(t, flowTest.InteropServer.ErrorResponse) + assert.Nil(t, flowTest.InteropServer.Response) + assert.Equal(t, "Runtime.InvalidResponseModeHeader", flowTest.InteropServer.ErrorResponse.ErrorType) + } else { + assert.NotNil(t, flowTest.InteropServer.Response) + assert.Nil(t, flowTest.InteropServer.ErrorResponse) + assert.Equal(t, responseBody, flowTest.InteropServer.Response, + "Persisted response data in app context must match the submitted.") + } + + assert.Equal(t, testCase.expectedFunctionResponseMode, flowTest.InteropServer.FunctionResponseMode) + } +} diff --git a/lambda/rapi/handler/restorenext.go b/lambda/rapi/handler/restorenext.go new file mode 100644 index 0000000..ecff059 --- /dev/null +++ b/lambda/rapi/handler/restorenext.go @@ -0,0 +1,40 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "net/http" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/rapi/rendering" +) + +type restoreNextHandler struct { + registrationService core.RegistrationService + renderingService *rendering.EventRenderingService +} + +func (h *restoreNextHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + runtime := h.registrationService.GetRuntime() + err := runtime.RestoreReady() + if err != nil { + log.Warn(err) + rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, runtime.GetState().Name(), core.RuntimeReadyStateName, err) + return + } + err = h.renderingService.RenderRuntimeEvent(writer, request) + if err != nil { + log.Error(err) + rendering.RenderInternalServerError(writer, request) + return + } +} + +func NewRestoreNextHandler(registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { + return &restoreNextHandler{ + registrationService: registrationService, + renderingService: renderingService, + } +} diff --git a/lambda/rapi/handler/restorenext_test.go b/lambda/rapi/handler/restorenext_test.go new file mode 100644 index 0000000..7018d98 --- /dev/null +++ b/lambda/rapi/handler/restorenext_test.go @@ -0,0 +1,87 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" +) + +func TestRenderRestoreNext(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + handler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + flowTest.ConfigureForRestore() + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusOK, responseRecorder.Code) +} + +func TestBrokenRenderer(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + handler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + flowTest.ConfigureForRestore() + flowTest.RenderingService.SetRenderer(&mockBrokenRenderer{}) + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) + + assert.JSONEq(t, `{"errorMessage":"Internal Server Error","errorType":"InternalServerError"}`, responseRecorder.Body.String()) +} + +func TestRenderRestoreAfterInvoke(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + deadlineNs := 12345 + invokePayload := "Payload" + invoke := &interop.Invoke{ + TraceID: "Root=RootID;Parent=LambdaFrontend;Sampled=1", + ID: "ID", + InvokedFunctionArn: "InvokedFunctionArn", + CognitoIdentityID: "CognitoIdentityId1", + CognitoIdentityPoolID: "CognitoIdentityPoolId1", + ClientContext: "ClientContext", + DeadlineNs: strconv.Itoa(deadlineNs), + ContentType: "image/png", + Payload: strings.NewReader(invokePayload), + } + + ctx := telemetry.NewTraceContext(context.Background(), "RootID", "InvocationSubegmentID") + flowTest.ConfigureForInvoke(ctx, invoke) + + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + restoreHandler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + restoreRequest := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + responseRecorder = httptest.NewRecorder() + restoreHandler.ServeHTTP(responseRecorder, restoreRequest) + + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) +} diff --git a/lambda/rapi/handler/runtimelogs.go b/lambda/rapi/handler/runtimelogs.go index 9b4e406..99941b0 100644 --- a/lambda/rapi/handler/runtimelogs.go +++ b/lambda/rapi/handler/runtimelogs.go @@ -7,7 +7,7 @@ import ( "bytes" "errors" "fmt" - "io/ioutil" + "io" "net/http" "go.amzn.com/lambda/core" @@ -20,8 +20,8 @@ import ( ) type runtimeLogsHandler struct { - registrationService core.RegistrationService - logsSubscriptionAPI telemetry.LogsSubscriptionAPI + registrationService core.RegistrationService + telemetrySubscription telemetry.SubscriptionAPI } func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -31,10 +31,10 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http switch err := err.(type) { case *ErrAgentIdentifierUnknown: rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension "+err.agentID.String()) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } return } @@ -45,21 +45,21 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http if err != nil { log.Error(err) rendering.RenderInternalServerError(writer, request) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) return } - respBody, status, headers, err := h.logsSubscriptionAPI.Subscribe(agentName, bytes.NewReader(body), request.Header) + respBody, status, headers, err := h.telemetrySubscription.Subscribe(agentName, bytes.NewReader(body), request.Header) if err != nil { log.Errorf("Telemetry API error: %s", err) switch err { case logsapi.ErrTelemetryServiceOff: rendering.RenderForbiddenWithTypeMsg(writer, request, - errLogsSubscriptionClosed, "Logs API subscription is closed already") - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.GetServiceClosedErrorType(), h.telemetrySubscription.GetServiceClosedErrorMessage()) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } return } @@ -67,11 +67,11 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http rendering.RenderRuntimeLogsResponse(writer, respBody, status, headers) switch status / 100 { case 2: // 2xx - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeSuccess, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeSuccess, 1) case 4: // 4xx - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) case 5: // 5xx - h.logsSubscriptionAPI.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) } } @@ -114,7 +114,7 @@ func (h *runtimeLogsHandler) getAgentName(agentID uuid.UUID) (string, bool) { } func (h *runtimeLogsHandler) getBody(writer http.ResponseWriter, request *http.Request) ([]byte, error) { - body, err := ioutil.ReadAll(request.Body) + body, err := io.ReadAll(request.Body) if err != nil { return nil, fmt.Errorf("Failed to read error body: %s", err) } @@ -122,11 +122,11 @@ func (h *runtimeLogsHandler) getBody(writer http.ResponseWriter, request *http.R return body, nil } -// NewRuntimeLogsHandler returns a new instance of http handler +// NewRuntimeTelemetrySubscriptionHandler returns a new instance of http handler // for serving /runtime/logs -func NewRuntimeLogsHandler(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.LogsSubscriptionAPI) http.Handler { +func NewRuntimeTelemetrySubscriptionHandler(registrationService core.RegistrationService, telemetrySubscription telemetry.SubscriptionAPI) http.Handler { return &runtimeLogsHandler{ - registrationService: registrationService, - logsSubscriptionAPI: logsSubscriptionAPI, + registrationService: registrationService, + telemetrySubscription: telemetrySubscription, } } diff --git a/lambda/rapi/handler/runtimelogs_stub.go b/lambda/rapi/handler/runtimelogs_stub.go index 0ce472e..f540e9b 100644 --- a/lambda/rapi/handler/runtimelogs_stub.go +++ b/lambda/rapi/handler/runtimelogs_stub.go @@ -6,27 +6,48 @@ package handler import ( "net/http" + log "github.com/sirupsen/logrus" "go.amzn.com/lambda/rapi/model" - - "github.com/go-chi/render" + "go.amzn.com/lambda/rapi/rendering" ) const ( - telemetryAPIDisabledErrorType = "Logs.NotSupported" + logsAPIDisabledErrorType = "Logs.NotSupported" + telemetryAPIDisabledErrorType = "Telemetry.NotSupported" ) -type runtimeLogsStubHandler struct{} +type runtimeLogsStubAPIHandler struct{} -func (h *runtimeLogsStubHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - render.Status(request, http.StatusAccepted) - render.JSON(writer, request, &model.ErrorResponse{ - ErrorType: telemetryAPIDisabledErrorType, +func (h *runtimeLogsStubAPIHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if err := rendering.RenderJSON(http.StatusAccepted, writer, request, &model.ErrorResponse{ + ErrorType: logsAPIDisabledErrorType, ErrorMessage: "Logs API is not supported", - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(writer, err.Error(), http.StatusInternalServerError) + } +} + +// NewRuntimeLogsAPIStubHandler returns a new instance of http handler +// for serving /runtime/logs when a telemetry service implementation is absent +func NewRuntimeLogsAPIStubHandler() http.Handler { + return &runtimeLogsStubAPIHandler{} +} + +type runtimeTelemetryAPIStubHandler struct{} + +func (h *runtimeTelemetryAPIStubHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + if err := rendering.RenderJSON(http.StatusAccepted, writer, request, &model.ErrorResponse{ + ErrorType: telemetryAPIDisabledErrorType, + ErrorMessage: "Telemetry API is not supported", + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(writer, err.Error(), http.StatusInternalServerError) + } } -// NewRuntimeLogsStubHandler returns a new instance of http handler +// NewRuntimeTelemetryAPIStubHandler returns a new instance of http handler // for serving /runtime/logs when a telemetry service implementation is absent -func NewRuntimeLogsStubHandler() http.Handler { - return &runtimeLogsStubHandler{} +func NewRuntimeTelemetryAPIStubHandler() http.Handler { + return &runtimeTelemetryAPIStubHandler{} } diff --git a/lambda/rapi/handler/runtimelogs_stub_test.go b/lambda/rapi/handler/runtimelogs_stub_test.go index 4826d12..5b27983 100644 --- a/lambda/rapi/handler/runtimelogs_stub_test.go +++ b/lambda/rapi/handler/runtimelogs_stub_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" ) -func TestSuccessfulRuntimeLogsStub202Response(t *testing.T) { - handler := NewRuntimeLogsStubHandler() +func TestSuccessfulRuntimeLogsAPIStub202Response(t *testing.T) { + handler := NewRuntimeLogsAPIStubHandler() requestBody := []byte(`foobar`) request := httptest.NewRequest("PUT", "/logs", bytes.NewBuffer(requestBody)) responseRecorder := httptest.NewRecorder() @@ -23,3 +23,15 @@ func TestSuccessfulRuntimeLogsStub202Response(t *testing.T) { assert.Equal(t, http.StatusAccepted, responseRecorder.Code) assert.JSONEq(t, `{"errorMessage":"Logs API is not supported","errorType":"Logs.NotSupported"}`, responseRecorder.Body.String()) } + +func TestSuccessfulRuntimeTelemetryAPIStub202Response(t *testing.T) { + handler := NewRuntimeTelemetryAPIStubHandler() + requestBody := []byte(`foobar`) + request := httptest.NewRequest("PUT", "/telemetry", bytes.NewBuffer(requestBody)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.JSONEq(t, `{"errorMessage":"Telemetry API is not supported","errorType":"Telemetry.NotSupported"}`, responseRecorder.Body.String()) +} diff --git a/lambda/rapi/handler/runtimelogs_test.go b/lambda/rapi/handler/runtimelogs_test.go index b7db6df..892d61e 100644 --- a/lambda/rapi/handler/runtimelogs_test.go +++ b/lambda/rapi/handler/runtimelogs_test.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "testing" @@ -23,30 +22,45 @@ import ( "go.amzn.com/lambda/rapidcore/telemetry/logsapi" ) -type mockLogsSubscriptionAPI struct{ mock.Mock } +type mockSubscriptionAPI struct{ mock.Mock } -func (s *mockLogsSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { +func (s *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { args := s.Called(agentName, body, headers) return args.Get(0).([]byte), args.Int(1), args.Get(2).(map[string][]string), args.Error(3) } -func (s *mockLogsSubscriptionAPI) RecordCounterMetric(metricName string, count int) { +func (s *mockSubscriptionAPI) RecordCounterMetric(metricName string, count int) { s.Called(metricName, count) } -func (s *mockLogsSubscriptionAPI) FlushMetrics() interop.LogsAPIMetrics { +func (s *mockSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { args := s.Called() - return args.Get(0).(interop.LogsAPIMetrics) + return args.Get(0).(interop.TelemetrySubscriptionMetrics) } -func (s *mockLogsSubscriptionAPI) Clear() { +func (s *mockSubscriptionAPI) Clear() { s.Called() } -func (s *mockLogsSubscriptionAPI) TurnOff() { +func (s *mockSubscriptionAPI) TurnOff() { s.Called() } +func (s *mockSubscriptionAPI) GetEndpointURL() string { + args := s.Called() + return args.Get(0).(string) +} + +func (s *mockSubscriptionAPI) GetServiceClosedErrorMessage() string { + args := s.Called() + return args.Get(0).(string) +} + +func (s *mockSubscriptionAPI) GetServiceClosedErrorType() string { + args := s.Called() + return args.Get(0).(string) +} + func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} respBody, respStatus, respHeaders := []byte(`barbaz`), http.StatusNotFound, map[string][]string{"K": []string{"V1", "V2"}} @@ -60,11 +74,11 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return(respBody, respStatus, respHeaders, nil) - logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -77,10 +91,10 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - logsSubscriptionAPI.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) assert.Equal(t, respStatus, responseRecorder.Code) @@ -98,10 +112,10 @@ func TestErrorUnregisteredAgentID(t *testing.T) { core.NewInvokeFlowSynchronization(), ) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -114,16 +128,16 @@ func TestErrorUnregisteredAgentID(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) expectedErrorBody := fmt.Sprintf(`{"errorMessage":"Unknown extension %s","errorType":"Extension.UnknownExtensionIdentifier"}`+"\n", invalidAgentID.String()) - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json; charset=utf-8"}}) + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) } func TestErrorTelemetryAPICallFailure(t *testing.T) { @@ -139,11 +153,11 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - logsSubscriptionAPI.On("RecordCounterMetric", serverErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", serverErrMetric, 1) - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -156,16 +170,16 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) expectedErrorBody := `{"errorMessage":"Internal Server Error","errorType":"InternalServerError"}` + "\n" - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json; charset=utf-8"}}) + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", serverErrMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", serverErrMetric, 1) } func TestRenderLogsSubscriptionClosed(t *testing.T) { @@ -181,11 +195,13 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { agent, err := registrationService.CreateExternalAgent(agentName) assert.NoError(t, err) - logsSubscriptionAPI := &mockLogsSubscriptionAPI{} - logsSubscriptionAPI.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - logsSubscriptionAPI.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Logs API subscription is closed already") + telemetrySubscription.On("GetServiceClosedErrorType").Return("Logs.SubscriptionClosed") - handler := NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI) + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) for k, vals := range reqHeaders { for _, v := range vals { @@ -198,14 +214,58 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - recordedBody, err := ioutil.ReadAll(responseRecorder.Body) + recordedBody, err := io.ReadAll(responseRecorder.Body) assert.NoError(t, err) expectedErrorBody := `{"errorMessage":"Logs API subscription is closed already","errorType":"Logs.SubscriptionClosed"}` + "\n" - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json; charset=utf-8"}}) + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) + + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + assert.Equal(t, expectedErrorBody, string(recordedBody)) + assert.Equal(t, expectedHeaders, responseRecorder.Header()) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) +} + +func TestRenderTelemetrySubscriptionClosed(t *testing.T) { + agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} + apiError := logsapi.ErrTelemetryServiceOff + clientErrMetric := logsapi.SubscribeClientErr + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization(), + core.NewInvokeFlowSynchronization(), + ) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Telemetry API subscription is closed already") + telemetrySubscription.On("GetServiceClosedErrorType").Return("Telemetry.SubscriptionClosed") + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + expectedErrorBody := `{"errorMessage":"Telemetry API subscription is closed already","errorType":"Telemetry.SubscriptionClosed"}` + "\n" + expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) assert.Equal(t, expectedErrorBody, string(recordedBody)) assert.Equal(t, expectedHeaders, responseRecorder.Header()) - logsSubscriptionAPI.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) } diff --git a/lambda/rapi/middleware/middleware_test.go b/lambda/rapi/middleware/middleware_test.go index 7b37de9..a0b9134 100644 --- a/lambda/rapi/middleware/middleware_test.go +++ b/lambda/rapi/middleware/middleware_test.go @@ -7,7 +7,7 @@ import ( "bytes" "context" "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -58,7 +58,7 @@ func TestAgentUniqueIdentifierHeaderValidatorForbidden(t *testing.T) { responseRecorder := httptest.NewRecorder() router.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - respBody, _ := ioutil.ReadAll(responseRecorder.Body) + respBody, _ := io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, handler.ErrAgentIdentifierMissing, errorResponse.ErrorType) @@ -66,7 +66,7 @@ func TestAgentUniqueIdentifierHeaderValidatorForbidden(t *testing.T) { request.Header.Set(handler.LambdaAgentIdentifier, "invalid-unique-identifier") router.ServeHTTP(responseRecorder, request) assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - respBody, _ = ioutil.ReadAll(responseRecorder.Body) + respBody, _ = io.ReadAll(responseRecorder.Body) json.Unmarshal(respBody, &errorResponse) assert.Equal(t, handler.ErrAgentIdentifierInvalid, errorResponse.ErrorType) } diff --git a/lambda/rapi/model/tracing.go b/lambda/rapi/model/tracing.go index af90e8f..83f97e8 100644 --- a/lambda/rapi/model/tracing.go +++ b/lambda/rapi/model/tracing.go @@ -3,14 +3,21 @@ package model +type TracingType string + const ( // XRayTracingType represents an X-Ray Tracing object type - XRayTracingType = "X-Amzn-Trace-Id" + XRayTracingType TracingType = "X-Amzn-Trace-Id" +) + +const ( + XRaySampled = "1" + XRayNonSampled = "0" ) // Tracing object returned as part of agent Invoke event type Tracing struct { - Type string `json:"type"` + Type TracingType `json:"type"` XRayTracing } diff --git a/lambda/rapi/rendering/doc.go b/lambda/rapi/rendering/doc.go index 4573638..bc359a1 100644 --- a/lambda/rapi/rendering/doc.go +++ b/lambda/rapi/rendering/doc.go @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 /* - Package rendering provides stateful event rendering service. State of the rendering service should be set from the main event dispatch thread @@ -17,6 +16,5 @@ Example of INVOKE event: [main] // release threads registered for INVOKE event [thread] // receives INVOKE event - */ package rendering diff --git a/lambda/rapi/rendering/render_json.go b/lambda/rapi/rendering/render_json.go new file mode 100644 index 0000000..8cea816 --- /dev/null +++ b/lambda/rapi/rendering/render_json.go @@ -0,0 +1,33 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rendering + +import ( + "bytes" + "encoding/json" + log "github.com/sirupsen/logrus" + "net/http" +) + +// RenderJSON: +// - marshals 'v' to JSON, automatically escaping HTML +// - sets the Content-Type as application/json +// - sets the HTTP response status code +// - returns an error if it occurred before writing to response +func RenderJSON(status int, w http.ResponseWriter, r *http.Request, v interface{}) error { + buf := &bytes.Buffer{} + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(true) + if err := enc.Encode(v); err != nil { + return err + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if _, err := w.Write(buf.Bytes()); err != nil { + log.WithError(err).Warn("Error while writing response body") + } + + return nil +} diff --git a/lambda/rapi/rendering/rendering.go b/lambda/rapi/rendering/rendering.go index c75d010..0edfb68 100644 --- a/lambda/rapi/rendering/rendering.go +++ b/lambda/rapi/rendering/rendering.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "strconv" "sync" @@ -19,7 +18,6 @@ import ( "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi/model" - "github.com/go-chi/render" "github.com/google/uuid" log "github.com/sirupsen/logrus" ) @@ -33,6 +31,8 @@ const ( ErrorTypeInvalidRequestID = "InvalidRequestID" // ErrorTypeRequestEntityTooLarge error type for payload too large ErrorTypeRequestEntityTooLarge = "RequestEntityTooLarge" + // ErrorTypeTruncatedHTTPRequest error type for truncated HTTP request + ErrorTypeTruncatedHTTPRequest = "TruncatedHTTPRequest" ) // ErrRenderingServiceStateNotSet returned when state not set @@ -100,6 +100,9 @@ type InvokeRenderer struct { metrics InvokeRendererMetrics } +type RestoreRenderer struct { +} + // NewAgentInvokeEvent forms a new AgentInvokeEvent from INVOKE request func NewAgentInvokeEvent(req *interop.Invoke) (*model.AgentInvokeEvent, error) { deadlineMono, err := strconv.ParseInt(req.DeadlineNs, 10, 64) @@ -145,7 +148,7 @@ func (s *InvokeRenderer) bufferInvokeRequest() error { if nil == s.requestBuffer { reader := io.LimitReader(s.invoke.Payload, interop.MaxPayloadSize) start := time.Now() - s.requestBuffer, err = ioutil.ReadAll(reader) + s.requestBuffer, err = io.ReadAll(reader) s.metrics = InvokeRendererMetrics{ ReadTime: time.Since(start), SizeBytes: len(s.requestBuffer), @@ -193,6 +196,15 @@ func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request return nil } +func (s *RestoreRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { + writer.WriteHeader(http.StatusOK) + return nil +} + +func (s *RestoreRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { + return nil +} + // NewInvokeRenderer returns new invoke event renderer func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, traceParser func(context.Context, *interop.Invoke) string) *InvokeRenderer { return &InvokeRenderer{ @@ -204,6 +216,10 @@ func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, traceParser } } +func NewRestoreRenderer() *RestoreRenderer { + return &RestoreRenderer{} +} + func (s *InvokeRenderer) GetMetrics() InvokeRendererMetrics { s.requestMutex.Lock() defer s.requestMutex.Unlock() @@ -283,46 +299,78 @@ func renderAgentInvokeHeaders(writer http.ResponseWriter, eventID uuid.UUID) { // RenderForbiddenWithTypeMsg method for rendering error response func RenderForbiddenWithTypeMsg(w http.ResponseWriter, r *http.Request, errorType string, format string, args ...interface{}) { - render.Status(r, http.StatusForbidden) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusForbidden, w, r, &model.ErrorResponse{ ErrorType: errorType, ErrorMessage: fmt.Sprintf(format, args...), - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderInternalServerError method for rendering error response func RenderInternalServerError(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusInternalServerError) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusInternalServerError, w, r, &model.ErrorResponse{ ErrorMessage: "Internal Server Error", ErrorType: ErrorTypeInternalServerError, - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderRequestEntityTooLarge method for rendering error response func RenderRequestEntityTooLarge(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusRequestEntityTooLarge) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusRequestEntityTooLarge, w, r, &model.ErrorResponse{ ErrorMessage: fmt.Sprintf("Exceeded maximum allowed payload size (%d bytes).", interop.MaxPayloadSize), ErrorType: ErrorTypeRequestEntityTooLarge, - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderTruncatedHTTPRequestError method for rendering error response +func RenderTruncatedHTTPRequestError(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "HTTP request detected as truncated", + ErrorType: ErrorTypeTruncatedHTTPRequest, + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderInvalidRequestID renders invalid request ID error response func RenderInvalidRequestID(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusBadRequest) - render.JSON(w, r, &model.ErrorResponse{ + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ ErrorMessage: "Invalid request ID", ErrorType: "InvalidRequestID", - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInvalidFunctionResponseMode renders invalid function response mode response +func RenderInvalidFunctionResponseMode(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "Invalid function response mode", + ErrorType: "InvalidFunctionResponseMode", + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderAccepted method for rendering accepted status response func RenderAccepted(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusAccepted) - render.JSON(w, r, &model.StatusResponse{ + if err := RenderJSON(http.StatusAccepted, w, r, &model.StatusResponse{ Status: "OK", - }) + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } } // RenderInteropError is a convenience method for interpreting interop errors diff --git a/lambda/rapi/router.go b/lambda/rapi/router.go index 1d2766a..5c2a56d 100644 --- a/lambda/rapi/router.go +++ b/lambda/rapi/router.go @@ -19,7 +19,7 @@ import ( // NewRouter returns a new instance of chi router implementing // Runtime API specification. -func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { +func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, eventsAPI telemetry.EventsAPI) http.Handler { router := chi.NewRouter() router.Use(middleware.AppCtxMiddleware(appCtx)) @@ -46,7 +46,11 @@ func NewRouter(appCtx appctx.ApplicationContext, registrationService core.Regist handler.NewInvocationErrorHandler(registrationService)).ServeHTTP) router.Post("/runtime/init/error", - handler.NewInitErrorHandler(registrationService).ServeHTTP) + handler.NewInitErrorHandler(registrationService, eventsAPI).ServeHTTP) + + if appctx.LoadInitType(appCtx) == appctx.InitCaching { + router.Get("/runtime/restore/next", handler.NewRestoreNextHandler(registrationService, renderingService).ServeHTTP) + } return router } @@ -80,14 +84,14 @@ func ExtensionsRouter(appCtx appctx.ApplicationContext, registrationService core // LogsAPIRouter returns a new instance of chi router implementing // Logs API specification. -func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.LogsSubscriptionAPI) http.Handler { +func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.SubscriptionAPI) http.Handler { router := chi.NewRouter() router.Use(middleware.AccessLogMiddleware()) router.Use(middleware.AllowIfExtensionsEnabled) router.Put("/logs", middleware.AgentUniqueIdentifierHeaderValidator( - handler.NewRuntimeLogsHandler(registrationService, logsSubscriptionAPI)).ServeHTTP) + handler.NewRuntimeTelemetrySubscriptionHandler(registrationService, logsSubscriptionAPI)).ServeHTTP) return router } @@ -98,7 +102,32 @@ func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptio func LogsAPIStubRouter() http.Handler { router := chi.NewRouter() - router.Put("/logs", handler.NewRuntimeLogsStubHandler().ServeHTTP) + router.Put("/logs", handler.NewRuntimeLogsAPIStubHandler().ServeHTTP) + + return router +} + +// TelemetryRouter returns a new instance of chi router implementing +// Telemetry API specification. +func TelemetryAPIRouter(registrationService core.RegistrationService, telemetrySubscriptionAPI telemetry.SubscriptionAPI) http.Handler { + router := chi.NewRouter() + router.Use(middleware.AccessLogMiddleware()) + router.Use(middleware.AllowIfExtensionsEnabled) + + router.Put("/telemetry", + middleware.AgentUniqueIdentifierHeaderValidator( + handler.NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscriptionAPI)).ServeHTTP) + + return router +} + +// TelemetryStubRouter returns a new instance of chi router implementing +// a stub of Telemetry API that always returns a non-committal response to +// prevent customer code from crashing when Telemetry API is disabled locally +func TelemetryAPIStubRouter() http.Handler { + router := chi.NewRouter() + + router.Put("/telemetry", handler.NewRuntimeTelemetryAPIStubHandler().ServeHTTP) return router } diff --git a/lambda/rapi/router_test.go b/lambda/rapi/router_test.go index f1cbde8..73cbde1 100644 --- a/lambda/rapi/router_test.go +++ b/lambda/rapi/router_test.go @@ -60,7 +60,7 @@ func assertResponseErrorType(t *testing.T, expectedErrorType string, response *h // rendered as JSON, regardless of the value provided // in "Accept" header. // -// When using render.Render(...), chi rendering library +// When using render.Render(...), rendering function // would attempt to render response using content type // specified in the "Accept" header. // @@ -69,7 +69,7 @@ func assertResponseErrorType(t *testing.T, expectedErrorType string, response *h func TestAcceptXML(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := httptest.NewRecorder() request := httptest.NewRequest("POST", "/runtime/invocation/x-y-z/error", bytes.NewReader([]byte(""))) // Tell server that client side accepts "application/xml". @@ -90,7 +90,7 @@ func TestAcceptXML(t *testing.T) { func Test404PageNotFound(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/unsupported", bytes.NewReader([]byte("")))) assert.Equal(t, http.StatusNotFound, responseRecorder.Code) assert.Equal(t, "404 page not found\n", responseRecorder.Body.String()) @@ -99,7 +99,7 @@ func Test404PageNotFound(t *testing.T) { func Test405MethodNotAllowed(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("DELETE", "/runtime/invocation/ABC/error", bytes.NewReader([]byte("")))) assert.Equal(t, http.StatusMethodNotAllowed, responseRecorder.Code) } @@ -107,7 +107,7 @@ func Test405MethodNotAllowed(t *testing.T) { func TestInitErrorAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/init/error", bytes.NewReader([]byte("{}")))) assert.Equal(t, http.StatusAccepted, responseRecorder.Code) } @@ -115,7 +115,7 @@ func TestInitErrorAccepted(t *testing.T) { func TestInitErrorForbidden(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -126,7 +126,7 @@ func TestInitErrorForbidden(t *testing.T) { func TestInvokeResponseAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -137,7 +137,7 @@ func TestInvokeResponseAccepted(t *testing.T) { func TestInvokeErrorResponseAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -148,7 +148,7 @@ func TestInvokeErrorResponseAccepted(t *testing.T) { func TestInvokeNextTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -159,7 +159,7 @@ func TestInvokeNextTwice(t *testing.T) { func TestInvokeResponseInvalidRequestID(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -171,7 +171,7 @@ func TestInvokeResponseInvalidRequestID(t *testing.T) { func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -183,7 +183,7 @@ func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { func TestInvokeResponseTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -197,7 +197,7 @@ func TestInvokeResponseTwice(t *testing.T) { func TestInvokeErrorResponseTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -211,7 +211,7 @@ func TestInvokeErrorResponseTwice(t *testing.T) { func TestInvokeResponseAfterErrorResponse(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -225,7 +225,7 @@ func TestInvokeResponseAfterErrorResponse(t *testing.T) { func TestInvokeErrorResponseAfterResponse(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -239,7 +239,7 @@ func TestInvokeErrorResponseAfterResponse(t *testing.T) { func TestMoreThanOneInvoke(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) var responseRecorder *httptest.ResponseRecorder for _, id := range []string{"A", "B", "C"} { flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) @@ -250,12 +250,25 @@ func TestMoreThanOneInvoke(t *testing.T) { } } +func TestInitCachingAPIDisabledForPlainInit(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + var responseRecorder *httptest.ResponseRecorder + + responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/restore/next", nil)) + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) + + responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/credentials", nil)) + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) +} + func benchmarkInvoke(b *testing.B, payload []byte) { b.StopTimer() b.ReportAllocs() flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) for i := 0; i < b.N; i++ { id := uuid.New().String() flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) diff --git a/lambda/rapi/security_test.go b/lambda/rapi/security_test.go index 3f869d5..5312b43 100644 --- a/lambda/rapi/security_test.go +++ b/lambda/rapi/security_test.go @@ -20,7 +20,7 @@ func TestInvokeValidId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) @@ -53,7 +53,7 @@ func TestSecurityInvokeResponseBadRequestId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) @@ -100,7 +100,7 @@ func TestSecurityInvokeErrorBadRequestId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) diff --git a/lambda/rapi/server.go b/lambda/rapi/server.go index e2c6ad4..dd027f4 100644 --- a/lambda/rapi/server.go +++ b/lambda/rapi/server.go @@ -13,6 +13,7 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/rendering" "go.amzn.com/lambda/telemetry" @@ -23,6 +24,7 @@ const version20180601 = "/2018-06-01" const version20200101 = "/2020-01-01" const version20200815 = "/2020-08-15" const version20210423 = "/2021-04-23" +const version20220701 = "/2022-07-01" // Server is a Runtime API server type Server struct { @@ -33,6 +35,10 @@ type Server struct { exit chan error } +func SaveConnInContext(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, interop.HTTPConnKey, c) +} + // NewServer creates a new Runtime API Server // // Unlike net/http server's ListenAndServe, we separate Listen() @@ -44,28 +50,30 @@ func NewServer(host string, port int, appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, telemetryAPIEnabled bool, - logsSubscriptionAPI telemetry.LogsSubscriptionAPI, initCachingEnabled bool, credentialsService core.CredentialsService) *Server { + logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI, credentialsService core.CredentialsService, eventsAPI telemetry.EventsAPI) *Server { exitErrors := make(chan error, 1) router := chi.NewRouter() - router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService)) + router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService, eventsAPI)) router.Mount(version20200101, ExtensionsRouter(appCtx, registrationService, renderingService)) if telemetryAPIEnabled { router.Mount(version20200815, LogsAPIRouter(registrationService, logsSubscriptionAPI)) + router.Mount(version20220701, TelemetryAPIRouter(registrationService, telemetrySubscriptionAPI)) } else { router.Mount(version20200815, LogsAPIStubRouter()) + router.Mount(version20220701, TelemetryAPIStubRouter()) } - if initCachingEnabled { + if appctx.LoadInitType(appCtx) == appctx.InitCaching { router.Mount(version20210423, CredentialsAPIRouter(credentialsService)) } return &Server{ host: host, port: port, - server: &http.Server{Handler: router}, + server: &http.Server{Handler: router, ConnContext: SaveConnInContext}, listener: nil, exit: exitErrors, } diff --git a/lambda/rapi/server_test.go b/lambda/rapi/server_test.go index ce6e4e6..cf31fab 100644 --- a/lambda/rapi/server_test.go +++ b/lambda/rapi/server_test.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "testing" "time" @@ -52,7 +51,7 @@ func TestServerReturnsSuccessfulResponse(t *testing.T) { if err != nil { assert.FailNowf(t, "Failed to get response", err.Error()) } - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { assert.FailNowf(t, "Failed to read response body", err.Error()) } diff --git a/lambda/rapid/bootstrap.go b/lambda/rapid/bootstrap.go deleted file mode 100644 index e82ec6c..0000000 --- a/lambda/rapid/bootstrap.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "os" - - "go.amzn.com/lambda/fatalerror" -) - -type Bootstrap interface { - Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable - Env(e EnvironmentVariables) []string // returns the environment variables to be passed to the bootstrapped process - Cwd() (string, error) // returns the working directory of the bootstrap process - ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime - CachedFatalError(err error) (fatalerror.ErrorType, string, bool) -} diff --git a/lambda/rapid/exit.go b/lambda/rapid/exit.go index af5cb72..e45f3a4 100644 --- a/lambda/rapid/exit.go +++ b/lambda/rapid/exit.go @@ -5,8 +5,9 @@ package rapid import ( "fmt" - "os" + "time" + "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" @@ -15,102 +16,116 @@ import ( log "github.com/sirupsen/logrus" ) -func checkInteropError(format string, err error) { - if err == interop.ErrInvalidInvokeID || err == interop.ErrResponseSent { - log.Warnf(format, err) - } else { - log.Panicf(format, err) +func handleInvokeError(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { + invokeFailure := newInvokeFailureMsg(execCtx, invokeRequest, invokeMx, err) + resp := model.ErrorResponse{ + ErrorType: string(invokeFailure.ErrorType), + ErrorMessage: fmt.Sprintf("Error: %v", invokeFailure.ErrorMessage), } -} -func trySendDefaultErrorResponse(interopServer interop.Server, invokeID string, errorType fatalerror.ErrorType, err error) { - resp := model.ErrorResponse{ - ErrorType: string(errorType), - ErrorMessage: fmt.Sprintf("Error: %v", err), + if invokeRequest.ID != "" { + resp.ErrorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeRequest.ID, invokeFailure.ErrorMessage) } - if invokeID != "" { - resp.ErrorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeID, err) + // This is the default error response that gets sent back as the function response in failure cases + invokeFailure.DefaultErrorResponse = resp.AsInteropError() + + // Invoke with extensions disabled maintains behaviour parity with pre-extensions rapid + if !extensions.AreEnabled() { + invokeFailure.RequestReset = false + return invokeFailure } - if err := interopServer.SendErrorResponse(invokeID, resp.AsInteropError()); err != nil { - checkInteropError("Failed to send default error response: %s", err) + if err == errResetReceived { + // errResetReceived is returned when execution flow was interrupted by the Reset message, + // hence this error deserves special handling and we yield to main receive loop to handle it + invokeFailure.ResetReceived = true + return invokeFailure } + + invokeFailure.RequestReset = true + return invokeFailure } -func reportErrorAndExit(doneFailMsg *interop.DoneFail, invokeID string, interopServer interop.Server, err error) { - // This function maintains compatibility of exit sequence behaviour - // with Sandbox Factory in the absence of extensions - - // NOTE this check will prevent us from sending FAULT message in case - // response (positive or negative) has already been sent. This is done - // to maintain legacy behavior of RAPID. - // ALSO NOTE, this works in case of positive response because this will - // be followed by RAPID exit. - if !interopServer.IsResponseSent() { - trySendDefaultErrorResponse(interopServer, invokeID, doneFailMsg.ErrorType, err) +func newInvokeFailureMsg(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { + errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + errorType = fatalerror.Unknown } - if err := interopServer.CommitResponse(); err != nil { - checkInteropError("Failed to commit error response: %s", err) + invokeFailure := &interop.InvokeFailure{ + ErrorType: errorType, + ErrorMessage: err, + RequestReset: true, + ResetReceived: false, + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + InvokeReceivedTime: invokeRequest.InvokeReceivedTime, } - // old behavior: no DoneFails - doneMsg := &interop.Done{ - WaitForExit: true, - CorrelationID: doneFailMsg.CorrelationID, // required for standalone mode - Meta: doneFailMsg.Meta, + if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeFailure.ResponseMetrics.RuntimeTimeThrottledMs = invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) + invokeFailure.ResponseMetrics.RuntimeProducedBytes = invokeRequest.InvokeResponseMetrics.ProducedBytes + invokeFailure.ResponseMetrics.RuntimeOutboundThroughputBps = invokeRequest.InvokeResponseMetrics.OutboundThroughputBps } - if err := interopServer.SendDone(doneMsg); err != nil { - checkInteropError("Failed to send DONE during exit: %s", err) + if invokeMx != nil { + invokeFailure.InvokeMetrics.InvokeRequestReadTimeNs = invokeMx.rendererMetrics.ReadTime.Nanoseconds() + invokeFailure.InvokeMetrics.InvokeRequestSizeBytes = int64(invokeMx.rendererMetrics.SizeBytes) + invokeFailure.InvokeMetrics.RuntimeReadyTime = int64(invokeMx.runtimeReadyTime) + invokeFailure.ExtensionNames = execCtx.GetExtensionNames() } - os.Exit(1) -} - -func reportErrorAndRequestReset(doneFailMsg *interop.DoneFail, invokeID string, interopServer interop.Server, err error) { - if err == errResetReceived { - // errResetReceived is returned when execution flow was interrupted by the Reset message, - // hence this error deserves special handling and we yield to main receive loop to handle it - return + if execCtx.telemetryAPIEnabled { + invokeFailure.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } - trySendDefaultErrorResponse(interopServer, invokeID, doneFailMsg.ErrorType, err) - - if err := interopServer.CommitResponse(); err != nil { - checkInteropError("Failed to commit error response: %s", err) - } + return invokeFailure +} - if err := interopServer.SendDoneFail(doneFailMsg); err != nil { - checkInteropError("Failed to send DONEFAIL: %s", err) +func generateInitFailureMsg(execCtx *rapidContext, err error) interop.InitFailure { + errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + errorType = fatalerror.Unknown } -} -func handleInitError(doneFailMsg *interop.DoneFail, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error) { - if execCtx.standaloneMode { - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) - return + initFailureMsg := interop.InitFailure{ + RequestReset: true, + ErrorType: errorType, + ErrorMessage: err, + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + Ack: make(chan struct{}), } - if !execCtx.HasActiveExtensions() { - // we don't expect Slicer to send RESET during INIT, that's why we Exit here - reportErrorAndExit(doneFailMsg, invokeID, interopServer, err) + if execCtx.telemetryAPIEnabled { + initFailureMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) + return initFailureMsg } -func handleInvokeError(doneFailMsg *interop.DoneFail, execCtx *rapidContext, invokeID string, interopServer interop.Server, err error) { - if execCtx.standaloneMode { - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) +func handleInitError(execCtx *rapidContext, invokeID string, err error, initFailureResponse chan<- interop.InitFailure) { + log.WithError(err).WithField("InvokeID", invokeID).Error("Init failed") + initFailureMsg := generateInitFailureMsg(execCtx, err) + + if err == errResetReceived { + // errResetReceived is returned when execution flow was interrupted by the Reset message, + // hence this error deserves special handling and we yield to main receive loop to handle it + initFailureMsg.ResetReceived = true + initFailureResponse <- initFailureMsg + <-initFailureMsg.Ack return } - // Invoke with extensions disabled maintains behaviour parity with pre-extensions rapid - if !extensions.AreEnabled() { - reportErrorAndExit(doneFailMsg, invokeID, interopServer, err) + if !execCtx.HasActiveExtensions() && !execCtx.standaloneMode { + // different behaviour when no extensions are present, + // for compatibility with previous implementations + initFailureMsg.RequestReset = false + } else { + initFailureMsg.RequestReset = true } - reportErrorAndRequestReset(doneFailMsg, invokeID, interopServer, err) + initFailureResponse <- initFailureMsg + <-initFailureMsg.Ack } diff --git a/lambda/rapid/graceful_shutdown.go b/lambda/rapid/graceful_shutdown.go deleted file mode 100644 index 5ad1326..0000000 --- a/lambda/rapid/graceful_shutdown.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "syscall" - "time" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/rapi/rendering" - - log "github.com/sirupsen/logrus" -) - -func sigkillProcessGroup(pid int, sigkilledPids map[int]bool) map[int]bool { - pgid, err := syscall.Getpgid(pid) - if err == nil { - syscall.Kill(-pgid, 9) // Negative pid sends signal to all in process group - } else { - syscall.Kill(pid, 9) - } - sigkilledPids[pid] = true - - return sigkilledPids -} - -func awaitSigkilledProcessesToExit(exitPidChan chan int, processesExited, sigkilledPidsToAwait map[int]bool) { - for pid := range processesExited { - delete(sigkilledPidsToAwait, pid) - } - - for len(sigkilledPidsToAwait) != 0 { - pid := <-exitPidChan - _, found := sigkilledPidsToAwait[pid] - if !found { - log.Warnf("Unexpected process %d exited while waiting for sigkilled processes to exit", pid) - } else { - delete(sigkilledPidsToAwait, pid) - } - } -} - -func gracefulShutdown(execCtx *rapidContext, watchdog *core.Watchdog, profiler *metering.ExtensionsResetDurationProfiler, deadlineNs int64, killAgents bool, reason string) { - watchdog.Mute() - defer watchdog.Unmute() - - if execCtx.registrationService.CountAgents() == 0 { - // We do not spend any compute time on runtime graceful shutdown if there are no agents - if runtime := execCtx.registrationService.GetRuntime(); runtime != nil && runtime.Pid != 0 { - sigkilledPids := sigkillProcessGroup(runtime.Pid, map[int]bool{}) - if execCtx.standaloneMode { - processesExited := map[int]bool{} - awaitSigkilledProcessesToExit(execCtx.exitPidChan, processesExited, sigkilledPids) - } - } - return - } - - mono := metering.Monotime() - - availableNs := deadlineNs - mono - - if availableNs < 0 { - log.Warnf("Deadline is in the past: %v, %v, %v", mono, deadlineNs, availableNs) - availableNs = 0 - } - - profiler.AvailableNs = availableNs - - start := time.Now() - profiler.Start() - - runtimeDeadline := start.Add(time.Duration(float64(availableNs) * runtimeDeadlineShare)) - agentsDeadline := start.Add(time.Duration(availableNs)) - - sigkilledPids := make(map[int]bool) // Track process ids that were sent sigkill - processesExited := make(map[int]bool) // Don't send sigkill to processes that exit out of order - - processesExited, sigkilledPids = shutdownRuntime(execCtx, start, runtimeDeadline, processesExited, sigkilledPids) - processesExited, sigkilledPids = shutdownAgents(execCtx, start, profiler, agentsDeadline, killAgents, reason, processesExited, sigkilledPids) - if execCtx.standaloneMode { - awaitSigkilledProcessesToExit(execCtx.exitPidChan, processesExited, sigkilledPids) - } - - profiler.Stop() -} - -func shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time, processesExited, sigkilledPids map[int]bool) (map[int]bool, map[int]bool) { - // If runtime is started: - // 1. SIGTERM and wait until timeout - // 2. SIGKILL on timeout - - log.Debug("shutdown runtime") - runtime := execCtx.registrationService.GetRuntime() - if runtime == nil || runtime.Pid == 0 { - log.Warn("Runtime not started") - return processesExited, sigkilledPids - } - - syscall.Kill(runtime.Pid, syscall.SIGTERM) - - runtimeTimeout := deadline.Sub(start) - runtimeTimer := time.NewTimer(runtimeTimeout) - - for { - select { - case pid := <-execCtx.exitPidChan: - processesExited[pid] = true - if pid == runtime.Pid { - log.Info("runtime exited") - return processesExited, sigkilledPids - } - - log.Warnf("Process %d exited unexpectedly", pid) - case <-runtimeTimer.C: - log.Warnf("Timeout: no SIGCHLD from Runtime after %d ms; dispatching SIGKILL to runtime process group", int64(runtimeTimeout/time.Millisecond)) - sigkilledPids = sigkillProcessGroup(runtime.Pid, sigkilledPids) - return processesExited, sigkilledPids - } - } -} - -func shutdownAgents(execCtx *rapidContext, start time.Time, profiler *metering.ExtensionsResetDurationProfiler, deadline time.Time, killAgents bool, reason string, processesExited, sigkilledPids map[int]bool) (map[int]bool, map[int]bool) { - // For each external agent, if agent is launched: - // 1. Send Shutdown event if subscribed for it, else send SIGKILL to process group - // 2. Wait for all Shutdown-subscribed agents to exit with timeout - // 3. Send SIGKILL to process group for Shutdown-subscribed agents on timeout - - log.Debug("shutdown agents") - execCtx.renderingService.SetRenderer( - &rendering.ShutdownRenderer{ - AgentEvent: model.AgentShutdownEvent{ - AgentEvent: &model.AgentEvent{ - EventType: "SHUTDOWN", - DeadlineMs: deadline.UnixNano() / (1000 * 1000), - }, - ShutdownReason: reason, - }, - }) - - pidsToShutdown := make(map[int]*core.ExternalAgent) - for _, a := range execCtx.registrationService.GetExternalAgents() { - if a.Pid == 0 { - log.Warnf("Agent %s failed not launched; skipping shutdown", a) - continue - } - if a.IsSubscribed(core.ShutdownEvent) { - pidsToShutdown[a.Pid] = a - a.Release() - } else { - if !processesExited[a.Pid] { - sigkilledPids = sigkillProcessGroup(a.Pid, sigkilledPids) - } - } - } - profiler.NumAgentsRegisteredForShutdown = len(pidsToShutdown) - - var timerChan <-chan time.Time // default timerChan - if killAgents { - timerChan = time.NewTimer(deadline.Sub(start)).C // timerChan with deadline - } - - timeoutExceeded := false - for !timeoutExceeded && len(pidsToShutdown) != 0 { - select { - case pid := <-execCtx.exitPidChan: - processesExited[pid] = true - a, found := pidsToShutdown[pid] - if !found { - log.Warnf("Process %d exited unexpectedly", pid) - } else { - if err := a.Exited(); err != nil { - log.Warnf("%s failed to transition to EXITED: %s (current state: %s)", a.String(), err, a.GetState().Name()) - } - delete(pidsToShutdown, pid) - } - case <-timerChan: - timeoutExceeded = true - } - } - - if len(pidsToShutdown) != 0 { - for pid, agent := range pidsToShutdown { - if err := agent.ShutdownFailed(); err != nil { - log.Warnf("%s failed to transition to ShutdownFailed: %s (current state: %s)", agent, err, agent.GetState().Name()) - } - log.Warnf("Killing agent %s which failed to shutdown", agent) - if !processesExited[pid] { - sigkilledPids = sigkillProcessGroup(pid, sigkilledPids) - } - } - } - - return processesExited, sigkilledPids -} diff --git a/lambda/rapid/sandbox.go b/lambda/rapid/sandbox.go index a5614b0..9259514 100644 --- a/lambda/rapid/sandbox.go +++ b/lambda/rapid/sandbox.go @@ -7,49 +7,43 @@ import ( "context" "fmt" "io" + "sync" + "time" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/logging" "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi" "go.amzn.com/lambda/rapi/rendering" + supvmodel "go.amzn.com/lambda/supervisor/model" "go.amzn.com/lambda/telemetry" -) -type EnvironmentVariables interface { - AgentExecEnv() []string - RuntimeExecEnv() []string - SetHandler(handler string) - StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) - StoreEnvironmentVariablesFromInit(customerEnv map[string]string, - handler, awsKey, awsSecret, awsSession, funcName, funcVer string) - StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) -} + log "github.com/sirupsen/logrus" +) type Sandbox struct { - EnableTelemetryAPI bool - StandaloneMode bool - Bootstrap Bootstrap - InteropServer interop.Server - Tracer telemetry.Tracer - LogsSubscriptionAPI telemetry.LogsSubscriptionAPI - LogsEgressAPI telemetry.LogsEgressAPI - Environment EnvironmentVariables - DebugTailLogger *logging.TailLogWriter - PlatformLogger logging.PlatformLogger - RuntimeStdoutWriter io.Writer - RuntimeStderrWriter io.Writer - PreLoadTimeNs int64 - Handler string - SignalCtx context.Context - EventsAPI telemetry.EventsAPI - InitCachingEnabled bool + EnableTelemetryAPI bool + StandaloneMode bool + InteropServer interop.Server + Tracer telemetry.Tracer + LogsSubscriptionAPI telemetry.SubscriptionAPI + TelemetrySubscriptionAPI telemetry.SubscriptionAPI + LogsEgressAPI telemetry.StdLogsEgressAPI + RuntimeStdoutWriter io.Writer + RuntimeStderrWriter io.Writer + PreLoadTimeNs int64 + Handler string + SignalCtx context.Context + EventsAPI telemetry.EventsAPI + InitCachingEnabled bool + Supervisor supvmodel.Supervisor + RuntimeAPIHost string + RuntimeAPIPort int } // Start is a public version of start() that exports only configurable parameters -func Start(s *Sandbox) { +func Start(s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, string) { appCtx := appctx.NewApplicationContext() initFlow := core.NewInitFlowSynchronization() invokeFlow := core.NewInvokeFlowSynchronization() @@ -57,19 +51,18 @@ func Start(s *Sandbox) { renderingService := rendering.NewRenderingService() credentialsService := core.NewCredentialsService() - if s.StandaloneMode { - s.InteropServer.SetInternalStateGetter(registrationService.GetInternalStateDescriptor(appCtx)) - } - server := rapi.NewServer(RuntimeAPIHost, RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.InitCachingEnabled, credentialsService) - - postLoadTimeNs := metering.Monotime() + appctx.StoreInitType(appCtx, s.InitCachingEnabled) + server := rapi.NewServer(s.RuntimeAPIHost, s.RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.TelemetrySubscriptionAPI, credentialsService, s.EventsAPI) runtimeAPIAddr := fmt.Sprintf("%s:%d", server.Host(), server.Port()) - s.Environment.StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddr) + postLoadTimeNs := metering.Monotime() + + // TODO: pass this directly down to HTTP servers and handlers, instead of using + // global state to share the interop server implementation appctx.StoreInteropServer(appCtx, s.InteropServer) - start(s.SignalCtx, &rapidContext{ + execCtx := &rapidContext{ server: server, appCtx: appCtx, postLoadTimeNs: postLoadTimeNs, @@ -78,24 +71,86 @@ func Start(s *Sandbox) { invokeFlow: invokeFlow, registrationService: registrationService, renderingService: renderingService, - exitPidChan: make(chan int), - resetChan: make(chan *interop.Reset), credentialsService: credentialsService, - telemetryAPIEnabled: s.EnableTelemetryAPI, - logsSubscriptionAPI: s.LogsSubscriptionAPI, - logsEgressAPI: s.LogsEgressAPI, - bootstrap: s.Bootstrap, - interopServer: s.InteropServer, - xray: s.Tracer, - environment: s.Environment, - standaloneMode: s.StandaloneMode, - debugTailLogger: s.DebugTailLogger, - platformLogger: s.PlatformLogger, - runtimeStdoutWriter: s.RuntimeStdoutWriter, - runtimeStderrWriter: s.RuntimeStderrWriter, - preLoadTimeNs: s.PreLoadTimeNs, - eventsAPI: s.EventsAPI, - initCachingEnabled: s.InitCachingEnabled, - }) + telemetryAPIEnabled: s.EnableTelemetryAPI, + logsSubscriptionAPI: s.LogsSubscriptionAPI, + telemetrySubscriptionAPI: s.TelemetrySubscriptionAPI, + logsEgressAPI: s.LogsEgressAPI, + interopServer: s.InteropServer, + xray: s.Tracer, + standaloneMode: s.StandaloneMode, + preLoadTimeNs: s.PreLoadTimeNs, + eventsAPI: s.EventsAPI, + initCachingEnabled: s.InitCachingEnabled, + signalCtx: s.SignalCtx, + supervisor: s.Supervisor, + executionMutex: sync.Mutex{}, + shutdownContext: newShutdownContext(), + } + + // We call /ping on Supervisor before starting Rapid, since Rapid + // depends on Supervisor setting up networking dependencies + var startupErr error + for retries := 1; retries <= 5; retries++ { + if startupErr = s.Supervisor.Ping(); startupErr == nil { + break + } + // Retry timeout: 5s, same order-of-mag as test client PING retries + // TODO: revisit retry timeout, identify appropriate value for prod. + time.Sleep(1000 * time.Millisecond) + } + + if startupErr != nil { + log.Panicf("Application ping to Supervisor failed, terminating Rapid Startup: %s", startupErr) + } + + go start(s.SignalCtx, execCtx) + + return execCtx, registrationService.GetInternalStateDescriptor(appCtx), runtimeAPIAddr +} + +func (r *rapidContext) HandleInit(init *interop.Init, initStartedResponseChan chan<- interop.InitStarted, initSuccessResponseChan chan<- interop.InitSuccess, initFailureResponseChan chan<- interop.InitFailure) { + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + handleInit(r, init, initStartedResponseChan, initSuccessResponseChan, initFailureResponseChan) +} + +func (r *rapidContext) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + // Clear the context used by the last invok + r.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) + return handleInvoke(r, invoke, sbInfoFromInit) +} + +func (r *rapidContext) HandleReset(reset *interop.Reset, invokeReceivedTime int64, InvokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { + // In the event of a Reset during init/invoke, CancelFlows cancels execution + // flows and return with the errResetReceived err - this error is special-cased + // and not handled by the init/invoke (unexpected) error handling functions + r.registrationService.CancelFlows(errResetReceived) + + // Wait until invoke error handling has returned before continuing execution + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + + // Clear the context used by the last invoke, i.e. error message etc. + r.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) + return handleReset(r, reset, invokeReceivedTime, InvokeResponseMetrics) +} + +func (r *rapidContext) HandleShutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { + // Wait until invoke error handling has returned before continuing execution + r.executionMutex.Lock() + defer r.executionMutex.Unlock() + // Shutdown doesn't cancel flows, so it can block forever + return handleShutdown(r, shutdown, standaloneShutdownReason) +} + +func (r *rapidContext) HandleRestore(restore *interop.Restore) error { + return handleRestore(r, restore) +} + +func (r *rapidContext) Clear() { + reinitialize(r) } diff --git a/lambda/rapid/shutdown.go b/lambda/rapid/shutdown.go new file mode 100644 index 0000000..fe23a9f --- /dev/null +++ b/lambda/rapid/shutdown.go @@ -0,0 +1,366 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package rapid implements synchronous even dispatch loop. +package rapid + +import ( + "fmt" + "sync" + "time" + + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/metering" + "go.amzn.com/lambda/rapi/model" + "go.amzn.com/lambda/rapi/rendering" + supvmodel "go.amzn.com/lambda/supervisor/model" + + log "github.com/sirupsen/logrus" +) + +const ( + // supervisor shutdown and kill operations block until the exit status of the + // interested process has been collected, or until the specified timeotuw + // expires (in which case the operation fails). + // Note that this timeout is mainly relevant when any of the domain + // processes are in uninterruptible sleep state (notable examples: syscall + // to read/write a newtorked driver) + // + // We set a non nil value for these timeouts so that RAPID doesn't block + // forever in one of the cases above. + supervisorBlockingMaxMillis = 9000 + runtimeDeadlineShare = 0.3 +) + +type shutdownContext struct { + // Adding a mutex around shuttingDown because there may be concurrent reads/writes. + // Because the code in shutdown() and the seperate go routine created in setupEventsWatcher() + // could be concurrently accessing the field shuttingDown. + shuttingDownMutex sync.Mutex + shuttingDown bool + agentsAwaitingExit map[string]*core.ExternalAgent + // Adding a mutex around runtimeDomainExited because there may be concurrent reads/writes. + // The first reason this can be caused is by different go routines reading/writing different keys. + // The second reason this can be caused is between the code shutting down the runtime/extensions and + // handleProcessExit in a separate go routine, reading and writing to the same key. Caused by + // unexpected exits. + runtimeDomainExitedMutex sync.Mutex + // used to synchronize on processes exits. We create the channel when a + // process is started and we close it upon exit notification from + // supervisor. Closing the channel is basically a persistent broadcast of process exit. + // We never write anything to the channels + runtimeDomainExited map[string]chan struct{} +} + +func newShutdownContext() *shutdownContext { + return &shutdownContext{ + shuttingDownMutex: sync.Mutex{}, + shuttingDown: false, + agentsAwaitingExit: make(map[string]*core.ExternalAgent), + runtimeDomainExited: make(map[string]chan struct{}), + runtimeDomainExitedMutex: sync.Mutex{}, + } +} + +func (s *shutdownContext) isShuttingDown() bool { + s.shuttingDownMutex.Lock() + defer s.shuttingDownMutex.Unlock() + return s.shuttingDown +} + +func (s *shutdownContext) setShuttingDown(value bool) { + s.shuttingDownMutex.Lock() + defer s.shuttingDownMutex.Unlock() + s.shuttingDown = value +} + +func (s *shutdownContext) handleProcessExit(termination supvmodel.ProcessTermination) { + + name := *termination.Name + agent, found := s.agentsAwaitingExit[name] + + // If it is an agent registered to receive a shutdown event. + if found { + log.Debugf("Handling termination for %s", name) + exitStatus := termination.Exited() + if exitStatus != nil && *exitStatus == 0 { + // If the agent exited by itself after receiving the shutdown event. + stateErr := agent.Exited() + if stateErr != nil { + log.Warnf("%s failed to transition to EXITED: %s (current state: %s)", agent.String(), stateErr, agent.GetState().Name()) + } + } else { + // If the agent did not exit by itself, had to be SIGKILLed (only in standalone mode). + stateErr := agent.ShutdownFailed() + if stateErr != nil { + log.Warnf("%s failed to transition to ShutdownFailed: %s (current state: %s)", agent, stateErr, agent.GetState().Name()) + } + } + } + + exitedChannel, found := s.getExitedChannel(name) + + if !found { + log.Panicf("Unable to find an exitedChannel for '%s', it should have been created just after it was execed.", name) + } + // we close the channel so that whoever is blocked on it + // or will try to block on it in the future unblocks immediately + close(exitedChannel) +} + +func (s *shutdownContext) getExitedChannel(name string) (chan struct{}, bool) { + s.runtimeDomainExitedMutex.Lock() + defer s.runtimeDomainExitedMutex.Unlock() + exitedChannel, found := s.runtimeDomainExited[name] + return exitedChannel, found +} + +func (s *shutdownContext) createExitedChannel(name string) { + s.runtimeDomainExitedMutex.Lock() + defer s.runtimeDomainExitedMutex.Unlock() + + _, found := s.runtimeDomainExited[name] + + if found { + log.Panicf("Tried to create an exited channel for '%s' but one already exists.", name) + } + s.runtimeDomainExited[name] = make(chan struct{}) +} + +// Blocks until all the processes in the runtime domain generation have exited. +// This helps us have a nice sync point on Shutdown where we know for sure that +// all the processes have exited and the state has been cleared. +// +// It is OK not to hold the lock because we know that this is called only during +// shutdown and nobody will start a new process during shutdown +func (s *shutdownContext) clearExitedChannel() { + s.runtimeDomainExitedMutex.Lock() + mapLen := len(s.runtimeDomainExited) + channels := make([]chan struct{}, 0, mapLen) + for _, v := range s.runtimeDomainExited { + channels = append(channels, v) + } + s.runtimeDomainExitedMutex.Unlock() + + for _, v := range channels { + <-v + } + + s.runtimeDomainExitedMutex.Lock() + s.runtimeDomainExited = make(map[string]chan struct{}, mapLen) + s.runtimeDomainExitedMutex.Unlock() +} + +func (s *shutdownContext) shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time) { + // If runtime is started: + // 1. SIGTERM and wait until timeout + // 2. SIGKILL on timeout + log.Debug("Shutting down the runtime.") + name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) + exitedChannel, found := s.getExitedChannel(name) + + if found { + + err := execCtx.supervisor.Terminate(&supvmodel.TerminateRequest{ + Domain: RuntimeDomain, + Name: name, + }) + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Termination signal to runtime") + } + + runtimeTimeout := deadline.Sub(start) + log.Tracef("The runtime timeout is %v.", runtimeTimeout) + runtimeTimer := time.NewTimer(runtimeTimeout) + select { + case <-runtimeTimer.C: + log.Warnf("Timeout: The runtime did not exit after %d ms; Killing it.", int64(runtimeTimeout/time.Millisecond)) + supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) + err = execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Kill signal to runtime") + } + case <-exitedChannel: + } + } else { + log.Warn("The runtime was not started.") + } + log.Debug("Shutdown the runtime.") +} + +func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, deadline time.Time, reason string) { + // For each external agent, if agent is launched: + // 1. Send Shutdown event if subscribed for it, else send SIGKILL to process group + // 2. Wait for all Shutdown-subscribed agents to exit with timeout + // 3. Send SIGKILL to process group for Shutdown-subscribed agents on timeout + + log.Debug("Shutting down the agents.") + execCtx.renderingService.SetRenderer( + &rendering.ShutdownRenderer{ + AgentEvent: model.AgentShutdownEvent{ + AgentEvent: &model.AgentEvent{ + EventType: "SHUTDOWN", + DeadlineMs: deadline.UnixNano() / (1000 * 1000), + }, + ShutdownReason: reason, + }, + }) + + var wg sync.WaitGroup + + // clear agentsAwaitingExit from last shutdownAgents + s.agentsAwaitingExit = make(map[string]*core.ExternalAgent) + + for _, a := range execCtx.registrationService.GetExternalAgents() { + name := fmt.Sprintf("extension-%s-%d", a.Name, execCtx.runtimeDomainGeneration) + exitedChannel, found := s.getExitedChannel(name) + supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) + + if !found { + log.Warnf("Agent %s failed to launch, therefore skipping shutting it down.", a) + continue + } + + wg.Add(1) + + if a.IsSubscribed(core.ShutdownEvent) { + log.Debugf("Agent %s is registered for the shutdown event.", a) + s.agentsAwaitingExit[name] = a + + go func(name string, agent *core.ExternalAgent) { + defer wg.Done() + + agent.Release() + + agentTimeout := deadline.Sub(start) + var agentTimeoutChan <-chan time.Time + if execCtx.standaloneMode { + agentTimeoutChan = time.NewTimer(agentTimeout).C + } + + select { + case <-agentTimeoutChan: + log.Warnf("Timeout: the agent %s did not exit after %d ms; Killing it.", name, int64(agentTimeout/time.Millisecond)) + err := execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Kill signal to runtime") + } + case <-exitedChannel: + } + }(name, a) + } else { + log.Debugf("Agent %s is not registered for the shutdown event, so just killing it.", a) + + go func(name string) { + defer wg.Done() + + execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + }(name) + } + } + + // Wait on the agents subscribed to the shutdown event to voluntary shutting down after receiving the shutdown event or be sigkilled. + // In addition to waiting on the agents not subscribed to the shutdown event being sigkilled. + wg.Wait() + log.Debug("Shutdown the agents.") +} + +func (s *shutdownContext) shutdown(execCtx *rapidContext, deadlineNs int64, reason string) (int64, bool, error) { + var err error + s.setShuttingDown(true) + defer s.setShuttingDown(false) + + // Fatal errors such as Runtime exit and Extension.Crash + // are ignored by the events watcher when shutting down + execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) + + runtimeDomainProfiler := &metering.ExtensionsResetDurationProfiler{} + supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) + + // We do not spend any compute time on runtime graceful shutdown if there are no agents + if execCtx.registrationService.CountAgents() == 0 { + name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) + + _, found := s.getExitedChannel(name) + + if found { + log.Debug("SIGKILLing the runtime as no agents are registered.") + err = execCtx.supervisor.Kill(&supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + // We are not reporting the error upstream because we will anyway + // shut the domain out at the end of the shutdown sequence + log.WithError(err).Warn("Failed sending Kill signal to runtime") + } + } else { + log.Debugf("Could not find runtime process %s in processes map. Already exited/never started", name) + } + } else { + mono := metering.Monotime() + availableNs := deadlineNs - mono + + if availableNs < 0 { + log.Warnf("Deadline is in the past: %v, %v, %v", mono, deadlineNs, availableNs) + availableNs = 0 + } + + start := time.Now() + + runtimeDeadline := start.Add(time.Duration(float64(availableNs) * runtimeDeadlineShare)) + agentsDeadline := start.Add(time.Duration(availableNs)) + + runtimeDomainProfiler.AvailableNs = availableNs + runtimeDomainProfiler.Start() + + s.shutdownRuntime(execCtx, start, runtimeDeadline) + s.shutdownAgents(execCtx, start, agentsDeadline, reason) + + runtimeDomainProfiler.NumAgentsRegisteredForShutdown = len(s.agentsAwaitingExit) + } + log.Info("Stopping runtime domain") + err = execCtx.supervisor.Stop(&supvmodel.StopRequest{ + Domain: RuntimeDomain, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + log.WithError(err).Error("Failed shutting runtime domain down") + } else { + log.Info("Waiting for runtime domain processes termination") + s.clearExitedChannel() + log.Info("Stopping operator domain") + err = execCtx.supervisor.Stop(&supvmodel.StopRequest{ + Domain: OperatorDomain, + Timeout: &supervisorBlockingMaxMillis, + }) + if err != nil { + log.WithError(err).Error("Failed shutting operator domain down") + } + } + + runtimeDomainProfiler.Stop() + extensionsRestMs, timeout := runtimeDomainProfiler.CalculateExtensionsResetMs() + return extensionsRestMs, timeout, err +} diff --git a/lambda/rapid/start.go b/lambda/rapid/start.go index 087ef13..76337af 100644 --- a/lambda/rapid/start.go +++ b/lambda/rapid/start.go @@ -7,9 +7,11 @@ package rapid import ( "context" "errors" - "io" + "fmt" "os" + "path" "strings" + "sync" "time" "go.amzn.com/lambda/agents" @@ -18,11 +20,11 @@ import ( "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/logging" "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi" + "go.amzn.com/lambda/rapi/model" "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/runtimecmd" + supvmodel "go.amzn.com/lambda/supervisor/model" "go.amzn.com/lambda/telemetry" "github.com/google/uuid" @@ -31,11 +33,11 @@ import ( ) const ( - RuntimeAPIHost = "127.0.0.1" - RuntimeAPIPort = 9001 + RuntimeDomain = "runtime" + OperatorDomain = "operator" defaultAgentLocation = "/opt/extensions" - runtimeDeadlineShare = 0.3 disableExtensionsFile = "/opt/disable-extensions-jwigqn8j" + runtimeProcessName = "runtime" ) const ( @@ -47,37 +49,38 @@ const ( var errResetReceived = errors.New("errResetReceived") type rapidContext struct { - bootstrap Bootstrap - interopServer interop.Server - server *rapi.Server - appCtx appctx.ApplicationContext - preLoadTimeNs int64 - postLoadTimeNs int64 - startRequest *interop.Start - initDone bool - initFlow core.InitFlowSynchronization - invokeFlow core.InvokeFlowSynchronization - registrationService core.RegistrationService - renderingService *rendering.EventRenderingService - telemetryAPIEnabled bool - logsSubscriptionAPI telemetry.LogsSubscriptionAPI - logsEgressAPI telemetry.LogsEgressAPI - xray telemetry.Tracer - exitPidChan chan int - resetChan chan *interop.Reset - environment EnvironmentVariables - standaloneMode bool - debugTailLogger *logging.TailLogWriter - platformLogger logging.PlatformLogger - runtimeStdoutWriter io.Writer - runtimeStderrWriter io.Writer - eventsAPI telemetry.EventsAPI - initCachingEnabled bool - credentialsService core.CredentialsService + interopServer interop.Server + server *rapi.Server + appCtx appctx.ApplicationContext + preLoadTimeNs int64 + postLoadTimeNs int64 + initDone bool + supervisor supvmodel.Supervisor + runtimeDomainGeneration uint32 + initFlow core.InitFlowSynchronization + invokeFlow core.InvokeFlowSynchronization + registrationService core.RegistrationService + renderingService *rendering.EventRenderingService + telemetryAPIEnabled bool + logsSubscriptionAPI telemetry.SubscriptionAPI + telemetrySubscriptionAPI telemetry.SubscriptionAPI + logsEgressAPI telemetry.StdLogsEgressAPI + xray telemetry.Tracer + standaloneMode bool + eventsAPI telemetry.EventsAPI + initCachingEnabled bool + credentialsService core.CredentialsService + signalCtx context.Context + executionMutex sync.Mutex + shutdownContext *shutdownContext } +// Validate interface compliance +var _ interop.RapidContext = (*rapidContext)(nil) + type invokeMetrics struct { - rendererMetrics rendering.InvokeRendererMetrics + rendererMetrics rendering.InvokeRendererMetrics + runtimeReadyTime int64 } @@ -102,7 +105,7 @@ func (c *rapidContext) GetExtensionNames() string { func logAgentsInitStatus(execCtx *rapidContext) { for _, agent := range execCtx.registrationService.AgentsInfo() { - execCtx.platformLogger.LogExtensionInitEvent(agent.Name, agent.State, agent.ErrorType, agent.Subscriptions) + execCtx.eventsAPI.SendExtensionInit(agent.Name, agent.State, agent.ErrorType, agent.Subscriptions) } } @@ -113,8 +116,7 @@ func agentLaunchError(agent *core.ExternalAgent, appCtx appctx.ApplicationContex appctx.StoreFirstFatalError(appCtx, fatalerror.AgentLaunchError) } -func doInitExtensions(execCtx *rapidContext, watchdog *core.Watchdog) error { - agentPaths := agents.ListExternalAgentPaths(defaultAgentLocation) +func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, env interop.EnvironmentVariables) error { initFlow := execCtx.registrationService.InitFlow() // we don't bring it into the loop below because we don't want unnecessary broadcasts on agent gate @@ -123,38 +125,42 @@ func doInitExtensions(execCtx *rapidContext, watchdog *core.Watchdog) error { } for _, agentPath := range agentPaths { - env := execCtx.environment.AgentExecEnv() - - agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() + // Using path.Base(agentPath) not agentName because the agent name is contact, as standalone can get the internal state. + agent, err := execCtx.registrationService.CreateExternalAgent(path.Base(agentPath)) if err != nil { return err } - // Compose debug log writer with all log sinks. Debug log writer w - // will not write logs when disabled by invoke parameter - agentStdoutWriter = io.MultiWriter(execCtx.debugTailLogger, agentStdoutWriter) - agentStderrWriter = io.MultiWriter(execCtx.debugTailLogger, agentStderrWriter) + if execCtx.registrationService.CountAgents() > core.MaxAgentsAllowed { + agentLaunchError(agent, execCtx.appCtx, core.ErrTooManyExtensions) + return core.ErrTooManyExtensions + } - agentProc := agents.NewExternalAgentProcess(agentPath, env, agentStdoutWriter, agentStderrWriter) + env := env.AgentExecEnv() - agent, err := execCtx.registrationService.CreateExternalAgent(agentProc.Name()) + agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() if err != nil { return err } + agentName := fmt.Sprintf("extension-%s-%d", path.Base(agentPath), execCtx.runtimeDomainGeneration) - if execCtx.registrationService.CountAgents() > core.MaxAgentsAllowed { - agentLaunchError(agent, execCtx.appCtx, core.ErrTooManyExtensions) - return core.ErrTooManyExtensions - } + err = execCtx.supervisor.Exec(&supvmodel.ExecRequest{ + Domain: domain, + Name: agentName, + Path: agentPath, + Env: &env, + StdoutWriter: agentStdoutWriter, + StderrWriter: agentStderrWriter, + }) - if err := agentProc.Start(); err != nil { + if err != nil { agentLaunchError(agent, execCtx.appCtx, err) return err } - agent.Pid = watchdog.GoWait(&agentProc, fatalerror.AgentCrash) + execCtx.shutdownContext.createExitedChannel(agentName) } if err := initFlow.AwaitExternalAgentsRegistered(); err != nil { @@ -164,20 +170,154 @@ func doInitExtensions(execCtx *rapidContext, watchdog *core.Watchdog) error { return nil } -func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) error { +func doRuntimeBootstrap(execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit) ([]string, map[string]string, string, []*os.File, error) { + env := sbInfoFromInit.EnvironmentVariables + runtimeBootstrap := sbInfoFromInit.RuntimeBootstrap + bootstrapCmd, err := runtimeBootstrap.Cmd() + if err != nil { + if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) + execCtx.eventsAPI.SendImageErrorLog(formattedLog) + } else { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) + } + return []string{}, map[string]string{}, "", []*os.File{}, err + } + + bootstrapEnv := runtimeBootstrap.Env(env) + bootstrapCwd, err := runtimeBootstrap.Cwd() + if err != nil { + if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) + execCtx.eventsAPI.SendImageErrorLog(formattedLog) + } else { + appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidWorkingDir) + } + return []string{}, map[string]string{}, "", []*os.File{}, err + } + + bootstrapExtraFiles := runtimeBootstrap.ExtraFiles() + + return bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, nil +} + +func (c *rapidContext) setupEventsWatcher(events <-chan supvmodel.Event) { + go func() { + for event := range events { + var err error = nil + log.Debugf("The events handler received the event %+v.", event) + if loss := event.Event.EventLoss(); loss != nil { + log.Panicf("Lost %d events from supervisor", *loss) + } + termination := event.Event.ProcessTerminated() + + // If we are not shutting down then we care if an unexpected exit happens. + if !c.shutdownContext.isShuttingDown() { + runtimeProcessName := fmt.Sprintf("%s-%d", runtimeProcessName, c.runtimeDomainGeneration) + + // If event from the runtime. + if *termination.Name == runtimeProcessName { + if termination.Success() { + err = fmt.Errorf("Runtime exited without providing a reason") + } else { + err = fmt.Errorf("Runtime exited with error: %s", termination.String()) + } + appctx.StoreFirstFatalError(c.appCtx, fatalerror.RuntimeExit) + } else { + if termination.Success() { + err = fmt.Errorf("exit code 0") + } else { + err = fmt.Errorf(termination.String()) + } + + appctx.StoreFirstFatalError(c.appCtx, fatalerror.AgentCrash) + } + + log.Warnf("Process %s exited: %+v", *termination.Name, termination) + } + + // At the moment we only get termination events. + // When their are other event types then we would need to be selective, + // about what we send to handleShutdownEvent(). + c.shutdownContext.handleProcessExit(*termination) + c.registrationService.CancelFlows(err) + } + }() +} + +func doOperatorDomainInit(ctx context.Context, execCtx *rapidContext, operatorDomainExtraConfig interop.DynamicDomainConfig) error { + events, err := execCtx.supervisor.Events() + if err != nil { + log.WithError(err).Panic("Could not get events stream from supervsior") + } + execCtx.setupEventsWatcher(events) + + log.Info("Configuring and starting Operator Domain") + conf := operatorDomainExtraConfig + err = execCtx.supervisor.Configure(&supvmodel.ConfigureRequest{ + Domain: OperatorDomain, + AdditionalStartHooks: conf.AdditionalStartHooks, + Mounts: conf.Mounts, + }) + + if err != nil { + log.WithError(err).Error("Failed to configure operator domain") + return err + } + + err = execCtx.supervisor.Start(&supvmodel.StartRequest{ + Domain: OperatorDomain, + }) + + if err != nil { + log.WithError(err).Error("Failed to start operator domain") + return err + } + + // we configure the runtime domain only once and not at + // every init phase (e.g., suppressed or reset). + err = execCtx.supervisor.Configure(&supvmodel.ConfigureRequest{ + Domain: RuntimeDomain, + }) + + if err != nil { + log.WithError(err).Error("Failed to configure operator domain") + return err + } + + return nil + +} + +func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit) error { execCtx.xray.RecordInitStartTime() defer execCtx.xray.RecordInitEndTime() - if extensions.AreEnabled() { - defer func() { + defer func() { + if extensions.AreEnabled() { logAgentsInitStatus(execCtx) - }() + } + }() + + log.Info("Starting runtime domain") + err := execCtx.supervisor.Start(&supvmodel.StartRequest{ + Domain: RuntimeDomain, + }) + if err != nil { + log.WithError(err).Panic("Failed configuring runtime domain") + } + execCtx.runtimeDomainGeneration++ - if err := doInitExtensions(execCtx, watchdog); err != nil { + if extensions.AreEnabled() { + runtimeExtensions := agents.ListExternalAgentPaths(defaultAgentLocation, + execCtx.supervisor.RuntimeConfig.RootPath) + if err := doInitExtensions(RuntimeDomain, runtimeExtensions, execCtx, sbInfoFromInit.EnvironmentVariables); err != nil { return err } } + appctx.StoreSandboxType(execCtx.appCtx, sbInfoFromInit.SandboxType) + initFlow := execCtx.registrationService.InitFlow() // Runtime state machine @@ -188,56 +328,66 @@ func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) // runtime is implicitly subscribed for certain lifecycle events. log.Debug("Preregister runtime") registrationService := execCtx.registrationService - if err := registrationService.PreregisterRuntime(runtime); err != nil { + err = registrationService.PreregisterRuntime(runtime) + + if err != nil { return err } - bootstrap := execCtx.bootstrap - bootstrapCmd, err := bootstrap.Cmd() + bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, err := doRuntimeBootstrap(execCtx, sbInfoFromInit) + if err != nil { - if fatalError, formattedLog, hasError := bootstrap.CachedFatalError(err); hasError { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.platformLogger.Printf("%s", formattedLog) - } else { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) - } return err } - bootstrapEnv := bootstrap.Env(execCtx.environment) - bootstrapCwd, err := bootstrap.Cwd() + runtimeStdoutWriter, runtimeStderrWriter, err := execCtx.logsEgressAPI.GetRuntimeSockets() + if err != nil { - if fatalError, formattedLog, hasError := bootstrap.CachedFatalError(err); hasError { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.platformLogger.Printf("%s", formattedLog) - } else { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidWorkingDir) - } return err } - bootstrapExtraFiles := bootstrap.ExtraFiles() - runtimeCmd := runtimecmd.NewCustomRuntimeCmd(ctx, bootstrapCmd, bootstrapCwd, bootstrapEnv, execCtx.runtimeStdoutWriter, execCtx.runtimeStderrWriter, bootstrapExtraFiles) - log.Debug("Start runtime") - err = runtimeCmd.Start() + checkCredentials(execCtx, bootstrapEnv) + name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) + err = execCtx.supervisor.Exec(&supvmodel.ExecRequest{ + Domain: RuntimeDomain, + Name: name, + Cwd: &bootstrapCwd, + Path: bootstrapCmd[0], + Args: bootstrapCmd[1:], + Env: &bootstrapEnv, + StdoutWriter: runtimeStdoutWriter, + StderrWriter: runtimeStderrWriter, + ExtraFiles: &bootstrapExtraFiles, + }) + + runtimeDoneStatus := telemetry.RuntimeDoneSuccess + + defer func() { + sendInitRuntimeDoneLogEvent(execCtx, sbInfoFromInit.SandboxType, runtimeDoneStatus) + }() + if err != nil { - if fatalError, formattedLog, hasError := bootstrap.CachedFatalError(err); hasError { + if fatalError, formattedLog, hasError := sbInfoFromInit.RuntimeBootstrap.CachedFatalError(err); hasError { appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.platformLogger.Printf("%s", formattedLog) + execCtx.eventsAPI.SendImageErrorLog(formattedLog) } else { appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) } + runtimeDoneStatus = telemetry.RuntimeDoneFailure return err } - registrationService.GetRuntime().Pid = watchdog.GoWait(runtimeCmd, fatalerror.RuntimeExit) + execCtx.shutdownContext.createExitedChannel(name) - if err := initFlow.AwaitRuntimeReady(); err != nil { + if err := initFlow.AwaitRuntimeRestoreReady(); err != nil { + runtimeDoneStatus = telemetry.RuntimeDoneFailure return err } + runtimeDoneStatus = telemetry.RuntimeDoneSuccess + // Registration phase finished for agents - no more agents can be registered with the system registrationService.TurnOff() if extensions.AreEnabled() { @@ -253,22 +403,17 @@ func doInit(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog) // Logs API subscription phase finished for agents - no more agents can be subscribed to the Logs API if execCtx.telemetryAPIEnabled { execCtx.logsSubscriptionAPI.TurnOff() + execCtx.telemetrySubscriptionAPI.TurnOff() } execCtx.initDone = true + return nil } -func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke, mx *invokeMetrics) error { +func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop.Invoke, mx *invokeMetrics, sbInfoFromInit interop.SandboxInfoFromInit) error { execCtx.eventsAPI.SetCurrentRequestID(invokeRequest.ID) appCtx := execCtx.appCtx - appctx.StoreErrorResponse(appCtx, nil) - - if invokeRequest.NeedDebugLogs { - execCtx.debugTailLogger.Enable() - } else { - execCtx.debugTailLogger.Disable() - } xray := execCtx.xray xray.Configure(invokeRequest) @@ -277,11 +422,11 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo if !execCtx.initDone { // do inline init if err := xray.CaptureInitSubsegment(ctx, func(ctx context.Context) error { - return doInit(ctx, execCtx, watchdog) + return doRuntimeDomainInit(ctx, execCtx, sbInfoFromInit) }); err != nil { return err } - } else if execCtx.startRequest.SandboxType != interop.SandboxPreWarmed { + } else if sbInfoFromInit.SandboxType != interop.SandboxPreWarmed { xray.SendInitSubsegmentWithRecordedTimesOnce(ctx) } @@ -317,16 +462,20 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo if extensions.AreEnabled() { log.Debug("Release agents conditions") for _, agent := range extAgents { + //TODO handle Supervisors listening channel agent.Release() } for _, agent := range intAgents { + //TODO handle Supervisors listening channel agent.Release() } } log.Debug("Release runtime condition") + //TODO handle Supervisors listening channel runtime.Release() log.Debug("Await runtime response") + //TODO handle Supervisors listening channel return invokeFlow.AwaitRuntimeResponse() })); err != nil { return err @@ -335,12 +484,21 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo // Runtime overhead if err := xray.CaptureOverheadSubsegment(ctx, func(ctx context.Context) error { log.Debug("Await runtime ready") + //TODO handle Supervisors listening channel return invokeFlow.AwaitRuntimeReady() }); err != nil { return err } mx.runtimeReadyTime = metering.Monotime() - if err := execCtx.eventsAPI.SendRuntimeDone("success"); err != nil { + + runtimeDoneEventData := telemetry.InvokeRuntimeDoneData{ + Status: telemetry.RuntimeDoneSuccess, + Metrics: telemetry.GetRuntimeDoneInvokeMetrics(invokeRequest.InvokeReceivedTime, invokeRequest.InvokeResponseMetrics, mx.runtimeReadyTime), + InternalMetrics: invokeRequest.InvokeResponseMetrics, + Tracing: telemetry.BuildTracingCtx(model.XRayTracingType, invokeRequest.TraceID, invokeRequest.LambdaSegmentID), + Spans: telemetry.GetRuntimeDoneSpans(invokeRequest.InvokeReceivedTime, invokeRequest.InvokeResponseMetrics), + } + if err := execCtx.eventsAPI.SendRuntimeDone(runtimeDoneEventData); err != nil { log.Errorf("Failed to send RUNDONE: %s", err) } @@ -348,6 +506,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdo if execCtx.HasActiveExtensions() { execCtx.interopServer.SendRuntimeReady() log.Debug("Await agents ready") + //TODO handle Supervisors listening channel if err := invokeFlow.AwaitAgentsReady(); err != nil { log.Warnf("AwaitAgentsReady() = %s", err) return err @@ -364,177 +523,148 @@ func extensionsDisabledByLayer() bool { return err == nil } -// acceptStartRequest is a second initialization phase, performed after receiving START +// acceptInitRequest is a second initialization phase, performed after receiving START // initialized entities: _HANDLER, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN -func (c *rapidContext) acceptStartRequest(startRequest *interop.Start) { - c.startRequest = startRequest - c.environment.StoreEnvironmentVariablesFromInit( - startRequest.CustomerEnvironmentVariables, - startRequest.Handler, - startRequest.AwsKey, - startRequest.AwsSecret, - startRequest.AwsSession, - startRequest.FunctionName, - startRequest.FunctionVersion) +func (c *rapidContext) acceptInitRequest(initRequest *interop.Init) *interop.Init { + initRequest.EnvironmentVariables.StoreEnvironmentVariablesFromInit( + initRequest.CustomerEnvironmentVariables, + initRequest.Handler, + initRequest.AwsKey, + initRequest.AwsSecret, + initRequest.AwsSession, + initRequest.FunctionName, + initRequest.FunctionVersion) c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - FunctionName: startRequest.FunctionName, - FunctionVersion: startRequest.FunctionVersion, - Handler: startRequest.Handler, + FunctionName: initRequest.FunctionName, + FunctionVersion: initRequest.FunctionVersion, + Handler: initRequest.Handler, + RuntimeInfo: initRequest.RuntimeInfo, }) if extensionsDisabledByLayer() { extensions.Disable() } + + return initRequest } -func (c *rapidContext) acceptStartRequestForInitCaching(startRequest *interop.Start) error { +func (c *rapidContext) acceptInitRequestForInitCaching(initRequest *interop.Init) (*interop.Init, error) { log.Info("Configure environment for Init Caching.") - c.startRequest = startRequest randomUUID, err := uuid.NewRandom() if err != nil { - return err + return initRequest, err } initCachingToken := randomUUID.String() - c.environment.StoreEnvironmentVariablesFromInitForInitCaching( - RuntimeAPIHost, - RuntimeAPIPort, - startRequest.CustomerEnvironmentVariables, - startRequest.Handler, - startRequest.FunctionName, - startRequest.FunctionVersion, + initRequest.EnvironmentVariables.StoreEnvironmentVariablesFromInitForInitCaching( + c.server.Host(), + c.server.Port(), + initRequest.CustomerEnvironmentVariables, + initRequest.Handler, + initRequest.FunctionName, + initRequest.FunctionVersion, initCachingToken) c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - FunctionName: startRequest.FunctionName, - FunctionVersion: startRequest.FunctionVersion, - Handler: startRequest.Handler, + FunctionName: initRequest.FunctionName, + FunctionVersion: initRequest.FunctionVersion, + Handler: initRequest.Handler, }) - c.credentialsService.SetCredentials(initCachingToken, startRequest.AwsKey, startRequest.AwsSecret, startRequest.AwsSession) + c.credentialsService.SetCredentials(initCachingToken, initRequest.AwsKey, initRequest.AwsSecret, initRequest.AwsSession, initRequest.CredentialsExpiry) if extensionsDisabledByLayer() { extensions.Disable() } - return nil + return initRequest, nil } -func handleStart(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, startRequest *interop.Start) { +func handleInit(execCtx *rapidContext, initRequest *interop.Init, + initStartedResponse chan<- interop.InitStarted, + initSuccessResponse chan<- interop.InitSuccess, + initFailureResponse chan<- interop.InitFailure) { + ctx := execCtx.signalCtx + if execCtx.initCachingEnabled { - if err := execCtx.acceptStartRequestForInitCaching(startRequest); err != nil { - handleStartError(execCtx, startRequest.InvokeID, startRequest.CorrelationID, err) + var err error + if initRequest, err = execCtx.acceptInitRequestForInitCaching(initRequest); err != nil { + // TODO: call handleInitError only after sending the RUNNING, since + // Slicer will fail receiving DONEFAIL here as it is expecting RUNNING + handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) return } - - execCtx.credentialsService.UnblockService() - defer execCtx.credentialsService.BlockService() } else { - execCtx.acceptStartRequest(startRequest) + initRequest = execCtx.acceptInitRequest(initRequest) } - interopServer, appCtx := execCtx.interopServer, execCtx.appCtx - - if err := interopServer.SendRunning(&interop.Running{ + initStartedMsg := interop.InitStarted{ PreLoadTimeNs: execCtx.preLoadTimeNs, PostLoadTimeNs: execCtx.postLoadTimeNs, WaitStartTimeNs: execCtx.postLoadTimeNs, WaitEndTimeNs: metering.Monotime(), ExtensionsEnabled: extensions.AreEnabled(), - }); err != nil { - log.Panic(err) + Ack: make(chan struct{}), } - if !startRequest.SuppressInit { - if err := doInit(ctx, execCtx, watchdog); err != nil { - handleStartError(execCtx, startRequest.InvokeID, startRequest.CorrelationID, err) - return - } - } + initStartedResponse <- initStartedMsg + <-initStartedMsg.Ack - doneMsg := &interop.Done{ - CorrelationID: startRequest.CorrelationID, - Meta: interop.DoneMetadata{ - RuntimeRelease: appctx.GetRuntimeRelease(appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - ExtensionNames: execCtx.GetExtensionNames(), - }, - } - if execCtx.telemetryAPIEnabled { - doneMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() - } - if err := interopServer.SendDone(doneMsg); err != nil { - log.Panic(err) - } - - if err := interopServer.StartAcceptingDirectInvokes(); err != nil { - log.Panic(err) - } -} - -func handleStartError(execCtx *rapidContext, invokeID string, correlationID string, err error) { - log.WithError(err).WithField("InvokeID", invokeID).Error("Init failed") - doneFailMsg := generateDoneFail(execCtx, correlationID, nil, 0) - handleInitError(doneFailMsg, execCtx, invokeID, execCtx.interopServer, err) -} - -func generateDoneFail(execCtx *rapidContext, correlationID string, invokeMx *invokeMetrics, invokeReceivedTime int64) *interop.DoneFail { - errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) - if !found { - errorType = fatalerror.Unknown + // Operator domain init happens only once, it's never suppressed, + // and it's terminal in case of failures + if err := doOperatorDomainInit(ctx, execCtx, initRequest.OperatorDomainExtraConfig); err != nil { + // TODO: I believe we need to handle this specially, because we want + // to consider any failure here as terminal + handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) + return } - doneFailMsg := &interop.DoneFail{ - ErrorType: errorType, - CorrelationID: correlationID, // required for standalone mode - Meta: interop.DoneMetadata{ - RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - InvokeReceivedTime: invokeReceivedTime, - }, + if !initRequest.SuppressInit { + // doRuntimeDomainInit() is used in both init/invoke, so the signature requires sbInfo arg + sbInfo := interop.SandboxInfoFromInit{ + EnvironmentVariables: initRequest.EnvironmentVariables, + SandboxType: initRequest.SandboxType, + RuntimeBootstrap: initRequest.Bootstrap, + } + if err := doRuntimeDomainInit(ctx, execCtx, sbInfo); err != nil { + handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) + return + } } - if invokeMx != nil { - doneFailMsg.Meta.InvokeRequestReadTimeNs = invokeMx.rendererMetrics.ReadTime.Nanoseconds() - doneFailMsg.Meta.InvokeRequestSizeBytes = int64(invokeMx.rendererMetrics.SizeBytes) - doneFailMsg.Meta.RuntimeReadyTime = int64(invokeMx.runtimeReadyTime) - doneFailMsg.Meta.ExtensionNames = execCtx.GetExtensionNames() + initSuccessMsg := interop.InitSuccess{ + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + ExtensionNames: execCtx.GetExtensionNames(), + Ack: make(chan struct{}), } if execCtx.telemetryAPIEnabled { - doneFailMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() + initSuccessMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } - return doneFailMsg + initSuccessResponse <- initSuccessMsg + <-initSuccessMsg.Ack } -func handleInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Watchdog, invokeRequest *interop.Invoke) { - interopServer, appCtx := execCtx.interopServer, execCtx.appCtx - +func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { + ctx := execCtx.signalCtx invokeMx := invokeMetrics{} - if invokeRequest.ResyncState.IsResyncReceived { - err := execCtx.credentialsService.UpdateCredentials(invokeRequest.ResyncState.AwsKey, invokeRequest.ResyncState.AwsSecret, invokeRequest.ResyncState.AwsSession) - execCtx.credentialsService.UnblockService() - - if err != nil { - log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Resync for Invoke failed") - doneFailMsg := generateDoneFail(execCtx, invokeRequest.CorrelationID, &invokeMx, invokeRequest.InvokeReceivedTime) - handleInvokeError(doneFailMsg, execCtx, invokeRequest.ID, interopServer, err) - } - } - - if err := doInvoke(ctx, execCtx, watchdog, invokeRequest, &invokeMx); err != nil { + if err := doInvoke(ctx, execCtx, invokeRequest, &invokeMx, sbInfoFromInit); err != nil { log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Invoke failed") - doneFailMsg := generateDoneFail(execCtx, invokeRequest.CorrelationID, &invokeMx, invokeRequest.InvokeReceivedTime) - handleInvokeError(doneFailMsg, execCtx, invokeRequest.ID, interopServer, err) - return - } + invokeFailure := handleInvokeError(execCtx, invokeRequest, &invokeMx, err) - if err := execCtx.interopServer.CommitResponse(); err != nil { - log.Panic(err) + if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeFailure.ResponseMetrics = interop.ResponseMetrics{ + RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), + RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, + RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, + } + } + return interop.InvokeSuccess{}, invokeFailure } var invokeCompletionTimeNs int64 @@ -542,30 +672,35 @@ func handleInvoke(ctx context.Context, execCtx *rapidContext, watchdog *core.Wat invokeCompletionTimeNs = time.Now().UnixNano() - responseTimeNs } - doneMsg := &interop.Done{ - CorrelationID: invokeRequest.CorrelationID, - Meta: interop.DoneMetadata{ - RuntimeRelease: appctx.GetRuntimeRelease(appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - ExtensionNames: execCtx.GetExtensionNames(), + invokeSuccessMsg := interop.InvokeSuccess{ + RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), + NumActiveExtensions: execCtx.registrationService.CountAgents(), + ExtensionNames: execCtx.GetExtensionNames(), + InvokeMetrics: interop.InvokeMetrics{ InvokeRequestReadTimeNs: invokeMx.rendererMetrics.ReadTime.Nanoseconds(), InvokeRequestSizeBytes: int64(invokeMx.rendererMetrics.SizeBytes), - InvokeCompletionTimeNs: invokeCompletionTimeNs, - InvokeReceivedTime: invokeRequest.InvokeReceivedTime, RuntimeReadyTime: invokeMx.runtimeReadyTime, }, + InvokeCompletionTimeNs: invokeCompletionTimeNs, + InvokeReceivedTime: invokeRequest.InvokeReceivedTime, } - if execCtx.telemetryAPIEnabled { - doneMsg.Meta.LogsAPIMetrics = execCtx.logsSubscriptionAPI.FlushMetrics() + + if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeSuccessMsg.ResponseMetrics = interop.ResponseMetrics{ + RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), + RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, + RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, + } } - if err := interopServer.SendDone(doneMsg); err != nil { - log.Panic(err) + if execCtx.telemetryAPIEnabled { + invokeSuccessMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } + + return invokeSuccessMsg, nil } -func reinitialize(execCtx *rapidContext, watchdog *core.Watchdog) { - execCtx.interopServer.Clear() +func reinitialize(execCtx *rapidContext) { execCtx.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) execCtx.appCtx.Delete(appctx.AppCtxRuntimeReleaseKey) execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) @@ -576,90 +711,125 @@ func reinitialize(execCtx *rapidContext, watchdog *core.Watchdog) { execCtx.invokeFlow.Clear() if execCtx.telemetryAPIEnabled { execCtx.logsSubscriptionAPI.Clear() + execCtx.telemetrySubscriptionAPI.Clear() } - watchdog.Clear() -} - -func blockForever() { - select {} } // handle notification of reset -func handleReset(execCtx *rapidContext, watchdog *core.Watchdog, reset *interop.Reset) { - log.Warnf("Reset initiated: %s", reset.Reason) - if execCtx.initCachingEnabled { - execCtx.credentialsService.UnblockService() +func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { + log.Warnf("Reset initiated: %s", resetEvent.Reason) + + // Only send RuntimeDone event if we get a reset during an Invoke + if resetEvent.Reason == "failure" || resetEvent.Reason == "timeout" { + runtimeDoneEventData := telemetry.InvokeRuntimeDoneData{ + Status: resetEvent.Reason, + InternalMetrics: invokeResponseMetrics, + Metrics: telemetry.GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, metering.Monotime()), + Tracing: telemetry.BuildTracingCtx(model.XRayTracingType, resetEvent.TraceID, resetEvent.LambdaSegmentID), + Spans: telemetry.GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics), + } + if err := execCtx.eventsAPI.SendRuntimeDone(runtimeDoneEventData); err != nil { + log.Errorf("Failed to send RUNDONE: %s", err) + } } - if err := execCtx.eventsAPI.SendRuntimeDone(reset.Reason); err != nil { - log.Errorf("Failed to send RUNDONE: %s", err) + extensionsResetMs, resetTimeout, _ := execCtx.shutdownContext.shutdown(execCtx, resetEvent.DeadlineNs, resetEvent.Reason) + + log.Info("Starting runtime domain") + err := execCtx.supervisor.Start(&supvmodel.StartRequest{ + Domain: RuntimeDomain, + }) + if err != nil { + log.WithError(err).Panic("Failed booting runtime domain") } + execCtx.runtimeDomainGeneration++ - profiler := metering.ExtensionsResetDurationProfiler{} - gracefulShutdown(execCtx, watchdog, &profiler, reset.DeadlineNs, execCtx.standaloneMode, reset.Reason) + // Only used by standalone for more indepth assertions. + var fatalErrorType fatalerror.ErrorType - extensionsResetMs, resetTimeout := profiler.CalculateExtensionsResetMs() + if execCtx.standaloneMode { + fatalErrorType, _ = appctx.LoadFirstFatalError(execCtx.appCtx) + } - meta := interop.DoneMetadata{ - ExtensionsResetMs: extensionsResetMs, + var responseMetrics interop.ResponseMetrics + if resetEvent.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(resetEvent.InvokeResponseMetrics) { + responseMetrics.RuntimeTimeThrottledMs = resetEvent.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) + responseMetrics.RuntimeProducedBytes = resetEvent.InvokeResponseMetrics.ProducedBytes + responseMetrics.RuntimeOutboundThroughputBps = resetEvent.InvokeResponseMetrics.OutboundThroughputBps } - if !execCtx.standaloneMode { - // GIRP interopServer implementation sends GIRP RSTFAIL/RSTDONE - if resetTimeout { - // TODO: DoneFail must contain a reset timeout ErrorType for rapid local to distinguish errors - doneFail := &interop.DoneFail{CorrelationID: reset.CorrelationID, Meta: meta} - if err := execCtx.interopServer.SendDoneFail(doneFail); err != nil { - log.Panicf("Failed to SendDoneFail: %s", err) - } - } else { - done := &interop.Done{CorrelationID: reset.CorrelationID, Meta: meta} - if err := execCtx.interopServer.SendDone(done); err != nil { - log.Panicf("Failed to SendDone: %s", err) - } + if resetTimeout { + return interop.ResetSuccess{}, &interop.ResetFailure{ + ExtensionsResetMs: extensionsResetMs, + ErrorType: fatalErrorType, + ResponseMetrics: responseMetrics, } - - os.Exit(0) } - reinitialize(execCtx, watchdog) + return interop.ResetSuccess{ + ExtensionsResetMs: extensionsResetMs, + ErrorType: fatalErrorType, + ResponseMetrics: responseMetrics, + }, nil +} + +// handle notification of shutdown +func handleShutdown(execCtx *rapidContext, shutdownEvent *interop.Shutdown, reason string) interop.ShutdownSuccess { + log.Warnf("Shutdown initiated: %s", reason) + // TODO Handle shutdown error + _, _, _ = execCtx.shutdownContext.shutdown(execCtx, shutdownEvent.DeadlineNs, reason) - fatalErrorType, _ := appctx.LoadFirstFatalError(execCtx.appCtx) + // Only used by standalone for more indepth assertions. + var fatalErrorType fatalerror.ErrorType - if resetTimeout { - doneFail := &interop.DoneFail{CorrelationID: reset.CorrelationID, ErrorType: fatalErrorType, Meta: meta} - if err := execCtx.interopServer.SendDoneFail(doneFail); err != nil { - log.Panicf("Failed to SendDoneFail: %s", err) - } - } else { - done := &interop.Done{CorrelationID: reset.CorrelationID, ErrorType: fatalErrorType, Meta: meta} - if err := execCtx.interopServer.SendDone(done); err != nil { - log.Panicf("Failed to SendDone: %s", err) - } + if execCtx.standaloneMode { + fatalErrorType, _ = appctx.LoadFirstFatalError(execCtx.appCtx) } + + return interop.ShutdownSuccess{ErrorType: fatalErrorType} } -// handle notification of shutdown -func handleShutdown(execCtx *rapidContext, watchdog *core.Watchdog, shutdown *interop.Shutdown, reason string) { - log.Warnf("Shutdown initiated") +func handleRestore(execCtx *rapidContext, restore *interop.Restore) error { + err := execCtx.credentialsService.UpdateCredentials(restore.AwsKey, restore.AwsSecret, restore.AwsSession, restore.CredentialsExpiry) + restoreStatus := telemetry.RuntimeDoneSuccess + + defer func() { + sendRestoreRuntimeDoneLogEvent(execCtx, restoreStatus) + }() + + if err != nil { + return fmt.Errorf("error when updating credentials: %s", err) + } + renderer := rendering.NewRestoreRenderer() + execCtx.renderingService.SetRenderer(renderer) + + registrationService := execCtx.registrationService + runtime := registrationService.GetRuntime() + + // If runtime has not called /restore/next then just return + // instead of releasing the Runtime since there is no need to release. + // Then the runtime should be released only during Invoke + if runtime.GetState() != runtime.RuntimeRestoreReadyState { + restoreStatus = telemetry.RuntimeDoneSuccess + log.Infof("Runtime is in state: %s just returning", runtime.GetState().Name()) + return nil + } - gracefulShutdown(execCtx, watchdog, &metering.ExtensionsResetDurationProfiler{}, shutdown.DeadlineNs, true, reason) + runtime.Release() - fatalErrorType, _ := appctx.LoadFirstFatalError(execCtx.appCtx) + initFlow := execCtx.initFlow + err = initFlow.AwaitRuntimeReady() - if err := execCtx.interopServer.SendDone(&interop.Done{CorrelationID: shutdown.CorrelationID, ErrorType: fatalErrorType}); err != nil { - log.Panicf("Failed to SendDone: %s", err) + if err != nil { + restoreStatus = telemetry.RuntimeDoneFailure + } else { + restoreStatus = telemetry.RuntimeDoneSuccess } - // Shutdown induces a terminal state and no further messages will be processed - blockForever() + return err } func start(signalCtx context.Context, execCtx *rapidContext) { - watchdog := core.NewWatchdog(execCtx.registrationService.InitFlow(), execCtx.invokeFlow, execCtx.exitPidChan, execCtx.appCtx) - - interopServer := execCtx.interopServer - // Start Runtime API Server err := execCtx.server.Listen() if err != nil { @@ -670,30 +840,40 @@ func start(signalCtx context.Context, execCtx *rapidContext) { // Note, most of initialization code should run before blocking to receive START, // code before START runs in parallel with code downloads. +} - go func() { - for { - reset := <-interopServer.ResetChan() - // In the event of a Reset during init/invoke, CancelFlows cancels execution - // flows and return with the errResetReceived err - this error is special-cased - // and not handled by the init/invoke (unexpected) error handling functions - watchdog.CancelFlows(errResetReceived) - execCtx.resetChan <- reset - } - }() +func sendRestoreRuntimeDoneLogEvent(execCtx *rapidContext, status string) { + if err := execCtx.eventsAPI.SendRestoreRuntimeDone(status); err != nil { + log.Errorf("Failed to send RESTRD: %s", err) + } +} + +func sendInitRuntimeDoneLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, status string) { + initSource := interop.InferTelemetryInitSource(execCtx.initCachingEnabled, sandboxType) + + runtimeDoneData := &telemetry.InitRuntimeDoneData{ + InitSource: initSource, + Status: status, + } - for { - select { - case start := <-interopServer.StartChan(): - handleStart(signalCtx, execCtx, watchdog, start) - case invoke := <-interopServer.InvokeChan(): - handleInvoke(signalCtx, execCtx, watchdog, invoke) - case err := <-interopServer.TransportErrorChan(): - log.Panicf("Transport error emitted by interop server: %s", err) - case reset := <-execCtx.resetChan: - handleReset(execCtx, watchdog, reset) - case shutdown := <-interopServer.ShutdownChan(): // only in standalone mode - handleShutdown(execCtx, watchdog, shutdown, standaloneShutdownReason) + if err := execCtx.eventsAPI.SendInitRuntimeDone(runtimeDoneData); err != nil { + log.Errorf("Failed to send INITRD: %s", err) + } +} + +// This function will log a line if AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, or AWS_SESSION_TOKEN is missing +// This is expected to happen in cases when credentials provider is not needed +func checkCredentials(execCtx *rapidContext, bootstrapEnv map[string]string) { + credentialsKeys := []string{"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"} + missingCreds := []string{} + + for _, credEnvVar := range credentialsKeys { + if val, keyExists := bootstrapEnv[credEnvVar]; !keyExists || val == "" { + missingCreds = append(missingCreds, credEnvVar) } } + + if len(missingCreds) > 0 { + log.Infof("Starting runtime without %s , Expected?: %t", strings.Join(missingCreds[:], ", "), execCtx.initCachingEnabled) + } } diff --git a/lambda/rapid/start_test.go b/lambda/rapid/start_test.go index 2363705..ffb446f 100644 --- a/lambda/rapid/start_test.go +++ b/lambda/rapid/start_test.go @@ -7,7 +7,7 @@ import ( "context" "fmt" "go.amzn.com/lambda/core" - "io/ioutil" + "io" "net/http" "regexp" "strconv" @@ -142,7 +142,7 @@ func TestListen(t *testing.T) { ctx := context.Background() telemetryAPIEnabled := true - server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.LogsSubscriptionAPI, false, flowTest.CredentialsService) + server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.TelemetrySubscription, flowTest.TelemetrySubscription, flowTest.CredentialsService, flowTest.EventsAPI) err := server.Listen() assert.NoError(t, err) @@ -161,7 +161,7 @@ func TestListen(t *testing.T) { resp, err1 := http.Get(fmt.Sprintf("http://%s:%d/2018-06-01/runtime/invocation/next", server.Host(), server.Port())) assert.Nil(t, err1) - body, err2 := ioutil.ReadAll(resp.Body) + body, err2 := io.ReadAll(resp.Body) assert.Nil(t, err2) assert.Equal(t, "MyTest", string(body)) @@ -171,3 +171,31 @@ func TestListen(t *testing.T) { <-done } + +func TestInferSandboxInitTypeOnDemand(t *testing.T) { + initCachingEnabled := false + sandboxType := interop.SandboxClassic + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "on-demand", initSource) +} + +func TestInferSandboxInitTypeProvisionedConcurrency(t *testing.T) { + initCachingEnabled := false + sandboxType := interop.SandboxPreWarmed + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "provisioned-concurrency", initSource) +} + +func TestInferSandboxInitTypeInitCaching(t *testing.T) { + initCachingEnabled := true + sandboxType := interop.SandboxClassic + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "snap-start", initSource) +} + +func TestInferSandboxInitTypeInitCachingWithPC(t *testing.T) { + initCachingEnabled := true + sandboxType := interop.SandboxPreWarmed + initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) + assert.Equal(t, "snap-start", initSource) +} diff --git a/lambda/rapidcore/bootstrap.go b/lambda/rapidcore/bootstrap.go index 9faf518..165f532 100644 --- a/lambda/rapidcore/bootstrap.go +++ b/lambda/rapidcore/bootstrap.go @@ -6,11 +6,12 @@ package rapidcore import ( "fmt" "os" + "path" "path/filepath" + "strings" "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/logging" - "go.amzn.com/lambda/rapid" + "go.amzn.com/lambda/interop" log "github.com/sirupsen/logrus" ) @@ -21,6 +22,7 @@ type BootstrapError func() (fatalerror.ErrorType, LogFormatter) // Bootstrap represents a list of executable bootstrap // candidates in order of priority and exec metadata type Bootstrap struct { + runtimeDomainRoot string orderedLookupPaths []string validCmd []string workingDir string @@ -29,8 +31,11 @@ type Bootstrap struct { bootstrapError BootstrapError } +// Validate interface compliance +var _ interop.Bootstrap = (*Bootstrap)(nil) + // NewBootstrap returns an instance of bootstrap defined by given params -func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string) *Bootstrap { +func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string, runtimeDomainRoot string) *Bootstrap { var orderedLookupBootstrapPaths []string for _, args := range cmdCandidates { // Empty args is an error, but we want to detect it later (in Cmd() call) when we are able to report a descriptive error @@ -44,23 +49,32 @@ func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string) *Bootstrap currentWorkingDir = "/" } + if runtimeDomainRoot == "" { + runtimeDomainRoot = "/" + } + return &Bootstrap{ orderedLookupPaths: orderedLookupBootstrapPaths, workingDir: currentWorkingDir, cmdCandidates: cmdCandidates, + runtimeDomainRoot: runtimeDomainRoot, } } -func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string) *Bootstrap { +func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string, runtimeDomainRoot string) *Bootstrap { if currentWorkingDir == "" { // use the root directory as the default working directory currentWorkingDir = "/" } + if runtimeDomainRoot == "" { + runtimeDomainRoot = "/" + } // a single candidate command makes it automatically valid return &Bootstrap{ - validCmd: cmd, - workingDir: currentWorkingDir, + validCmd: cmd, + workingDir: currentWorkingDir, + runtimeDomainRoot: runtimeDomainRoot, } } @@ -68,16 +82,28 @@ func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string) *Bootstrap { // actual bootstrap, given a list of possible files func (b *Bootstrap) locateBootstrap() error { for i, bootstrapCandidate := range b.orderedLookupPaths { - if file, err := os.Stat(bootstrapCandidate); !os.IsNotExist(err) && !file.IsDir() { - b.validCmd = b.cmdCandidates[i] - return nil + // validate path relatively to the domain's root + candidatPath := path.Join(b.runtimeDomainRoot, bootstrapCandidate) + file, err := os.Stat(candidatPath) + if err != nil { + if !os.IsNotExist(err) { + log.WithError(err).Warnf("Could not validate %s. Ignoring it.", bootstrapCandidate) + } + continue } + if file.IsDir() { + log.Warnf("%s is a directory. Ignoring it", bootstrapCandidate) + continue + } + b.validCmd = b.cmdCandidates[i] + return nil } log.WithField("bootstrapPathsChecked", b.orderedLookupPaths).Warn("Couldn't find valid bootstrap(s)") return fmt.Errorf("Couldn't find valid bootstrap(s): %s", b.orderedLookupPaths) } -// Cmd returns the args of bootstrap, where args[0] +// Cmd returns the args of bootstrap, relative to the +// chroot idenfied by `root`, where args[0] // is the path to executable func (b *Bootstrap) Cmd() ([]string, error) { if len(b.validCmd) > 0 { @@ -94,16 +120,21 @@ func (b *Bootstrap) Cmd() ([]string, error) { // Env returns the environment variables available to // the bootstrap process -func (b *Bootstrap) Env(e rapid.EnvironmentVariables) []string { +func (b *Bootstrap) Env(e interop.EnvironmentVariables) map[string]string { return e.RuntimeExecEnv() } // Cwd returns the working directory of the bootstrap process +// The path is validated against the chroot identified by `root` func (b *Bootstrap) Cwd() (string, error) { if !filepath.IsAbs(b.workingDir) { return "", fmt.Errorf("the working directory '%s' is invalid, it needs to be an absolute path", b.workingDir) - } else if _, err := os.Stat(b.workingDir); os.IsNotExist(err) { - return "", fmt.Errorf("the working directory doesn't exist: %s", b.workingDir) + } + + // evaluate the path relatively to the domain's mnt namespace root + domainPath := path.Join(b.runtimeDomainRoot, b.workingDir) + if _, err := os.Stat(domainPath); os.IsNotExist(err) { + return "", fmt.Errorf("the working directory doesn't exist: %s", domainPath) } return b.workingDir, nil @@ -140,19 +171,35 @@ func (b *Bootstrap) SetCachedFatalError(bootstrapErrFn BootstrapError) { // BootstrapErrInvalidLCISTaskConfig represents an error while parsing LCIS task config func BootstrapErrInvalidLCISTaskConfig(err error) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidTaskConfig, logging.SupernovaInvalidTaskConfigRepr(err) + return fatalerror.InvalidTaskConfig, SupernovaInvalidTaskConfigRepr(err) } } // BootstrapErrInvalidLCISEntrypoint represents an invalid LCIS entrypoint error func BootstrapErrInvalidLCISEntrypoint(entrypoint []string, cmd []string, workingdir string) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidEntrypoint, logging.SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) + return fatalerror.InvalidEntrypoint, SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) } } func BootstrapErrInvalidLCISWorkingDir(entrypoint []string, cmd []string, workingdir string) BootstrapError { return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidWorkingDir, logging.SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) + return fatalerror.InvalidWorkingDir, SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) + } +} + +func SupernovaInvalidTaskConfigRepr(err error) func(error) string { + return func(unused error) string { + return fmt.Sprintf("IMAGE\tInvalid task config: %s", err) + } +} + +func SupernovaLaunchErrorRepr(entrypoint []string, cmd []string, workingDir string) func(error) string { + return func(err error) string { + return fmt.Sprintf("IMAGE\tLaunch error: %s\tEntrypoint: [%s]\tCmd: [%s]\tWorkingDir: [%s]", + err, + strings.Join(entrypoint, ","), + strings.Join(cmd, ","), + workingDir) } } diff --git a/lambda/rapidcore/bootstrap_test.go b/lambda/rapidcore/bootstrap_test.go index 4700130..b43520d 100644 --- a/lambda/rapidcore/bootstrap_test.go +++ b/lambda/rapidcore/bootstrap_test.go @@ -4,8 +4,10 @@ package rapidcore import ( - "io/ioutil" "os" + "path" + "path/filepath" + "reflect" "testing" "go.amzn.com/lambda/rapidcore/env" @@ -14,11 +16,11 @@ import ( ) func TestBootstrap(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") assert.NoError(t, err) defer os.RemoveAll(tmpDir) - tmpFile, err := ioutil.TempFile("", "lcis-test-bootstrap") + tmpFile, err := os.CreateTemp("", "lcis-test-bootstrap") assert.NoError(t, err) defer os.Remove(tmpFile.Name()) @@ -38,11 +40,57 @@ func TestBootstrap(t *testing.T) { environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") // Test - b := NewBootstrap(cmdCandidates, cwd) + b := NewBootstrap(cmdCandidates, cwd, "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, cwd, bCwd) - assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) + + cmd, err := b.Cmd() + assert.NoError(t, err) + assert.Equal(t, file, cmd) +} + +// When running bootstraps in separate mount namespaces +// we want to verify and discover paths relative to +// a root different from "/" +func TestBootstrapChroot(t *testing.T) { + tmpRoot, err := os.MkdirTemp(os.TempDir(), "domain-root") + assert.NoError(t, err) + defer os.RemoveAll(tmpRoot) + tmpDir, err := os.MkdirTemp(tmpRoot, "lcis-test-invalid-bootstrap") + assert.NoError(t, err) + defer os.RemoveAll(tmpDir) + + tmpFile, err := os.CreateTemp(tmpRoot, "lcis-test-bootstrap") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // Setup cmd candidates + nonExistent := []string{"/foo/bar/baz"} + baseName := filepath.Base(tmpDir) + dir := []string{"/" + baseName, "--arg1", "foo"} + baseName = filepath.Base(tmpFile.Name()) + file := []string{"/" + baseName, "--arg1 s", "foo"} + cmdCandidates := [][]string{nonExistent, dir, file} + + // Setup working dir + cwd, err := os.MkdirTemp(tmpRoot, "cwd") + assert.NoError(t, err) + defer os.RemoveAll(cwd) + + // Setup environment + environment := env.NewEnvironment() + environment.StoreRuntimeAPIEnvironmentVariable("host:port") + environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") + + // Test + baseName = filepath.Base(cwd) + b := NewBootstrap(cmdCandidates, "/"+baseName, tmpRoot) + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, cwd, path.Join(tmpRoot, bCwd)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) cmd, err := b.Cmd() assert.NoError(t, err) @@ -53,17 +101,24 @@ func TestBootstrapEmptyCandidate(t *testing.T) { // we expect newBootstrap to succeed and bootstrap.Cmd() to fail. // We want to postpone the failure to be able to propagate error description to slicer and write it to customer log invalidBootstrapCandidate := []string{} - bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/") + bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/", "") + _, err := bs.Cmd() + assert.Error(t, err) +} + +func TestBootstrapChrootNonExistingRoot(t *testing.T) { + invalidBootstrapCandidate := []string{"/bin/bash", "-c"} + bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/", "/does_not_exist") _, err := bs.Cmd() assert.Error(t, err) } func TestBootstrapSingleCmd(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") assert.NoError(t, err) defer os.RemoveAll(tmpDir) - tmpFile, err := ioutil.TempFile("", "lcis-test-bootstrap") + tmpFile, err := os.CreateTemp("", "lcis-test-bootstrap") assert.NoError(t, err) defer os.Remove(tmpFile.Name()) @@ -81,11 +136,11 @@ func TestBootstrapSingleCmd(t *testing.T) { environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") // Test - b := NewBootstrapSingleCmd(cmdCandidate, cwd) + b := NewBootstrapSingleCmd(cmdCandidate, cwd, "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, cwd, bCwd) - assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) cmd, err := b.Cmd() assert.NoError(t, err) @@ -93,7 +148,7 @@ func TestBootstrapSingleCmd(t *testing.T) { } func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { - tmpDir, err := ioutil.TempDir("", "lcis-test-invalid-bootstrap") + tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") assert.NoError(t, err) defer os.RemoveAll(tmpDir) @@ -111,11 +166,11 @@ func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") // Test - b := NewBootstrapSingleCmd(cmdCandidate, cwd) + b := NewBootstrapSingleCmd(cmdCandidate, cwd, "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, cwd, bCwd) - assert.ElementsMatch(t, environment.RuntimeExecEnv(), b.Env(environment)) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) // No validations run against single candidates cmd, err := b.Cmd() @@ -125,100 +180,100 @@ func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { // Test our ability to locate bootstrap files in the file system func TestFindCustomRuntimeIfExists(t *testing.T) { - tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp-") + tmpFile, err := os.CreateTemp(os.TempDir(), "tmp-") if err != nil { t.Fatal("Cannot create temporary file", err) } defer os.Remove(tmpFile.Name()) - tmpFile2, err := ioutil.TempFile(os.TempDir(), "tmp-") + tmpFile2, err := os.CreateTemp(os.TempDir(), "tmp-") if err != nil { t.Fatal("Cannot create temporary file", err) } defer os.Remove(tmpFile2.Name()) // one bootstrap argument was given and it exists - bootstrap := NewBootstrap([][]string{[]string{tmpFile.Name()}}, "/") + bootstrap := NewBootstrap([][]string{{tmpFile.Name()}}, "/", "") cmd, err := bootstrap.Cmd() assert.NoError(t, err) assert.Equal(t, []string{tmpFile.Name()}, cmd) assert.Nil(t, err) // two bootstrap arguments given, both exist but first one is returned - bootstrap = NewBootstrap([][]string{[]string{tmpFile.Name()}, []string{tmpFile2.Name()}}, "/") + bootstrap = NewBootstrap([][]string{{tmpFile.Name()}, {tmpFile2.Name()}}, "/", "") cmd, err = bootstrap.Cmd() assert.NoError(t, err) assert.Equal(t, []string{tmpFile.Name()}, cmd) assert.Nil(t, err) // two bootstrap arguments given, first one does not exist, second exists and is returned - bootstrap = NewBootstrap([][]string{[]string{"mk"}, []string{tmpFile2.Name()}}, "/") + bootstrap = NewBootstrap([][]string{{"mk"}, {tmpFile2.Name()}}, "/", "") cmd, err = bootstrap.Cmd() assert.NoError(t, err) assert.Equal(t, []string{tmpFile2.Name()}, cmd) assert.Nil(t, err) // two bootstrap arguments given, none exists - bootstrap = NewBootstrap([][]string{[]string{"mk"}, []string{"mk2"}}, "/") + bootstrap = NewBootstrap([][]string{{"mk"}, {"mk2"}}, "/", "") cmd, err = bootstrap.Cmd() assert.EqualError(t, err, "Couldn't find valid bootstrap(s): [mk mk2]") assert.Equal(t, []string{}, cmd) } func TestCwdIsAbsolute(t *testing.T) { - tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp-") + tmpFile, err := os.CreateTemp(os.TempDir(), "tmp-") if err != nil { t.Fatal("Cannot create temporary file", err) } defer os.Remove(tmpFile.Name()) - cmdCandidates := [][]string{[]string{tmpFile.Name()}} + cmdCandidates := [][]string{{tmpFile.Name()}} // no errors when currentWorkingDir is absolute - bootstrap := NewBootstrap(cmdCandidates, "/tmp") + bootstrap := NewBootstrap(cmdCandidates, "/tmp", "") cwd, err := bootstrap.Cwd() assert.Nil(t, err) assert.Equal(t, "/tmp", cwd) - bootstrap = NewBootstrap(cmdCandidates, "tmp") + bootstrap = NewBootstrap(cmdCandidates, "tmp", "") _, err = bootstrap.Cwd() assert.EqualError(t, err, "the working directory 'tmp' is invalid, it needs to be an absolute path") - bootstrap = NewBootstrap(cmdCandidates, "./") + bootstrap = NewBootstrap(cmdCandidates, "./", "") _, err = bootstrap.Cwd() assert.EqualError(t, err, "the working directory './' is invalid, it needs to be an absolute path") } func TestBootstrapMissingWorkingDirectory(t *testing.T) { - tmpFile, err := ioutil.TempFile(os.TempDir(), "cwd-test-bootstrap") + tmpFile, err := os.CreateTemp(os.TempDir(), "cwd-test-bootstrap") assert.NoError(t, err) defer os.Remove(tmpFile.Name()) - tmpDir, err := ioutil.TempDir("", "cwd-test") + tmpDir, err := os.MkdirTemp("", "cwd-test") assert.NoError(t, err) defer os.RemoveAll(tmpDir) // cwd argument exists - bootstrap := NewBootstrap([][]string{[]string{tmpFile.Name()}}, tmpDir) + bootstrap := NewBootstrap([][]string{{tmpFile.Name()}}, tmpDir, "") cwd, err := bootstrap.Cwd() assert.Equal(t, cwd, tmpDir) assert.NoError(t, err) // cwd argument doesn't exist - bootstrap = NewBootstrap([][]string{[]string{tmpFile.Name()}}, "/foo") + bootstrap = NewBootstrap([][]string{{tmpFile.Name()}}, "/foo", "") _, err = bootstrap.Cwd() assert.EqualError(t, err, "the working directory doesn't exist: /foo") } func TestDefaultWorkeringDirectory(t *testing.T) { - bootstrap := NewBootstrap([][]string{[]string{}}, "") + bootstrap := NewBootstrap([][]string{{}}, "", "") cwd, err := bootstrap.Cwd() assert.NoError(t, err) assert.Equal(t, "/", cwd) } func TestBootstrapSingleCmdDefaultWorkingDir(t *testing.T) { - b := NewBootstrapSingleCmd([]string{}, "") + b := NewBootstrapSingleCmd([]string{}, "", "") bCwd, err := b.Cwd() assert.NoError(t, err) assert.Equal(t, "/", bCwd) diff --git a/lambda/rapidcore/env/environment.go b/lambda/rapidcore/env/environment.go index 699abda..be0584c 100644 --- a/lambda/rapidcore/env/environment.go +++ b/lambda/rapidcore/env/environment.go @@ -149,26 +149,24 @@ func (e *Environment) mergeCustomerEnvironmentVariables(envVars map[string]strin // RuntimeExecEnv returns the key=value strings of all environment variables // passed to runtime process on exec() -func (e *Environment) RuntimeExecEnv() []string { +func (e *Environment) RuntimeExecEnv() map[string]string { if !e.initEnvVarsSet || !e.runtimeAPISet { log.Fatal("credentials, customer and runtime API address must be set") } - return asEnviron(mapUnion(e.Customer, e.PlatformUnreserved, e.Credentials, e.Runtime, e.Platform)) + return mapUnion(e.Customer, e.PlatformUnreserved, e.Credentials, e.Runtime, e.Platform) } // AgentExecEnv returns the key=value strings of all environment variables // passed to agent process on exec() -func (e *Environment) AgentExecEnv() []string { +func (e *Environment) AgentExecEnv() map[string]string { if !e.initEnvVarsSet || !e.runtimeAPISet { log.Fatal("credentials, customer and runtime API address must be set") } excludedKeys := extensionExcludedKeys() excludeCondition := func(key string) bool { return excludedKeys[key] || strings.HasPrefix(key, "_") } - environ := asEnviron(mapExclude(mapUnion(e.Customer, e.Credentials, e.Platform), excludeCondition)) - - return environ + return mapExclude(mapUnion(e.Customer, e.Credentials, e.Platform), excludeCondition) } // RAPIDInternalConfig returns the rapid config parsed from environment vars @@ -249,14 +247,6 @@ func mapUnion(maps ...map[string]string) map[string]string { return union } -func asEnviron(m map[string]string) []string { - keySepValArray := []string{} - for key, val := range m { - keySepValArray = append(keySepValArray, key+"="+val) - } - return keySepValArray -} - func mapExclude(m map[string]string, excludeCondition func(string) bool) map[string]string { res := map[string]string{} for key, val := range m { diff --git a/lambda/rapidcore/env/environment_test.go b/lambda/rapidcore/env/environment_test.go index cdfef24..ed3043c 100644 --- a/lambda/rapidcore/env/environment_test.go +++ b/lambda/rapidcore/env/environment_test.go @@ -6,12 +6,21 @@ package env import ( "fmt" "os" - "strings" "testing" "github.com/stretchr/testify/assert" ) +func envToSlice(env map[string]string) []string { + ret := make([]string, len(env)) + i := 0 + for key, val := range env { + ret[i] = key + "=" + val + i++ + } + return ret +} + func TestRAPIDInternalConfig(t *testing.T) { os.Clearenv() os.Setenv("_LAMBDA_SB_ID", "sbid") @@ -121,34 +130,35 @@ func TestRuntimeExecEnvironmentVariables(t *testing.T) { rapidEnvVars := env.RuntimeExecEnv() var rapidEnvKeys []string - for _, keyval := range rapidEnvVars { - key := strings.Split(keyval, "=")[0] + for key := range rapidEnvVars { rapidEnvKeys = append(rapidEnvKeys, key) } + rapidEnvVarsSlice := envToSlice(rapidEnvVars) + for key := range env.RAPID { assert.NotContains(t, rapidEnvKeys, key) } for key, val := range env.Runtime { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.Platform { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.PlatformUnreserved { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) assert.NotContains(t, env.Customer, key) } for key, val := range env.Credentials { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.Customer { - assert.Contains(t, rapidEnvVars, key+"="+val) + assert.Contains(t, rapidEnvVarsSlice, key+"="+val) assert.NotContains(t, env.PlatformUnreserved, key) } } @@ -191,7 +201,7 @@ func TestRuntimeExecEnvironmentVariablesPriority(t *testing.T) { assert.Equal(t, len(predefinedInternalEnvVarKeys()), len(env.RAPID)) assert.Equal(t, len(predefinedRuntimeEnvVarKeys()), len(env.Runtime)) - rapidEnvVars := env.RuntimeExecEnv() + rapidEnvVars := envToSlice(env.RuntimeExecEnv()) // Customer env vars cannot override platform/internal ones assert.NotContains(t, rapidEnvVars, conflictPlatformKeyFromInit+"="+customerEnvVal) @@ -224,7 +234,7 @@ func TestCustomerEnvironmentVariablesFromInitCanOverrideEnvironmentVariablesFrom assert.Equal(t, env.Customer["LCIS_ARG1"], lcisCLIArgEnvVal) assert.Equal(t, env.Customer["MY_FOOBAR_ENV_1"], customerEnvVal) - rapidEnvVars := env.RuntimeExecEnv() + rapidEnvVars := envToSlice(env.RuntimeExecEnv()) assert.Contains(t, rapidEnvVars, "LCIS_ARG1="+lcisCLIArgEnvVal) assert.Contains(t, rapidEnvVars, "MY_FOOBAR_ENV_1="+customerEnvVal) @@ -250,17 +260,18 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { agentEnvVars := env.AgentExecEnv() var agentEnvKeys []string - for _, keyval := range agentEnvVars { - key := strings.Split(keyval, "=")[0] + for key := range agentEnvVars { agentEnvKeys = append(agentEnvKeys, key) } + agentEnvVarsSlice := envToSlice(agentEnvVars) + for key := range env.RAPID { assert.NotContains(t, agentEnvKeys, key) } for key, val := range env.Runtime { - assert.NotContains(t, agentEnvKeys, key+"="+val) + assert.NotContains(t, agentEnvVarsSlice, key+"="+val) } for key := range env.Platform { @@ -272,10 +283,10 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { } for key, val := range env.Credentials { - assert.Contains(t, agentEnvVars, key+"="+val) + assert.Contains(t, agentEnvVarsSlice, key+"="+val) } - assert.Contains(t, agentEnvVars, runtimeAPIAddressKey+"="+env.Platform[runtimeAPIAddressKey]) + assert.Contains(t, agentEnvVarsSlice, runtimeAPIAddressKey+"="+env.Platform[runtimeAPIAddressKey]) } func TestStoreEnvironmentVariablesFromInitCaching(t *testing.T) { diff --git a/lambda/rapidcore/errors.go b/lambda/rapidcore/errors.go index 06a4830..7f35ca8 100644 --- a/lambda/rapidcore/errors.go +++ b/lambda/rapidcore/errors.go @@ -5,9 +5,9 @@ package rapidcore import "errors" -var ErrInitAlreadyDone = errors.New("InitAlreadyDone") var ErrInitDoneFailed = errors.New("InitDoneFailed") -var ErrInitError = errors.New("InitError") +var ErrInitNotStarted = errors.New("InitNotStarted") +var ErrInitResetReceived = errors.New("InitResetReceived") var ErrNotReserved = errors.New("NotReserved") var ErrAlreadyReserved = errors.New("AlreadyReserved") @@ -23,5 +23,3 @@ var ErrReleaseReservationDone = errors.New("ReleaseReservationDone") var ErrInternalServerError = errors.New("InternalServerError") var ErrInvokeTimeout = errors.New("InvokeTimeout") - -var ErrTerminated = errors.New("SandboxTerminated") // sent to signal a process exit diff --git a/lambda/rapidcore/sandbox.go b/lambda/rapidcore/sandbox.go deleted file mode 100644 index 7d5a8a9..0000000 --- a/lambda/rapidcore/sandbox.go +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "context" - "io" - "io/ioutil" - "net/http" - "os" - "os/signal" - "syscall" - - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/logging" - "go.amzn.com/lambda/rapid" - "go.amzn.com/lambda/rapidcore/env" - "go.amzn.com/lambda/telemetry" - - log "github.com/sirupsen/logrus" -) - -const ( - defaultSigtermResetTimeoutMs = int64(2000) -) - -type Sandbox interface { - Init(i *interop.Init, invokeTimeoutMs int64) - Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error - InteropServer() InteropServer -} - -type ReserveResponse struct { - Token interop.Token - InternalState *statejson.InternalStateDescription -} - -type InteropServer interface { - FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error - Reserve(id string, traceID, lambdaSegmentID string) (*ReserveResponse, error) - Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) - AwaitRelease() (*statejson.InternalStateDescription, error) - Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription - InternalState() (*statejson.InternalStateDescription, error) - CurrentToken() *interop.Token -} - -type SandboxBuilder struct { - sandbox *rapid.Sandbox - defaultInteropServer *Server - useCustomInteropServer bool - shutdownFuncs []context.CancelFunc - debugTailLogWriter io.Writer - platformLogWriter io.Writer -} - -type logSink int - -const ( - RuntimeLogSink logSink = iota - ExtensionLogSink -) - -func NewSandboxBuilder(bootstrap *Bootstrap) *SandboxBuilder { - defaultInteropServer := NewServer(context.Background()) - signalCtx, cancelSignalCtx := context.WithCancel(context.Background()) - logsEgressAPI := &telemetry.NoOpLogsEgressAPI{} - runtimeStdoutWriter, runtimeStderrWriter, _ := logsEgressAPI.GetRuntimeSockets() - - b := &SandboxBuilder{ - sandbox: &rapid.Sandbox{ - Bootstrap: bootstrap, - PreLoadTimeNs: 0, // TODO - StandaloneMode: true, - RuntimeStdoutWriter: runtimeStdoutWriter, - RuntimeStderrWriter: runtimeStderrWriter, - LogsEgressAPI: logsEgressAPI, - EnableTelemetryAPI: false, - Environment: env.NewEnvironment(), - Tracer: telemetry.NewNoOpTracer(), - SignalCtx: signalCtx, - EventsAPI: &telemetry.NoOpEventsAPI{}, - InitCachingEnabled: false, - }, - defaultInteropServer: defaultInteropServer, - shutdownFuncs: []context.CancelFunc{}, - debugTailLogWriter: ioutil.Discard, - platformLogWriter: ioutil.Discard, - } - - b.AddShutdownFunc(context.CancelFunc(func() { - log.Info("Shutting down...") - defaultInteropServer.Reset("SandboxTerminated", defaultSigtermResetTimeoutMs) - cancelSignalCtx() - })) - - return b -} - -func (b *SandboxBuilder) SetInteropServer(interopServer interop.Server) *SandboxBuilder { - b.sandbox.InteropServer = interopServer - b.useCustomInteropServer = true - return b -} - -func (b *SandboxBuilder) SetEventsAPI(eventsAPI telemetry.EventsAPI) *SandboxBuilder { - b.sandbox.EventsAPI = eventsAPI - return b -} - -func (b *SandboxBuilder) SetTracer(tracer telemetry.Tracer) *SandboxBuilder { - b.sandbox.Tracer = tracer - return b -} - -func (b *SandboxBuilder) DisableStandaloneMode() *SandboxBuilder { - b.sandbox.StandaloneMode = false - return b -} - -func (b *SandboxBuilder) SetExtensionsFlag(extensionsEnabled bool) *SandboxBuilder { - if extensionsEnabled { - extensions.Enable() - } else { - extensions.Disable() - } - return b -} - -func (b *SandboxBuilder) SetInitCachingFlag(initCachingEnabled bool) *SandboxBuilder { - b.sandbox.InitCachingEnabled = initCachingEnabled - return b -} - -func (b *SandboxBuilder) SetPreLoadTimeNs(preLoadTimeNs int64) *SandboxBuilder { - b.sandbox.PreLoadTimeNs = preLoadTimeNs - return b -} - -func (b *SandboxBuilder) SetEnvironmentVariables(environment *env.Environment) *SandboxBuilder { - b.sandbox.Environment = environment - return b -} - -func (b *SandboxBuilder) SetPlatformLogOutput(w io.Writer) *SandboxBuilder { - b.platformLogWriter = w - return b -} - -func (b *SandboxBuilder) SetTailLogOutput(w io.Writer) *SandboxBuilder { - b.debugTailLogWriter = w - return b -} - -func (b *SandboxBuilder) SetLogsSubscriptionAPI(logsSubscriptionAPI telemetry.LogsSubscriptionAPI) *SandboxBuilder { - b.sandbox.EnableTelemetryAPI = true - b.sandbox.LogsSubscriptionAPI = logsSubscriptionAPI - return b -} - -func (b *SandboxBuilder) SetLogsEgressAPI(logsEgressAPI telemetry.LogsEgressAPI) *SandboxBuilder { - runtimeStdoutWriter, runtimeStderrWriter, err := logsEgressAPI.GetRuntimeSockets() - - if err != nil { - log.WithError(err).Fatal("failed to get the Runtime sockets from the logs egress API") - } - - b.sandbox.LogsEgressAPI = logsEgressAPI - b.sandbox.RuntimeStdoutWriter = runtimeStdoutWriter - b.sandbox.RuntimeStderrWriter = runtimeStderrWriter - return b -} - -func (b *SandboxBuilder) SetHandler(handler string) *SandboxBuilder { - b.sandbox.Handler = handler - return b -} - -func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc context.CancelFunc) *SandboxBuilder { - b.shutdownFuncs = append(b.shutdownFuncs, shutdownFunc) - return b -} - -func (b *SandboxBuilder) setupLoggingWithDebugLogs() { - // Compose debug log writer with all log sinks. Debug log writer w - // will not write logs when disabled by invoke parameter - b.sandbox.DebugTailLogger = logging.NewTailLogWriter(b.debugTailLogWriter) - b.sandbox.PlatformLogger = logging.NewPlatformLogger(b.platformLogWriter, b.sandbox.DebugTailLogger) - b.sandbox.RuntimeStdoutWriter = io.MultiWriter(b.sandbox.DebugTailLogger, b.sandbox.RuntimeStdoutWriter) - b.sandbox.RuntimeStderrWriter = io.MultiWriter(b.sandbox.DebugTailLogger, b.sandbox.RuntimeStderrWriter) -} - -func (b *SandboxBuilder) Create() { - if len(b.sandbox.Handler) > 0 { - b.sandbox.Environment.SetHandler(b.sandbox.Handler) - } - - if !b.useCustomInteropServer { - b.sandbox.InteropServer = b.defaultInteropServer - } - - b.setupLoggingWithDebugLogs() - - go signalHandler(b.shutdownFuncs) - - rapid.Start(b.sandbox) -} - -func (b *SandboxBuilder) Init(i *interop.Init, timeoutMs int64) { - b.sandbox.InteropServer.Init(&interop.Start{ - Handler: i.Handler, - CorrelationID: i.CorrelationID, - AwsKey: i.AwsKey, - AwsSecret: i.AwsSecret, - AwsSession: i.AwsSession, - XRayDaemonAddress: i.XRayDaemonAddress, - FunctionName: i.FunctionName, - FunctionVersion: i.FunctionVersion, - CustomerEnvironmentVariables: i.CustomerEnvironmentVariables, - }, timeoutMs) -} - -func (b *SandboxBuilder) Invoke(w http.ResponseWriter, i *interop.Invoke) error { - return b.sandbox.InteropServer.Invoke(w, i) -} - -func (b *SandboxBuilder) InteropServer() InteropServer { - return b.defaultInteropServer -} - -// SetLogLevel sets the log level for internal logging. Needs to be called very -// early during startup to configure logs emitted during initialization -func SetLogLevel(logLevel string) { - level, err := log.ParseLevel(logLevel) - if err != nil { - log.WithError(err).Fatal("Failed to set log level. Valid log levels are:", log.AllLevels) - } - - log.SetLevel(level) - log.SetFormatter(&logging.InternalFormatter{}) -} - -func SetInternalLogOutput(w io.Writer) { - logging.SetOutput(w) -} - -// Trap SIGINT and SIGTERM signals and call shutdown function -func signalHandler(shutdownFuncs []context.CancelFunc) { - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) - sigReceived := <-sig - log.WithField("signal", sigReceived.String()).Info("Received signal") - for _, shutdownFunc := range shutdownFuncs { - shutdownFunc() - } -} diff --git a/lambda/rapidcore/sandbox_api.go b/lambda/rapidcore/sandbox_api.go new file mode 100644 index 0000000..0c7052e --- /dev/null +++ b/lambda/rapidcore/sandbox_api.go @@ -0,0 +1,147 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "go.amzn.com/lambda/interop" +) + +// SandboxContext and other structs form the implementation of the SandboxAPI +// interface defined in interop/sandbox_model.go, using the implementation of +// Init, Invoke and Reset handlers in rapid/sandbox.go +type SandboxContext struct { + rapidCtx interop.RapidContext + handler string + runtimeAPIAddress string + + InvokeReceivedTime int64 + InvokeResponseMetrics *interop.InvokeResponseMetrics +} + +type initContext struct { + initSuccessChan chan interop.InitSuccess + initFailureChan chan interop.InitFailure + rapidCtx interop.RapidContext + sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke +} + +type invokeContext struct { + rapidCtx interop.RapidContext + invokeRequestChan chan *interop.Invoke + invokeSuccessChan chan interop.InvokeSuccess + invokeFailureChan chan interop.InvokeFailure +} + +// Validate interface compliance +var _ interop.SandboxContext = (*SandboxContext)(nil) +var _ interop.InitContext = (*initContext)(nil) +var _ interop.InvokeContext = (*invokeContext)(nil) + +func (s SandboxContext) Init(init *interop.Init, timeoutMs int64) (interop.InitStarted, interop.InitContext) { + initStartedResponseChan := make(chan interop.InitStarted) + initSuccessResponseChan := make(chan interop.InitSuccess) + initFailureResponseChan := make(chan interop.InitFailure) + + if len(s.handler) > 0 { + init.EnvironmentVariables.SetHandler(s.handler) + } + + init.EnvironmentVariables.StoreRuntimeAPIEnvironmentVariable(s.runtimeAPIAddress) + + go s.rapidCtx.HandleInit(init, initStartedResponseChan, initSuccessResponseChan, initFailureResponseChan) + initStarted := <-initStartedResponseChan + + sbMetadata := interop.SandboxInfoFromInit{ + EnvironmentVariables: init.EnvironmentVariables, + SandboxType: init.SandboxType, + RuntimeBootstrap: init.Bootstrap, + } + return initStarted, newInitContext(s.rapidCtx, sbMetadata, initSuccessResponseChan, initFailureResponseChan) +} + +func (s SandboxContext) Reset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { + defer s.rapidCtx.Clear() + return s.rapidCtx.HandleReset(reset, s.InvokeReceivedTime, s.InvokeResponseMetrics) +} + +func (s SandboxContext) Shutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { + return s.rapidCtx.HandleShutdown(shutdown) +} + +func (s SandboxContext) Restore(restore *interop.Restore) error { + return s.rapidCtx.HandleRestore(restore) +} + +func (s *SandboxContext) SetInvokeReceivedTime(invokeReceivedTime int64) { + s.InvokeReceivedTime = invokeReceivedTime +} + +func (s *SandboxContext) SetInvokeResponseMetrics(metrics *interop.InvokeResponseMetrics) { + s.InvokeResponseMetrics = metrics +} + +func newInitContext(r interop.RapidContext, sbMetadata interop.SandboxInfoFromInit, + initSuccessChan chan interop.InitSuccess, initFailureChan chan interop.InitFailure) initContext { + return initContext{ + initSuccessChan: initSuccessChan, + initFailureChan: initFailureChan, + rapidCtx: r, + sbInfoFromInit: sbMetadata, + } +} + +func (i initContext) Wait() (interop.InitSuccess, *interop.InitFailure) { + select { + case initSuccess, isOpen := <-i.initSuccessChan: + if !isOpen { + // If init has already suceeded, we return quickly + return interop.InitSuccess{}, nil + } + return initSuccess, nil + case initFailure, isOpen := <-i.initFailureChan: + if !isOpen { + // If init has already failed, we return quickly for init to be suppressed + return interop.InitSuccess{}, &initFailure + } + return interop.InitSuccess{}, &initFailure + } +} + +func (i initContext) Reserve() interop.InvokeContext { + + invokeRequestChan := make(chan *interop.Invoke) + invokeSuccessChan := make(chan interop.InvokeSuccess) + invokeFailureChan := make(chan interop.InvokeFailure) + + go func() { + invoke := <-invokeRequestChan + // For suppressed inits, invoke needs the runtime and agent env vars + invokeSuccess, invokeFailure := i.rapidCtx.HandleInvoke(invoke, i.sbInfoFromInit) + if invokeFailure != nil { + invokeFailureChan <- *invokeFailure + } else { + invokeSuccessChan <- invokeSuccess + } + }() + + return invokeContext{ + rapidCtx: i.rapidCtx, + invokeRequestChan: invokeRequestChan, + invokeSuccessChan: invokeSuccessChan, + invokeFailureChan: invokeFailureChan, + } +} + +func (invCtx invokeContext) SendRequest(i *interop.Invoke) { + invCtx.invokeRequestChan <- i +} + +func (invCtx invokeContext) Wait() (interop.InvokeSuccess, *interop.InvokeFailure) { + select { + case invokeSuccess := <-invCtx.invokeSuccessChan: + return invokeSuccess, nil + case invokeFailure := <-invCtx.invokeFailureChan: + return interop.InvokeSuccess{}, &invokeFailure + } +} diff --git a/lambda/rapidcore/sandbox_builder.go b/lambda/rapidcore/sandbox_builder.go new file mode 100644 index 0000000..ce016a0 --- /dev/null +++ b/lambda/rapidcore/sandbox_builder.go @@ -0,0 +1,217 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "context" + "io" + "net" + "os" + "os/signal" + "strconv" + "syscall" + + "go.amzn.com/lambda/extensions" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/logging" + "go.amzn.com/lambda/rapid" + "go.amzn.com/lambda/supervisor" + supvmodel "go.amzn.com/lambda/supervisor/model" + "go.amzn.com/lambda/telemetry" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultSigtermResetTimeoutMs = int64(2000) +) + +type SandboxBuilder struct { + sandbox *rapid.Sandbox + sandboxContext interop.SandboxContext + lambdaInvokeAPI LambdaInvokeAPI + defaultInteropServer *Server + useCustomInteropServer bool + shutdownFuncs []context.CancelFunc + handler string +} + +type logSink int + +const ( + RuntimeLogSink logSink = iota + ExtensionLogSink +) + +func NewSandboxBuilder() *SandboxBuilder { + defaultInteropServer := NewServer(context.Background()) + signalCtx, cancelSignalCtx := context.WithCancel(context.Background()) + + b := &SandboxBuilder{ + sandbox: &rapid.Sandbox{ + PreLoadTimeNs: 0, // TODO + StandaloneMode: true, + LogsEgressAPI: &telemetry.NoOpLogsEgressAPI{}, + EnableTelemetryAPI: false, + Tracer: telemetry.NewNoOpTracer(), + SignalCtx: signalCtx, + EventsAPI: &telemetry.NoOpEventsAPI{}, + InitCachingEnabled: false, + Supervisor: supervisor.NewLocalSupervisor(), + RuntimeAPIHost: "127.0.0.1", + RuntimeAPIPort: 9001, + }, + defaultInteropServer: defaultInteropServer, + shutdownFuncs: []context.CancelFunc{}, + lambdaInvokeAPI: NewEmulatorAPI(defaultInteropServer), + } + + b.AddShutdownFunc(context.CancelFunc(func() { + log.Info("Shutting down...") + defaultInteropServer.Reset("SandboxTerminated", defaultSigtermResetTimeoutMs) + cancelSignalCtx() + })) + + return b +} + +func (b *SandboxBuilder) SetSupervisor(supervisor supvmodel.Supervisor) *SandboxBuilder { + b.sandbox.Supervisor = supervisor + return b +} + +func (b *SandboxBuilder) SetRuntimeAPIAddress(runtimeAPIAddress string) *SandboxBuilder { + host, port, err := net.SplitHostPort(runtimeAPIAddress) + if err != nil { + log.WithError(err).Warnf("Failed to parse RuntimeApiAddress: %s:", runtimeAPIAddress) + return b + } + + portInt, err := strconv.Atoi(port) + if err != nil { + log.WithError(err).Warnf("Failed to parse RuntimeApiPort: %s:", port) + return b + } + + b.sandbox.RuntimeAPIHost = host + b.sandbox.RuntimeAPIPort = portInt + return b +} + +func (b *SandboxBuilder) SetInteropServer(interopServer interop.Server) *SandboxBuilder { + b.sandbox.InteropServer = interopServer + b.useCustomInteropServer = true + return b +} + +func (b *SandboxBuilder) SetEventsAPI(eventsAPI telemetry.EventsAPI) *SandboxBuilder { + b.sandbox.EventsAPI = eventsAPI + return b +} + +func (b *SandboxBuilder) SetTracer(tracer telemetry.Tracer) *SandboxBuilder { + b.sandbox.Tracer = tracer + return b +} + +func (b *SandboxBuilder) DisableStandaloneMode() *SandboxBuilder { + b.sandbox.StandaloneMode = false + return b +} + +func (b *SandboxBuilder) SetExtensionsFlag(extensionsEnabled bool) *SandboxBuilder { + if extensionsEnabled { + extensions.Enable() + } else { + extensions.Disable() + } + return b +} + +func (b *SandboxBuilder) SetInitCachingFlag(initCachingEnabled bool) *SandboxBuilder { + b.sandbox.InitCachingEnabled = initCachingEnabled + return b +} + +func (b *SandboxBuilder) SetPreLoadTimeNs(preLoadTimeNs int64) *SandboxBuilder { + b.sandbox.PreLoadTimeNs = preLoadTimeNs + return b +} + +func (b *SandboxBuilder) SetTelemetrySubscription(logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI) *SandboxBuilder { + b.sandbox.EnableTelemetryAPI = true + b.sandbox.LogsSubscriptionAPI = logsSubscriptionAPI + b.sandbox.TelemetrySubscriptionAPI = telemetrySubscriptionAPI + return b +} + +func (b *SandboxBuilder) SetLogsEgressAPI(logsEgressAPI telemetry.StdLogsEgressAPI) *SandboxBuilder { + b.sandbox.LogsEgressAPI = logsEgressAPI + return b +} + +func (b *SandboxBuilder) SetHandler(handler string) *SandboxBuilder { + b.handler = handler + return b +} + +func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc context.CancelFunc) *SandboxBuilder { + b.shutdownFuncs = append(b.shutdownFuncs, shutdownFunc) + return b +} + +func (b *SandboxBuilder) Create() (interop.SandboxContext, interop.InternalStateGetter) { + if !b.useCustomInteropServer { + b.sandbox.InteropServer = b.defaultInteropServer + } + + go signalHandler(b.shutdownFuncs) + + rapidCtx, internalStateFn, runtimeAPIAddr := rapid.Start(b.sandbox) + + b.sandboxContext = &SandboxContext{ + rapidCtx: rapidCtx, + handler: b.handler, + runtimeAPIAddress: runtimeAPIAddr, + InvokeReceivedTime: int64(0), + InvokeResponseMetrics: nil, + } + + return b.sandboxContext, internalStateFn +} + +func (b *SandboxBuilder) DefaultInteropServer() *Server { + return b.defaultInteropServer +} + +func (b *SandboxBuilder) LambdaInvokeAPI() LambdaInvokeAPI { + return b.lambdaInvokeAPI +} + +// SetLogLevel sets the log level for internal logging. Needs to be called very +// early during startup to configure logs emitted during initialization +func SetLogLevel(logLevel string) { + level, err := log.ParseLevel(logLevel) + if err != nil { + log.WithError(err).Fatal("Failed to set log level. Valid log levels are:", log.AllLevels) + } + + log.SetLevel(level) + log.SetFormatter(&logging.InternalFormatter{}) +} + +func SetInternalLogOutput(w io.Writer) { + logging.SetOutput(w) +} + +// Trap SIGINT and SIGTERM signals and call shutdown function +func signalHandler(shutdownFuncs []context.CancelFunc) { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + sigReceived := <-sig + log.WithField("signal", sigReceived.String()).Info("Received signal") + for _, shutdownFunc := range shutdownFuncs { + shutdownFunc() + } +} diff --git a/lambda/rapidcore/sandbox_emulator_api.go b/lambda/rapidcore/sandbox_emulator_api.go new file mode 100644 index 0000000..6737631 --- /dev/null +++ b/lambda/rapidcore/sandbox_emulator_api.go @@ -0,0 +1,52 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "go.amzn.com/lambda/interop" + + "net/http" +) + +// LambdaInvokeAPI are the methods used by the Runtime Interface Emulator +type LambdaInvokeAPI interface { + Init(i *interop.Init, invokeTimeoutMs int64) + Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error +} + +// EmulatorAPI wraps the standalone interop server to provide a convenient interface +// for Rapid Standalone +type EmulatorAPI struct { + server *Server +} + +// Validate interface compliance +var _ LambdaInvokeAPI = (*EmulatorAPI)(nil) + +func NewEmulatorAPI(s *Server) *EmulatorAPI { + return &EmulatorAPI{s} +} + +// Init method is only used by the Runtime interface emulator +func (l *EmulatorAPI) Init(i *interop.Init, timeoutMs int64) { + l.server.Init(&interop.Init{ + Handler: i.Handler, + AwsKey: i.AwsKey, + AwsSecret: i.AwsSecret, + AwsSession: i.AwsSession, + XRayDaemonAddress: i.XRayDaemonAddress, + FunctionName: i.FunctionName, + FunctionVersion: i.FunctionVersion, + CustomerEnvironmentVariables: i.CustomerEnvironmentVariables, + RuntimeInfo: i.RuntimeInfo, + SandboxType: i.SandboxType, + Bootstrap: i.Bootstrap, + EnvironmentVariables: i.EnvironmentVariables, + }, timeoutMs) +} + +// Invoke method is only used by the Runtime interface emulator +func (l *EmulatorAPI) Invoke(w http.ResponseWriter, i *interop.Invoke) error { + return l.server.Invoke(w, i) +} diff --git a/lambda/rapidcore/server.go b/lambda/rapidcore/server.go index e3e01b6..e652130 100644 --- a/lambda/rapidcore/server.go +++ b/lambda/rapidcore/server.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "math" "net/http" "sync" @@ -17,6 +16,7 @@ import ( "go.amzn.com/lambda/core/directinvoke" "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" @@ -33,6 +33,12 @@ const ( resetDefaultTimeoutMs = 2000 ) +const ( + contentTypeHeader = "Content-Type" + errorTypeHeader = "Error-Type" + functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" +) + type rapidPhase int const ( @@ -46,7 +52,6 @@ type runtimeState int const ( runtimeNotStarted = iota - runtimeInitStarted runtimeInitError runtimeInitComplete runtimeInitFailed @@ -76,14 +81,11 @@ type InvokeContext struct { type Server struct { InternalStateGetter interop.InternalStateGetter - invokeChanOut chan *interop.Invoke - startChanOut chan *interop.Start - resetChanOut chan *interop.Reset - shutdownChanOut chan *interop.Shutdown - errorChanOut chan error + initChanOut chan *interop.Init + interruptedResponseChan chan *interop.Reset - sendRunningChan chan *interop.Running - sendResponseChan chan struct{} + sendRunningChan chan *interop.InitStarted + sendResponseChan chan *interop.InvokeResponseMetrics doneChan chan *interop.Done InitDoneChan chan DoneWithState @@ -100,12 +102,17 @@ type Server struct { rapidPhase rapidPhase runtimeState runtimeState -} -func (s *Server) StartAcceptingDirectInvokes() error { - return nil + sandboxContext interop.SandboxContext + initContext interop.InitContext + invoker interop.InvokeContext + initFailures chan interop.InitFailure + cachedInitErrorResponse *interop.ErrorResponse } +// Validate interface compliance +var _ interop.Server = (*Server)(nil) + func (s *Server) setRapidPhase(phase rapidPhase) { s.mutex.Lock() defer s.mutex.Unlock() @@ -185,6 +192,11 @@ func (s *Server) setNewInvokeContext(invokeID string, traceID, lambdaSegmentID s return resp, nil } +type ReserveResponse struct { + Token interop.Token + InternalState *statejson.InternalStateDescription +} + // Reserve allocates invoke context func (s *Server) Reserve(id string, traceID, lambdaSegmentID string) (*ReserveResponse, error) { invokeID := uuid.New().String() @@ -196,37 +208,28 @@ func (s *Server) Reserve(id string, traceID, lambdaSegmentID string) (*ReserveRe return nil, err } - resp.InternalState, err = s.waitInit() + // The two errors reserve returns in standalone mode are INIT timeout + // and INIT failure (two types of failure: runtime exit, /init/error). Both require suppressed + // initialization, so we succeed the reservation. + invCtx := s.initContext.Reserve() + s.invoker = invCtx + resp.InternalState, err = s.InternalState() + return resp, err } -func (s *Server) waitInit() (*statejson.InternalStateDescription, error) { - for { - select { - - case doneWithState, chanOpen := <-s.InitDoneChan: - if !chanOpen { - // init only happens once - return nil, ErrInitAlreadyDone - } - - close(s.InitDoneChan) // this was first call to reserve - - if s.getRuntimeState() == runtimeInitFailed { - return &doneWithState.State, ErrInitError - } - - if len(doneWithState.ErrorType) > 0 { - log.Errorf("INIT DONE failed: %s", doneWithState.ErrorType) - return &doneWithState.State, ErrInitDoneFailed - } - - return &doneWithState.State, nil - - case <-s.reservationContext.Done(): - return nil, ErrReserveReservationDone - } +func (s *Server) awaitInitCompletion() { + initSuccess, initFailure := s.initContext.Wait() + if initFailure != nil { + // In standalone, we don't have to block rapid start() goroutine until init failure is consumed + // because there is no channel back to the invoker until an invoke arrives via a Reserve() + initFailure.Ack <- struct{}{} + s.initFailures <- *initFailure + } else { + initSuccess.Ack <- struct{}{} } + // always closing the channel makes this method idempotent + close(s.initFailures) } func (s *Server) setReplyStream(w http.ResponseWriter, direct bool) (string, error) { @@ -263,6 +266,8 @@ func (s *Server) Release() error { s.reservationCancel() } + s.sandboxContext.SetInvokeReceivedTime(0) + s.sandboxContext.SetInvokeResponseMetrics(nil) s.invokeCtx = nil return nil } @@ -279,37 +284,18 @@ func (s *Server) GetCurrentInvokeID() string { return s.invokeCtx.Token.InvokeID } +// SetSandboxContext is used to set the sandbox context after intiialization of interop server. +// After refactoring all messages, this needs to be removed and made an struct parameter on initialization. +func (s *Server) SetSandboxContext(sbCtx interop.SandboxContext) { + s.sandboxContext = sbCtx +} + // SetInternalStateGetter is used to set callback which returnes internal state for /test/internalState request func (s *Server) SetInternalStateGetter(cb interop.InternalStateGetter) { s.InternalStateGetter = cb } -// StartChan returns Start emitter -func (s *Server) StartChan() <-chan *interop.Start { - return s.startChanOut -} - -// InvokeChan returns Invoke emitter -func (s *Server) InvokeChan() <-chan *interop.Invoke { - return s.invokeChanOut -} - -// ResetChan returns Reset emitter -func (s *Server) ResetChan() <-chan *interop.Reset { - return s.resetChanOut -} - -// ShutdownChan returns Shutdown emitter -func (s *Server) ShutdownChan() <-chan *interop.Shutdown { - return s.shutdownChanOut -} - -// InvalidMessageChan emits errors if there was something we could not parse -func (s *Server) TransportErrorChan() <-chan error { - return s.errorChanOut -} - -func (s *Server) sendResponseUnsafe(invokeID string, contentType string, status int, payload io.Reader) error { +func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[string]string, status int, payload io.Reader, trailers http.Header, request *interop.CancellableRequest, runtimeCalledResponse bool) error { if s.invokeCtx == nil || invokeID != s.invokeCtx.Token.InvokeID { return interop.ErrInvalidInvokeID } @@ -322,26 +308,15 @@ func (s *Server) sendResponseUnsafe(invokeID string, contentType string, status return fmt.Errorf("ReplyStream not available") } - // TODO: earlier, status was set to 500 if runtime called /invocation/error. However, the integration - // tests do not differentiate between /invocation/error and /invocation/response, but they check the error type: - // To identify user-errors, we should also allowlist custom errortypes and propagate them via headers. - - // s.invokeCtx.ReplyStream.WriteHeader(status) - var reportedErr error if s.invokeCtx.Direct { - if err := directinvoke.SendDirectInvokeResponse(map[string]string{"Content-Type": contentType}, payload, s.invokeCtx.ReplyStream); err != nil { + if err := directinvoke.SendDirectInvokeResponse(additionalHeaders, payload, trailers, s.invokeCtx.ReplyStream, s.interruptedResponseChan, s.sendResponseChan, request, runtimeCalledResponse); err != nil { // TODO: Do we need to drain the reader in case of a large payload and connection reuse? log.Errorf("Failed to write response to %s: %s", invokeID, err) - flusher, ok := s.invokeCtx.ReplyStream.(http.Flusher) - if !ok { - log.Error("Failed to flush response") - } - flusher.Flush() reportedErr = err } } else { - data, err := ioutil.ReadAll(payload) + data, err := io.ReadAll(payload) if err != nil { return fmt.Errorf("Failed to read response on %s: %s", invokeID, err) } @@ -352,73 +327,103 @@ func (s *Server) sendResponseUnsafe(invokeID string, contentType string, status } } - s.invokeCtx.ReplyStream.Header().Add("Content-Type", contentType) - if _, err := s.invokeCtx.ReplyStream.Write(data); err != nil { + startReadingResponseMonoTimeMs := metering.Monotime() + s.invokeCtx.ReplyStream.Header().Add(contentTypeHeader, additionalHeaders[contentTypeHeader]) + written, err := s.invokeCtx.ReplyStream.Write(data) + if err != nil { return fmt.Errorf("Failed to write response to %s: %s", invokeID, err) } + + s.sendResponseChan <- &interop.InvokeResponseMetrics{ + ProducedBytes: int64(written), + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: metering.Monotime(), + TimeShapedNs: int64(-1), + OutboundThroughputBps: int64(-1), + // FIXME: + // The runtime tells whether the function response mode is streaming or not. + // Ideally, we would want to use that value here. Since I'm just rebasing, I will leave + // as-is, but we should use that instead of relying on our memory to set this here + // because we "know" it's a streaming code path. + FunctionResponseMode: interop.FunctionResponseModeBuffered, + RuntimeCalledResponse: runtimeCalledResponse, + } } - s.sendResponseChan <- struct{}{} s.invokeCtx.ReplySent = true s.invokeCtx.Direct = false return reportedErr } -func (s *Server) SendResponse(invokeID string, contentType string, reader io.Reader) error { +func (s *Server) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { s.setRuntimeState(runtimeInvokeResponseSent) s.mutex.Lock() defer s.mutex.Unlock() - return s.sendResponseUnsafe(invokeID, contentType, http.StatusOK, reader) -} - -func (s *Server) CommitResponse() error { return nil } - -func (s *Server) SendRunning(run *interop.Running) error { - s.setRuntimeState(runtimeInitStarted) - s.sendRunningChan <- run - return nil + runtimeCalledResponse := true + return s.sendResponseUnsafe(invokeID, headers, http.StatusOK, reader, trailers, request, runtimeCalledResponse) } -func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorResponse) error { - switch s.getRapidPhase() { - case phaseInitializing: - s.setRuntimeState(runtimeInitError) - return nil - case phaseInvoking: - // This branch can also occur during a suppressed init error, which is reported as invoke error - s.setRuntimeState(runtimeInvokeError) - s.mutex.Lock() - defer s.mutex.Unlock() - return s.sendResponseUnsafe(invokeID, resp.ContentType, http.StatusInternalServerError, bytes.NewReader(resp.Payload)) - default: - panic("received unexpected error response outside invoke or init phases") +func (s *Server) SendInitErrorResponse(invokeID string, resp *interop.ErrorResponse) error { + log.Debugf("Sending Init Error Response: %s", resp.ErrorType) + if s.getRapidPhase() == phaseInvoking { + // This branch occurs during suppressed init + return s.SendErrorResponse(invokeID, resp) } -} -func (s *Server) SendDone(done *interop.Done) error { - s.doneChan <- done + // Handle an /init/error outside of the invoke phase + s.setCachedInitErrorResponse(resp) + s.setRuntimeState(runtimeInitError) return nil } -func (s *Server) SendDoneFail(doneFail *interop.DoneFail) error { - s.doneChan <- &interop.Done{ - ErrorType: doneFail.ErrorType, - CorrelationID: doneFail.CorrelationID, // filipovi: correlationID is required to dispatch message into correct channel - Meta: doneFail.Meta, +func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorResponse) error { + log.Debugf("Sending Error Response: %s", resp.ErrorType) + s.setRuntimeState(runtimeInvokeError) + s.mutex.Lock() + defer s.mutex.Unlock() + additionalHeaders := map[string]string{contentTypeHeader: resp.ContentType, errorTypeHeader: resp.ErrorType} + if functionResponseMode := resp.FunctionResponseMode; functionResponseMode != "" { + additionalHeaders[functionResponseModeHeader] = functionResponseMode } - return nil + runtimeCalledResponse := false // we are sending an error here, so runtime called /error or crashed/timeout + return s.sendResponseUnsafe(invokeID, additionalHeaders, http.StatusInternalServerError, bytes.NewReader(resp.Payload), nil, nil, runtimeCalledResponse) } func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) { // pass reset to rapid - s.resetChanOut <- &interop.Reset{ - Reason: reason, - DeadlineNs: deadlineNsFromTimeoutMs(timeoutMs), - CorrelationID: "resetCorrelationID", + reset := &interop.Reset{ + Reason: reason, + DeadlineNs: deadlineNsFromTimeoutMs(timeoutMs), } + go func() { + select { + case s.interruptedResponseChan <- reset: + <-s.interruptedResponseChan // wait for response streaming metrics being added to reset struct + s.sandboxContext.SetInvokeResponseMetrics(reset.InvokeResponseMetrics) + default: + } + + resetSuccess, resetFailure := s.sandboxContext.Reset(reset) + s.Clear() // clear server state to prepare for new invokes + s.setRapidPhase(phaseIdle) + s.setRuntimeState(runtimeNotStarted) + + var meta interop.DoneMetadata + if reset.InvokeResponseMetrics != nil { + meta.RuntimeTimeThrottledMs = reset.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) + meta.RuntimeProducedBytes = reset.InvokeResponseMetrics.ProducedBytes + meta.RuntimeOutboundThroughputBps = reset.InvokeResponseMetrics.OutboundThroughputBps + } + + if resetFailure != nil { + meta.ExtensionsResetMs = resetFailure.ExtensionsResetMs + s.ResetDoneChan <- &interop.Done{ErrorType: resetFailure.ErrorType, Meta: meta} + } else { + meta.ExtensionsResetMs = resetSuccess.ExtensionsResetMs + s.ResetDoneChan <- &interop.Done{ErrorType: resetSuccess.ErrorType, Meta: meta} + } + }() - // TODO do not block on reset, instead consume ResetDoneChan in waitForRelease handler, - // this will get us more aligned on async reset notification handling. done := <-s.ResetDoneChan s.Release() @@ -431,14 +436,11 @@ func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescript func NewServer(ctx context.Context) *Server { s := &Server{ - startChanOut: make(chan *interop.Start), - invokeChanOut: make(chan *interop.Invoke), - errorChanOut: make(chan error), - resetChanOut: make(chan *interop.Reset), - shutdownChanOut: make(chan *interop.Shutdown), - - sendRunningChan: make(chan *interop.Running), - sendResponseChan: make(chan struct{}), + initChanOut: make(chan *interop.Init), + interruptedResponseChan: make(chan *interop.Reset), + + sendRunningChan: make(chan *interop.InitStarted), + sendResponseChan: make(chan *interop.InvokeResponseMetrics), doneChan: make(chan *interop.Done), // These two channels are buffered, because they are depleted asynchronously (by reserve and waitUntilRelease) and we don't want to block in SendDone until they are called @@ -449,47 +451,9 @@ func NewServer(ctx context.Context) *Server { ShutdownDoneChan: make(chan *interop.Done), } - go s.dispatchDone() - return s } -func (s *Server) setInitDoneRuntimeState(done *interop.Done) { - if len(done.ErrorType) > 0 { - s.setRuntimeState(runtimeInitFailed) // donefail - } else { - s.setRuntimeState(runtimeInitComplete) // done - } -} - -// Note, the dispatch loop below has potential to block, when -// channel is not drained. E.g. if test assumes sandbox init -// completion before dispatching reset, then reset will block -// until init channel is drained. -func (s *Server) dispatchDone() { - for { - done := <-s.doneChan - log.Debug("Dispatching DONE:", done.CorrelationID) - internalState := s.InternalStateGetter() - s.setRapidPhase(phaseIdle) - if done.CorrelationID == "initCorrelationID" { - s.setInitDoneRuntimeState(done) - s.InitDoneChan <- DoneWithState{Done: done, State: internalState} - } else if done.CorrelationID == "invokeCorrelationID" { - s.setRuntimeState(runtimeInvokeComplete) - s.InvokeDoneChan <- DoneWithState{Done: done, State: internalState} - } else if done.CorrelationID == "resetCorrelationID" { - s.setRuntimeState(runtimeNotStarted) - s.ResetDoneChan <- done - } else if done.CorrelationID == "shutdownCorrelationID" { - s.setRuntimeState(runtimeNotStarted) - s.ShutdownDoneChan <- done - } else { - panic("Received DONE without correlation ID") - } - } -} - func drainChannel(c chan DoneWithState) { for { select { @@ -509,10 +473,6 @@ func (s *Server) Clear() { s.Release() } -func (s *Server) IsResponseSent() bool { - panic("unexpected call to unimplemented method in rapidcore: IsResponseSent()") -} - func (s *Server) SendRuntimeReady() error { // only called when extensions are enabled s.setRuntimeState(runtimeReady) @@ -524,16 +484,34 @@ func deadlineNsFromTimeoutMs(timeoutMs int64) int64 { return mono + timeoutMs*1000*1000 } -func (s *Server) Init(i *interop.Start, invokeTimeoutMs int64) { - s.SetInvokeTimeout(time.Duration(invokeTimeoutMs) * time.Millisecond) +func (s *Server) setInitFailuresChan() { + s.mutex.Lock() + defer s.mutex.Unlock() + s.initFailures = make(chan interop.InitFailure) +} + +func (s *Server) getInitFailuresChan() chan interop.InitFailure { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.initFailures +} - s.startChanOut <- i +func (s *Server) Init(i *interop.Init, invokeTimeoutMs int64) error { + s.SetInvokeTimeout(time.Duration(invokeTimeoutMs) * time.Millisecond) s.setRapidPhase(phaseInitializing) - <-s.sendRunningChan - log.Debug("Received RUNNING") + s.setInitFailuresChan() + initStarted, initCtx := s.sandboxContext.Init(i, invokeTimeoutMs) + initStarted.Ack <- struct{}{} + + s.initContext = initCtx + go s.awaitInitCompletion() + + log.Debugf("Received RUNNING: %v", initStarted) + return nil } func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error { + s.sandboxContext.SetInvokeReceivedTime(i.InvokeReceivedTime) invokeID, err := s.setReplyStream(w, direct) if err != nil { return err @@ -544,16 +522,55 @@ func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct boo i.ID = invokeID select { - case s.invokeChanOut <- i: - break case <-s.sendResponseChan: // we didn't pass invoke to rapid yet, but rapid already has written some response // It can happend if runtime/agent crashed even before we passed invoke to it return ErrInvokeResponseAlreadyWritten + default: } + go func() { + if s.invoker == nil { + // Reset occurred, do not send invoke request + s.InvokeDoneChan <- DoneWithState{State: s.InternalStateGetter()} + s.setRuntimeState(runtimeInvokeComplete) + return + } + s.invoker.SendRequest(i) + invokeSuccess, invokeFailure := s.invoker.Wait() + if invokeFailure != nil { + if invokeFailure.ResetReceived { + return + } + + // Rapid constructs a response body itself when invoke fails, with error type. + // These are on the handleInvokeError path, may occur during timeout resets, + // failure reset (proc exit). It is expected to be non-nil on all invoke failures. + if invokeFailure.DefaultErrorResponse == nil { + log.Panicf("default error response was nil for invoke failure, %v", invokeFailure) + } + + if cachedInitError := s.getCachedInitErrorResponse(); cachedInitError != nil { + // /init/error was called + s.trySendDefaultErrorResponse(cachedInitError) + } else { + // sent only if /error and /response not called + s.trySendDefaultErrorResponse(invokeFailure.DefaultErrorResponse) + } + doneFail := doneFailFromInvokeFailure(invokeFailure) + s.InvokeDoneChan <- DoneWithState{ + Done: &interop.Done{ErrorType: doneFail.ErrorType, Meta: doneFail.Meta}, + State: s.InternalStateGetter(), + } + } else { + done := doneFromInvokeSuccess(invokeSuccess) + s.InvokeDoneChan <- DoneWithState{Done: done, State: s.InternalStateGetter()} + } + }() + select { - case <-s.sendResponseChan: + case i.InvokeResponseMetrics = <-s.sendResponseChan: + s.sandboxContext.SetInvokeResponseMetrics(i.InvokeResponseMetrics) break case <-s.reservationContext.Done(): return ErrInvokeReservationDone @@ -562,6 +579,26 @@ func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct boo return nil } +func (s *Server) setCachedInitErrorResponse(errResp *interop.ErrorResponse) { + s.mutex.Lock() + defer s.mutex.Unlock() + s.cachedInitErrorResponse = errResp +} + +func (s *Server) getCachedInitErrorResponse() *interop.ErrorResponse { + s.mutex.Lock() + defer s.mutex.Unlock() + return s.cachedInitErrorResponse +} + +func (s *Server) trySendDefaultErrorResponse(resp *interop.ErrorResponse) { + if err := s.SendErrorResponse(s.GetCurrentInvokeID(), resp); err != nil { + if err != interop.ErrResponseSent { + log.Panicf("Failed to send default error response: %s", err) + } + } +} + func (s *Server) CurrentToken() *interop.Token { s.mutex.Lock() defer s.mutex.Unlock() @@ -582,77 +619,158 @@ func (s *Server) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invo go func() { select { case <-time.After(s.GetInvokeTimeout()): + log.Debug("Invoke() timeout") timeoutChan <- ErrInvokeTimeout - s.Reset(autoresetReasonTimeout, resetDefaultTimeoutMs) case <-resetCtx.Done(): log.Debugf("execute finished, autoreset cancelled") } }() - reserveResp, err := s.Reserve(invoke.ID, "", "") - if err != nil { - switch err { - case ErrInitError: - // Simulate 'Suppressed Init' scenario - s.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) - reserveResp, err = s.Reserve("", "", "") - if err == ErrInitAlreadyDone { - break - } - return err - case ErrInitDoneFailed, ErrTerminated: - s.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) - return err - - case ErrInitAlreadyDone: - // This is a valid response (e.g. for 2nd and 3rd invokes) - // TODO: switch on ReserveResponse status instead of err, - // since these are valid values - if s.InternalStateGetter == nil { - responseWriter.Write([]byte("error: internal state callback not set")) - return ErrInternalServerError - } + initFailures := s.getInitFailuresChan() + if initFailures == nil { + return ErrInitNotStarted + } - default: - return err + releaseErrChan := make(chan error) + releaseSuccessChan := make(chan struct{}) + go func() { + // This thread can block in one of two method calls Reserve() & AwaitRelease(), + // corresponding to Init and Invoke phase. + // FastInvoke is intended to be 'async' response stream copying. + // When a timeout occurs, we send a 'Reset' with the timeout reason + // When a Reset is sent, the reset handler in rapid lib cancels existing flows, + // including init/invoke. This causes either initFailure/invokeFailure, and then + // the Reset is handled and processed. + // TODO: however, ideally Reserve() does not block on init, but FastInvoke does + // The logic would be almost identical, except that init failures could manifest + // through return values of FastInvoke and not Reserve() + + reserveResp, err := s.Reserve("", "", "") + if err != nil { + log.Infof("ReserveFailed: %s", err) } - } - invoke.DeadlineNs = fmt.Sprintf("%d", metering.Monotime()+reserveResp.Token.FunctionTimeout.Nanoseconds()) + invoke.DeadlineNs = fmt.Sprintf("%d", metering.Monotime()+reserveResp.Token.FunctionTimeout.Nanoseconds()) + go func() { + if initCompletionResp, err := s.awaitInitialized(); err != nil { + switch err { + case ErrInitResetReceived, ErrInitDoneFailed: + // For init failures, cache the response so they can be checked later + // We check if they have not already been set by a call to /init/error by runtime + if s.getCachedInitErrorResponse() == nil { + errType, errMsg := string(initCompletionResp.InitErrorType), initCompletionResp.InitErrorMessage.Error() + s.setCachedInitErrorResponse(&interop.ErrorResponse{ErrorType: errType, ErrorMessage: errMsg}) + } + } + } - invokeChan := make(chan error) - go func() { - if err := s.FastInvoke(responseWriter, invoke, false); err != nil { - invokeChan <- err + if err := s.FastInvoke(responseWriter, invoke, false); err != nil { + log.Debugf("FastInvoke() error: %s", err) + } + }() + + _, err = s.AwaitRelease() + if err != nil && err != ErrReleaseReservationDone { + log.Debugf("AwaitRelease() error: %s", err) + switch err { + case ErrReleaseReservationDone: // not an error, expected return value when Reset is called + if s.getCachedInitErrorResponse() != nil { + // For Init failures, AwaitRelease returns ErrReleaseReservationDone + // because the Reset calls Release & cancels the release context + // We rename the error to ErrInitDoneFailed + releaseErrChan <- ErrInitDoneFailed + } + case ErrInitDoneFailed, ErrInvokeDoneFailed: + // Reset when either init or invoke failrues occur, i.e. + // init/error, invocation/error, Runtime.ExitError, Extension.ExitError + s.Reset(autoresetReasonReleaseFail, resetDefaultTimeoutMs) + releaseErrChan <- err + default: + releaseErrChan <- err + } + return } - }() - releaseChan := make(chan error) - go func() { - _, err := s.AwaitRelease() - releaseChan <- err + releaseSuccessChan <- struct{}{} }() - // TODO: verify the order of channel receives. When timeouts happen, Reset() - // is called first, which also does Release() => this may signal a type - // Err<*>ReservationDone error to the non-timeout channels. This is currently - // handled by the http handler, which returns GatewayTimeout for reservation errors - // too. However, Timeouts should ideally be only represented by ErrInvokeTimeout. + var err error select { - case err = <-invokeChan: - case err = <-timeoutChan: - case err = <-releaseChan: - if err != nil { - s.Reset(autoresetReasonReleaseFail, resetDefaultTimeoutMs) + case timeoutErr := <-timeoutChan: + s.Reset(autoresetReasonTimeout, resetDefaultTimeoutMs) + select { + case releaseErr := <-releaseErrChan: // when AwaitRelease() has errors + log.Debugf("Invoke() release error on Execute() timeout: %s", releaseErr) + case <-releaseSuccessChan: // when AwaitRelease() finishes cleanly } + err = timeoutErr + case err = <-releaseErrChan: + log.Debug("Invoke() release error") + case <-releaseSuccessChan: + s.Release() + log.Debug("Invoke() success") } return err } +type initCompletionResponse struct { + InitErrorType fatalerror.ErrorType + InitErrorMessage error +} + +func (s *Server) awaitInitialized() (initCompletionResponse, error) { + initFailure, awaitingInitStatus := <-s.getInitFailuresChan() + resp := initCompletionResponse{} + + if initFailure.ResetReceived { + // Resets during Init are only received in standalone + // during an invoke timeout + s.setRuntimeState(runtimeInitFailed) + resp.InitErrorType = initFailure.ErrorType + resp.InitErrorMessage = initFailure.ErrorMessage + return resp, ErrInitResetReceived + } + + if awaitingInitStatus { + // channel not closed, received init failure + // Sandbox can be reserved even if init failed (due to function errors) + s.setRuntimeState(runtimeInitFailed) + resp.InitErrorType = initFailure.ErrorType + resp.InitErrorMessage = initFailure.ErrorMessage + return resp, ErrInitDoneFailed + } + + // not awaiting init status (channel closed) + return resp, nil +} + +// AwaitInitialized waits until init is complete. It must be idempotent, +// since it can be called twice when a caller wants to wait until init is complete +func (s *Server) AwaitInitialized() error { + if _, err := s.awaitInitialized(); err != nil { + if releaseErr := s.Release(); err != nil { + log.Infof("Error releasing after init failure %s: %s", err, releaseErr) + } + s.setRuntimeState(runtimeInitFailed) + return err + } + s.setRuntimeState(runtimeInitComplete) + return nil +} + func (s *Server) AwaitRelease() (*statejson.InternalStateDescription, error) { + defer func() { + s.setRapidPhase(phaseIdle) + s.setRuntimeState(runtimeInvokeComplete) + }() + select { case doneWithState := <-s.InvokeDoneChan: + if len(doneWithState.ErrorType) > 0 && string(doneWithState.ErrorType) == ErrInitDoneFailed.Error() { + return nil, ErrInitDoneFailed + } + if len(doneWithState.ErrorType) > 0 { log.Errorf("Invoke DONE failed: %s", doneWithState.ErrorType) return nil, ErrInvokeDoneFailed @@ -667,8 +785,13 @@ func (s *Server) AwaitRelease() (*statejson.InternalStateDescription, error) { } func (s *Server) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { - s.shutdownChanOut <- shutdown - <-s.ShutdownDoneChan + shutdownSuccess := s.sandboxContext.Shutdown(shutdown) + if len(shutdownSuccess.ErrorType) > 0 { + log.Errorf("Shutdown first fatal error: %s", shutdownSuccess.ErrorType) + } + + s.setRapidPhase(phaseIdle) + s.setRuntimeState(runtimeNotStarted) state := s.InternalStateGetter() return &state @@ -682,3 +805,49 @@ func (s *Server) InternalState() (*statejson.InternalStateDescription, error) { state := s.InternalStateGetter() return &state, nil } + +func (s *Server) Restore(restore *interop.Restore) error { + return s.sandboxContext.Restore(restore) +} + +func doneFromInvokeSuccess(successMsg interop.InvokeSuccess) *interop.Done { + return &interop.Done{ + Meta: interop.DoneMetadata{ + RuntimeRelease: successMsg.RuntimeRelease, + NumActiveExtensions: successMsg.NumActiveExtensions, + ExtensionNames: successMsg.ExtensionNames, + InvokeRequestReadTimeNs: successMsg.InvokeMetrics.InvokeRequestReadTimeNs, + InvokeRequestSizeBytes: successMsg.InvokeMetrics.InvokeRequestSizeBytes, + RuntimeReadyTime: successMsg.InvokeMetrics.RuntimeReadyTime, + + InvokeCompletionTimeNs: successMsg.InvokeCompletionTimeNs, + InvokeReceivedTime: successMsg.InvokeReceivedTime, + RuntimeTimeThrottledMs: successMsg.ResponseMetrics.RuntimeTimeThrottledMs, + RuntimeProducedBytes: successMsg.ResponseMetrics.RuntimeProducedBytes, + RuntimeOutboundThroughputBps: successMsg.ResponseMetrics.RuntimeOutboundThroughputBps, + LogsAPIMetrics: successMsg.LogsAPIMetrics, + }, + } +} + +func doneFailFromInvokeFailure(failureMsg *interop.InvokeFailure) *interop.DoneFail { + return &interop.DoneFail{ + ErrorType: failureMsg.ErrorType, + Meta: interop.DoneMetadata{ + RuntimeRelease: failureMsg.RuntimeRelease, + NumActiveExtensions: failureMsg.NumActiveExtensions, + InvokeReceivedTime: failureMsg.InvokeReceivedTime, + + RuntimeTimeThrottledMs: failureMsg.ResponseMetrics.RuntimeTimeThrottledMs, + RuntimeProducedBytes: failureMsg.ResponseMetrics.RuntimeProducedBytes, + RuntimeOutboundThroughputBps: failureMsg.ResponseMetrics.RuntimeOutboundThroughputBps, + + InvokeRequestReadTimeNs: failureMsg.InvokeMetrics.InvokeRequestReadTimeNs, + InvokeRequestSizeBytes: failureMsg.InvokeMetrics.InvokeRequestSizeBytes, + RuntimeReadyTime: failureMsg.InvokeMetrics.RuntimeReadyTime, + + ExtensionNames: failureMsg.ExtensionNames, + LogsAPIMetrics: failureMsg.LogsAPIMetrics, + }, + } +} diff --git a/lambda/rapidcore/server_test.go b/lambda/rapidcore/server_test.go index 416304c..88eea3f 100644 --- a/lambda/rapidcore/server_test.go +++ b/lambda/rapidcore/server_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore/env" ) func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { @@ -26,15 +27,68 @@ func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { } } +func sendInitStartedResponse(responseChannel chan<- interop.InitStarted, msg interop.InitStarted) { + msg.Ack = make(chan struct{}) + responseChannel <- msg + <-msg.Ack +} + +func sendInitSuccessResponse(responseChannel chan<- interop.InitSuccess, msg interop.InitSuccess) { + msg.Ack = make(chan struct{}) + responseChannel <- msg + <-msg.Ack +} + +func sendInitFailureResponse(responseChannel chan<- interop.InitFailure, msg interop.InitFailure) { + msg.Ack = make(chan struct{}) + responseChannel <- msg + <-msg.Ack +} + +type mockRapidCtx struct { + initHandler func(start chan<- interop.InitStarted, success chan<- interop.InitSuccess, fail chan<- interop.InitFailure) + invokeHandler func() (interop.InvokeSuccess, *interop.InvokeFailure) + resetHandler func() (interop.ResetSuccess, *interop.ResetFailure) +} + +func (r *mockRapidCtx) HandleInit(init *interop.Init, startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + r.initHandler(startResp, successResp, failureResp) +} + +func (r *mockRapidCtx) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { + return r.invokeHandler() +} + +func (r *mockRapidCtx) HandleReset(reset *interop.Reset, invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { + return r.resetHandler() +} + +func (r *mockRapidCtx) HandleShutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { + return interop.ShutdownSuccess{} +} + +func (r *mockRapidCtx) HandleRestore(restore *interop.Restore) error { + return nil +} + +func (r *mockRapidCtx) Clear() {} + func TestReserveDoesNotDeadlockWhenCalledMultipleTimes(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { <-srv.StartChan() }() - go srv.SendRunning(&interop.Running{}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - go srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"}) _, err := srv.Reserve("", "", "") // reserve successfully require.NoError(t, err) @@ -61,89 +115,120 @@ func TestInitSuccess(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) - }() + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) _, err := srv.Reserve("", "", "") require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) } func TestInitErrorBeforeReserve(t *testing.T) { + // Rapid thread sending init failure should not be blocked even if reserve hasn't arrived srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) initErrorResponseSent := make(chan error) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) initErrorResponseSent <- errors.New("initErrorResponseSent") - }() + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) if msg := waitForChanWithTimeout(initErrorResponseSent, 1*time.Second); msg == nil { require.Fail(t, "Timed out waiting for init error response to be sent") } resp, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) + require.NoError(t, err) require.True(t, len(resp.Token.InvokeID) > 0) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) + + _, err = srv.AwaitRelease() + require.Error(t, err, ErrReleaseReservationDone) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) } func TestInitErrorDuringReserve(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) - }() + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) + } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ + initHandler, + func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, + func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, + }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) resp, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) + require.NoError(t, err) require.True(t, len(resp.Token.InvokeID) > 0) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) + + _, err = srv.AwaitRelease() + require.Error(t, err, ErrReleaseReservationDone) + require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) } func TestInvokeSuccess(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) + releaseRuntimeInit := make(chan struct{}) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + <-releaseRuntimeInit + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } - <-srv.InvokeChan() - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "application/json", bytes.NewReader([]byte("response")))) + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader([]byte("response")), nil, nil)) require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + return interop.InvokeSuccess{}, nil + } - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil } + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) + releaseRuntimeInit <- struct{}{} _, err := srv.Reserve("", "", "") require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) // Reserve does not wait for init completion + + awaitInitErr := srv.AwaitInitialized() + require.NoError(t, awaitInitErr) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "response", responseRecorder.Body.String()) require.Equal(t, "application/json", responseRecorder.Result().Header.Get("Content-Type")) @@ -157,28 +242,35 @@ func TestInvokeError(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) - - <-srv.InvokeChan() + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }"), ContentType: "application/json"})) require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + return interop.InvokeSuccess{}, nil + } - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } + + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) _, err := srv.Reserve("", "", "") require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitComplete), srv.getRuntimeState()) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + awaitInitErr := srv.AwaitInitialized() + require.NoError(t, awaitInitErr) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) require.Equal(t, "application/json", responseRecorder.Result().Header.Get("Content-Type")) @@ -203,43 +295,49 @@ func TestInvokeWithSuppressedInitSuccess(t *testing.T) { srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) initErrorCompleted := make(chan error) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "foobar"})) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) initErrorCompleted <- errors.New("initErrorSequenceCompleted") + } - <-srv.ResetChan() - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"})) + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), nil, bytes.NewReader([]byte("response")), nil, nil)) + return interop.InvokeSuccess{}, nil + } - <-srv.InvokeChan() // run only after FastInvoke is called - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "", bytes.NewReader([]byte("response")))) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } + + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) if msg := waitForChanWithTimeout(initErrorCompleted, 1*time.Second); msg == nil { require.Fail(t, "Timed out waiting for init error sequence to be called") } - _, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + resp, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.True(t, len(resp.Token.InvokeID) > 0) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for suppressed init require.NoError(t, err) _, err = srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.NoError(t, err) responseRecorder := httptest.NewRecorder() successChan := make(chan error) go func() { directInvoke := false - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, directInvoke) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, directInvoke) require.NoError(t, invokeErr) successChan <- errors.New("invokeResponseWritten") }() @@ -261,39 +359,45 @@ func TestInvokeWithSuppressedInitErrorDueToInitError(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) + } + releaseChan := make(chan error) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "A.B"})) + releaseChan <- nil + return interop.InvokeSuccess{}, &interop.InvokeFailure{ErrorType: "A.B", RequestReset: true, DefaultErrorResponse: &interop.ErrorResponse{}} + } - <-srv.ResetChan() - srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"}) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } - <-srv.InvokeChan() - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - releaseChan <- nil - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "invokeCorrelationID", ErrorType: "A.B"})) - }() + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - _, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + resp, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.True(t, len(resp.Token.InvokeID) > 0) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for invoke with suppressed init require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) _, err = srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) require.Equal(t, phaseInvoking, srv.getRapidPhase()) @@ -310,39 +414,43 @@ func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - require.NoError(t, srv.SendDoneFail(&interop.DoneFail{CorrelationID: "initCorrelationID", ErrorType: "A.B"})) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + sendInitFailureResponse(failureResp, interop.InitFailure{}) + } + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) + require.NoError(t, srv.SendRuntimeReady()) + return interop.InvokeSuccess{}, nil + } - <-srv.ResetChan() - srv.SendDone(&interop.Done{CorrelationID: "resetCorrelationID"}) + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil + } - <-srv.InvokeChan() - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) - require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) - }() + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) - _, err := srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitError.Error()) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInitFailed), srv.getRuntimeState()) + resp, err := srv.Reserve("", "", "") + require.NoError(t, err) + require.True(t, len(resp.Token.InvokeID) > 0) + + awaitInitErr := srv.AwaitInitialized() + require.Error(t, ErrInitDoneFailed, awaitInitErr) _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for invoke with suppressed init require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) _, err = srv.Reserve("", "", "") - require.EqualError(t, err, ErrInitAlreadyDone.Error()) + require.NoError(t, err) require.Equal(t, phaseIdle, srv.getRapidPhase()) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "{ 'errorType': 'B.C' }", responseRecorder.Body.String()) @@ -356,39 +464,43 @@ func TestMultipleInvokeSuccess(t *testing.T) { srv := NewServer(context.Background()) srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - go func() { - <-srv.StartChan() - require.NoError(t, srv.SendRunning(&interop.Running{})) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "initCorrelationID"})) - }() - - srv.Init(&interop.Start{CorrelationID: "initCorrelationID"}, int64(1*time.Second*time.Millisecond)) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - - invokeFunc := func(i int) { - <-srv.InvokeChan() - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), "", bytes.NewReader([]byte("response-"+fmt.Sprint(i))))) + initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + sendInitStartedResponse(startResp, interop.InitStarted{}) + sendInitSuccessResponse(successResp, interop.InitSuccess{}) + } + i := 0 + invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), nil, bytes.NewReader([]byte("response-"+fmt.Sprint(i))), nil, nil)) require.NoError(t, srv.SendRuntimeReady()) - require.NoError(t, srv.SendDone(&interop.Done{CorrelationID: "invokeCorrelationID"})) + i++ + return interop.InvokeSuccess{}, nil + } + + resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { + return interop.ResetSuccess{}, nil } - go func() { - for i := 0; i < 3; i++ { - invokeFunc(i) - } - }() + + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + + srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) + require.Equal(t, phaseInitializing, srv.getRapidPhase()) for i := 0; i < 3; i++ { _, err := srv.Reserve("", "", "") - require.Contains(t, []error{nil, ErrInitAlreadyDone}, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) + require.NoError(t, err) + + awaitInitErr := srv.AwaitInitialized() + require.NoError(t, awaitInitErr) responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{CorrelationID: "invokeCorrelationID"}, false) + invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) require.NoError(t, invokeErr) require.Equal(t, "response-"+fmt.Sprint(i), responseRecorder.Body.String()) + require.Equal(t, phaseInvoking, srv.getRapidPhase()) _, err = srv.AwaitRelease() require.NoError(t, err) + require.Equal(t, phaseIdle, srv.getRapidPhase()) require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) } } diff --git a/lambda/rapidcore/standalone/directInvokeHandler.go b/lambda/rapidcore/standalone/directInvokeHandler.go index a485deb..1c7e7cb 100644 --- a/lambda/rapidcore/standalone/directInvokeHandler.go +++ b/lambda/rapidcore/standalone/directInvokeHandler.go @@ -4,13 +4,15 @@ package standalone import ( - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/core/directinvoke" "go.amzn.com/lambda/rapidcore" + "net/http" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/core/directinvoke" ) -func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { tok := s.CurrentToken() if tok == nil { log.Errorf("Attempt to call directInvoke without Reserve") @@ -24,6 +26,14 @@ func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.Int return } + if err := s.AwaitInitialized(); err != nil { + w.WriteHeader(DoneFailedHTTPCode) + if state, err := s.InternalState(); err == nil { + w.Write(state.AsJSON()) + } + return + } + if err := s.FastInvoke(w, invoke, true); err != nil { switch err { case rapidcore.ErrNotReserved: diff --git a/lambda/rapidcore/standalone/executeHandler.go b/lambda/rapidcore/standalone/executeHandler.go index 36c257a..9bac400 100644 --- a/lambda/rapidcore/standalone/executeHandler.go +++ b/lambda/rapidcore/standalone/executeHandler.go @@ -8,16 +8,17 @@ import ( log "github.com/sirupsen/logrus" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapidcore" ) -func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) { +func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.LambdaInvokeAPI) { invokePayload := &interop.Invoke{ - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: r.Body, - CorrelationID: "invokeCorrelationID", + TraceID: r.Header.Get("X-Amzn-Trace-Id"), + LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), + Payload: r.Body, + InvokeReceivedTime: metering.Monotime(), } // If we write to 'w' directly and waitUntilRelease fails, we won't be able to propagate error anymore @@ -38,17 +39,17 @@ func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) case rapidcore.ErrInvokeResponseAlreadyWritten: return - case rapidcore.ErrInvokeTimeout: + case rapidcore.ErrInvokeTimeout, rapidcore.ErrInitResetReceived: w.WriteHeader(http.StatusGatewayTimeout) // DONE failures: - case rapidcore.ErrTerminated, rapidcore.ErrInitDoneFailed, rapidcore.ErrInvokeDoneFailed: + case rapidcore.ErrInvokeDoneFailed: copyHeaders(invokeResp, w) w.WriteHeader(DoneFailedHTTPCode) w.Write(invokeResp.Body) return // Reservation canceled errors - case rapidcore.ErrReserveReservationDone, rapidcore.ErrInvokeReservationDone, rapidcore.ErrReleaseReservationDone: + case rapidcore.ErrReserveReservationDone, rapidcore.ErrInvokeReservationDone, rapidcore.ErrReleaseReservationDone, rapidcore.ErrInitNotStarted: w.WriteHeader(http.StatusGatewayTimeout) } diff --git a/lambda/rapidcore/standalone/initHandler.go b/lambda/rapidcore/standalone/initHandler.go index d006b81..d60ec6f 100644 --- a/lambda/rapidcore/standalone/initHandler.go +++ b/lambda/rapidcore/standalone/initHandler.go @@ -7,21 +7,33 @@ import ( "fmt" "net/http" "os" + "time" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" "go.amzn.com/lambda/rapidcore/env" ) +type RuntimeInfo struct { + ImageJSON string `json:"runtimeImageJSON,omitempty"` + Arn string `json:"runtimeArn,omitempty"` + Version string `json:"runtimeVersion,omitempty"` +} + // TODO: introduce suppress init flag type InitBody struct { - Handler string `json:"handler"` - FunctionName string `json:"functionName"` - FunctionVersion string `json:"functionVersion"` - InvokeTimeoutMs int64 `json:"invokeTimeoutMs"` + Handler string `json:"handler"` + FunctionName string `json:"functionName"` + FunctionVersion string `json:"functionVersion"` + InvokeTimeoutMs int64 `json:"invokeTimeoutMs"` + RuntimeInfo RuntimeInfo `json:"runtimeInfo"` Customer struct { Environment map[string]string `json:"environment"` } `json:"customer"` + AwsKey *string `json:"awskey"` + AwsSecret *string `json:"awssecret"` + AwsSession *string `json:"awssession"` + CredentialsExpiry time.Time `json:"credentialsExpiry"` + Throttled bool `json:"throttled"` } type InitRequest struct { @@ -44,7 +56,7 @@ func (c *InitBody) Validate() error { return nil } -func InitHandler(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandbox) { +func InitHandler(w http.ResponseWriter, r *http.Request, sandbox InteropServer, bs interop.Bootstrap) { init := InitBody{} if lerr := readBodyAndUnmarshalJSON(r, &init); lerr != nil { lerr.Send(w, r) @@ -61,20 +73,54 @@ func InitHandler(w http.ResponseWriter, r *http.Request, sandbox rapidcore.Sandb // logic consistent across standalone-mode and girp-mode os.Setenv(envKey, envVal) } - // TODO generate CorrelationID + + awsKey, awsSecret, awsSession := getCredentials(init) + + sandboxType := interop.SandboxClassic + + if init.Throttled { + sandboxType = interop.SandboxPreWarmed + } // pass to rapid sandbox.Init(&interop.Init{ Handler: init.Handler, - CorrelationID: "initCorrelationID", - AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), - AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), - AwsSession: os.Getenv("AWS_SESSION_TOKEN"), + AwsKey: awsKey, + AwsSecret: awsSecret, + AwsSession: awsSession, + CredentialsExpiry: init.CredentialsExpiry, XRayDaemonAddress: "0.0.0.0:0", // TODO FunctionName: init.FunctionName, FunctionVersion: init.FunctionVersion, - + RuntimeInfo: interop.RuntimeInfo{ + ImageJSON: init.RuntimeInfo.ImageJSON, + Arn: init.RuntimeInfo.Arn, + Version: init.RuntimeInfo.Version}, CustomerEnvironmentVariables: env.CustomerEnvironmentVariables(), + SandboxType: sandboxType, + Bootstrap: bs, + EnvironmentVariables: env.NewEnvironment(), }, init.InvokeTimeoutMs) +} + +func getCredentials(init InitBody) (string, string, string) { + // ToDo(guvfatih): I think instead of passing and getting these credentials values via environment variables + // we need to make StandaloneTests passing these via the Init request to be compliant with the existing protocol. + awsKey := os.Getenv("AWS_ACCESS_KEY_ID") + awsSecret := os.Getenv("AWS_SECRET_ACCESS_KEY") + awsSession := os.Getenv("AWS_SESSION_TOKEN") + + if init.AwsKey != nil { + awsKey = *init.AwsKey + } + + if init.AwsSecret != nil { + awsSecret = *init.AwsSecret + } + + if init.AwsSession != nil { + awsSession = *init.AwsSession + } + return awsKey, awsSecret, awsSession } diff --git a/lambda/rapidcore/standalone/internalStateHandler.go b/lambda/rapidcore/standalone/internalStateHandler.go index ff1335a..cb40c1c 100644 --- a/lambda/rapidcore/standalone/internalStateHandler.go +++ b/lambda/rapidcore/standalone/internalStateHandler.go @@ -5,11 +5,9 @@ package standalone import ( "net/http" - - "go.amzn.com/lambda/rapidcore" ) -func InternalStateHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func InternalStateHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { state, err := s.InternalState() if err != nil { http.Error(w, "internal state callback not set", http.StatusInternalServerError) diff --git a/lambda/rapidcore/standalone/invokeHandler.go b/lambda/rapidcore/standalone/invokeHandler.go index 0d89f1c..3e9768c 100644 --- a/lambda/rapidcore/standalone/invokeHandler.go +++ b/lambda/rapidcore/standalone/invokeHandler.go @@ -14,7 +14,7 @@ import ( log "github.com/sirupsen/logrus" ) -func InvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func InvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { tok := s.CurrentToken() if tok == nil { log.Errorf("Attempt to call directInvoke without Reserve") @@ -22,28 +22,20 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropSe return } - isResyncReceivedFlag := false - - awsKey := r.Header.Get("ResyncAwsKey") - awsSecret := r.Header.Get("ResyncAwsSecret") - awsSession := r.Header.Get("ResyncAwsSession") - - if len(awsKey) > 0 && len(awsSecret) > 0 && len(awsSession) > 0 { - isResyncReceivedFlag = true + invokePayload := &interop.Invoke{ + TraceID: r.Header.Get("X-Amzn-Trace-Id"), + LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), + Payload: r.Body, + DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), + InvokeReceivedTime: metering.Monotime(), } - invokePayload := &interop.Invoke{ - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: r.Body, - CorrelationID: "invokeCorrelationID", - DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), - ResyncState: interop.Resync{ - IsResyncReceived: isResyncReceivedFlag, - AwsKey: awsKey, - AwsSecret: awsSecret, - AwsSession: awsSession, - }, + if err := s.AwaitInitialized(); err != nil { + w.WriteHeader(DoneFailedHTTPCode) + if state, err := s.InternalState(); err == nil { + w.Write(state.AsJSON()) + } + return } if err := s.FastInvoke(w, invokePayload, false); err != nil { diff --git a/lambda/rapidcore/standalone/pingHandler.go b/lambda/rapidcore/standalone/pingHandler.go new file mode 100644 index 0000000..c6cb021 --- /dev/null +++ b/lambda/rapidcore/standalone/pingHandler.go @@ -0,0 +1,12 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package standalone + +import ( + "net/http" +) + +func PingHandler(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("pong")) +} diff --git a/lambda/rapidcore/standalone/reserveHandler.go b/lambda/rapidcore/standalone/reserveHandler.go index d3e0b9f..52b51cd 100644 --- a/lambda/rapidcore/standalone/reserveHandler.go +++ b/lambda/rapidcore/standalone/reserveHandler.go @@ -24,24 +24,14 @@ func tokenToHeaders(w http.ResponseWriter, token interop.Token) { w.Header().Set(directinvoke.VersionIDHeader, token.VersionID) } -func ReserveHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func ReserveHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { reserveResp, err := s.Reserve("", r.Header.Get("X-Amzn-Trace-Id"), r.Header.Get("X-Amzn-Segment-Id")) if err != nil { switch err { - case rapidcore.ErrInitAlreadyDone: - // init already happened before, just provide internal state and return - tokenToHeaders(w, reserveResp.Token) - InternalStateHandler(w, r, s) case rapidcore.ErrReserveReservationDone: // TODO use http.StatusBadGateway w.WriteHeader(http.StatusGatewayTimeout) - case rapidcore.ErrInitDoneFailed, rapidcore.ErrInitError: - w.WriteHeader(DoneFailedHTTPCode) - w.Write(reserveResp.InternalState.AsJSON()) - case rapidcore.ErrTerminated: - w.WriteHeader(DoneFailedHTTPCode) - w.Write(reserveResp.InternalState.AsJSON()) default: log.Errorf("Failed to reserve: %s", err) w.WriteHeader(400) diff --git a/lambda/rapidcore/standalone/resetHandler.go b/lambda/rapidcore/standalone/resetHandler.go index 1a719ff..4f2ca2e 100644 --- a/lambda/rapidcore/standalone/resetHandler.go +++ b/lambda/rapidcore/standalone/resetHandler.go @@ -5,8 +5,6 @@ package standalone import ( "net/http" - - "go.amzn.com/lambda/rapidcore" ) type resetAPIRequest struct { @@ -14,7 +12,7 @@ type resetAPIRequest struct { TimeoutMs int64 `json:"timeoutMs"` } -func ResetHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func ResetHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { reset := resetAPIRequest{} if lerr := readBodyAndUnmarshalJSON(r, &reset); lerr != nil { lerr.Send(w, r) diff --git a/lambda/rapidcore/standalone/restoreHandler.go b/lambda/rapidcore/standalone/restoreHandler.go new file mode 100644 index 0000000..190b6d8 --- /dev/null +++ b/lambda/rapidcore/standalone/restoreHandler.go @@ -0,0 +1,41 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package standalone + +import ( + "net/http" + "time" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" +) + +type RestoreBody struct { + AwsKey string `json:"awskey"` + AwsSecret string `json:"awssecret"` + AwsSession string `json:"awssession"` + CredentialsExpiry time.Time `json:"credentialsExpiry"` +} + +func RestoreHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { + restoreRequest := RestoreBody{} + if lerr := readBodyAndUnmarshalJSON(r, &restoreRequest); lerr != nil { + lerr.Send(w, r) + return + } + + restore := &interop.Restore{ + AwsKey: restoreRequest.AwsKey, + AwsSecret: restoreRequest.AwsSecret, + AwsSession: restoreRequest.AwsSession, + CredentialsExpiry: restoreRequest.CredentialsExpiry, + } + + err := s.Restore(restore) + + if err != nil { + log.Errorf("Failed to restore: %s", err) + w.WriteHeader(http.StatusBadGateway) + } +} diff --git a/lambda/rapidcore/standalone/router.go b/lambda/rapidcore/standalone/router.go index 5a4ae7c..f1712ea 100644 --- a/lambda/rapidcore/standalone/router.go +++ b/lambda/rapidcore/standalone/router.go @@ -7,18 +7,35 @@ import ( "context" "net/http" + "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" "go.amzn.com/lambda/rapidcore/telemetry" "github.com/go-chi/chi" ) -func NewHTTPRouter(sandbox rapidcore.Sandbox, eventLog *telemetry.EventLog, shutdownFunc context.CancelFunc) *chi.Mux { - ipcSrv := sandbox.InteropServer() +type InteropServer interface { + Init(i *interop.Init, invokeTimeoutMs int64) error + AwaitInitialized() error + FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error + Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) + Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) + AwaitRelease() (*statejson.InternalStateDescription, error) + Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription + InternalState() (*statejson.InternalStateDescription, error) + CurrentToken() *interop.Token + Restore(restore *interop.Restore) error +} + +func NewHTTPRouter(ipcSrv InteropServer, lambdaInvokeAPI rapidcore.LambdaInvokeAPI, eventLog *telemetry.EventLog, shutdownFunc context.CancelFunc, bs interop.Bootstrap) *chi.Mux { r := chi.NewRouter() r.Use(standaloneAccessLogDecorator) - r.Post("/2015-03-31/functions/*/invocations", func(w http.ResponseWriter, r *http.Request) { Execute(w, r, sandbox) }) - r.Post("/test/init", func(w http.ResponseWriter, r *http.Request) { InitHandler(w, r, sandbox) }) + + r.Post("/2015-03-31/functions/*/invocations", func(w http.ResponseWriter, r *http.Request) { Execute(w, r, lambdaInvokeAPI) }) + r.Get("/test/ping", func(w http.ResponseWriter, r *http.Request) { PingHandler(w, r) }) + r.Post("/test/init", func(w http.ResponseWriter, r *http.Request) { InitHandler(w, r, ipcSrv, bs) }) + r.Post("/test/waitUntilInitialized", func(w http.ResponseWriter, r *http.Request) { WaitUntilInitializedHandler(w, r, ipcSrv) }) r.Post("/test/reserve", func(w http.ResponseWriter, r *http.Request) { ReserveHandler(w, r, ipcSrv) }) r.Post("/test/invoke", func(w http.ResponseWriter, r *http.Request) { InvokeHandler(w, r, ipcSrv) }) r.Post("/test/waitUntilRelease", func(w http.ResponseWriter, r *http.Request) { WaitUntilReleaseHandler(w, r, ipcSrv) }) @@ -27,6 +44,6 @@ func NewHTTPRouter(sandbox rapidcore.Sandbox, eventLog *telemetry.EventLog, shut r.Post("/test/directInvoke/{reservationtoken}", func(w http.ResponseWriter, r *http.Request) { DirectInvokeHandler(w, r, ipcSrv) }) r.Get("/test/internalState", func(w http.ResponseWriter, r *http.Request) { InternalStateHandler(w, r, ipcSrv) }) r.Get("/test/eventLog", func(w http.ResponseWriter, r *http.Request) { EventLogHandler(w, r, eventLog) }) - + r.Post("/test/restore", func(w http.ResponseWriter, r *http.Request) { RestoreHandler(w, r, ipcSrv) }) return r } diff --git a/lambda/rapidcore/standalone/shutdownHandler.go b/lambda/rapidcore/standalone/shutdownHandler.go index ee91277..8085541 100644 --- a/lambda/rapidcore/standalone/shutdownHandler.go +++ b/lambda/rapidcore/standalone/shutdownHandler.go @@ -9,14 +9,13 @@ import ( "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapidcore" ) type shutdownAPIRequest struct { TimeoutMs int64 `json:"timeoutMs"` } -func ShutdownHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer, shutdownFunc context.CancelFunc) { +func ShutdownHandler(w http.ResponseWriter, r *http.Request, s InteropServer, shutdownFunc context.CancelFunc) { shutdown := shutdownAPIRequest{} if lerr := readBodyAndUnmarshalJSON(r, &shutdown); lerr != nil { lerr.Send(w, r) @@ -24,8 +23,7 @@ func ShutdownHandler(w http.ResponseWriter, r *http.Request, s rapidcore.Interop } internalState := s.Shutdown(&interop.Shutdown{ - DeadlineNs: metering.Monotime() + int64(shutdown.TimeoutMs*1000*1000), - CorrelationID: "shutdownCorrelationID", + DeadlineNs: metering.Monotime() + int64(shutdown.TimeoutMs*1000*1000), }) w.Write(internalState.AsJSON()) diff --git a/lambda/rapidcore/standalone/util.go b/lambda/rapidcore/standalone/util.go index 21ee08f..7ba7420 100644 --- a/lambda/rapidcore/standalone/util.go +++ b/lambda/rapidcore/standalone/util.go @@ -6,7 +6,7 @@ package standalone import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" log "github.com/sirupsen/logrus" @@ -58,7 +58,7 @@ func (w *ResponseWriterProxy) IsError() bool { } func readBodyAndUnmarshalJSON(r *http.Request, dst interface{}) *ErrorReply { - bodyBytes, err := ioutil.ReadAll(r.Body) + bodyBytes, err := io.ReadAll(r.Body) if err != nil { return newErrorReply(ClientInvalidRequest, fmt.Sprintf("Failed to read full body: %s", err)) } diff --git a/lambda/rapidcore/standalone/waitUntilInitializedHandler.go b/lambda/rapidcore/standalone/waitUntilInitializedHandler.go new file mode 100644 index 0000000..95d64ac --- /dev/null +++ b/lambda/rapidcore/standalone/waitUntilInitializedHandler.go @@ -0,0 +1,23 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package standalone + +import ( + "net/http" + + "go.amzn.com/lambda/rapidcore" +) + +func WaitUntilInitializedHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { + err := s.AwaitInitialized() + if err != nil { + switch err { + case rapidcore.ErrInitDoneFailed: + w.WriteHeader(DoneFailedHTTPCode) + case rapidcore.ErrInitResetReceived: + w.WriteHeader(DoneFailedHTTPCode) + } + } + w.WriteHeader(http.StatusOK) +} diff --git a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go index 9aab644..0a756dd 100644 --- a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go +++ b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go @@ -9,7 +9,7 @@ import ( "go.amzn.com/lambda/rapidcore" ) -func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s rapidcore.InteropServer) { +func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { internalState, err := s.AwaitRelease() if err != nil { switch err { @@ -20,7 +20,7 @@ func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s rapidcore // TODO use http.StatusOK w.WriteHeader(http.StatusGatewayTimeout) return - case rapidcore.ErrTerminated: + case rapidcore.ErrInitDoneFailed: w.WriteHeader(DoneFailedHTTPCode) w.Write(internalState.AsJSON()) return diff --git a/lambda/rapidcore/telemetry/eventLog.go b/lambda/rapidcore/telemetry/eventLog.go index c66672c..2f809fa 100644 --- a/lambda/rapidcore/telemetry/eventLog.go +++ b/lambda/rapidcore/telemetry/eventLog.go @@ -9,6 +9,8 @@ import ( "time" ) +// TODO: Refactor to represent event structs below as a form of Events API entity + type XrayEvent struct { Msg string `json:"msg"` TraceID string `json:"traceID"` @@ -32,20 +34,13 @@ type FunctionLogEvent struct{} type ExtensionLogEvent struct{} type EventLog struct { + Events []SandboxEvent `json:"events,omitempty"` // populated by the StandaloneEventLog object Xray []XrayEvent `json:"xray,omitempty"` PlatformLog []PlatformLogEvent `json:"platformLogs,omitempty"` Logs []string `json:"rawLogs,omitempty"` mutex sync.Mutex } -func (p *EventLog) LogXrayEvent(msg string, traceID string, segmentName string, segmentID string) { - p.Xray = append(p.Xray, XrayEvent{Msg: msg, TraceID: traceID, SegmentName: segmentName, SegmentID: segmentID, Timestamp: time.Now().UnixNano() / int64(time.Millisecond)}) -} - -func (p *EventLog) LogExtensionInitEvent(agentName string, state string, subscriptions string, errorType string) { - p.PlatformLog = append(p.PlatformLog, PlatformLogEvent{agentName, state, errorType, strings.Split(subscriptions, ",")}) -} - func parseLogString(s string) []string { elems := strings.Split(s, "\t")[1:] for i, e := range elems { @@ -62,19 +57,7 @@ func (p *EventLog) dispatchLogEvent(logStr string) { if strings.HasPrefix(logStr, "XRAY") { // format: 'XRAY\tMessage: %s\tTraceID: %s\tSegmentName: %s\tSegmentID: %s' msg, traceID, segmentName, segmentID := elems[0], elems[1], elems[2], elems[3] - p.LogXrayEvent(msg, traceID, segmentName, segmentID) - } - - if strings.HasPrefix(logStr, "EXTENSION") && strings.Contains(logStr, "Error Type") { - // format: 'EXTENSION\tName: %s\tState: %s\tEvents: [%s]\tError Type: %s' - agentName, state, subscriptions, errorType := elems[0], elems[1], elems[2], elems[3] - p.LogExtensionInitEvent(agentName, state, subscriptions, errorType) - } - - if strings.HasPrefix(logStr, "EXTENSION") && !strings.Contains(logStr, "Error Type") { - // format: 'EXTENSION\tName: %s\tState: %s\tEvents: [%s]' - agentName, state, subscriptions, errorType := elems[0], elems[1], elems[2], "" - p.LogExtensionInitEvent(agentName, state, subscriptions, errorType) + p.Xray = append(p.Xray, XrayEvent{Msg: msg, TraceID: traceID, SegmentName: segmentName, SegmentID: segmentID, Timestamp: time.Now().UnixNano() / int64(time.Millisecond)}) } } diff --git a/lambda/rapidcore/telemetry/events_api.go b/lambda/rapidcore/telemetry/events_api.go new file mode 100644 index 0000000..7a882fd --- /dev/null +++ b/lambda/rapidcore/telemetry/events_api.go @@ -0,0 +1,97 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "sort" + "time" + + "go.amzn.com/lambda/telemetry" +) + +// EventType indicates the type of SandboxEvent. See full list: +type EventType = string + +const ( + PlatformInitRuntimeDone = EventType("platform.initRuntimeDone") + PlatformRestoreRuntimeDone = EventType("platform.restoreRuntimeDone") + PlatformRuntimeDone = EventType("platform.runtimeDone") + PlatformExtension = EventType("platform.extension") +) + +/* + SandboxEvent represents a generic sandbox event. For example: + {'time': '2021-03-16T13:10:42.358Z', + 'type': 'platform.extension', + 'record': { "name": "foo bar", "state": "Ready", "events": ["INVOKE", "SHUTDOWN"]}} +*/ +type SandboxEvent struct { + Time string `json:"time"` + Type EventType `json:"type"` + Record map[string]interface{} `json:"record"` +} + +type StandaloneEventLog struct { + requestID string + eventLog *EventLog +} + +func (s *StandaloneEventLog) SetCurrentRequestID(requestID string) { + s.requestID = requestID +} + +func (s *StandaloneEventLog) SendInitRuntimeDone(data *telemetry.InitRuntimeDoneData) error { + record := map[string]interface{}{"initializationType": data.InitSource, "status": data.Status} + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformInitRuntimeDone, record}) + return nil +} + +func (s *StandaloneEventLog) SendRestoreRuntimeDone(status string) error { + record := map[string]interface{}{"status": status} + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformRestoreRuntimeDone, record}) + return nil +} + +func (s *StandaloneEventLog) SendRuntimeDone(data telemetry.InvokeRuntimeDoneData) error { + // e.g. 'record': {'requestId': '1506eb3053d148f3bb7ec0fabe6f8d91','status': 'success', 'metrics': {...}, 'tracing': {...}} + record := map[string]interface{}{ + "requestId": s.requestID, + "status": data.Status, + "metrics": data.Metrics, + "internalMetrics": data.InternalMetrics, + "spans": data.Spans, + } + + if data.Tracing != nil { + record["tracing"] = map[string]string{ + "spanId": data.Tracing.SpanID, + "type": string(data.Tracing.Type), + "value": data.Tracing.Value, + } + } + + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformRuntimeDone, record}) + return nil +} + +func (s *StandaloneEventLog) SendExtensionInit(agentName, state, errorType string, subscriptions []string) error { + // e.g. 'record': { "name": "", "state": "", errorType: "", events: [""] } + sort.Strings(subscriptions) + record := map[string]interface{}{"name": agentName, "state": state, "events": subscriptions} + if len(errorType) > 0 { + record["errorType"] = errorType + } + s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformExtension, record}) + return nil +} + +func (s *StandaloneEventLog) SendImageErrorLog(logline string) { + // Called on bootstrap exec errors for OCI error modes, e.g. InvalidEntrypoint etc. +} + +func NewStandaloneEventLog(eventLog *EventLog) *StandaloneEventLog { + return &StandaloneEventLog{ + eventLog: eventLog, + } +} diff --git a/lambda/runtimecmd/runtime_command.go b/lambda/runtimecmd/runtime_command.go deleted file mode 100644 index adf7886..0000000 --- a/lambda/runtimecmd/runtime_command.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package runtimecmd - -import ( - "context" - "fmt" - "io" - "os" - "os/exec" - "path" - "syscall" -) - -// CustomRuntimeCmd wraps exec.Cmd -type CustomRuntimeCmd struct { - *exec.Cmd -} - -// NewCustomRuntimeCmd returns a new CustomRuntimeCmd -func NewCustomRuntimeCmd(ctx context.Context, bootstrapCmd []string, dir string, env []string, stdoutWriter io.Writer, stderrWriter io.Writer, extraFiles []*os.File) *CustomRuntimeCmd { - cmd := exec.CommandContext(ctx, bootstrapCmd[0], bootstrapCmd[1:]...) - cmd.Dir = dir - - cmd.Stdout = stdoutWriter - cmd.Stderr = stderrWriter - - cmd.Env = env - - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - - if len(extraFiles) > 0 { - cmd.ExtraFiles = extraFiles - } - - return &CustomRuntimeCmd{cmd} -} - -// Name returnes runtime executable name -func (cmd *CustomRuntimeCmd) Name() string { - return path.Base(cmd.Path) -} - -// Pid returns the pid of a started runtime process -func (cmd *CustomRuntimeCmd) Pid() int { - return cmd.Process.Pid -} - -// Wait waits for the started customer runtime process to exit -func (cmd *CustomRuntimeCmd) Wait() error { - if err := cmd.Cmd.Wait(); err != nil { - return fmt.Errorf("Runtime exited with error: %v", err) - } - - return fmt.Errorf("Runtime exited without providing a reason") -} diff --git a/lambda/runtimecmd/runtime_command_test.go b/lambda/runtimecmd/runtime_command_test.go deleted file mode 100644 index f99599d..0000000 --- a/lambda/runtimecmd/runtime_command_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package runtimecmd - -import ( - "context" - "errors" - "io/ioutil" - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestRuntimeCommandSetsEnvironmentVariables(t *testing.T) { - envVars := []string{"foo=1", "bar=2", "baz=3"} - - currentDir, err := os.Getwd() - assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) - - execCmdArgs := []string{"foobar"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) - - assert.ElementsMatch(t, envVars, runtimeCmd.Env) - assert.Equal(t, execCmdArgs, runtimeCmd.Args) -} - -func TestRuntimeCommandSetsCurrentWorkingDir(t *testing.T) { - envVars := []string{} - - currentDir, err := os.Getwd() - assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) - - execCmdArgs := []string{"foobar"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) - - assert.Equal(t, currentDir, runtimeCmd.Dir) -} - -func TestRuntimeCommandSetsMultipleArgs(t *testing.T) { - envVars := []string{} - - currentDir, err := os.Getwd() - assert.NoError(t, err, errors.New("Failed to get working directory to execute helper process")) - - execCmdArgs := []string{"foobar", "--baz", "22"} - runtimeCmd := NewCustomRuntimeCmd(context.Background(), execCmdArgs, currentDir, envVars, ioutil.Discard, ioutil.Discard, nil) - - assert.Equal(t, execCmdArgs, runtimeCmd.Args) -} diff --git a/lambda/supervisor/local_supervisor.go b/lambda/supervisor/local_supervisor.go new file mode 100644 index 0000000..1174089 --- /dev/null +++ b/lambda/supervisor/local_supervisor.go @@ -0,0 +1,302 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package supervisor + +import ( + "errors" + "fmt" + "os/exec" + "sync" + "syscall" + "time" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/supervisor/model" +) + +// typecheck interface compliance +var _ model.SupervisorClient = (*LocalSupervisor)(nil) + +type process struct { + // pid of the running process + pid int + // channel that can be use to block + // while waiting on process termination. + termination chan struct{} +} + +type LocalSupervisor struct { + events chan model.Event + processMapLock sync.Mutex + processMap map[string]process +} + +func NewLocalSupervisor() model.Supervisor { + return model.Supervisor{ + SupervisorClient: &LocalSupervisor{ + events: make(chan model.Event), + processMap: make(map[string]process), + }, + OperatorConfig: model.DomainConfig{ + RootPath: "/", + }, + RuntimeConfig: model.DomainConfig{ + RootPath: "/", + }, + } +} + +func (*LocalSupervisor) Start(req *model.StartRequest) error { + return nil +} +func (*LocalSupervisor) Configure(req *model.ConfigureRequest) error { + return nil +} +func (s *LocalSupervisor) Exec(req *model.ExecRequest) error { + if req.Domain != "runtime" { + log.Debug("Exec is a no op if domain != runtime") + return nil + } + command := exec.Command(req.Path, req.Args...) + + if req.Env != nil { + envStrings := make([]string, 0, len(*req.Env)) + for key, value := range *req.Env { + envStrings = append(envStrings, key+"="+value) + } + command.Env = envStrings + } + + if req.Cwd != nil && *req.Cwd != "" { + command.Dir = *req.Cwd + } + + if req.ExtraFiles != nil { + command.ExtraFiles = *req.ExtraFiles + } + + command.Stdout = req.StdoutWriter + command.Stderr = req.StderrWriter + + command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + err := command.Start() + + if err != nil { + return err + // TODO Use supevisor specific error + } + + pid := command.Process.Pid + termination := make(chan struct{}) + s.processMapLock.Lock() + s.processMap[req.Name] = process{ + pid: pid, + termination: termination, + } + s.processMapLock.Unlock() + + go func() { + err = command.Wait() + // close the termination channel to unblock whoever's blocked on + // it (used to implement kill's blocking behaviour) + close(termination) + + var cell int32 + var exitStatus *int32 + var signo *int32 + var exitErr *exec.ExitError + + if err == nil { + exitStatus = &cell + } else if errors.As(err, &exitErr) { + if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { + if code := status.ExitStatus(); code >= 0 { + cell = int32(code) + exitStatus = &cell + } else { + cell = int32(status.Signal()) + signo = &cell + } + } + } + + if signo == nil && exitStatus == nil { + log.Error("Cannot convert process exit status to unix WaitStatus. This is unexpected. Assuming ExitStatus 1") + cell = 1 + exitStatus = &cell + } + s.events <- model.Event{ + Time: uint64(time.Now().UnixMilli()), + Event: model.EventData{ + Domain: &req.Domain, + Name: &req.Name, + Signo: signo, + ExitStatus: exitStatus, + }, + } + }() + + return nil +} + +func kill(p process, name string, timeout *time.Duration) error { + // kill should report success if the process terminated by the time + //supervisor receives the request. + select { + // ifthis case is selected, the channel is closed, + // which means the process is terminated + case <-p.termination: + log.Debugf("Process %s already terminated.", name) + return nil + default: + log.Infof("Sending SIGKILL to %s(%d).", name, p.pid) + } + + if timeout != nil && *timeout <= 0 { + return fmt.Errorf("Timed out while trying to SIGKILL %s", name) + } + + pgid, err := syscall.Getpgid(p.pid) + + if err == nil { + // Negative pid sends signal to all in process group + syscall.Kill(-pgid, syscall.SIGKILL) + } else { + syscall.Kill(p.pid, syscall.SIGKILL) + } + + // the nil channel blocks forever + var timer <-chan time.Time + if timeout != nil { + timer = time.After(*timeout) + } + + // block until the (main) process exits + // or the timeout fires + select { + case <-p.termination: + return nil + case <-timer: + return fmt.Errorf("Timed out while trying to SIGKILL %s", name) + } +} + +func (s *LocalSupervisor) Kill(req *model.KillRequest) error { + if req.Domain != "runtime" { + log.Debug("Kill is a no op if domain != runtime") + return nil + } + s.processMapLock.Lock() + process, ok := s.processMap[req.Name] + s.processMapLock.Unlock() + if !ok { + msg := "Unknown process" + return &model.SupervisorError{ + Kind: model.NoSuchEntity, + Message: &msg, + } + } + timeout := convertTimeout(req.Timeout) + + return kill(process, req.Name, timeout) +} + +func (s *LocalSupervisor) Terminate(req *model.TerminateRequest) error { + if req.Domain != "runtime" { + log.Debug("Terminate is no op if domain != runtime") + return nil + } + s.processMapLock.Lock() + process, ok := s.processMap[req.Name] + pid := process.pid + s.processMapLock.Unlock() + if !ok { + msg := "Unknown process" + err := &model.SupervisorError{ + Kind: model.NoSuchEntity, + Message: &msg, + } + log.WithError(err).Errorf("Process %s not found in local supervisor map", req.Name) + return err + } + + pgid, err := syscall.Getpgid(pid) + + if err == nil { + // Negative pid sends signal to all in process group + // best effort, ignore errors + _ = syscall.Kill(-pgid, syscall.SIGTERM) + } else { + _ = syscall.Kill(pid, syscall.SIGTERM) + } + + return nil +} + +func (s *LocalSupervisor) Stop(req *model.StopRequest) error { + if req.Domain != "runtime" { + log.Debug("Shutdown is no op if domain != runtime") + return nil + } + timeout := convertTimeout(req.Timeout) + + // shut down kills all the processes in the map + s.processMapLock.Lock() + defer s.processMapLock.Unlock() + + nprocs := len(s.processMap) + + successes := make(chan struct{}) + errors := make(chan error) + for name, proc := range s.processMap { + go func(n string, p process) { + log.Debugf("Killing %s", n) + err := kill(p, n, timeout) + if err != nil { + errors <- err + } else { + successes <- struct{}{} + } + + }(name, proc) + } + + var err error + for i := 0; i < nprocs; i++ { + select { + case <-successes: + case e := <-errors: + if err == nil { + err = fmt.Errorf("Shutdown failed: %s", e.Error()) + } + } + + } + + s.processMap = make(map[string]process) + return err +} +func (*LocalSupervisor) Freeze(req *model.FreezeRequest) error { + return nil +} +func (*LocalSupervisor) Thaw(req *model.ThawRequest) error { + return nil +} +func (s *LocalSupervisor) Ping() error { + return nil +} + +func (s *LocalSupervisor) Events() (<-chan model.Event, error) { + return s.events, nil +} + +func convertTimeout(millis *uint64) *time.Duration { + var timeout *time.Duration + if millis != nil { + t := time.Duration(*millis) * time.Millisecond + timeout = &t + } + return timeout +} diff --git a/lambda/supervisor/local_supervisor_test.go b/lambda/supervisor/local_supervisor_test.go new file mode 100644 index 0000000..8b3336b --- /dev/null +++ b/lambda/supervisor/local_supervisor_test.go @@ -0,0 +1,215 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package supervisor + +import ( + "errors" + "fmt" + "syscall" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/supervisor/model" +) + +func TestRuntimeDomainExec(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + + assert.Nil(t, err) +} + +func TestInvalidRuntimeDomainExec(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/none", + }) + + require.Error(t, err) +} + +func TestEvents(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + sync := make(chan struct{}) + go func() { + evt, ok := <-client.events + require.True(t, ok) + termination := evt.Event.ProcessTerminated() + require.NotNil(t, termination) + assert.Equal(t, "runtime", *termination.Domain) + assert.Equal(t, "agent", *termination.Name) + sync <- struct{}{} + }() + + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + <-sync +} + +func TestTerminate(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = supv.Terminate(&model.TerminateRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err) + // wait for process exit notification + ev := <-client.events + require.NotNil(t, ev.Event.ProcessTerminated()) + term := *ev.Event.ProcessTerminated() + require.Nil(t, term.Exited()) + require.NotNil(t, term.Signaled()) + require.EqualValues(t, syscall.SIGTERM, *term.Signo) +} + +// Termiante should not fail if the message is not delivered +func TestTerminateExited(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + // wait a short bit for bash to exit + time.Sleep(100 * time.Millisecond) + err = supv.Terminate(&model.TerminateRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err) +} + +func TestKill(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + Args: []string{"-c", "sleep 10s"}, + }) + require.NoError(t, err) + err = supv.Kill(&model.KillRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err) + timer := time.NewTimer(50 * time.Millisecond) + select { + case _, ok := <-client.events: + assert.True(t, ok) + case <-timer.C: + require.Fail(t, "Process should have exited by the time kill returns") + } +} + +func TestKillExited(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent", + Path: "/bin/bash", + }) + require.NoError(t, err) + //wait for natural exit event + <-client.events + err = supv.Kill(&model.KillRequest{ + Domain: "runtime", + Name: "agent", + }) + require.NoError(t, err, "Kill should succeed for exited processes") +} + +func TestKillUnknown(t *testing.T) { + supv := NewLocalSupervisor() + err := supv.Kill(&model.KillRequest{ + Domain: "runtime", + Name: "unknown", + }) + require.Error(t, err) + var supvError *model.SupervisorError + assert.True(t, errors.As(err, &supvError)) + assert.Equal(t, supvError.Kind, model.NoSuchEntity) +} + +func TestShutdown(t *testing.T) { + supv := NewLocalSupervisor() + client := supv.SupervisorClient.(*LocalSupervisor) + log.Debug("hello") + // start a bunch of processes, some short running, some longer running + err := supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent-0", + Path: "/bin/bash", + Args: []string{"-c", "sleep 1s"}, + }) + require.NoError(t, err) + + err = supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent-1", + Path: "/bin/bash", + }) + require.NoError(t, err) + + err = supv.Exec(&model.ExecRequest{ + Domain: "runtime", + Name: "agent-2", + Path: "/bin/bash", + Args: []string{"-c", "sleep 2s"}, + }) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = supv.Stop(&model.StopRequest{ + Domain: "runtime", + }) + require.NoError(t, err) + // Shutdown is expected to block untill all processes have exited + expected := map[string]struct{}{ + "agent-0": {}, + "agent-1": {}, + "agent-2": {}, + } + done := false + timer := time.NewTimer(200 * time.Millisecond) + for !done { + select { + case ev := <-client.events: + data := ev.Event.ProcessTerminated() + assert.NotNil(t, data) + _, ok := expected[*data.Name] + assert.True(t, ok) + delete(expected, *data.Name) + case <-timer.C: + fmt.Print(expected) + assert.Equal(t, 0, len(expected), "All process should terminate at shutdown") + done = true + } + } +} diff --git a/lambda/supervisor/model/model.go b/lambda/supervisor/model/model.go new file mode 100644 index 0000000..384726d --- /dev/null +++ b/lambda/supervisor/model/model.go @@ -0,0 +1,269 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "fmt" + "io" + "os" + "syscall" +) + +type Supervisor struct { + SupervisorClient + OperatorConfig DomainConfig + RuntimeConfig DomainConfig +} + +type DomainConfig struct { + // path to the root of the domain within the root mnt namespace + RootPath string +} + +type SupervisorClient interface { + Start(req *StartRequest) error + Configure(req *ConfigureRequest) error + Exec(req *ExecRequest) error + Terminate(req *TerminateRequest) error + Kill(req *KillRequest) error + Stop(req *StopRequest) error + Freeze(req *FreezeRequest) error + Thaw(req *ThawRequest) error + Ping() error + Events() (<-chan Event, error) +} + +type StartRequest struct { + Domain string `json:"domain"` + // name of the cgroup profile to start the domain in + CgroupProfile *string `json:"cgroup_profile,omitempty"` +} + +// Mount in lockhard::mnt is a Rust enum, an algebraic type, where each case has different set of fields. +// This models only the Mount::Drive case, the only one we need for now. +type DriveMount struct { + Source string `json:"source,omitempty"` + Destination string `json:"destination,omitempty"` + FsType string `json:"fs_type,omitempty"` + Options []string `json:"options,omitempty"` + Chowner []uint32 `json:"chowner,omitempty"` // array of two integers representing a tuple + Chmode uint32 `json:"chmode,omitempty"` + // Lockhard also expects a "type" field here, which in our case is constant, so we provide it upon serialization below +} + +// Adds the "type": "drive" to json +func (m *DriveMount) MarshalJSON() ([]byte, error) { + type driveMountAlias DriveMount + + return json.Marshal(&struct { + Type string `json:"type,omitempty"` + *driveMountAlias + }{ + Type: "drive", + driveMountAlias: (*driveMountAlias)(m), + }) +} + +type Capabilities struct { + Ambient []string `json:"ambient,omitempty"` + Bounding []string `json:"bounding,omitempty"` + Effective []string `json:"effective,omitempty"` + Inheritable []string `json:"inheritable,omitempty"` + Permitted []string `json:"permitted,omitempty"` +} + +type CgroupProfile struct { + Name string `json:"name"` + CPUPct *float64 `json:"cpu_pct,omitempty"` + MemMaxBytes *uint64 `json:"mem_max,omitempty"` +} + +type ExecUser struct { + UID *uint32 `json:"uid"` + GID *uint32 `json:"gid"` +} + +type ConfigureRequest struct { + // domain to configure + Domain string `json:"domain"` + Mounts []DriveMount `json:"mounts,omitempty"` + Capabilities *Capabilities `json:"capabilities,omitempty"` + SeccompFilters []string `json:"seccomp_filters,omitempty"` + // list of cgroup profiles available for the domain + // cgroup profiles are set on boot or thaw requests + CgroupProfiles []CgroupProfile `json:"cgroup_profiles,omitempty"` + // uid and gid of the user the spawned process runs as (w.r.t. the domain user namespace). + // If nil, Supervisor will use the ExecUser specified in the domain configuration file + ExecUser *ExecUser `json:"exec_user,omitempty"` + // additional hooks to execute on domain start + AdditionalStartHooks []Hook `json:"additional_start_hooks,omitempty"` +} + +type Event struct { + Time uint64 `json:"timestamp_millis"` + Event EventData `json:"event"` +} + +// EventData is a union type tagged by the "EventType" +// and "Cause" strings. +// you can use ProcessTermination() or EventLoss() to access +// the correct type of Event. +type EventData struct { + EvType string `json:"type"` + Domain *string `json:"domain"` + Name *string `json:"name"` + Cause *string `json:"cause"` + Signo *int32 `json:"signo"` + ExitStatus *int32 `json:"exit_status"` + Size *uint64 `json:"size"` +} + +// returns nil if the event is not a EventLoss event +// otherwise returns how many events were lost due to +// backpressure (slow reader) +func (d EventData) EventLoss() *uint64 { + return d.Size +} + +// Returns a ProcessTermination struct that describe the process +// which terminated. Use Signaled() or Exited() to check whether +// the process terminated because of a signal or exited on its own +func (d EventData) ProcessTerminated() *ProcessTermination { + if d.Signo != nil || d.ExitStatus != nil { + return &ProcessTermination{ + Domain: d.Domain, + Name: d.Name, + Signo: d.Signo, + ExitStatus: d.ExitStatus, + } + } + return nil +} + +// Event signalling that a process exited +type ProcessTermination struct { + Domain *string + Name *string + Signo *int32 + ExitStatus *int32 +} + +// If not nil, the process was terminated by an unhandled signal. +// The returned value is the number of the signal that terminated the process +func (t ProcessTermination) Signaled() *int32 { + return t.Signo +} + +// It not nil, the process exited (as opposed to killed by a signal). +// The returned value is the exit_status returned by the process +func (t ProcessTermination) Exited() *int32 { + return t.ExitStatus +} + +func (t ProcessTermination) Success() bool { + return t.ExitStatus != nil && *t.ExitStatus == 0 +} + +// Transform the process termination status in a string that +// is equal to what would be returned by golang exec.ExitError.Error() +// We used to rely on this format to report errors to customer (sigh) +// so we keep this for backwards compatibility +func (t ProcessTermination) String() string { + if t.ExitStatus != nil { + return fmt.Sprintf("exit status %d", *t.ExitStatus) + } + sig := syscall.Signal(*t.Signo) + return fmt.Sprintf("signal: %s", sig.String()) +} + +type Hook struct { + // Unique name identifying the hook + Name string `json:"name"` + // Path in the parent domain mount namespace that locates + // the executable to run as the hook + Path string `json:"path"` + // Args for the hook + Args []string `json:"args,omitempty"` + // Map of ENV variables to set when running the hook + Env *map[string]string `json:"envs,omitempty"` + // Maximum time for the hook to run. The hook will be considered failed + // if it takes more than this value (default 10_000) + TimeoutMillis *uint64 `json:"timeout_millis,omitempty"` +} + +type ExecRequest struct { + // Identifier that Supervisor will assign to the spawned process. + // The tuple (Domain,Name) must be unique. It is the caller's responsibility + // to generate the unique name + Name string `json:"name"` + Domain string `json:"domain"` + // Path pointing to the exectuable file within the domain's root filesystem + Path string `json:"path"` + Args []string `json:"args,omitempty"` + // If nil, root of the domain + Cwd *string `json:"cwd,omitempty"` + Env *map[string]string `json:"env,omitempty"` + // If not nil, points to the socket that Supervisor + // uses to get the processes stdout and stderr. + LogsSock *string `json:"logs_sock,omitempty"` + StdoutWriter io.Writer `json:"-"` + StderrWriter io.Writer `json:"-"` + ExtraFiles *[]*os.File `json:"-"` +} + +type ErrorKind string + +const ( + // operation on an unkown entity (e.g., domain process) + NoSuchEntity ErrorKind = "no_such_entity" + // operation not allowed in the current state (e.g., tried to exec a proces in a domain which is not booted) + InvalidState ErrorKind = "invalid_state" + // Serialization or derserialization issue in the communication + Serde ErrorKind = "serde" + // Unhandled Supervisor server error + Failure ErrorKind = "failure" +) + +type SupervisorError struct { + Kind ErrorKind `json:"error_kind"` + Message *string `json:"message"` +} + +func (e *SupervisorError) Error() string { + return string(e.Kind) +} + +// Send SIGETERM asynchrnously to a process +type TerminateRequest struct { + Name string `json:"name"` + Domain string `json:"domain"` +} + +// Force terminate a process (SIGKILL) +// Block until process is exited or timeout +// If timeout is 0 or nil, block forever +type KillRequest struct { + Name string `json:"name"` + Domain string `json:"domain"` + Timeout *uint64 `json:",omitempty"` +} + +// Stop the domain. Supervisor will first try to +// cleanly terminate the domain's init process. If unsuccessful, +// within Timeout seconds, it will send SIGKILL. +type StopRequest struct { + Domain string `json:"domain"` + Timeout *uint64 `json:",omitempty"` +} + +type FreezeRequest struct { + Domain string `json:"domain"` +} + +type ThawRequest struct { + Domain string `json:"domain"` + // if not nil, changes the cgroup profile of the domain upon thawing. + CgroupProfile *string `json:"cgroup_profile,omitempty"` +} diff --git a/lambda/telemetry/events_api.go b/lambda/telemetry/events_api.go index 132977e..e7c5c36 100644 --- a/lambda/telemetry/events_api.go +++ b/lambda/telemetry/events_api.go @@ -3,12 +3,136 @@ package telemetry +import ( + "time" + + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" + "go.amzn.com/lambda/rapi/model" +) + +type RuntimeDoneInvokeMetrics struct { + ProducedBytes int64 + DurationMs float64 +} + +func GetRuntimeDoneInvokeMetrics(invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics, runtimeDoneTime int64) *RuntimeDoneInvokeMetrics { + if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && invokeReceivedTime != 0 { + return &RuntimeDoneInvokeMetrics{ + ProducedBytes: invokeResponseMetrics.ProducedBytes, + // time taken from sending the invoke to the sandbox until the runtime calls GET /next + DurationMs: float64((runtimeDoneTime - invokeReceivedTime) / int64(time.Millisecond)), + } + } + + // when we get a reset before runtime called /response + if invokeReceivedTime != 0 { + return &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64((runtimeDoneTime - invokeReceivedTime) / int64(time.Millisecond)), + } + } + + // We didn't have time to register the invokeReceiveTime, which means we crash/reset very early, + // too early for the runtime to actual run. In such case, the runtimeDone event shouldn't be sent + // Not returning Nil even in this improbable case guarantees that we will always have some metrics to send to FluxPump + return &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(0), + } +} + +type InitRuntimeDoneData struct { + InitSource string + Status string +} + +type InvokeRuntimeDoneData struct { + Status string + Metrics *RuntimeDoneInvokeMetrics + InternalMetrics *interop.InvokeResponseMetrics + Tracing *TracingCtx + Spans []Span +} + +type Span struct { + Name string + Start string + DurationMs float64 +} + +func GetRuntimeDoneSpans(invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) []Span { + if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && invokeReceivedTime != 0 { + // time span from when the invoke is received in the sandbox to the moment the runtime calls PUT /response + responseLatencyMsSpan := Span{ + Name: "responseLatency", + Start: getEpochTimeInISO8601FormatFromMonotime(invokeReceivedTime), + DurationMs: float64((invokeResponseMetrics.StartReadingResponseMonoTimeMs - invokeReceivedTime) / int64(time.Millisecond)), + } + + // time span from when the runtime called PUT /response to the moment the body of the response is fully sent + responseDurationMsSpan := Span{ + Name: "responseDuration", + Start: getEpochTimeInISO8601FormatFromMonotime(invokeResponseMetrics.StartReadingResponseMonoTimeMs), + DurationMs: float64((invokeResponseMetrics.FinishReadingResponseMonoTimeMs - invokeResponseMetrics.StartReadingResponseMonoTimeMs) / int64(time.Millisecond)), + } + return []Span{responseLatencyMsSpan, responseDurationMsSpan} + } + + return []Span{} +} + +func getEpochTimeInISO8601FormatFromMonotime(monotime int64) string { + return time.Unix(0, metering.MonoToEpoch(monotime)).Format("2006-01-02T15:04:05.000Z") +} + +type TracingCtx struct { + SpanID string + Type model.TracingType + Value string +} + +func BuildTracingCtx(tracingType model.TracingType, traceID string, lambdaSegmentID string) *TracingCtx { + // it takes current tracing context and change its parent value with the provided lambda segment id + root, currentParent, sample := ParseTraceID(traceID) + if root == "" || sample != model.XRaySampled { + return nil + } + + return &TracingCtx{ + SpanID: currentParent, + Type: tracingType, + Value: BuildFullTraceID(root, lambdaSegmentID, sample), + } +} + +const ( + RuntimeDoneSuccess = "success" + RuntimeDoneFailure = "failure" +) + type EventsAPI interface { SetCurrentRequestID(requestID string) - SendRuntimeDone(status string) error + SendInitRuntimeDone(data *InitRuntimeDoneData) error + SendRestoreRuntimeDone(status string) error + SendRuntimeDone(data InvokeRuntimeDoneData) error + SendExtensionInit(agentName, state, errorType string, subscriptions []string) error + SendImageErrorLog(logline string) } type NoOpEventsAPI struct{} func (s *NoOpEventsAPI) SetCurrentRequestID(requestID string) {} -func (s *NoOpEventsAPI) SendRuntimeDone(status string) error { return nil } +func (s *NoOpEventsAPI) SendInitRuntimeDone(data *InitRuntimeDoneData) error { + return nil +} +func (s *NoOpEventsAPI) SendRestoreRuntimeDone(status string) error { + return nil +} +func (s *NoOpEventsAPI) SendRuntimeDone(data InvokeRuntimeDoneData) error { + return nil +} +func (s *NoOpEventsAPI) SendExtensionInit(agentName, state, errorType string, subscriptions []string) error { + return nil +} +func (s *NoOpEventsAPI) SendImageErrorLog(logline string) {} diff --git a/lambda/telemetry/events_api_test.go b/lambda/telemetry/events_api_test.go new file mode 100644 index 0000000..b943be9 --- /dev/null +++ b/lambda/telemetry/events_api_test.go @@ -0,0 +1,139 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" +) + +func TestGetRuntimeDoneInvokeMetrics(t *testing.T) { + now := metering.Monotime() + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + ProducedBytes: int64(100), + RuntimeCalledResponse: true, + } + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(10), + } + + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, runtimeDoneTime)) +} + +func TestGetRuntimeDoneInvokeMetricsWhenRuntimeCalledError(t *testing.T) { + now := metering.Monotime() + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + ProducedBytes: int64(100), + RuntimeCalledResponse: false, + } + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(10), + } + + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, runtimeDoneTime)) +} + +func TestGetRuntimeDoneInvokeMetricsWhenInvokeReceivedTimeIsZero(t *testing.T) { + now := int64(0) // January 1st, 1970 at 00:00:00 UTC + invokeReceivedTime := now + + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(0), + } + actual := GetRuntimeDoneInvokeMetrics(invokeReceivedTime, nil, runtimeDoneTime) + assert.Equal(t, expected, actual) +} + +func TestGetRuntimeDoneInvokeMetricsWhenInvokeResponseMetricsIsNil(t *testing.T) { + now := metering.Monotime() + invokeReceivedTime := now + + runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + + expected := &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(10), + } + + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, nil, runtimeDoneTime)) +} + +func TestGetRuntimeDoneSpans(t *testing.T) { + now := metering.Monotime() + startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) + finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, + RuntimeCalledResponse: true, + } + + expectedResponseLatencyMsStartTime := getEpochTimeInISO8601FormatFromMonotime(now) + expectedResponseDurationMsStartTime := getEpochTimeInISO8601FormatFromMonotime(startReadingResponseMonoTimeMs) + expected := []Span{ + Span{ + Name: "responseLatency", + Start: expectedResponseLatencyMsStartTime, + DurationMs: 5, + }, + Span{ + Name: "responseDuration", + Start: expectedResponseDurationMsStartTime, + DurationMs: 2, + }, + } + + assert.Equal(t, expected, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) +} + +func TestGetRuntimeDoneSpansWhenRuntimeCalledError(t *testing.T) { + now := metering.Monotime() + startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) + finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) + + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, + FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, + RuntimeCalledResponse: false, + } + + assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) +} + +func TestGetRuntimeDoneSpansWhenInvokeResponseMetricsNil(t *testing.T) { + invokeReceivedTime := metering.Monotime() + + assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, nil)) +} + +func TestGetRuntimeDoneSpansWhenInvokeReceivedTimeIsZero(t *testing.T) { + now := int64(0) // January 1st, 1970 at 00:00:00 UTC + invokeReceivedTime := now + invokeResponseMetrics := &interop.InvokeResponseMetrics{ + StartReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(5)), + FinishReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(7)), + } + + assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) +} diff --git a/lambda/telemetry/logs_egress_api.go b/lambda/telemetry/logs_egress_api.go index ac9a754..7e84fe2 100644 --- a/lambda/telemetry/logs_egress_api.go +++ b/lambda/telemetry/logs_egress_api.go @@ -8,7 +8,12 @@ import ( "os" ) -type LogsEgressAPI interface { +// StdLogsEgressAPI is the interface that wraps the basic methods required to setup +// logs channels for Runtime's stdout/stderr and Extension's stdout/stderr. +// +// Implementation should return a Writer implementor for stdout and another for +// stderr on success and an error on failure. +type StdLogsEgressAPI interface { GetExtensionSockets() (io.Writer, io.Writer, error) GetRuntimeSockets() (io.Writer, io.Writer, error) } diff --git a/lambda/telemetry/logs_subscription_api.go b/lambda/telemetry/logs_subscription_api.go index 3ea7a20..6ee9490 100644 --- a/lambda/telemetry/logs_subscription_api.go +++ b/lambda/telemetry/logs_subscription_api.go @@ -10,28 +10,37 @@ import ( "go.amzn.com/lambda/interop" ) -// LogsSubscriptionAPI represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible -type LogsSubscriptionAPI interface { +// SubscriptionAPI represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible +type SubscriptionAPI interface { Subscribe(agentName string, body io.Reader, headers map[string][]string) (resp []byte, status int, respHeaders map[string][]string, err error) RecordCounterMetric(metricName string, count int) - FlushMetrics() interop.LogsAPIMetrics + FlushMetrics() interop.TelemetrySubscriptionMetrics Clear() TurnOff() + GetEndpointURL() string + GetServiceClosedErrorMessage() string + GetServiceClosedErrorType() string } -type NoOpLogsSubscriptionAPI struct{} +type NoOpSubscriptionAPI struct{} // Subscribe writes response to a shared memory -func (m *NoOpLogsSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { +func (m *NoOpSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { return []byte(`{}`), http.StatusOK, map[string][]string{}, nil } -func (m *NoOpLogsSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} +func (m *NoOpSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} -func (m *NoOpLogsSubscriptionAPI) FlushMetrics() interop.LogsAPIMetrics { - return interop.LogsAPIMetrics(map[string]int{}) +func (m *NoOpSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { + return interop.TelemetrySubscriptionMetrics(map[string]int{}) } -func (m *NoOpLogsSubscriptionAPI) Clear() {} +func (m *NoOpSubscriptionAPI) Clear() {} -func (m *NoOpLogsSubscriptionAPI) TurnOff() {} +func (m *NoOpSubscriptionAPI) TurnOff() {} + +func (m *NoOpSubscriptionAPI) GetEndpointURL() string { return "" } + +func (m *NoOpSubscriptionAPI) GetServiceClosedErrorMessage() string { return "" } + +func (m *NoOpSubscriptionAPI) GetServiceClosedErrorType() string { return "" } diff --git a/lambda/telemetry/tracer.go b/lambda/telemetry/tracer.go index 1ac8325..affca60 100644 --- a/lambda/telemetry/tracer.go +++ b/lambda/telemetry/tracer.go @@ -11,6 +11,7 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi/model" ) type traceContextKey int @@ -129,3 +130,24 @@ func ParseTraceID(fullTraceID string) (rootID, parentID, sample string) { } return } + +// BuildFullTraceID takes individual components of X-Ray trace header +// and puts them together into a formatted trace header. +// If root is empty, returns an empty string. +func BuildFullTraceID(root, parent, sample string) string { + if root == "" { + return "" + } + + parts := make([]string, 0, 3) + parts = append(parts, "Root="+root) + if parent != "" { + parts = append(parts, "Parent="+parent) + } + if sample == "" { + sample = model.XRayNonSampled + } + parts = append(parts, "Sampled="+sample) + + return strings.Join(parts, ";") +} diff --git a/lambda/telemetry/tracer_test.go b/lambda/telemetry/tracer_test.go index 9ac1260..c31653f 100644 --- a/lambda/telemetry/tracer_test.go +++ b/lambda/telemetry/tracer_test.go @@ -5,6 +5,8 @@ package telemetry import ( "testing" + + "go.amzn.com/lambda/rapi/model" ) var parserTests = []struct { @@ -35,3 +37,47 @@ func TestParseTraceID(t *testing.T) { }) } } + +func TestBuildFullTraceID(t *testing.T) { + specs := map[string]struct { + root string + parent string + sample string + expectedTraceID string + }{ + "all non-empty components, sampled": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + parent: "c88d77b0aef840e9", + sample: model.XRaySampled, + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", + }, + "all non-empty components, non-sampled": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + parent: "c88d77b0aef840e9", + sample: model.XRayNonSampled, + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=0", + }, + "root is non-empty, parent and sample are empty": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Sampled=0", + }, + "root is empty": { + parent: "c88d77b0aef840e9", + expectedTraceID: "", + }, + "sample is empty": { + root: "1-5b3cc918-939afd635f8891ba6a9e1df6", + parent: "c88d77b0aef840e9", + expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=0", + }, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + actual := BuildFullTraceID(spec.root, spec.parent, spec.sample) + if actual != spec.expectedTraceID { + t.Errorf("got %q, wanted %q", actual, spec.expectedTraceID) + } + }) + } +} diff --git a/lambda/testdata/agents/bash_stderr.sh b/lambda/testdata/agents/bash_stderr.sh deleted file mode 100755 index 65c0ff1..0000000 --- a/lambda/testdata/agents/bash_stderr.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash - -printf "stderr line 1\n" >&2 -printf "stderr line 2\n" >&2 -printf "stderr line 3\n" >&2 diff --git a/lambda/testdata/agents/bash_stdout.sh b/lambda/testdata/agents/bash_stdout.sh deleted file mode 100755 index d0cb893..0000000 --- a/lambda/testdata/agents/bash_stdout.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash - -printf "stdout line 1\n" -printf "stdout line 2\n" -printf "stdout line 3\n" diff --git a/lambda/testdata/agents/bash_stdout_and_stderr.sh b/lambda/testdata/agents/bash_stdout_and_stderr.sh deleted file mode 100755 index cf87e60..0000000 --- a/lambda/testdata/agents/bash_stdout_and_stderr.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env bash - -printf "stdout line 1\n" -printf "stderr line 1\n" >&2 -printf "stdout line 2\n" -printf "stderr line 2\n" >&2 -printf "stdout line 3\n" -printf "stderr line 3\n" >&2 diff --git a/lambda/testdata/flowtesting.go b/lambda/testdata/flowtesting.go index ee163bb..c028d7c 100644 --- a/lambda/testdata/flowtesting.go +++ b/lambda/testdata/flowtesting.go @@ -8,28 +8,31 @@ import ( "io" "io/ioutil" "net/http" + "time" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" - "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/rendering" "go.amzn.com/lambda/telemetry" "go.amzn.com/lambda/testdata/mockthread" ) +const ( + contentTypeHeader = "Content-Type" + functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" +) + type MockInteropServer struct { - Response []byte - ErrorResponse *interop.ErrorResponse - ResponseContentType string - ActiveInvokeID string + Response []byte + ErrorResponse *interop.ErrorResponse + ResponseContentType string + FunctionResponseMode string + ActiveInvokeID string } -// StartAcceptingDirectInvokes -func (i *MockInteropServer) StartAcceptingDirectInvokes() error { return nil } - // SendResponse writes response to a shared memory. -func (i *MockInteropServer) SendResponse(invokeID string, contentType string, reader io.Reader) error { +func (i *MockInteropServer) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { bytes, err := ioutil.ReadAll(reader) if err != nil { return err @@ -41,7 +44,8 @@ func (i *MockInteropServer) SendResponse(invokeID string, contentType string, re } } i.Response = bytes - i.ResponseContentType = contentType + i.ResponseContentType = headers[contentTypeHeader] + i.FunctionResponseMode = headers[functionResponseModeHeader] return nil } @@ -49,68 +53,35 @@ func (i *MockInteropServer) SendResponse(invokeID string, contentType string, re func (i *MockInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error { i.ErrorResponse = response i.ResponseContentType = response.ContentType + i.FunctionResponseMode = response.FunctionResponseMode return nil } -func (i *MockInteropServer) GetCurrentInvokeID() string { - return i.ActiveInvokeID +// SendInitErrorResponse writes error response during init to a shared memory and sends GIRD FAULT. +func (i *MockInteropServer) SendInitErrorResponse(invokeID string, response *interop.ErrorResponse) error { + i.ErrorResponse = response + i.ResponseContentType = response.ContentType + return nil } -func (i *MockInteropServer) CommitResponse() error { return nil } - -// SendRunning sends GIRD RUNNING. -func (i *MockInteropServer) SendRunning(*interop.Running) error { return nil } - -// SendDone sends GIRD DONE. -func (i *MockInteropServer) SendDone(*interop.Done) error { return nil } - -// SendDoneFail sends GIRD DONEFAIL. -func (i *MockInteropServer) SendDoneFail(*interop.DoneFail) error { return nil } - -// StartChan returns Start emitter -func (i *MockInteropServer) StartChan() <-chan *interop.Start { return nil } - -// InvokeChan returns Invoke emitter -func (i *MockInteropServer) InvokeChan() <-chan *interop.Invoke { return nil } - -// ResetChan returns Reset emitter -func (i *MockInteropServer) ResetChan() <-chan *interop.Reset { return nil } - -// ShutdownChan returns Shutdown emitter -func (i *MockInteropServer) ShutdownChan() <-chan *interop.Shutdown { return nil } - -// TransportErrorChan emits errors if there was parsing/connection issue -func (i *MockInteropServer) TransportErrorChan() <-chan error { return nil } - -func (i *MockInteropServer) Clear() {} - -func (i *MockInteropServer) IsResponseSent() bool { - return !(i.Response == nil && i.ErrorResponse == nil) +func (i *MockInteropServer) GetCurrentInvokeID() string { + return i.ActiveInvokeID } func (i *MockInteropServer) SendRuntimeReady() error { return nil } -func (i *MockInteropServer) SetInternalStateGetter(isd interop.InternalStateGetter) {} - -func (m *MockInteropServer) Init(i *interop.Start, invokeTimeoutMs int64) {} - -func (m *MockInteropServer) Invoke(w http.ResponseWriter, i *interop.Invoke) error { return nil } - -func (m *MockInteropServer) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { - return nil -} - // FlowTest provides configuration for tests that involve synchronization flows. type FlowTest struct { - AppCtx appctx.ApplicationContext - InitFlow core.InitFlowSynchronization - InvokeFlow core.InvokeFlowSynchronization - RegistrationService core.RegistrationService - RenderingService *rendering.EventRenderingService - Runtime *core.Runtime - InteropServer *MockInteropServer - LogsSubscriptionAPI *telemetry.NoOpLogsSubscriptionAPI - CredentialsService core.CredentialsService + AppCtx appctx.ApplicationContext + InitFlow core.InitFlowSynchronization + InvokeFlow core.InvokeFlowSynchronization + RegistrationService core.RegistrationService + RenderingService *rendering.EventRenderingService + Runtime *core.Runtime + InteropServer *MockInteropServer + TelemetrySubscription *telemetry.NoOpSubscriptionAPI + CredentialsService core.CredentialsService + EventsAPI telemetry.EventsAPI } // ConfigureForInit initialize synchronization gates and states for init. @@ -125,13 +96,13 @@ func (s *FlowTest) ConfigureForInvoke(ctx context.Context, invoke *interop.Invok s.RenderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invoke, telemetry.GetCustomerTracingHeader)) } -func (s *FlowTest) ConfigureForInitCaching(token, awsKey, awsSecret, awsSession string) { - s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession) +func (s *FlowTest) ConfigureForRestore() { + s.RenderingService.SetRenderer(rendering.NewRestoreRenderer()) } -func (s *FlowTest) ConfigureForBlockedInitCaching(token, awsKey, awsSecret, awsSession string) { - s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession) - s.CredentialsService.BlockService() +func (s *FlowTest) ConfigureForInitCaching(token, awsKey, awsSecret, awsSession string) { + credentialsExpiration := time.Now().Add(30 * time.Minute) + s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession, credentialsExpiration) } // NewFlowTest returns new FlowTest configuration. @@ -145,16 +116,18 @@ func NewFlowTest() *FlowTest { runtime := core.NewRuntime(initFlow, invokeFlow) runtime.ManagedThread = &mockthread.MockManagedThread{} interopServer := &MockInteropServer{} + eventsAPI := telemetry.NoOpEventsAPI{} appctx.StoreInteropServer(appCtx, interopServer) return &FlowTest{ - AppCtx: appCtx, - InitFlow: initFlow, - InvokeFlow: invokeFlow, - RegistrationService: registrationService, - RenderingService: renderingService, - LogsSubscriptionAPI: &telemetry.NoOpLogsSubscriptionAPI{}, - Runtime: runtime, - InteropServer: interopServer, - CredentialsService: credentialsService, + AppCtx: appCtx, + InitFlow: initFlow, + InvokeFlow: invokeFlow, + RegistrationService: registrationService, + RenderingService: renderingService, + TelemetrySubscription: &telemetry.NoOpSubscriptionAPI{}, + Runtime: runtime, + InteropServer: interopServer, + CredentialsService: credentialsService, + EventsAPI: &eventsAPI, } } diff --git a/test/integration/local_lambda/test_end_to_end.py b/test/integration/local_lambda/test_end_to_end.py index a85abce..c5c3e63 100644 --- a/test/integration/local_lambda/test_end_to_end.py +++ b/test/integration/local_lambda/test_end_to_end.py @@ -53,6 +53,7 @@ def tearDownClass(cls): "remaining_time_in_default_deadline", "pre-runtime-api", "assert-overwritten", + "port_override" ] for image in images_to_delete: @@ -264,6 +265,23 @@ def test_function_name_is_overriden(self, arch, port): ) self.assertEqual(b'"My lambda ran succesfully"', r.content) + @parameterized.expand([("x86_64", "8011"), ("arm64", "9011"), ("", "9061")]) + def test_port_override(self, arch, port): + image, rie, image_name = self.tagged_name("port_override", arch) + + # Use port 8081 inside the container instead of 8080 + cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8081 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler --runtime-interface-emulator-address 0.0.0.0:8081" + + Popen(cmd.split(" ")).communicate() + + # sleep 1s to give enough time for the endpoint to be up to curl + time.sleep(SLEEP_TIME) + + r = requests.post( + f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} + ) + self.assertEqual(b'"My lambda ran succesfully"', r.content) + if __name__ == "__main__": main() From d69fb6aa81eb832a8e6d3a8a5aecd767881ce5eb Mon Sep 17 00:00:00 2001 From: Dominik Schubert Date: Mon, 5 Jun 2023 10:25:33 +0200 Subject: [PATCH 09/41] fix exception propagation during runtime environment create (#20) --- cmd/localstack/custom_interop.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/cmd/localstack/custom_interop.go b/cmd/localstack/custom_interop.go index 96966b3..7bf3e86 100644 --- a/cmd/localstack/custom_interop.go +++ b/cmd/localstack/custom_interop.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/go-chi/chi" log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/core" "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" @@ -208,11 +209,18 @@ func (c *CustomInteropServer) SendResponse(invokeID string, contentType string, } func (c *CustomInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error { - log.Traceln("Function called") - err := c.localStackAdapter.SendStatus(Error, response.Payload) + is, err := c.InternalState() if err != nil { return err } + rs := is.Runtime.State + if rs.Name == core.RuntimeInitErrorStateName { + err = c.localStackAdapter.SendStatus(Error, response.Payload) + if err != nil { + return err + } + } + return c.delegate.SendErrorResponse(invokeID, response) } From bf7e2486034742b84a1a25d28478e147b5e65f06 Mon Sep 17 00:00:00 2001 From: Renato Valenzuela <37676028+valerena@users.noreply.github.com> Date: Wed, 7 Jun 2023 17:24:16 -0700 Subject: [PATCH 10/41] chore(deps): Update to Golang 1.19 (#95) --- Makefile | 2 +- go.mod | 24 +++++++++++------------ go.sum | 60 +++++++++++++++++++++++++------------------------------- 3 files changed, 39 insertions(+), 47 deletions(-) diff --git a/Makefile b/Makefile index 7678eb6..9ff6c1a 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ compile-lambda-linux-all: make ARCH=old compile-lambda-linux compile-with-docker: - docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.18 make ARCH=${ARCH} compile-lambda-linux + docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.19 make ARCH=${ARCH} compile-lambda-linux compile-lambda-linux: CGO_ENABLED=0 GOOS=linux GOARCH=${GO_ARCH_${ARCH}} go build -ldflags "${RELEASE_BUILD_LINKER_FLAGS}" -o ${DESTINATION_${ARCH}} ./cmd/aws-lambda-rie diff --git a/go.mod b/go.mod index 278c63a..053c7e0 100644 --- a/go.mod +++ b/go.mod @@ -1,24 +1,22 @@ module go.amzn.com -go 1.18 +go 1.19 require ( - github.com/aws/aws-lambda-go v1.20.0 + github.com/aws/aws-lambda-go v1.41.0 github.com/go-chi/chi v4.1.2+incompatible - github.com/go-chi/render v1.0.1 - github.com/google/uuid v1.1.2 - github.com/jessevdk/go-flags v1.4.0 - github.com/sirupsen/logrus v1.6.0 - github.com/stretchr/testify v1.6.1 - golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 + github.com/google/uuid v1.3.0 + github.com/jessevdk/go-flags v1.5.0 + github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.8.4 + golang.org/x/sync v0.2.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/konsorten/go-windows-terminal-sequences v1.0.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/objx v0.1.0 // indirect - golang.org/x/net v0.7.0 // indirect - golang.org/x/sys v0.5.0 // indirect - gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect + github.com/stretchr/objx v0.5.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.8.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 905e315..d8fb9e9 100644 --- a/go.sum +++ b/go.sum @@ -1,44 +1,38 @@ -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/aws/aws-lambda-go v1.20.0 h1:ZSweJx/Hy9BoIDXKBEh16vbHH0t0dehnF8MKpMiOWc0= -github.com/aws/aws-lambda-go v1.20.0/go.mod h1:jJmlefzPfGnckuHdXX7/80O3BvUUi12XOkbv4w9SGLU= -github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= -github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/aws/aws-lambda-go v1.41.0 h1:l/5fyVb6Ud9uYd411xdHZzSf2n86TakxzpvIoz7l+3Y= +github.com/aws/aws-lambda-go v1.41.0/go.mod h1:jwFe2KmMsHmffA1X2R09hH6lFzJQxzI8qK17ewzbQMM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= github.com/go-chi/chi v4.1.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= -github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= -github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns= -github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= -github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA= -github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= -github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= -github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= +github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= -github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/urfave/cli/v2 v2.2.0/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= +golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ= -gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From e4d28d80e11f0c85e71081761a796208d03346aa Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Thu, 6 Jul 2023 21:00:37 +0200 Subject: [PATCH 11/41] set bootstrap to 755 if not executable (#21) --- cmd/localstack/awsutil.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cmd/localstack/awsutil.go b/cmd/localstack/awsutil.go index b28d734..2f78acf 100644 --- a/cmd/localstack/awsutil.go +++ b/cmd/localstack/awsutil.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" + "golang.org/x/sys/unix" "io" "io/fs" "math" @@ -77,6 +78,15 @@ func getBootstrap(args []string) (*rapidcore.Bootstrap, string) { log.Panic("insufficient arguments: bootstrap not provided") } + err := unix.Access(bootstrapLookupCmd[0], unix.X_OK) + if err != nil { + log.Debug("Bootstrap not executable, setting permissions to 0755...", bootstrapLookupCmd[0]) + err = os.Chmod(bootstrapLookupCmd[0], 0755) + if err != nil { + log.Warn("Error setting bootstrap to 0755 permissions: ", bootstrapLookupCmd[0], err) + } + } + return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir), handler } From 605e17df31b8a7ef59e5c45d574e506a3958843f Mon Sep 17 00:00:00 2001 From: Joel Scheuner Date: Mon, 16 Oct 2023 15:43:10 +0200 Subject: [PATCH 12/41] Fix filesystem permission parity (#22) --- cmd/localstack/file_utils.go | 22 ++++++++++++++++++++++ cmd/localstack/main.go | 19 ++++++++++++++++--- cmd/localstack/user.go | 4 ++-- 3 files changed, 40 insertions(+), 5 deletions(-) create mode 100644 cmd/localstack/file_utils.go diff --git a/cmd/localstack/file_utils.go b/cmd/localstack/file_utils.go new file mode 100644 index 0000000..69e93de --- /dev/null +++ b/cmd/localstack/file_utils.go @@ -0,0 +1,22 @@ +package main + +import ( + "os" + "path/filepath" +) + +// Inspired by https://stackoverflow.com/questions/73864379/golang-change-permission-os-chmod-and-os-chowm-recursively +// but using the more efficient WalkDir API +func ChmodRecursively(root string, mode os.FileMode) error { + return filepath.WalkDir(root, + func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + err = os.Chmod(path, mode) + if err != nil { + return err + } + return nil + }) +} diff --git a/cmd/localstack/main.go b/cmd/localstack/main.go index e4e096a..917b330 100644 --- a/cmd/localstack/main.go +++ b/cmd/localstack/main.go @@ -132,6 +132,15 @@ func main() { log.Fatal("Failed to download code archives: " + err.Error()) } + // fix permissions of the layers directory for better AWS parity + if err := ChmodRecursively("/opt", 0755); err != nil { + log.Warnln("Could not change file mode recursively of directory /opt:", err) + } + // fix permissions of the tmp directory for better AWS parity + if err := ChmodRecursively("/tmp", 0700); err != nil { + log.Warnln("Could not change file mode recursively of directory /tmp:", err) + } + // parse CLI args bootstrap, handler := getBootstrap(os.Args) @@ -141,11 +150,15 @@ func main() { gid := 990 AddUser(lsOpts.User, uid, gid) if err := os.Chown("/tmp", uid, gid); err != nil { - log.Warnln("Could not change owner of /tmp:", err) + log.Warnln("Could not change owner of directory /tmp:", err) } UserLogger().Debugln("Process running as root user.") - DropPrivileges(lsOpts.User) - UserLogger().Debugln("Process running as non-root user.") + err := DropPrivileges(lsOpts.User) + if err != nil { + log.Warnln("Could not drop root privileges.", err) + } else { + UserLogger().Debugln("Process running as non-root user.") + } } logCollector := NewLogCollector() diff --git a/cmd/localstack/user.go b/cmd/localstack/user.go index 13c5f5d..3e6da42 100644 --- a/cmd/localstack/user.go +++ b/cmd/localstack/user.go @@ -70,12 +70,12 @@ func UserLogger() *log.Entry { } uid := os.Getuid() uidString := strconv.Itoa(uid) - user, err := user.LookupId(uidString) + userObject, err := user.LookupId(uidString) if err != nil { log.Warnln("Could not look up user by uid:", uid, err) } return log.WithFields(log.Fields{ - "username": user.Username, + "username": userObject.Username, "uid": uid, "euid": os.Geteuid(), "gid": os.Getgid(), From 605fa1c0b50d8430fd9030f3b3daf686e3e4c202 Mon Sep 17 00:00:00 2001 From: Joel Scheuner Date: Fri, 6 Oct 2023 17:03:01 +0200 Subject: [PATCH 13/41] Adapt to new init logic (#24) --- cmd/localstack/awsutil.go | 16 +++- cmd/localstack/custom_interop.go | 148 +++++++++--------------------- cmd/localstack/logs_egress_api.go | 31 +++++++ cmd/localstack/main.go | 36 ++++++-- debugging/Makefile | 2 +- debugging/README.md | 4 +- lambda/rapidcore/server.go | 3 +- 7 files changed, 118 insertions(+), 122 deletions(-) create mode 100644 cmd/localstack/logs_egress_api.go diff --git a/cmd/localstack/awsutil.go b/cmd/localstack/awsutil.go index 2f78acf..7a8ba8a 100644 --- a/cmd/localstack/awsutil.go +++ b/cmd/localstack/awsutil.go @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // LOCALSTACK CHANGES 2022-03-10: modified/collected file from /cmd/aws-lambda-rie/* into this util // LOCALSTACK CHANGES 2022-03-10: minor refactoring of PrintEndReports +// LOCALSTACK CHANGES 2023-10-06: reflect getBootstrap and InitHandler API updates package main @@ -11,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" + "go.amzn.com/lambda/rapidcore/env" "golang.org/x/sys/unix" "io" "io/fs" @@ -87,7 +89,7 @@ func getBootstrap(args []string) (*rapidcore.Bootstrap, string) { } } - return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir), handler + return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir, ""), handler } func PrintEndReports(invokeId string, initDuration string, memorySize string, invokeStart time.Time, timeoutDuration time.Duration, w io.Writer) { @@ -203,7 +205,7 @@ func getSubFoldersInList(prefix string, pathList []string) (oldFolders []string, return } -func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.Time, time.Time) { +func InitHandler(sandbox Sandbox, functionVersion string, timeout int64, bs interop.Bootstrap) (time.Time, time.Time) { additionalFunctionEnvironmentVariables := map[string]string{} // Add default Env Vars if they were not defined. This is a required otherwise 1p Python2.7, Python3.6, and @@ -226,7 +228,6 @@ func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.T // pass to rapid sandbox.Init(&interop.Init{ Handler: GetenvWithDefault("AWS_LAMBDA_FUNCTION_HANDLER", os.Getenv("_HANDLER")), - CorrelationID: "initCorrelationID", // TODO AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), AwsSession: os.Getenv("AWS_SESSION_TOKEN"), @@ -234,7 +235,16 @@ func InitHandler(sandbox Sandbox, functionVersion string, timeout int64) (time.T FunctionName: GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function"), FunctionVersion: functionVersion, + // TODO: Implement runtime management controls + // https://aws.amazon.com/blogs/compute/introducing-aws-lambda-runtime-management-controls/ + RuntimeInfo: interop.RuntimeInfo{ + ImageJSON: "{}", + Arn: "", + Version: ""}, CustomerEnvironmentVariables: additionalFunctionEnvironmentVariables, + SandboxType: interop.SandboxClassic, + Bootstrap: bs, + EnvironmentVariables: env.NewEnvironment(), }, timeout*1000) initEnd := time.Now() return initStart, initEnd diff --git a/cmd/localstack/custom_interop.go b/cmd/localstack/custom_interop.go index 7bf3e86..a8acc58 100644 --- a/cmd/localstack/custom_interop.go +++ b/cmd/localstack/custom_interop.go @@ -1,12 +1,14 @@ package main +// Original implementation: lambda/rapidcore/server.go includes Server struct with state +// Server interface between Runtime API and this init: lambda/interop/model.go:358 + import ( "bytes" "encoding/json" "fmt" "github.com/go-chi/chi" log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/core" "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" @@ -38,8 +40,8 @@ const ( ) func (l *LocalStackAdapter) SendStatus(status LocalStackStatus, payload []byte) error { - status_url := fmt.Sprintf("%s/status/%s/%s", l.UpstreamEndpoint, l.RuntimeId, status) - _, err := http.Post(status_url, "application/json", bytes.NewReader(payload)) + statusUrl := fmt.Sprintf("%s/status/%s/%s", l.UpstreamEndpoint, l.RuntimeId, status) + _, err := http.Post(statusUrl, "application/json", bytes.NewReader(payload)) if err != nil { return err } @@ -62,7 +64,7 @@ type ErrorResponse struct { StackTrace []string `json:"stackTrace,omitempty"` } -func NewCustomInteropServer(lsOpts *LsOpts, delegate rapidcore.InteropServer, logCollector *LogCollector) (server *CustomInteropServer) { +func NewCustomInteropServer(lsOpts *LsOpts, delegate interop.Server, logCollector *LogCollector) (server *CustomInteropServer) { server = &CustomInteropServer{ delegate: delegate.(*rapidcore.Server), port: lsOpts.InteropPort, @@ -99,9 +101,7 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate rapidcore.InteropServer, lo InvokedFunctionArn: invokeR.InvokedFunctionArn, Payload: strings.NewReader(invokeR.Payload), // r.Body, NeedDebugLogs: true, - CorrelationID: "invokeCorrelationID", - - TraceID: invokeR.TraceId, + TraceID: invokeR.TraceId, // TODO: set correct segment ID from request //LambdaSegmentID: "LambdaSegmentID", // r.Header.Get("X-Amzn-Segment-Id"), //CognitoIdentityID: "", @@ -194,147 +194,81 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate rapidcore.InteropServer, lo return server } -func (c *CustomInteropServer) StartAcceptingDirectInvokes() error { - log.Traceln("Function called") - err := c.localStackAdapter.SendStatus(Ready, []byte{}) - if err != nil { - return err - } - return c.delegate.StartAcceptingDirectInvokes() +func (c *CustomInteropServer) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { + log.Traceln("SendResponse called") + return c.delegate.SendResponse(invokeID, headers, reader, trailers, request) } -func (c *CustomInteropServer) SendResponse(invokeID string, contentType string, response io.Reader) error { - log.Traceln("Function called") - return c.delegate.SendResponse(invokeID, contentType, response) +func (c *CustomInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error { + log.Traceln("SendErrorResponse called") + return c.delegate.SendErrorResponse(invokeID, response) } -func (c *CustomInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error { - is, err := c.InternalState() - if err != nil { - return err - } - rs := is.Runtime.State - if rs.Name == core.RuntimeInitErrorStateName { - err = c.localStackAdapter.SendStatus(Error, response.Payload) - if err != nil { - return err - } +// SendInitErrorResponse writes error response during init to a shared memory and sends GIRD FAULT. +func (c *CustomInteropServer) SendInitErrorResponse(invokeID string, response *interop.ErrorResponse) error { + log.Traceln("SendInitErrorResponse called") + if err := c.localStackAdapter.SendStatus(Error, response.Payload); err != nil { + log.Fatalln("Failed to send init error to LocalStack " + err.Error() + ". Exiting.") } - - return c.delegate.SendErrorResponse(invokeID, response) + return c.delegate.SendInitErrorResponse(invokeID, response) } func (c *CustomInteropServer) GetCurrentInvokeID() string { - log.Traceln("Function called") + log.Traceln("GetCurrentInvokeID called") return c.delegate.GetCurrentInvokeID() } -func (c *CustomInteropServer) CommitResponse() error { - log.Traceln("Function called") - return c.delegate.CommitResponse() -} - -func (c *CustomInteropServer) SendRunning(running *interop.Running) error { - log.Traceln("Function called") - return c.delegate.SendRunning(running) -} - func (c *CustomInteropServer) SendRuntimeReady() error { - log.Traceln("Function called") + log.Traceln("SendRuntimeReady called") return c.delegate.SendRuntimeReady() } -func (c *CustomInteropServer) SendDone(done *interop.Done) error { - log.Traceln("Function called") - return c.delegate.SendDone(done) -} - -func (c *CustomInteropServer) SendDoneFail(fail *interop.DoneFail) error { - log.Traceln("Function called") - return c.delegate.SendDoneFail(fail) -} - -func (c *CustomInteropServer) StartChan() <-chan *interop.Start { - log.Traceln("Function called") - return c.delegate.StartChan() -} - -func (c *CustomInteropServer) InvokeChan() <-chan *interop.Invoke { - log.Traceln("Function called") - return c.delegate.InvokeChan() -} - -func (c *CustomInteropServer) ResetChan() <-chan *interop.Reset { - log.Traceln("Function called") - return c.delegate.ResetChan() -} - -func (c *CustomInteropServer) ShutdownChan() <-chan *interop.Shutdown { - log.Traceln("Function called") - return c.delegate.ShutdownChan() -} - -func (c *CustomInteropServer) TransportErrorChan() <-chan error { - log.Traceln("Function called") - return c.delegate.TransportErrorChan() -} - -func (c *CustomInteropServer) Clear() { - log.Traceln("Function called") - c.delegate.Clear() -} - -func (c *CustomInteropServer) IsResponseSent() bool { - log.Traceln("Function called") - return c.delegate.IsResponseSent() -} - -func (c *CustomInteropServer) SetInternalStateGetter(cb interop.InternalStateGetter) { - log.Traceln("Function called") - c.delegate.SetInternalStateGetter(cb) -} - -func (c *CustomInteropServer) Init(i *interop.Start, invokeTimeoutMs int64) { - log.Traceln("Function called") - c.delegate.Init(i, invokeTimeoutMs) +func (c *CustomInteropServer) Init(i *interop.Init, invokeTimeoutMs int64) error { + log.Traceln("Init called") + return c.delegate.Init(i, invokeTimeoutMs) } func (c *CustomInteropServer) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error { - log.Traceln("Function called") + log.Traceln("Invoke called") return c.delegate.Invoke(responseWriter, invoke) } func (c *CustomInteropServer) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error { - log.Traceln("Function called") + log.Traceln("FastInvoke called") return c.delegate.FastInvoke(w, i, direct) } func (c *CustomInteropServer) Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) { - log.Traceln("Function called") + log.Traceln("Reserve called") return c.delegate.Reserve(id, traceID, lambdaSegmentID) } func (c *CustomInteropServer) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) { - log.Traceln("Function called") + log.Traceln("Reset called") return c.delegate.Reset(reason, timeoutMs) } func (c *CustomInteropServer) AwaitRelease() (*statejson.InternalStateDescription, error) { - log.Traceln("Function called") + log.Traceln("AwaitRelease called") return c.delegate.AwaitRelease() } -func (c *CustomInteropServer) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { - log.Traceln("Function called") - return c.delegate.Shutdown(shutdown) -} - func (c *CustomInteropServer) InternalState() (*statejson.InternalStateDescription, error) { - log.Traceln("Function called") + log.Traceln("InternalState called") return c.delegate.InternalState() } func (c *CustomInteropServer) CurrentToken() *interop.Token { - log.Traceln("Function called") + log.Traceln("CurrentToken called") return c.delegate.CurrentToken() } + +func (c *CustomInteropServer) SetSandboxContext(sbCtx interop.SandboxContext) { + log.Traceln("SetSandboxContext called") + c.delegate.SetSandboxContext(sbCtx) +} + +func (c *CustomInteropServer) SetInternalStateGetter(cb interop.InternalStateGetter) { + log.Traceln("SetInternalStateGetter called") + c.delegate.InternalStateGetter = cb +} diff --git a/cmd/localstack/logs_egress_api.go b/cmd/localstack/logs_egress_api.go new file mode 100644 index 0000000..ec567d0 --- /dev/null +++ b/cmd/localstack/logs_egress_api.go @@ -0,0 +1,31 @@ +package main + +import ( + "io" + "os" +) + +// This LocalStack LogsEgressAPI builder allows to customize log capturing, in our case using the logCollector. + +type LocalStackLogsEgressAPI struct { + logCollector *LogCollector +} + +func NewLocalStackLogsEgressAPI(logCollector *LogCollector) *LocalStackLogsEgressAPI { + return &LocalStackLogsEgressAPI{ + logCollector: logCollector, + } +} + +// The interface StdLogsEgressAPI for the functions below is defined in the under cmd/localstack/logs_egress_api.go +// The default implementation is a NoOpLogsEgressAPI + +func (s *LocalStackLogsEgressAPI) GetExtensionSockets() (io.Writer, io.Writer, error) { + // os.Stderr can not be used for the stderrWriter because stderr is for internal logging (not customer visible). + return io.MultiWriter(s.logCollector, os.Stdout), io.MultiWriter(s.logCollector, os.Stdout), nil +} + +func (s *LocalStackLogsEgressAPI) GetRuntimeSockets() (io.Writer, io.Writer, error) { + // os.Stderr can not be used for the stderrWriter because stderr is for internal logging (not customer visible). + return io.MultiWriter(s.logCollector, os.Stdout), io.MultiWriter(s.logCollector, os.Stdout), nil +} diff --git a/cmd/localstack/main.go b/cmd/localstack/main.go index 917b330..b519965 100644 --- a/cmd/localstack/main.go +++ b/cmd/localstack/main.go @@ -161,14 +161,15 @@ func main() { } } - logCollector := NewLogCollector() - // file watcher for hot-reloading fileWatcherContext, cancelFileWatcher := context.WithCancel(context.Background()) + logCollector := NewLogCollector() + localStackLogsEgressApi := NewLocalStackLogsEgressAPI(logCollector) + // build sandbox sandbox := rapidcore. - NewSandboxBuilder(bootstrap). + NewSandboxBuilder(). //SetTracer(tracer). AddShutdownFunc(func() { log.Debugln("Stopping file watcher") @@ -178,7 +179,7 @@ func main() { }). SetExtensionsFlag(true). SetInitCachingFlag(true). - SetTailLogOutput(logCollector) + SetLogsEgressAPI(localStackLogsEgressApi) // xray daemon endpoint := "http://" + lsOpts.LocalstackIP + ":" + lsOpts.EdgePort @@ -192,7 +193,7 @@ func main() { }) runDaemon(d) // async - defaultInterop := sandbox.InteropServer() + defaultInterop := sandbox.DefaultInteropServer() interopServer := NewCustomInteropServer(lsOpts, defaultInterop, logCollector) sandbox.SetInteropServer(interopServer) if len(handler) > 0 { @@ -204,7 +205,10 @@ func main() { }) // initialize all flows and start runtime API - go sandbox.Create() + sandboxContext, internalStateFn := sandbox.Create() + // Populate our custom interop server + interopServer.SetSandboxContext(sandboxContext) + interopServer.SetInternalStateGetter(internalStateFn) // get timeout invokeTimeoutEnv := GetEnvOrDie("AWS_LAMBDA_FUNCTION_TIMEOUT") // TODO: collect all AWS_* env parsing @@ -214,8 +218,24 @@ func main() { } go RunHotReloadingListener(interopServer, lsOpts.HotReloadingPaths, fileWatcherContext) - // start runtime init - go InitHandler(sandbox, GetEnvOrDie("AWS_LAMBDA_FUNCTION_VERSION"), int64(invokeTimeoutSeconds)) // TODO: replace this with a custom init + // start runtime init. It is important to start `InitHandler` synchronously because we need to ensure the + // notification channels and status fields are properly initialized before `AwaitInitialized` + log.Debugln("Starting runtime init.") + InitHandler(sandbox.LambdaInvokeAPI(), GetEnvOrDie("AWS_LAMBDA_FUNCTION_VERSION"), int64(invokeTimeoutSeconds), bootstrap) // TODO: replace this with a custom init + + log.Debugln("Awaiting initialization of runtime init.") + if err := interopServer.delegate.AwaitInitialized(); err != nil { + // Error cases: ErrInitDoneFailed or ErrInitResetReceived + log.Errorln("Runtime init failed to initialize: " + err.Error() + ". Exiting.") + // NOTE: Sending the error status to LocalStack is handled beforehand in the custom_interop.go through the + // callback SendInitErrorResponse because it contains the correct error response payload. + return + } + + log.Debugln("Completed initialization of runtime init. Sending status ready to LocalStack.") + if err := interopServer.localStackAdapter.SendStatus(Ready, []byte{}); err != nil { + log.Fatalln("Failed to send status ready to LocalStack " + err.Error() + ". Exiting.") + } <-exitChan } diff --git a/debugging/Makefile b/debugging/Makefile index 03ea056..9bd3e35 100644 --- a/debugging/Makefile +++ b/debugging/Makefile @@ -1,5 +1,5 @@ # Golang EOL overview: https://endoflife.date/go -DOCKER_GOLANG_IMAGE ?= golang:1.18.2 +DOCKER_GOLANG_IMAGE ?= golang:1.19 # On ARM hosts, use: make ARCH=arm64 build-init # Check host architecture: uname -m diff --git a/debugging/README.md b/debugging/README.md index ac62d31..a335f32 100644 --- a/debugging/README.md +++ b/debugging/README.md @@ -7,7 +7,7 @@ Useful if you want more control over the API between the init and LocalStack (e. ## Debugging with LocalStack 1. Build init via `make build` - * On ARM hosts, use `make ARCH=arm64 build` + * On ARM hosts, use `make ARCH=arm64 build` because debugging only works with native containers. 2. Start LocalStack with the following flags: @@ -33,4 +33,4 @@ Useful if you want more control over the API between the init and LocalStack (e. Within `create_lambda_function`: * Increase the `timeout=3600` -* On ARM hosts, use `Architectures=[Architecture.arm64]` +* On ARM hosts, debugging only works with ARM containers. Use `LAMBDA_IGNORE_ARCHITECTURE=1` or explicitly configure the Lambda function with `Architectures=[Architecture.arm64]` diff --git a/lambda/rapidcore/server.go b/lambda/rapidcore/server.go index e652130..ba4a06a 100644 --- a/lambda/rapidcore/server.go +++ b/lambda/rapidcore/server.go @@ -1,5 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +// LOCALSTACK CHANGES 2023-10-17: pass request metadata into .Reserve(invoke.ID, invoke.TraceID, invoke.LambdaSegmentID) package rapidcore @@ -645,7 +646,7 @@ func (s *Server) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invo // The logic would be almost identical, except that init failures could manifest // through return values of FastInvoke and not Reserve() - reserveResp, err := s.Reserve("", "", "") + reserveResp, err := s.Reserve(invoke.ID, invoke.TraceID, invoke.LambdaSegmentID) if err != nil { log.Infof("ReserveFailed: %s", err) } From 3f022d7d40c7d66e6d2cbe1f1d08221164c7c022 Mon Sep 17 00:00:00 2001 From: Joel Scheuner Date: Wed, 18 Oct 2023 21:22:13 +0200 Subject: [PATCH 14/41] Add readme describing the LocalStack customizations (#25) --- README-LOCALSTACK.md | 47 ++++++++++++++++++++++++++++++++ cmd/localstack/custom_interop.go | 2 +- 2 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 README-LOCALSTACK.md diff --git a/README-LOCALSTACK.md b/README-LOCALSTACK.md new file mode 100644 index 0000000..2074548 --- /dev/null +++ b/README-LOCALSTACK.md @@ -0,0 +1,47 @@ +# Customized lambda-runtime-init for LocalStack + +This customized version of the Lambda Runtime Interface Emulator (RIE) is designed to work with [LocalStack](https://github.com/localstack/localstack). + +Refer to [debugging/README.md](./debugging/README.md) for instructions on how to build and test the customized RIE with LocalStack. + +## Branches + +* `localstack` main branch with the latest custom LocalStack changes +* `develop` and `main` are mirror branches of the upstream AWS repository [lambda-runtime-init](https://github.com/aws/aws-lambda-runtime-interface-emulator) + +## Structure + +| Directory | Description | +|------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `.github` | Build and release workflows | +| `bin/` | Target directory for binary builds (e.g., `aws-lambda-rie-x86_64`) | +| `cmd/localstack` | LocalStack customizations | +| ├── `main.go` | Main entrypoint | +| ├── `custom_interop.go` | Custom server interface between the Lambda runtime API and this Go init. Implements the `Server` interface from `lambda/interop/model.go:Server` but forwards most calls to the original implementation in `lambda/rapidcore/server.go` available as `delegate`. | +| `cmd/ls-api` | Mock LocalStack component for testing (likely outdated) | +| `debugging/` | Debug and test this Go init with LocalStack | +| ├── [`README.md`](./debugging/README.md) | Instructions for building and debugging with LocalStack | +| `lambda` | Original AWS implementation of the runtime emulator ideally kept untouched | + +## Integrate Upstream Changes + +Follow these steps to integrate upstream changes from the official AWS [lambda-runtime-init](https://github.com/aws/aws-lambda-runtime-interface-emulator) repository: + +1. Open the [develop](https://github.com/localstack/lambda-runtime-init/tree/develop) branch on GitHub. +2. Click "🔁Sync fork" to pull the upstream changes from AWS into the develop branch. +3. Create a new branch based on the branch localstack `git checkout localstack && git checkout -b integrate-upstream-changes`. +4. Merge the upstream changes from develop into the new branch `git merge develop` and resolve any potential conflicts. +5. If needed, add a single commit with minimal changes to adjust the localstack customizations to the new changes. +6. Create a PR on Github against `localstack/lambda-runtime-init localstack` (️not against AWS as by default ⚠️). +7. **MERGING:** Manually merge the approved PR using `git checkout localstack && git merge --ff integrate-upstream-changes` and add the PR number as a suffix to the commit message. Example: `(#24)`. Do not squash any upstream commits! +8. Manually push `git push origin localstack` and close the PR on GitHub + +Example PR that integrates upstream changes: https://github.com/localstack/lambda-runtime-init/pull/24 + +## Custom LocalStack Changes + +Document all custom changes with the following comment prefix `# LOCALSTACK CHANGES yyyy-mm-dd:` + +* Everything in `cmd/localstack`, `cmd/ls-api`, and `.github` +* `Makefile` for debugging and building with Docker +* 2023-10-17: `lambda/rapidcore/server.go` pass request metadata into .Reserve(invoke.ID, invoke.TraceID, invoke.LambdaSegmentID) diff --git a/cmd/localstack/custom_interop.go b/cmd/localstack/custom_interop.go index a8acc58..3dcde93 100644 --- a/cmd/localstack/custom_interop.go +++ b/cmd/localstack/custom_interop.go @@ -1,7 +1,7 @@ package main // Original implementation: lambda/rapidcore/server.go includes Server struct with state -// Server interface between Runtime API and this init: lambda/interop/model.go:358 +// Server interface between Runtime API and this init: lambda/interop/model.go:Server import ( "bytes" From bb1394a8383be01359a24d906f2d93053c2171b5 Mon Sep 17 00:00:00 2001 From: Joel Scheuner Date: Mon, 23 Oct 2023 17:04:39 +0200 Subject: [PATCH 15/41] Skip dropping privileges for root user (#26) --- cmd/localstack/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/localstack/main.go b/cmd/localstack/main.go index b519965..67a2bd4 100644 --- a/cmd/localstack/main.go +++ b/cmd/localstack/main.go @@ -145,7 +145,7 @@ func main() { bootstrap, handler := getBootstrap(os.Args) // Switch to non-root user and drop root privileges - if IsRootUser() && lsOpts.User != "" { + if IsRootUser() && lsOpts.User != "" && lsOpts.User != "root" { uid := 993 gid := 990 AddUser(lsOpts.User, uid, gid) From d176f7dd4f8757bf9c11b622f6406bf01b9715ad Mon Sep 17 00:00:00 2001 From: Joel Scheuner Date: Tue, 24 Oct 2023 18:51:19 +0200 Subject: [PATCH 16/41] Set file permissions of code directory when layer present (#27) --- cmd/localstack/file_utils.go | 17 +++++++++++++++++ cmd/localstack/main.go | 17 +++++++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/cmd/localstack/file_utils.go b/cmd/localstack/file_utils.go index 69e93de..ed65c70 100644 --- a/cmd/localstack/file_utils.go +++ b/cmd/localstack/file_utils.go @@ -1,6 +1,7 @@ package main import ( + "io" "os" "path/filepath" ) @@ -20,3 +21,19 @@ func ChmodRecursively(root string, mode os.FileMode) error { return nil }) } + +// Check if a directory is empty +// Source: https://stackoverflow.com/questions/30697324/how-to-check-if-directory-on-path-is-empty/30708914#30708914 +func IsDirEmpty(name string) (bool, error) { + f, err := os.Open(name) + if err != nil { + return false, err + } + defer f.Close() + + _, err = f.Readdirnames(1) // faster than f.Readdir(1) + if err == io.EOF { + return true, nil + } + return false, err // Either not empty or error, suits both cases +} diff --git a/cmd/localstack/main.go b/cmd/localstack/main.go index 67a2bd4..08b70d9 100644 --- a/cmd/localstack/main.go +++ b/cmd/localstack/main.go @@ -132,13 +132,22 @@ func main() { log.Fatal("Failed to download code archives: " + err.Error()) } - // fix permissions of the layers directory for better AWS parity + // set file permissions of the tmp directory for better AWS parity + if err := ChmodRecursively("/tmp", 0700); err != nil { + log.Warnln("Could not change file mode recursively of directory /tmp:", err) + } + // set file permissions of the layers directory for better AWS parity if err := ChmodRecursively("/opt", 0755); err != nil { log.Warnln("Could not change file mode recursively of directory /opt:", err) } - // fix permissions of the tmp directory for better AWS parity - if err := ChmodRecursively("/tmp", 0700); err != nil { - log.Warnln("Could not change file mode recursively of directory /tmp:", err) + // set file permissions of the code directory if at least one layer is present for better AWS parity + // Limitation: hot reloading likely breaks file permission parity for /var/task in combination with layers + // Heuristic for detecting the presence of layers. It might fail for an empty layer or image-based Lambda. + if isDirEmpty, _ := IsDirEmpty("/opt"); !isDirEmpty { + log.Debugln("Detected layer present") + if err := ChmodRecursively("/var/task", 0755); err != nil { + log.Warnln("Could not change file mode recursively of directory /var/task:", err) + } } // parse CLI args From 4ef12298f282481adccf69ab20058be91cd0c56f Mon Sep 17 00:00:00 2001 From: Renato Valenzuela Date: Fri, 10 Nov 2023 22:51:07 +0000 Subject: [PATCH 17/41] chore: Don't embed VCS information --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 9ff6c1a..c2d5e55 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ compile-with-docker: docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.19 make ARCH=${ARCH} compile-lambda-linux compile-lambda-linux: - CGO_ENABLED=0 GOOS=linux GOARCH=${GO_ARCH_${ARCH}} go build -ldflags "${RELEASE_BUILD_LINKER_FLAGS}" -o ${DESTINATION_${ARCH}} ./cmd/aws-lambda-rie + CGO_ENABLED=0 GOOS=linux GOARCH=${GO_ARCH_${ARCH}} go build -buildvcs=false -ldflags "${RELEASE_BUILD_LINKER_FLAGS}" -o ${DESTINATION_${ARCH}} ./cmd/aws-lambda-rie tests: go test ./... From d9bbbf13d97ec787d8444f80efb86fa3243a1f04 Mon Sep 17 00:00:00 2001 From: Renato Valenzuela Date: Sat, 11 Nov 2023 00:08:14 +0000 Subject: [PATCH 18/41] chore: Pull upstream changes 2023-11 --- cmd/aws-lambda-rie/main.go | 5 +- cmd/aws-lambda-rie/simple_bootstrap.go | 69 ++ cmd/aws-lambda-rie/simple_bootstrap_test.go | 78 +++ lambda/agents/agent.go | 10 +- lambda/appctx/appctx.go | 7 +- lambda/appctx/appctxutil.go | 28 +- lambda/core/directinvoke/directinvoke.go | 172 ++++- lambda/core/directinvoke/directinvoke_test.go | 400 ++++++++++- lambda/core/flow.go | 27 + lambda/core/registrations.go | 10 +- lambda/core/runtime_state_names.go | 1 + lambda/core/statejson/description.go | 40 +- lambda/core/states.go | 29 +- lambda/core/states_test.go | 57 +- lambda/extensions/extensions.go | 15 + lambda/fatalerror/fatalerror.go | 70 +- lambda/fatalerror/fatalerror_test.go | 51 ++ lambda/interop/bootstrap.go | 9 +- lambda/interop/environment_variables.go | 14 - lambda/interop/events_api.go | 193 ++++++ lambda/interop/events_api_test.go | 656 ++++++++++++++++++ lambda/interop/messages.go | 68 ++ lambda/interop/model.go | 244 ++++--- lambda/interop/model_test.go | 39 ++ lambda/interop/sandbox_model.go | 148 ++-- lambda/metering/time.go | 8 +- lambda/metering/time_test.go | 8 + lambda/rapi/extensions_fuzz_test.go | 344 +++++++++ lambda/rapi/handler/agentnext_test.go | 10 +- lambda/rapi/handler/agentregister.go | 79 ++- lambda/rapi/handler/agentregister_test.go | 197 ++++-- lambda/rapi/handler/initerror.go | 85 +-- lambda/rapi/handler/initerror_test.go | 8 +- lambda/rapi/handler/invocationerror.go | 18 +- lambda/rapi/handler/invocationerror_test.go | 58 +- lambda/rapi/handler/invocationnext_test.go | 69 +- lambda/rapi/handler/invocationresponse.go | 29 +- .../rapi/handler/invocationresponse_test.go | 11 +- lambda/rapi/handler/restoreerror.go | 47 ++ lambda/rapi/handler/restoreerror_test.go | 44 ++ lambda/rapi/handler/runtimelogs.go | 25 +- lambda/rapi/handler/runtimelogs_test.go | 129 +++- lambda/rapi/model/agentregisterresponse.go | 1 + lambda/rapi/model/errorresponse.go | 19 - lambda/rapi/rapi_fuzz_test.go | 391 +++++++++++ lambda/rapi/rendering/render_error.go | 88 +++ lambda/rapi/rendering/render_json.go | 4 +- lambda/rapi/rendering/rendering.go | 194 ++---- lambda/rapi/router.go | 6 +- lambda/rapi/router_test.go | 111 ++- lambda/rapi/security_test.go | 6 +- lambda/rapi/server.go | 12 +- lambda/rapi/telemetry_logs_fuzz_test.go | 185 +++++ lambda/rapid/exit.go | 20 +- lambda/rapid/{start.go => handlers.go} | 572 +++++++++------ lambda/rapid/handlers_test.go | 341 +++++++++ lambda/rapid/sandbox.go | 146 ++-- lambda/rapid/shutdown.go | 126 ++-- lambda/rapid/start_test.go | 201 ------ lambda/rapidcore/bootstrap.go | 205 ------ lambda/rapidcore/bootstrap_test.go | 280 -------- lambda/rapidcore/env/environment.go | 151 +--- lambda/rapidcore/env/environment_test.go | 62 +- lambda/rapidcore/env/rapidenv.go | 96 +++ lambda/rapidcore/runtime_release.go | 68 ++ lambda/rapidcore/runtime_release_test.go | 97 +++ lambda/rapidcore/sandbox_api.go | 116 ++-- lambda/rapidcore/sandbox_builder.go | 58 +- lambda/rapidcore/sandbox_emulator_api.go | 1 + lambda/rapidcore/server.go | 120 ++-- lambda/rapidcore/server_test.go | 158 +++-- .../rapidcore/standalone/eventLogHandler.go | 6 +- lambda/rapidcore/standalone/executeHandler.go | 7 +- lambda/rapidcore/standalone/invokeHandler.go | 29 +- lambda/rapidcore/standalone/restoreHandler.go | 35 +- lambda/rapidcore/standalone/router.go | 10 +- .../standalone/telemetry/agent_writer.go | 30 + .../standalone/telemetry/eventLog.go | 13 + .../standalone/telemetry/events_api.go | 293 ++++++++ .../standalone/telemetry/logs_egress_api.go | 26 + .../standalone/telemetry/structured_logger.go | 21 + .../rapidcore/standalone/telemetry/tracer.go | 216 ++++++ .../standalone/waitUntilReleaseHandler.go | 6 +- lambda/rapidcore/telemetry/eventLog.go | 78 --- lambda/rapidcore/telemetry/events_api.go | 97 --- lambda/rapidcore/telemetry/xray.go | 124 ---- lambda/supervisor/local_supervisor.go | 107 +-- lambda/supervisor/local_supervisor_test.go | 87 ++- lambda/supervisor/model/model.go | 203 ++++-- lambda/supervisor/model/model_test.go | 31 + .../logsapi => telemetry}/constants.go | 10 +- lambda/telemetry/events_api.go | 178 ++--- lambda/telemetry/events_api_test.go | 135 +++- lambda/telemetry/logs_egress_api.go | 2 + lambda/telemetry/logs_subscription_api.go | 4 +- lambda/telemetry/tracer.go | 107 ++- lambda/telemetry/tracer_test.go | 100 ++- lambda/testdata/flowtesting.go | 41 +- lambda/testdata/mocktracer/mocktracer.go | 5 +- 99 files changed, 6728 insertions(+), 2717 deletions(-) create mode 100644 cmd/aws-lambda-rie/simple_bootstrap.go create mode 100644 cmd/aws-lambda-rie/simple_bootstrap_test.go create mode 100644 lambda/fatalerror/fatalerror_test.go delete mode 100644 lambda/interop/environment_variables.go create mode 100644 lambda/interop/events_api.go create mode 100644 lambda/interop/events_api_test.go create mode 100644 lambda/interop/messages.go create mode 100644 lambda/rapi/extensions_fuzz_test.go create mode 100644 lambda/rapi/handler/restoreerror.go create mode 100644 lambda/rapi/handler/restoreerror_test.go create mode 100644 lambda/rapi/rapi_fuzz_test.go create mode 100644 lambda/rapi/rendering/render_error.go create mode 100644 lambda/rapi/telemetry_logs_fuzz_test.go rename lambda/rapid/{start.go => handlers.go} (58%) create mode 100644 lambda/rapid/handlers_test.go delete mode 100644 lambda/rapid/start_test.go delete mode 100644 lambda/rapidcore/bootstrap.go delete mode 100644 lambda/rapidcore/bootstrap_test.go create mode 100644 lambda/rapidcore/env/rapidenv.go create mode 100644 lambda/rapidcore/runtime_release.go create mode 100644 lambda/rapidcore/runtime_release_test.go create mode 100644 lambda/rapidcore/standalone/telemetry/agent_writer.go create mode 100644 lambda/rapidcore/standalone/telemetry/eventLog.go create mode 100644 lambda/rapidcore/standalone/telemetry/events_api.go create mode 100644 lambda/rapidcore/standalone/telemetry/logs_egress_api.go create mode 100644 lambda/rapidcore/standalone/telemetry/structured_logger.go create mode 100644 lambda/rapidcore/standalone/telemetry/tracer.go delete mode 100644 lambda/rapidcore/telemetry/eventLog.go delete mode 100644 lambda/rapidcore/telemetry/events_api.go delete mode 100644 lambda/rapidcore/telemetry/xray.go create mode 100644 lambda/supervisor/model/model_test.go rename lambda/{rapidcore/telemetry/logsapi => telemetry}/constants.go (94%) diff --git a/cmd/aws-lambda-rie/main.go b/cmd/aws-lambda-rie/main.go index 65879c0..bd15402 100644 --- a/cmd/aws-lambda-rie/main.go +++ b/cmd/aws-lambda-rie/main.go @@ -11,6 +11,7 @@ import ( "runtime/debug" "github.com/jessevdk/go-flags" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" log "github.com/sirupsen/logrus" @@ -103,7 +104,7 @@ func isBootstrapFileExist(filePath string) bool { return !os.IsNotExist(err) && !file.IsDir() } -func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) { +func getBootstrap(args []string, opts options) (interop.Bootstrap, string) { var bootstrapLookupCmd []string var handler string currentWorkingDir := "/var/task" // default value @@ -149,5 +150,5 @@ func getBootstrap(args []string, opts options) (*rapidcore.Bootstrap, string) { log.Panic("insufficient arguments: bootstrap not provided") } - return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir, ""), handler + return NewSimpleBootstrap(bootstrapLookupCmd, currentWorkingDir), handler } diff --git a/cmd/aws-lambda-rie/simple_bootstrap.go b/cmd/aws-lambda-rie/simple_bootstrap.go new file mode 100644 index 0000000..c9111a2 --- /dev/null +++ b/cmd/aws-lambda-rie/simple_bootstrap.go @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "fmt" + "os" + "path/filepath" + + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore/env" +) + +// the type implement a simpler version of the Bootstrap +// this is useful in the Standalone Core implementation. +type simpleBootstrap struct { + cmd []string + workingDir string +} + +func NewSimpleBootstrap(cmd []string, currentWorkingDir string) interop.Bootstrap { + if currentWorkingDir == "" { + // use the root directory as the default working directory + currentWorkingDir = "/" + } + + // a single candidate command makes it automatically valid + return &simpleBootstrap{ + cmd: cmd, + workingDir: currentWorkingDir, + } +} + +func (b *simpleBootstrap) Cmd() ([]string, error) { + return b.cmd, nil +} + +// Cwd returns the working directory of the bootstrap process +// The path is validated against the chroot identified by `root` +func (b *simpleBootstrap) Cwd() (string, error) { + if !filepath.IsAbs(b.workingDir) { + return "", fmt.Errorf("the working directory '%s' is invalid, it needs to be an absolute path", b.workingDir) + } + + // evaluate the path relatively to the domain's mnt namespace root + if _, err := os.Stat(b.workingDir); os.IsNotExist(err) { + return "", fmt.Errorf("the working directory doesn't exist: %s", b.workingDir) + } + + return b.workingDir, nil +} + +// Env returns the environment variables available to +// the bootstrap process +func (b *simpleBootstrap) Env(e *env.Environment) map[string]string { + return e.RuntimeExecEnv() +} + +// ExtraFiles returns the extra file descriptors apart from 1 & 2 to be passed to runtime +func (b *simpleBootstrap) ExtraFiles() []*os.File { + return make([]*os.File, 0) +} + +func (b *simpleBootstrap) CachedFatalError(err error) (fatalerror.ErrorType, string, bool) { + // not implemented as it is not needed in Core but we need to fullfil the interface anyway + return fatalerror.ErrorType(""), "", false +} diff --git a/cmd/aws-lambda-rie/simple_bootstrap_test.go b/cmd/aws-lambda-rie/simple_bootstrap_test.go new file mode 100644 index 0000000..de00ee2 --- /dev/null +++ b/cmd/aws-lambda-rie/simple_bootstrap_test.go @@ -0,0 +1,78 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "os" + "reflect" + "testing" + + "go.amzn.com/lambda/rapidcore/env" + + "github.com/stretchr/testify/assert" +) + +func TestSimpleBootstrap(t *testing.T) { + tmpFile, err := os.CreateTemp("", "oci-test-bootstrap") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // Setup single cmd candidate + file := []string{tmpFile.Name(), "--arg1 s", "foo"} + cmdCandidate := file + + // Setup working dir + cwd, err := os.Getwd() + assert.NoError(t, err) + + // Setup environment + environment := env.NewEnvironment() + environment.StoreRuntimeAPIEnvironmentVariable("host:port") + environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") + + // Test + b := NewSimpleBootstrap(cmdCandidate, cwd) + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, cwd, bCwd) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) + + cmd, err := b.Cmd() + assert.NoError(t, err) + assert.Equal(t, file, cmd) +} + +func TestSimpleBootstrapCmdNonExistingCandidate(t *testing.T) { + // Setup inexistent single cmd candidate + file := []string{"/foo/bar", "--arg1 s", "foo"} + cmdCandidate := file + + // Setup working dir + cwd, err := os.Getwd() + assert.NoError(t, err) + + // Setup environment + environment := env.NewEnvironment() + environment.StoreRuntimeAPIEnvironmentVariable("host:port") + environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") + + // Test + b := NewSimpleBootstrap(cmdCandidate, cwd) + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, cwd, bCwd) + assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) + + // No validations run against single candidates + cmd, err := b.Cmd() + assert.NoError(t, err) + assert.Equal(t, file, cmd) +} + +func TestSimpleBootstrapCmdDefaultWorkingDir(t *testing.T) { + b := NewSimpleBootstrap([]string{}, "") + bCwd, err := b.Cwd() + assert.NoError(t, err) + assert.Equal(t, "/", bCwd) +} diff --git a/lambda/agents/agent.go b/lambda/agents/agent.go index b1f8563..cabe1fa 100644 --- a/lambda/agents/agent.go +++ b/lambda/agents/agent.go @@ -20,10 +20,18 @@ func ListExternalAgentPaths(dir string, root string) []string { } fullDir := path.Join(root, dir) files, err := os.ReadDir(fullDir) + if err != nil { - log.WithError(err).Warning("Cannot list external agents") + if os.IsNotExist(err) { + log.Infof("The extension's directory %q does not exist, assuming no extensions to be loaded.", fullDir) + } else { + // TODO - Should this return an error rather than ignore failing to load? + log.WithError(err).Error("Cannot list external agents") + } + return agentPaths } + for _, file := range files { if !file.IsDir() { // The returned path is absolute wrt to `root`. This allows diff --git a/lambda/appctx/appctx.go b/lambda/appctx/appctx.go index 6c81653..931a2ec 100644 --- a/lambda/appctx/appctx.go +++ b/lambda/appctx/appctx.go @@ -13,9 +13,9 @@ type Key int type InitType int const ( - // AppCtxInvokeErrorResponseKey is used for storing deferred invoke error response. + // AppCtxInvokeErrorTraceDataKey is used for storing deferred invoke error cause header value. // Only used by xray. TODO refactor xray interface so it doesn't use appctx - AppCtxInvokeErrorResponseKey Key = iota + AppCtxInvokeErrorTraceDataKey Key = iota // AppCtxRuntimeReleaseKey is used for storing runtime release information (parsed from User_Agent Http header string). AppCtxRuntimeReleaseKey @@ -23,6 +23,9 @@ const ( // AppCtxInteropServerKey is used to store a reference to the interop server. AppCtxInteropServerKey + // AppCtxResponseSenderKey is used to store a reference to the response sender + AppCtxResponseSenderKey + // AppCtxFirstFatalErrorKey is used to store first unrecoverable error message encountered to propagate it to slicer with DONE(errortype) or DONEFAIL(errortype) AppCtxFirstFatalErrorKey diff --git a/lambda/appctx/appctxutil.go b/lambda/appctx/appctxutil.go index a30677f..cd6e6d3 100644 --- a/lambda/appctx/appctxutil.go +++ b/lambda/appctx/appctxutil.go @@ -119,16 +119,16 @@ func UpdateAppCtxWithRuntimeRelease(request *http.Request, appCtx ApplicationCon return false } -// StoreErrorResponse stores response in the applicaton context. -func StoreErrorResponse(appCtx ApplicationContext, errorResponse *interop.ErrorResponse) { - appCtx.Store(AppCtxInvokeErrorResponseKey, errorResponse) +// StoreInvokeErrorTraceData stores invocation error x-ray cause header in the applicaton context. +func StoreInvokeErrorTraceData(appCtx ApplicationContext, invokeError *interop.InvokeErrorTraceData) { + appCtx.Store(AppCtxInvokeErrorTraceDataKey, invokeError) } -// LoadErrorResponse retrieves response from the application context. -func LoadErrorResponse(appCtx ApplicationContext) *interop.ErrorResponse { - v, ok := appCtx.Load(AppCtxInvokeErrorResponseKey) +// LoadInvokeErrorTraceData retrieves invocation error x-ray cause header from the application context. +func LoadInvokeErrorTraceData(appCtx ApplicationContext) *interop.InvokeErrorTraceData { + v, ok := appCtx.Load(AppCtxInvokeErrorTraceDataKey) if ok { - return v.(*interop.ErrorResponse) + return v.(*interop.InvokeErrorTraceData) } return nil } @@ -147,6 +147,20 @@ func LoadInteropServer(appCtx ApplicationContext) interop.Server { return nil } +// StoreResponseSender stores a reference to the response sender +func StoreResponseSender(appCtx ApplicationContext, server interop.InvokeResponseSender) { + appCtx.Store(AppCtxResponseSenderKey, server) +} + +// LoadResponseSender retrieves the response sender +func LoadResponseSender(appCtx ApplicationContext) interop.InvokeResponseSender { + v, ok := appCtx.Load(AppCtxResponseSenderKey) + if ok { + return v.(interop.InvokeResponseSender) + } + return nil +} + // StoreFirstFatalError stores unrecoverable error code in appctx once. This error is considered to be the rootcause of failure func StoreFirstFatalError(appCtx ApplicationContext, err fatalerror.ErrorType) { if existing := appCtx.StoreIfNotExists(AppCtxFirstFatalErrorKey, err); existing != nil { diff --git a/lambda/core/directinvoke/directinvoke.go b/lambda/core/directinvoke/directinvoke.go index 8ef59ae..3510132 100644 --- a/lambda/core/directinvoke/directinvoke.go +++ b/lambda/core/directinvoke/directinvoke.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strconv" + "strings" "github.com/go-chi/chi" "go.amzn.com/lambda/core/bandwidthlimiter" @@ -27,6 +28,7 @@ const ( CustomerHeadersHeader = "Customer-Headers" ContentTypeHeader = "Content-Type" MaxPayloadSizeHeader = "MaxPayloadSize" + InvokeResponseModeHeader = "InvokeResponseMode" ResponseBandwidthRateHeader = "ResponseBandwidthRate" ResponseBandwidthBurstSizeHeader = "ResponseBandwidthBurstSize" FunctionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" @@ -53,6 +55,10 @@ var MaxDirectResponseSize int64 = interop.MaxPayloadSize // this is intentionall var ResponseBandwidthRate int64 = interop.ResponseBandwidthRate var ResponseBandwidthBurstSize int64 = interop.ResponseBandwidthBurstSize +// InvokeResponseMode controls the context in which the invoke is. Since this was introduced +// in Streaming invokes, we default it to Buffered. +var InvokeResponseMode interop.InvokeResponseMode = interop.InvokeResponseModeBuffered + func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) { w.Header().Set(ErrorTypeHeader, errorType) w.WriteHeader(http.StatusBadRequest) @@ -65,9 +71,29 @@ func renderInternalServerError(w http.ResponseWriter, errorType string) { w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } +// convertToInvokeResponseMode converts the given string to a InvokeResponseMode +// It is case insensitive and if there is no match, an error is thrown. +func convertToInvokeResponseMode(value string) (interop.InvokeResponseMode, error) { + // buffered + if strings.EqualFold(value, string(interop.InvokeResponseModeBuffered)) { + return interop.InvokeResponseModeBuffered, nil + } + + // streaming + if strings.EqualFold(value, string(interop.InvokeResponseModeStreaming)) { + return interop.InvokeResponseModeStreaming, nil + } + + // unknown + allowedValues := strings.Join(interop.AllInvokeResponseModes, ", ") + log.Errorf("Unable to map %s to %s.", value, allowedValues) + return "", interop.ErrInvalidInvokeResponseMode +} + // ReceiveDirectInvoke parses invoke and verifies it against Token message. Uses deadline provided by Token // Renders BadRequest in case of error func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.Token) (*interop.Invoke, error) { + log.Infof("Received Invoke(invokeID: %s) Request", token.InvokeID) w.Header().Set("Trailer", EndOfResponseTrailer) custHeaders := CustomerHeaders{} @@ -89,10 +115,30 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T } } - if MaxDirectResponseSize == -1 { + if valueFromHeader := r.Header.Get(InvokeResponseModeHeader); valueFromHeader != "" { + invokeResponseMode, err := convertToInvokeResponseMode(valueFromHeader) + if err != nil { + log.Errorf( + "InvokeResponseMode header is not a valid string. Was: %#v, Allowed: %#v.", + valueFromHeader, + strings.Join(interop.AllInvokeResponseModes, ", "), + ) + renderBadRequest(w, r, err.Error()) + return nil, err + } + InvokeResponseMode = invokeResponseMode + } + + // TODO: stop using `MaxDirectResponseSize` + if isStreamingInvoke(int(MaxDirectResponseSize), InvokeResponseMode) { w.Header().Add("Trailer", FunctionErrorTypeTrailer) w.Header().Add("Trailer", FunctionErrorBodyTrailer) + // FIXME + // Until WorkerProxy stops sending MaxDirectResponseSize == -1 to identify streaming + // invokes, we need to override InvokeResponseMode to avoid setting InvokeResponseMode to buffered (default) for a streaming invoke (MaxDirectResponseSize == -1). + InvokeResponseMode = interop.InvokeResponseModeStreaming + ResponseBandwidthRate = interop.ResponseBandwidthRate if responseBandwidthRate := r.Header.Get(ResponseBandwidthRateHeader); responseBandwidthRate != "" { if n, err := strconv.ParseInt(responseBandwidthRate, 10, 64); err == nil && @@ -119,20 +165,23 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T } inv := &interop.Invoke{ - ID: r.Header.Get(InvokeIDHeader), - ReservationToken: chi.URLParam(r, "reservationtoken"), - InvokedFunctionArn: r.Header.Get(InvokedFunctionArnHeader), - VersionID: r.Header.Get(VersionIDHeader), - ContentType: r.Header.Get(ContentTypeHeader), - CognitoIdentityID: custHeaders.CognitoIdentityID, - CognitoIdentityPoolID: custHeaders.CognitoIdentityPoolID, - TraceID: token.TraceID, - LambdaSegmentID: token.LambdaSegmentID, - ClientContext: custHeaders.ClientContext, - Payload: r.Body, - DeadlineNs: fmt.Sprintf("%d", now+token.FunctionTimeout.Nanoseconds()), - NeedDebugLogs: token.NeedDebugLogs, - InvokeReceivedTime: now, + ID: r.Header.Get(InvokeIDHeader), + ReservationToken: chi.URLParam(r, "reservationtoken"), + InvokedFunctionArn: r.Header.Get(InvokedFunctionArnHeader), + VersionID: r.Header.Get(VersionIDHeader), + ContentType: r.Header.Get(ContentTypeHeader), + CognitoIdentityID: custHeaders.CognitoIdentityID, + CognitoIdentityPoolID: custHeaders.CognitoIdentityPoolID, + TraceID: token.TraceID, + LambdaSegmentID: token.LambdaSegmentID, + ClientContext: custHeaders.ClientContext, + Payload: r.Body, + DeadlineNs: fmt.Sprintf("%d", now+token.FunctionTimeout.Nanoseconds()), + NeedDebugLogs: token.NeedDebugLogs, + InvokeReceivedTime: now, + InvokeResponseMode: InvokeResponseMode, + RestoreDurationNs: token.RestoreDurationNs, + RestoreStartTimeMonotime: token.RestoreStartTimeMonotime, } if inv.ID != token.InvokeID { @@ -170,7 +219,7 @@ type CopyDoneResult struct { func getErrorTypeFromResetReason(resetReason string) fatalerror.ErrorType { errorTypeTrailer, ok := ResetReasonMap[resetReason] if !ok { - errorTypeTrailer = fatalerror.Unknown + errorTypeTrailer = fatalerror.SandboxFailure } return errorTypeTrailer } @@ -180,8 +229,11 @@ func isErrorResponse(additionalHeaders map[string]string) (isErrorResponse bool) return } -func isStreamingInvoke() bool { - return MaxDirectResponseSize == -1 +// isStreamingInvoke checks whether the invoke mode is streaming or not. +// `maxDirectResponseSize == -1` is used as it was the first check we did when we released +// streaming invokes. +func isStreamingInvoke(maxDirectResponseSize int, invokeResponseMode interop.InvokeResponseMode) bool { + return maxDirectResponseSize == -1 || invokeResponseMode == interop.InvokeResponseModeStreaming } func asyncPayloadCopy(w http.ResponseWriter, payload io.Reader) (copyDone chan CopyDoneResult, cancel context.CancelFunc, err error) { @@ -190,10 +242,34 @@ func asyncPayloadCopy(w http.ResponseWriter, payload io.Reader) (copyDone chan C if err != nil { return nil, nil, &interop.ErrInternalPlatformError{} } + go func() { // copy payload in a separate go routine - _, copyError := bandwidthlimiter.BandwidthLimitingCopy(streamedResponseWriter, payload) + // -1 size indicates the payload size is unlimited. + isPayloadsSizeRestricted := MaxDirectResponseSize != -1 + + if isPayloadsSizeRestricted { + // Setting the limit to MaxDirectResponseSize + 1 so we can do + // readBytes > MaxDirectResponseSize to check if the response is oversized. + // As the response is allowed to be of the size MaxDirectResponseSize but not larger than it. + payload = io.LimitReader(payload, MaxDirectResponseSize+1) + } + + // FIXME: inject bandwidthlimiter as a dependency, so that we can mock it in tests + copiedBytes, copyError := bandwidthlimiter.BandwidthLimitingCopy(streamedResponseWriter, payload) + + isPayloadsSizeOversized := copiedBytes > MaxDirectResponseSize + if copyError != nil { w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) + copyError = &interop.ErrTruncatedResponse{} + } else if isPayloadsSizeRestricted && isPayloadsSizeOversized { + w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) + copyError = &interop.ErrorResponseTooLargeDI{ + ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ + ResponseSize: int(copiedBytes), + MaxResponseSize: int(MaxDirectResponseSize), + }, + } } else { w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) } @@ -227,8 +303,8 @@ func sendStreamingInvokeResponse(payload io.Reader, trailers http.Header, w http case copyDoneResult = <-copyDone: // copy finished errorTypeTrailer = trailers.Get(FunctionErrorTypeTrailer) errorBodyTrailer = trailers.Get(FunctionErrorBodyTrailer) - if copyDoneResult.Error != nil && errorTypeTrailer == "" { // truncated payload, error type not known - errorTypeTrailer = string(fatalerror.TruncatedResponse) + if copyDoneResult.Error != nil && errorTypeTrailer == "" { + errorTypeTrailer = string(mapCopyDoneResultErrorToErrorType(copyDoneResult.Error)) } case reset := <-interruptedResponseChan: // reset initiated cancel() @@ -247,6 +323,7 @@ func sendStreamingInvokeResponse(payload io.Reader, trailers http.Header, w http } copyDoneResult = <-copyDone reset.InvokeResponseMetrics = copyDoneResult.Metrics + reset.InvokeResponseMode = InvokeResponseMode interruptedResponseChan <- nil errorTypeTrailer = string(getErrorTypeFromResetReason(reset.Reason)) } @@ -258,11 +335,23 @@ func sendStreamingInvokeResponse(payload io.Reader, trailers http.Header, w http if copyDoneResult.Error != nil { log.Errorf("Error while streaming response payload: %s", copyDoneResult.Error) - err = &interop.ErrTruncatedResponse{} + err = copyDoneResult.Error } return } +// mapCopyDoneResultErrorToErrorType map a copyDoneResult error into a fatalerror +func mapCopyDoneResultErrorToErrorType(err interface{}) fatalerror.ErrorType { + switch err.(type) { + case *interop.ErrTruncatedResponse: + return fatalerror.TruncatedResponse + case *interop.ErrorResponseTooLargeDI: + return fatalerror.FunctionOversizedResponse + default: + return fatalerror.SandboxFailure + } +} + func sendStreamingInvokeErrorResponse(payload io.Reader, w http.ResponseWriter, interruptedResponseChan chan *interop.Reset, sendResponseChan chan *interop.InvokeResponseMetrics, runtimeCalledResponse bool) (err error) { @@ -279,6 +368,7 @@ func sendStreamingInvokeErrorResponse(payload io.Reader, w http.ResponseWriter, cancel() copyDoneResult = <-copyDone reset.InvokeResponseMetrics = copyDoneResult.Metrics + reset.InvokeResponseMode = InvokeResponseMode interruptedResponseChan <- nil } @@ -287,8 +377,9 @@ func sendStreamingInvokeErrorResponse(payload io.Reader, w http.ResponseWriter, if copyDoneResult.Error != nil { log.Errorf("Error while streaming error response payload: %s", copyDoneResult.Error) - err = &interop.ErrTruncatedResponse{} + err = copyDoneResult.Error } + return } @@ -317,7 +408,10 @@ func sendPayloadLimitedResponse(payload io.Reader, trailers http.Header, w http. } startReadingResponseMonoTimeMs := metering.Monotime() - written, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // +1 because we do allow 10MB but not 10MB + 1 byte + // Setting the limit to MaxDirectResponseSize + 1 so we can do + // readBytes > MaxDirectResponseSize to check if the response is oversized. + // As the response is allowed to be of the size MaxDirectResponseSize but not larger than it. + written, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) // non-streaming invoke request but runtime is streaming: set response trailers if functionResponseMode == interop.FunctionResponseModeStreaming { @@ -325,10 +419,12 @@ func sendPayloadLimitedResponse(payload io.Reader, trailers http.Header, w http. w.Header().Set(FunctionErrorBodyTrailer, trailers.Get(FunctionErrorBodyTrailer)) } + isNotStreamingInvoke := InvokeResponseMode != interop.InvokeResponseModeStreaming + if err != nil { w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) err = &interop.ErrTruncatedResponse{} - } else if MaxDirectResponseSize != -1 && written == MaxDirectResponseSize+1 { + } else if isNotStreamingInvoke && written == MaxDirectResponseSize+1 { w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) err = &interop.ErrorResponseTooLargeDI{ ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ @@ -358,19 +454,33 @@ func sendPayloadLimitedResponse(payload io.Reader, trailers http.Header, w http. func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, trailers http.Header, w http.ResponseWriter, interruptedResponseChan chan *interop.Reset, - sendResponseChan chan *interop.InvokeResponseMetrics, request *interop.CancellableRequest, runtimeCalledResponse bool) error { + sendResponseChan chan *interop.InvokeResponseMetrics, request *interop.CancellableRequest, runtimeCalledResponse bool, invokeID string) error { for k, v := range additionalHeaders { w.Header().Add(k, v) } - if isStreamingInvoke() { // unlimited payload; response streaming mode - if isErrorResponse(additionalHeaders) { // send streamed error response when runtime called /error - return sendStreamingInvokeErrorResponse(payload, w, interruptedResponseChan, sendResponseChan, runtimeCalledResponse) + var err error + log.Infof("Started sending response (mode: %s, requestID: %s)", InvokeResponseMode, invokeID) + if InvokeResponseMode == interop.InvokeResponseModeStreaming { + // send streamed error response when runtime called /error + if isErrorResponse(additionalHeaders) { + err = sendStreamingInvokeErrorResponse(payload, w, interruptedResponseChan, sendResponseChan, runtimeCalledResponse) + if err != nil { + log.Infof("Error in sending error response (mode: %s, requestID: %s, error: %v)", InvokeResponseMode, invokeID, err) + } + return err } // send streamed response when runtime called /response - return sendStreamingInvokeResponse(payload, trailers, w, interruptedResponseChan, sendResponseChan, request, runtimeCalledResponse) + err = sendStreamingInvokeResponse(payload, trailers, w, interruptedResponseChan, sendResponseChan, request, runtimeCalledResponse) + } else { + err = sendPayloadLimitedResponse(payload, trailers, w, sendResponseChan, runtimeCalledResponse) } - return sendPayloadLimitedResponse(payload, trailers, w, sendResponseChan, runtimeCalledResponse) + if err != nil { + log.Infof("Error in sending response (mode: %s, requestID: %s, error: %v)", InvokeResponseMode, invokeID, err) + } else { + log.Infof("Completed sending response (mode: %s, requestID: %s)", InvokeResponseMode, invokeID) + } + return err } diff --git a/lambda/core/directinvoke/directinvoke_test.go b/lambda/core/directinvoke/directinvoke_test.go index 4e26161..94b6323 100644 --- a/lambda/core/directinvoke/directinvoke_test.go +++ b/lambda/core/directinvoke/directinvoke_test.go @@ -5,14 +5,24 @@ package directinvoke import ( "bytes" + "context" + "errors" + "fmt" "io" + "math" "net/http" + "net/http/httptest" + "strconv" "strings" "testing" "time" + "github.com/go-chi/chi" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" ) func NewResponseWriterWithoutFlushMethod() *ResponseWriterWithoutFlushMethod { @@ -93,24 +103,87 @@ func (r *Reader) Read(b []byte) (n int, err error) { return } -func TestSendDirectInvokeWithIncompatibleResponseWriter(t *testing.T) { - MaxDirectResponseSize = -1 - err := SendDirectInvokeResponse(nil, nil, nil, NewResponseWriterWithoutFlushMethod(), nil, nil, nil, false) - require.Error(t, err) - require.Equal(t, "ErrInternalPlatformError", err.Error()) +func TestAsyncPayloadCopyWhenPayloadSizeBelowMaxAllowed(t *testing.T) { + MaxDirectResponseSize = 2 + payloadSize := int(MaxDirectResponseSize - 1) + payloadString := strings.Repeat("a", payloadSize) + writer := NewSimpleResponseWriter() + + copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + copyDoneResult := <-copyDone + require.Nil(t, copyDoneResult.Error) + + require.Equal(t, payloadString, writer.buffer.String()) + require.Equal(t, EndOfResponseComplete, writer.Header().Get(EndOfResponseTrailer)) + + // reset it to its original value + MaxDirectResponseSize = interop.MaxPayloadSize } -func TestAsyncPayloadCopySuccess(t *testing.T) { - payloadString := strings.Repeat("a", 10*1024*1024) +func TestAsyncPayloadCopyWhenPayloadSizeEqualMaxAllowed(t *testing.T) { + MaxDirectResponseSize = 2 + payloadSize := int(MaxDirectResponseSize) + payloadString := strings.Repeat("a", payloadSize) writer := NewSimpleResponseWriter() - expectedPayloadString := payloadString + copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + copyDoneResult := <-copyDone + require.Nil(t, copyDoneResult.Error) + + require.Equal(t, payloadString, writer.buffer.String()) + require.Equal(t, EndOfResponseComplete, writer.Header().Get(EndOfResponseTrailer)) + + // reset it to its original value + MaxDirectResponseSize = interop.MaxPayloadSize +} + +func TestAsyncPayloadCopyWhenPayloadSizeAboveMaxAllowed(t *testing.T) { + MaxDirectResponseSize = 2 + payloadSize := int(MaxDirectResponseSize) + 1 + payloadString := strings.Repeat("a", payloadSize) + writer := NewSimpleResponseWriter() + expectedCopyDoneResultError := &interop.ErrorResponseTooLargeDI{ + ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ + ResponseSize: payloadSize, + MaxResponseSize: int(MaxDirectResponseSize), + }, + } copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) require.Nil(t, err) - <-copyDone - require.Equal(t, expectedPayloadString, writer.buffer.String()) + copyDoneResult := <-copyDone + require.Equal(t, expectedCopyDoneResultError, copyDoneResult.Error) + + require.Equal(t, payloadString, writer.buffer.String()) + require.Equal(t, EndOfResponseOversized, writer.Header().Get(EndOfResponseTrailer)) + + // reset it to its original value + MaxDirectResponseSize = interop.MaxPayloadSize +} + +// This is only allowed in streaming mode, currently. +func TestAsyncPayloadCopyWhenUnlimitedPayloadSizeAllowed(t *testing.T) { + MaxDirectResponseSize = -1 + payloadSize := int(interop.MaxPayloadSize + 1) + payloadString := strings.Repeat("a", payloadSize) + writer := NewSimpleResponseWriter() + + copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) + require.Nil(t, err) + + copyDoneResult := <-copyDone + require.Nil(t, copyDoneResult.Error) + + require.Equal(t, payloadString, writer.buffer.String()) + require.Equal(t, EndOfResponseComplete, writer.Header().Get(EndOfResponseTrailer)) + + // reset it to its original value + MaxDirectResponseSize = interop.MaxPayloadSize } // We use an interruptable response writer which informs on a channel that it's ready to be interrupted after @@ -135,7 +208,6 @@ func TestAsyncPayloadCopySuccessAfterCancel(t *testing.T) { <-copyDone require.Equal(t, expectedPayloadString, writer.buffer.String()) } - func TestAsyncPayloadCopyWithIncompatibleResponseWriter(t *testing.T) { copyDone, cancel, err := asyncPayloadCopy(&ResponseWriterWithoutFlushMethod{}, nil) require.Nil(t, copyDone) @@ -144,6 +216,13 @@ func TestAsyncPayloadCopyWithIncompatibleResponseWriter(t *testing.T) { require.Equal(t, "ErrInternalPlatformError", err.Error()) } +// TODO: in order to implement this test we need bandwidthlimiter to be received by asyncPayloadCopy +// as an argument. Otherwise, this test will need to know how to force bandwidthlimiter to fail, +// which isn't a good practice. +func TestAsyncPayloadCopyWhenResponseIsTruncated(t *testing.T) { + t.Skip("Pending injection of bandwidthlimiter as a dependency of asyncPayloadCopy.") +} + func TestSendStreamingInvokeResponseSuccess(t *testing.T) { payloadString := strings.Repeat("a", 128*1024) // 128 KiB payload := NewReader(payloadString) @@ -289,6 +368,7 @@ func TestSendStreamingInvokeResponseReset(t *testing.T) { // Reset initiated aft interruptedTestWriterChan <- struct{}{} // inform test writer about interruption <-interruptedResponseChan // wait for copy done after interruption require.NotNil(t, reset.InvokeResponseMetrics) + require.Equal(t, interop.InvokeResponseMode("Buffered"), reset.InvokeResponseMode) <-sendResponseChan require.Equal(t, expectedPayloadString, writer.buffer.String()) @@ -298,6 +378,60 @@ func TestSendStreamingInvokeResponseReset(t *testing.T) { // Reset initiated aft <-testFinished } +// TODO: mock asyncPayloadCopy and force it to return Oversized in copyDone +func TestSendStreamingInvokeResponseOversizedRuntimesWithTrailers(t *testing.T) { + oversizedPayloadString := strings.Repeat("a", int(MaxDirectResponseSize)+1) + payload := NewReader(oversizedPayloadString) + trailers := http.Header{ + FunctionErrorTypeTrailer: []string{"RuntimesErrorType"}, + FunctionErrorBodyTrailer: []string{"RuntimesBody"}, + } + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) + require.Error(t, err) + require.IsType(t, &interop.ErrorResponseTooLargeDI{}, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, trailers.Get(FunctionErrorTypeTrailer), writer.Header().Get(FunctionErrorTypeTrailer)) + require.Equal(t, trailers.Get(FunctionErrorBodyTrailer), writer.Header().Get(FunctionErrorBodyTrailer)) + require.Equal(t, EndOfResponseOversized, writer.Header().Get(EndOfResponseTrailer)) + <-testFinished +} + +// TODO: mock asyncPayloadCopy and force it to return Oversized in copyDone +func TestSendStreamingInvokeResponseOversizedRuntimesWithoutErrorTypeTrailer(t *testing.T) { + oversizedPayloadString := strings.Repeat("a", int(MaxDirectResponseSize)+1) + payload := NewReader(oversizedPayloadString) + trailers := http.Header{ + FunctionErrorTypeTrailer: []string{""}, + FunctionErrorBodyTrailer: []string{"RuntimesErrorBody"}, + } + writer := NewSimpleResponseWriter() + interruptedResponseChan := make(chan *interop.Reset) + sendResponseChan := make(chan *interop.InvokeResponseMetrics) + testFinished := make(chan struct{}) + + go func() { + err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) + require.Error(t, err) + require.IsType(t, &interop.ErrorResponseTooLargeDI{}, err) + testFinished <- struct{}{} + }() + + <-sendResponseChan + require.Equal(t, "Function.ResponseSizeTooLarge", writer.Header().Get(FunctionErrorTypeTrailer)) + require.Equal(t, trailers.Get(FunctionErrorBodyTrailer), writer.Header().Get(FunctionErrorBodyTrailer)) + require.Equal(t, EndOfResponseOversized, writer.Header().Get(EndOfResponseTrailer)) + <-testFinished +} + func TestSendStreamingInvokeErrorResponseSuccess(t *testing.T) { payloadString := strings.Repeat("a", 128*1024) // 128 KiB payload := NewReader(payloadString) @@ -356,3 +490,247 @@ func TestSendStreamingInvokeErrorResponseReset(t *testing.T) { // Reset initiate require.Equal(t, "Truncated", writer.Header().Get("End-Of-Response")) <-testFinished } + +func TestIsStreamingInvokeTrue(t *testing.T) { + fallbackFlag := -1 + reponseForFallback := isStreamingInvoke(fallbackFlag, interop.InvokeResponseModeBuffered) + + require.True(t, reponseForFallback) + + nonFallbackFlag := 1 + reponseForResponseMode := isStreamingInvoke(nonFallbackFlag, interop.InvokeResponseModeStreaming) + + require.True(t, reponseForResponseMode) +} + +func TestIsStreamingInvokeFalse(t *testing.T) { + nonFallbackFlag := 1 + response := isStreamingInvoke(nonFallbackFlag, interop.InvokeResponseModeBuffered) + + require.False(t, response) +} + +func TestMapCopyDoneResultErrorToErrorType(t *testing.T) { + require.Equal(t, fatalerror.TruncatedResponse, mapCopyDoneResultErrorToErrorType(&interop.ErrTruncatedResponse{})) + require.Equal(t, fatalerror.FunctionOversizedResponse, mapCopyDoneResultErrorToErrorType(&interop.ErrorResponseTooLargeDI{})) + require.Equal(t, fatalerror.SandboxFailure, mapCopyDoneResultErrorToErrorType(errors.New(""))) +} + +func TestConvertToInvokeResponseMode(t *testing.T) { + response, err := convertToInvokeResponseMode("buffered") + require.Equal(t, interop.InvokeResponseModeBuffered, response) + require.Nil(t, err) + + response, err = convertToInvokeResponseMode("streaming") + require.Equal(t, interop.InvokeResponseModeStreaming, response) + require.Nil(t, err) + + response, err = convertToInvokeResponseMode("foo-bar") + require.Equal(t, interop.InvokeResponseMode(""), response) + require.Equal(t, interop.ErrInvalidInvokeResponseMode, err) +} + +func FuzzReceiveDirectInvoke(f *testing.F) { + testCustHeaders := CustomerHeaders{ + CognitoIdentityID: "id1", + CognitoIdentityPoolID: "id2", + ClientContext: "clientcontext1", + } + custHeadersJSON := testCustHeaders.Dump() + + f.Add([]byte{'a'}, "res-token", "invokeid", "functionarn", "versionid", "contenttype", + custHeadersJSON, "1000", + "Streaming", fmt.Sprint(interop.MinResponseBandwidthRate), fmt.Sprint(interop.MinResponseBandwidthBurstSize)) + f.Add([]byte{'b'}, "res-token", "invokeid", "functionarn", "versionid", "contenttype", + custHeadersJSON, "2000", "Buffered", + "0", "0") + f.Add([]byte{'0'}, "0", "0", "0", "0", "0", + "", "", "0", + "0", "0") + + f.Fuzz(func( + t *testing.T, + payload []byte, + reservationToken string, + invokeID string, + invokedFunctionArn string, + versionID string, + contentType string, + custHeadersStr string, + maxPayloadSizeStr string, + invokeResponseModeStr string, + responseBandwidthRateStr string, + responseBandwidthBurstSizeStr string, + ) { + request := makeDirectInvokeRequest(payload, reservationToken, invokeID, + invokedFunctionArn, versionID, contentType, custHeadersStr, maxPayloadSizeStr, + invokeResponseModeStr, responseBandwidthRateStr, responseBandwidthBurstSizeStr) + + token := createDummyToken() + responseRecorder := httptest.NewRecorder() + + receivedInvoke, err := ReceiveDirectInvoke(responseRecorder, request, token) + + // default values used if header values are empty + responseMode := interop.InvokeResponseModeBuffered + maxDirectResponseSize := interop.MaxPayloadSize + + custHeaders := CustomerHeaders{} + + if err != nil { + if err = custHeaders.Load(custHeadersStr); err != nil { + assertBadRequestErrorType(t, responseRecorder, interop.ErrMalformedCustomerHeaders) + return + } + + if !isValidMaxPayloadSize(maxPayloadSizeStr) { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidMaxPayloadSize) + return + } + + n, _ := strconv.ParseInt(maxPayloadSizeStr, 10, 64) + maxDirectResponseSize = int(n) + + if invokeResponseModeStr != "" { + if responseMode, err = convertToInvokeResponseMode(invokeResponseModeStr); err != nil { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidInvokeResponseMode) + return + } + } + + if isStreamingInvoke(maxDirectResponseSize, responseMode) { + if !isValidResponseBandwidthRate(responseBandwidthRateStr) { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidResponseBandwidthRate) + return + } + + if !isValidResponseBandwidthBurstSize(responseBandwidthBurstSizeStr) { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidResponseBandwidthBurstSize) + return + } + } + + } else { + if isStreamingInvoke(maxDirectResponseSize, responseMode) { + // FIXME + // Until WorkerProxy stops sending MaxDirectResponseSize == -1 to identify streaming + // invokes, the ReceiveDirectInvoke() implementation overrides InvokeResponseMode + // to avoid setting InvokeResponseMode to buffered (default) for a streaming invoke (MaxDirectResponseSize == -1). + responseMode = interop.InvokeResponseModeStreaming + + assert.Equal(t, responseRecorder.Header().Values("Trailer"), []string{FunctionErrorTypeTrailer, FunctionErrorBodyTrailer}) + } + + if receivedInvoke.ID != token.InvokeID { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidInvokeID) + return + } + + if receivedInvoke.ReservationToken != token.ReservationToken { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidReservationToken) + return + } + + if receivedInvoke.VersionID != token.VersionID { + assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidFunctionVersion) + return + } + + if now := metering.Monotime(); now > token.InvackDeadlineNs { + assertBadRequestErrorType(t, responseRecorder, interop.ErrReservationExpired) + return + } + + assert.Equal(t, responseRecorder.Header().Get(VersionIDHeader), token.VersionID) + assert.Equal(t, responseRecorder.Header().Get(ReservationTokenHeader), token.ReservationToken) + assert.Equal(t, responseRecorder.Header().Get(InvokeIDHeader), token.InvokeID) + + expectedInvoke := &interop.Invoke{ + ID: invokeID, + ReservationToken: reservationToken, + InvokedFunctionArn: invokedFunctionArn, + VersionID: versionID, + ContentType: contentType, + CognitoIdentityID: custHeaders.CognitoIdentityID, + CognitoIdentityPoolID: custHeaders.CognitoIdentityPoolID, + TraceID: token.TraceID, + LambdaSegmentID: token.LambdaSegmentID, + ClientContext: custHeaders.ClientContext, + Payload: request.Body, + DeadlineNs: receivedInvoke.DeadlineNs, + NeedDebugLogs: token.NeedDebugLogs, + InvokeReceivedTime: receivedInvoke.InvokeReceivedTime, + InvokeResponseMode: responseMode, + RestoreDurationNs: token.RestoreDurationNs, + RestoreStartTimeMonotime: token.RestoreStartTimeMonotime, + } + + assert.Equal(t, expectedInvoke, receivedInvoke) + } + }) +} + +func createDummyToken() interop.Token { + return interop.Token{ + ReservationToken: "reservation_token", + TraceID: "trace_id", + InvokeID: "invoke_id", + InvackDeadlineNs: math.MaxInt64, + VersionID: "version_id", + } +} + +func assertBadRequestErrorType(t *testing.T, responseRecorder *httptest.ResponseRecorder, expectedErrType error) { + assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) + + assert.Equal(t, expectedErrType.Error(), responseRecorder.Header().Get(ErrorTypeHeader)) + assert.Equal(t, EndOfResponseComplete, responseRecorder.Header().Get(EndOfResponseTrailer)) +} + +func isValidResponseBandwidthBurstSize(sizeStr string) bool { + size, err := strconv.ParseInt(sizeStr, 10, 64) + return err == nil && + interop.MinResponseBandwidthBurstSize <= size && size <= interop.MaxResponseBandwidthBurstSize +} + +func isValidResponseBandwidthRate(rateStr string) bool { + rate, err := strconv.ParseInt(rateStr, 10, 64) + return err == nil && + interop.MinResponseBandwidthRate <= rate && rate <= interop.MaxResponseBandwidthRate +} + +func isValidMaxPayloadSize(maxPayloadSizeStr string) bool { + if maxPayloadSizeStr != "" { + maxPayloadSize, err := strconv.ParseInt(maxPayloadSizeStr, 10, 64) + return err == nil && maxPayloadSize >= -1 + } + + return true +} + +func makeDirectInvokeRequest( + payload []byte, reservationToken string, invokeID string, invokedFunctionArn string, + versionID string, contentType string, custHeadersStr string, maxPayloadSize string, + invokeResponseModeStr string, responseBandwidthRate string, responseBandwidthBurstSize string, +) *http.Request { + request := httptest.NewRequest("POST", "http://example.com/", bytes.NewReader(payload)) + request = addReservationToken(request, reservationToken) + + request.Header.Set(InvokeIDHeader, invokeID) + request.Header.Set(InvokedFunctionArnHeader, invokedFunctionArn) + request.Header.Set(VersionIDHeader, versionID) + request.Header.Set(ContentTypeHeader, contentType) + request.Header.Set(CustomerHeadersHeader, custHeadersStr) + request.Header.Set(MaxPayloadSizeHeader, maxPayloadSize) + request.Header.Set(InvokeResponseModeHeader, invokeResponseModeStr) + request.Header.Set(ResponseBandwidthRateHeader, responseBandwidthRate) + request.Header.Set(ResponseBandwidthBurstSizeHeader, responseBandwidthBurstSize) + + return request +} + +func addReservationToken(r *http.Request, reservationToken string) *http.Request { + rctx := chi.NewRouteContext() + rctx.URLParams.Add("reservationtoken", reservationToken) + return r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) +} diff --git a/lambda/core/flow.go b/lambda/core/flow.go index b2cb538..08d5e4b 100644 --- a/lambda/core/flow.go +++ b/lambda/core/flow.go @@ -3,6 +3,12 @@ package core +import ( + "context" + + "go.amzn.com/lambda/interop" +) + // InitFlowSynchronization wraps init flow barriers. type InitFlowSynchronization interface { SetExternalAgentsRegisterCount(uint16) error @@ -13,6 +19,7 @@ type InitFlowSynchronization interface { RuntimeReady() error AwaitRuntimeReady() error + AwaitRuntimeReadyWithDeadline(context.Context) error AgentReady() error AwaitAgentsReady() error @@ -47,6 +54,26 @@ func (s *initFlowSynchronizationImpl) AwaitRuntimeReady() error { return s.runtimeReadyGate.AwaitGateCondition() } +func (s *initFlowSynchronizationImpl) AwaitRuntimeReadyWithDeadline(ctx context.Context) error { + var err error + errorChan := make(chan error) + + go func() { + errorChan <- s.runtimeReadyGate.AwaitGateCondition() + }() + + select { + case err = <-errorChan: + break + case <-ctx.Done(): + err = interop.ErrRestoreHookTimeout + s.CancelWithError(err) + break + } + + return err +} + // AwaitRuntimeRestoreReady awaits runtime restore ready state (/restore/next is called by runtime) func (s *initFlowSynchronizationImpl) AwaitRuntimeRestoreReady() error { return s.runtimeRestoreReadyGate.AwaitGateCondition() diff --git a/lambda/core/registrations.go b/lambda/core/registrations.go index f68612c..26f6f2f 100644 --- a/lambda/core/registrations.go +++ b/lambda/core/registrations.go @@ -70,10 +70,12 @@ type AgentInfo struct { // FunctionMetadata holds static information regarding the function (Name, Version, Handler) type FunctionMetadata struct { - FunctionName string - FunctionVersion string - Handler string - RuntimeInfo interop.RuntimeInfo + AccountID string + FunctionName string + FunctionVersion string + InstanceMaxMemory uint64 + Handler string + RuntimeInfo interop.RuntimeInfo } // RegistrationService keeps track of registered parties, including external agents, threads, and runtime. diff --git a/lambda/core/runtime_state_names.go b/lambda/core/runtime_state_names.go index b04ba5d..4a2184d 100644 --- a/lambda/core/runtime_state_names.go +++ b/lambda/core/runtime_state_names.go @@ -16,4 +16,5 @@ const ( RuntimeInvocationResponseStateName = "InvocationResponse" RuntimeInvocationErrorResponseStateName = "InvocationErrorResponse" RuntimeResponseSentStateName = "RuntimeResponseSentState" + RuntimeRestoreErrorStateName = "RuntimeRestoreErrorState" ) diff --git a/lambda/core/statejson/description.go b/lambda/core/statejson/description.go index eb46946..a614d20 100644 --- a/lambda/core/statejson/description.go +++ b/lambda/core/statejson/description.go @@ -5,9 +5,24 @@ package statejson import ( "encoding/json" + log "github.com/sirupsen/logrus" ) +// ResponseMode are top-level constants used in combination with the various types of +// modes we have for responses, such as invoke's response mode and function's response mode. +// In the future we might have invoke's request mode or similar, so these help set the ground +// for consistency. +type ResponseMode string + +const ResponseModeBuffered = "Buffered" +const ResponseModeStreaming = "Streaming" + +type InvokeResponseMode string + +const InvokeResponseModeBuffered InvokeResponseMode = ResponseModeBuffered +const InvokeResponseModeStreaming InvokeResponseMode = ResponseModeStreaming + // StateDescription ... type StateDescription struct { Name string `json:"name"` @@ -35,9 +50,24 @@ type InternalStateDescription struct { FirstFatalError string `json:"firstFatalError"` } +type ResponseMetricsDimensions struct { + InvokeResponseMode InvokeResponseMode `json:"invokeResponseMode"` +} + +type ResponseMetrics struct { + RuntimeResponseLatencyMs float64 `json:"runtimeResponseLatencyMs"` + Dimensions ResponseMetricsDimensions `json:"dimensions"` +} + +type ReleaseResponse struct { + *InternalStateDescription + ResponseMetrics ResponseMetrics `json:"responseMetrics"` +} + // ResetDescription describes fields of the response to an INVOKE API request type ResetDescription struct { - ExtensionsResetMs int64 `json:"extensionsResetMs"` + ExtensionsResetMs int64 `json:"extensionsResetMs"` + ResponseMetrics ResponseMetrics `json:"responseMetrics"` } func (s *InternalStateDescription) AsJSON() []byte { @@ -55,3 +85,11 @@ func (s *ResetDescription) AsJSON() []byte { } return bytes } + +func (s *ReleaseResponse) AsJSON() []byte { + bytes, err := json.Marshal(s) + if err != nil { + log.Panicf("Failed to marshall release response: %s", err) + } + return bytes +} diff --git a/lambda/core/states.go b/lambda/core/states.go index a5e2010..0de88ec 100644 --- a/lambda/core/states.go +++ b/lambda/core/states.go @@ -9,6 +9,7 @@ import ( "time" "go.amzn.com/lambda/core/statejson" + "go.amzn.com/lambda/interop" ) // Suspendable on operator condition. @@ -76,6 +77,7 @@ type RuntimeState interface { InvocationResponse() error InvocationErrorResponse() error ResponseSent() error + RestoreError(interop.FunctionError) error Name() string } @@ -87,6 +89,9 @@ func (s *disallowEveryTransitionByDefault) RestoreReady() error { ret func (s *disallowEveryTransitionByDefault) InvocationResponse() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) InvocationErrorResponse() error { return ErrNotAllowed } func (s *disallowEveryTransitionByDefault) ResponseSent() error { return ErrNotAllowed } +func (s *disallowEveryTransitionByDefault) RestoreError(interop.FunctionError) error { + return ErrNotAllowed +} // Runtime is runtime object. type Runtime struct { @@ -105,6 +110,7 @@ type Runtime struct { RuntimeInvocationResponseState RuntimeState RuntimeInvocationErrorResponseState RuntimeState RuntimeResponseSentState RuntimeState + RuntimeRestoreErrorState RuntimeState } // Release ... @@ -176,6 +182,12 @@ func (s *Runtime) ResponseSent() error { return err } +func (s *Runtime) RestoreError(UserError interop.FunctionError) error { + s.ManagedThread.Lock() + defer s.ManagedThread.Unlock() + return s.currentState.RestoreError(UserError) +} + // GetRuntimeDescription returns runtime description object for debugging purposes func (s *Runtime) GetRuntimeDescription() statejson.RuntimeDescription { s.ManagedThread.Lock() @@ -207,6 +219,7 @@ func NewRuntime(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchroni runtime.RuntimeResponseSentState = &RuntimeResponseSentState{runtime: runtime, invokeFlow: invokeFlow} runtime.RuntimeRestoreReadyState = &RuntimeRestoreReadyState{} runtime.RuntimeRestoringState = &RuntimeRestoringState{runtime: runtime, initFlow: initFlow} + runtime.RuntimeRestoreErrorState = &RuntimeRestoreErrorState{runtime: runtime, initFlow: initFlow} runtime.setStateUnsafe(runtime.RuntimeStartedState) return runtime @@ -292,9 +305,9 @@ func (s *RuntimeRestoringState) Ready() error { return nil } -// Runtime has thrown an exception when executing restore hooks and called /init/error -func (s *RuntimeRestoringState) InitError() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) +func (s *RuntimeRestoringState) RestoreError(userError interop.FunctionError) error { + s.runtime.setStateUnsafe(s.runtime.RuntimeRestoreErrorState) + s.initFlow.CancelWithError(interop.ErrRestoreHookUserError{UserError: userError}) return nil } @@ -436,3 +449,13 @@ func (s *RuntimeResponseSentState) Ready() error { func (s *RuntimeResponseSentState) Name() string { return RuntimeResponseSentStateName } + +type RuntimeRestoreErrorState struct { + disallowEveryTransitionByDefault + runtime *Runtime + initFlow InitFlowSynchronization +} + +func (s *RuntimeRestoreErrorState) Name() string { + return RuntimeRestoreErrorStateName +} diff --git a/lambda/core/states_test.go b/lambda/core/states_test.go index 37f38e2..b6d2955 100644 --- a/lambda/core/states_test.go +++ b/lambda/core/states_test.go @@ -4,11 +4,14 @@ package core import ( + "context" + "sync" + "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/testdata/mockthread" - "sync" - "testing" ) func TestRuntimeInitErrorAfterReady(t *testing.T) { @@ -96,6 +99,34 @@ func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) } +func TestRuntimeStateTransitionsFromRestoreErrorState(t *testing.T) { + runtime := newRuntime() + // RestoreError -> InitError + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.InitError()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) + // RestoreError -> Ready + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.Ready()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) + // RestoreError -> RestoreReady + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) + // RestoreError -> ResponseSent + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) + // RestoreError -> InvocationResponse + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) + // RestoreError -> InvocationErrorResponse + runtime.SetState(runtime.RuntimeRestoreErrorState) + assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) +} + func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { runtime := newRuntime() // Ready -> InitError @@ -266,11 +297,9 @@ func TestRuntimeStateTransitionsFromRestoreReadyState(t *testing.T) { } func TestRuntimeStateTransitionsFromRestoringState(t *testing.T) { - runtime := newRuntime() - // RestoreRunning -> InitError + runtime, mockInitFlow, _ := newRuntimeGetMockFlows() runtime.SetState(runtime.RuntimeRestoringState) - assert.NoError(t, runtime.InitError()) - assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) + mockInitFlow.On("CancelWithError", interop.ErrRestoreHookUserError{UserError: interop.FunctionError{}}).Return() // RestoreRunning -> Ready runtime.SetState(runtime.RuntimeRestoringState) assert.NoError(t, runtime.Ready()) @@ -291,6 +320,10 @@ func TestRuntimeStateTransitionsFromRestoringState(t *testing.T) { runtime.SetState(runtime.RuntimeRestoringState) assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) + // RestoreRunning -> RestoreError + runtime.SetState(runtime.RuntimeRestoringState) + assert.NoError(t, runtime.RestoreError(interop.FunctionError{})) + assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) } func newRuntime() *Runtime { @@ -302,6 +335,15 @@ func newRuntime() *Runtime { return runtime } +func newRuntimeGetMockFlows() (*Runtime, *mockInitFlowSynchronization, *mockInvokeFlowSynchronization) { + initFlow := &mockInitFlowSynchronization{} + invokeFlow := &mockInvokeFlowSynchronization{} + runtime := NewRuntime(initFlow, invokeFlow) + runtime.ManagedThread = &mockthread.MockManagedThread{} + + return runtime, initFlow, invokeFlow +} + type mockInitFlowSynchronization struct { mock.Mock ReadyCond *sync.Cond @@ -325,6 +367,9 @@ func (s *mockInitFlowSynchronization) ExternalAgentRegistered() error { func (s *mockInitFlowSynchronization) AwaitRuntimeReady() error { return nil } +func (s *mockInitFlowSynchronization) AwaitRuntimeReadyWithDeadline(ctx context.Context) error { + return nil +} func (s *mockInitFlowSynchronization) AwaitAgentsReady() error { return nil } diff --git a/lambda/extensions/extensions.go b/lambda/extensions/extensions.go index b55dc51..abe0c87 100644 --- a/lambda/extensions/extensions.go +++ b/lambda/extensions/extensions.go @@ -4,7 +4,14 @@ package extensions import ( + "os" "sync/atomic" + + log "github.com/sirupsen/logrus" +) + +const ( + disableExtensionsFile = "/opt/disable-extensions-jwigqn8j" ) var enabled atomic.Value @@ -27,3 +34,11 @@ func AreEnabled() bool { } return val.(bool) } + +func DisableViaMagicLayer() { + _, err := os.Stat(disableExtensionsFile) + if err == nil { + log.Infof("Extensions disabled by attached layer (%s)", disableExtensionsFile) + Disable() + } +} diff --git a/lambda/fatalerror/fatalerror.go b/lambda/fatalerror/fatalerror.go index bb8a86a..665627d 100644 --- a/lambda/fatalerror/fatalerror.go +++ b/lambda/fatalerror/fatalerror.go @@ -3,23 +3,69 @@ package fatalerror +import ( + "regexp" + "strings" +) + // This package defines constant error types returned to slicer with DONE(failure), and also sandbox errors // Separate package for namespacing // ErrorType is returned to slicer inside DONE type ErrorType string +// TODO: Find another name than "fatalerror" +// TODO: Rename all const so that they always begin with Agent/Runtime/Sandbox/Function +// TODO: Add filtering for extensions as well const ( - AgentInitError ErrorType = "Extension.InitError" // agent exited after calling /extension/init/error - AgentExitError ErrorType = "Extension.ExitError" // agent exited after calling /extension/exit/error - AgentCrash ErrorType = "Extension.Crash" // agent crashed unexpectedly - AgentLaunchError ErrorType = "Extension.LaunchError" // agent could not be launched - RuntimeExit ErrorType = "Runtime.ExitError" - InvalidEntrypoint ErrorType = "Runtime.InvalidEntrypoint" - InvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir" - InvalidTaskConfig ErrorType = "Runtime.InvalidTaskConfig" - TruncatedResponse ErrorType = "Runtime.TruncatedResponse" - SandboxFailure ErrorType = "Sandbox.Failure" - SandboxTimeout ErrorType = "Sandbox.Timeout" - Unknown ErrorType = "Unknown" + // Extension errors + AgentInitError ErrorType = "Extension.InitError" // agent exited after calling /extension/init/error + AgentExitError ErrorType = "Extension.ExitError" // agent exited after calling /extension/exit/error + AgentCrash ErrorType = "Extension.Crash" // agent crashed unexpectedly + AgentLaunchError ErrorType = "Extension.LaunchError" // agent could not be launched + + // Runtime errors + RuntimeExit ErrorType = "Runtime.ExitError" + InvalidEntrypoint ErrorType = "Runtime.InvalidEntrypoint" + InvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir" + InvalidTaskConfig ErrorType = "Runtime.InvalidTaskConfig" + TruncatedResponse ErrorType = "Runtime.TruncatedResponse" + RuntimeInvalidResponseModeHeader ErrorType = "Runtime.InvalidResponseModeHeader" + RuntimeUnknown ErrorType = "Runtime.Unknown" + + // Function errors + FunctionOversizedResponse ErrorType = "Function.ResponseSizeTooLarge" + FunctionUnknown ErrorType = "Function.Unknown" + + // Sandbox errors + SandboxFailure ErrorType = "Sandbox.Failure" + SandboxTimeout ErrorType = "Sandbox.Timeout" ) + +var validRuntimeAndFunctionErrors = map[ErrorType]struct{}{ + // Runtime errors + RuntimeExit: {}, + InvalidEntrypoint: {}, + InvalidWorkingDir: {}, + InvalidTaskConfig: {}, + TruncatedResponse: {}, + RuntimeInvalidResponseModeHeader: {}, + RuntimeUnknown: {}, + + // Function errors + FunctionOversizedResponse: {}, + FunctionUnknown: {}, +} + +func GetValidRuntimeOrFunctionErrorType(errorType string) ErrorType { + match, _ := regexp.MatchString("(Runtime|Function)\\.[A-Z][a-zA-Z]+", errorType) + if match { + return ErrorType(errorType) + } + + if strings.HasPrefix(errorType, "Function.") { + return FunctionUnknown + } + + return RuntimeUnknown +} diff --git a/lambda/fatalerror/fatalerror_test.go b/lambda/fatalerror/fatalerror_test.go new file mode 100644 index 0000000..72c34aa --- /dev/null +++ b/lambda/fatalerror/fatalerror_test.go @@ -0,0 +1,51 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package fatalerror + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidRuntimeAndFunctionErrors(t *testing.T) { + type test struct { + input string + expected ErrorType + } + + var tests = []test{} + for validError := range validRuntimeAndFunctionErrors { + tests = append(tests, test{input: string(validError), expected: validError}) + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + assert.Equal(t, GetValidRuntimeOrFunctionErrorType(tt.input), tt.expected) + }) + } +} + +func TestGetValidRuntimeOrFunctionErrorType(t *testing.T) { + type test struct { + input string + expected ErrorType + } + + var tests = []test{ + {"", RuntimeUnknown}, + {"MyCustomError", RuntimeUnknown}, + {"MyCustomError.Error", RuntimeUnknown}, + {"Runtime.MyCustomErrorTypeHere", ErrorType("Runtime.MyCustomErrorTypeHere")}, + {"Function.MyCustomErrorTypeHere", ErrorType("Function.MyCustomErrorTypeHere")}, + } + + for _, tt := range tests { + testname := fmt.Sprintf("TestGetValidRuntimeOrFunctionErrorType with %s", tt.input) + t.Run(testname, func(t *testing.T) { + assert.Equal(t, GetValidRuntimeOrFunctionErrorType(tt.input), tt.expected) + }) + } +} diff --git a/lambda/interop/bootstrap.go b/lambda/interop/bootstrap.go index 4a9b6af..d3f4500 100644 --- a/lambda/interop/bootstrap.go +++ b/lambda/interop/bootstrap.go @@ -7,12 +7,13 @@ import ( "os" "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/rapidcore/env" ) type Bootstrap interface { - Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable - Env(e EnvironmentVariables) map[string]string // returns the environment variables to be passed to the bootstrapped process - Cwd() (string, error) // returns the working directory of the bootstrap process - ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime + Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable + Env(e *env.Environment) map[string]string // returns the environment variables to be passed to the bootstrapped process + Cwd() (string, error) // returns the working directory of the bootstrap process + ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime CachedFatalError(err error) (fatalerror.ErrorType, string, bool) } diff --git a/lambda/interop/environment_variables.go b/lambda/interop/environment_variables.go deleted file mode 100644 index 46bdf8b..0000000 --- a/lambda/interop/environment_variables.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package interop - -type EnvironmentVariables interface { - AgentExecEnv() map[string]string - RuntimeExecEnv() map[string]string - SetHandler(handler string) - StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) - StoreEnvironmentVariablesFromInit(customerEnv map[string]string, - handler, awsKey, awsSecret, awsSession, funcName, funcVer string) - StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) -} diff --git a/lambda/interop/events_api.go b/lambda/interop/events_api.go new file mode 100644 index 0000000..a0e9967 --- /dev/null +++ b/lambda/interop/events_api.go @@ -0,0 +1,193 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "fmt" + + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/rapi/model" +) + +type InitPhase string + +// InitializationType describes possible types of INIT phase +type InitType string + +type InitStartData struct { + InitializationType InitType `json:"initializationType"` + RuntimeVersion string `json:"runtimeVersion"` + RuntimeVersionArn string `json:"runtimeVersionArn"` + FunctionName string `json:"functionName"` + FunctionArn string `json:"functionArn"` + FunctionVersion string `json:"functionVersion"` + InstanceID string `json:"instanceId"` + InstanceMaxMemory uint64 `json:"instanceMaxMemory"` + Phase InitPhase `json:"phase"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InitStartData) String() string { + return fmt.Sprintf("INIT START(type: %s, phase: %s)", d.InitializationType, d.Phase) +} + +type InitRuntimeDoneData struct { + InitializationType InitType `json:"initializationType"` + Status string `json:"status"` + Phase InitPhase `json:"phase"` + ErrorType *string `json:"errorType,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InitRuntimeDoneData) String() string { + return fmt.Sprintf("INIT RTDONE(status: %s)", d.Status) +} + +type InitReportMetrics struct { + DurationMs float64 `json:"durationMs"` +} + +type InitReportData struct { + InitializationType InitType `json:"initializationType"` + Metrics InitReportMetrics `json:"metrics"` + Phase InitPhase `json:"phase"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InitReportData) String() string { + return fmt.Sprintf("INIT REPORT(durationMs: %f)", d.Metrics.DurationMs) +} + +type RestoreRuntimeDoneData struct { + Status string `json:"status"` + ErrorType *string `json:"errorType,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *RestoreRuntimeDoneData) String() string { + return fmt.Sprintf("RESTORE RTDONE(status: %s)", d.Status) +} + +type TracingCtx struct { + SpanID string `json:"spanId,omitempty"` + Type model.TracingType `json:"type"` + Value string `json:"value"` +} + +type InvokeStartData struct { + RequestID string `json:"requestId"` + Version string `json:"version,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` +} + +func (d *InvokeStartData) String() string { + return fmt.Sprintf("INVOKE START(requestId: %s)", d.RequestID) +} + +type RuntimeDoneInvokeMetrics struct { + ProducedBytes int64 `json:"producedBytes"` + DurationMs float64 `json:"durationMs"` +} + +type Span struct { + Name string `json:"name"` + Start string `json:"start"` + DurationMs float64 `json:"durationMs"` +} + +func (s *Span) String() string { + return fmt.Sprintf("SPAN(name: %s)", s.Name) +} + +type InvokeRuntimeDoneData struct { + RequestID RequestID `json:"requestId"` + Status string `json:"status"` + Metrics *RuntimeDoneInvokeMetrics `json:"metrics,omitempty"` + Tracing *TracingCtx `json:"tracing,omitempty"` + Spans []Span `json:"spans,omitempty"` + ErrorType *string `json:"errorType,omitempty"` + InternalMetrics *InvokeResponseMetrics `json:"-"` +} + +func (d *InvokeRuntimeDoneData) String() string { + return fmt.Sprintf("INVOKE RTDONE(status: %s, produced bytes: %d, duration: %fms)", d.Status, d.Metrics.ProducedBytes, d.Metrics.DurationMs) +} + +type ExtensionInitData struct { + AgentName string `json:"name"` + State string `json:"state"` + Subscriptions []string `json:"events"` + ErrorType string `json:"errorType,omitempty"` +} + +func (d *ExtensionInitData) String() string { + return fmt.Sprintf("EXTENSION INIT(agent name: %s, state: %s, error type: %s)", d.AgentName, d.State, d.ErrorType) +} + +type ReportMetrics struct { + DurationMs float64 `json:"durationMs"` + BilledDurationMs float64 `json:"billedDurationMs"` + MemorySizeMB uint64 `json:"memorySizeMB"` + MaxMemoryUsedMB uint64 `json:"maxMemoryUsedMB"` + InitDurationMs float64 `json:"initDurationMs,omitempty"` +} + +type ReportData struct { + RequestID RequestID `json:"requestId"` + Status string `json:"status"` + Metrics ReportMetrics `json:"metrics"` + Tracing *TracingCtx `json:"tracing,omitempty"` + Spans []Span `json:"spans,omitempty"` + ErrorType *string `json:"errorType,omitempty"` +} + +func (d *ReportData) String() string { + return fmt.Sprintf("REPORT(status: %s, durationMs: %f)", d.Status, d.Metrics.DurationMs) +} + +type EndData struct { + RequestID RequestID `json:"requestId"` +} + +func (d *EndData) String() string { + return "END" +} + +type RequestID string + +type FaultData struct { + RequestID RequestID + ErrorMessage error + ErrorType fatalerror.ErrorType +} + +func (d *FaultData) String() string { + return fmt.Sprintf("RequestId: %s Error: %s\n%s\n", d.RequestID, d.ErrorMessage, d.ErrorType) +} + +type ImageErrorLogData string + +type EventsAPI interface { + SetCurrentRequestID(RequestID) + SendInitStart(InitStartData) error + SendInitRuntimeDone(InitRuntimeDoneData) error + SendInitReport(InitReportData) error + SendRestoreRuntimeDone(RestoreRuntimeDoneData) error + SendInvokeStart(InvokeStartData) error + SendInvokeRuntimeDone(InvokeRuntimeDoneData) error + SendExtensionInit(ExtensionInitData) error + SendReportSpan(Span) error + SendReport(ReportData) error + SendEnd(EndData) error + SendFault(FaultData) error + SendImageErrorLog(ImageErrorLogData) + + FetchTailLogs(string) (string, error) + GetRuntimeDoneSpans( + runtimeStartedTime int64, + invokeResponseMetrics *InvokeResponseMetrics, + runtimeOverheadStartedTime int64, + runtimeReadyTime int64, + ) []Span +} diff --git a/lambda/interop/events_api_test.go b/lambda/interop/events_api_test.go new file mode 100644 index 0000000..d3a7dc1 --- /dev/null +++ b/lambda/interop/events_api_test.go @@ -0,0 +1,656 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/rapi/model" +) + +const requestID RequestID = "REQUEST_ID" + +func TestJsonMarshalInvokeRuntimeDone(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(52.56), + }, + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ], + "metrics": { + "producedBytes": 100, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneNoTracing(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(52.56), + }, + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ], + "metrics": { + "producedBytes": 100, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneNoMetrics(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "success", + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ] + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneWithProducedBytesEqualToZero(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(0), + DurationMs: float64(52.56), + }, + Spans: []Span{ + { + Name: "responseLatency", + Start: "2022-04-11T15:01:28.543Z", + DurationMs: float64(23.02), + }, + { + Name: "responseDuration", + Start: "2022-04-11T15:00:00.000Z", + DurationMs: float64(20), + }, + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "spans": [ + { + "name": "responseLatency", + "start": "2022-04-11T15:01:28.543Z", + "durationMs": 23.02 + }, + { + "name": "responseDuration", + "start": "2022-04-11T15:00:00.000Z", + "durationMs": 20 + } + ], + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneWithNoSpans(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "success", + Metrics: &RuntimeDoneInvokeMetrics{ + ProducedBytes: int64(100), + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 100, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneTimeout(t *testing.T) { + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "timeout", + Metrics: &RuntimeDoneInvokeMetrics{ + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "timeout", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + } + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneFailure(t *testing.T) { + errorType := "Runtime.ExitError" + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "failure", + ErrorType: &errorType, + Metrics: &RuntimeDoneInvokeMetrics{ + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "failure", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + }, + "errorType": "Runtime.ExitError" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInvokeRuntimeDoneWithEmptyErrorType(t *testing.T) { + errorType := "" + data := InvokeRuntimeDoneData{ + RequestID: requestID, + Status: "failure", + ErrorType: &errorType, + Metrics: &RuntimeDoneInvokeMetrics{ + DurationMs: float64(52.56), + }, + Spans: []Span{}, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "failure", + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + }, + "metrics": { + "producedBytes": 0, + "durationMs": 52.56 + }, + "errorType": "" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitRuntimeDoneSuccess(t *testing.T) { + var errorType *string + data := InitRuntimeDoneData{ + InitializationType: "snap-start", + Phase: "init", + Status: "success", + ErrorType: errorType, + } + + expected := ` + { + "initializationType": "snap-start", + "phase": "init", + "status": "success" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitRuntimeDoneError(t *testing.T) { + errorType := "Runtime.ExitError" + data := InitRuntimeDoneData{ + InitializationType: "snap-start", + Phase: "init", + Status: "error", + ErrorType: &errorType, + } + + expected := ` + { + "initializationType": "snap-start", + "phase": "init", + "status": "error", + "errorType": "Runtime.ExitError" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalInitRuntimeDoneFailureWithEmptyErrorType(t *testing.T) { + errorType := "" + data := InitRuntimeDoneData{ + InitializationType: "snap-start", + Phase: "init", + Status: "error", + ErrorType: &errorType, + } + + expected := ` + { + "initializationType": "snap-start", + "phase": "init", + "status": "error", + "errorType": "" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalRestoreRuntimeDoneSuccess(t *testing.T) { + var errorType *string + data := RestoreRuntimeDoneData{ + Status: "success", + ErrorType: errorType, + } + + expected := ` + { + "status": "success" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalRestoreRuntimeDoneError(t *testing.T) { + errorType := "Runtime.ExitError" + data := RestoreRuntimeDoneData{ + Status: "error", + ErrorType: &errorType, + } + + expected := ` + { + "status": "error", + "errorType": "Runtime.ExitError" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalRestoreRuntimeDoneErrorWithEmptyErrorType(t *testing.T) { + errorType := "" + data := RestoreRuntimeDoneData{ + Status: "error", + ErrorType: &errorType, + } + + expected := ` + { + "status": "error", + "errorType": "" + } + ` + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalExtensionInit(t *testing.T) { + data := ExtensionInitData{ + AgentName: "agentName", + State: "Registered", + ErrorType: "", + Subscriptions: []string{"INVOKE", "SHUTDOWN"}, + } + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, `{"name":"agentName","state":"Registered","events":["INVOKE","SHUTDOWN"]}`, string(actual)) +} + +func TestJsonMarshalExtensionInitWithError(t *testing.T) { + data := ExtensionInitData{ + AgentName: "agentName", + State: "Registered", + ErrorType: "Extension.FooBar", + Subscriptions: []string{"INVOKE", "SHUTDOWN"}, + } + + actual, err := json.Marshal(data) + assert.NoError(t, err) + assert.JSONEq(t, `{"name":"agentName","state":"Registered","events":["INVOKE","SHUTDOWN"],"errorType":"Extension.FooBar"}`, string(actual)) +} + +func TestJsonMarshalExtensionInitEmptyEvents(t *testing.T) { + data := ExtensionInitData{ + AgentName: "agentName", + State: "Registered", + ErrorType: "Extension.FooBar", + Subscriptions: []string{}, + } + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, `{"name":"agentName","state":"Registered","events":[],"errorType":"Extension.FooBar"}`, string(actual)) +} + +func TestJsonMarshalReportWithTracing(t *testing.T) { + errorType := "Runtime.ExitError" + data := ReportData{ + RequestID: requestID, + Status: "error", + ErrorType: &errorType, + Metrics: ReportMetrics{ + DurationMs: float64(52.56), + BilledDurationMs: float64(52.40), + MemorySizeMB: uint64(1024), + MaxMemoryUsedMB: uint64(512), + }, + Tracing: &TracingCtx{ + SpanID: "spanid", + Type: model.XRayTracingType, + Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "error", + "errorType": "Runtime.ExitError", + "metrics": { + "durationMs": 52.56, + "billedDurationMs": 52.40, + "memorySizeMB": 1024, + "maxMemoryUsedMB": 512 + }, + "tracing": { + "spanId": "spanid", + "type": "X-Amzn-Trace-Id", + "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + } + } + ` + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalReportWithoutErrorSpansAndTracing(t *testing.T) { + data := ReportData{ + RequestID: requestID, + Status: "timeout", + Metrics: ReportMetrics{ + DurationMs: float64(52.56), + BilledDurationMs: float64(52.40), + MemorySizeMB: uint64(1024), + MaxMemoryUsedMB: uint64(512), + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "timeout", + "metrics": { + "durationMs": 52.56, + "billedDurationMs": 52.40, + "memorySizeMB": 1024, + "maxMemoryUsedMB": 512 + } + } + ` + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, expected, string(actual)) +} + +func TestJsonMarshalReportWithInit(t *testing.T) { + data := ReportData{ + RequestID: requestID, + Status: "success", + Metrics: ReportMetrics{ + DurationMs: float64(52.56), + BilledDurationMs: float64(52.40), + MemorySizeMB: uint64(1024), + MaxMemoryUsedMB: uint64(512), + InitDurationMs: float64(3.15), + }, + } + + expected := ` + { + "requestId": "REQUEST_ID", + "status": "success", + "metrics": { + "durationMs": 52.56, + "billedDurationMs": 52.40, + "memorySizeMB": 1024, + "maxMemoryUsedMB": 512, + "initDurationMs": 3.15 + } + } + ` + + actual, err := json.Marshal(data) + require.NoError(t, err) + require.JSONEq(t, expected, string(actual)) +} diff --git a/lambda/interop/messages.go b/lambda/interop/messages.go new file mode 100644 index 0000000..ee1c783 --- /dev/null +++ b/lambda/interop/messages.go @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package interop + +// conversion from internal data structure into well defined messages + +func DoneFromInvokeSuccess(successMsg InvokeSuccess) *Done { + return &Done{ + Meta: DoneMetadata{ + RuntimeRelease: successMsg.RuntimeRelease, + NumActiveExtensions: successMsg.NumActiveExtensions, + ExtensionNames: successMsg.ExtensionNames, + InvokeRequestReadTimeNs: successMsg.InvokeMetrics.InvokeRequestReadTimeNs, + InvokeRequestSizeBytes: successMsg.InvokeMetrics.InvokeRequestSizeBytes, + RuntimeReadyTime: successMsg.InvokeMetrics.RuntimeReadyTime, + + InvokeCompletionTimeNs: successMsg.InvokeCompletionTimeNs, + InvokeReceivedTime: successMsg.InvokeReceivedTime, + RuntimeResponseLatencyMs: successMsg.ResponseMetrics.RuntimeResponseLatencyMs, + RuntimeTimeThrottledMs: successMsg.ResponseMetrics.RuntimeTimeThrottledMs, + RuntimeProducedBytes: successMsg.ResponseMetrics.RuntimeProducedBytes, + RuntimeOutboundThroughputBps: successMsg.ResponseMetrics.RuntimeOutboundThroughputBps, + LogsAPIMetrics: successMsg.LogsAPIMetrics, + MetricsDimensions: DoneMetadataMetricsDimensions{ + InvokeResponseMode: successMsg.InvokeResponseMode, + }, + }, + } +} + +func DoneFailFromInvokeFailure(failureMsg *InvokeFailure) *DoneFail { + return &DoneFail{ + ErrorType: failureMsg.ErrorType, + Meta: DoneMetadata{ + RuntimeRelease: failureMsg.RuntimeRelease, + NumActiveExtensions: failureMsg.NumActiveExtensions, + InvokeReceivedTime: failureMsg.InvokeReceivedTime, + + RuntimeResponseLatencyMs: failureMsg.ResponseMetrics.RuntimeResponseLatencyMs, + RuntimeTimeThrottledMs: failureMsg.ResponseMetrics.RuntimeTimeThrottledMs, + RuntimeProducedBytes: failureMsg.ResponseMetrics.RuntimeProducedBytes, + RuntimeOutboundThroughputBps: failureMsg.ResponseMetrics.RuntimeOutboundThroughputBps, + + InvokeRequestReadTimeNs: failureMsg.InvokeMetrics.InvokeRequestReadTimeNs, + InvokeRequestSizeBytes: failureMsg.InvokeMetrics.InvokeRequestSizeBytes, + RuntimeReadyTime: failureMsg.InvokeMetrics.RuntimeReadyTime, + + ExtensionNames: failureMsg.ExtensionNames, + LogsAPIMetrics: failureMsg.LogsAPIMetrics, + + MetricsDimensions: DoneMetadataMetricsDimensions{ + InvokeResponseMode: failureMsg.InvokeResponseMode, + }, + }, + } +} + +func DoneFailFromInitFailure(initFailure *InitFailure) *DoneFail { + return &DoneFail{ + ErrorType: initFailure.ErrorType, + Meta: DoneMetadata{ + RuntimeRelease: initFailure.RuntimeRelease, + NumActiveExtensions: initFailure.NumActiveExtensions, + LogsAPIMetrics: initFailure.LogsAPIMetrics, + }, + } +} diff --git a/lambda/interop/model.go b/lambda/interop/model.go index cc9c7d0..a4bdbf4 100644 --- a/lambda/interop/model.go +++ b/lambda/interop/model.go @@ -5,9 +5,9 @@ package interop import ( "encoding/json" + "errors" "fmt" "io" - "net/http" "strings" "time" @@ -32,8 +32,6 @@ const ( MaxResponseBandwidthBurstSize = 64 * 1024 * 1024 // 64 MiB ) -const functionResponseSizeTooLargeType = "Function.ResponseSizeTooLarge" - // ResponseMode are top-level constants used in combination with the various types of // modes we have for responses, such as invoke's response mode and function's response mode. // In the future we might have invoke's request mode or similar, so these help set the ground @@ -52,25 +50,6 @@ var AllInvokeResponseModes = []string{ string(InvokeResponseModeBuffered), string(InvokeResponseModeStreaming), } -// ConvertToInvokeResponseMode converts the given string to a InvokeResponseMode -// It is case insensitive and if there is no match, an error is thrown. -func ConvertToInvokeResponseMode(value string) (InvokeResponseMode, error) { - // buffered - if strings.EqualFold(value, string(InvokeResponseModeBuffered)) { - return InvokeResponseModeBuffered, nil - } - - // streaming - if strings.EqualFold(value, string(InvokeResponseModeStreaming)) { - return InvokeResponseModeStreaming, nil - } - - // unknown - allowedValues := strings.Join(AllInvokeResponseModes, ", ") - log.Errorf("Unlable to map %s to %s.", value, allowedValues) - return "", ErrInvalidInvokeResponseMode -} - // FunctionResponseMode is passed by Runtime to tell whether the response should be // streamed or not. type FunctionResponseMode string @@ -82,6 +61,7 @@ var AllFunctionResponseModes = []string{ string(FunctionResponseModeBuffered), string(FunctionResponseModeStreaming), } +// TODO: move to directinvoke.go as we're trying to deprecate interop.* package // ConvertToFunctionResponseMode converts the given string to a FunctionResponseMode // It is case insensitive and if there is no match, an error is thrown. func ConvertToFunctionResponseMode(value string) (FunctionResponseMode, error) { @@ -108,57 +88,79 @@ type Message interface{} type Invoke struct { // Tracing header. // https://docs.aws.amazon.com/xray/latest/devguide/xray-concepts.html#xray-concepts-tracingheader - TraceID string - LambdaSegmentID string - ID string - InvokedFunctionArn string - CognitoIdentityID string - CognitoIdentityPoolID string - DeadlineNs string - ClientContext string - ContentType string - Payload io.Reader - NeedDebugLogs bool - ReservationToken string - VersionID string - InvokeReceivedTime int64 - InvokeResponseMetrics *InvokeResponseMetrics + TraceID string + LambdaSegmentID string + ID string + InvokedFunctionArn string + CognitoIdentityID string + CognitoIdentityPoolID string + DeadlineNs string + ClientContext string + ContentType string + Payload io.Reader + NeedDebugLogs bool + ReservationToken string + VersionID string + InvokeReceivedTime int64 + InvokeResponseMetrics *InvokeResponseMetrics + InvokeResponseMode InvokeResponseMode + RestoreDurationNs int64 // equals 0 for non-snapstart functions + RestoreStartTimeMonotime int64 // equals 0 for non-snapstart functions } type Token struct { - ReservationToken string - InvokeID string - VersionID string - FunctionTimeout time.Duration - InvackDeadlineNs int64 - TraceID string - LambdaSegmentID string - InvokeMetadata string - NeedDebugLogs bool + ReservationToken string + InvokeID string + VersionID string + FunctionTimeout time.Duration + InvackDeadlineNs int64 + TraceID string + LambdaSegmentID string + InvokeMetadata string + NeedDebugLogs bool + RestoreDurationNs int64 + RestoreStartTimeMonotime int64 } -type ErrorResponse struct { - // Payload sent via shared memory. - Payload []byte `json:"Payload,omitempty"` - ContentType string `json:"-"` - FunctionResponseMode string `json:"-"` - - // When error response body (Payload) is not provided, e.g. - // not retrievable, error type and error message will be - // used by the Slicer to construct a response json, e.g: - // - // default error response produced by the Slicer: - // '{"errorMessage":"Unknown application error occurred"}', - // - // when error type is provided, error response becomes: - // '{"errorMessage":"Unknown application error occurred","errorType":"ErrorType"}' - ErrorType string `json:"errorType,omitempty"` - ErrorMessage string `json:"errorMessage,omitempty"` - +// InvokeErrorTraceData is used by the tracer to mark segments as being invocation error +type InvokeErrorTraceData struct { // Attached to invoke segment ErrorCause json.RawMessage `json:"ErrorCause,omitempty"` } +func GetErrorResponseWithFormattedErrorMessage(errorType fatalerror.ErrorType, err error, invokeRequestID string) *ErrorInvokeResponse { + var errorMessage string + if invokeRequestID != "" { + errorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeRequestID, err) + } else { + errorMessage = fmt.Sprintf("Error: %v", err) + } + + jsonPayload, err := json.Marshal(FunctionError{ + Type: errorType, + Message: errorMessage, + }) + + if err != nil { + return &ErrorInvokeResponse{ + Headers: InvokeResponseHeaders{}, + FunctionError: FunctionError{ + Type: fatalerror.SandboxFailure, + Message: errorMessage, + }, + Payload: []byte{}, + } + } + + headers := InvokeResponseHeaders{} + functionError := FunctionError{ + Type: errorType, + Message: errorMessage, + } + + return &ErrorInvokeResponse{Headers: headers, FunctionError: functionError, Payload: jsonPayload} +} + // SandboxType identifies sandbox type (PreWarmed vs Classic) type SandboxType string @@ -178,7 +180,7 @@ type DynamicDomainConfig struct { // extra hooks to execute at domain start. Currently used for filesystem and network hooks. // It can be empty. AdditionalStartHooks []model.Hook - Mounts []model.DriveMount + Mounts []model.Mount //TODO: other dynamic configurations for the domain go here } @@ -189,14 +191,17 @@ type Reset struct { InvokeResponseMetrics *InvokeResponseMetrics TraceID string LambdaSegmentID string + InvokeResponseMode InvokeResponseMode } // Restore message is sent to rapid to restore runtime to make it ready for consecutive invokes type Restore struct { - AwsKey string - AwsSecret string - AwsSession string - CredentialsExpiry time.Time + AwsKey string + AwsSecret string + AwsSession string + CredentialsExpiry time.Time + RestoreHookTimeoutMs int64 + LogStreamName string } type Resync struct { @@ -224,7 +229,10 @@ func MergeSubscriptionMetrics(logsAPIMetrics TelemetrySubscriptionMetrics, telem // InvokeResponseMetrics are produced while sending streaming invoke response to WP type InvokeResponseMetrics struct { - StartReadingResponseMonoTimeMs int64 + // FIXME: this assumes a value in nanoseconds, let's rename it + // to StartReadingResponseMonoTimeNs + StartReadingResponseMonoTimeMs int64 + // Same as the one above FinishReadingResponseMonoTimeMs int64 TimeShapedNs int64 ProducedBytes int64 @@ -240,6 +248,22 @@ func IsResponseStreamingMetrics(metrics *InvokeResponseMetrics) bool { return metrics.FunctionResponseMode == FunctionResponseModeStreaming } +type DoneMetadataMetricsDimensions struct { + InvokeResponseMode InvokeResponseMode +} + +func (dimensions DoneMetadataMetricsDimensions) String() string { + var stringDimensions []string + + if dimensions.InvokeResponseMode != "" { + dimension := string("invoke_response_mode=" + dimensions.InvokeResponseMode) + stringDimensions = append(stringDimensions, dimension) + } + return strings.ToLower( + strings.Join(stringDimensions, ","), + ) +} + type DoneMetadata struct { NumActiveExtensions int ExtensionsResetMs int64 @@ -252,9 +276,11 @@ type DoneMetadata struct { InvokeCompletionTimeNs int64 InvokeReceivedTime int64 RuntimeReadyTime int64 + RuntimeResponseLatencyMs float64 RuntimeTimeThrottledMs int64 RuntimeProducedBytes int64 RuntimeOutboundThroughputBps int64 + MetricsDimensions DoneMetadataMetricsDimensions } type Done struct { @@ -332,19 +358,18 @@ func (s *ErrorResponseTooLarge) Error() string { return fmt.Sprintf("Response payload size (%d bytes) exceeded maximum allowed payload size (%d bytes).", s.ResponseSize, s.MaxResponseSize) } -// AsErrorResponse generates ErrorResponse from ErrorResponseTooLarge -func (s *ErrorResponseTooLarge) AsInteropError() *ErrorResponse { - resp := ErrorResponse{ - ErrorType: functionResponseSizeTooLargeType, - ErrorMessage: s.Error(), +// AsErrorResponse generates ErrorInvokeResponse from ErrorResponseTooLarge +func (s *ErrorResponseTooLarge) AsErrorResponse() *ErrorInvokeResponse { + functionError := FunctionError{ + Type: fatalerror.FunctionOversizedResponse, + Message: s.Error(), } - respJSON, err := json.Marshal(resp) + jsonPayload, err := json.Marshal(functionError) if err != nil { - panic("Failed to marshal interop.ErrorResponse") + panic("Failed to marshal interop.FunctionError") } - resp.Payload = respJSON - resp.ContentType = "application/json" - return &resp + headers := InvokeResponseHeaders{ContentType: "application/json"} + return &ErrorInvokeResponse{Headers: headers, FunctionError: functionError, Payload: jsonPayload} } // Server used for sending messages and sharing data between the Runtime API handlers and the @@ -356,21 +381,6 @@ func (s *ErrorResponseTooLarge) AsInteropError() *ErrorResponse { // protocol used by the specific implementation // TODO: rename this to InvokeResponseContext, used to send responses from handlers to platform-facing server type Server interface { - // SendResponse sends response. - // Errors returned: - // ErrInvalidInvokeID - validation error indicating that provided invokeID doesn't match current invokeID - // ErrResponseSent - validation error indicating that response with given invokeID was already sent - // Non-nil error - non-nil error indicating transport failure - SendResponse(invokeID string, headers map[string]string, response io.Reader, trailers http.Header, request *CancellableRequest) error - - // SendErrorResponse sends error response. - // Errors returned: - // ErrInvalidInvokeID - validation error indicating that provided invokeID doesn't match current invokeID - // ErrResponseSent - validation error indicating that response with given invokeID was already sent - // Non-nil error - non-nil error indicating transport failure - SendErrorResponse(invokeID string, response *ErrorResponse) error - SendInitErrorResponse(invokeID string, response *ErrorResponse) error - // GetCurrentInvokeID returns current invokeID. // NOTE, in case of INIT, when invokeID is not known in advance (e.g. provisioned concurrency), // returned invokeID will contain empty value. @@ -381,24 +391,40 @@ type Server interface { // from the time when all extensions have called /next. // TODO: this method is a lifecycle event used only for metrics, and doesn't belong here SendRuntimeReady() error + + // SendInitErrorResponse does two separate things when init/error is called: + // a) sends the init error response if called during invoke, and + // b) notifies platform of a user fault if called, during both init or invoke + // TODO: + // separate the two concerns & unify with SendErrorResponse in response sender + SendInitErrorResponse(response *ErrorInvokeResponse) error } type InternalStateGetter func() statejson.InternalStateDescription -const OnDemandInitTelemetrySource string = "on-demand" -const ProvisionedConcurrencyInitTelemetrySource string = "provisioned-concurrency" -const InitCachingInitTelemetrySource string = "snap-start" +// ErrRestoreHookTimeout is returned as a response to `RESTORE` message +// when function's restore hook takes more time to execute thatn +// the timeout value. +var ErrRestoreHookTimeout = errors.New("Runtime.RestoreHookUserTimeout") -func InferTelemetryInitSource(initCachingEnabled bool, sandboxType SandboxType) string { - initSource := OnDemandInitTelemetrySource - - // ToDo: Unify this selection of SandboxType by using the START message - // after having a roadmap on the combination of INIT modes - if initCachingEnabled { - initSource = InitCachingInitTelemetrySource - } else if sandboxType == SandboxPreWarmed { - initSource = ProvisionedConcurrencyInitTelemetrySource - } +// ErrRestoreHookUserError is returned as a response to `RESTORE` message +// when function's restore hook faces with an error on throws an exception. +// UserError contains the error type that the runtime encountered. +type ErrRestoreHookUserError struct { + UserError FunctionError +} - return initSource +func (err ErrRestoreHookUserError) Error() string { + return "errRestoreHookUserError" } + +// ErrRestoreUpdateCredentials is returned as a response to `RESTORE` message +// if RAPID cannot update the credentials served by credentials API +// during the RESTORE phase. +var ErrRestoreUpdateCredentials = errors.New("errRestoreUpdateCredentials") + +var ErrCannotParseCredentialsExpiry = errors.New("errCannotParseCredentialsExpiry") + +var ErrCannotParseRestoreHookTimeoutMs = errors.New("errCannotParseRestoreHookTimeoutMs") + +var ErrMissingRestoreCredentials = errors.New("errMissingRestoreCredentials") diff --git a/lambda/interop/model_test.go b/lambda/interop/model_test.go index 9ad4d17..d9ba36a 100644 --- a/lambda/interop/model_test.go +++ b/lambda/interop/model_test.go @@ -4,8 +4,11 @@ package interop import ( + "fmt" "testing" + "go.amzn.com/lambda/fatalerror" + "github.com/stretchr/testify/assert" ) @@ -25,3 +28,39 @@ func TestMergeSubscriptionMetrics(t *testing.T) { assert.Equal(t, 2, metrics["server_error"]) assert.Equal(t, 2, metrics["client_error"]) } + +func TestGetErrorResponseWithFormattedErrorMessageWithoutInvokeRequestId(t *testing.T) { + errorType := fatalerror.RuntimeExit + errorMessage := fmt.Errorf("Divided by 0") + expectedMsg := fmt.Sprintf(`Error: %s`, errorMessage) + expectedJSON := fmt.Sprintf(`{"errorType": "%s", "errorMessage": "%s"}`, string(errorType), expectedMsg) + + actual := GetErrorResponseWithFormattedErrorMessage(errorType, errorMessage, "") + assert.Equal(t, errorType, actual.FunctionError.Type) + assert.Equal(t, expectedMsg, actual.FunctionError.Message) + assert.JSONEq(t, expectedJSON, string(actual.Payload)) +} + +func TestGetErrorResponseWithFormattedErrorMessageWithInvokeRequestId(t *testing.T) { + errorType := fatalerror.RuntimeExit + errorMessage := fmt.Errorf("Divided by 0") + invokeID := "invoke-id" + expectedMsg := fmt.Sprintf(`RequestId: %s Error: %s`, invokeID, errorMessage) + expectedJSON := fmt.Sprintf(`{"errorType": "%s", "errorMessage": "%s"}`, string(errorType), expectedMsg) + + actual := GetErrorResponseWithFormattedErrorMessage(errorType, errorMessage, invokeID) + assert.Equal(t, errorType, actual.FunctionError.Type) + assert.Equal(t, expectedMsg, actual.FunctionError.Message) + assert.JSONEq(t, expectedJSON, string(actual.Payload)) +} + +func TestDoneMetadataMetricsDimensionsStringWhenInvokeResponseModeIsPresent(t *testing.T) { + dimensions := DoneMetadataMetricsDimensions{ + InvokeResponseMode: InvokeResponseModeStreaming, + } + assert.Equal(t, "invoke_response_mode=streaming", dimensions.String()) +} +func TestDoneMetadataMetricsDimensionsStringWhenEmpty(t *testing.T) { + dimensions := DoneMetadataMetricsDimensions{} + assert.Equal(t, "", dimensions.String()) +} diff --git a/lambda/interop/sandbox_model.go b/lambda/interop/sandbox_model.go index b5d15b0..3011c48 100644 --- a/lambda/interop/sandbox_model.go +++ b/lambda/interop/sandbox_model.go @@ -4,9 +4,13 @@ package interop import ( + "bytes" + "io" + "net/http" "time" "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/rapidcore/env" ) // Init represents an init message @@ -15,6 +19,7 @@ import ( type Init struct { InvokeID string Handler string + AccountID string AwsKey string AwsSecret string AwsSession string @@ -28,23 +33,13 @@ type Init struct { // In standalone mode, these env vars come from test/init but from environment otherwise. CustomerEnvironmentVariables map[string]string SandboxType SandboxType - // there is no dynamic config at the moment for the runtime domain - OperatorDomainExtraConfig DynamicDomainConfig - RuntimeInfo RuntimeInfo - Bootstrap Bootstrap - EnvironmentVariables EnvironmentVariables // contains env vars for agents and runtime procs -} - -// InitStarted contains metadata about the initialized sandbox -// In Rapid Shim, this translates to a RUNNING GirD message to Slicer -// In Rapid Daemon, this is followed by a SANDBOX GirP message to MM -type InitStarted struct { - WaitStartTimeNs int64 - WaitEndTimeNs int64 - PreLoadTimeNs int64 - PostLoadTimeNs int64 - ExtensionsEnabled bool - Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent + LogStreamName string + InstanceMaxMemory uint64 + OperatorDomainExtraConfig DynamicDomainConfig + RuntimeDomainExtraConfig DynamicDomainConfig + RuntimeInfo RuntimeInfo + Bootstrap Bootstrap + EnvironmentVariables *env.Environment // contains env vars for agents and runtime procs } // InitSuccess indicates that runtime/extensions initialization completed successfully @@ -72,11 +67,61 @@ type InitFailure struct { Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent } +// ErrorInvokeResponse represents a buffered response received via Runtime API +// for error responses. When body (Payload) is not provided, e.g. +// not retrievable, error type and error message headers will be +// used by the platform to construct a response json, e.g: +// +// default error response produced by the Slicer: +// '{"errorMessage":"Unknown application error occurred"}', +// +// when error type is provided, error response becomes: +// '{"errorMessage":"Unknown application error occurred","errorType":"ErrorType"}' +type ErrorInvokeResponse struct { + Headers InvokeResponseHeaders + Payload []byte + FunctionError FunctionError +} + +// StreamableInvokeResponse represents a response received via Runtime API that can be streamed +type StreamableInvokeResponse struct { + Headers map[string]string + Payload io.Reader + Trailers http.Header + Request *CancellableRequest // streaming request may need to gracefully terminate request streams +} + +// InvokeResponseHeaders contains the headers received via Runtime API /invocation/response +type InvokeResponseHeaders struct { + ContentType string + FunctionResponseMode string +} + +// FunctionError represents information about function errors or 'user errors' +// These are not platform errors and hence are returned as 200 by Lambda +// In the absence of a response payload, the Function Error is serialized and sent +type FunctionError struct { + // Type of error is derived from the Lambda-Runtime-Function-Error-Type set by the Runtime + // This is customer data, so RAPID scrubs this error type to contain only allowlisted values + Type fatalerror.ErrorType `json:"errorType,omitempty"` + // ErrorMessage is generated by RAPID and can never be specified by runtime + Message string `json:"errorMessage,omitempty"` +} + +type InvokeResponseSender interface { + // SendResponse sends invocation response received from Runtime to platform + // This is response may be streamed based on function and invoke response mode + SendResponse(invokeID string, response *StreamableInvokeResponse) error + // SendErrorResponse sends error response in the case of function errors, which are always buffered + SendErrorResponse(invokeID string, response *ErrorInvokeResponse) error +} + // ResponseMetrics groups metrics related to the response stream type ResponseMetrics struct { - RuntimeTimeThrottledMs int64 - RuntimeProducedBytes int64 RuntimeOutboundThroughputBps int64 + RuntimeProducedBytes int64 + RuntimeResponseLatencyMs float64 + RuntimeTimeThrottledMs int64 } // InvokeMetrics groups metrics related to the invoke phase @@ -96,6 +141,7 @@ type InvokeSuccess struct { LogsAPIMetrics TelemetrySubscriptionMetrics ResponseMetrics ResponseMetrics InvokeMetrics InvokeMetrics + InvokeResponseMode InvokeResponseMode } // InvokeFailure is the failure response to invoke phase end @@ -111,21 +157,24 @@ type InvokeFailure struct { ResponseMetrics ResponseMetrics InvokeMetrics InvokeMetrics ExtensionNames string - DefaultErrorResponse *ErrorResponse // error resp constructed by platform during fn errors + DefaultErrorResponse *ErrorInvokeResponse // error resp constructed by platform during fn errors + InvokeResponseMode InvokeResponseMode } // ResetSuccess is the success response to reset request type ResetSuccess struct { - ExtensionsResetMs int64 - ErrorType fatalerror.ErrorType - ResponseMetrics ResponseMetrics + ExtensionsResetMs int64 + ErrorType fatalerror.ErrorType + ResponseMetrics ResponseMetrics + InvokeResponseMode InvokeResponseMode } // ResetFailure is the failure response to reset request type ResetFailure struct { - ExtensionsResetMs int64 - ErrorType fatalerror.ErrorType - ResponseMetrics ResponseMetrics + ExtensionsResetMs int64 + ErrorType fatalerror.ErrorType + ResponseMetrics ResponseMetrics + InvokeResponseMode InvokeResponseMode } // ShutdownSuccess is the response to a shutdown request @@ -136,35 +185,46 @@ type ShutdownSuccess struct { // SandboxInfoFromInit captures data from init request that // is required during invoke (e.g. for suppressed init) type SandboxInfoFromInit struct { - EnvironmentVariables EnvironmentVariables // contains agent env vars (creds, customer, platform) - SandboxType SandboxType // indicating Pre-Warmed, On-Demand etc - RuntimeBootstrap Bootstrap // contains the runtime bootstrap binary path, Cwd, Args, Env, Cmd + EnvironmentVariables *env.Environment // contains agent env vars (creds, customer, platform) + SandboxType SandboxType // indicating Pre-Warmed, On-Demand etc + RuntimeBootstrap Bootstrap // contains the runtime bootstrap binary path, Cwd, Args, Env, Cmd +} + +// RestoreResult represents the result of `HandleRestore` function +// in RapidCore +type RestoreResult struct { + RestoreMs int64 } // RapidContext expose methods for functionality of the Rapid Core library type RapidContext interface { - HandleInit(i *Init, started chan<- InitStarted, success chan<- InitSuccess, failure chan<- InitFailure) - HandleInvoke(i *Invoke, sbMetadata SandboxInfoFromInit) (InvokeSuccess, *InvokeFailure) - HandleReset(reset *Reset, invokeReceivedTime int64, InvokeResponseMetrics *InvokeResponseMetrics) (ResetSuccess, *ResetFailure) + HandleInit(i *Init, success chan<- InitSuccess, failure chan<- InitFailure) + HandleInvoke(i *Invoke, sbMetadata SandboxInfoFromInit, requestBuf *bytes.Buffer, responseSender InvokeResponseSender) (InvokeSuccess, *InvokeFailure) + HandleReset(reset *Reset) (ResetSuccess, *ResetFailure) HandleShutdown(shutdown *Shutdown) ShutdownSuccess - HandleRestore(restore *Restore) error + HandleRestore(restore *Restore) (RestoreResult, error) Clear() + + SetRuntimeStartedTime(runtimeStartedTime int64) + SetInvokeResponseMetrics(metrics *InvokeResponseMetrics) + + SetEventsAPI(eventsAPI EventsAPI) } // SandboxContext represents the sandbox lifecycle context type SandboxContext interface { - Init(i *Init, timeoutMs int64) (InitStarted, InitContext) + Init(i *Init, timeoutMs int64) InitContext Reset(reset *Reset) (ResetSuccess, *ResetFailure) Shutdown(shutdown *Shutdown) ShutdownSuccess - Restore(restore *Restore) error + Restore(restore *Restore) (RestoreResult, error) // TODO: refactor this - // invokeReceivedTime and InvokeResponseMetrics are needed to compute the runtimeDone metrics + // runtimeStartedTime and InvokeResponseMetrics are needed to compute the runtimeDone metrics // in case of a Reset during an invoke (reset.reason=failure or reset.reason=timeout). // Ideally: - // - the InvokeContext will have a Reset method to deal with Reset during an invoke and will hold invokeReceivedTime and InvokeResponseMetrics + // - the InvokeContext will have a Reset method to deal with Reset during an invoke and will hold runtimeStartedTime and InvokeResponseMetrics // - the SandboxContext will have its own Reset/Spindown method - SetInvokeReceivedTime(invokeReceivedTime int64) + SetRuntimeStartedTime(invokeReceivedTime int64) SetInvokeResponseMetrics(metrics *InvokeResponseMetrics) } @@ -176,10 +236,14 @@ type InitContext interface { // InvokeContext represents the lifecycle of a sandbox reservation type InvokeContext interface { - SendRequest(i *Invoke) + SendRequest(i *Invoke, r InvokeResponseSender) Wait() (InvokeSuccess, *InvokeFailure) } -// Restored message is sent to Slicer to inform Runtime Restore Hook execution was successful -type Restored struct { -} +// LifecyclePhase represents enum for possible Sandbox lifecycle phases, like init, invoke, etc. +type LifecyclePhase int + +const ( + LifecyclePhaseInit LifecyclePhase = iota + 1 + LifecyclePhaseInvoke +) diff --git a/lambda/metering/time.go b/lambda/metering/time.go index cf3ad1d..9e0fa01 100644 --- a/lambda/metering/time.go +++ b/lambda/metering/time.go @@ -12,15 +12,19 @@ import ( //go:linkname Monotime runtime.nanotime func Monotime() int64 -// MonoToEpoch converts monotonic time nanos to epoch time nanos. +// MonoToEpoch converts monotonic time nanos to unix epoch time nanos. func MonoToEpoch(t int64) int64 { monoNsec := Monotime() wallNsec := time.Now().UnixNano() - clockOffset := wallNsec - monoNsec return t + clockOffset } +func TimeToMono(t time.Time) int64 { + durNs := time.Since(t).Nanoseconds() + return Monotime() - durNs +} + type ExtensionsResetDurationProfiler struct { NumAgentsRegisteredForShutdown int AvailableNs int64 diff --git a/lambda/metering/time_test.go b/lambda/metering/time_test.go index 0088f9f..5c37a87 100644 --- a/lambda/metering/time_test.go +++ b/lambda/metering/time_test.go @@ -19,6 +19,14 @@ func TestMonoToEpochPrecision(t *testing.T) { assert.True(t, math.Abs(float64(a-b)) < float64(time.Millisecond)) } +func TestEpochToMonoPrecision(t *testing.T) { + a := Monotime() + b := TimeToMono(time.Now()) + + // Conversion error is less than a millisecond. + assert.Less(t, math.Abs(float64(b-a)), float64(1*time.Millisecond)) +} + func TestExtensionsResetDurationProfilerForExtensionsResetWithNoExtensions(t *testing.T) { mono := Monotime() profiler := ExtensionsResetDurationProfiler{} diff --git a/lambda/rapi/extensions_fuzz_test.go b/lambda/rapi/extensions_fuzz_test.go new file mode 100644 index 0000000..c223859 --- /dev/null +++ b/lambda/rapi/extensions_fuzz_test.go @@ -0,0 +1,344 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/extensions" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/rapi/handler" + "go.amzn.com/lambda/rapi/model" + "go.amzn.com/lambda/rapi/rendering" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" +) + +func FuzzAgentRegisterHandler(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + registerReq := handler.RegisterRequest{ + Events: []core.Event{core.InvokeEvent, core.ShutdownEvent}, + } + regReqBytes, err := json.Marshal(®isterReq) + if err != nil { + f.Errorf("failed to marshal register request: %v", err) + } + f.Add("agent", "accountId", true, regReqBytes) + f.Add("agent", "accountId", false, regReqBytes) + + f.Fuzz(func(t *testing.T, + agentName string, + featuresHeader string, + external bool, + payload []byte, + ) { + flowTest := testdata.NewFlowTest() + + if external { + flowTest.RegistrationService.CreateExternalAgent(agentName) + } + + functionMetadata := createDummyFunctionMetadata() + flowTest.RegistrationService.SetFunctionMetadata(functionMetadata) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL("/extension/register", version20200101) + request := httptest.NewRequest("POST", target, bytes.NewReader(payload)) + request.Header.Add(handler.LambdaAgentName, agentName) + request.Header.Add("Lambda-Extension-Accept-Feature", featuresHeader) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentName == "" { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionName") + return + } + + regReqStruct := struct { + handler.RegisterRequest + ConfigurationKeys []string `json:"configurationKeys"` + }{} + if err := json.Unmarshal(payload, ®ReqStruct); err != nil { + assertForbiddenErrorType(t, responseRecorder, "InvalidRequestFormat") + return + } + + if containsInvalidEvent(external, regReqStruct.Events) { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidEventType") + return + } + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + respBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + expectedResponse := map[string]interface{}{ + "functionName": functionMetadata.FunctionName, + "functionVersion": functionMetadata.FunctionVersion, + "handler": functionMetadata.Handler, + } + if featuresHeader == "accountId" && functionMetadata.AccountID != "" { + expectedResponse["accountId"] = functionMetadata.AccountID + } + + expectedRespBytes, err := json.Marshal(expectedResponse) + assert.NoError(t, err) + assert.JSONEq(t, string(expectedRespBytes), string(respBody)) + + if external { + agent, found := flowTest.RegistrationService.FindExternalAgentByName(agentName) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) + } else { + agent, found := flowTest.RegistrationService.FindInternalAgentByName(agentName) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) + } + }) +} + +func FuzzAgentNextHandler(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + regService := core.NewRegistrationService(core.NewInitFlowSynchronization(), core.NewInvokeFlowSynchronization()) + testAgent := makeExternalAgent(regService) + f.Add(testAgent.ID.String(), true, true) + f.Add(testAgent.ID.String(), true, false) + + f.Fuzz(func(t *testing.T, + agentIdentifierHeader string, + registered bool, + isInvokeEvent bool, + ) { + flowTest := testdata.NewFlowTest() + agent := makeExternalAgent(flowTest.RegistrationService) + + if registered { + agent.SetState(agent.RegisteredState) + agent.Release() + } + + configureRendererForEvent(flowTest, isInvokeEvent) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL("/extension/event/next", version20200101) + request := httptest.NewRequest("GET", target, nil) + request.Header.Set(handler.LambdaAgentIdentifier, agentIdentifierHeader) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentIdentifierHeader == "" { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierMissing) + return + } + if _, err := uuid.Parse(agentIdentifierHeader); err != nil { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierInvalid) + return + } + if agentIdentifierHeader != agent.ID.String() { + assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") + return + } + if !registered { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionState") + return + } + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + + assertResponseEventType(t, isInvokeEvent, responseRecorder) + + assert.Equal(t, agent.RunningState, agent.GetState()) + }) +} + +func FuzzAgentInitErrorHandler(f *testing.F) { + fuzzErrorHandler(f, "/extension/init/error", fatalerror.AgentInitError) +} + +func FuzzAgentExitErrorHandler(f *testing.F) { + fuzzErrorHandler(f, "/extension/exit/error", fatalerror.AgentExitError) +} + +func fuzzErrorHandler(f *testing.F, handlerPath string, fatalErrorType fatalerror.ErrorType) { + extensions.Enable() + defer extensions.Disable() + + regService := core.NewRegistrationService(core.NewInitFlowSynchronization(), core.NewInvokeFlowSynchronization()) + testAgent := makeExternalAgent(regService) + f.Add(true, testAgent.ID.String(), "Extension.SomeError") + f.Add(false, testAgent.ID.String(), "Extension.SomeError") + + f.Fuzz(func(t *testing.T, + agentRegistered bool, + agentIdentifierHeader string, + errorType string, + ) { + flowTest := testdata.NewFlowTest() + + agent := makeExternalAgent(flowTest.RegistrationService) + + if agentRegistered { + agent.SetState(agent.RegisteredState) + } + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL(handlerPath, version20200101) + + request := httptest.NewRequest("POST", target, nil) + request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) + request.Header.Set(handler.LambdaAgentIdentifier, agentIdentifierHeader) + request.Header.Set(handler.LambdaAgentFunctionErrorType, errorType) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentIdentifierHeader == "" { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierMissing) + return + } + + if _, e := uuid.Parse(agentIdentifierHeader); e != nil { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierInvalid) + return + } + + if errorType == "" { + assertForbiddenErrorType(t, responseRecorder, "Extension.MissingHeader") + return + } + if agentIdentifierHeader != agent.ID.String() { + assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") + return + } + if !agentRegistered { + assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionState") + } else { + assertErrorAgentRegistered(t, responseRecorder, flowTest, fatalErrorType) + } + }) +} + +func assertErrorAgentRegistered(t *testing.T, responseRecorder *httptest.ResponseRecorder, flowTest *testdata.FlowTest, expectedErrType fatalerror.ErrorType) { + var response model.StatusResponse + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &response) + assert.NoError(t, err) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.Equal(t, "OK", response.Status) + + v, found := appctx.LoadFirstFatalError(flowTest.AppCtx) + assert.True(t, found) + assert.Equal(t, expectedErrType, v) +} + +func assertForbiddenErrorType(t *testing.T, responseRecorder *httptest.ResponseRecorder, errType string) { + assert.Equal(t, http.StatusForbidden, responseRecorder.Code) + + var errorResponse model.ErrorResponse + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &errorResponse) + assert.NoError(t, err) + + assert.Equal(t, errType, errorResponse.ErrorType) +} + +func createDummyFunctionMetadata() core.FunctionMetadata { + return core.FunctionMetadata{ + AccountID: "accID", + FunctionName: "myFunc", + FunctionVersion: "1.0", + Handler: "myHandler", + } +} + +func makeExternalAgent(registrationService core.RegistrationService) *core.ExternalAgent { + agent, err := registrationService.CreateExternalAgent("agent") + if err != nil { + log.Fatalf("failed to create external agent: %v", err) + return nil + } + + return agent +} + +func configureRendererForEvent(flowTest *testdata.FlowTest, isInvokeEvent bool) { + if isInvokeEvent { + invoke := createDummyInvoke() + + var buf bytes.Buffer + flowTest.RenderingService.SetRenderer( + rendering.NewInvokeRenderer( + context.Background(), + invoke, + &buf, + telemetry.NewNoOpTracer().BuildTracingHeader(), + )) + } else { + flowTest.RenderingService.SetRenderer( + &rendering.ShutdownRenderer{ + AgentEvent: model.AgentShutdownEvent{ + AgentEvent: &model.AgentEvent{ + EventType: "SHUTDOWN", + DeadlineMs: int64(10000), + }, + ShutdownReason: "spindown", + }, + }) + } +} + +func assertResponseEventType(t *testing.T, isInvokeEvent bool, responseRecorder *httptest.ResponseRecorder) { + if isInvokeEvent { + var response model.AgentInvokeEvent + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &response) + assert.NoError(t, err) + + assert.Equal(t, "INVOKE", response.AgentEvent.EventType) + } else { + var response model.AgentShutdownEvent + + respBody, _ := io.ReadAll(responseRecorder.Body) + err := json.Unmarshal(respBody, &response) + assert.NoError(t, err) + + assert.Equal(t, "SHUTDOWN", response.AgentEvent.EventType) + } +} + +func containsInvalidEvent(external bool, events []core.Event) bool { + for _, e := range events { + if external { + if err := core.ValidateExternalAgentEvent(e); err != nil { + return true + } + } else if err := core.ValidateInternalAgentEvent(e); err != nil { + return true + } + } + + return false +} diff --git a/lambda/rapi/handler/agentnext_test.go b/lambda/rapi/handler/agentnext_test.go index 003c4b6..417633e 100644 --- a/lambda/rapi/handler/agentnext_test.go +++ b/lambda/rapi/handler/agentnext_test.go @@ -4,6 +4,7 @@ package handler import ( + "bytes" "context" "encoding/json" "fmt" @@ -108,7 +109,8 @@ func TestRenderAgentInvokeNextHappy(t *testing.T) { } renderingService := rendering.NewRenderingService() - renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, telemetry.GetCustomerTracingHeader)) + var buf bytes.Buffer + renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, telemetry.NewNoOpTracer().BuildTracingHeader())) handler := NewAgentNextHandler(registrationService, renderingService) request := httptest.NewRequest("GET", "/", nil) @@ -157,7 +159,8 @@ func TestRenderAgentInternalInvokeNextHappy(t *testing.T) { } renderingService := rendering.NewRenderingService() - renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, telemetry.GetCustomerTracingHeader)) + var buf bytes.Buffer + renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, telemetry.NewNoOpTracer().BuildTracingHeader())) handler := NewAgentNextHandler(registrationService, renderingService) request := httptest.NewRequest("GET", "/", nil) @@ -287,7 +290,8 @@ func TestRenderAgentInvokeNextHappyEmptyTraceID(t *testing.T) { } renderingService := rendering.NewRenderingService() - renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, telemetry.GetCustomerTracingHeader)) + var buf bytes.Buffer + renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, telemetry.NewNoOpTracer().BuildTracingHeader())) handler := NewAgentNextHandler(registrationService, renderingService) request := httptest.NewRequest("GET", "/", nil) diff --git a/lambda/rapi/handler/agentregister.go b/lambda/rapi/handler/agentregister.go index 8882965..8da9e4c 100644 --- a/lambda/rapi/handler/agentregister.go +++ b/lambda/rapi/handler/agentregister.go @@ -8,6 +8,7 @@ import ( "errors" "io" "net/http" + "strings" log "github.com/sirupsen/logrus" "go.amzn.com/lambda/core" @@ -24,6 +25,20 @@ type RegisterRequest struct { Events []core.Event `json:"events"` } +const featuresHeader = "Lambda-Extension-Accept-Feature" + +type registrationFeature int + +const ( + accountFeature registrationFeature = iota + 1 +) + +var allowedFeatures = map[string]registrationFeature{ + "accountId": accountFeature, +} + +type responseModifier func(*model.ExtensionRegisterResponse) + func parseRegister(request *http.Request) (*RegisterRequest, error) { body, err := io.ReadAll(request.Body) if err != nil { @@ -53,6 +68,13 @@ func (h *agentRegisterHandler) ServeHTTP(writer http.ResponseWriter, request *ht return } + var responseModifiers []responseModifier + for _, f := range parseRegistrationFeatures(request) { + if f == accountFeature { + responseModifiers = append(responseModifiers, h.respondWithAccountID()) + } + } + registerRequest, err := parseRegister(request) if err != nil { rendering.RenderForbiddenWithTypeMsg(writer, request, errInvalidRequestFormat, err.Error()) @@ -60,32 +82,65 @@ func (h *agentRegisterHandler) ServeHTTP(writer http.ResponseWriter, request *ht } agent, found := h.registrationService.FindExternalAgentByName(agentName) - if found { - h.registerExternalAgent(agent, registerRequest, writer, request) + h.registerExternalAgent(agent, registerRequest, writer, request, responseModifiers...) } else { - h.registerInternalAgent(agentName, registerRequest, writer, request) + h.registerInternalAgent(agentName, registerRequest, writer, request, responseModifiers...) } } -func (h *agentRegisterHandler) renderResponse(agentID string, writer http.ResponseWriter, request *http.Request) { +func (h *agentRegisterHandler) respondWithAccountID() responseModifier { + return func(resp *model.ExtensionRegisterResponse) { + resp.AccountID = h.registrationService.GetFunctionMetadata().AccountID + } +} + +func parseRegistrationFeatures(request *http.Request) []registrationFeature { + rawFeatures := strings.Split(request.Header.Get(featuresHeader), ",") + + var features []registrationFeature + for _, feature := range rawFeatures { + feature = strings.TrimSpace(feature) + if v, found := allowedFeatures[feature]; found { + features = append(features, v) + } + } + + return features +} + +func (h *agentRegisterHandler) renderResponse( + agentID string, + writer http.ResponseWriter, + request *http.Request, + respModifiers ...responseModifier, +) { writer.Header().Set(LambdaAgentIdentifier, agentID) metadata := h.registrationService.GetFunctionMetadata() - resp := &model.ExtensionRegisterResponse{ FunctionVersion: metadata.FunctionVersion, FunctionName: metadata.FunctionName, Handler: metadata.Handler, } + for _, mod := range respModifiers { + mod(resp) + } + if err := rendering.RenderJSON(http.StatusOK, writer, request, resp); err != nil { log.WithError(err).Warn("Error while rendering response") http.Error(writer, err.Error(), http.StatusInternalServerError) } } -func (h *agentRegisterHandler) registerExternalAgent(agent *core.ExternalAgent, registerRequest *RegisterRequest, writer http.ResponseWriter, request *http.Request) { +func (h *agentRegisterHandler) registerExternalAgent( + agent *core.ExternalAgent, + registerRequest *RegisterRequest, + writer http.ResponseWriter, + request *http.Request, + respModifiers ...responseModifier, +) { for _, e := range registerRequest.Events { if err := core.ValidateExternalAgentEvent(e); err != nil { log.Warnf("Failed to register %s: event %s: %s", agent.Name, e, err) @@ -101,11 +156,17 @@ func (h *agentRegisterHandler) registerExternalAgent(agent *core.ExternalAgent, return } - h.renderResponse(agent.ID.String(), writer, request) + h.renderResponse(agent.ID.String(), writer, request, respModifiers...) log.Infof("External agent %s registered, subscribed to %v", agent.String(), registerRequest.Events) } -func (h *agentRegisterHandler) registerInternalAgent(agentName string, registerRequest *RegisterRequest, writer http.ResponseWriter, request *http.Request) { +func (h *agentRegisterHandler) registerInternalAgent( + agentName string, + registerRequest *RegisterRequest, + writer http.ResponseWriter, + request *http.Request, + respModifiers ...responseModifier, +) { for _, e := range registerRequest.Events { if err := core.ValidateInternalAgentEvent(e); err != nil { log.Warnf("Failed to register %s: event %s: %s", agentName, e, err) @@ -142,7 +203,7 @@ func (h *agentRegisterHandler) registerInternalAgent(agentName string, registerR return } - h.renderResponse(agent.ID.String(), writer, request) + h.renderResponse(agent.ID.String(), writer, request, respModifiers...) log.Infof("Internal agent %s registered, subscribed to %v", agent.String(), registerRequest.Events) } diff --git a/lambda/rapi/handler/agentregister_test.go b/lambda/rapi/handler/agentregister_test.go index 35456ee..7370c42 100644 --- a/lambda/rapi/handler/agentregister_test.go +++ b/lambda/rapi/handler/agentregister_test.go @@ -230,102 +230,167 @@ type ExtensionRegisterResponseWithConfig struct { Configuration map[string]string `json:"configuration"` } -var happyPathTests = []struct { - testName string - agentName string - external bool - registrationRequest RegisterRequest - functionMetadata *core.FunctionMetadata - expectedRegistrationResponse ExtensionRegisterResponseWithConfig -}{ - { - testName: "no-config-internal", - agentName: "internal", - external: false, - registrationRequest: RegisterRequest{}, - expectedRegistrationResponse: ExtensionRegisterResponseWithConfig{ - ExtensionRegisterResponse: model.ExtensionRegisterResponse{ - FunctionName: "my-func", - FunctionVersion: "$LATEST", - Handler: "lambda_handler", +func TestRenderAgentResponse(t *testing.T) { + defaultFunctionMetadata := core.FunctionMetadata{ + FunctionVersion: "$LATEST", + FunctionName: "my-func", + Handler: "lambda_handler", + } + + happyPathTests := map[string]struct { + agentName string + external bool + registrationRequest RegisterRequest + featuresHeader string + functionMetadata core.FunctionMetadata + expectedResponse string + }{ + "no-config-internal": { + agentName: "internal", + external: false, + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, + }, + "no-config-external": { + agentName: "external", + external: true, + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, + }, + "function-md-override": { + agentName: "external", + external: true, + functionMetadata: core.FunctionMetadata{FunctionName: "function-name", FunctionVersion: "1", Handler: "myHandler"}, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler" + }`, + }, + "internal with account id feature": { + agentName: "internal", + external: false, + functionMetadata: core.FunctionMetadata{ + FunctionName: "function-name", + FunctionVersion: "1", + Handler: "myHandler", + AccountID: "0123", }, + featuresHeader: "accountId", + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler", + "accountId": "0123" + }`, }, - }, - { - testName: "no-config-external", - agentName: "external", - external: true, - registrationRequest: RegisterRequest{}, - expectedRegistrationResponse: ExtensionRegisterResponseWithConfig{ - ExtensionRegisterResponse: model.ExtensionRegisterResponse{ - FunctionName: "my-func", - FunctionVersion: "$LATEST", - Handler: "lambda_handler", + "external with account id feature": { + agentName: "external", + external: true, + functionMetadata: core.FunctionMetadata{ + FunctionName: "function-name", + FunctionVersion: "1", + Handler: "myHandler", + AccountID: "0123", }, + featuresHeader: "accountId", + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler", + "accountId": "0123" + }`, + }, + "with non-existing accept feature": { + agentName: "external", + external: true, + featuresHeader: "some_non_existing_feature,", + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, }, - }, - { - testName: "function-md-override", - agentName: "external", - external: true, - functionMetadata: &core.FunctionMetadata{FunctionName: "function-name", FunctionVersion: "1", Handler: "myHandler"}, - registrationRequest: RegisterRequest{}, - expectedRegistrationResponse: ExtensionRegisterResponseWithConfig{ - ExtensionRegisterResponse: model.ExtensionRegisterResponse{ + "account id feature and some non-existing feature": { + agentName: "external", + external: true, + featuresHeader: "some_non_existing_feature,accountId,", + functionMetadata: core.FunctionMetadata{ FunctionName: "function-name", FunctionVersion: "1", Handler: "myHandler", + AccountID: "0123", }, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "function-name", + "functionVersion": "1", + "handler": "myHandler", + "accountId": "0123" + }`, + }, + "with empty account id data": { + agentName: "external", + external: true, + featuresHeader: "accountId", + functionMetadata: defaultFunctionMetadata, + registrationRequest: RegisterRequest{}, + expectedResponse: `{ + "functionName": "my-func", + "functionVersion": "$LATEST", + "handler": "lambda_handler" + }`, }, - }, -} - -func TestRenderAgentResponse(t *testing.T) { - defaultFunctionMetadata := core.FunctionMetadata{ - FunctionVersion: "$LATEST", - FunctionName: "my-func", - Handler: "lambda_handler", } - for _, tt := range happyPathTests { - t.Run(tt.testName, func(t *testing.T) { + for name, tt := range happyPathTests { + t.Run(name, func(t *testing.T) { registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), core.NewInvokeFlowSynchronization(), ) registrationService.CreateExternalAgent("external") // external agent has to be pre-registered - if tt.functionMetadata != nil { - registrationService.SetFunctionMetadata(*tt.functionMetadata) - } else { - registrationService.SetFunctionMetadata(defaultFunctionMetadata) - } + registrationService.SetFunctionMetadata(tt.functionMetadata) handler := NewAgentRegisterHandler(registrationService) request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(tt.registrationRequest)) request.Header.Add(LambdaAgentName, tt.agentName) + if tt.featuresHeader != "" { + request.Header.Add(featuresHeader, tt.featuresHeader) + } responseRecorder := httptest.NewRecorder() handler.ServeHTTP(responseRecorder, request) - require.Equal(t, http.StatusOK, responseRecorder.Code) - - registerResponse := ExtensionRegisterResponseWithConfig{} - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, ®isterResponse) - assert.Equal(t, tt.expectedRegistrationResponse.FunctionName, registerResponse.FunctionName) - assert.Equal(t, tt.expectedRegistrationResponse.FunctionVersion, registerResponse.FunctionVersion) - assert.Equal(t, tt.expectedRegistrationResponse.Handler, registerResponse.Handler) + assert.Equal(t, http.StatusOK, responseRecorder.Code) - require.Len(t, registerResponse.Configuration, 0) + respBody, err := io.ReadAll(responseRecorder.Body) + require.NoError(t, err) + assert.JSONEq(t, tt.expectedResponse, string(respBody)) if tt.external { agent, found := registrationService.FindExternalAgentByName(tt.agentName) - require.True(t, found) - require.Equal(t, agent.RegisteredState, agent.GetState()) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) } else { agent, found := registrationService.FindInternalAgentByName(tt.agentName) - require.True(t, found) - require.Equal(t, agent.RegisteredState, agent.GetState()) + assert.True(t, found) + assert.Equal(t, agent.RegisteredState, agent.GetState()) } }) } diff --git a/lambda/rapi/handler/initerror.go b/lambda/rapi/handler/initerror.go index d28e2d4..79daa1f 100644 --- a/lambda/rapi/handler/initerror.go +++ b/lambda/rapi/handler/initerror.go @@ -9,8 +9,8 @@ import ( "net/http" "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/telemetry" "go.amzn.com/lambda/core" "go.amzn.com/lambda/rapi/rendering" @@ -20,21 +20,40 @@ import ( type initErrorHandler struct { registrationService core.RegistrationService - eventsAPI telemetry.EventsAPI } func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { appCtx := appctx.FromRequest(request) - - server := appctx.LoadInteropServer(appCtx) - if server == nil { + interopServer := appctx.LoadInteropServer(appCtx) + if interopServer == nil { log.Panic("Invalid state, cannot access interop server") } + errorType := fatalerror.GetValidRuntimeOrFunctionErrorType(request.Header.Get("Lambda-Runtime-Function-Error-Type")) + fnError := interop.FunctionError{Type: errorType} + errorBody, err := io.ReadAll(request.Body) + if err != nil { + log.WithError(err).Warn("Failed to read error body") + } + headers := interop.InvokeResponseHeaders{ContentType: determineJSONContentType(errorBody)} + response := &interop.ErrorInvokeResponse{Headers: headers, FunctionError: fnError, Payload: errorBody} + runtime := h.registrationService.GetRuntime() - // the previousStateName is needed to define if the init/error is called for INIT or RESTORE - previousStateName := runtime.GetState().Name() + // remove once Languages team change the endpoint to /restore/error + // when an exception is throw while executing the restore hooks + if runtime.GetState() == runtime.RuntimeRestoringState { + if err := runtime.RestoreError(fnError); err != nil { + log.Warn(err) + rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, + runtime.GetState().Name(), core.RuntimeRestoreErrorStateName, err) + return + } + + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) + rendering.RenderAccepted(writer, request) + return + } if err := runtime.InitError(); err != nil { log.Warn(err) @@ -43,42 +62,19 @@ func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.R return } - errorType := request.Header.Get("Lambda-Runtime-Function-Error-Type") - - errorBody, err := io.ReadAll(request.Body) - if err != nil { - log.WithError(err).Warn("Failed to read error body") - } - - if previousStateName == core.RuntimeRestoringStateName { - h.sendRestoreRuntimeDoneLogEvent() - } else { - h.sendInitRuntimeDoneLogEvent(appCtx) - } - - response := &interop.ErrorResponse{ - ErrorType: errorType, - Payload: errorBody, - ContentType: determineJSONContentType(errorBody), - } - - if err := server.SendInitErrorResponse(server.GetCurrentInvokeID(), response); err != nil { + if err := interopServer.SendInitErrorResponse(response); err != nil { rendering.RenderInteropError(writer, request, err) return } - appctx.StoreErrorResponse(appCtx, response) - + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) rendering.RenderAccepted(writer, request) } // NewInitErrorHandler returns a new instance of http handler // for serving /runtime/init/error. -func NewInitErrorHandler(registrationService core.RegistrationService, eventsAPI telemetry.EventsAPI) http.Handler { - return &initErrorHandler{ - registrationService: registrationService, - eventsAPI: eventsAPI, - } +func NewInitErrorHandler(registrationService core.RegistrationService) http.Handler { + return &initErrorHandler{registrationService: registrationService} } func determineJSONContentType(body []byte) string { @@ -87,24 +83,3 @@ func determineJSONContentType(body []byte) string { } return "application/octet-stream" } - -func (h *initErrorHandler) sendInitRuntimeDoneLogEvent(appCtx appctx.ApplicationContext) { - // ToDo: Convert this to an enum for the whole package to increase readability. - initCachingEnabled := appctx.LoadInitType(appCtx) == appctx.InitCaching - - initSource := interop.InferTelemetryInitSource(initCachingEnabled, appctx.LoadSandboxType(appCtx)) - runtimeDoneData := &telemetry.InitRuntimeDoneData{ - InitSource: initSource, - Status: telemetry.RuntimeDoneFailure, - } - - if err := h.eventsAPI.SendInitRuntimeDone(runtimeDoneData); err != nil { - log.Errorf("Failed to send INITRD: %s", err) - } -} - -func (h *initErrorHandler) sendRestoreRuntimeDoneLogEvent() { - if err := h.eventsAPI.SendRestoreRuntimeDone(telemetry.RuntimeDoneFailure); err != nil { - log.Errorf("Failed to send RESTRD: %s", err) - } -} diff --git a/lambda/rapi/handler/initerror_test.go b/lambda/rapi/handler/initerror_test.go index c9a5a83..a9c4b94 100644 --- a/lambda/rapi/handler/initerror_test.go +++ b/lambda/rapi/handler/initerror_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" "go.amzn.com/lambda/appctx" - + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/testdata" ) @@ -27,7 +27,7 @@ func runTestInitErrorHandler(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - handler := NewInitErrorHandler(flowTest.RegistrationService, flowTest.EventsAPI) + handler := NewInitErrorHandler(flowTest.RegistrationService) responseRecorder := httptest.NewRecorder() appCtx := flowTest.AppCtx @@ -60,12 +60,12 @@ func runTestInitErrorHandler(t *testing.T) { // payload is not provided. This fallback is not part // of the RAPID API spec and is not available to // customers. - require.Equal(t, "", errorResponse.ErrorMessage) + require.Equal(t, "", errorResponse.FunctionError.Message) // Slicer falls back to using ErrorType when error // payload is not provided. Customers can set error // type via header to use this fallback. - require.Equal(t, errorType, errorResponse.ErrorType) + require.Equal(t, fatalerror.RuntimeUnknown, errorResponse.FunctionError.Type) // Payload is arbitrary data that customers submit - it's error response body. require.Equal(t, errorBody, errorResponse.Payload) diff --git a/lambda/rapi/handler/invocationerror.go b/lambda/rapi/handler/invocationerror.go index 170c0cb..d434461 100644 --- a/lambda/rapi/handler/invocationerror.go +++ b/lambda/rapi/handler/invocationerror.go @@ -9,6 +9,7 @@ import ( "io" "net/http" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/model" @@ -37,7 +38,7 @@ type invocationErrorHandler struct { func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { appCtx := appctx.FromRequest(request) - server := appctx.LoadInteropServer(appCtx) + server := appctx.LoadResponseSender(appCtx) if server == nil { log.Panic("Invalid state, cannot access interop server") } @@ -50,7 +51,7 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * return } - errorType := h.getErrorType(request.Header) + errorType := fatalerror.GetValidRuntimeOrFunctionErrorType(h.getErrorType(request.Header)) var errorCause json.RawMessage var errorBody []byte @@ -75,20 +76,23 @@ func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request * log.WithError(err).Warn("Failed to parse error body") } - response := &interop.ErrorResponse{ - ErrorType: errorType, - Payload: errorBody, - ErrorCause: errorCause, + headers := interop.InvokeResponseHeaders{ ContentType: contentType, FunctionResponseMode: functionResponseMode, } + response := &interop.ErrorInvokeResponse{ + Headers: headers, + FunctionError: interop.FunctionError{Type: errorType}, + Payload: errorBody, + } + if err := server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), response); err != nil { rendering.RenderInteropError(writer, request, err) return } - appctx.StoreErrorResponse(appCtx, response) + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{ErrorCause: errorCause}) if err := runtime.ResponseSent(); err != nil { log.Panic(err) diff --git a/lambda/rapi/handler/invocationerror_test.go b/lambda/rapi/handler/invocationerror_test.go index 2f177fe..72e6719 100644 --- a/lambda/rapi/handler/invocationerror_test.go +++ b/lambda/rapi/handler/invocationerror_test.go @@ -14,6 +14,7 @@ import ( "testing" "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/model" "go.amzn.com/lambda/testdata" @@ -87,12 +88,12 @@ func runTestInvocationErrorHandler(t *testing.T) { // payload is not provided. This fallback is not part // of the RAPID API spec and is not available to // customers. - assert.Equal(t, "", errorResponse.ErrorMessage) + assert.Equal(t, "", errorResponse.FunctionError.Message) // Slicer falls back to using ErrorType when error // payload is not provided. Customers can set error // type header to use this fallback. - assert.Equal(t, errorType, errorResponse.ErrorType) + assert.Equal(t, fatalerror.RuntimeUnknown, errorResponse.FunctionError.Type) // Payload is arbitrary data that customers submit - it's error response body. assert.Equal(t, errorBody, errorResponse.Payload) @@ -176,10 +177,10 @@ func TestInvocationErrorHandlerSendsErrorCauseToXRayForContentTypeErrorCause(t * handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) // Assert error response contains error cause - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.JSONEq(t, string(errorCause), string(errorResponse.ErrorCause)) + assert.JSONEq(t, string(errorCause), string(invokeErrorTraceData.ErrorCause)) } func TestInvocationErrorHandlerSendsNullErrorCauseWhenErrorCauseFormatIsInvalidOrEmptyForContentTypeErrorCause(t *testing.T) { @@ -213,10 +214,10 @@ func TestInvocationErrorHandlerSendsNullErrorCauseWhenErrorCauseFormatIsInvalidO // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, json.RawMessage(nil), errorResponse.ErrorCause) + assert.Equal(t, json.RawMessage(nil), invokeErrorTraceData.ErrorCause) } } @@ -248,11 +249,11 @@ func TestInvocationErrorHandlerSendsCompactedErrorCauseWhenErrorCauseIsTooLargeF // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - errorCauseJSON, err := model.ValidatedErrorCauseJSON(errorResponse.ErrorCause) + errorCauseJSON, err := model.ValidatedErrorCauseJSON(invokeErrorTraceData.ErrorCause) assert.NoError(t, err, "expected cause sent x-ray to be valid") assert.True(t, len(errorCauseJSON) < model.MaxErrorCauseSizeBytes, "expected cause to be compacted to size") } @@ -277,12 +278,13 @@ func TestInvocationResponsePayloadIsDefaultErrorMessageWhenRequestParsingFailsFo // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) assert.Equal(t, "application/octet-stream", flowTest.InteropServer.ResponseContentType) assert.Equal(t, "function-response-mode", flowTest.InteropServer.FunctionResponseMode) + errorResponse := flowTest.InteropServer.ErrorResponse invokeResponsePayload := errorResponse.Payload expectedResponse, _ := json.Marshal(invalidErrorBodyMessage) @@ -311,10 +313,10 @@ func TestInvocationErrorHandlerSendsErrorCauseToXRayWhenXRayErrorCauseHeaderIsSe // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.JSONEq(t, string(errorCause), string(errorResponse.ErrorCause)) + assert.JSONEq(t, string(errorCause), string(invokeErrorTraceData.ErrorCause)) } func TestInvocationErrorHandlerSendsNilCauseToXRayWhenXRayErrorCauseHeaderContainsInvalidCause(t *testing.T) { @@ -340,10 +342,10 @@ func TestInvocationErrorHandlerSendsNilCauseToXRayWhenXRayErrorCauseHeaderContai // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, json.RawMessage(nil), errorResponse.ErrorCause) + assert.Equal(t, json.RawMessage(nil), invokeErrorTraceData.ErrorCause) } } @@ -366,11 +368,11 @@ func TestInvocationErrorHandlerSendsCompactedErrorCauseToXRayWhenXRayErrorCauseI // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - errorCauseJSON, err := model.ValidatedErrorCauseJSON(errorResponse.ErrorCause) + errorCauseJSON, err := model.ValidatedErrorCauseJSON(invokeErrorTraceData.ErrorCause) assert.NoError(t, err, "expected cause sent x-ray to be valid") assert.True(t, len(errorCauseJSON) < model.MaxErrorCauseSizeBytes, "expected cause to be compacted to size") } @@ -391,10 +393,10 @@ func TestInvocationErrorHandlerSendsNilToXRayWhenXRayErrorCauseHeaderIsNotSet(t // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.Nil(t, errorResponse.ErrorCause) + assert.Nil(t, invokeErrorTraceData.ErrorCause) } func TestInvocationErrorHandlerSendsErrorCauseToXRayWhenXRayErrorCauseContainsUTF8Characters(t *testing.T) { @@ -416,8 +418,8 @@ func TestInvocationErrorHandlerSendsErrorCauseToXRayWhenXRayErrorCauseContainsUT // Run NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) + invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) + assert.NotNil(t, invokeErrorTraceData) assert.Nil(t, flowTest.InteropServer.Response) - assert.JSONEq(t, string(errorCause), string(errorResponse.ErrorCause)) + assert.JSONEq(t, string(errorCause), string(invokeErrorTraceData.ErrorCause)) } diff --git a/lambda/rapi/handler/invocationnext_test.go b/lambda/rapi/handler/invocationnext_test.go index 5bddb86..64ae057 100644 --- a/lambda/rapi/handler/invocationnext_test.go +++ b/lambda/rapi/handler/invocationnext_test.go @@ -4,6 +4,7 @@ package handler import ( + "bytes" "context" "errors" "fmt" @@ -19,6 +20,8 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" @@ -45,57 +48,65 @@ func TestRenderInvokeEmptyHeaders(t *testing.T) { assert.Equal(t, http.StatusOK, responseRecorder.Code) } -func TestRenderInvoke(t *testing.T) { +func TestRenderInvokeHappy(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := httptest.NewRecorder() appCtx := flowTest.AppCtx deadlineNs := 12345 - invokePayload := "Payload" invoke := &interop.Invoke{ TraceID: "Root=RootID;Parent=LambdaFrontend;Sampled=1", - ID: "ID", + ID: "", // updated in loop InvokedFunctionArn: "InvokedFunctionArn", CognitoIdentityID: "CognitoIdentityId1", CognitoIdentityPoolID: "CognitoIdentityPoolId1", ClientContext: "ClientContext", DeadlineNs: strconv.Itoa(deadlineNs), ContentType: "image/png", - Payload: strings.NewReader(invokePayload), + Payload: strings.NewReader(""), // updated in loop } ctx := telemetry.NewTraceContext(context.Background(), "RootID", "InvocationSubegmentID") - flowTest.ConfigureForInvoke(ctx, invoke) - - request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - handler.ServeHTTP(responseRecorder, request) + var requestBuffer bytes.Buffer + for i := 0; i < 6; i++ { + handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) + responseRecorder := httptest.NewRecorder() + invoke.ID = fmt.Sprintf("ID-%d", i) + invokePayload := string(bytes.Repeat([]byte("a"), (i%3)*128*1024)) // vary payload size up and down across invokes + invoke.Payload = strings.NewReader(invokePayload) + + flowTest.ConfigureForInvoke(ctx, invoke) + flowTest.ConfigureInvokeRenderer(ctx, invoke, &requestBuffer) // reuse request buffer on each invoke + request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) + handler.ServeHTTP(responseRecorder, request) - headers := responseRecorder.Header() - assert.Equal(t, invoke.InvokedFunctionArn, headers.Get("Lambda-Runtime-Invoked-Function-Arn")) - assert.Equal(t, invoke.ID, headers.Get("Lambda-Runtime-Aws-Request-Id")) - assert.Equal(t, invoke.ClientContext, headers.Get("Lambda-Runtime-Client-Context")) - expectedCognitoIdentityHeader := fmt.Sprintf("{\"cognitoIdentityId\":\"%s\",\"cognitoIdentityPoolId\":\"%s\"}", invoke.CognitoIdentityID, invoke.CognitoIdentityPoolID) - assert.JSONEq(t, expectedCognitoIdentityHeader, headers.Get("Lambda-Runtime-Cognito-Identity")) - assert.Equal(t, "Root=RootID;Parent=InvocationSubegmentID;Sampled=1", headers.Get("Lambda-Runtime-Trace-Id")) - - // Assert deadline precision. E.g. 1999 ns and 2001 ns having diff of 2 ns - // would result in 1ms and 2ms deadline correspondingly. - expectedDeadline := metering.MonoToEpoch(int64(deadlineNs)) / int64(time.Millisecond) - receivedDeadline, _ := strconv.ParseInt(headers.Get("Lambda-Runtime-Deadline-Ms"), 10, 64) - assert.True(t, math.Abs(float64(expectedDeadline-receivedDeadline)) <= float64(1), - fmt.Sprintf("Expected: %v, received: %v", expectedDeadline, receivedDeadline)) - - assert.Equal(t, "image/png", headers.Get("Content-Type")) - assert.Len(t, headers, 7) - assert.Equal(t, invokePayload, responseRecorder.Body.String()) + headers := responseRecorder.Header() + assert.Equal(t, invoke.InvokedFunctionArn, headers.Get("Lambda-Runtime-Invoked-Function-Arn")) + assert.Equal(t, invoke.ID, headers.Get("Lambda-Runtime-Aws-Request-Id")) + assert.Equal(t, invoke.ClientContext, headers.Get("Lambda-Runtime-Client-Context")) + expectedCognitoIdentityHeader := fmt.Sprintf("{\"cognitoIdentityId\":\"%s\",\"cognitoIdentityPoolId\":\"%s\"}", invoke.CognitoIdentityID, invoke.CognitoIdentityPoolID) + assert.JSONEq(t, expectedCognitoIdentityHeader, headers.Get("Lambda-Runtime-Cognito-Identity")) + assert.Equal(t, "Root=RootID;Parent=InvocationSubegmentID;Sampled=1", headers.Get("Lambda-Runtime-Trace-Id")) + + // Assert deadline precision. E.g. 1999 ns and 2001 ns having diff of 2 ns + // would result in 1ms and 2ms deadline correspondingly. + expectedDeadline := metering.MonoToEpoch(int64(deadlineNs)) / int64(time.Millisecond) + receivedDeadline, _ := strconv.ParseInt(headers.Get("Lambda-Runtime-Deadline-Ms"), 10, 64) + assert.True(t, math.Abs(float64(expectedDeadline-receivedDeadline)) <= float64(1), + fmt.Sprintf("Expected: %v, received: %v", expectedDeadline, receivedDeadline)) + + assert.Equal(t, "image/png", headers.Get("Content-Type")) + assert.Len(t, headers, 7) + responsePayload := responseRecorder.Body.String() + require.Equalf(t, len(invokePayload), len(responsePayload), "Unexpected payload for request %d", i) + assert.Equal(t, invokePayload, responsePayload) + } } // Cgo calls removed due to crashes while spawning threads under memory pressure. func TestRenderInvokeDoesNotCallCgo(t *testing.T) { cgoCallsBefore := runtime.NumCgoCall() - TestRenderInvoke(t) + TestRenderInvokeHappy(t) cgoCallsAfter := runtime.NumCgoCall() assert.Equal(t, cgoCallsBefore, cgoCallsAfter) } diff --git a/lambda/rapi/handler/invocationresponse.go b/lambda/rapi/handler/invocationresponse.go index 7e47d2e..d267775 100644 --- a/lambda/rapi/handler/invocationresponse.go +++ b/lambda/rapi/handler/invocationresponse.go @@ -8,6 +8,7 @@ import ( "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapi/rendering" @@ -17,7 +18,6 @@ import ( const ( StreamingFunctionResponseMode = "streaming" - ErrInvalidResponseModeHeader = "Runtime.InvalidResponseModeHeader" ) type invocationResponseHandler struct { @@ -27,7 +27,7 @@ type invocationResponseHandler struct { func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { appCtx := appctx.FromRequest(request) - server := appctx.LoadInteropServer(appCtx) + server := appctx.LoadResponseSender(appCtx) if server == nil { log.Panic("Invalid state, cannot access interop server") } @@ -48,25 +48,38 @@ func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, reques case StreamingFunctionResponseMode: headers[functionResponseModeHeader] = functionResponseMode default: - errorResponse := &interop.ErrorResponse{ - ErrorType: ErrInvalidResponseModeHeader, + errHeaders := interop.InvokeResponseHeaders{ ContentType: request.Header.Get(contentTypeHeader), } - _ = server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), errorResponse) + fnError := interop.FunctionError{Type: fatalerror.RuntimeInvalidResponseModeHeader} + response := &interop.ErrorInvokeResponse{ + Headers: errHeaders, + FunctionError: fnError, + Payload: []byte{}, + } + + _ = server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), response) rendering.RenderInvalidFunctionResponseMode(writer, request) return } } - if err := server.SendResponse(invokeID, headers, request.Body, request.Trailer, &interop.CancellableRequest{Request: request}); err != nil { + response := &interop.StreamableInvokeResponse{ + Headers: headers, + Payload: request.Body, + Trailers: request.Trailer, + Request: &interop.CancellableRequest{Request: request}, + } + + if err := server.SendResponse(invokeID, response); err != nil { switch err := err.(type) { case *interop.ErrorResponseTooLarge: - if server.SendErrorResponse(invokeID, err.AsInteropError()) != nil { + if server.SendErrorResponse(invokeID, err.AsErrorResponse()) != nil { rendering.RenderInteropError(writer, request, err) return } - appctx.StoreErrorResponse(appCtx, err.AsInteropError()) + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) if err := runtime.ResponseSent(); err != nil { log.Panic(err) diff --git a/lambda/rapi/handler/invocationresponse_test.go b/lambda/rapi/handler/invocationresponse_test.go index 7c0b220..dc29c10 100644 --- a/lambda/rapi/handler/invocationresponse_test.go +++ b/lambda/rapi/handler/invocationresponse_test.go @@ -17,6 +17,7 @@ import ( "github.com/aws/aws-lambda-go/events/test" "github.com/stretchr/testify/assert" "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/testdata" ) @@ -62,13 +63,13 @@ func TestResponseTooLarge(t *testing.T) { errorResponse := flowTest.InteropServer.ErrorResponse assert.NotNil(t, errorResponse) assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, "Function.ResponseSizeTooLarge", errorResponse.ErrorType) - assert.Equal(t, "Response payload size (6291557 bytes) exceeded maximum allowed payload size (6291556 bytes).", errorResponse.ErrorMessage) + assert.Equal(t, fatalerror.FunctionOversizedResponse, errorResponse.FunctionError.Type) + assert.Equal(t, "Response payload size (6291557 bytes) exceeded maximum allowed payload size (6291556 bytes).", errorResponse.FunctionError.Message) var errorPayload map[string]interface{} assert.NoError(t, json.Unmarshal(errorResponse.Payload, &errorPayload)) - assert.Equal(t, errorResponse.ErrorType, errorPayload["errorType"]) - assert.Equal(t, errorResponse.ErrorMessage, errorPayload["errorMessage"]) + assert.Equal(t, string(errorResponse.FunctionError.Type), errorPayload["errorType"]) + assert.Equal(t, errorResponse.FunctionError.Message, errorPayload["errorMessage"]) } func TestResponseAccepted(t *testing.T) { @@ -193,7 +194,7 @@ func TestResponseWithDifferentFunctionResponseModes(t *testing.T) { if testCase.expectedErrorResponse { assert.NotNil(t, flowTest.InteropServer.ErrorResponse) assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, "Runtime.InvalidResponseModeHeader", flowTest.InteropServer.ErrorResponse.ErrorType) + assert.Equal(t, fatalerror.RuntimeInvalidResponseModeHeader, flowTest.InteropServer.ErrorResponse.FunctionError.Type) } else { assert.NotNil(t, flowTest.InteropServer.Response) assert.Nil(t, flowTest.InteropServer.ErrorResponse) diff --git a/lambda/rapi/handler/restoreerror.go b/lambda/rapi/handler/restoreerror.go new file mode 100644 index 0000000..eed97b2 --- /dev/null +++ b/lambda/rapi/handler/restoreerror.go @@ -0,0 +1,47 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "net/http" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi/rendering" +) + +type restoreErrorHandler struct { + registrationService core.RegistrationService +} + +func (h *restoreErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + appCtx := appctx.FromRequest(request) + server := appctx.LoadInteropServer(appCtx) + if server == nil { + log.Panic("Invalid state, cannot access interop server") + } + + errorType := fatalerror.GetValidRuntimeOrFunctionErrorType(request.Header.Get("Lambda-Runtime-Function-Error-Type")) + fnError := interop.FunctionError{Type: errorType} + + runtime := h.registrationService.GetRuntime() + + if err := runtime.RestoreError(fnError); err != nil { + log.Warn(err) + rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, + runtime.GetState().Name(), core.RuntimeRestoreErrorStateName, err) + return + } + + appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) + + rendering.RenderAccepted(writer, request) +} + +func NewRestoreErrorHandler(registrationService core.RegistrationService) http.Handler { + return &restoreErrorHandler{registrationService: registrationService} +} diff --git a/lambda/rapi/handler/restoreerror_test.go b/lambda/rapi/handler/restoreerror_test.go new file mode 100644 index 0000000..57226fa --- /dev/null +++ b/lambda/rapi/handler/restoreerror_test.go @@ -0,0 +1,44 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/testdata" +) + +func TestRestoreErrorHandler(t *testing.T) { + t.Run("GA", func(t *testing.T) { runTestRestoreErrorHandler(t) }) +} + +func runTestRestoreErrorHandler(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForRestoring() + + handler := NewRestoreErrorHandler(flowTest.RegistrationService) + responseRecorder := httptest.NewRecorder() + appCtx := flowTest.AppCtx + + errorBody := []byte("My byte array is yours") + errorType := "ErrorType" + errorContentType := "application/MyBinaryType" + + request := appctx.RequestWithAppCtx(httptest.NewRequest("POST", "/", bytes.NewReader(errorBody)), appCtx) + + request.Header.Set("Content-Type", errorContentType) + request.Header.Set("Lambda-Runtime-Function-Error-Type", errorType) + + handler.ServeHTTP(responseRecorder, request) + + require.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", responseRecorder.Code, http.StatusAccepted) + require.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) + require.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) +} diff --git a/lambda/rapi/handler/runtimelogs.go b/lambda/rapi/handler/runtimelogs.go index 99941b0..6b8a67e 100644 --- a/lambda/rapi/handler/runtimelogs.go +++ b/lambda/rapi/handler/runtimelogs.go @@ -9,10 +9,10 @@ import ( "fmt" "io" "net/http" + "strings" "go.amzn.com/lambda/core" "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/rapidcore/telemetry/logsapi" "go.amzn.com/lambda/telemetry" "github.com/google/uuid" @@ -31,10 +31,10 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http switch err := err.(type) { case *ErrAgentIdentifierUnknown: rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension "+err.agentID.String()) - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) } return } @@ -45,21 +45,21 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http if err != nil { log.Error(err) rendering.RenderInternalServerError(writer, request) - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) return } - respBody, status, headers, err := h.telemetrySubscription.Subscribe(agentName, bytes.NewReader(body), request.Header) + respBody, status, headers, err := h.telemetrySubscription.Subscribe(agentName, bytes.NewReader(body), request.Header, request.RemoteAddr) if err != nil { log.Errorf("Telemetry API error: %s", err) switch err { - case logsapi.ErrTelemetryServiceOff: + case telemetry.ErrTelemetryServiceOff: rendering.RenderForbiddenWithTypeMsg(writer, request, h.telemetrySubscription.GetServiceClosedErrorType(), h.telemetrySubscription.GetServiceClosedErrorMessage()) - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) } return } @@ -67,11 +67,14 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http rendering.RenderRuntimeLogsResponse(writer, respBody, status, headers) switch status / 100 { case 2: // 2xx - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeSuccess, 1) + if strings.Contains(string(respBody), "OK") { + h.telemetrySubscription.RecordCounterMetric(telemetry.NumSubscribers, 1) + } + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeSuccess, 1) case 4: // 4xx - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeClientErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) case 5: // 5xx - h.telemetrySubscription.RecordCounterMetric(logsapi.SubscribeServerErr, 1) + h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) } } diff --git a/lambda/rapi/handler/runtimelogs_test.go b/lambda/rapi/handler/runtimelogs_test.go index 892d61e..cbb8b0b 100644 --- a/lambda/rapi/handler/runtimelogs_test.go +++ b/lambda/rapi/handler/runtimelogs_test.go @@ -9,23 +9,24 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "testing" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore/telemetry/logsapi" ) type mockSubscriptionAPI struct{ mock.Mock } -func (s *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { - args := s.Called(agentName, body, headers) +func (s *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { + args := s.Called(agentName, body, headers, remoteAddr) return args.Get(0).([]byte), args.Int(1), args.Get(2).(map[string][]string), args.Error(3) } @@ -61,10 +62,15 @@ func (s *mockSubscriptionAPI) GetServiceClosedErrorType() string { return args.Get(0).(string) } +func validIPPort(addr string) bool { + ip, _, err := net.SplitHostPort(addr) + return err == nil && net.ParseIP(ip) != nil +} + func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} respBody, respStatus, respHeaders := []byte(`barbaz`), http.StatusNotFound, map[string][]string{"K": []string{"V1", "V2"}} - clientErrMetric := logsapi.SubscribeClientErr + clientErrMetric := telemetry.SubscribeClientErr registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), @@ -75,7 +81,7 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { assert.NoError(t, err) telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) @@ -91,7 +97,7 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { handler.ServeHTTP(responseRecorder, request) - telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders) + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) recordedBody, err := io.ReadAll(responseRecorder.Body) @@ -102,10 +108,97 @@ func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) } +func TestSuccessfulTelemetryAPIPutRequest(t *testing.T) { + agentName, reqBody, reqHeaders := "extensionName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} + respBody, respStatus, respHeaders := []byte(`"OK"`), http.StatusOK, map[string][]string{"K": []string{"V1", "V2"}} + numSubscribersMetric := telemetry.NumSubscribers + subscribeSuccessMetric := telemetry.SubscribeSuccess + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization(), + core.NewInvokeFlowSynchronization(), + ) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("RecordCounterMetric", numSubscribersMetric, 1) + telemetrySubscription.On("RecordCounterMetric", subscribeSuccessMetric, 1) + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", numSubscribersMetric, 1) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", subscribeSuccessMetric, 1) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + assert.Equal(t, respStatus, responseRecorder.Code) + assert.Equal(t, respBody, recordedBody) + assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) +} + +func TestNumberOfSubscribersWhenAnExtensionIsAlreadySubscribed(t *testing.T) { + agentName, reqBody, reqHeaders := "extensionName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} + respBody, respStatus, respHeaders := []byte(`"AlreadySubcribed"`), http.StatusOK, map[string][]string{"K": []string{"V1", "V2"}} + numSubscribersMetric := telemetry.NumSubscribers + subscribeSuccessMetric := telemetry.SubscribeSuccess + + registrationService := core.NewRegistrationService( + core.NewInitFlowSynchronization(), + core.NewInvokeFlowSynchronization(), + ) + + agent, err := registrationService.CreateExternalAgent(agentName) + assert.NoError(t, err) + + telemetrySubscription := &mockSubscriptionAPI{} + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) + telemetrySubscription.On("RecordCounterMetric", subscribeSuccessMetric, 1) + + handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) + request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) + for k, vals := range reqHeaders { + for _, v := range vals { + request.Header.Add(k, v) + } + } + + request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) + responseRecorder := httptest.NewRecorder() + + handler.ServeHTTP(responseRecorder, request) + + telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) + telemetrySubscription.AssertCalled(t, "RecordCounterMetric", subscribeSuccessMetric, 1) + telemetrySubscription.AssertNotCalled(t, "RecordCounterMetric", numSubscribersMetric, mock.Anything) + + recordedBody, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + + assert.Equal(t, respStatus, responseRecorder.Code) + assert.Equal(t, respBody, recordedBody) + assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) +} + func TestErrorUnregisteredAgentID(t *testing.T) { invalidAgentID := uuid.New() reqBody, reqHeaders := []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - clientErrMetric := logsapi.SubscribeClientErr + clientErrMetric := telemetry.SubscribeClientErr registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), @@ -143,7 +236,7 @@ func TestErrorUnregisteredAgentID(t *testing.T) { func TestErrorTelemetryAPICallFailure(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} apiError := errors.New("Error calling Telemetry API: connection refused") - serverErrMetric := logsapi.SubscribeServerErr + serverErrMetric := telemetry.SubscribeServerErr registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), @@ -154,7 +247,7 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { assert.NoError(t, err) telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) telemetrySubscription.On("RecordCounterMetric", serverErrMetric, 1) handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) @@ -184,8 +277,8 @@ func TestErrorTelemetryAPICallFailure(t *testing.T) { func TestRenderLogsSubscriptionClosed(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - apiError := logsapi.ErrTelemetryServiceOff - clientErrMetric := logsapi.SubscribeClientErr + apiError := telemetry.ErrTelemetryServiceOff + clientErrMetric := telemetry.SubscribeClientErr registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), @@ -196,7 +289,7 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { assert.NoError(t, err) telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Logs API subscription is closed already") telemetrySubscription.On("GetServiceClosedErrorType").Return("Logs.SubscriptionClosed") @@ -228,8 +321,8 @@ func TestRenderLogsSubscriptionClosed(t *testing.T) { func TestRenderTelemetrySubscriptionClosed(t *testing.T) { agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - apiError := logsapi.ErrTelemetryServiceOff - clientErrMetric := logsapi.SubscribeClientErr + apiError := telemetry.ErrTelemetryServiceOff + clientErrMetric := telemetry.SubscribeClientErr registrationService := core.NewRegistrationService( core.NewInitFlowSynchronization(), @@ -240,7 +333,7 @@ func TestRenderTelemetrySubscriptionClosed(t *testing.T) { assert.NoError(t, err) telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) + telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Telemetry API subscription is closed already") telemetrySubscription.On("GetServiceClosedErrorType").Return("Telemetry.SubscriptionClosed") diff --git a/lambda/rapi/model/agentregisterresponse.go b/lambda/rapi/model/agentregisterresponse.go index 7e2eb86..fb9cacc 100644 --- a/lambda/rapi/model/agentregisterresponse.go +++ b/lambda/rapi/model/agentregisterresponse.go @@ -5,6 +5,7 @@ package model // ExtensionRegisterResponse is a response returned by the API server on extension/register post request type ExtensionRegisterResponse struct { + AccountID string `json:"accountId,omitempty"` FunctionName string `json:"functionName"` FunctionVersion string `json:"functionVersion"` Handler string `json:"handler"` diff --git a/lambda/rapi/model/errorresponse.go b/lambda/rapi/model/errorresponse.go index 621811c..4c95e6c 100644 --- a/lambda/rapi/model/errorresponse.go +++ b/lambda/rapi/model/errorresponse.go @@ -3,12 +3,6 @@ package model -import ( - "encoding/json" - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/interop" -) - // ErrorResponse is a standard invoke error response, // providing information about the error. type ErrorResponse struct { @@ -16,16 +10,3 @@ type ErrorResponse struct { ErrorType string `json:"errorType"` StackTrace []string `json:"stackTrace,omitempty"` } - -func (s *ErrorResponse) AsInteropError() *interop.ErrorResponse { - respJSON, err := json.Marshal(s) - if err != nil { - log.Panicf("Failed to marshal %#v: %s", *s, err) - } - - return &interop.ErrorResponse{ - ErrorType: s.ErrorType, - ErrorMessage: s.ErrorMessage, - Payload: respJSON, - } -} diff --git a/lambda/rapi/rapi_fuzz_test.go b/lambda/rapi/rapi_fuzz_test.go new file mode 100644 index 0000000..f1df47f --- /dev/null +++ b/lambda/rapi/rapi_fuzz_test.go @@ -0,0 +1,391 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/url" + "os" + "regexp" + "strings" + "testing" + "unicode" + + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/extensions" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" +) + +type runtimeFunctionErrStruct struct { + ErrorMessage string + ErrorType string + StackTrace []string +} + +func FuzzRuntimeAPIRouter(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + addSeedCorpusURLTargets(f) + + f.Fuzz(func(t *testing.T, rawPath string, payload []byte, isGetMethod bool) { + u, err := parseToURLStruct(rawPath) + if err != nil { + t.Skipf("error parsing url: %v. Skipping test.", err) + } + + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + + invoke := createDummyInvoke() + flowTest.ConfigureForInvoke(context.Background(), invoke) + + appctx.StoreInitType(flowTest.AppCtx, true) + + rapiServer := makeRapiServer(flowTest) + + method := "GET" + if !isGetMethod { + method = "POST" + } + + request := httptest.NewRequest(method, rawPath, bytes.NewReader(payload)) + responseRecorder := serveTestRequest(rapiServer, request) + + if isExpectedPath(u.Path, invoke.ID, isGetMethod) { + assertExpectedPathResponseCode(t, responseRecorder.Code, rawPath) + } else { + assertUnexpectedPathResponseCode(t, responseRecorder.Code, rawPath) + } + }) +} + +func FuzzInitErrorHandler(f *testing.F) { + addRuntimeFunctionErrorJSONCorpus(f) + + f.Fuzz(func(t *testing.T, errorBody []byte, errTypeHeader []byte) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL("/runtime/init/error", version20180601) + request := httptest.NewRequest("POST", target, bytes.NewReader(errorBody)) + request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) + request.Header.Set("Lambda-Runtime-Function-Error-Type", string(errTypeHeader)) + + responseRecorder := serveTestRequest(rapiServer, request) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.JSONEq(t, "{\"status\":\"OK\"}\n", responseRecorder.Body.String()) + assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) + + assertErrorResponsePersists(t, errorBody, errTypeHeader, flowTest) + }) +} + +func FuzzInvocationResponseHandler(f *testing.F) { + f.Add([]byte("SUCCESS"), []byte("application/json"), []byte("streaming")) + f.Add([]byte(strings.Repeat("a", interop.MaxPayloadSize+1)), []byte("application/json"), []byte("streaming")) + + f.Fuzz(func(t *testing.T, responseBody []byte, contentType []byte, responseMode []byte) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + flowTest.Runtime.Ready() + + invoke := createDummyInvoke() + flowTest.ConfigureForInvoke(context.Background(), invoke) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/response", invoke.ID), version20180601) + request := httptest.NewRequest("POST", target, bytes.NewReader(responseBody)) + request.Header.Set("Content-Type", string(contentType)) + request.Header.Set("Lambda-Runtime-Function-Response-Mode", string(responseMode)) + + request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) + + responseRecorder := serveTestRequest(rapiServer, request) + + if !isValidResponseMode(responseMode) { + assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) + return + } + + if len(responseBody) > interop.MaxPayloadSize { + assertInvocationResponseTooLarge(t, responseRecorder, flowTest, responseBody) + } else { + assertInvocationResponseAccepted(t, responseRecorder, flowTest, responseBody, contentType) + } + }) +} + +func FuzzInvocationErrorHandler(f *testing.F) { + addRuntimeFunctionErrorJSONCorpus(f) + + f.Fuzz(func(t *testing.T, errorBody []byte, errTypeHeader []byte) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + flowTest.Runtime.Ready() + appCtx := flowTest.AppCtx + + invoke := createDummyInvoke() + flowTest.ConfigureForInvoke(context.Background(), invoke) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/error", invoke.ID), version20180601) + request := httptest.NewRequest("POST", target, bytes.NewReader(errorBody)) + request = appctx.RequestWithAppCtx(request, appCtx) + + request.Header.Set("Lambda-Runtime-Function-Error-Type", string(errTypeHeader)) + + responseRecorder := serveTestRequest(rapiServer, request) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.JSONEq(t, "{\"status\":\"OK\"}\n", responseRecorder.Body.String()) + assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) + + assertErrorResponsePersists(t, errorBody, errTypeHeader, flowTest) + }) +} + +func FuzzRestoreErrorHandler(f *testing.F) { + f.Fuzz(func(t *testing.T, errorBody []byte, errTypeHeader []byte) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForRestoring() + + appctx.StoreInitType(flowTest.AppCtx, true) + + rapiServer := makeRapiServer(flowTest) + + target := makeTargetURL("/runtime/restore/error", version20180601) + request := httptest.NewRequest("POST", target, bytes.NewReader(errorBody)) + request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) + + request.Header.Set("Lambda-Runtime-Function-Error-Type", string(errTypeHeader)) + + responseRecorder := serveTestRequest(rapiServer, request) + + assert.Equal(t, http.StatusAccepted, responseRecorder.Code) + assert.JSONEq(t, "{\"status\":\"OK\"}\n", responseRecorder.Body.String()) + assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) + }) +} + +func makeRapiServer(flowTest *testdata.FlowTest) *Server { + return NewServer( + "127.0.0.1", + 0, + flowTest.AppCtx, + flowTest.RegistrationService, + flowTest.RenderingService, + true, + &telemetry.NoOpSubscriptionAPI{}, + flowTest.TelemetrySubscription, + flowTest.CredentialsService, + ) +} + +func createDummyInvoke() *interop.Invoke { + return &interop.Invoke{ + ID: "InvocationID1", + Payload: strings.NewReader("Payload1"), + } +} + +func makeTargetURL(path string, apiVersion string) string { + protocol := "http" + endpoint := os.Getenv("AWS_LAMBDA_RUNTIME_API") + baseurl := fmt.Sprintf("%s://%s%s", protocol, endpoint, apiVersion) + + return fmt.Sprintf("%s%s", baseurl, path) +} + +func serveTestRequest(rapiServer *Server, request *http.Request) *httptest.ResponseRecorder { + responseRecorder := httptest.NewRecorder() + rapiServer.server.Handler.ServeHTTP(responseRecorder, request) + log.Printf("test(%v) = %v", request.URL, responseRecorder.Code) + + return responseRecorder +} + +func addSeedCorpusURLTargets(f *testing.F) { + invoke := createDummyInvoke() + errStruct := runtimeFunctionErrStruct{ + ErrorMessage: "error occurred", + ErrorType: "Runtime.UnknownReason", + StackTrace: []string{}, + } + errJSON, _ := json.Marshal(errStruct) + f.Add(makeTargetURL("/runtime/init/error", version20180601), errJSON, false) + f.Add(makeTargetURL("/runtime/invocation/next", version20180601), []byte{}, true) + f.Add(makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/response", invoke.ID), version20180601), []byte("SUCCESS"), false) + f.Add(makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/error", invoke.ID), version20180601), errJSON, false) + f.Add(makeTargetURL("/runtime/restore/next", version20180601), []byte{}, true) + f.Add(makeTargetURL("/runtime/restore/error", version20180601), errJSON, false) + + f.Add(makeTargetURL("/extension/register", version20200101), []byte("register"), false) + f.Add(makeTargetURL("/extension/event/next", version20200101), []byte("next"), true) + f.Add(makeTargetURL("/extension/init/error", version20200101), []byte("init error"), false) + f.Add(makeTargetURL("/extension/exit/error", version20200101), []byte("exit error"), false) +} + +func addRuntimeFunctionErrorJSONCorpus(f *testing.F) { + runtimeFuncErr := runtimeFunctionErrStruct{ + ErrorMessage: "error", + ErrorType: "Runtime.Unknown", + StackTrace: []string{}, + } + data, _ := json.Marshal(runtimeFuncErr) + + f.Add(data, []byte("Runtime.Unknown")) +} + +func isExpectedPath(path string, invokeID string, isGetMethod bool) bool { + expectedPaths := make(map[string]bool) + + expectedPaths[fmt.Sprintf("%s/runtime/init/error", version20180601)] = false + expectedPaths[fmt.Sprintf("%s/runtime/invocation/next", version20180601)] = true + expectedPaths[fmt.Sprintf("%s/runtime/invocation/%s/response", version20180601, invokeID)] = false + expectedPaths[fmt.Sprintf("%s/runtime/invocation/%s/error", version20180601, invokeID)] = false + expectedPaths[fmt.Sprintf("%s/runtime/restore/next", version20180601)] = true + expectedPaths[fmt.Sprintf("%s/runtime/restore/error", version20180601)] = false + + expectedPaths[fmt.Sprintf("%s/extension/register", version20200101)] = false + expectedPaths[fmt.Sprintf("%s/extension/event/next", version20200101)] = true + expectedPaths[fmt.Sprintf("%s/extension/init/error", version20200101)] = false + expectedPaths[fmt.Sprintf("%s/extension/exit/error", version20200101)] = false + + val, found := expectedPaths[path] + return found && (val == isGetMethod) +} + +func parseToURLStruct(rawPath string) (*url.URL, error) { + invalidChars := regexp.MustCompile(`[ %]+`) + if invalidChars.MatchString(rawPath) { + return nil, errors.New("url must not contain spaces or %") + } + + for _, r := range rawPath { + if !unicode.IsGraphic(r) { + return nil, errors.New("url contains non-graphic runes") + } + } + + if _, err := url.ParseRequestURI(rawPath); err != nil { + return nil, err + } + + u, err := url.Parse(rawPath) + if err != nil { + return nil, err + } + + if u.Scheme == "" { + return nil, errors.New("blank url scheme") + } + + return u, nil +} + +func assertInvocationResponseAccepted(t *testing.T, responseRecorder *httptest.ResponseRecorder, + flowTest *testdata.FlowTest, responseBody []byte, contentType []byte) { + assert.Equal(t, http.StatusAccepted, responseRecorder.Code, + "Handler returned wrong status code: got %v expected %v", + responseRecorder.Code, http.StatusAccepted) + + expectedAPIResponse := "{\"status\":\"OK\"}\n" + body, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + assert.JSONEq(t, expectedAPIResponse, string(body)) + + response := flowTest.InteropServer.Response + assert.NotNil(t, response) + assert.Nil(t, flowTest.InteropServer.ErrorResponse) + + assert.Equal(t, string(contentType), flowTest.InteropServer.ResponseContentType) + + assert.Equal(t, responseBody, response, + "Persisted response data in app context must match the submitted.") +} + +func assertInvocationResponseTooLarge(t *testing.T, responseRecorder *httptest.ResponseRecorder, flowTest *testdata.FlowTest, responseBody []byte) { + assert.Equal(t, http.StatusRequestEntityTooLarge, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", + responseRecorder.Code, http.StatusRequestEntityTooLarge) + + expectedAPIResponse := fmt.Sprintf("{\"errorMessage\":\"Exceeded maximum allowed payload size (%d bytes).\",\"errorType\":\"RequestEntityTooLarge\"}\n", interop.MaxPayloadSize) + body, err := io.ReadAll(responseRecorder.Body) + assert.NoError(t, err) + assert.JSONEq(t, expectedAPIResponse, string(body)) + + errorResponse := flowTest.InteropServer.ErrorResponse + assert.NotNil(t, errorResponse) + assert.Nil(t, flowTest.InteropServer.Response) + assert.Equal(t, fatalerror.FunctionOversizedResponse, errorResponse.FunctionError.Type) + assert.Equal(t, fmt.Sprintf("Response payload size (%v bytes) exceeded maximum allowed payload size (6291556 bytes).", len(responseBody)), errorResponse.FunctionError.Message) + + var errorPayload map[string]interface{} + assert.NoError(t, json.Unmarshal(errorResponse.Payload, &errorPayload)) + assert.Equal(t, string(errorResponse.FunctionError.Type), errorPayload["errorType"]) + assert.Equal(t, errorResponse.FunctionError.Message, errorPayload["errorMessage"]) +} + +func assertErrorResponsePersists(t *testing.T, errorBody []byte, errTypeHeader []byte, flowTest *testdata.FlowTest) { + errorResponse := flowTest.InteropServer.ErrorResponse + assert.NotNil(t, errorResponse) + assert.Nil(t, flowTest.InteropServer.Response) + + var runtimeFunctionErr runtimeFunctionErrStruct + var expectedErrMsg string + + // If input payload is a valid function error json object, + // assert that the error message persisted in the response + err := json.Unmarshal(errorBody, &runtimeFunctionErr) + if err != nil { + expectedErrMsg = runtimeFunctionErr.ErrorMessage + } + assert.Equal(t, expectedErrMsg, errorResponse.FunctionError.Message) + + // If input error type is valid (within the allow-listed value, + // assert that the error type persisted in the response + expectedErrType := fatalerror.GetValidRuntimeOrFunctionErrorType(string(errTypeHeader)) + assert.Equal(t, expectedErrType, errorResponse.FunctionError.Type) + + assert.Equal(t, errorBody, errorResponse.Payload) +} + +func isValidResponseMode(responseMode []byte) bool { + responseModeStr := string(responseMode) + return responseModeStr == "streaming" || + responseModeStr == "" +} + +func assertExpectedPathResponseCode(t *testing.T, code int, target string) { + if !(code == http.StatusOK || + code == http.StatusAccepted || + code == http.StatusForbidden) { + t.Errorf("Unexpected status code (%v) for target (%v)", code, target) + } +} + +func assertUnexpectedPathResponseCode(t *testing.T, code int, target string) { + if !(code == http.StatusNotFound || + code == http.StatusMethodNotAllowed || + code == http.StatusBadRequest) { + t.Errorf("Unexpected status code (%v) for target (%v)", code, target) + } +} diff --git a/lambda/rapi/rendering/render_error.go b/lambda/rapi/rendering/render_error.go new file mode 100644 index 0000000..151e606 --- /dev/null +++ b/lambda/rapi/rendering/render_error.go @@ -0,0 +1,88 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rendering + +import ( + "fmt" + "net/http" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi/model" +) + +// RenderForbiddenWithTypeMsg method for rendering error response +func RenderForbiddenWithTypeMsg(w http.ResponseWriter, r *http.Request, errorType string, format string, args ...interface{}) { + if err := RenderJSON(http.StatusForbidden, w, r, &model.ErrorResponse{ + ErrorType: errorType, + ErrorMessage: fmt.Sprintf(format, args...), + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInternalServerError method for rendering error response +func RenderInternalServerError(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusInternalServerError, w, r, &model.ErrorResponse{ + ErrorMessage: "Internal Server Error", + ErrorType: ErrorTypeInternalServerError, + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderRequestEntityTooLarge method for rendering error response +func RenderRequestEntityTooLarge(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusRequestEntityTooLarge, w, r, &model.ErrorResponse{ + ErrorMessage: fmt.Sprintf("Exceeded maximum allowed payload size (%d bytes).", interop.MaxPayloadSize), + ErrorType: ErrorTypeRequestEntityTooLarge, + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderTruncatedHTTPRequestError method for rendering error response +func RenderTruncatedHTTPRequestError(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "HTTP request detected as truncated", + ErrorType: ErrorTypeTruncatedHTTPRequest, + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInvalidRequestID renders invalid request ID error response +func RenderInvalidRequestID(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "Invalid request ID", + ErrorType: "InvalidRequestID", + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInvalidFunctionResponseMode renders invalid function response mode response +func RenderInvalidFunctionResponseMode(w http.ResponseWriter, r *http.Request) { + if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ + ErrorMessage: "Invalid function response mode", + ErrorType: "InvalidFunctionResponseMode", + }); err != nil { + log.WithError(err).Warn("Error while rendering response") + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// RenderInteropError is a convenience method for interpreting interop errors +func RenderInteropError(writer http.ResponseWriter, request *http.Request, err error) { + if err == interop.ErrInvalidInvokeID || err == interop.ErrResponseSent { + RenderInvalidRequestID(writer, request) + } else { + log.Panic(err) + } +} diff --git a/lambda/rapi/rendering/render_json.go b/lambda/rapi/rendering/render_json.go index 8cea816..1afbfe8 100644 --- a/lambda/rapi/rendering/render_json.go +++ b/lambda/rapi/rendering/render_json.go @@ -6,8 +6,9 @@ package rendering import ( "bytes" "encoding/json" - log "github.com/sirupsen/logrus" "net/http" + + log "github.com/sirupsen/logrus" ) // RenderJSON: @@ -15,6 +16,7 @@ import ( // - sets the Content-Type as application/json // - sets the HTTP response status code // - returns an error if it occurred before writing to response +// TODO: r *http.Request is not used, remove it func RenderJSON(status int, w http.ResponseWriter, r *http.Request, v interface{}) error { buf := &bytes.Buffer{} enc := json.NewEncoder(buf) diff --git a/lambda/rapi/rendering/rendering.go b/lambda/rapi/rendering/rendering.go index 0edfb68..9a9d77b 100644 --- a/lambda/rapi/rendering/rendering.go +++ b/lambda/rapi/rendering/rendering.go @@ -4,10 +4,10 @@ package rendering import ( + "bytes" "context" "encoding/json" "errors" - "fmt" "io" "net/http" "strconv" @@ -50,6 +50,13 @@ type EventRenderingService struct { currentState RendererState } +// NewRenderingService returns new EventRenderingService. +func NewRenderingService() *EventRenderingService { + return &EventRenderingService{ + mutex: &sync.RWMutex{}, + } +} + // SetRenderer set current state func (s *EventRenderingService) SetRenderer(state RendererState) { s.mutex.Lock() @@ -77,11 +84,19 @@ func (s *EventRenderingService) RenderRuntimeEvent(w http.ResponseWriter, r *htt return s.currentState.RenderRuntimeEvent(w, r) } -// NewRenderingService returns new EventRenderingService. -func NewRenderingService() *EventRenderingService { - return &EventRenderingService{ - mutex: &sync.RWMutex{}, - } +type RestoreRenderer struct{} + +func NewRestoreRenderer() *RestoreRenderer { + return &RestoreRenderer{} +} + +func (s *RestoreRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { + writer.WriteHeader(http.StatusOK) + return nil +} + +func (s *RestoreRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { + return nil } // InvokeRendererMetrics contains metrics of invoke request @@ -94,17 +109,26 @@ type InvokeRendererMetrics struct { type InvokeRenderer struct { ctx context.Context invoke *interop.Invoke - tracingHeaderParser func(context.Context, *interop.Invoke) string - requestBuffer []byte + tracingHeaderParser func(context.Context) string + requestBuffer *bytes.Buffer requestMutex sync.Mutex metrics InvokeRendererMetrics } -type RestoreRenderer struct { +// NewInvokeRenderer returns new invoke event renderer +func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, requestBuffer *bytes.Buffer, traceParser func(context.Context) string) *InvokeRenderer { + requestBuffer.Reset() // clear request buffer, since this can be reused across invokes + return &InvokeRenderer{ + invoke: invoke, + ctx: ctx, + tracingHeaderParser: traceParser, + requestBuffer: requestBuffer, + requestMutex: sync.Mutex{}, + } } -// NewAgentInvokeEvent forms a new AgentInvokeEvent from INVOKE request -func NewAgentInvokeEvent(req *interop.Invoke) (*model.AgentInvokeEvent, error) { +// newAgentInvokeEvent forms a new AgentInvokeEvent from INVOKE request +func newAgentInvokeEvent(req *interop.Invoke) (*model.AgentInvokeEvent, error) { deadlineMono, err := strconv.ParseInt(req.DeadlineNs, 10, 64) if err != nil { return nil, err @@ -123,7 +147,7 @@ func NewAgentInvokeEvent(req *interop.Invoke) (*model.AgentInvokeEvent, error) { // RenderAgentEvent renders invoke event json for agent. func (s *InvokeRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { - event, err := NewAgentInvokeEvent(s.invoke) + event, err := newAgentInvokeEvent(s.invoke) if err != nil { return err } @@ -133,7 +157,11 @@ func (s *InvokeRenderer) RenderAgentEvent(writer http.ResponseWriter, request *h return err } - renderAgentInvokeHeaders(writer, uuid.New()) // TODO: check this thing + eventID := uuid.New() + headers := writer.Header() + headers.Set("Lambda-Extension-Event-Identifier", eventID.String()) + headers.Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) if _, err := writer.Write(bytes); err != nil { return err @@ -145,13 +173,13 @@ func (s *InvokeRenderer) bufferInvokeRequest() error { s.requestMutex.Lock() defer s.requestMutex.Unlock() var err error = nil - if nil == s.requestBuffer { + if s.requestBuffer.Len() == 0 { reader := io.LimitReader(s.invoke.Payload, interop.MaxPayloadSize) start := time.Now() - s.requestBuffer, err = io.ReadAll(reader) + _, err = s.requestBuffer.ReadFrom(reader) s.metrics = InvokeRendererMetrics{ ReadTime: time.Since(start), - SizeBytes: len(s.requestBuffer), + SizeBytes: s.requestBuffer.Len(), } } return err @@ -160,7 +188,7 @@ func (s *InvokeRenderer) bufferInvokeRequest() error { // RenderRuntimeEvent renders invoke payload for runtime. func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { invoke := s.invoke - customerTraceID := s.tracingHeaderParser(s.ctx, s.invoke) + customerTraceID := s.tracingHeaderParser(s.ctx) cognitoIdentityJSON := "" if len(invoke.CognitoIdentityID) != 0 || len(invoke.CognitoIdentityPoolID) != 0 { @@ -189,37 +217,13 @@ func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request if err := s.bufferInvokeRequest(); err != nil { return err } - _, err := writer.Write(s.requestBuffer) + _, err := writer.Write(s.requestBuffer.Bytes()) return err } return nil } -func (s *RestoreRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { - writer.WriteHeader(http.StatusOK) - return nil -} - -func (s *RestoreRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { - return nil -} - -// NewInvokeRenderer returns new invoke event renderer -func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, traceParser func(context.Context, *interop.Invoke) string) *InvokeRenderer { - return &InvokeRenderer{ - invoke: invoke, - ctx: ctx, - tracingHeaderParser: traceParser, - requestBuffer: nil, - requestMutex: sync.Mutex{}, - } -} - -func NewRestoreRenderer() *RestoreRenderer { - return &RestoreRenderer{} -} - func (s *InvokeRenderer) GetMetrics() InvokeRendererMetrics { s.requestMutex.Lock() defer s.requestMutex.Unlock() @@ -248,22 +252,15 @@ func (s *ShutdownRenderer) RenderRuntimeEvent(w http.ResponseWriter, r *http.Req panic("We should SIGTERM runtime") } -func setHeaderIfNotEmpty(headers http.Header, key string, value string) { - if len(value) != 0 { - headers.Set(key, value) - } -} +func renderInvokeHeaders(writer http.ResponseWriter, invokeID string, customerTraceID string, clientContext string, + cognitoIdentity string, invokedFunctionArn string, deadlineMs string, contentType string) { -func setHeaderOrDefault(headers http.Header, key, val, defaultVal string) { - if val == "" { - headers.Set(key, defaultVal) - return + setHeaderIfNotEmpty := func(headers http.Header, key string, value string) { + if value != "" { + headers.Set(key, value) + } } - headers.Set(key, val) -} -func renderInvokeHeaders(writer http.ResponseWriter, invokeID string, customerTraceID string, clientContext string, - cognitoIdentity string, invokedFunctionArn string, deadlineMs string, contentType string) { headers := writer.Header() setHeaderIfNotEmpty(headers, "Lambda-Runtime-Aws-Request-Id", invokeID) setHeaderIfNotEmpty(headers, "Lambda-Runtime-Trace-Id", customerTraceID) @@ -271,7 +268,10 @@ func renderInvokeHeaders(writer http.ResponseWriter, invokeID string, customerTr setHeaderIfNotEmpty(headers, "Lambda-Runtime-Cognito-Identity", cognitoIdentity) setHeaderIfNotEmpty(headers, "Lambda-Runtime-Invoked-Function-Arn", invokedFunctionArn) setHeaderIfNotEmpty(headers, "Lambda-Runtime-Deadline-Ms", deadlineMs) - setHeaderOrDefault(headers, "Content-Type", contentType, "application/json") + if contentType == "" { + contentType = "application/json" + } + headers.Set("Content-Type", contentType) writer.WriteHeader(http.StatusOK) } @@ -290,79 +290,6 @@ func RenderRuntimeLogsResponse(w http.ResponseWriter, respBody []byte, status in return err } -func renderAgentInvokeHeaders(writer http.ResponseWriter, eventID uuid.UUID) { - headers := writer.Header() - headers.Set("Lambda-Extension-Event-Identifier", eventID.String()) - headers.Set("Content-Type", "application/json") - writer.WriteHeader(http.StatusOK) -} - -// RenderForbiddenWithTypeMsg method for rendering error response -func RenderForbiddenWithTypeMsg(w http.ResponseWriter, r *http.Request, errorType string, format string, args ...interface{}) { - if err := RenderJSON(http.StatusForbidden, w, r, &model.ErrorResponse{ - ErrorType: errorType, - ErrorMessage: fmt.Sprintf(format, args...), - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderInternalServerError method for rendering error response -func RenderInternalServerError(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusInternalServerError, w, r, &model.ErrorResponse{ - ErrorMessage: "Internal Server Error", - ErrorType: ErrorTypeInternalServerError, - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderRequestEntityTooLarge method for rendering error response -func RenderRequestEntityTooLarge(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusRequestEntityTooLarge, w, r, &model.ErrorResponse{ - ErrorMessage: fmt.Sprintf("Exceeded maximum allowed payload size (%d bytes).", interop.MaxPayloadSize), - ErrorType: ErrorTypeRequestEntityTooLarge, - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderTruncatedHTTPRequestError method for rendering error response -func RenderTruncatedHTTPRequestError(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ - ErrorMessage: "HTTP request detected as truncated", - ErrorType: ErrorTypeTruncatedHTTPRequest, - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderInvalidRequestID renders invalid request ID error response -func RenderInvalidRequestID(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ - ErrorMessage: "Invalid request ID", - ErrorType: "InvalidRequestID", - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderInvalidFunctionResponseMode renders invalid function response mode response -func RenderInvalidFunctionResponseMode(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ - ErrorMessage: "Invalid function response mode", - ErrorType: "InvalidFunctionResponseMode", - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - // RenderAccepted method for rendering accepted status response func RenderAccepted(w http.ResponseWriter, r *http.Request) { if err := RenderJSON(http.StatusAccepted, w, r, &model.StatusResponse{ @@ -372,12 +299,3 @@ func RenderAccepted(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) } } - -// RenderInteropError is a convenience method for interpreting interop errors -func RenderInteropError(writer http.ResponseWriter, request *http.Request, err error) { - if err == interop.ErrInvalidInvokeID || err == interop.ErrResponseSent { - RenderInvalidRequestID(writer, request) - } else { - log.Panic(err) - } -} diff --git a/lambda/rapi/router.go b/lambda/rapi/router.go index 5c2a56d..dc036bc 100644 --- a/lambda/rapi/router.go +++ b/lambda/rapi/router.go @@ -19,7 +19,7 @@ import ( // NewRouter returns a new instance of chi router implementing // Runtime API specification. -func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, eventsAPI telemetry.EventsAPI) http.Handler { +func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { router := chi.NewRouter() router.Use(middleware.AppCtxMiddleware(appCtx)) @@ -45,11 +45,11 @@ func NewRouter(appCtx appctx.ApplicationContext, registrationService core.Regist middleware.AwsRequestIDValidator( handler.NewInvocationErrorHandler(registrationService)).ServeHTTP) - router.Post("/runtime/init/error", - handler.NewInitErrorHandler(registrationService, eventsAPI).ServeHTTP) + router.Post("/runtime/init/error", handler.NewInitErrorHandler(registrationService).ServeHTTP) if appctx.LoadInitType(appCtx) == appctx.InitCaching { router.Get("/runtime/restore/next", handler.NewRestoreNextHandler(registrationService, renderingService).ServeHTTP) + router.Post("/runtime/restore/error", handler.NewRestoreErrorHandler(registrationService).ServeHTTP) } return router diff --git a/lambda/rapi/router_test.go b/lambda/rapi/router_test.go index 73cbde1..276fa53 100644 --- a/lambda/rapi/router_test.go +++ b/lambda/rapi/router_test.go @@ -69,7 +69,7 @@ func assertResponseErrorType(t *testing.T, expectedErrorType string, response *h func TestAcceptXML(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) responseRecorder := httptest.NewRecorder() request := httptest.NewRequest("POST", "/runtime/invocation/x-y-z/error", bytes.NewReader([]byte(""))) // Tell server that client side accepts "application/xml". @@ -90,7 +90,7 @@ func TestAcceptXML(t *testing.T) { func Test404PageNotFound(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/unsupported", bytes.NewReader([]byte("")))) assert.Equal(t, http.StatusNotFound, responseRecorder.Code) assert.Equal(t, "404 page not found\n", responseRecorder.Body.String()) @@ -99,7 +99,7 @@ func Test404PageNotFound(t *testing.T) { func Test405MethodNotAllowed(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("DELETE", "/runtime/invocation/ABC/error", bytes.NewReader([]byte("")))) assert.Equal(t, http.StatusMethodNotAllowed, responseRecorder.Code) } @@ -107,7 +107,7 @@ func Test405MethodNotAllowed(t *testing.T) { func TestInitErrorAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/init/error", bytes.NewReader([]byte("{}")))) assert.Equal(t, http.StatusAccepted, responseRecorder.Code) } @@ -115,7 +115,7 @@ func TestInitErrorAccepted(t *testing.T) { func TestInitErrorForbidden(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -126,7 +126,7 @@ func TestInitErrorForbidden(t *testing.T) { func TestInvokeResponseAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -137,7 +137,7 @@ func TestInvokeResponseAccepted(t *testing.T) { func TestInvokeErrorResponseAccepted(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -148,7 +148,7 @@ func TestInvokeErrorResponseAccepted(t *testing.T) { func TestInvokeNextTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -159,7 +159,7 @@ func TestInvokeNextTwice(t *testing.T) { func TestInvokeResponseInvalidRequestID(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -171,7 +171,7 @@ func TestInvokeResponseInvalidRequestID(t *testing.T) { func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -183,7 +183,7 @@ func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { func TestInvokeResponseTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -197,7 +197,7 @@ func TestInvokeResponseTwice(t *testing.T) { func TestInvokeErrorResponseTwice(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -211,7 +211,7 @@ func TestInvokeErrorResponseTwice(t *testing.T) { func TestInvokeResponseAfterErrorResponse(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -225,7 +225,7 @@ func TestInvokeResponseAfterErrorResponse(t *testing.T) { func TestInvokeErrorResponseAfterResponse(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) assert.Equal(t, http.StatusOK, responseRecorder.Code) @@ -239,7 +239,7 @@ func TestInvokeErrorResponseAfterResponse(t *testing.T) { func TestMoreThanOneInvoke(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) var responseRecorder *httptest.ResponseRecorder for _, id := range []string{"A", "B", "C"} { flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) @@ -253,7 +253,7 @@ func TestMoreThanOneInvoke(t *testing.T) { func TestInitCachingAPIDisabledForPlainInit(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) var responseRecorder *httptest.ResponseRecorder responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/restore/next", nil)) @@ -263,12 +263,13 @@ func TestInitCachingAPIDisabledForPlainInit(t *testing.T) { assert.Equal(t, http.StatusNotFound, responseRecorder.Code) } -func benchmarkInvoke(b *testing.B, payload []byte) { +func benchmarkInvokeResponse(b *testing.B, payload []byte) { b.StopTimer() + b.ResetTimer() // does not restart timer, only resets state b.ReportAllocs() flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) for i := 0; i < b.N; i++ { id := uuid.New().String() flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) @@ -277,30 +278,76 @@ func benchmarkInvoke(b *testing.B, payload []byte) { } } -func BenchmarkInvokeWithEmptyPayload(b *testing.B) { - benchmarkInvoke(b, []byte("")) +func BenchmarkInvokeResponseWithEmptyPayload(b *testing.B) { + benchmarkInvokeResponse(b, []byte("")) } -func BenchmarkInvokeWith4KBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 4*1024)) +func BenchmarkInvokeResponseWith4KBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 4*1024)) } -func BenchmarkInvokeWith512KBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 512*1024)) +func BenchmarkInvokeResponseWith512KBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 512*1024)) } -func BenchmarkInvokeWith1MBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 1*1024*1024)) +func BenchmarkInvokeResponseWith1MBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 1*1024*1024)) } -func BenchmarkInvokeWith2MBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 2*1024*1024)) +func BenchmarkInvokeResponseWith2MBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 2*1024*1024)) } -func BenchmarkInvokeWith4MBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 4*1024*1024)) +func BenchmarkInvokeResponseWith4MBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 4*1024*1024)) } -func BenchmarkInvokeWith6MBPayload(b *testing.B) { - benchmarkInvoke(b, bytes.Repeat([]byte("a"), 6*1024*1024)) +func BenchmarkInvokeResponseWith6MBPayload(b *testing.B) { + benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 6*1024*1024)) +} + +func benchmarkInvokeRequest(b *testing.B, payload []byte) { + b.StopTimer() + b.ResetTimer() // does not restart timer, only resets state + b.ReportAllocs() + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) + var requestBuf bytes.Buffer + for i := 0; i < b.N; i++ { + id := uuid.New().String() + ctx, invoke := context.Background(), createInvoke(id) + flowTest.ConfigureForInvoke(ctx, invoke) // set invoke ID and initialize barriers + flowTest.ConfigureInvokeRenderer(ctx, invoke, &requestBuf) // override invoke renderer to reuse buffer + makeBenchRequest(b, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) + makeBenchRequest(b, router, httptest.NewRequest("POST", fmt.Sprintf("/runtime/invocation/%s/response", id), bytes.NewReader(payload))) + } +} + +func BenchmarkInvokeRequestWithEmptyPayload(b *testing.B) { + benchmarkInvokeRequest(b, []byte("")) +} + +func BenchmarkInvokeRequestWith4KBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 4*1024)) +} + +func BenchmarkInvokeRequestWith512KBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 512*1024)) +} + +func BenchmarkInvokeRequestWith1MBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 1*1024*1024)) +} + +func BenchmarkInvokeRequestWith2MBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 2*1024*1024)) +} + +func BenchmarkInvokeRequestWith4MBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 4*1024*1024)) +} + +func BenchmarkInvokeRequestWith6MBPayload(b *testing.B) { + benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 6*1024*1024)) } diff --git a/lambda/rapi/security_test.go b/lambda/rapi/security_test.go index 5312b43..3f869d5 100644 --- a/lambda/rapi/security_test.go +++ b/lambda/rapi/security_test.go @@ -20,7 +20,7 @@ func TestInvokeValidId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) @@ -53,7 +53,7 @@ func TestSecurityInvokeResponseBadRequestId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) @@ -100,7 +100,7 @@ func TestSecurityInvokeErrorBadRequestId(t *testing.T) { flowTest := testdata.NewFlowTest() flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, flowTest.EventsAPI) + router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) diff --git a/lambda/rapi/server.go b/lambda/rapi/server.go index dd027f4..d17270a 100644 --- a/lambda/rapi/server.go +++ b/lambda/rapi/server.go @@ -46,16 +46,22 @@ func SaveConnInContext(ctx context.Context, c net.Conn) context.Context { // should happen before provided runtime is started. // // When port is 0, OS will dynamically allocate the listening port. -func NewServer(host string, port int, appCtx appctx.ApplicationContext, +func NewServer( + host string, + port int, + appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService, telemetryAPIEnabled bool, - logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI, credentialsService core.CredentialsService, eventsAPI telemetry.EventsAPI) *Server { + logsSubscriptionAPI telemetry.SubscriptionAPI, + telemetrySubscriptionAPI telemetry.SubscriptionAPI, + credentialsService core.CredentialsService, +) *Server { exitErrors := make(chan error, 1) router := chi.NewRouter() - router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService, eventsAPI)) + router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService)) router.Mount(version20200101, ExtensionsRouter(appCtx, registrationService, renderingService)) if telemetryAPIEnabled { diff --git a/lambda/rapi/telemetry_logs_fuzz_test.go b/lambda/rapi/telemetry_logs_fuzz_test.go new file mode 100644 index 0000000..89adbd1 --- /dev/null +++ b/lambda/rapi/telemetry_logs_fuzz_test.go @@ -0,0 +1,185 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapi + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "go.amzn.com/lambda/extensions" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi/handler" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" +) + +const ( + logsHandlerPath = "/logs" + telemetryHandlerPath = "/telemetry" + + samplePayload = `{"foo" : "bar"}` +) + +func FuzzTelemetryLogRouters(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + f.Add(makeTargetURL(logsHandlerPath, version20200815), []byte(samplePayload)) + f.Add(makeTargetURL(telemetryHandlerPath, version20220701), []byte(samplePayload)) + + logsPath := fmt.Sprintf("%s%s", version20200815, logsHandlerPath) + telemetryPath := fmt.Sprintf("%s%s", version20220701, telemetryHandlerPath) + + f.Fuzz(func(t *testing.T, rawPath string, payload []byte) { + u, err := parseToURLStruct(rawPath) + if err != nil { + t.Skipf("error parsing url: %v. Skipping test.", err) + } + + flowTest := testdata.NewFlowTest() + + rapiServer := makeRapiServerWithMockSubscriptionAPI(flowTest, newMockSubscriptionAPI(true), newMockSubscriptionAPI(true)) + + request := httptest.NewRequest("PUT", rawPath, bytes.NewReader(payload)) + responseRecorder := serveTestRequest(rapiServer, request) + + if u.Path == logsPath || u.Path == telemetryPath { + assertExpectedPathResponseCode(t, responseRecorder.Code, rawPath) + } else { + assertUnexpectedPathResponseCode(t, responseRecorder.Code, rawPath) + } + }) +} + +func FuzzLogsHandler(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + fuzzSubscriptionAPIHandler(f, logsHandlerPath, version20200815) +} + +func FuzzTelemetryHandler(f *testing.F) { + extensions.Enable() + defer extensions.Disable() + + fuzzSubscriptionAPIHandler(f, telemetryHandlerPath, version20220701) +} + +func fuzzSubscriptionAPIHandler(f *testing.F, handlerPath string, apiVersion string) { + flowTest := testdata.NewFlowTest() + agent := makeExternalAgent(flowTest.RegistrationService) + f.Add([]byte(samplePayload), agent.ID.String(), true) + f.Add([]byte(samplePayload), agent.ID.String(), false) + + f.Fuzz(func(t *testing.T, payload []byte, agentIdentifierHeader string, serviceOn bool) { + telemetrySubscriptionAPI := newMockSubscriptionAPI(serviceOn) + logsSubscriptionAPI := newMockSubscriptionAPI(serviceOn) + rapiServer := makeRapiServerWithMockSubscriptionAPI(flowTest, logsSubscriptionAPI, telemetrySubscriptionAPI) + + apiUnderTest := telemetrySubscriptionAPI + if handlerPath == logsHandlerPath { + apiUnderTest = logsSubscriptionAPI + } + + target := makeTargetURL(handlerPath, apiVersion) + request := httptest.NewRequest("PUT", target, bytes.NewReader(payload)) + request.Header.Set(handler.LambdaAgentIdentifier, agentIdentifierHeader) + + responseRecorder := serveTestRequest(rapiServer, request) + + if agentIdentifierHeader == "" { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierMissing) + return + } + + if _, err := uuid.Parse(agentIdentifierHeader); err != nil { + assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierInvalid) + return + } + + if agentIdentifierHeader != agent.ID.String() { + assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") + return + } + + if !serviceOn { + assertForbiddenErrorType(t, responseRecorder, apiUnderTest.GetServiceClosedErrorType()) + return + } + + // assert that payload has been stored in the mock subscription api after the handler calls Subscribe() + assert.Equal(t, payload, apiUnderTest.receivedPayload) + }) +} + +func makeRapiServerWithMockSubscriptionAPI( + flowTest *testdata.FlowTest, + logsSubscription telemetry.SubscriptionAPI, + telemetrySubscription telemetry.SubscriptionAPI) *Server { + return NewServer( + "127.0.0.1", + 0, + flowTest.AppCtx, + flowTest.RegistrationService, + flowTest.RenderingService, + true, + logsSubscription, + telemetrySubscription, + flowTest.CredentialsService, + ) +} + +type mockSubscriptionAPI struct { + serviceOn bool + receivedPayload []byte +} + +func newMockSubscriptionAPI(serviceOn bool) *mockSubscriptionAPI { + return &mockSubscriptionAPI{ + serviceOn: serviceOn, + } +} + +func (m *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { + if !m.serviceOn { + return nil, 0, map[string][]string{}, telemetry.ErrTelemetryServiceOff + } + + bodyBytes, err := io.ReadAll(body) + if err != nil { + return nil, 0, map[string][]string{}, fmt.Errorf("error Reading the body of subscription request: %s", err) + } + + m.receivedPayload = bodyBytes + + return []byte("OK"), http.StatusOK, map[string][]string{}, nil +} + +func (m *mockSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} + +func (m *mockSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { + return nil +} + +func (m *mockSubscriptionAPI) Clear() {} + +func (m *mockSubscriptionAPI) TurnOff() {} + +func (m *mockSubscriptionAPI) GetEndpointURL() string { + return "/subscribe" +} + +func (m *mockSubscriptionAPI) GetServiceClosedErrorMessage() string { + return "Subscription API is closed" +} + +func (m *mockSubscriptionAPI) GetServiceClosedErrorType() string { + return "SubscriptionClosed" +} diff --git a/lambda/rapid/exit.go b/lambda/rapid/exit.go index e45f3a4..a601efc 100644 --- a/lambda/rapid/exit.go +++ b/lambda/rapid/exit.go @@ -4,31 +4,22 @@ package rapid import ( - "fmt" "time" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/fatalerror" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/model" + "go.amzn.com/lambda/telemetry" log "github.com/sirupsen/logrus" ) func handleInvokeError(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { invokeFailure := newInvokeFailureMsg(execCtx, invokeRequest, invokeMx, err) - resp := model.ErrorResponse{ - ErrorType: string(invokeFailure.ErrorType), - ErrorMessage: fmt.Sprintf("Error: %v", invokeFailure.ErrorMessage), - } - - if invokeRequest.ID != "" { - resp.ErrorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeRequest.ID, invokeFailure.ErrorMessage) - } // This is the default error response that gets sent back as the function response in failure cases - invokeFailure.DefaultErrorResponse = resp.AsInteropError() + invokeFailure.DefaultErrorResponse = interop.GetErrorResponseWithFormattedErrorMessage(invokeFailure.ErrorType, invokeFailure.ErrorMessage, invokeRequest.ID) // Invoke with extensions disabled maintains behaviour parity with pre-extensions rapid if !extensions.AreEnabled() { @@ -50,7 +41,7 @@ func handleInvokeError(execCtx *rapidContext, invokeRequest *interop.Invoke, inv func newInvokeFailureMsg(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) if !found { - errorType = fatalerror.Unknown + errorType = fatalerror.SandboxFailure } invokeFailure := &interop.InvokeFailure{ @@ -64,6 +55,7 @@ func newInvokeFailureMsg(execCtx *rapidContext, invokeRequest *interop.Invoke, i } if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { + invokeFailure.ResponseMetrics.RuntimeResponseLatencyMs = telemetry.CalculateDuration(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics.StartReadingResponseMonoTimeMs) invokeFailure.ResponseMetrics.RuntimeTimeThrottledMs = invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) invokeFailure.ResponseMetrics.RuntimeProducedBytes = invokeRequest.InvokeResponseMetrics.ProducedBytes invokeFailure.ResponseMetrics.RuntimeOutboundThroughputBps = invokeRequest.InvokeResponseMetrics.OutboundThroughputBps @@ -80,13 +72,15 @@ func newInvokeFailureMsg(execCtx *rapidContext, invokeRequest *interop.Invoke, i invokeFailure.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) } + invokeFailure.InvokeResponseMode = invokeRequest.InvokeResponseMode + return invokeFailure } func generateInitFailureMsg(execCtx *rapidContext, err error) interop.InitFailure { errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) if !found { - errorType = fatalerror.Unknown + errorType = fatalerror.SandboxFailure } initFailureMsg := interop.InitFailure{ diff --git a/lambda/rapid/start.go b/lambda/rapid/handlers.go similarity index 58% rename from lambda/rapid/start.go rename to lambda/rapid/handlers.go index 76337af..f379c4c 100644 --- a/lambda/rapid/start.go +++ b/lambda/rapid/handlers.go @@ -5,6 +5,7 @@ package rapid import ( + "bytes" "context" "errors" "fmt" @@ -22,22 +23,20 @@ import ( "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi" - "go.amzn.com/lambda/rapi/model" "go.amzn.com/lambda/rapi/rendering" + "go.amzn.com/lambda/rapidcore/env" supvmodel "go.amzn.com/lambda/supervisor/model" "go.amzn.com/lambda/telemetry" "github.com/google/uuid" - log "github.com/sirupsen/logrus" ) const ( - RuntimeDomain = "runtime" - OperatorDomain = "operator" - defaultAgentLocation = "/opt/extensions" - disableExtensionsFile = "/opt/disable-extensions-jwigqn8j" - runtimeProcessName = "runtime" + RuntimeDomain = "runtime" + OperatorDomain = "operator" + defaultAgentLocation = "/opt/extensions" + runtimeProcessName = "runtime" ) const ( @@ -48,14 +47,17 @@ const ( var errResetReceived = errors.New("errResetReceived") +type processSupervisor struct { + supvmodel.ProcessSupervisor + RootPath string +} + type rapidContext struct { interopServer interop.Server server *rapi.Server appCtx appctx.ApplicationContext - preLoadTimeNs int64 - postLoadTimeNs int64 initDone bool - supervisor supvmodel.Supervisor + supervisor processSupervisor runtimeDomainGeneration uint32 initFlow core.InitFlowSynchronization invokeFlow core.InvokeFlowSynchronization @@ -67,12 +69,16 @@ type rapidContext struct { logsEgressAPI telemetry.StdLogsEgressAPI xray telemetry.Tracer standaloneMode bool - eventsAPI telemetry.EventsAPI + eventsAPI interop.EventsAPI initCachingEnabled bool credentialsService core.CredentialsService - signalCtx context.Context - executionMutex sync.Mutex + handlerExecutionMutex sync.Mutex shutdownContext *shutdownContext + logStreamName string + + RuntimeStartedTime int64 + RuntimeOverheadStartedTime int64 + InvokeResponseMetrics *interop.InvokeResponseMetrics } // Validate interface compliance @@ -105,7 +111,13 @@ func (c *rapidContext) GetExtensionNames() string { func logAgentsInitStatus(execCtx *rapidContext) { for _, agent := range execCtx.registrationService.AgentsInfo() { - execCtx.eventsAPI.SendExtensionInit(agent.Name, agent.State, agent.ErrorType, agent.Subscriptions) + extensionInitData := interop.ExtensionInitData{ + AgentName: agent.Name, + State: agent.State, + ErrorType: agent.ErrorType, + Subscriptions: agent.Subscriptions, + } + execCtx.eventsAPI.SendExtensionInit(extensionInitData) } } @@ -116,7 +128,7 @@ func agentLaunchError(agent *core.ExternalAgent, appCtx appctx.ApplicationContex appctx.StoreFirstFatalError(appCtx, fatalerror.AgentLaunchError) } -func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, env interop.EnvironmentVariables) error { +func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, env *env.Environment) error { initFlow := execCtx.registrationService.InitFlow() // we don't bring it into the loop below because we don't want unnecessary broadcasts on agent gate @@ -127,7 +139,6 @@ func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, for _, agentPath := range agentPaths { // Using path.Base(agentPath) not agentName because the agent name is contact, as standalone can get the internal state. agent, err := execCtx.registrationService.CreateExternalAgent(path.Base(agentPath)) - if err != nil { return err } @@ -140,21 +151,27 @@ func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, env := env.AgentExecEnv() agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() - if err != nil { return err } agentName := fmt.Sprintf("extension-%s-%d", path.Base(agentPath), execCtx.runtimeDomainGeneration) - err = execCtx.supervisor.Exec(&supvmodel.ExecRequest{ - Domain: domain, - Name: agentName, - Path: agentPath, - Env: &env, + err = execCtx.supervisor.Exec(context.Background(), &supvmodel.ExecRequest{ + Domain: domain, + Name: agentName, + Path: agentPath, + Env: &env, + Logging: supvmodel.Logging{ + Managed: supvmodel.ManagedLogging{ + Topic: supvmodel.RtExtensionManagedLoggingTopic, + Formats: []supvmodel.ManagedLoggingFormat{ + supvmodel.LineBasedManagedLogging, + }, + }, + }, StdoutWriter: agentStdoutWriter, StderrWriter: agentStderrWriter, }) - if err != nil { agentLaunchError(agent, execCtx.appCtx, err) return err @@ -177,7 +194,7 @@ func doRuntimeBootstrap(execCtx *rapidContext, sbInfoFromInit interop.SandboxInf if err != nil { if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.eventsAPI.SendImageErrorLog(formattedLog) + execCtx.eventsAPI.SendImageErrorLog(interop.ImageErrorLogData(formattedLog)) } else { appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) } @@ -189,7 +206,7 @@ func doRuntimeBootstrap(execCtx *rapidContext, sbInfoFromInit interop.SandboxInf if err != nil { if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.eventsAPI.SendImageErrorLog(formattedLog) + execCtx.eventsAPI.SendImageErrorLog(interop.ImageErrorLogData(formattedLog)) } else { appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidWorkingDir) } @@ -201,95 +218,69 @@ func doRuntimeBootstrap(execCtx *rapidContext, sbInfoFromInit interop.SandboxInf return bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, nil } -func (c *rapidContext) setupEventsWatcher(events <-chan supvmodel.Event) { - go func() { - for event := range events { - var err error = nil - log.Debugf("The events handler received the event %+v.", event) - if loss := event.Event.EventLoss(); loss != nil { - log.Panicf("Lost %d events from supervisor", *loss) - } - termination := event.Event.ProcessTerminated() - - // If we are not shutting down then we care if an unexpected exit happens. - if !c.shutdownContext.isShuttingDown() { - runtimeProcessName := fmt.Sprintf("%s-%d", runtimeProcessName, c.runtimeDomainGeneration) - - // If event from the runtime. - if *termination.Name == runtimeProcessName { - if termination.Success() { - err = fmt.Errorf("Runtime exited without providing a reason") - } else { - err = fmt.Errorf("Runtime exited with error: %s", termination.String()) - } - appctx.StoreFirstFatalError(c.appCtx, fatalerror.RuntimeExit) - } else { - if termination.Success() { - err = fmt.Errorf("exit code 0") - } else { - err = fmt.Errorf(termination.String()) - } +func (c *rapidContext) watchEvents(events <-chan supvmodel.Event) { + for event := range events { + var err error + log.Debugf("The events handler received the event %+v.", event) + if loss := event.Event.EventLoss(); loss != nil { + log.Panicf("Lost %d events from supervisor", *loss) + } + termination := event.Event.ProcessTerminated() - appctx.StoreFirstFatalError(c.appCtx, fatalerror.AgentCrash) + // If we are not shutting down then we care if an unexpected exit happens. + if !c.shutdownContext.isShuttingDown() { + runtimeProcessName := fmt.Sprintf("%s-%d", runtimeProcessName, c.runtimeDomainGeneration) + + // If event from the runtime. + if *termination.Name == runtimeProcessName { + if termination.Success() { + err = fmt.Errorf("Runtime exited without providing a reason") + } else { + err = fmt.Errorf("Runtime exited with error: %s", termination.String()) + } + appctx.StoreFirstFatalError(c.appCtx, fatalerror.RuntimeExit) + } else { + if termination.Success() { + err = fmt.Errorf("exit code 0") + } else { + err = fmt.Errorf(termination.String()) } - log.Warnf("Process %s exited: %+v", *termination.Name, termination) + appctx.StoreFirstFatalError(c.appCtx, fatalerror.AgentCrash) } - // At the moment we only get termination events. - // When their are other event types then we would need to be selective, - // about what we send to handleShutdownEvent(). - c.shutdownContext.handleProcessExit(*termination) - c.registrationService.CancelFlows(err) + log.Warnf("Process %s exited: %+v", *termination.Name, termination) } - }() -} - -func doOperatorDomainInit(ctx context.Context, execCtx *rapidContext, operatorDomainExtraConfig interop.DynamicDomainConfig) error { - events, err := execCtx.supervisor.Events() - if err != nil { - log.WithError(err).Panic("Could not get events stream from supervsior") - } - execCtx.setupEventsWatcher(events) - - log.Info("Configuring and starting Operator Domain") - conf := operatorDomainExtraConfig - err = execCtx.supervisor.Configure(&supvmodel.ConfigureRequest{ - Domain: OperatorDomain, - AdditionalStartHooks: conf.AdditionalStartHooks, - Mounts: conf.Mounts, - }) - - if err != nil { - log.WithError(err).Error("Failed to configure operator domain") - return err - } - - err = execCtx.supervisor.Start(&supvmodel.StartRequest{ - Domain: OperatorDomain, - }) - if err != nil { - log.WithError(err).Error("Failed to start operator domain") - return err + // At the moment we only get termination events. + // When their are other event types then we would need to be selective, + // about what we send to handleShutdownEvent(). + c.shutdownContext.handleProcessExit(*termination) + c.registrationService.CancelFlows(err) } +} - // we configure the runtime domain only once and not at - // every init phase (e.g., suppressed or reset). - err = execCtx.supervisor.Configure(&supvmodel.ConfigureRequest{ +// subscribe to /events for runtime domain in supervisor +func setupEventsWatcher(execCtx *rapidContext) error { + eventsRequest := supvmodel.EventsRequest{ Domain: RuntimeDomain, - }) + } + events, err := execCtx.supervisor.Events(context.Background(), &eventsRequest) if err != nil { - log.WithError(err).Error("Failed to configure operator domain") + log.Errorf("Could not get events stream from supervisor: %s", err) return err } + go execCtx.watchEvents(events) return nil - } -func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit) error { +func doRuntimeDomainInit(execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit, phase interop.LifecyclePhase) error { + initStartTime := metering.Monotime() + sendInitStartLogEvent(execCtx, sbInfoFromInit.SandboxType, phase) + defer sendInitReportLogEvent(execCtx, sbInfoFromInit.SandboxType, initStartTime, phase) + execCtx.xray.RecordInitStartTime() defer execCtx.xray.RecordInitEndTime() @@ -299,18 +290,11 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI } }() - log.Info("Starting runtime domain") - err := execCtx.supervisor.Start(&supvmodel.StartRequest{ - Domain: RuntimeDomain, - }) - if err != nil { - log.WithError(err).Panic("Failed configuring runtime domain") - } execCtx.runtimeDomainGeneration++ if extensions.AreEnabled() { runtimeExtensions := agents.ListExternalAgentPaths(defaultAgentLocation, - execCtx.supervisor.RuntimeConfig.RootPath) + execCtx.supervisor.RootPath) if err := doInitExtensions(RuntimeDomain, runtimeExtensions, execCtx, sbInfoFromInit.EnvironmentVariables); err != nil { return err } @@ -328,20 +312,17 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI // runtime is implicitly subscribed for certain lifecycle events. log.Debug("Preregister runtime") registrationService := execCtx.registrationService - err = registrationService.PreregisterRuntime(runtime) - + err := registrationService.PreregisterRuntime(runtime) if err != nil { return err } bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, err := doRuntimeBootstrap(execCtx, sbInfoFromInit) - if err != nil { return err } runtimeStdoutWriter, runtimeStderrWriter, err := execCtx.logsEgressAPI.GetRuntimeSockets() - if err != nil { return err } @@ -349,13 +330,23 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI log.Debug("Start runtime") checkCredentials(execCtx, bootstrapEnv) name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) - err = execCtx.supervisor.Exec(&supvmodel.ExecRequest{ - Domain: RuntimeDomain, - Name: name, - Cwd: &bootstrapCwd, - Path: bootstrapCmd[0], - Args: bootstrapCmd[1:], - Env: &bootstrapEnv, + + err = execCtx.supervisor.Exec(context.Background(), &supvmodel.ExecRequest{ + Domain: RuntimeDomain, + Name: name, + Cwd: &bootstrapCwd, + Path: bootstrapCmd[0], + Args: bootstrapCmd[1:], + Env: &bootstrapEnv, + Logging: supvmodel.Logging{ + Managed: supvmodel.ManagedLogging{ + Topic: supvmodel.RuntimeManagedLoggingTopic, + Formats: []supvmodel.ManagedLoggingFormat{ + supvmodel.LineBasedManagedLogging, + supvmodel.MessageBasedManagedLogging, + }, + }, + }, StdoutWriter: runtimeStdoutWriter, StderrWriter: runtimeStderrWriter, ExtraFiles: &bootstrapExtraFiles, @@ -364,25 +355,25 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI runtimeDoneStatus := telemetry.RuntimeDoneSuccess defer func() { - sendInitRuntimeDoneLogEvent(execCtx, sbInfoFromInit.SandboxType, runtimeDoneStatus) + sendInitRuntimeDoneLogEvent(execCtx, sbInfoFromInit.SandboxType, runtimeDoneStatus, phase) }() if err != nil { if fatalError, formattedLog, hasError := sbInfoFromInit.RuntimeBootstrap.CachedFatalError(err); hasError { appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.eventsAPI.SendImageErrorLog(formattedLog) + execCtx.eventsAPI.SendImageErrorLog(interop.ImageErrorLogData(formattedLog)) } else { appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) } - runtimeDoneStatus = telemetry.RuntimeDoneFailure + runtimeDoneStatus = telemetry.RuntimeDoneError return err } execCtx.shutdownContext.createExitedChannel(name) if err := initFlow.AwaitRuntimeRestoreReady(); err != nil { - runtimeDoneStatus = telemetry.RuntimeDoneFailure + runtimeDoneStatus = telemetry.RuntimeDoneError return err } @@ -396,6 +387,7 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI return err } if err := initFlow.AwaitAgentsReady(); err != nil { + runtimeDoneStatus = telemetry.RuntimeDoneError return err } } @@ -411,25 +403,34 @@ func doRuntimeDomainInit(ctx context.Context, execCtx *rapidContext, sbInfoFromI return nil } -func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop.Invoke, mx *invokeMetrics, sbInfoFromInit interop.SandboxInfoFromInit) error { - execCtx.eventsAPI.SetCurrentRequestID(invokeRequest.ID) +func doInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, mx *invokeMetrics, sbInfoFromInit interop.SandboxInfoFromInit, requestBuffer *bytes.Buffer) error { + execCtx.eventsAPI.SetCurrentRequestID(interop.RequestID(invokeRequest.ID)) appCtx := execCtx.appCtx xray := execCtx.xray xray.Configure(invokeRequest) + ctx := context.Background() + return xray.CaptureInvokeSegment(ctx, xray.WithErrorCause(ctx, appCtx, func(ctx context.Context) error { + telemetryTracingCtx := xray.BuildTracingCtxForStart() + if !execCtx.initDone { // do inline init if err := xray.CaptureInitSubsegment(ctx, func(ctx context.Context) error { - return doRuntimeDomainInit(ctx, execCtx, sbInfoFromInit) + return doRuntimeDomainInit(execCtx, sbInfoFromInit, interop.LifecyclePhaseInvoke) }); err != nil { + sendInvokeStartLogEvent(execCtx, invokeRequest.ID, telemetryTracingCtx) return err } - } else if sbInfoFromInit.SandboxType != interop.SandboxPreWarmed { + } else if sbInfoFromInit.SandboxType != interop.SandboxPreWarmed && !execCtx.initCachingEnabled { xray.SendInitSubsegmentWithRecordedTimesOnce(ctx) } + xray.SendRestoreSubsegmentWithRecordedTimesOnce(ctx) + + sendInvokeStartLogEvent(execCtx, invokeRequest.ID, telemetryTracingCtx) + invokeFlow := execCtx.invokeFlow log.Debug("Initialize invoke flow barriers") err := invokeFlow.InitializeBarriers() @@ -453,7 +454,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop // Invoke if err := xray.CaptureInvokeSubsegment(ctx, xray.WithError(ctx, appCtx, func(ctx context.Context) error { log.Debug("Set renderer for invoke") - renderer := rendering.NewInvokeRenderer(ctx, invokeRequest, xray.TracingHeaderParser()) + renderer := rendering.NewInvokeRenderer(ctx, invokeRequest, requestBuffer, xray.BuildTracingHeader()) defer func() { mx.rendererMetrics = renderer.GetMetrics() }() @@ -473,6 +474,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop log.Debug("Release runtime condition") //TODO handle Supervisors listening channel + execCtx.SetRuntimeStartedTime(metering.Monotime()) runtime.Release() log.Debug("Await runtime response") //TODO handle Supervisors listening channel @@ -484,6 +486,7 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop // Runtime overhead if err := xray.CaptureOverheadSubsegment(ctx, func(ctx context.Context) error { log.Debug("Await runtime ready") + execCtx.SetRuntimeOverheadStartedTime(metering.Monotime()) //TODO handle Supervisors listening channel return invokeFlow.AwaitRuntimeReady() }); err != nil { @@ -491,19 +494,21 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop } mx.runtimeReadyTime = metering.Monotime() - runtimeDoneEventData := telemetry.InvokeRuntimeDoneData{ + runtimeDoneEventData := interop.InvokeRuntimeDoneData{ Status: telemetry.RuntimeDoneSuccess, - Metrics: telemetry.GetRuntimeDoneInvokeMetrics(invokeRequest.InvokeReceivedTime, invokeRequest.InvokeResponseMetrics, mx.runtimeReadyTime), + Metrics: telemetry.GetRuntimeDoneInvokeMetrics(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics, mx.runtimeReadyTime), InternalMetrics: invokeRequest.InvokeResponseMetrics, - Tracing: telemetry.BuildTracingCtx(model.XRayTracingType, invokeRequest.TraceID, invokeRequest.LambdaSegmentID), - Spans: telemetry.GetRuntimeDoneSpans(invokeRequest.InvokeReceivedTime, invokeRequest.InvokeResponseMetrics), + Tracing: xray.BuildTracingCtxAfterInvokeComplete(), + Spans: execCtx.eventsAPI.GetRuntimeDoneSpans(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics, execCtx.RuntimeOverheadStartedTime, mx.runtimeReadyTime), } - if err := execCtx.eventsAPI.SendRuntimeDone(runtimeDoneEventData); err != nil { - log.Errorf("Failed to send RUNDONE: %s", err) + log.Info(runtimeDoneEventData.String()) + if err := execCtx.eventsAPI.SendInvokeRuntimeDone(runtimeDoneEventData); err != nil { + log.Errorf("Failed to send INVOKE RTDONE: %s", err) } // Extensions overhead if execCtx.HasActiveExtensions() { + extensionOverheadStartTime := metering.Monotime() execCtx.interopServer.SendRuntimeReady() log.Debug("Await agents ready") //TODO handle Supervisors listening channel @@ -511,18 +516,21 @@ func doInvoke(ctx context.Context, execCtx *rapidContext, invokeRequest *interop log.Warnf("AwaitAgentsReady() = %s", err) return err } + extensionOverheadEndTime := metering.Monotime() + extensionOverheadMsSpan := interop.Span{ + Name: "extensionOverhead", + Start: telemetry.GetEpochTimeInISO8601FormatFromMonotime(extensionOverheadStartTime), + DurationMs: telemetry.CalculateDuration(extensionOverheadStartTime, extensionOverheadEndTime), + } + if err := execCtx.eventsAPI.SendReportSpan(extensionOverheadMsSpan); err != nil { + log.WithError(err).Error("Failed to create REPORT Span") + } } return nil })) } -func extensionsDisabledByLayer() bool { - _, err := os.Stat(disableExtensionsFile) - log.Infof("extensionsDisabledByLayer(%s) -> %s", disableExtensionsFile, err) - return err == nil -} - // acceptInitRequest is a second initialization phase, performed after receiving START // initialized entities: _HANDLER, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN func (c *rapidContext) acceptInitRequest(initRequest *interop.Init) *interop.Init { @@ -535,15 +543,14 @@ func (c *rapidContext) acceptInitRequest(initRequest *interop.Init) *interop.Ini initRequest.FunctionName, initRequest.FunctionVersion) c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - FunctionName: initRequest.FunctionName, - FunctionVersion: initRequest.FunctionVersion, - Handler: initRequest.Handler, - RuntimeInfo: initRequest.RuntimeInfo, + AccountID: initRequest.AccountID, + FunctionName: initRequest.FunctionName, + FunctionVersion: initRequest.FunctionVersion, + InstanceMaxMemory: initRequest.InstanceMaxMemory, + Handler: initRequest.Handler, + RuntimeInfo: initRequest.RuntimeInfo, }) - - if extensionsDisabledByLayer() { - extensions.Disable() - } + c.SetLogStreamName(initRequest.LogStreamName) return initRequest } @@ -568,26 +575,21 @@ func (c *rapidContext) acceptInitRequestForInitCaching(initRequest *interop.Init initCachingToken) c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - FunctionName: initRequest.FunctionName, - FunctionVersion: initRequest.FunctionVersion, - Handler: initRequest.Handler, + AccountID: initRequest.AccountID, + FunctionName: initRequest.FunctionName, + FunctionVersion: initRequest.FunctionVersion, + InstanceMaxMemory: initRequest.InstanceMaxMemory, + Handler: initRequest.Handler, + RuntimeInfo: initRequest.RuntimeInfo, }) + c.SetLogStreamName(initRequest.LogStreamName) c.credentialsService.SetCredentials(initCachingToken, initRequest.AwsKey, initRequest.AwsSecret, initRequest.AwsSession, initRequest.CredentialsExpiry) - if extensionsDisabledByLayer() { - extensions.Disable() - } - return initRequest, nil } -func handleInit(execCtx *rapidContext, initRequest *interop.Init, - initStartedResponse chan<- interop.InitStarted, - initSuccessResponse chan<- interop.InitSuccess, - initFailureResponse chan<- interop.InitFailure) { - ctx := execCtx.signalCtx - +func handleInit(execCtx *rapidContext, initRequest *interop.Init, initSuccessResponse chan<- interop.InitSuccess, initFailureResponse chan<- interop.InitFailure) { if execCtx.initCachingEnabled { var err error if initRequest, err = execCtx.acceptInitRequestForInitCaching(initRequest); err != nil { @@ -600,23 +602,7 @@ func handleInit(execCtx *rapidContext, initRequest *interop.Init, initRequest = execCtx.acceptInitRequest(initRequest) } - initStartedMsg := interop.InitStarted{ - PreLoadTimeNs: execCtx.preLoadTimeNs, - PostLoadTimeNs: execCtx.postLoadTimeNs, - WaitStartTimeNs: execCtx.postLoadTimeNs, - WaitEndTimeNs: metering.Monotime(), - ExtensionsEnabled: extensions.AreEnabled(), - Ack: make(chan struct{}), - } - - initStartedResponse <- initStartedMsg - <-initStartedMsg.Ack - - // Operator domain init happens only once, it's never suppressed, - // and it's terminal in case of failures - if err := doOperatorDomainInit(ctx, execCtx, initRequest.OperatorDomainExtraConfig); err != nil { - // TODO: I believe we need to handle this specially, because we want - // to consider any failure here as terminal + if err := setupEventsWatcher(execCtx); err != nil { handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) return } @@ -628,7 +614,7 @@ func handleInit(execCtx *rapidContext, initRequest *interop.Init, SandboxType: initRequest.SandboxType, RuntimeBootstrap: initRequest.Bootstrap, } - if err := doRuntimeDomainInit(ctx, execCtx, sbInfo); err != nil { + if err := doRuntimeDomainInit(execCtx, sbInfo, interop.LifecyclePhaseInit); err != nil { handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) return } @@ -649,16 +635,18 @@ func handleInit(execCtx *rapidContext, initRequest *interop.Init, <-initSuccessMsg.Ack } -func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { - ctx := execCtx.signalCtx +func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit, requestBuffer *bytes.Buffer, responseSender interop.InvokeResponseSender) (interop.InvokeSuccess, *interop.InvokeFailure) { + appctx.StoreResponseSender(execCtx.appCtx, responseSender) invokeMx := invokeMetrics{} - if err := doInvoke(ctx, execCtx, invokeRequest, &invokeMx, sbInfoFromInit); err != nil { + if err := doInvoke(execCtx, invokeRequest, &invokeMx, sbInfoFromInit, requestBuffer); err != nil { log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Invoke failed") invokeFailure := handleInvokeError(execCtx, invokeRequest, &invokeMx, err) + invokeFailure.InvokeResponseMode = invokeRequest.InvokeResponseMode if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { invokeFailure.ResponseMetrics = interop.ResponseMetrics{ + RuntimeResponseLatencyMs: telemetry.CalculateDuration(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics.StartReadingResponseMonoTimeMs), RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, @@ -683,10 +671,12 @@ func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFr }, InvokeCompletionTimeNs: invokeCompletionTimeNs, InvokeReceivedTime: invokeRequest.InvokeReceivedTime, + InvokeResponseMode: invokeRequest.InvokeResponseMode, } if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { invokeSuccessMsg.ResponseMetrics = interop.ResponseMetrics{ + RuntimeResponseLatencyMs: telemetry.CalculateDuration(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics.StartReadingResponseMonoTimeMs), RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, @@ -701,7 +691,7 @@ func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFr } func reinitialize(execCtx *rapidContext) { - execCtx.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) + execCtx.appCtx.Delete(appctx.AppCtxInvokeErrorTraceDataKey) execCtx.appCtx.Delete(appctx.AppCtxRuntimeReleaseKey) execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) execCtx.renderingService.SetRenderer(nil) @@ -716,32 +706,46 @@ func reinitialize(execCtx *rapidContext) { } // handle notification of reset -func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { +func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, runtimeStartedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { log.Warnf("Reset initiated: %s", resetEvent.Reason) // Only send RuntimeDone event if we get a reset during an Invoke if resetEvent.Reason == "failure" || resetEvent.Reason == "timeout" { - runtimeDoneEventData := telemetry.InvokeRuntimeDoneData{ - Status: resetEvent.Reason, + var errorType *string + if resetEvent.Reason == "failure" { + firstFatalError, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + firstFatalError = fatalerror.SandboxFailure + } + stringifiedError := string(firstFatalError) + errorType = &stringifiedError + } + + var status string + if resetEvent.Reason == "timeout" { + status = "timeout" + } else if strings.HasPrefix(*errorType, "Sandbox.") { + status = "failure" + } else { + status = "error" + } + + var runtimeReadyTime int64 = metering.Monotime() + runtimeDoneEventData := interop.InvokeRuntimeDoneData{ + Status: status, InternalMetrics: invokeResponseMetrics, - Metrics: telemetry.GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, metering.Monotime()), - Tracing: telemetry.BuildTracingCtx(model.XRayTracingType, resetEvent.TraceID, resetEvent.LambdaSegmentID), - Spans: telemetry.GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics), + Metrics: telemetry.GetRuntimeDoneInvokeMetrics(runtimeStartedTime, invokeResponseMetrics, runtimeReadyTime), + Tracing: execCtx.xray.BuildTracingCtxAfterInvokeComplete(), + Spans: execCtx.eventsAPI.GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics, execCtx.RuntimeOverheadStartedTime, runtimeReadyTime), + ErrorType: errorType, } - if err := execCtx.eventsAPI.SendRuntimeDone(runtimeDoneEventData); err != nil { - log.Errorf("Failed to send RUNDONE: %s", err) + if err := execCtx.eventsAPI.SendInvokeRuntimeDone(runtimeDoneEventData); err != nil { + log.Errorf("Failed to send INVOKE RTDONE: %s", err) } } extensionsResetMs, resetTimeout, _ := execCtx.shutdownContext.shutdown(execCtx, resetEvent.DeadlineNs, resetEvent.Reason) - log.Info("Starting runtime domain") - err := execCtx.supervisor.Start(&supvmodel.StartRequest{ - Domain: RuntimeDomain, - }) - if err != nil { - log.WithError(err).Panic("Failed booting runtime domain") - } execCtx.runtimeDomainGeneration++ // Only used by standalone for more indepth assertions. @@ -751,8 +755,12 @@ func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, invokeReceive fatalErrorType, _ = appctx.LoadFirstFatalError(execCtx.appCtx) } + // TODO: move interop.ResponseMetrics{} to a factory method and initialize it there. + // Initialization is very similar in handleInvoke's invokeFailure.ResponseMetrics and + // invokeSuccessMsg.ResponseMetrics var responseMetrics interop.ResponseMetrics if resetEvent.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(resetEvent.InvokeResponseMetrics) { + responseMetrics.RuntimeResponseLatencyMs = telemetry.CalculateDuration(execCtx.RuntimeStartedTime, resetEvent.InvokeResponseMetrics.StartReadingResponseMonoTimeMs) responseMetrics.RuntimeTimeThrottledMs = resetEvent.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) responseMetrics.RuntimeProducedBytes = resetEvent.InvokeResponseMetrics.ProducedBytes responseMetrics.RuntimeOutboundThroughputBps = resetEvent.InvokeResponseMetrics.OutboundThroughputBps @@ -760,16 +768,18 @@ func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, invokeReceive if resetTimeout { return interop.ResetSuccess{}, &interop.ResetFailure{ - ExtensionsResetMs: extensionsResetMs, - ErrorType: fatalErrorType, - ResponseMetrics: responseMetrics, + ExtensionsResetMs: extensionsResetMs, + ErrorType: fatalErrorType, + ResponseMetrics: responseMetrics, + InvokeResponseMode: resetEvent.InvokeResponseMode, } } return interop.ResetSuccess{ - ExtensionsResetMs: extensionsResetMs, - ErrorType: fatalErrorType, - ResponseMetrics: responseMetrics, + ExtensionsResetMs: extensionsResetMs, + ErrorType: fatalErrorType, + ResponseMetrics: responseMetrics, + InvokeResponseMode: resetEvent.InvokeResponseMode, }, nil } @@ -789,75 +799,199 @@ func handleShutdown(execCtx *rapidContext, shutdownEvent *interop.Shutdown, reas return interop.ShutdownSuccess{ErrorType: fatalErrorType} } -func handleRestore(execCtx *rapidContext, restore *interop.Restore) error { +func handleRestore(execCtx *rapidContext, restore *interop.Restore) (interop.RestoreResult, error) { err := execCtx.credentialsService.UpdateCredentials(restore.AwsKey, restore.AwsSecret, restore.AwsSession, restore.CredentialsExpiry) restoreStatus := telemetry.RuntimeDoneSuccess + restoreResult := interop.RestoreResult{} + defer func() { sendRestoreRuntimeDoneLogEvent(execCtx, restoreStatus) }() if err != nil { - return fmt.Errorf("error when updating credentials: %s", err) + log.Infof("error when updating credentials: %s", err) + return restoreResult, interop.ErrRestoreUpdateCredentials } + renderer := rendering.NewRestoreRenderer() execCtx.renderingService.SetRenderer(renderer) registrationService := execCtx.registrationService runtime := registrationService.GetRuntime() + execCtx.SetLogStreamName(restore.LogStreamName) + // If runtime has not called /restore/next then just return // instead of releasing the Runtime since there is no need to release. // Then the runtime should be released only during Invoke if runtime.GetState() != runtime.RuntimeRestoreReadyState { restoreStatus = telemetry.RuntimeDoneSuccess log.Infof("Runtime is in state: %s just returning", runtime.GetState().Name()) - return nil + + return restoreResult, nil } + deadlineNs := time.Now().Add(time.Duration(restore.RestoreHookTimeoutMs) * time.Millisecond).UnixNano() + + ctx, ctxCancel := context.WithDeadline(context.Background(), time.Unix(0, deadlineNs)) + + defer ctxCancel() + + startTime := metering.Monotime() + runtime.Release() initFlow := execCtx.initFlow - err = initFlow.AwaitRuntimeReady() + err = initFlow.AwaitRuntimeReadyWithDeadline(ctx) + + fatalErrorType, fatalErrorFound := appctx.LoadFirstFatalError(execCtx.appCtx) + + // If there is an error occured when waiting runtime to complete the restore hook execution, + // check if there is any error stored in appctx to get the root cause error type + // Runtime.ExitError is an example to such a scenario + if fatalErrorFound { + err = fmt.Errorf(string(fatalErrorType)) + } if err != nil { - restoreStatus = telemetry.RuntimeDoneFailure - } else { - restoreStatus = telemetry.RuntimeDoneSuccess + restoreStatus = telemetry.RuntimeDoneError } - return err + endTime := metering.Monotime() + restoreDuration := time.Duration(endTime - startTime) + restoreResult.RestoreMs = restoreDuration.Milliseconds() + + return restoreResult, err } -func start(signalCtx context.Context, execCtx *rapidContext) { +func startRuntimeAPI(ctx context.Context, execCtx *rapidContext) { // Start Runtime API Server err := execCtx.server.Listen() if err != nil { log.WithError(err).Panic("Runtime API Server failed to listen") } - go func() { execCtx.server.Serve(signalCtx) }() + execCtx.server.Serve(ctx) // blocking until server exits // Note, most of initialization code should run before blocking to receive START, // code before START runs in parallel with code downloads. } +func getFirstFatalError(execCtx *rapidContext, status string) *string { + if status == telemetry.RuntimeDoneSuccess { + return nil + } + + firstFatalError, found := appctx.LoadFirstFatalError(execCtx.appCtx) + if !found { + // We will set errorType to "Runtime.Unknown" in case of INIT timeout and RESTORE timeout + // This is a trade-off we are willing to make. We will improve this later + firstFatalError = fatalerror.RuntimeUnknown + } + stringifiedError := string(firstFatalError) + return &stringifiedError +} + func sendRestoreRuntimeDoneLogEvent(execCtx *rapidContext, status string) { - if err := execCtx.eventsAPI.SendRestoreRuntimeDone(status); err != nil { - log.Errorf("Failed to send RESTRD: %s", err) + firstFatalError := getFirstFatalError(execCtx, status) + + restoreRuntimeDoneData := interop.RestoreRuntimeDoneData{ + Status: status, + ErrorType: firstFatalError, + } + + if err := execCtx.eventsAPI.SendRestoreRuntimeDone(restoreRuntimeDoneData); err != nil { + log.Errorf("Failed to send RESTORE RTDONE: %s", err) + } +} + +func sendInitStartLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, phase interop.LifecyclePhase) { + initPhase, err := telemetry.InitPhaseFromLifecyclePhase(phase) + if err != nil { + log.Errorf("failed to convert lifecycle phase into init phase: %s", err) + return + } + + functionMetadata := execCtx.registrationService.GetFunctionMetadata() + initStartData := interop.InitStartData{ + InitializationType: telemetry.InferInitType(execCtx.initCachingEnabled, sandboxType), + RuntimeVersion: functionMetadata.RuntimeInfo.Version, + RuntimeVersionArn: functionMetadata.RuntimeInfo.Arn, + FunctionName: functionMetadata.FunctionName, + FunctionVersion: functionMetadata.FunctionVersion, + // based on https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/resource/semantic_conventions/faas.md + // we're sending the logStream as the instance id + InstanceID: execCtx.logStreamName, + InstanceMaxMemory: functionMetadata.InstanceMaxMemory, + Phase: initPhase, + } + log.Info(initStartData.String()) + + if err := execCtx.eventsAPI.SendInitStart(initStartData); err != nil { + log.Errorf("Failed to send INIT START: %s", err) } } -func sendInitRuntimeDoneLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, status string) { - initSource := interop.InferTelemetryInitSource(execCtx.initCachingEnabled, sandboxType) +func sendInitRuntimeDoneLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, status string, phase interop.LifecyclePhase) { + initPhase, err := telemetry.InitPhaseFromLifecyclePhase(phase) + if err != nil { + log.Errorf("failed to convert lifecycle phase into init phase: %s", err) + return + } + + firstFatalError := getFirstFatalError(execCtx, status) + + initRuntimeDoneData := interop.InitRuntimeDoneData{ + InitializationType: telemetry.InferInitType(execCtx.initCachingEnabled, sandboxType), + Status: status, + Phase: initPhase, + ErrorType: firstFatalError, + } + + log.Info(initRuntimeDoneData.String()) + + if err := execCtx.eventsAPI.SendInitRuntimeDone(initRuntimeDoneData); err != nil { + log.Errorf("Failed to send INIT RTDONE: %s", err) + } +} + +func sendInitReportLogEvent( + execCtx *rapidContext, + sandboxType interop.SandboxType, + initStartMonotime int64, + phase interop.LifecyclePhase, +) { + initPhase, err := telemetry.InitPhaseFromLifecyclePhase(phase) + if err != nil { + log.Errorf("failed to convert lifecycle phase into init phase: %s", err) + return + } + + initReportData := interop.InitReportData{ + InitializationType: telemetry.InferInitType(execCtx.initCachingEnabled, sandboxType), + Metrics: interop.InitReportMetrics{ + DurationMs: telemetry.CalculateDuration(initStartMonotime, metering.Monotime()), + }, + Phase: initPhase, + } + log.Info(initReportData.String()) + + if err = execCtx.eventsAPI.SendInitReport(initReportData); err != nil { + log.Errorf("Failed to send INIT REPORT: %s", err) + } +} - runtimeDoneData := &telemetry.InitRuntimeDoneData{ - InitSource: initSource, - Status: status, +func sendInvokeStartLogEvent(execCtx *rapidContext, invokeRequestID string, tracingCtx *interop.TracingCtx) { + invokeStartData := interop.InvokeStartData{ + RequestID: invokeRequestID, + Version: execCtx.registrationService.GetFunctionMetadata().FunctionVersion, + Tracing: tracingCtx, } + log.Info(invokeStartData.String()) - if err := execCtx.eventsAPI.SendInitRuntimeDone(runtimeDoneData); err != nil { - log.Errorf("Failed to send INITRD: %s", err) + if err := execCtx.eventsAPI.SendInvokeStart(invokeStartData); err != nil { + log.Errorf("Failed to send INVOKE START: %s", err) } } diff --git a/lambda/rapid/handlers_test.go b/lambda/rapid/handlers_test.go new file mode 100644 index 0000000..089dbb7 --- /dev/null +++ b/lambda/rapid/handlers_test.go @@ -0,0 +1,341 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapid + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "regexp" + "strconv" + "strings" + "sync" + "testing" + "time" + + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/core" + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapi" + "go.amzn.com/lambda/rapi/handler" + "go.amzn.com/lambda/rapi/rendering" + "go.amzn.com/lambda/rapidcore/env" + "go.amzn.com/lambda/supervisor/model" + "go.amzn.com/lambda/telemetry" + "go.amzn.com/lambda/testdata" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func BenchmarkChannelsSelect10(b *testing.B) { + c1 := make(chan int) + c2 := make(chan int) + c3 := make(chan int) + c4 := make(chan int) + c5 := make(chan int) + c6 := make(chan int) + c7 := make(chan int) + c8 := make(chan int) + c9 := make(chan int) + c10 := make(chan int) + + for n := 0; n < b.N; n++ { + select { + case <-c1: + case <-c2: + case <-c3: + case <-c4: + case <-c5: + case <-c6: + case <-c7: + case <-c8: + case <-c9: + case <-c10: + default: + } + } +} + +func BenchmarkChannelsSelect2(b *testing.B) { + c1 := make(chan int) + c2 := make(chan int) + + for n := 0; n < b.N; n++ { + select { + case <-c1: + case <-c2: + default: + } + } +} + +func TestGetExtensionNamesWithNoExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + + c := &rapidContext{ + registrationService: rs, + } + + assert.Equal(t, "", c.GetExtensionNames()) +} + +func TestGetExtensionNamesWithMultipleExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + _, _ = rs.CreateExternalAgent("Example1") + _, _ = rs.CreateInternalAgent("Example2") + _, _ = rs.CreateExternalAgent("Example3") + _, _ = rs.CreateInternalAgent("Example4") + + c := &rapidContext{ + registrationService: rs, + } + + r := regexp.MustCompile(`^(Example\d;){3}(Example\d)$`) + assert.True(t, r.MatchString(c.GetExtensionNames())) +} + +func TestGetExtensionNamesWithTooManyExtensions(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + for i := 10; i < 60; i++ { + _, _ = rs.CreateExternalAgent("E" + strconv.Itoa(i)) + } + + c := &rapidContext{ + registrationService: rs, + } + + output := c.GetExtensionNames() + + r := regexp.MustCompile(`^(E\d\d;){30}(E\d\d)$`) + assert.LessOrEqual(t, len(output), maxExtensionNamesLength) + assert.True(t, r.MatchString(output)) +} + +func TestGetExtensionNamesWithTooLongExtensionName(t *testing.T) { + rs := core.NewRegistrationService(nil, nil) + for i := 10; i < 60; i++ { + _, _ = rs.CreateExternalAgent(strings.Repeat("E", 130)) + } + + c := &rapidContext{ + registrationService: rs, + } + + assert.Equal(t, "", c.GetExtensionNames()) +} + +// This test confirms our assumption that http client can establish a tcp connection +// to a listening server. +func TestListen(t *testing.T) { + flowTest := testdata.NewFlowTest() + flowTest.ConfigureForInit() + flowTest.ConfigureForInvoke(context.Background(), &interop.Invoke{ID: "ID", DeadlineNs: "1", Payload: strings.NewReader("MyTest")}) + + ctx := context.Background() + telemetryAPIEnabled := true + server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.TelemetrySubscription, flowTest.TelemetrySubscription, flowTest.CredentialsService) + err := server.Listen() + assert.NoError(t, err) + + defer server.Close() + + go func() { + time.Sleep(time.Second) + fmt.Println("Serving...") + server.Serve(ctx) + }() + + done := make(chan struct{}) + + go func() { + fmt.Println("Connecting...") + resp, err1 := http.Get(fmt.Sprintf("http://%s:%d/2018-06-01/runtime/invocation/next", server.Host(), server.Port())) + assert.Nil(t, err1) + + body, err2 := io.ReadAll(resp.Body) + assert.Nil(t, err2) + + assert.Equal(t, "MyTest", string(body)) + + done <- struct{}{} + }() + + <-done +} + +func makeRapidContext(appCtx appctx.ApplicationContext, initFlow core.InitFlowSynchronization, invokeFlow core.InvokeFlowSynchronization, registrationService core.RegistrationService, supervisor *processSupervisor) *rapidContext { + + appctx.StoreInitType(appCtx, true) + appctx.StoreInteropServer(appCtx, MockInteropServer{}) + + renderingService := rendering.NewRenderingService() + + credentialsService := core.NewCredentialsService() + credentialsService.SetCredentials("token", "key", "secret", "session", time.Now()) + + // Runtime state machine + runtime := core.NewRuntime(initFlow, invokeFlow) + + registrationService.PreregisterRuntime(runtime) + runtime.SetState(runtime.RuntimeRestoreReadyState) + + rapidCtx := &rapidContext{ + // Internally initialized configurations + appCtx: appCtx, + initDone: true, + initFlow: initFlow, + invokeFlow: invokeFlow, + registrationService: registrationService, + renderingService: renderingService, + credentialsService: credentialsService, + handlerExecutionMutex: sync.Mutex{}, + shutdownContext: newShutdownContext(), + eventsAPI: &telemetry.NoOpEventsAPI{}, + } + if supervisor != nil { + rapidCtx.supervisor = *supervisor + } + + return rapidCtx +} + +const hookErrorType = "Runtime.RestoreHookUserErrorType" + +func makeRequest(appCtx appctx.ApplicationContext) *http.Request { + errorBody := []byte("My byte array is yours") + + request := appctx.RequestWithAppCtx(httptest.NewRequest("POST", "/", bytes.NewReader(errorBody)), appCtx) + + request.Header.Set("Content-Type", "application/MyBinaryType") + request.Header.Set("Lambda-Runtime-Function-Error-Type", hookErrorType) + + return request +} + +type MockInteropServer struct{} + +func (server MockInteropServer) GetCurrentInvokeID() string { + return "" +} + +func (server MockInteropServer) SendRuntimeReady() error { + return nil +} + +func (server MockInteropServer) SendInitErrorResponse(response *interop.ErrorInvokeResponse) error { + return nil +} + +func TestRestoreErrorAndAwaitRestoreCompletionRaceCondition(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + invokeFlow := core.NewInvokeFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow, invokeFlow) + + rapidCtx := makeRapidContext(appCtx, initFlow, invokeFlow, registrationService, nil /* don't set process supervisor */) + + // Runtime state machine + runtime := core.NewRuntime(initFlow, invokeFlow) + registrationService.PreregisterRuntime(runtime) + runtime.SetState(runtime.RuntimeRestoreReadyState) + + restore := &interop.Restore{ + AwsKey: "key", + AwsSecret: "secret", + AwsSession: "session", + CredentialsExpiry: time.Now(), + RestoreHookTimeoutMs: 10 * 1000, + } + + var wg sync.WaitGroup + + wg.Add(1) + + go func() { + defer wg.Done() + _, err := rapidCtx.HandleRestore(restore) + assert.Equal(t, err.Error(), "errRestoreHookUserError") + v, ok := err.(interop.ErrRestoreHookUserError) + assert.True(t, ok) + assert.Equal(t, v.UserError.Type, fatalerror.ErrorType(hookErrorType)) + }() + + responseRecorder := httptest.NewRecorder() + + handler := handler.NewRestoreErrorHandler(registrationService) + + request := makeRequest(appCtx) + + wg.Add(1) + + time.Sleep(1 * time.Second) + runtime.SetState(runtime.RuntimeRestoringState) + + go func() { + defer wg.Done() + handler.ServeHTTP(responseRecorder, request) + }() + + wg.Wait() +} + +type MockedProcessSupervisor struct { + mock.Mock +} + +func (supv *MockedProcessSupervisor) Exec(ctx context.Context, req *model.ExecRequest) error { + args := supv.Called(req) + return args.Error(0) +} + +func (supv *MockedProcessSupervisor) Events(ctx context.Context, req *model.EventsRequest) (<-chan model.Event, error) { + args := supv.Called(req) + err := args.Error(1) + if err != nil { + return nil, err + } + return args.Get(0).(<-chan model.Event), nil +} + +func (supv *MockedProcessSupervisor) Terminate(ctx context.Context, req *model.TerminateRequest) error { + args := supv.Called(req) + return args.Error(0) +} + +func (supv *MockedProcessSupervisor) Kill(ctx context.Context, req *model.KillRequest) error { + args := supv.Called(req) + return args.Error(0) +} + +var _ model.ProcessSupervisor = (*MockedProcessSupervisor)(nil) + +func TestSetupEventWatcherErrorHandling(t *testing.T) { + appCtx := appctx.NewApplicationContext() + initFlow := core.NewInitFlowSynchronization() + invokeFlow := core.NewInvokeFlowSynchronization() + registrationService := core.NewRegistrationService(initFlow, invokeFlow) + mockedProcessSupervisor := &MockedProcessSupervisor{} + mockedProcessSupervisor.On("Events", mock.Anything).Return(nil, fmt.Errorf("events call failed")) + procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} + + rapidCtx := makeRapidContext(appCtx, initFlow, invokeFlow, registrationService, procSupv) + + initSuccessResponseChan := make(chan interop.InitSuccess) + initFailureResponseChan := make(chan interop.InitFailure) + init := &interop.Init{EnvironmentVariables: env.NewEnvironment()} + + go assert.NotPanics(t, func() { + rapidCtx.HandleInit(init, initSuccessResponseChan, initFailureResponseChan) + }) + + failure := <-initFailureResponseChan + failure.Ack <- struct{}{} + errorType := interop.InitFailure(failure).ErrorType + assert.Equal(t, fatalerror.SandboxFailure, errorType) +} diff --git a/lambda/rapid/sandbox.go b/lambda/rapid/sandbox.go index 9259514..26eaff0 100644 --- a/lambda/rapid/sandbox.go +++ b/lambda/rapid/sandbox.go @@ -4,22 +4,19 @@ package rapid import ( + "bytes" "context" "fmt" "io" "sync" - "time" "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/core" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" "go.amzn.com/lambda/rapi" "go.amzn.com/lambda/rapi/rendering" supvmodel "go.amzn.com/lambda/supervisor/model" "go.amzn.com/lambda/telemetry" - - log "github.com/sirupsen/logrus" ) type Sandbox struct { @@ -32,18 +29,31 @@ type Sandbox struct { LogsEgressAPI telemetry.StdLogsEgressAPI RuntimeStdoutWriter io.Writer RuntimeStderrWriter io.Writer - PreLoadTimeNs int64 Handler string - SignalCtx context.Context - EventsAPI telemetry.EventsAPI + EventsAPI interop.EventsAPI InitCachingEnabled bool - Supervisor supvmodel.Supervisor + Supervisor supvmodel.ProcessSupervisor + RuntimeFsRootPath string // path to the root of the domain within the root mnt namespace. Reqired to find extensions RuntimeAPIHost string RuntimeAPIPort int } -// Start is a public version of start() that exports only configurable parameters -func Start(s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, string) { +// Start pings Supervisor, and starts the Runtime API server. It allows the caller to configure: +// - Supervisor implementation: performs container construction & process management +// - Telemetry API and Logs API implementation: handling /logs and /telemetry of Runtime API +// - Events API implementation: handles platform log events emitted by Rapid (e.g. RuntimeDone, InitStart) +// - Logs Egress implementation: handling stdout/stderr logs from extension & runtime processes (TODO: remove & unify with Supervisor) +// - Tracer implementation: handling trace segments generate by platform (TODO: remove & unify with Events API) +// - InteropServer implementation: legacy interface for sending internal protocol messages, today only RuntimeReady remains (TODO: move RuntimeReady outside Core) +// - Feature flags: +// - StandaloneMode: indicates if being called by Rapid Core's standalone HTTP frontend (TODO: remove after unifying error reporting) +// - InitCachingEnabled: indicates if handlers must run Init Caching specific logic +// - TelemetryAPIEnabled: indicates if /telemetry and /logs endpoint HTTP handlers must be mounted +// +// - Contexts & Data: +// - ctx is used to gracefully terminate Runtime API HTTP Server on exit +func Start(ctx context.Context, s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, string) { + // Initialize internal state objects required by Rapid handlers appCtx := appctx.NewApplicationContext() initFlow := core.NewInitFlowSynchronization() invokeFlow := core.NewInvokeFlowSynchronization() @@ -53,26 +63,27 @@ func Start(s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, strin appctx.StoreInitType(appCtx, s.InitCachingEnabled) - server := rapi.NewServer(s.RuntimeAPIHost, s.RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.TelemetrySubscriptionAPI, credentialsService, s.EventsAPI) + server := rapi.NewServer(s.RuntimeAPIHost, s.RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.TelemetrySubscriptionAPI, credentialsService) runtimeAPIAddr := fmt.Sprintf("%s:%d", server.Host(), server.Port()) - postLoadTimeNs := metering.Monotime() - // TODO: pass this directly down to HTTP servers and handlers, instead of using // global state to share the interop server implementation appctx.StoreInteropServer(appCtx, s.InteropServer) execCtx := &rapidContext{ - server: server, - appCtx: appCtx, - postLoadTimeNs: postLoadTimeNs, - initDone: false, - initFlow: initFlow, - invokeFlow: invokeFlow, - registrationService: registrationService, - renderingService: renderingService, - credentialsService: credentialsService, - + // Internally initialized configurations + server: server, + appCtx: appCtx, + initDone: false, + initFlow: initFlow, + invokeFlow: invokeFlow, + registrationService: registrationService, + renderingService: renderingService, + credentialsService: credentialsService, + handlerExecutionMutex: sync.Mutex{}, + shutdownContext: newShutdownContext(), + + // Externally specified configurations (i.e. via SandboxBuilder) telemetryAPIEnabled: s.EnableTelemetryAPI, logsSubscriptionAPI: s.LogsSubscriptionAPI, telemetrySubscriptionAPI: s.TelemetrySubscriptionAPI, @@ -80,77 +91,84 @@ func Start(s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, strin interopServer: s.InteropServer, xray: s.Tracer, standaloneMode: s.StandaloneMode, - preLoadTimeNs: s.PreLoadTimeNs, eventsAPI: s.EventsAPI, initCachingEnabled: s.InitCachingEnabled, - signalCtx: s.SignalCtx, - supervisor: s.Supervisor, - executionMutex: sync.Mutex{}, - shutdownContext: newShutdownContext(), + supervisor: processSupervisor{ + ProcessSupervisor: s.Supervisor, + RootPath: s.RuntimeFsRootPath, + }, + + RuntimeStartedTime: -1, + RuntimeOverheadStartedTime: -1, + InvokeResponseMetrics: nil, } - // We call /ping on Supervisor before starting Rapid, since Rapid - // depends on Supervisor setting up networking dependencies - var startupErr error - for retries := 1; retries <= 5; retries++ { - if startupErr = s.Supervisor.Ping(); startupErr == nil { - break - } - // Retry timeout: 5s, same order-of-mag as test client PING retries - // TODO: revisit retry timeout, identify appropriate value for prod. - time.Sleep(1000 * time.Millisecond) - } - - if startupErr != nil { - log.Panicf("Application ping to Supervisor failed, terminating Rapid Startup: %s", startupErr) - } - - go start(s.SignalCtx, execCtx) + go startRuntimeAPI(ctx, execCtx) return execCtx, registrationService.GetInternalStateDescriptor(appCtx), runtimeAPIAddr } -func (r *rapidContext) HandleInit(init *interop.Init, initStartedResponseChan chan<- interop.InitStarted, initSuccessResponseChan chan<- interop.InitSuccess, initFailureResponseChan chan<- interop.InitFailure) { - r.executionMutex.Lock() - defer r.executionMutex.Unlock() - handleInit(r, init, initStartedResponseChan, initSuccessResponseChan, initFailureResponseChan) +func (r *rapidContext) HandleInit(init *interop.Init, initSuccessResponseChan chan<- interop.InitSuccess, initFailureResponseChan chan<- interop.InitFailure) { + r.handlerExecutionMutex.Lock() + defer r.handlerExecutionMutex.Unlock() + handleInit(r, init, initSuccessResponseChan, initFailureResponseChan) } -func (r *rapidContext) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { - r.executionMutex.Lock() - defer r.executionMutex.Unlock() - // Clear the context used by the last invok - r.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) - return handleInvoke(r, invoke, sbInfoFromInit) +func (r *rapidContext) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit, requestBuffer *bytes.Buffer, responseSender interop.InvokeResponseSender) (interop.InvokeSuccess, *interop.InvokeFailure) { + r.handlerExecutionMutex.Lock() + defer r.handlerExecutionMutex.Unlock() + // Clear the context used by the last invoke + r.appCtx.Delete(appctx.AppCtxInvokeErrorTraceDataKey) + return handleInvoke(r, invoke, sbInfoFromInit, requestBuffer, responseSender) } -func (r *rapidContext) HandleReset(reset *interop.Reset, invokeReceivedTime int64, InvokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { +func (r *rapidContext) HandleReset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { // In the event of a Reset during init/invoke, CancelFlows cancels execution // flows and return with the errResetReceived err - this error is special-cased // and not handled by the init/invoke (unexpected) error handling functions r.registrationService.CancelFlows(errResetReceived) // Wait until invoke error handling has returned before continuing execution - r.executionMutex.Lock() - defer r.executionMutex.Unlock() + r.handlerExecutionMutex.Lock() + defer r.handlerExecutionMutex.Unlock() - // Clear the context used by the last invoke, i.e. error message etc. - r.appCtx.Delete(appctx.AppCtxInvokeErrorResponseKey) - return handleReset(r, reset, invokeReceivedTime, InvokeResponseMetrics) + // Clear the context used by the last invoke + r.appCtx.Delete(appctx.AppCtxInvokeErrorTraceDataKey) + return handleReset(r, reset, r.RuntimeStartedTime, r.InvokeResponseMetrics) } func (r *rapidContext) HandleShutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { // Wait until invoke error handling has returned before continuing execution - r.executionMutex.Lock() - defer r.executionMutex.Unlock() + r.handlerExecutionMutex.Lock() + defer r.handlerExecutionMutex.Unlock() // Shutdown doesn't cancel flows, so it can block forever return handleShutdown(r, shutdown, standaloneShutdownReason) } -func (r *rapidContext) HandleRestore(restore *interop.Restore) error { +func (r *rapidContext) HandleRestore(restore *interop.Restore) (interop.RestoreResult, error) { return handleRestore(r, restore) } func (r *rapidContext) Clear() { reinitialize(r) } + +func (r *rapidContext) SetRuntimeStartedTime(runtimeStartedTime int64) { + r.RuntimeStartedTime = runtimeStartedTime +} + +func (r *rapidContext) SetRuntimeOverheadStartedTime(runtimeOverheadStartedTime int64) { + r.RuntimeOverheadStartedTime = runtimeOverheadStartedTime +} + +func (r *rapidContext) SetInvokeResponseMetrics(metrics *interop.InvokeResponseMetrics) { + r.InvokeResponseMetrics = metrics +} + +func (r *rapidContext) SetLogStreamName(logStreamName string) { + r.logStreamName = logStreamName +} + +func (r *rapidContext) SetEventsAPI(eventsAPI interop.EventsAPI) { + r.eventsAPI = eventsAPI +} diff --git a/lambda/rapid/shutdown.go b/lambda/rapid/shutdown.go index fe23a9f..05695e3 100644 --- a/lambda/rapid/shutdown.go +++ b/lambda/rapid/shutdown.go @@ -5,6 +5,8 @@ package rapid import ( + "context" + "errors" "fmt" "sync" "time" @@ -21,18 +23,20 @@ import ( const ( // supervisor shutdown and kill operations block until the exit status of the - // interested process has been collected, or until the specified timeotuw - // expires (in which case the operation fails). - // Note that this timeout is mainly relevant when any of the domain + // interested process has been collected, or until the specified deadline expires + // Note that this deadline is mainly relevant when any of the domain // processes are in uninterruptible sleep state (notable examples: syscall - // to read/write a newtorked driver) + // to read/write a networked driver) // // We set a non nil value for these timeouts so that RAPID doesn't block // forever in one of the cases above. supervisorBlockingMaxMillis = 9000 runtimeDeadlineShare = 0.3 + + maxProcessExitWait = 2 * time.Second ) +// TODO: aggregate struct's methods into an interface, so that we can mock in tests type shutdownContext struct { // Adding a mutex around shuttingDown because there may be concurrent reads/writes. // Because the code in shutdown() and the seperate go routine created in setupEventsWatcher() @@ -130,11 +134,15 @@ func (s *shutdownContext) createExitedChannel(name string) { // Blocks until all the processes in the runtime domain generation have exited. // This helps us have a nice sync point on Shutdown where we know for sure that -// all the processes have exited and the state has been cleared. +// all the processes have exited and the state has been cleared. The exception +// to that rule is that if any of the processes don't exit within +// maxProcessExitWait from the beginning of the waiting period, an error is +// returned, in order to prevent it from waiting forever if any of the processes +// cannot be killed. // // It is OK not to hold the lock because we know that this is called only during // shutdown and nobody will start a new process during shutdown -func (s *shutdownContext) clearExitedChannel() { +func (s *shutdownContext) clearExitedChannel() error { s.runtimeDomainExitedMutex.Lock() mapLen := len(s.runtimeDomainExited) channels := make([]chan struct{}, 0, mapLen) @@ -143,26 +151,32 @@ func (s *shutdownContext) clearExitedChannel() { } s.runtimeDomainExitedMutex.Unlock() + exitTimeout := time.After(maxProcessExitWait) for _, v := range channels { - <-v + select { + case <-v: + case <-exitTimeout: + return errors.New("timed out waiting for runtime processes to exit") + } } s.runtimeDomainExitedMutex.Lock() s.runtimeDomainExited = make(map[string]chan struct{}, mapLen) s.runtimeDomainExitedMutex.Unlock() + return nil } func (s *shutdownContext) shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time) { // If runtime is started: - // 1. SIGTERM and wait until timeout - // 2. SIGKILL on timeout + // 1. SIGTERM and wait until deadline + // 2. SIGKILL on deadline log.Debug("Shutting down the runtime.") name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) exitedChannel, found := s.getExitedChannel(name) if found { - err := execCtx.supervisor.Terminate(&supvmodel.TerminateRequest{ + err := execCtx.supervisor.Terminate(context.Background(), &supvmodel.TerminateRequest{ Domain: RuntimeDomain, Name: name, }) @@ -172,17 +186,17 @@ func (s *shutdownContext) shutdownRuntime(execCtx *rapidContext, start time.Time log.WithError(err).Warn("Failed sending Termination signal to runtime") } - runtimeTimeout := deadline.Sub(start) - log.Tracef("The runtime timeout is %v.", runtimeTimeout) - runtimeTimer := time.NewTimer(runtimeTimeout) + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + select { - case <-runtimeTimer.C: - log.Warnf("Timeout: The runtime did not exit after %d ms; Killing it.", int64(runtimeTimeout/time.Millisecond)) - supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) - err = execCtx.supervisor.Kill(&supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Timeout: &supervisorBlockingMaxMillis, + case <-ctx.Done(): + log.Warnf("Deadline: The runtime did not exit after deadline %s; Killing it.", deadline) + + err = execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), }) if err != nil { @@ -201,8 +215,8 @@ func (s *shutdownContext) shutdownRuntime(execCtx *rapidContext, start time.Time func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, deadline time.Time, reason string) { // For each external agent, if agent is launched: // 1. Send Shutdown event if subscribed for it, else send SIGKILL to process group - // 2. Wait for all Shutdown-subscribed agents to exit with timeout - // 3. Send SIGKILL to process group for Shutdown-subscribed agents on timeout + // 2. Wait for all Shutdown-subscribed agents to exit with deadline + // 3. Send SIGKILL to process group for Shutdown-subscribed agents on deadline log.Debug("Shutting down the agents.") execCtx.renderingService.SetRenderer( @@ -224,7 +238,6 @@ func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, for _, a := range execCtx.registrationService.GetExternalAgents() { name := fmt.Sprintf("extension-%s-%d", a.Name, execCtx.runtimeDomainGeneration) exitedChannel, found := s.getExitedChannel(name) - supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) if !found { log.Warnf("Agent %s failed to launch, therefore skipping shutting it down.", a) @@ -242,24 +255,25 @@ func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, agent.Release() - agentTimeout := deadline.Sub(start) - var agentTimeoutChan <-chan time.Time + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() if execCtx.standaloneMode { - agentTimeoutChan = time.NewTimer(agentTimeout).C + ctx, cancel = context.WithDeadline(ctx, deadline) + defer cancel() } select { - case <-agentTimeoutChan: - log.Warnf("Timeout: the agent %s did not exit after %d ms; Killing it.", name, int64(agentTimeout/time.Millisecond)) - err := execCtx.supervisor.Kill(&supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Timeout: &supervisorBlockingMaxMillis, + case <-ctx.Done(): + log.Warnf("Deadline: the agent %s did not exit after deadline %s; Killing it.", name, deadline) + err := execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), }) if err != nil { // We are not reporting the error upstream because we will anyway // shut the domain out at the end of the shutdown sequence - log.WithError(err).Warn("Failed sending Kill signal to runtime") + log.WithError(err).Warn("Failed sending Kill signal to agent") } case <-exitedChannel: } @@ -270,11 +284,14 @@ func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, go func(name string) { defer wg.Done() - execCtx.supervisor.Kill(&supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Timeout: &supervisorBlockingMaxMillis, + err := execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), }) + if err != nil { + log.WithError(err).Warn("Failed sending Kill signal to agent") + } }(name) } } @@ -295,7 +312,6 @@ func (s *shutdownContext) shutdown(execCtx *rapidContext, deadlineNs int64, reas execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) runtimeDomainProfiler := &metering.ExtensionsResetDurationProfiler{} - supervisorBlockingMaxMillis := uint64(supervisorBlockingMaxMillis) // We do not spend any compute time on runtime graceful shutdown if there are no agents if execCtx.registrationService.CountAgents() == 0 { @@ -305,10 +321,10 @@ func (s *shutdownContext) shutdown(execCtx *rapidContext, deadlineNs int64, reas if found { log.Debug("SIGKILLing the runtime as no agents are registered.") - err = execCtx.supervisor.Kill(&supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Timeout: &supervisorBlockingMaxMillis, + err = execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ + Domain: RuntimeDomain, + Name: name, + Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), }) if err != nil { // We are not reporting the error upstream because we will anyway @@ -340,27 +356,13 @@ func (s *shutdownContext) shutdown(execCtx *rapidContext, deadlineNs int64, reas runtimeDomainProfiler.NumAgentsRegisteredForShutdown = len(s.agentsAwaitingExit) } - log.Info("Stopping runtime domain") - err = execCtx.supervisor.Stop(&supvmodel.StopRequest{ - Domain: RuntimeDomain, - Timeout: &supervisorBlockingMaxMillis, - }) - if err != nil { - log.WithError(err).Error("Failed shutting runtime domain down") - } else { - log.Info("Waiting for runtime domain processes termination") - s.clearExitedChannel() - log.Info("Stopping operator domain") - err = execCtx.supervisor.Stop(&supvmodel.StopRequest{ - Domain: OperatorDomain, - Timeout: &supervisorBlockingMaxMillis, - }) - if err != nil { - log.WithError(err).Error("Failed shutting operator domain down") - } + + log.Info("Waiting for runtime domain processes termination") + if err := s.clearExitedChannel(); err != nil { + log.Error(err) } runtimeDomainProfiler.Stop() - extensionsRestMs, timeout := runtimeDomainProfiler.CalculateExtensionsResetMs() - return extensionsRestMs, timeout, err + extensionsResetMs, timeout := runtimeDomainProfiler.CalculateExtensionsResetMs() + return extensionsResetMs, timeout, err } diff --git a/lambda/rapid/start_test.go b/lambda/rapid/start_test.go deleted file mode 100644 index ffb446f..0000000 --- a/lambda/rapid/start_test.go +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "context" - "fmt" - "go.amzn.com/lambda/core" - "io" - "net/http" - "regexp" - "strconv" - "strings" - "testing" - "time" - - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi" - "go.amzn.com/lambda/testdata" - - "github.com/stretchr/testify/assert" -) - -func BenchmarkChannelsSelect10(b *testing.B) { - c1 := make(chan int) - c2 := make(chan int) - c3 := make(chan int) - c4 := make(chan int) - c5 := make(chan int) - c6 := make(chan int) - c7 := make(chan int) - c8 := make(chan int) - c9 := make(chan int) - c10 := make(chan int) - - for n := 0; n < b.N; n++ { - select { - case <-c1: - break - case <-c2: - break - case <-c3: - break - case <-c4: - break - case <-c5: - break - case <-c6: - break - case <-c7: - break - case <-c8: - break - case <-c9: - break - case <-c10: - break - default: - break - } - } -} - -func BenchmarkChannelsSelect2(b *testing.B) { - c1 := make(chan int) - c2 := make(chan int) - - for n := 0; n < b.N; n++ { - select { - case <-c1: - break - case <-c2: - break - default: - break - } - } -} - -func TestGetExtensionNamesWithNoExtensions(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - - c := &rapidContext{ - registrationService: rs, - } - - assert.Equal(t, "", c.GetExtensionNames()) -} - -func TestGetExtensionNamesWithMultipleExtensions(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - _, _ = rs.CreateExternalAgent("Example1") - _, _ = rs.CreateInternalAgent("Example2") - _, _ = rs.CreateExternalAgent("Example3") - _, _ = rs.CreateInternalAgent("Example4") - - c := &rapidContext{ - registrationService: rs, - } - - r := regexp.MustCompile(`^(Example\d;){3}(Example\d)$`) - assert.True(t, r.MatchString(c.GetExtensionNames())) -} - -func TestGetExtensionNamesWithTooManyExtensions(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - for i := 10; i < 60; i++ { - _, _ = rs.CreateExternalAgent("E" + strconv.Itoa(i)) - } - - c := &rapidContext{ - registrationService: rs, - } - - output := c.GetExtensionNames() - - r := regexp.MustCompile(`^(E\d\d;){30}(E\d\d)$`) - assert.LessOrEqual(t, len(output), maxExtensionNamesLength) - assert.True(t, r.MatchString(output)) -} - -func TestGetExtensionNamesWithTooLongExtensionName(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - for i := 10; i < 60; i++ { - _, _ = rs.CreateExternalAgent(strings.Repeat("E", 130)) - } - - c := &rapidContext{ - registrationService: rs, - } - - assert.Equal(t, "", c.GetExtensionNames()) -} - -// This test confirms our assumption that http client can establish a tcp connection -// to a listening server. -func TestListen(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.ConfigureForInvoke(context.Background(), &interop.Invoke{ID: "ID", DeadlineNs: "1", Payload: strings.NewReader("MyTest")}) - - ctx := context.Background() - telemetryAPIEnabled := true - server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.TelemetrySubscription, flowTest.TelemetrySubscription, flowTest.CredentialsService, flowTest.EventsAPI) - err := server.Listen() - assert.NoError(t, err) - - defer server.Close() - - go func() { - time.Sleep(time.Second) - fmt.Println("Serving...") - server.Serve(ctx) - }() - - done := make(chan struct{}) - - go func() { - fmt.Println("Connecting...") - resp, err1 := http.Get(fmt.Sprintf("http://%s:%d/2018-06-01/runtime/invocation/next", server.Host(), server.Port())) - assert.Nil(t, err1) - - body, err2 := io.ReadAll(resp.Body) - assert.Nil(t, err2) - - assert.Equal(t, "MyTest", string(body)) - - done <- struct{}{} - }() - - <-done -} - -func TestInferSandboxInitTypeOnDemand(t *testing.T) { - initCachingEnabled := false - sandboxType := interop.SandboxClassic - initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) - assert.Equal(t, "on-demand", initSource) -} - -func TestInferSandboxInitTypeProvisionedConcurrency(t *testing.T) { - initCachingEnabled := false - sandboxType := interop.SandboxPreWarmed - initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) - assert.Equal(t, "provisioned-concurrency", initSource) -} - -func TestInferSandboxInitTypeInitCaching(t *testing.T) { - initCachingEnabled := true - sandboxType := interop.SandboxClassic - initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) - assert.Equal(t, "snap-start", initSource) -} - -func TestInferSandboxInitTypeInitCachingWithPC(t *testing.T) { - initCachingEnabled := true - sandboxType := interop.SandboxPreWarmed - initSource := interop.InferTelemetryInitSource(initCachingEnabled, sandboxType) - assert.Equal(t, "snap-start", initSource) -} diff --git a/lambda/rapidcore/bootstrap.go b/lambda/rapidcore/bootstrap.go deleted file mode 100644 index 165f532..0000000 --- a/lambda/rapidcore/bootstrap.go +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "fmt" - "os" - "path" - "path/filepath" - "strings" - - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - - log "github.com/sirupsen/logrus" -) - -type LogFormatter func(error) string -type BootstrapError func() (fatalerror.ErrorType, LogFormatter) - -// Bootstrap represents a list of executable bootstrap -// candidates in order of priority and exec metadata -type Bootstrap struct { - runtimeDomainRoot string - orderedLookupPaths []string - validCmd []string - workingDir string - cmdCandidates [][]string - extraFiles []*os.File - bootstrapError BootstrapError -} - -// Validate interface compliance -var _ interop.Bootstrap = (*Bootstrap)(nil) - -// NewBootstrap returns an instance of bootstrap defined by given params -func NewBootstrap(cmdCandidates [][]string, currentWorkingDir string, runtimeDomainRoot string) *Bootstrap { - var orderedLookupBootstrapPaths []string - for _, args := range cmdCandidates { - // Empty args is an error, but we want to detect it later (in Cmd() call) when we are able to report a descriptive error - if len(args) != 0 { - orderedLookupBootstrapPaths = append(orderedLookupBootstrapPaths, args[0]) - } - } - - if currentWorkingDir == "" { - // use the root directory as the default working directory - currentWorkingDir = "/" - } - - if runtimeDomainRoot == "" { - runtimeDomainRoot = "/" - } - - return &Bootstrap{ - orderedLookupPaths: orderedLookupBootstrapPaths, - workingDir: currentWorkingDir, - cmdCandidates: cmdCandidates, - runtimeDomainRoot: runtimeDomainRoot, - } -} - -func NewBootstrapSingleCmd(cmd []string, currentWorkingDir string, runtimeDomainRoot string) *Bootstrap { - if currentWorkingDir == "" { - // use the root directory as the default working directory - currentWorkingDir = "/" - } - if runtimeDomainRoot == "" { - runtimeDomainRoot = "/" - } - - // a single candidate command makes it automatically valid - return &Bootstrap{ - validCmd: cmd, - workingDir: currentWorkingDir, - runtimeDomainRoot: runtimeDomainRoot, - } -} - -// locateBootstrap sets the first occurrence of an -// actual bootstrap, given a list of possible files -func (b *Bootstrap) locateBootstrap() error { - for i, bootstrapCandidate := range b.orderedLookupPaths { - // validate path relatively to the domain's root - candidatPath := path.Join(b.runtimeDomainRoot, bootstrapCandidate) - file, err := os.Stat(candidatPath) - if err != nil { - if !os.IsNotExist(err) { - log.WithError(err).Warnf("Could not validate %s. Ignoring it.", bootstrapCandidate) - } - continue - } - if file.IsDir() { - log.Warnf("%s is a directory. Ignoring it", bootstrapCandidate) - continue - } - b.validCmd = b.cmdCandidates[i] - return nil - } - log.WithField("bootstrapPathsChecked", b.orderedLookupPaths).Warn("Couldn't find valid bootstrap(s)") - return fmt.Errorf("Couldn't find valid bootstrap(s): %s", b.orderedLookupPaths) -} - -// Cmd returns the args of bootstrap, relative to the -// chroot idenfied by `root`, where args[0] -// is the path to executable -func (b *Bootstrap) Cmd() ([]string, error) { - if len(b.validCmd) > 0 { - return b.validCmd, nil - } - - if err := b.locateBootstrap(); err != nil { - return []string{}, err - } - - log.Debug("Located runtime bootstrap", b.validCmd[0]) - return b.validCmd, nil -} - -// Env returns the environment variables available to -// the bootstrap process -func (b *Bootstrap) Env(e interop.EnvironmentVariables) map[string]string { - return e.RuntimeExecEnv() -} - -// Cwd returns the working directory of the bootstrap process -// The path is validated against the chroot identified by `root` -func (b *Bootstrap) Cwd() (string, error) { - if !filepath.IsAbs(b.workingDir) { - return "", fmt.Errorf("the working directory '%s' is invalid, it needs to be an absolute path", b.workingDir) - } - - // evaluate the path relatively to the domain's mnt namespace root - domainPath := path.Join(b.runtimeDomainRoot, b.workingDir) - if _, err := os.Stat(domainPath); os.IsNotExist(err) { - return "", fmt.Errorf("the working directory doesn't exist: %s", domainPath) - } - - return b.workingDir, nil -} - -// SetExtraFiles sets the extra file descriptors apart from 1 & 2 to be passed to runtime -func (b *Bootstrap) SetExtraFiles(extraFiles []*os.File) { - b.extraFiles = extraFiles -} - -// ExtraFiles returns the extra file descriptors apart from 1 & 2 to be passed to runtime -func (b *Bootstrap) ExtraFiles() []*os.File { - return b.extraFiles -} - -// CachedFatalError returns a bootstrap error that occurred during startup and before init -// so that it can be reported back to the customer in a later phase -func (b *Bootstrap) CachedFatalError(err error) (fatalerror.ErrorType, string, bool) { - if b.bootstrapError == nil { - return fatalerror.ErrorType(""), "", false - } - - fatalError, logFunc := b.bootstrapError() - - return fatalError, logFunc(err), true -} - -// SetCachedFatalError sets a cached fatal error that occurred during startup and before init -// so that it can be reported back to the customer in a later phase -func (b *Bootstrap) SetCachedFatalError(bootstrapErrFn BootstrapError) { - b.bootstrapError = bootstrapErrFn -} - -// BootstrapErrInvalidLCISTaskConfig represents an error while parsing LCIS task config -func BootstrapErrInvalidLCISTaskConfig(err error) BootstrapError { - return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidTaskConfig, SupernovaInvalidTaskConfigRepr(err) - } -} - -// BootstrapErrInvalidLCISEntrypoint represents an invalid LCIS entrypoint error -func BootstrapErrInvalidLCISEntrypoint(entrypoint []string, cmd []string, workingdir string) BootstrapError { - return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidEntrypoint, SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) - } -} - -func BootstrapErrInvalidLCISWorkingDir(entrypoint []string, cmd []string, workingdir string) BootstrapError { - return func() (fatalerror.ErrorType, LogFormatter) { - return fatalerror.InvalidWorkingDir, SupernovaLaunchErrorRepr(entrypoint, cmd, workingdir) - } -} - -func SupernovaInvalidTaskConfigRepr(err error) func(error) string { - return func(unused error) string { - return fmt.Sprintf("IMAGE\tInvalid task config: %s", err) - } -} - -func SupernovaLaunchErrorRepr(entrypoint []string, cmd []string, workingDir string) func(error) string { - return func(err error) string { - return fmt.Sprintf("IMAGE\tLaunch error: %s\tEntrypoint: [%s]\tCmd: [%s]\tWorkingDir: [%s]", - err, - strings.Join(entrypoint, ","), - strings.Join(cmd, ","), - workingDir) - } -} diff --git a/lambda/rapidcore/bootstrap_test.go b/lambda/rapidcore/bootstrap_test.go deleted file mode 100644 index b43520d..0000000 --- a/lambda/rapidcore/bootstrap_test.go +++ /dev/null @@ -1,280 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "os" - "path" - "path/filepath" - "reflect" - "testing" - - "go.amzn.com/lambda/rapidcore/env" - - "github.com/stretchr/testify/assert" -) - -func TestBootstrap(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - tmpFile, err := os.CreateTemp("", "lcis-test-bootstrap") - assert.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - // Setup cmd candidates - nonExistent := []string{"/foo/bar/baz"} - dir := []string{tmpDir, "--arg1", "foo"} - file := []string{tmpFile.Name(), "--arg1 s", "foo"} - cmdCandidates := [][]string{nonExistent, dir, file} - - // Setup working dir - cwd, err := os.Getwd() - assert.NoError(t, err) - - // Setup environment - environment := env.NewEnvironment() - environment.StoreRuntimeAPIEnvironmentVariable("host:port") - environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") - - // Test - b := NewBootstrap(cmdCandidates, cwd, "") - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, cwd, bCwd) - assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) - - cmd, err := b.Cmd() - assert.NoError(t, err) - assert.Equal(t, file, cmd) -} - -// When running bootstraps in separate mount namespaces -// we want to verify and discover paths relative to -// a root different from "/" -func TestBootstrapChroot(t *testing.T) { - tmpRoot, err := os.MkdirTemp(os.TempDir(), "domain-root") - assert.NoError(t, err) - defer os.RemoveAll(tmpRoot) - tmpDir, err := os.MkdirTemp(tmpRoot, "lcis-test-invalid-bootstrap") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - tmpFile, err := os.CreateTemp(tmpRoot, "lcis-test-bootstrap") - assert.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - // Setup cmd candidates - nonExistent := []string{"/foo/bar/baz"} - baseName := filepath.Base(tmpDir) - dir := []string{"/" + baseName, "--arg1", "foo"} - baseName = filepath.Base(tmpFile.Name()) - file := []string{"/" + baseName, "--arg1 s", "foo"} - cmdCandidates := [][]string{nonExistent, dir, file} - - // Setup working dir - cwd, err := os.MkdirTemp(tmpRoot, "cwd") - assert.NoError(t, err) - defer os.RemoveAll(cwd) - - // Setup environment - environment := env.NewEnvironment() - environment.StoreRuntimeAPIEnvironmentVariable("host:port") - environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") - - // Test - baseName = filepath.Base(cwd) - b := NewBootstrap(cmdCandidates, "/"+baseName, tmpRoot) - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, cwd, path.Join(tmpRoot, bCwd)) - assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) - - cmd, err := b.Cmd() - assert.NoError(t, err) - assert.Equal(t, file, cmd) -} - -func TestBootstrapEmptyCandidate(t *testing.T) { - // we expect newBootstrap to succeed and bootstrap.Cmd() to fail. - // We want to postpone the failure to be able to propagate error description to slicer and write it to customer log - invalidBootstrapCandidate := []string{} - bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/", "") - _, err := bs.Cmd() - assert.Error(t, err) -} - -func TestBootstrapChrootNonExistingRoot(t *testing.T) { - invalidBootstrapCandidate := []string{"/bin/bash", "-c"} - bs := NewBootstrap([][]string{invalidBootstrapCandidate}, "/", "/does_not_exist") - _, err := bs.Cmd() - assert.Error(t, err) -} - -func TestBootstrapSingleCmd(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - tmpFile, err := os.CreateTemp("", "lcis-test-bootstrap") - assert.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - // Setup single cmd candidate - file := []string{tmpFile.Name(), "--arg1 s", "foo"} - cmdCandidate := file - - // Setup working dir - cwd, err := os.Getwd() - assert.NoError(t, err) - - // Setup environment - environment := env.NewEnvironment() - environment.StoreRuntimeAPIEnvironmentVariable("host:port") - environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") - - // Test - b := NewBootstrapSingleCmd(cmdCandidate, cwd, "") - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, cwd, bCwd) - assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) - - cmd, err := b.Cmd() - assert.NoError(t, err) - assert.Equal(t, file, cmd) -} - -func TestBootstrapSingleCmdNonExistingCandidate(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "lcis-test-invalid-bootstrap") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - // Setup inexistent single cmd candidate - file := []string{"/foo/bar", "--arg1 s", "foo"} - cmdCandidate := file - - // Setup working dir - cwd, err := os.Getwd() - assert.NoError(t, err) - - // Setup environment - environment := env.NewEnvironment() - environment.StoreRuntimeAPIEnvironmentVariable("host:port") - environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") - - // Test - b := NewBootstrapSingleCmd(cmdCandidate, cwd, "") - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, cwd, bCwd) - assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) - - // No validations run against single candidates - cmd, err := b.Cmd() - assert.NoError(t, err) - assert.Equal(t, file, cmd) -} - -// Test our ability to locate bootstrap files in the file system -func TestFindCustomRuntimeIfExists(t *testing.T) { - tmpFile, err := os.CreateTemp(os.TempDir(), "tmp-") - if err != nil { - t.Fatal("Cannot create temporary file", err) - } - defer os.Remove(tmpFile.Name()) - - tmpFile2, err := os.CreateTemp(os.TempDir(), "tmp-") - if err != nil { - t.Fatal("Cannot create temporary file", err) - } - defer os.Remove(tmpFile2.Name()) - - // one bootstrap argument was given and it exists - bootstrap := NewBootstrap([][]string{{tmpFile.Name()}}, "/", "") - cmd, err := bootstrap.Cmd() - assert.NoError(t, err) - assert.Equal(t, []string{tmpFile.Name()}, cmd) - assert.Nil(t, err) - - // two bootstrap arguments given, both exist but first one is returned - bootstrap = NewBootstrap([][]string{{tmpFile.Name()}, {tmpFile2.Name()}}, "/", "") - cmd, err = bootstrap.Cmd() - assert.NoError(t, err) - assert.Equal(t, []string{tmpFile.Name()}, cmd) - assert.Nil(t, err) - - // two bootstrap arguments given, first one does not exist, second exists and is returned - bootstrap = NewBootstrap([][]string{{"mk"}, {tmpFile2.Name()}}, "/", "") - cmd, err = bootstrap.Cmd() - assert.NoError(t, err) - assert.Equal(t, []string{tmpFile2.Name()}, cmd) - assert.Nil(t, err) - - // two bootstrap arguments given, none exists - bootstrap = NewBootstrap([][]string{{"mk"}, {"mk2"}}, "/", "") - cmd, err = bootstrap.Cmd() - assert.EqualError(t, err, "Couldn't find valid bootstrap(s): [mk mk2]") - assert.Equal(t, []string{}, cmd) -} - -func TestCwdIsAbsolute(t *testing.T) { - tmpFile, err := os.CreateTemp(os.TempDir(), "tmp-") - if err != nil { - t.Fatal("Cannot create temporary file", err) - } - defer os.Remove(tmpFile.Name()) - - cmdCandidates := [][]string{{tmpFile.Name()}} - - // no errors when currentWorkingDir is absolute - bootstrap := NewBootstrap(cmdCandidates, "/tmp", "") - cwd, err := bootstrap.Cwd() - assert.Nil(t, err) - assert.Equal(t, "/tmp", cwd) - - bootstrap = NewBootstrap(cmdCandidates, "tmp", "") - _, err = bootstrap.Cwd() - assert.EqualError(t, err, "the working directory 'tmp' is invalid, it needs to be an absolute path") - - bootstrap = NewBootstrap(cmdCandidates, "./", "") - _, err = bootstrap.Cwd() - assert.EqualError(t, err, "the working directory './' is invalid, it needs to be an absolute path") -} - -func TestBootstrapMissingWorkingDirectory(t *testing.T) { - tmpFile, err := os.CreateTemp(os.TempDir(), "cwd-test-bootstrap") - assert.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - tmpDir, err := os.MkdirTemp("", "cwd-test") - assert.NoError(t, err) - defer os.RemoveAll(tmpDir) - - // cwd argument exists - bootstrap := NewBootstrap([][]string{{tmpFile.Name()}}, tmpDir, "") - cwd, err := bootstrap.Cwd() - assert.Equal(t, cwd, tmpDir) - assert.NoError(t, err) - - // cwd argument doesn't exist - bootstrap = NewBootstrap([][]string{{tmpFile.Name()}}, "/foo", "") - _, err = bootstrap.Cwd() - assert.EqualError(t, err, "the working directory doesn't exist: /foo") -} - -func TestDefaultWorkeringDirectory(t *testing.T) { - bootstrap := NewBootstrap([][]string{{}}, "", "") - cwd, err := bootstrap.Cwd() - assert.NoError(t, err) - assert.Equal(t, "/", cwd) -} - -func TestBootstrapSingleCmdDefaultWorkingDir(t *testing.T) { - b := NewBootstrapSingleCmd([]string{}, "", "") - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, "/", bCwd) -} diff --git a/lambda/rapidcore/env/environment.go b/lambda/rapidcore/env/environment.go index be0584c..fbe0ef2 100644 --- a/lambda/rapidcore/env/environment.go +++ b/lambda/rapidcore/env/environment.go @@ -6,9 +6,7 @@ package env import ( "fmt" "os" - "strconv" "strings" - "syscall" log "github.com/sirupsen/logrus" ) @@ -16,37 +14,24 @@ import ( const runtimeAPIAddressKey = "AWS_LAMBDA_RUNTIME_API" const handlerEnvKey = "_HANDLER" const executionEnvKey = "AWS_EXECUTION_ENV" +const taskRootEnvKey = "LAMBDA_TASK_ROOT" +const runtimeDirEnvKey = "LAMBDA_RUNTIME_DIR" // Environment holds env vars for runtime, agents, and for // internal use, parsed during startup and from START msg type Environment struct { - RAPID map[string]string // env vars req'd internally by RAPID - Platform map[string]string // reserved platform env vars as per Lambda docs - Runtime map[string]string // reserved runtime env vars as per Lambda docs - PlatformUnreserved map[string]string // unreserved platform env vars that customers can override - Credentials map[string]string // reserved env vars for credentials, set on INIT - Customer map[string]string // customer & unreserved platform env vars, set on INIT + Customer map[string]string // customer & unreserved platform env vars, set on INIT + + rapid map[string]string // env vars req'd internally by RAPID + platform map[string]string // reserved platform env vars as per Lambda docs + runtime map[string]string // reserved runtime env vars as per Lambda docs + platformUnreserved map[string]string // unreserved platform env vars that customers can override + credentials map[string]string // reserved env vars for credentials, set on INIT runtimeAPISet bool initEnvVarsSet bool } -// RapidConfig holds config req'd for RAPID's internal -// operation, parsed from internal env vars. -type RapidConfig struct { - SbID string - LogFd int - ShmFd int - CtrlFd int - CnslFd int - DirectInvokeFd int - LambdaTaskRoot string - XrayDaemonAddress string - PreLoadTimeNs int64 - FunctionName string - TelemetryAPIPassphrase string -} - func lookupEnv(keys map[string]bool) map[string]string { res := map[string]string{} for key := range keys { @@ -61,13 +46,13 @@ func lookupEnv(keys map[string]bool) map[string]string { // NewEnvironment parses environment variables into an Environment object func NewEnvironment() *Environment { return &Environment{ - RAPID: lookupEnv(predefinedInternalEnvVarKeys()), - Platform: lookupEnv(predefinedPlatformEnvVarKeys()), - Runtime: lookupEnv(predefinedRuntimeEnvVarKeys()), - PlatformUnreserved: lookupEnv(predefinedPlatformUnreservedEnvVarKeys()), + rapid: lookupEnv(predefinedInternalEnvVarKeys()), + platform: lookupEnv(predefinedPlatformEnvVarKeys()), + runtime: lookupEnv(predefinedRuntimeEnvVarKeys()), + platformUnreserved: lookupEnv(predefinedPlatformUnreservedEnvVarKeys()), - Credentials: map[string]string{}, Customer: map[string]string{}, + credentials: map[string]string{}, runtimeAPISet: false, initEnvVarsSet: false, @@ -77,44 +62,49 @@ func NewEnvironment() *Environment { // StoreRuntimeAPIEnvironmentVariable stores value for AWS_LAMBDA_RUNTIME_API func (e *Environment) StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) { - e.Platform[runtimeAPIAddressKey] = runtimeAPIAddress + e.platform[runtimeAPIAddressKey] = runtimeAPIAddress e.runtimeAPISet = true } -// GetHandler turns the current setting for handler -func (e *Environment) GetHandler() string { - return e.Runtime[handlerEnvKey] -} - // SetHandler sets _HANDLER env variable value for Runtime func (e *Environment) SetHandler(handler string) { - e.Runtime[handlerEnvKey] = handler + e.runtime[handlerEnvKey] = handler } // GetExecutionEnv returns the current setting for AWS_EXECUTION_ENV func (e *Environment) GetExecutionEnv() string { - return e.Runtime[executionEnvKey] + return e.runtime[executionEnvKey] } // SetExecutionEnv sets AWS_EXECUTION_ENV variable value for Runtime func (e *Environment) SetExecutionEnv(executionEnv string) { - e.Runtime[executionEnvKey] = executionEnv + e.runtime[executionEnvKey] = executionEnv +} + +// SetTaskRoot sets the LAMBDA_TASK_ROOT environment variable for Runtime +func (e *Environment) SetTaskRoot(taskRoot string) { + e.runtime[taskRootEnvKey] = taskRoot +} + +// SetRuntimeDir sets the LAMBDA_RUNTIME_DIR environment variable for Runtime +func (e *Environment) SetRuntimeDir(runtimeDir string) { + e.runtime[runtimeDirEnvKey] = runtimeDir } // StoreEnvironmentVariablesFromInit sets the environment variables // for credentials & _HANDLER which are received in the START message func (e *Environment) StoreEnvironmentVariablesFromInit(customerEnv map[string]string, handler, awsKey, awsSecret, awsSession, funcName, funcVer string) { - e.Credentials["AWS_ACCESS_KEY_ID"] = awsKey - e.Credentials["AWS_SECRET_ACCESS_KEY"] = awsSecret - e.Credentials["AWS_SESSION_TOKEN"] = awsSession + e.credentials["AWS_ACCESS_KEY_ID"] = awsKey + e.credentials["AWS_SECRET_ACCESS_KEY"] = awsSecret + e.credentials["AWS_SESSION_TOKEN"] = awsSession e.storeNonCredentialEnvironmentVariablesFromInit(customerEnv, handler, funcName, funcVer) } func (e *Environment) StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) { - e.Credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"] = fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port) - e.Credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"] = token + e.credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"] = fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port) + e.credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"] = token e.storeNonCredentialEnvironmentVariablesFromInit(customerEnv, handler, funcName, funcVer) } @@ -125,11 +115,11 @@ func (e *Environment) storeNonCredentialEnvironmentVariablesFromInit(customerEnv } if funcName != "" { - e.Platform["AWS_LAMBDA_FUNCTION_NAME"] = funcName + e.platform["AWS_LAMBDA_FUNCTION_NAME"] = funcName } if funcVer != "" { - e.Platform["AWS_LAMBDA_FUNCTION_VERSION"] = funcVer + e.platform["AWS_LAMBDA_FUNCTION_VERSION"] = funcVer } e.mergeCustomerEnvironmentVariables(customerEnv) // overrides env vars from CLI options @@ -154,7 +144,7 @@ func (e *Environment) RuntimeExecEnv() map[string]string { log.Fatal("credentials, customer and runtime API address must be set") } - return mapUnion(e.Customer, e.PlatformUnreserved, e.Credentials, e.Runtime, e.Platform) + return mapUnion(e.Customer, e.platformUnreserved, e.credentials, e.runtime, e.platform) } // AgentExecEnv returns the key=value strings of all environment variables @@ -166,74 +156,7 @@ func (e *Environment) AgentExecEnv() map[string]string { excludedKeys := extensionExcludedKeys() excludeCondition := func(key string) bool { return excludedKeys[key] || strings.HasPrefix(key, "_") } - return mapExclude(mapUnion(e.Customer, e.Credentials, e.Platform), excludeCondition) -} - -// RAPIDInternalConfig returns the rapid config parsed from environment vars -func (e *Environment) RAPIDInternalConfig() RapidConfig { - return RapidConfig{ - SbID: e.getStrEnvVarOrDie(e.RAPID, "_LAMBDA_SB_ID"), - LogFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_LOG_FD"), - ShmFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_SHARED_MEM_FD"), - CtrlFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_CONTROL_SOCKET"), - CnslFd: e.getSocketEnvVarOrDie(e.RAPID, "_LAMBDA_CONSOLE_SOCKET"), - DirectInvokeFd: e.getOptionalSocketEnvVar(e.RAPID, "_LAMBDA_DIRECT_INVOKE_SOCKET"), - PreLoadTimeNs: e.getInt64EnvVarOrDie(e.RAPID, "_LAMBDA_RUNTIME_LOAD_TIME"), - LambdaTaskRoot: e.getStrEnvVarOrDie(e.Runtime, "LAMBDA_TASK_ROOT"), - XrayDaemonAddress: e.getStrEnvVarOrDie(e.PlatformUnreserved, "AWS_XRAY_DAEMON_ADDRESS"), - FunctionName: e.getStrEnvVarOrDie(e.Platform, "AWS_LAMBDA_FUNCTION_NAME"), - TelemetryAPIPassphrase: e.RAPID["_LAMBDA_TELEMETRY_API_PASSPHRASE"], // TODO: Die if not set - } -} - -func (e *Environment) getStrEnvVarOrDie(env map[string]string, name string) string { - val, ok := env[name] - if !ok { - log.WithField("name", name).Fatal("Environment variable is not set") - } - return val -} - -func (e *Environment) getInt64EnvVarOrDie(env map[string]string, name string) int64 { - strval := e.getStrEnvVarOrDie(env, name) - val, err := strconv.ParseInt(strval, 10, 64) - if err != nil { - log.WithError(err).WithField("name", name).Fatal("Unable to parse int env var.") - } - return val -} - -func (e *Environment) getIntEnvVarOrDie(env map[string]string, name string) int { - return int(e.getInt64EnvVarOrDie(env, name)) -} - -// getSocketEnvVarOrDie reads and returns an int value of the -// environment variable or dies, when unable to do so. -// It also makes CloseOnExec for this value. -func (e *Environment) getSocketEnvVarOrDie(env map[string]string, name string) int { - sock := e.getIntEnvVarOrDie(env, name) - syscall.CloseOnExec(sock) - return sock -} - -// returns -1 if env variable was not set. Exits if it holds unexpected (non-int) value -func (e *Environment) getOptionalSocketEnvVar(env map[string]string, name string) int { - val, found := env[name] - if !found { - return -1 - } - - sock, err := strconv.Atoi(val) - if err != nil { - log.WithError(err).WithField("name", name).Fatal("Unable to parse socket env var.") - } - - if sock < 0 { - log.WithError(err).WithField("name", name).Fatal("Negative socket descriptor value") - } - - syscall.CloseOnExec(sock) - return sock + return mapExclude(mapUnion(e.Customer, e.credentials, e.platform), excludeCondition) } func mapUnion(maps ...map[string]string) map[string]string { diff --git a/lambda/rapidcore/env/environment_test.go b/lambda/rapidcore/env/environment_test.go index ed3043c..04c0494 100644 --- a/lambda/rapidcore/env/environment_test.go +++ b/lambda/rapidcore/env/environment_test.go @@ -34,7 +34,7 @@ func TestRAPIDInternalConfig(t *testing.T) { os.Setenv("AWS_LAMBDA_FUNCTION_NAME", "a") os.Setenv("_LAMBDA_TELEMETRY_API_PASSPHRASE", "a") os.Setenv("_LAMBDA_DIRECT_INVOKE_SOCKET", "1") - NewEnvironment().RAPIDInternalConfig() + NewRapidConfig(NewEnvironment()) } func TestEnvironmentParsing(t *testing.T) { @@ -59,11 +59,11 @@ func TestEnvironmentParsing(t *testing.T) { env.StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress) env.StoreEnvironmentVariablesFromInit(customerEnv, runtimeEnvVal, credsEnvVal, credsEnvVal, credsEnvVal, platformEnvVal, platformEnvVal) - for _, val := range env.RAPID { + for _, val := range env.rapid { assert.Equal(t, internalEnvVal, val) } - for key, val := range env.Platform { + for key, val := range env.platform { if key == runtimeAPIAddressKey { assert.Equal(t, runtimeAPIAddress, val) } else { @@ -71,16 +71,16 @@ func TestEnvironmentParsing(t *testing.T) { } } - for _, val := range env.Runtime { + for _, val := range env.runtime { assert.Equal(t, runtimeEnvVal, val) } - for key, val := range env.Credentials { + for key, val := range env.credentials { assert.Equal(t, credsEnvVal, val) assert.NotContains(t, env.Customer, key) } - for _, val := range env.PlatformUnreserved { + for _, val := range env.platformUnreserved { assert.Equal(t, customerEnvVal, val) } @@ -94,10 +94,10 @@ func TestEnvironmentParsingUnsetPlatformAndInternalEnvVarKeysAreDeleted(t *testi os.Clearenv() env := NewEnvironment() - assert.Len(t, env.RAPID, 0) - assert.Len(t, env.Platform, 0) - assert.Len(t, env.PlatformUnreserved, 0) - assert.Len(t, env.Credentials, 0) // uninitialized + assert.Len(t, env.rapid, 0) + assert.Len(t, env.platform, 0) + assert.Len(t, env.platformUnreserved, 0) + assert.Len(t, env.credentials, 0) // uninitialized assert.Len(t, env.Customer, 0) // uninitialized } @@ -136,30 +136,30 @@ func TestRuntimeExecEnvironmentVariables(t *testing.T) { rapidEnvVarsSlice := envToSlice(rapidEnvVars) - for key := range env.RAPID { + for key := range env.rapid { assert.NotContains(t, rapidEnvKeys, key) } - for key, val := range env.Runtime { + for key, val := range env.runtime { assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } - for key, val := range env.Platform { + for key, val := range env.platform { assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } - for key, val := range env.PlatformUnreserved { + for key, val := range env.platformUnreserved { assert.Contains(t, rapidEnvVarsSlice, key+"="+val) assert.NotContains(t, env.Customer, key) } - for key, val := range env.Credentials { + for key, val := range env.credentials { assert.Contains(t, rapidEnvVarsSlice, key+"="+val) } for key, val := range env.Customer { assert.Contains(t, rapidEnvVarsSlice, key+"="+val) - assert.NotContains(t, env.PlatformUnreserved, key) + assert.NotContains(t, env.platformUnreserved, key) } } @@ -195,11 +195,11 @@ func TestRuntimeExecEnvironmentVariablesPriority(t *testing.T) { env.StoreEnvironmentVariablesFromCLIOptions(cliOptionsEnv) env.StoreEnvironmentVariablesFromInit(customerEnv, runtimeEnvVal, credsEnvVal, credsEnvVal, credsEnvVal, platformEnvVal, platformEnvVal) - assert.Equal(t, len(predefinedPlatformEnvVarKeys()), len(env.Platform)) - assert.Equal(t, len(predefinedCredentialsEnvVarKeys()), len(env.Credentials)) - assert.Equal(t, len(predefinedPlatformUnreservedEnvVarKeys()), len(env.PlatformUnreserved)) - assert.Equal(t, len(predefinedInternalEnvVarKeys()), len(env.RAPID)) - assert.Equal(t, len(predefinedRuntimeEnvVarKeys()), len(env.Runtime)) + assert.Equal(t, len(predefinedPlatformEnvVarKeys()), len(env.platform)) + assert.Equal(t, len(predefinedCredentialsEnvVarKeys()), len(env.credentials)) + assert.Equal(t, len(predefinedPlatformUnreservedEnvVarKeys()), len(env.platformUnreserved)) + assert.Equal(t, len(predefinedInternalEnvVarKeys()), len(env.rapid)) + assert.Equal(t, len(predefinedRuntimeEnvVarKeys()), len(env.runtime)) rapidEnvVars := envToSlice(env.RuntimeExecEnv()) @@ -266,15 +266,15 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { agentEnvVarsSlice := envToSlice(agentEnvVars) - for key := range env.RAPID { + for key := range env.rapid { assert.NotContains(t, agentEnvKeys, key) } - for key, val := range env.Runtime { + for key, val := range env.runtime { assert.NotContains(t, agentEnvVarsSlice, key+"="+val) } - for key := range env.Platform { + for key := range env.platform { assert.Contains(t, agentEnvKeys, key) } @@ -282,11 +282,11 @@ func TestAgentExecEnvironmentVariables(t *testing.T) { assert.Contains(t, agentEnvKeys, key) } - for key, val := range env.Credentials { + for key, val := range env.credentials { assert.Contains(t, agentEnvVarsSlice, key+"="+val) } - assert.Contains(t, agentEnvVarsSlice, runtimeAPIAddressKey+"="+env.Platform[runtimeAPIAddressKey]) + assert.Contains(t, agentEnvVarsSlice, runtimeAPIAddressKey+"="+env.platform[runtimeAPIAddressKey]) } func TestStoreEnvironmentVariablesFromInitCaching(t *testing.T) { @@ -301,11 +301,11 @@ func TestStoreEnvironmentVariablesFromInitCaching(t *testing.T) { env.StoreEnvironmentVariablesFromInitForInitCaching("samplehost", 1234, customerEnv, handler, funcName, funcVer, token) - assert.Equal(t, fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port), env.Credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"]) - assert.Equal(t, token, env.Credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"]) - assert.Equal(t, funcName, env.Platform["AWS_LAMBDA_FUNCTION_NAME"]) - assert.Equal(t, funcVer, env.Platform["AWS_LAMBDA_FUNCTION_VERSION"]) - assert.Equal(t, handler, env.Runtime["_HANDLER"]) + assert.Equal(t, fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port), env.credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"]) + assert.Equal(t, token, env.credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"]) + assert.Equal(t, funcName, env.platform["AWS_LAMBDA_FUNCTION_NAME"]) + assert.Equal(t, funcVer, env.platform["AWS_LAMBDA_FUNCTION_VERSION"]) + assert.Equal(t, handler, env.runtime["_HANDLER"]) } func setAll(keys map[string]bool, value string) { diff --git a/lambda/rapidcore/env/rapidenv.go b/lambda/rapidcore/env/rapidenv.go new file mode 100644 index 0000000..bc1a6ad --- /dev/null +++ b/lambda/rapidcore/env/rapidenv.go @@ -0,0 +1,96 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package env + +import ( + "strconv" + "syscall" + + log "github.com/sirupsen/logrus" +) + +// RapidConfig holds config req'd for RAPID's internal +// operation, parsed from internal env vars. +// It should be build using `NewRapidConfig` to make sure that all the +// internal invariants are respected. +type RapidConfig struct { + SbID string + LogFd int + ShmFd int + CtrlFd int + CnslFd int + DirectInvokeFd int + LambdaTaskRoot string + XrayDaemonAddress string + PreLoadTimeNs int64 + FunctionName string + TelemetryAPIPassphrase string +} + +// Build the `RapidConfig` struct checking all the internal invariants +func NewRapidConfig(e *Environment) RapidConfig { + return RapidConfig{ + SbID: getStrEnvVarOrDie(e.rapid, "_LAMBDA_SB_ID"), + LogFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_LOG_FD"), + ShmFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_SHARED_MEM_FD"), + CtrlFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_CONTROL_SOCKET"), + CnslFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_CONSOLE_SOCKET"), + DirectInvokeFd: getOptionalSocketEnvVar(e.rapid, "_LAMBDA_DIRECT_INVOKE_SOCKET"), + PreLoadTimeNs: getInt64EnvVarOrDie(e.rapid, "_LAMBDA_RUNTIME_LOAD_TIME"), + LambdaTaskRoot: getStrEnvVarOrDie(e.runtime, "LAMBDA_TASK_ROOT"), + XrayDaemonAddress: getStrEnvVarOrDie(e.platformUnreserved, "AWS_XRAY_DAEMON_ADDRESS"), + FunctionName: getStrEnvVarOrDie(e.platform, "AWS_LAMBDA_FUNCTION_NAME"), + TelemetryAPIPassphrase: e.rapid["_LAMBDA_TELEMETRY_API_PASSPHRASE"], // TODO: Die if not set + } +} + +func getStrEnvVarOrDie(env map[string]string, name string) string { + val, ok := env[name] + if !ok { + log.WithField("name", name).Fatal("Environment variable is not set") + } + return val +} + +func getInt64EnvVarOrDie(env map[string]string, name string) int64 { + strval := getStrEnvVarOrDie(env, name) + val, err := strconv.ParseInt(strval, 10, 64) + if err != nil { + log.WithError(err).WithField("name", name).Fatal("Unable to parse int env var.") + } + return val +} + +func getIntEnvVarOrDie(env map[string]string, name string) int { + return int(getInt64EnvVarOrDie(env, name)) +} + +// getSocketEnvVarOrDie reads and returns an int value of the +// environment variable or dies, when unable to do so. +// It also makes CloseOnExec for this value. +func getSocketEnvVarOrDie(env map[string]string, name string) int { + sock := getIntEnvVarOrDie(env, name) + syscall.CloseOnExec(sock) + return sock +} + +// returns -1 if env variable was not set. Exits if it holds unexpected (non-int) value +func getOptionalSocketEnvVar(env map[string]string, name string) int { + val, found := env[name] + if !found { + return -1 + } + + sock, err := strconv.Atoi(val) + if err != nil { + log.WithError(err).WithField("name", name).Fatal("Unable to parse socket env var.") + } + + if sock < 0 { + log.WithError(err).WithField("name", name).Fatal("Negative socket descriptor value") + } + + syscall.CloseOnExec(sock) + return sock +} diff --git a/lambda/rapidcore/runtime_release.go b/lambda/rapidcore/runtime_release.go new file mode 100644 index 0000000..3875209 --- /dev/null +++ b/lambda/rapidcore/runtime_release.go @@ -0,0 +1,68 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +type Logging string + +const ( + AmznStdout Logging = "amzn-stdout" + AmznStdoutTLV Logging = "amzn-stdout-tlv" +) + +// RuntimeRelease stores runtime identification data +type RuntimeRelease struct { + Name string + Version string + Logging Logging +} + +const RuntimeReleasePath = "/var/runtime/runtime-release" + +// GetRuntimeRelease reads Runtime identification data from config file and parses it into a struct +func GetRuntimeRelease(path string) (*RuntimeRelease, error) { + pairs, err := ParsePropertiesFile(path) + if err != nil { + return nil, fmt.Errorf("could not parse %s: %w", path, err) + } + + return &RuntimeRelease{pairs["NAME"], pairs["VERSION"], Logging(pairs["LOGGING"])}, nil +} + +// ParsePropertiesFile reads key-value pairs from file in newline-separated list of environment-like +// shell-compatible variable assignments. +// Format: https://www.freedesktop.org/software/systemd/man/os-release.html +// Value quotes are trimmed. Latest write wins for duplicated keys. +func ParsePropertiesFile(path string) (map[string]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("could not open %s: %w", path, err) + } + defer f.Close() + + pairs := make(map[string]string) + + s := bufio.NewScanner(f) + for s.Scan() { + if s.Text() == "" || strings.HasPrefix(s.Text(), "#") { + continue + } + k, v, found := strings.Cut(s.Text(), "=") + if !found { + return nil, fmt.Errorf("could not parse key-value pair from a line: %s", s.Text()) + } + pairs[k] = strings.Trim(v, "'\"") + } + if err := s.Err(); err != nil { + return nil, fmt.Errorf("failed to read properties file: %w", err) + } + + return pairs, nil +} diff --git a/lambda/rapidcore/runtime_release_test.go b/lambda/rapidcore/runtime_release_test.go new file mode 100644 index 0000000..7397140 --- /dev/null +++ b/lambda/rapidcore/runtime_release_test.go @@ -0,0 +1,97 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package rapidcore + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetRuntimeRelease(t *testing.T) { + tests := []struct { + name string + content string + want *RuntimeRelease + }{ + { + "simple", + "NAME=foo\nVERSION=bar\nLOGGING=baz\n", + &RuntimeRelease{"foo", "bar", "baz"}, + }, + { + "no trailing new line", + "NAME=foo\nVERSION=bar\nLOGGING=baz", + &RuntimeRelease{"foo", "bar", "baz"}, + }, + { + "nonexistent keys", + "LOGGING=baz\n", + &RuntimeRelease{"", "", "baz"}, + }, + { + "empty value", + "NAME=\nVERSION=\nLOGGING=\n", + &RuntimeRelease{"", "", ""}, + }, + { + "delimiter in value", + "NAME=Foo=Bar\nVERSION=bar\nLOGGING=baz\n", + &RuntimeRelease{"Foo=Bar", "bar", "baz"}, + }, + { + "empty file", + "", + &RuntimeRelease{"", "", ""}, + }, + { + "quotes", + "NAME=\"foo\"\nVERSION='bar'\n", + &RuntimeRelease{"foo", "bar", ""}, + }, + { + "double quotes", + "NAME='\"foo\"'\nVERSION=\"'bar'\"\n", + &RuntimeRelease{"foo", "bar", ""}, + }, + { + "empty lines", // production runtime-release files have empty line in the end of the file + "\nNAME=foo\n\nVERSION=bar\n\nLOGGING=baz\n\n", + &RuntimeRelease{"foo", "bar", "baz"}, + }, + { + "comments", + "# comment 1\nNAME=foo\n# comment 2\nVERSION=bar\n# comment 3\nLOGGING=baz\n# comment 4\n", + &RuntimeRelease{"foo", "bar", "baz"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, err := os.CreateTemp(os.TempDir(), "runtime-release") + require.NoError(t, err) + _, err = f.WriteString(tt.content) + require.NoError(t, err) + got, err := GetRuntimeRelease(f.Name()) + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestGetRuntimeRelease_NotFound(t *testing.T) { + _, err := GetRuntimeRelease("/sys/not-exists") + assert.Error(t, err) +} + +func TestGetRuntimeRelease_InvalidLine(t *testing.T) { + f, err := os.CreateTemp(os.TempDir(), "runtime-release") + require.NoError(t, err) + _, err = f.WriteString("NAME=foo\nVERSION=bar\nLOGGING=baz\nSOMETHING") + require.NoError(t, err) + _, err = GetRuntimeRelease(f.Name()) + assert.Error(t, err) +} diff --git a/lambda/rapidcore/sandbox_api.go b/lambda/rapidcore/sandbox_api.go index 0c7052e..2e8d713 100644 --- a/lambda/rapidcore/sandbox_api.go +++ b/lambda/rapidcore/sandbox_api.go @@ -4,6 +4,9 @@ package rapidcore import ( + "bytes" + + "go.amzn.com/lambda/extensions" "go.amzn.com/lambda/interop" ) @@ -14,23 +17,26 @@ type SandboxContext struct { rapidCtx interop.RapidContext handler string runtimeAPIAddress string - - InvokeReceivedTime int64 - InvokeResponseMetrics *interop.InvokeResponseMetrics } +// initContext and its methods model the initialization lifecycle +// of the Sandbox, which persist across invocations type initContext struct { - initSuccessChan chan interop.InitSuccess - initFailureChan chan interop.InitFailure - rapidCtx interop.RapidContext - sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke + initSuccessChan chan interop.InitSuccess + initFailureChan chan interop.InitFailure + rapidCtx interop.RapidContext + sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke + invokeRequestBuffer *bytes.Buffer // byte buffer used to store the invoke request rendered to runtime (reused until reset) } +// invokeContext and its methods model the invocation lifecycle type invokeContext struct { - rapidCtx interop.RapidContext - invokeRequestChan chan *interop.Invoke - invokeSuccessChan chan interop.InvokeSuccess - invokeFailureChan chan interop.InvokeFailure + rapidCtx interop.RapidContext + invokeRequestChan chan *interop.Invoke + invokeSuccessChan chan interop.InvokeSuccess + invokeFailureChan chan interop.InvokeFailure + sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke + invokeRequestBuffer *bytes.Buffer // byte buffer used to store the invoke request rendered to runtime (reused until reset) } // Validate interface compliance @@ -38,8 +44,9 @@ var _ interop.SandboxContext = (*SandboxContext)(nil) var _ interop.InitContext = (*initContext)(nil) var _ interop.InvokeContext = (*invokeContext)(nil) -func (s SandboxContext) Init(init *interop.Init, timeoutMs int64) (interop.InitStarted, interop.InitContext) { - initStartedResponseChan := make(chan interop.InitStarted) +// Init starts the runtime domain initialization in a separate goroutine. +// Return value indicates that init request has been accepted and started. +func (s SandboxContext) Init(init *interop.Init, timeoutMs int64) interop.InitContext { initSuccessResponseChan := make(chan interop.InitSuccess) initFailureResponseChan := make(chan interop.InitFailure) @@ -48,49 +55,67 @@ func (s SandboxContext) Init(init *interop.Init, timeoutMs int64) (interop.InitS } init.EnvironmentVariables.StoreRuntimeAPIEnvironmentVariable(s.runtimeAPIAddress) + extensions.DisableViaMagicLayer() - go s.rapidCtx.HandleInit(init, initStartedResponseChan, initSuccessResponseChan, initFailureResponseChan) - initStarted := <-initStartedResponseChan + // We start initialization handling in a separate goroutine so that control can be returned back to + // caller, which can do work (e.g. notifying further upstream that initialization has started), and + // and call initCtx.Wait() to wait async for completion of initialization phase. + go s.rapidCtx.HandleInit(init, initSuccessResponseChan, initFailureResponseChan) sbMetadata := interop.SandboxInfoFromInit{ EnvironmentVariables: init.EnvironmentVariables, SandboxType: init.SandboxType, RuntimeBootstrap: init.Bootstrap, } - return initStarted, newInitContext(s.rapidCtx, sbMetadata, initSuccessResponseChan, initFailureResponseChan) + return newInitContext(s.rapidCtx, sbMetadata, initSuccessResponseChan, initFailureResponseChan) } +// Reset triggers a reset. In case of timeouts, the reset handler cancels all flows which triggers +// ongoing invoke handlers to return before proceeding with invoke +// TODO: move this method to the initialization context, since reset is conceptually on RT domain func (s SandboxContext) Reset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { defer s.rapidCtx.Clear() - return s.rapidCtx.HandleReset(reset, s.InvokeReceivedTime, s.InvokeResponseMetrics) + return s.rapidCtx.HandleReset(reset) } +// Reset triggers a shutdown. This is similar to a reset, except that this is a terminal state +// and no further invokes are allowed func (s SandboxContext) Shutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { return s.rapidCtx.HandleShutdown(shutdown) } -func (s SandboxContext) Restore(restore *interop.Restore) error { +func (s SandboxContext) Restore(restore *interop.Restore) (interop.RestoreResult, error) { return s.rapidCtx.HandleRestore(restore) } -func (s *SandboxContext) SetInvokeReceivedTime(invokeReceivedTime int64) { - s.InvokeReceivedTime = invokeReceivedTime +func (s *SandboxContext) SetRuntimeStartedTime(runtimeStartedTime int64) { + s.rapidCtx.SetRuntimeStartedTime(runtimeStartedTime) } func (s *SandboxContext) SetInvokeResponseMetrics(metrics *interop.InvokeResponseMetrics) { - s.InvokeResponseMetrics = metrics + s.rapidCtx.SetInvokeResponseMetrics(metrics) } func newInitContext(r interop.RapidContext, sbMetadata interop.SandboxInfoFromInit, initSuccessChan chan interop.InitSuccess, initFailureChan chan interop.InitFailure) initContext { + + // Invocation request buffer is initialized once per initialization + // to reduce memory usage & GC CPU time across invocations + var requestBuffer bytes.Buffer + return initContext{ - initSuccessChan: initSuccessChan, - initFailureChan: initFailureChan, - rapidCtx: r, - sbInfoFromInit: sbMetadata, + initSuccessChan: initSuccessChan, + initFailureChan: initFailureChan, + rapidCtx: r, + sbInfoFromInit: sbMetadata, + invokeRequestBuffer: &requestBuffer, } } +// Wait awaits until initialization phase is complete, i.e. one of: +// - until all runtime domain process call /next +// - any one of the runtime domain processes exit (init failure) +// Timeout handling is managed upstream entirely func (i initContext) Wait() (interop.InitSuccess, *interop.InitFailure) { select { case initSuccess, isOpen := <-i.initSuccessChan: @@ -108,35 +133,44 @@ func (i initContext) Wait() (interop.InitSuccess, *interop.InitFailure) { } } +// Reserve is used to initialize invoke-related state func (i initContext) Reserve() interop.InvokeContext { - invokeRequestChan := make(chan *interop.Invoke) invokeSuccessChan := make(chan interop.InvokeSuccess) invokeFailureChan := make(chan interop.InvokeFailure) + return invokeContext{ + rapidCtx: i.rapidCtx, + invokeRequestChan: invokeRequestChan, + invokeSuccessChan: invokeSuccessChan, + invokeFailureChan: invokeFailureChan, + sbInfoFromInit: i.sbInfoFromInit, + invokeRequestBuffer: i.invokeRequestBuffer, + } +} + +// SendRequest starts the invocation request handling in a separate goroutine, +// i.e. sending the request payload via /next response, +// and waiting for the synchronization points +func (invCtx invokeContext) SendRequest(invoke *interop.Invoke, responseSender interop.InvokeResponseSender) { + // Invoke handling needs to be in a separate goroutine so that control can + // be returned immediately to calling goroutine, which can do work and + // asynchronously call invCtx.Wait() to await completion of the invoke phase go func() { - invoke := <-invokeRequestChan // For suppressed inits, invoke needs the runtime and agent env vars - invokeSuccess, invokeFailure := i.rapidCtx.HandleInvoke(invoke, i.sbInfoFromInit) + invokeSuccess, invokeFailure := invCtx.rapidCtx.HandleInvoke(invoke, invCtx.sbInfoFromInit, invCtx.invokeRequestBuffer, responseSender) if invokeFailure != nil { - invokeFailureChan <- *invokeFailure + invCtx.invokeFailureChan <- *invokeFailure } else { - invokeSuccessChan <- invokeSuccess + invCtx.invokeSuccessChan <- invokeSuccess } }() - - return invokeContext{ - rapidCtx: i.rapidCtx, - invokeRequestChan: invokeRequestChan, - invokeSuccessChan: invokeSuccessChan, - invokeFailureChan: invokeFailureChan, - } -} - -func (invCtx invokeContext) SendRequest(i *interop.Invoke) { - invCtx.invokeRequestChan <- i } +// Wait awaits invoke completion, i.e. one of the following cases: +// - until all runtime domain process call /next +// - until a process exit (that notifies upstream to trigger a reset due to "failure") +// - until a timeout (triggered by a reset from upstream due to "timeout") func (invCtx invokeContext) Wait() (interop.InvokeSuccess, *interop.InvokeFailure) { select { case invokeSuccess := <-invCtx.invokeSuccessChan: diff --git a/lambda/rapidcore/sandbox_builder.go b/lambda/rapidcore/sandbox_builder.go index ce016a0..f51acda 100644 --- a/lambda/rapidcore/sandbox_builder.go +++ b/lambda/rapidcore/sandbox_builder.go @@ -33,7 +33,7 @@ type SandboxBuilder struct { lambdaInvokeAPI LambdaInvokeAPI defaultInteropServer *Server useCustomInteropServer bool - shutdownFuncs []context.CancelFunc + shutdownFuncs []func() handler string } @@ -45,42 +45,45 @@ const ( ) func NewSandboxBuilder() *SandboxBuilder { - defaultInteropServer := NewServer(context.Background()) - signalCtx, cancelSignalCtx := context.WithCancel(context.Background()) + defaultInteropServer := NewServer() + localSv := supervisor.NewLocalSupervisor() b := &SandboxBuilder{ sandbox: &rapid.Sandbox{ - PreLoadTimeNs: 0, // TODO StandaloneMode: true, LogsEgressAPI: &telemetry.NoOpLogsEgressAPI{}, EnableTelemetryAPI: false, Tracer: telemetry.NewNoOpTracer(), - SignalCtx: signalCtx, EventsAPI: &telemetry.NoOpEventsAPI{}, InitCachingEnabled: false, - Supervisor: supervisor.NewLocalSupervisor(), + Supervisor: localSv, + RuntimeFsRootPath: localSv.RootPath, RuntimeAPIHost: "127.0.0.1", RuntimeAPIPort: 9001, }, defaultInteropServer: defaultInteropServer, - shutdownFuncs: []context.CancelFunc{}, + shutdownFuncs: []func(){}, lambdaInvokeAPI: NewEmulatorAPI(defaultInteropServer), } - b.AddShutdownFunc(context.CancelFunc(func() { + b.AddShutdownFunc(func() { log.Info("Shutting down...") defaultInteropServer.Reset("SandboxTerminated", defaultSigtermResetTimeoutMs) - cancelSignalCtx() - })) + }) return b } -func (b *SandboxBuilder) SetSupervisor(supervisor supvmodel.Supervisor) *SandboxBuilder { +func (b *SandboxBuilder) SetSupervisor(supervisor supvmodel.ProcessSupervisor) *SandboxBuilder { b.sandbox.Supervisor = supervisor return b } +func (b *SandboxBuilder) SetRuntimeFsRootPath(rootPath string) *SandboxBuilder { + b.sandbox.RuntimeFsRootPath = rootPath + return b +} + func (b *SandboxBuilder) SetRuntimeAPIAddress(runtimeAPIAddress string) *SandboxBuilder { host, port, err := net.SplitHostPort(runtimeAPIAddress) if err != nil { @@ -105,7 +108,7 @@ func (b *SandboxBuilder) SetInteropServer(interopServer interop.Server) *Sandbox return b } -func (b *SandboxBuilder) SetEventsAPI(eventsAPI telemetry.EventsAPI) *SandboxBuilder { +func (b *SandboxBuilder) SetEventsAPI(eventsAPI interop.EventsAPI) *SandboxBuilder { b.sandbox.EventsAPI = eventsAPI return b } @@ -134,11 +137,6 @@ func (b *SandboxBuilder) SetInitCachingFlag(initCachingEnabled bool) *SandboxBui return b } -func (b *SandboxBuilder) SetPreLoadTimeNs(preLoadTimeNs int64) *SandboxBuilder { - b.sandbox.PreLoadTimeNs = preLoadTimeNs - return b -} - func (b *SandboxBuilder) SetTelemetrySubscription(logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI) *SandboxBuilder { b.sandbox.EnableTelemetryAPI = true b.sandbox.LogsSubscriptionAPI = logsSubscriptionAPI @@ -156,7 +154,7 @@ func (b *SandboxBuilder) SetHandler(handler string) *SandboxBuilder { return b } -func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc context.CancelFunc) *SandboxBuilder { +func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc func()) *SandboxBuilder { b.shutdownFuncs = append(b.shutdownFuncs, shutdownFunc) return b } @@ -166,16 +164,20 @@ func (b *SandboxBuilder) Create() (interop.SandboxContext, interop.InternalState b.sandbox.InteropServer = b.defaultInteropServer } - go signalHandler(b.shutdownFuncs) + ctx, cancel := context.WithCancel(context.Background()) + + // cancel is called when handling termination signals as a cancellation + // signal to the Runtime API sever to terminate gracefully + go signalHandler(cancel, b.shutdownFuncs) - rapidCtx, internalStateFn, runtimeAPIAddr := rapid.Start(b.sandbox) + // rapid.Start, among other things, starts the Runtime API server and + // terminates it gracefully if the cxt is canceled + rapidCtx, internalStateFn, runtimeAPIAddr := rapid.Start(ctx, b.sandbox) b.sandboxContext = &SandboxContext{ - rapidCtx: rapidCtx, - handler: b.handler, - runtimeAPIAddress: runtimeAPIAddr, - InvokeReceivedTime: int64(0), - InvokeResponseMetrics: nil, + rapidCtx: rapidCtx, + handler: b.handler, + runtimeAPIAddress: runtimeAPIAddr, } return b.sandboxContext, internalStateFn @@ -205,8 +207,10 @@ func SetInternalLogOutput(w io.Writer) { logging.SetOutput(w) } -// Trap SIGINT and SIGTERM signals and call shutdown function -func signalHandler(shutdownFuncs []context.CancelFunc) { +// Trap SIGINT and SIGTERM signals, call shutdown function, and cancel the +// ctx to terminate gracefully the Runtime API server +func signalHandler(cancel context.CancelFunc, shutdownFuncs []func()) { + defer cancel() sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) sigReceived := <-sig diff --git a/lambda/rapidcore/sandbox_emulator_api.go b/lambda/rapidcore/sandbox_emulator_api.go index 6737631..4cc2183 100644 --- a/lambda/rapidcore/sandbox_emulator_api.go +++ b/lambda/rapidcore/sandbox_emulator_api.go @@ -31,6 +31,7 @@ func NewEmulatorAPI(s *Server) *EmulatorAPI { // Init method is only used by the Runtime interface emulator func (l *EmulatorAPI) Init(i *interop.Init, timeoutMs int64) { l.server.Init(&interop.Init{ + AccountID: i.AccountID, Handler: i.Handler, AwsKey: i.AwsKey, AwsSecret: i.AwsSecret, diff --git a/lambda/rapidcore/server.go b/lambda/rapidcore/server.go index e652130..e903ebe 100644 --- a/lambda/rapidcore/server.go +++ b/lambda/rapidcore/server.go @@ -33,12 +33,6 @@ const ( resetDefaultTimeoutMs = 2000 ) -const ( - contentTypeHeader = "Content-Type" - errorTypeHeader = "Error-Type" - functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" -) - type rapidPhase int const ( @@ -84,7 +78,6 @@ type Server struct { initChanOut chan *interop.Init interruptedResponseChan chan *interop.Reset - sendRunningChan chan *interop.InitStarted sendResponseChan chan *interop.InvokeResponseMetrics doneChan chan *interop.Done @@ -107,7 +100,7 @@ type Server struct { initContext interop.InitContext invoker interop.InvokeContext initFailures chan interop.InitFailure - cachedInitErrorResponse *interop.ErrorResponse + cachedInitErrorResponse *interop.ErrorInvokeResponse } // Validate interface compliance @@ -266,7 +259,7 @@ func (s *Server) Release() error { s.reservationCancel() } - s.sandboxContext.SetInvokeReceivedTime(0) + s.sandboxContext.SetRuntimeStartedTime(-1) s.sandboxContext.SetInvokeResponseMetrics(nil) s.invokeCtx = nil return nil @@ -295,7 +288,7 @@ func (s *Server) SetInternalStateGetter(cb interop.InternalStateGetter) { s.InternalStateGetter = cb } -func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[string]string, status int, payload io.Reader, trailers http.Header, request *interop.CancellableRequest, runtimeCalledResponse bool) error { +func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[string]string, payload io.Reader, trailers http.Header, request *interop.CancellableRequest, runtimeCalledResponse bool) error { if s.invokeCtx == nil || invokeID != s.invokeCtx.Token.InvokeID { return interop.ErrInvalidInvokeID } @@ -310,7 +303,7 @@ func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[strin var reportedErr error if s.invokeCtx.Direct { - if err := directinvoke.SendDirectInvokeResponse(additionalHeaders, payload, trailers, s.invokeCtx.ReplyStream, s.interruptedResponseChan, s.sendResponseChan, request, runtimeCalledResponse); err != nil { + if err := directinvoke.SendDirectInvokeResponse(additionalHeaders, payload, trailers, s.invokeCtx.ReplyStream, s.interruptedResponseChan, s.sendResponseChan, request, runtimeCalledResponse, invokeID); err != nil { // TODO: Do we need to drain the reader in case of a large payload and connection reuse? log.Errorf("Failed to write response to %s: %s", invokeID, err) reportedErr = err @@ -328,7 +321,7 @@ func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[strin } startReadingResponseMonoTimeMs := metering.Monotime() - s.invokeCtx.ReplyStream.Header().Add(contentTypeHeader, additionalHeaders[contentTypeHeader]) + s.invokeCtx.ReplyStream.Header().Add(directinvoke.ContentTypeHeader, additionalHeaders[directinvoke.ContentTypeHeader]) written, err := s.invokeCtx.ReplyStream.Write(data) if err != nil { return fmt.Errorf("Failed to write response to %s: %s", invokeID, err) @@ -355,19 +348,19 @@ func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[strin return reportedErr } -func (s *Server) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { +func (s *Server) SendResponse(invokeID string, resp *interop.StreamableInvokeResponse) error { s.setRuntimeState(runtimeInvokeResponseSent) s.mutex.Lock() defer s.mutex.Unlock() runtimeCalledResponse := true - return s.sendResponseUnsafe(invokeID, headers, http.StatusOK, reader, trailers, request, runtimeCalledResponse) + return s.sendResponseUnsafe(invokeID, resp.Headers, resp.Payload, resp.Trailers, resp.Request, runtimeCalledResponse) } -func (s *Server) SendInitErrorResponse(invokeID string, resp *interop.ErrorResponse) error { - log.Debugf("Sending Init Error Response: %s", resp.ErrorType) +func (s *Server) SendInitErrorResponse(resp *interop.ErrorInvokeResponse) error { + log.Debugf("Sending Init Error Response: %s", resp.FunctionError.Type) if s.getRapidPhase() == phaseInvoking { // This branch occurs during suppressed init - return s.SendErrorResponse(invokeID, resp) + return s.SendErrorResponse(s.GetCurrentInvokeID(), resp) } // Handle an /init/error outside of the invoke phase @@ -376,17 +369,20 @@ func (s *Server) SendInitErrorResponse(invokeID string, resp *interop.ErrorRespo return nil } -func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorResponse) error { - log.Debugf("Sending Error Response: %s", resp.ErrorType) +func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorInvokeResponse) error { + log.Debugf("Sending Error Response: %s", resp.FunctionError.Type) s.setRuntimeState(runtimeInvokeError) s.mutex.Lock() defer s.mutex.Unlock() - additionalHeaders := map[string]string{contentTypeHeader: resp.ContentType, errorTypeHeader: resp.ErrorType} - if functionResponseMode := resp.FunctionResponseMode; functionResponseMode != "" { - additionalHeaders[functionResponseModeHeader] = functionResponseMode + additionalHeaders := map[string]string{ + directinvoke.ContentTypeHeader: resp.Headers.ContentType, + directinvoke.ErrorTypeHeader: string(resp.FunctionError.Type), + } + if functionResponseMode := resp.Headers.FunctionResponseMode; functionResponseMode != "" { + additionalHeaders[directinvoke.FunctionResponseModeHeader] = functionResponseMode } runtimeCalledResponse := false // we are sending an error here, so runtime called /error or crashed/timeout - return s.sendResponseUnsafe(invokeID, additionalHeaders, http.StatusInternalServerError, bytes.NewReader(resp.Payload), nil, nil, runtimeCalledResponse) + return s.sendResponseUnsafe(invokeID, additionalHeaders, bytes.NewReader(resp.Payload), nil, nil, runtimeCalledResponse) } func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) { @@ -409,10 +405,21 @@ func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescript s.setRuntimeState(runtimeNotStarted) var meta interop.DoneMetadata - if reset.InvokeResponseMetrics != nil { + if reset.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(reset.InvokeResponseMetrics) { meta.RuntimeTimeThrottledMs = reset.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) meta.RuntimeProducedBytes = reset.InvokeResponseMetrics.ProducedBytes meta.RuntimeOutboundThroughputBps = reset.InvokeResponseMetrics.OutboundThroughputBps + meta.MetricsDimensions = interop.DoneMetadataMetricsDimensions{ + InvokeResponseMode: reset.InvokeResponseMode, + } + + // These metrics aren't present in reset struct, therefore we need to get + // them from s.sandboxContext.Reset() response + if resetFailure != nil { + meta.RuntimeResponseLatencyMs = resetFailure.ResponseMetrics.RuntimeResponseLatencyMs + } else { + meta.RuntimeResponseLatencyMs = resetSuccess.ResponseMetrics.RuntimeResponseLatencyMs + } } if resetFailure != nil { @@ -431,15 +438,24 @@ func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescript return nil, errors.New(string(done.ErrorType)) } - return &statejson.ResetDescription{ExtensionsResetMs: done.Meta.ExtensionsResetMs}, nil + return &statejson.ResetDescription{ + ExtensionsResetMs: done.Meta.ExtensionsResetMs, + ResponseMetrics: statejson.ResponseMetrics{ + RuntimeResponseLatencyMs: done.Meta.RuntimeResponseLatencyMs, + Dimensions: statejson.ResponseMetricsDimensions{ + InvokeResponseMode: statejson.InvokeResponseMode( + done.Meta.MetricsDimensions.InvokeResponseMode, + ), + }, + }, + }, nil } -func NewServer(ctx context.Context) *Server { +func NewServer() *Server { s := &Server{ initChanOut: make(chan *interop.Init), interruptedResponseChan: make(chan *interop.Reset), - sendRunningChan: make(chan *interop.InitStarted), sendResponseChan: make(chan *interop.InvokeResponseMetrics), doneChan: make(chan *interop.Done), @@ -500,18 +516,15 @@ func (s *Server) Init(i *interop.Init, invokeTimeoutMs int64) error { s.SetInvokeTimeout(time.Duration(invokeTimeoutMs) * time.Millisecond) s.setRapidPhase(phaseInitializing) s.setInitFailuresChan() - initStarted, initCtx := s.sandboxContext.Init(i, invokeTimeoutMs) - initStarted.Ack <- struct{}{} + initCtx := s.sandboxContext.Init(i, invokeTimeoutMs) s.initContext = initCtx go s.awaitInitCompletion() - log.Debugf("Received RUNNING: %v", initStarted) return nil } func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error { - s.sandboxContext.SetInvokeReceivedTime(i.InvokeReceivedTime) invokeID, err := s.setReplyStream(w, direct) if err != nil { return err @@ -536,7 +549,7 @@ func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct boo s.setRuntimeState(runtimeInvokeComplete) return } - s.invoker.SendRequest(i) + s.invoker.SendRequest(i, s) invokeSuccess, invokeFailure := s.invoker.Wait() if invokeFailure != nil { if invokeFailure.ResetReceived { @@ -579,19 +592,19 @@ func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct boo return nil } -func (s *Server) setCachedInitErrorResponse(errResp *interop.ErrorResponse) { +func (s *Server) setCachedInitErrorResponse(errResp *interop.ErrorInvokeResponse) { s.mutex.Lock() defer s.mutex.Unlock() s.cachedInitErrorResponse = errResp } -func (s *Server) getCachedInitErrorResponse() *interop.ErrorResponse { +func (s *Server) getCachedInitErrorResponse() *interop.ErrorInvokeResponse { s.mutex.Lock() defer s.mutex.Unlock() return s.cachedInitErrorResponse } -func (s *Server) trySendDefaultErrorResponse(resp *interop.ErrorResponse) { +func (s *Server) trySendDefaultErrorResponse(resp *interop.ErrorInvokeResponse) { if err := s.SendErrorResponse(s.GetCurrentInvokeID(), resp); err != nil { if err != interop.ErrResponseSent { log.Panicf("Failed to send default error response: %s", err) @@ -658,9 +671,15 @@ func (s *Server) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invo // For init failures, cache the response so they can be checked later // We check if they have not already been set by a call to /init/error by runtime if s.getCachedInitErrorResponse() == nil { - errType, errMsg := string(initCompletionResp.InitErrorType), initCompletionResp.InitErrorMessage.Error() - s.setCachedInitErrorResponse(&interop.ErrorResponse{ErrorType: errType, ErrorMessage: errMsg}) + errType, errMsg := initCompletionResp.InitErrorType, initCompletionResp.InitErrorMessage.Error() + headers := interop.InvokeResponseHeaders{} + fnError := interop.FunctionError{Type: errType, Message: errMsg} + s.setCachedInitErrorResponse(&interop.ErrorInvokeResponse{Headers: headers, FunctionError: fnError, Payload: []byte{}}) } + + // Init failed, so we explicitly shutdown runtime (cleanup unused extensions). + // Because following fast invoke will start new (supressed) Init phase without reset call + s.Shutdown(&interop.Shutdown{DeadlineNs: metering.Monotime() + int64(resetDefaultTimeoutMs*1000*1000)}) } } @@ -759,7 +778,7 @@ func (s *Server) AwaitInitialized() error { return nil } -func (s *Server) AwaitRelease() (*statejson.InternalStateDescription, error) { +func (s *Server) AwaitRelease() (*statejson.ReleaseResponse, error) { defer func() { s.setRapidPhase(phaseIdle) s.setRuntimeState(runtimeInvokeComplete) @@ -776,8 +795,20 @@ func (s *Server) AwaitRelease() (*statejson.InternalStateDescription, error) { return nil, ErrInvokeDoneFailed } + releaseResponse := statejson.ReleaseResponse{ + InternalStateDescription: &doneWithState.State, + ResponseMetrics: statejson.ResponseMetrics{ + RuntimeResponseLatencyMs: doneWithState.Meta.RuntimeResponseLatencyMs, + Dimensions: statejson.ResponseMetricsDimensions{ + InvokeResponseMode: statejson.InvokeResponseMode( + doneWithState.Meta.MetricsDimensions.InvokeResponseMode, + ), + }, + }, + } + s.Release() - return &doneWithState.State, nil + return &releaseResponse, nil case <-s.reservationContext.Done(): return nil, ErrReleaseReservationDone @@ -806,7 +837,7 @@ func (s *Server) InternalState() (*statejson.InternalStateDescription, error) { return &state, nil } -func (s *Server) Restore(restore *interop.Restore) error { +func (s *Server) Restore(restore *interop.Restore) (interop.RestoreResult, error) { return s.sandboxContext.Restore(restore) } @@ -822,10 +853,14 @@ func doneFromInvokeSuccess(successMsg interop.InvokeSuccess) *interop.Done { InvokeCompletionTimeNs: successMsg.InvokeCompletionTimeNs, InvokeReceivedTime: successMsg.InvokeReceivedTime, + RuntimeResponseLatencyMs: successMsg.ResponseMetrics.RuntimeResponseLatencyMs, RuntimeTimeThrottledMs: successMsg.ResponseMetrics.RuntimeTimeThrottledMs, RuntimeProducedBytes: successMsg.ResponseMetrics.RuntimeProducedBytes, RuntimeOutboundThroughputBps: successMsg.ResponseMetrics.RuntimeOutboundThroughputBps, LogsAPIMetrics: successMsg.LogsAPIMetrics, + MetricsDimensions: interop.DoneMetadataMetricsDimensions{ + InvokeResponseMode: successMsg.InvokeResponseMode, + }, }, } } @@ -838,6 +873,7 @@ func doneFailFromInvokeFailure(failureMsg *interop.InvokeFailure) *interop.DoneF NumActiveExtensions: failureMsg.NumActiveExtensions, InvokeReceivedTime: failureMsg.InvokeReceivedTime, + RuntimeResponseLatencyMs: failureMsg.ResponseMetrics.RuntimeResponseLatencyMs, RuntimeTimeThrottledMs: failureMsg.ResponseMetrics.RuntimeTimeThrottledMs, RuntimeProducedBytes: failureMsg.ResponseMetrics.RuntimeProducedBytes, RuntimeOutboundThroughputBps: failureMsg.ResponseMetrics.RuntimeOutboundThroughputBps, @@ -848,6 +884,10 @@ func doneFailFromInvokeFailure(failureMsg *interop.InvokeFailure) *interop.DoneF ExtensionNames: failureMsg.ExtensionNames, LogsAPIMetrics: failureMsg.LogsAPIMetrics, + + MetricsDimensions: interop.DoneMetadataMetricsDimensions{ + InvokeResponseMode: failureMsg.InvokeResponseMode, + }, }, } } diff --git a/lambda/rapidcore/server_test.go b/lambda/rapidcore/server_test.go index 88eea3f..68ac30c 100644 --- a/lambda/rapidcore/server_test.go +++ b/lambda/rapidcore/server_test.go @@ -27,12 +27,6 @@ func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { } } -func sendInitStartedResponse(responseChannel chan<- interop.InitStarted, msg interop.InitStarted) { - msg.Ack = make(chan struct{}) - responseChannel <- msg - <-msg.Ack -} - func sendInitSuccessResponse(responseChannel chan<- interop.InitSuccess, msg interop.InitSuccess) { msg.Ack = make(chan struct{}) responseChannel <- msg @@ -46,20 +40,20 @@ func sendInitFailureResponse(responseChannel chan<- interop.InitFailure, msg int } type mockRapidCtx struct { - initHandler func(start chan<- interop.InitStarted, success chan<- interop.InitSuccess, fail chan<- interop.InitFailure) + initHandler func(success chan<- interop.InitSuccess, fail chan<- interop.InitFailure) invokeHandler func() (interop.InvokeSuccess, *interop.InvokeFailure) resetHandler func() (interop.ResetSuccess, *interop.ResetFailure) } -func (r *mockRapidCtx) HandleInit(init *interop.Init, startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - r.initHandler(startResp, successResp, failureResp) +func (r *mockRapidCtx) HandleInit(init *interop.Init, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + r.initHandler(successResp, failureResp) } -func (r *mockRapidCtx) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit) (interop.InvokeSuccess, *interop.InvokeFailure) { +func (r *mockRapidCtx) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit, buf *bytes.Buffer, responseSender interop.InvokeResponseSender) (interop.InvokeSuccess, *interop.InvokeFailure) { return r.invokeHandler() } -func (r *mockRapidCtx) HandleReset(reset *interop.Reset, invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { +func (r *mockRapidCtx) HandleReset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { return r.resetHandler() } @@ -67,25 +61,33 @@ func (r *mockRapidCtx) HandleShutdown(shutdown *interop.Shutdown) interop.Shutdo return interop.ShutdownSuccess{} } -func (r *mockRapidCtx) HandleRestore(restore *interop.Restore) error { - return nil +func (r *mockRapidCtx) HandleRestore(restore *interop.Restore) (interop.RestoreResult, error) { + return interop.RestoreResult{}, nil } func (r *mockRapidCtx) Clear() {} +func (r *mockRapidCtx) SetRuntimeStartedTime(a int64) { +} + +func (r *mockRapidCtx) SetInvokeResponseMetrics(a *interop.InvokeResponseMetrics) { +} + +func (r *mockRapidCtx) SetEventsAPI(e interop.EventsAPI) { +} + func TestReserveDoesNotDeadlockWhenCalledMultipleTimes(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { sendInitSuccessResponse(successResp, interop.InitSuccess{}) } srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ initHandler, func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + }, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) @@ -112,18 +114,17 @@ func TestReserveDoesNotDeadlockWhenCalledMultipleTimes(t *testing.T) { } func TestInitSuccess(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { sendInitSuccessResponse(successResp, interop.InitSuccess{}) } srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ initHandler, func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + }, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -134,13 +135,12 @@ func TestInitSuccess(t *testing.T) { func TestInitErrorBeforeReserve(t *testing.T) { // Rapid thread sending init failure should not be blocked even if reserve hasn't arrived - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) initErrorResponseSent := make(chan error) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) sendInitFailureResponse(failureResp, interop.InitFailure{}) initErrorResponseSent <- errors.New("initErrorResponseSent") } @@ -148,7 +148,7 @@ func TestInitErrorBeforeReserve(t *testing.T) { initHandler, func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + }, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) @@ -169,19 +169,18 @@ func TestInitErrorBeforeReserve(t *testing.T) { } func TestInitErrorDuringReserve(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) sendInitFailureResponse(failureResp, interop.InitFailure{}) } srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ initHandler, func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + }, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) resp, err := srv.Reserve("", "", "") @@ -197,24 +196,24 @@ func TestInitErrorDuringReserve(t *testing.T) { } func TestInvokeSuccess(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) releaseRuntimeInit := make(chan struct{}) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { <-releaseRuntimeInit sendInitSuccessResponse(successResp, interop.InitSuccess{}) } invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader([]byte("response")), nil, nil)) + response := &interop.StreamableInvokeResponse{Headers: map[string]string{"Content-Type": "application/json"}, Payload: bytes.NewReader([]byte("response"))} + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), response)) require.NoError(t, srv.SendRuntimeReady()) return interop.InvokeSuccess{}, nil } resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -239,16 +238,16 @@ func TestInvokeSuccess(t *testing.T) { } func TestInvokeError(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { sendInitSuccessResponse(successResp, interop.InitSuccess{}) } invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }"), ContentType: "application/json"})) + headers := interop.InvokeResponseHeaders{ContentType: "application/json"} + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }"), Headers: headers})) require.NoError(t, srv.SendRuntimeReady()) return interop.InvokeSuccess{}, nil } @@ -257,7 +256,7 @@ func TestInvokeError(t *testing.T) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -291,19 +290,19 @@ func TestInvokeWithSuppressedInitSuccess(t *testing.T) { // Reserve() returns ErrInitAlreadyDone, since the server implementation // closes the InitDone channel after the first InitDone message. - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) initErrorCompleted := make(chan error) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) sendInitFailureResponse(failureResp, interop.InitFailure{}) initErrorCompleted <- errors.New("initErrorSequenceCompleted") } invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), nil, bytes.NewReader([]byte("response")), nil, nil)) + response := &interop.StreamableInvokeResponse{Payload: bytes.NewReader([]byte("response"))} + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), response)) return interop.InvokeSuccess{}, nil } @@ -311,7 +310,7 @@ func TestInvokeWithSuppressedInitSuccess(t *testing.T) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -356,27 +355,26 @@ func TestInvokeWithSuppressedInitSuccess(t *testing.T) { func TestInvokeWithSuppressedInitErrorDueToInitError(t *testing.T) { // Tests init/error followed by init/error during suppressed init - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) sendInitFailureResponse(failureResp, interop.InitFailure{}) } releaseChan := make(chan error) invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) releaseChan <- nil - return interop.InvokeSuccess{}, &interop.InvokeFailure{ErrorType: "A.B", RequestReset: true, DefaultErrorResponse: &interop.ErrorResponse{}} + return interop.InvokeSuccess{}, &interop.InvokeFailure{ErrorType: "A.B", RequestReset: true, DefaultErrorResponse: &interop.ErrorInvokeResponse{}} } resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) @@ -411,16 +409,15 @@ func TestInvokeWithSuppressedInitErrorDueToInitError(t *testing.T) { func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { // Tests init/error followed by init/error during suppressed init - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) sendInitFailureResponse(failureResp, interop.InitFailure{}) } invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendInitErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) + require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) require.NoError(t, srv.SendRuntimeReady()) return interop.InvokeSuccess{}, nil } @@ -429,7 +426,7 @@ func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -461,16 +458,16 @@ func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { } func TestMultipleInvokeSuccess(t *testing.T) { - srv := NewServer(context.Background()) + srv := NewServer() srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - initHandler := func(startResp chan<- interop.InitStarted, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitStartedResponse(startResp, interop.InitStarted{}) + initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { sendInitSuccessResponse(successResp, interop.InitSuccess{}) } i := 0 invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), nil, bytes.NewReader([]byte("response-"+fmt.Sprint(i))), nil, nil)) + response := &interop.StreamableInvokeResponse{Payload: bytes.NewReader([]byte("response-" + fmt.Sprint(i)))} + require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), response)) require.NoError(t, srv.SendRuntimeReady()) i++ return interop.InvokeSuccess{}, nil @@ -480,7 +477,7 @@ func TestMultipleInvokeSuccess(t *testing.T) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999", 0, &interop.InvokeResponseMetrics{}}) + srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) require.Equal(t, phaseInitializing, srv.getRapidPhase()) @@ -505,11 +502,42 @@ func TestMultipleInvokeSuccess(t *testing.T) { } } +func TestAwaitReleaseOnSuccess(t *testing.T) { + srv := NewServer() + + // mocks + internalStateDescription := statejson.InternalStateDescription{} + srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return internalStateDescription }) + doneWithState := DoneWithState{ + State: internalStateDescription, + Done: &interop.Done{ + Meta: interop.DoneMetadata{ + RuntimeResponseLatencyMs: 12345, + MetricsDimensions: interop.DoneMetadataMetricsDimensions{ + InvokeResponseMode: interop.InvokeResponseModeStreaming, + }, + }, + }, + } + srv.InvokeDoneChan <- doneWithState + srv.reservationContext, srv.reservationCancel = context.WithCancel(context.Background()) + + // under test + responseAwaitRelease, err := srv.AwaitRelease() + + // assertions + require.NoError(t, err) + require.Equal(t, doneWithState.Done.Meta.RuntimeResponseLatencyMs, responseAwaitRelease.ResponseMetrics.RuntimeResponseLatencyMs) + require.Equal(t, string(doneWithState.Done.Meta.MetricsDimensions.InvokeResponseMode), string(responseAwaitRelease.ResponseMetrics.Dimensions.InvokeResponseMode)) + require.Equal(t, &doneWithState.State, responseAwaitRelease.InternalStateDescription) +} + /* Unit tests remaining: - Shutdown behaviour - Reset behaviour during various phases - Runtime / extensions process exit sequences - Invoke() and Init() api tests +- How can we add handleRestore test here? See PlantUML state diagram for potential other uncovered paths through the state machine diff --git a/lambda/rapidcore/standalone/eventLogHandler.go b/lambda/rapidcore/standalone/eventLogHandler.go index 156db99..e5bf7ac 100644 --- a/lambda/rapidcore/standalone/eventLogHandler.go +++ b/lambda/rapidcore/standalone/eventLogHandler.go @@ -8,11 +8,11 @@ import ( "fmt" "net/http" - "go.amzn.com/lambda/rapidcore/telemetry" + "go.amzn.com/lambda/rapidcore/standalone/telemetry" ) -func EventLogHandler(w http.ResponseWriter, r *http.Request, eventLog *telemetry.EventLog) { - bytes, err := json.Marshal(eventLog) +func EventLogHandler(w http.ResponseWriter, r *http.Request, eventsAPI *telemetry.StandaloneEventsAPI) { + bytes, err := json.Marshal(eventsAPI.EventLog()) if err != nil { http.Error(w, fmt.Sprintf("marshalling error: %s", err), http.StatusInternalServerError) return diff --git a/lambda/rapidcore/standalone/executeHandler.go b/lambda/rapidcore/standalone/executeHandler.go index 9bac400..0c7162b 100644 --- a/lambda/rapidcore/standalone/executeHandler.go +++ b/lambda/rapidcore/standalone/executeHandler.go @@ -27,19 +27,21 @@ func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.LambdaInv switch err { // Reserve errors: case rapidcore.ErrAlreadyReserved: - log.Errorf("Failed to reserve: %s", err) + log.WithError(err).Error("Failed to reserve as it is already reserved.") w.WriteHeader(400) case rapidcore.ErrInternalServerError: + log.WithError(err).Error("Failed to reserve from an internal server error.") w.WriteHeader(http.StatusInternalServerError) // Invoke errors: case rapidcore.ErrNotReserved, rapidcore.ErrAlreadyReplied, rapidcore.ErrAlreadyInvocating: - log.Errorf("Failed to set reply stream: %s", err) + log.WithError(err).Error("Failed to invoke from setting the reply stream.") w.WriteHeader(400) case rapidcore.ErrInvokeResponseAlreadyWritten: return case rapidcore.ErrInvokeTimeout, rapidcore.ErrInitResetReceived: + log.WithError(err).Error("Failed to invoke from an invoke timeout.") w.WriteHeader(http.StatusGatewayTimeout) // DONE failures: @@ -50,6 +52,7 @@ func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.LambdaInv return // Reservation canceled errors case rapidcore.ErrReserveReservationDone, rapidcore.ErrInvokeReservationDone, rapidcore.ErrReleaseReservationDone, rapidcore.ErrInitNotStarted: + log.WithError(err).Error("Failed to cancel reservation.") w.WriteHeader(http.StatusGatewayTimeout) } diff --git a/lambda/rapidcore/standalone/invokeHandler.go b/lambda/rapidcore/standalone/invokeHandler.go index 3e9768c..48a3a03 100644 --- a/lambda/rapidcore/standalone/invokeHandler.go +++ b/lambda/rapidcore/standalone/invokeHandler.go @@ -6,6 +6,7 @@ package standalone import ( "fmt" "net/http" + "strconv" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" @@ -22,12 +23,30 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { return } + restoreDurationHeader := r.Header.Get("restore-duration") + restoreStartHeader := r.Header.Get("restore-start-time") + + var restoreDurationNs int64 = 0 + var restoreStartTimeMonotime int64 = 0 + if restoreDurationHeader != "" && restoreStartHeader != "" { + var err1, err2 error + restoreDurationNs, err1 = strconv.ParseInt(restoreDurationHeader, 10, 64) + restoreStartTimeMonotime, err2 = strconv.ParseInt(restoreStartHeader, 10, 64) + if err1 != nil || err2 != nil { + log.Errorf("Failed to parse 'restore-duration' from '%s' and/or 'restore-start-time' from '%s'", restoreDurationHeader, restoreStartHeader) + restoreDurationNs = 0 + restoreStartTimeMonotime = 0 + } + } + invokePayload := &interop.Invoke{ - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: r.Body, - DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), - InvokeReceivedTime: metering.Monotime(), + TraceID: r.Header.Get("X-Amzn-Trace-Id"), + LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), + Payload: r.Body, + DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), + InvokeReceivedTime: metering.Monotime(), + RestoreDurationNs: restoreDurationNs, + RestoreStartTimeMonotime: restoreStartTimeMonotime, } if err := s.AwaitInitialized(); err != nil { diff --git a/lambda/rapidcore/standalone/restoreHandler.go b/lambda/rapidcore/standalone/restoreHandler.go index 190b6d8..fdf7a5d 100644 --- a/lambda/rapidcore/standalone/restoreHandler.go +++ b/lambda/rapidcore/standalone/restoreHandler.go @@ -4,7 +4,9 @@ package standalone import ( + "encoding/json" "net/http" + "strconv" "time" log "github.com/sirupsen/logrus" @@ -12,10 +14,11 @@ import ( ) type RestoreBody struct { - AwsKey string `json:"awskey"` - AwsSecret string `json:"awssecret"` - AwsSession string `json:"awssession"` - CredentialsExpiry time.Time `json:"credentialsExpiry"` + AwsKey string `json:"awskey"` + AwsSecret string `json:"awssecret"` + AwsSession string `json:"awssession"` + CredentialsExpiry time.Time `json:"credentialsExpiry"` + RestoreHookTimeoutMs int64 `json:"restoreHookTimeoutMs"` } func RestoreHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { @@ -26,16 +29,30 @@ func RestoreHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { } restore := &interop.Restore{ - AwsKey: restoreRequest.AwsKey, - AwsSecret: restoreRequest.AwsSecret, - AwsSession: restoreRequest.AwsSession, - CredentialsExpiry: restoreRequest.CredentialsExpiry, + AwsKey: restoreRequest.AwsKey, + AwsSecret: restoreRequest.AwsSecret, + AwsSession: restoreRequest.AwsSession, + CredentialsExpiry: restoreRequest.CredentialsExpiry, + RestoreHookTimeoutMs: restoreRequest.RestoreHookTimeoutMs, } - err := s.Restore(restore) + restoreResult, err := s.Restore(restore) + + responseMap := make(map[string]string) + + responseMap["restoreMs"] = strconv.FormatInt(restoreResult.RestoreMs, 10) if err != nil { log.Errorf("Failed to restore: %s", err) + responseMap["restoreError"] = err.Error() w.WriteHeader(http.StatusBadGateway) } + + responseJSON, err := json.Marshal(responseMap) + + if err != nil { + log.Panicf("Cannot marshal the response map for RESTORE, %v", responseMap) + } + + w.Write(responseJSON) } diff --git a/lambda/rapidcore/standalone/router.go b/lambda/rapidcore/standalone/router.go index f1712ea..7957c32 100644 --- a/lambda/rapidcore/standalone/router.go +++ b/lambda/rapidcore/standalone/router.go @@ -10,7 +10,7 @@ import ( "go.amzn.com/lambda/core/statejson" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" - "go.amzn.com/lambda/rapidcore/telemetry" + "go.amzn.com/lambda/rapidcore/standalone/telemetry" "github.com/go-chi/chi" ) @@ -21,14 +21,14 @@ type InteropServer interface { FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) - AwaitRelease() (*statejson.InternalStateDescription, error) + AwaitRelease() (*statejson.ReleaseResponse, error) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription InternalState() (*statejson.InternalStateDescription, error) CurrentToken() *interop.Token - Restore(restore *interop.Restore) error + Restore(restore *interop.Restore) (interop.RestoreResult, error) } -func NewHTTPRouter(ipcSrv InteropServer, lambdaInvokeAPI rapidcore.LambdaInvokeAPI, eventLog *telemetry.EventLog, shutdownFunc context.CancelFunc, bs interop.Bootstrap) *chi.Mux { +func NewHTTPRouter(ipcSrv InteropServer, lambdaInvokeAPI rapidcore.LambdaInvokeAPI, eventsAPI *telemetry.StandaloneEventsAPI, shutdownFunc context.CancelFunc, bs interop.Bootstrap) *chi.Mux { r := chi.NewRouter() r.Use(standaloneAccessLogDecorator) @@ -43,7 +43,7 @@ func NewHTTPRouter(ipcSrv InteropServer, lambdaInvokeAPI rapidcore.LambdaInvokeA r.Post("/test/shutdown", func(w http.ResponseWriter, r *http.Request) { ShutdownHandler(w, r, ipcSrv, shutdownFunc) }) r.Post("/test/directInvoke/{reservationtoken}", func(w http.ResponseWriter, r *http.Request) { DirectInvokeHandler(w, r, ipcSrv) }) r.Get("/test/internalState", func(w http.ResponseWriter, r *http.Request) { InternalStateHandler(w, r, ipcSrv) }) - r.Get("/test/eventLog", func(w http.ResponseWriter, r *http.Request) { EventLogHandler(w, r, eventLog) }) + r.Get("/test/eventLog", func(w http.ResponseWriter, r *http.Request) { EventLogHandler(w, r, eventsAPI) }) r.Post("/test/restore", func(w http.ResponseWriter, r *http.Request) { RestoreHandler(w, r, ipcSrv) }) return r } diff --git a/lambda/rapidcore/standalone/telemetry/agent_writer.go b/lambda/rapidcore/standalone/telemetry/agent_writer.go new file mode 100644 index 0000000..6ff2581 --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/agent_writer.go @@ -0,0 +1,30 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "bufio" + "bytes" +) + +type SandboxAgentWriter struct { + eventType string // 'runtime' or 'extension' + eventsAPI *StandaloneEventsAPI +} + +func NewSandboxAgentWriter(api *StandaloneEventsAPI, source string) *SandboxAgentWriter { + return &SandboxAgentWriter{ + eventType: source, + eventsAPI: api, + } +} + +func (w *SandboxAgentWriter) Write(logline []byte) (int, error) { + scanner := bufio.NewScanner(bytes.NewReader(logline)) + scanner.Split(bufio.ScanLines) + for scanner.Scan() { + w.eventsAPI.sendLogEvent(w.eventType, scanner.Text()) + } + return len(logline), nil +} diff --git a/lambda/rapidcore/standalone/telemetry/eventLog.go b/lambda/rapidcore/standalone/telemetry/eventLog.go new file mode 100644 index 0000000..0ab7c44 --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/eventLog.go @@ -0,0 +1,13 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +type EventLog struct { + Events []SandboxEvent `json:"events,omitempty"` // populated by the StandaloneEventLog object + Traces []TracingEvent `json:"traces,omitempty"` +} + +func NewEventLog() *EventLog { + return &EventLog{} +} diff --git a/lambda/rapidcore/standalone/telemetry/events_api.go b/lambda/rapidcore/standalone/telemetry/events_api.go new file mode 100644 index 0000000..dcac7a3 --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/events_api.go @@ -0,0 +1,293 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "encoding/json" + "sort" + "sync" + "time" + + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/telemetry" +) + +type EventType = string + +const ( + PlatformInitStart = EventType("platform.initStart") + PlatformInitRuntimeDone = EventType("platform.initRuntimeDone") + PlatformInitReport = EventType("platform.initReport") + PlatformRestoreRuntimeDone = EventType("platform.restoreRuntimeDone") + PlatformStart = EventType("platform.start") + PlatformRuntimeDone = EventType("platform.runtimeDone") + PlatformExtension = EventType("platform.extension") + PlatformEnd = EventType("platform.end") + PlatformReport = EventType("platform.report") + PlatformFault = EventType("platform.fault") +) + +/* +SandboxEvent represents a generic sandbox event. For example: + + { + "time": "2021-03-16T13:10:42.358Z", + "type": "platform.extension", + "platformEvent": { "name": "foo bar", "state": "Ready", "events": ["INVOKE", "SHUTDOWN"]} + } + +Or: + + { + "time": "2021-03-16T13:10:42.358Z", + "type": "extension", + "logMessage": "raw agent console output" + } + +FluxPump produces entries with a single field 'record', containing either an object or a string. +We make the distinction explicit by providing separate fields for the two cases, 'PlatformEvent' and 'LogMessage'. +Either one of the two would be populated, but not both. This makes code cleaner, but requires test client to merge +two fields back, producing a single 'record' entry again -- to match the FluxPump format that tests actually check. +*/ +type SandboxEvent struct { + Time string `json:"time"` + Type EventType `json:"type"` + PlatformEvent map[string]interface{} `json:"platformEvent,omitempty"` + LogMessage string `json:"logMessage,omitempty"` +} + +type tailLogs struct { + Events []SandboxEvent `json:"events,omitempty"` +} + +type StandaloneEventsAPI struct { + lock sync.Mutex + requestID interop.RequestID + eventLog EventLog +} + +func (s *StandaloneEventsAPI) LogTrace(entry TracingEvent) { + s.lock.Lock() + defer s.lock.Unlock() + s.eventLog.Traces = append(s.eventLog.Traces, entry) +} + +func (s *StandaloneEventsAPI) EventLog() *EventLog { + return &s.eventLog +} + +func (s *StandaloneEventsAPI) SetCurrentRequestID(requestID interop.RequestID) { + s.requestID = requestID +} + +func (s *StandaloneEventsAPI) SendInitStart(data interop.InitStartData) error { + record := map[string]interface{}{ + "initializationType": data.InitializationType, + "runtimeVersion": data.RuntimeVersion, + "runtimeArn": data.RuntimeVersionArn, + "runtimeVersionArn": data.RuntimeVersionArn, + "functionArn": data.FunctionArn, + "functionName": data.FunctionName, + "functionVersion": data.FunctionVersion, + "instanceId": data.InstanceID, + "instanceMaxMemory": data.InstanceMaxMemory, + "phase": data.Phase, + } + + s.addTracingToRecord(data.Tracing, record) + + return s.sendPlatformEvent(PlatformInitStart, record) +} + +func (s *StandaloneEventsAPI) SendInitRuntimeDone(data interop.InitRuntimeDoneData) error { + record := map[string]interface{}{ + "initializationType": data.InitializationType, + "status": data.Status, + "phase": data.Phase, + } + + s.addTracingToRecord(data.Tracing, record) + + if data.ErrorType != nil { + record["errorType"] = data.ErrorType + } + + return s.sendPlatformEvent(PlatformInitRuntimeDone, record) +} + +func (s *StandaloneEventsAPI) SendInitReport(data interop.InitReportData) error { + record := map[string]interface{}{ + "initializationType": data.InitializationType, + "metrics": data.Metrics, + "phase": data.Phase, + } + + s.addTracingToRecord(data.Tracing, record) + + return s.sendPlatformEvent(PlatformInitReport, record) +} + +func (s *StandaloneEventsAPI) SendRestoreRuntimeDone(data interop.RestoreRuntimeDoneData) error { + record := map[string]interface{}{"status": data.Status} + + s.addTracingToRecord(data.Tracing, record) + + if data.ErrorType != nil { + record["errorType"] = data.ErrorType + } + + return s.sendPlatformEvent(PlatformRestoreRuntimeDone, record) +} + +func (s *StandaloneEventsAPI) SendInvokeStart(data interop.InvokeStartData) error { + record := map[string]interface{}{ + "version": data.Version, + "requestId": data.RequestID, + } + + s.addTracingToRecord(data.Tracing, record) + + return s.sendPlatformEvent(PlatformStart, record) +} + +func (s *StandaloneEventsAPI) SendInvokeRuntimeDone(data interop.InvokeRuntimeDoneData) error { + record := map[string]interface{}{ + "requestId": s.requestID, + "status": data.Status, + "metrics": data.Metrics, + "internalMetrics": data.InternalMetrics, + "spans": data.Spans, + } + + if data.ErrorType != nil { + record["errorType"] = data.ErrorType + } + + s.addTracingToRecord(data.Tracing, record) + + return s.sendPlatformEvent(PlatformRuntimeDone, record) +} + +func (s *StandaloneEventsAPI) SendExtensionInit(data interop.ExtensionInitData) error { + sort.Strings(data.Subscriptions) + record := map[string]interface{}{ + "name": data.AgentName, + "state": data.State, + "events": data.Subscriptions, + } + if len(data.ErrorType) > 0 { + record["errorType"] = data.ErrorType + } + return s.sendPlatformEvent(PlatformExtension, record) +} + +func (s *StandaloneEventsAPI) SendImageErrorLog(interop.ImageErrorLogData) { + // Called on bootstrap exec errors for OCI error modes, e.g. InvalidEntrypoint etc. +} + +func (s *StandaloneEventsAPI) SendEnd(data interop.EndData) error { + record := map[string]interface{}{ + "requestId": data.RequestID, + } + + return s.sendPlatformEvent(PlatformEnd, record) +} + +func (s *StandaloneEventsAPI) SendReportSpan(interop.Span) error { + return nil +} + +func (s *StandaloneEventsAPI) SendReport(data interop.ReportData) error { + record := map[string]interface{}{ + "requestId": s.requestID, + "status": data.Status, + "metrics": data.Metrics, + "spans": data.Spans, + "tracing": data.Tracing, + } + if data.ErrorType != nil { + record["errorType"] = data.ErrorType + } + + return s.sendPlatformEvent(PlatformReport, record) +} + +func (s *StandaloneEventsAPI) SendFault(data interop.FaultData) error { + record := map[string]interface{}{ + "fault": data.String(), + } + + return s.sendPlatformEvent(PlatformFault, record) +} + +func (s *StandaloneEventsAPI) FetchTailLogs(string) (string, error) { + s.lock.Lock() + defer s.lock.Unlock() + + if len(s.eventLog.Events) == 0 { + return "", nil + } + + logs := tailLogs{Events: s.eventLog.Events} + logsBytes, err := json.Marshal(logs) + if err != nil { + return "", err + } + + s.eventLog.Events = nil + + return string(logsBytes), nil +} + +func (s *StandaloneEventsAPI) GetRuntimeDoneSpans( + runtimeStartedTime int64, + invokeResponseMetrics *interop.InvokeResponseMetrics, + runtimeOverheadStartedTime int64, + runtimeReadyTime int64, +) []interop.Span { + spans := telemetry.GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics) + return spans +} + +func (s *StandaloneEventsAPI) sendPlatformEvent(eventType string, record map[string]interface{}) error { + e := SandboxEvent{ + Time: time.Now().Format(time.RFC3339), + Type: eventType, + PlatformEvent: record, + } + s.appendEvent(e) + s.logEvent(e) + return nil +} + +func (s *StandaloneEventsAPI) sendLogEvent(eventType, logMessage string) error { + e := SandboxEvent{ + Time: time.Now().Format(time.RFC3339), + Type: eventType, + LogMessage: logMessage, + } + s.appendEvent(e) + s.logEvent(e) + return nil +} + +func (s *StandaloneEventsAPI) appendEvent(event SandboxEvent) { + s.lock.Lock() + defer s.lock.Unlock() + s.eventLog.Events = append(s.eventLog.Events, event) +} + +func (s *StandaloneEventsAPI) logEvent(e SandboxEvent) { + log.WithField("event", e).Info("sandbox event") +} + +func (s *StandaloneEventsAPI) addTracingToRecord(tracingData *interop.TracingCtx, record map[string]interface{}) { + if tracingData != nil { + record["tracing"] = map[string]string{ + "spanId": tracingData.SpanID, + "type": string(tracingData.Type), + "value": tracingData.Value, + } + } +} diff --git a/lambda/rapidcore/standalone/telemetry/logs_egress_api.go b/lambda/rapidcore/standalone/telemetry/logs_egress_api.go new file mode 100644 index 0000000..0f42dd1 --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/logs_egress_api.go @@ -0,0 +1,26 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import "io" + +type StandaloneLogsEgressAPI struct { + api *StandaloneEventsAPI +} + +func NewStandaloneLogsEgressAPI(api *StandaloneEventsAPI) *StandaloneLogsEgressAPI { + return &StandaloneLogsEgressAPI{ + api: api, + } +} + +func (s *StandaloneLogsEgressAPI) GetExtensionSockets() (io.Writer, io.Writer, error) { + w := NewSandboxAgentWriter(s.api, "extension") + return w, w, nil +} + +func (s *StandaloneLogsEgressAPI) GetRuntimeSockets() (io.Writer, io.Writer, error) { + w := NewSandboxAgentWriter(s.api, "function") + return w, w, nil +} diff --git a/lambda/rapidcore/standalone/telemetry/structured_logger.go b/lambda/rapidcore/standalone/telemetry/structured_logger.go new file mode 100644 index 0000000..8d9382b --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/structured_logger.go @@ -0,0 +1,21 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "github.com/sirupsen/logrus" + "os" +) + +var log = getLogger() + +func getLogger() *logrus.Logger { + formatter := logrus.JSONFormatter{} + formatter.DisableTimestamp = true + logger := new(logrus.Logger) + logger.Out = os.Stdout + logger.Formatter = &formatter + logger.Level = logrus.InfoLevel + return logger +} diff --git a/lambda/rapidcore/standalone/telemetry/tracer.go b/lambda/rapidcore/standalone/telemetry/tracer.go new file mode 100644 index 0000000..ba7f32d --- /dev/null +++ b/lambda/rapidcore/standalone/telemetry/tracer.go @@ -0,0 +1,216 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package telemetry + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/metering" + "go.amzn.com/lambda/rapi/model" + "go.amzn.com/lambda/telemetry" + + "github.com/sirupsen/logrus" +) + +// InitSubsegmentName provides name attribute for Init subsegment +const InitSubsegmentName = "Initialization" + +// RestoreSubsegmentName provides name attribute for Restore subsegment +const RestoreSubsegmentName = "Restore" + +// InvokeSubsegmentName provides name attribute for Invoke subsegment +const InvokeSubsegmentName = "Invocation" + +// OverheadSubsegmentName provides name attribute for Overhead subsegment +const OverheadSubsegmentName = "Overhead" + +type StandaloneTracer struct { + startFunction func(ctx context.Context, invoke *interop.Invoke, segmentName string, timestamp int64) + endFunction func(ctx context.Context, invoke *interop.Invoke, segmentName string, timestamp int64) + invoke *interop.Invoke + tracingHeader string + rootTraceID string + parent string + sampled string + lineage string + invocationSubsegmentID string + initStartTime int64 + initEndTime int64 + restoreStartTime int64 + restoreEndTime int64 + restorePresent bool +} + +type TracingEvent struct { + Message string `json:"message"` + TraceID string `json:"trace_id"` + SegmentName string `json:"segment_name"` + SegmentID string `json:"segment_id"` + Timestamp int64 `json:"timestamp"` +} + +func (t *StandaloneTracer) Configure(invoke *interop.Invoke) { + t.invoke = invoke + t.tracingHeader = invoke.TraceID + t.invocationSubsegmentID = "" + t.rootTraceID, t.parent, t.sampled, t.lineage = telemetry.ParseTracingHeader(invoke.TraceID) + if invoke.RestoreDurationNs == 0 { + t.restorePresent = false + } else { + t.restorePresent = true + t.restoreStartTime = metering.MonoToEpoch(invoke.RestoreStartTimeMonotime) + t.restoreEndTime = t.restoreStartTime + invoke.RestoreDurationNs + } +} + +func (t *StandaloneTracer) CaptureInvokeSegment(ctx context.Context, criticalFunction func(context.Context) error) error { + return t.withStartAndEnd(ctx, criticalFunction, "STANDALONE_FUNCTION_NAME") +} + +func (t *StandaloneTracer) CaptureInitSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { + return t.withStartAndEnd(ctx, criticalFunction, InitSubsegmentName) +} + +func (t *StandaloneTracer) CaptureInvokeSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { + t.invocationSubsegmentID = InvokeSubsegmentName + return t.withStartAndEnd(ctx, criticalFunction, InvokeSubsegmentName) +} + +func (t *StandaloneTracer) CaptureOverheadSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { + return t.withStartAndEnd(ctx, criticalFunction, OverheadSubsegmentName) +} + +func (t *StandaloneTracer) withStartAndEnd(ctx context.Context, criticalFunction func(context.Context) error, segmentName string) error { + ctx = telemetry.NewTraceContext(ctx, t.rootTraceID, segmentName) + t.startFunction(ctx, t.invoke, segmentName, time.Now().UnixNano()) + err := criticalFunction(ctx) + t.endFunction(ctx, t.invoke, segmentName, time.Now().UnixNano()) + return err +} + +func (t *StandaloneTracer) RecordInitStartTime() { + t.initStartTime = time.Now().UnixNano() +} + +func (t *StandaloneTracer) RecordInitEndTime() { + t.initEndTime = time.Now().UnixNano() + +} + +func (t *StandaloneTracer) sendPrepSubsegment(ctx context.Context, subsegmentName string, startTime int64, endTime int64) { + ctx = telemetry.NewTraceContext(ctx, t.rootTraceID, subsegmentName) + t.startFunction(ctx, t.invoke, subsegmentName, startTime) + t.endFunction(ctx, t.invoke, subsegmentName, endTime) +} + +func (t *StandaloneTracer) SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) { + t.sendPrepSubsegment(ctx, InitSubsegmentName, t.initStartTime, t.initEndTime) +} +func (t *StandaloneTracer) SendRestoreSubsegmentWithRecordedTimesOnce(ctx context.Context) { + if t.restorePresent { + t.sendPrepSubsegment(ctx, RestoreSubsegmentName, t.restoreStartTime, t.restoreEndTime) + } +} +func (t *StandaloneTracer) MarkError(ctx context.Context) {} +func (t *StandaloneTracer) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) {} + +func (t *StandaloneTracer) WithErrorCause(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { + return criticalFunction +} +func (t *StandaloneTracer) WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { + return criticalFunction +} + +func (t *StandaloneTracer) BuildTracingHeader() func(ctx context.Context) string { + // extract root trace ID and parent from context and build the tracing header + return func(ctx context.Context) string { + var parent string + var ok bool + + if parent, ok = ctx.Value(telemetry.DocumentIDKey).(string); !ok || parent == "" { + return t.invoke.TraceID + } + + if t.rootTraceID == "" || t.sampled == "" { + return "" + } + + var tracingHeader = "Root=%s;Parent=%s;Sampled=%s" + + if t.lineage == "" { + return fmt.Sprintf(tracingHeader, t.rootTraceID, parent, t.sampled) + } + + return fmt.Sprintf(tracingHeader+";Lineage=%s", t.rootTraceID, parent, t.sampled, t.lineage) + } +} + +func (t *StandaloneTracer) BuildTracingCtxForStart() *interop.TracingCtx { + if t.rootTraceID == "" || t.sampled != model.XRaySampled { + return nil + } + + return &interop.TracingCtx{ + SpanID: t.parent, + Type: model.XRayTracingType, + Value: telemetry.BuildFullTraceID(t.rootTraceID, t.invoke.LambdaSegmentID, t.sampled), + } +} +func (t *StandaloneTracer) BuildTracingCtxAfterInvokeComplete() *interop.TracingCtx { + if t.rootTraceID == "" || t.sampled != model.XRaySampled || t.invocationSubsegmentID == "" { + return nil + } + + return &interop.TracingCtx{ + SpanID: t.invocationSubsegmentID, + Type: model.XRayTracingType, + Value: t.tracingHeader, + } +} + +func isTracingEnabled(root, parent, sampled string) bool { + return len(root) != 0 && len(parent) != 0 && sampled == "1" +} + +func NewStandaloneTracer(api *StandaloneEventsAPI) *StandaloneTracer { + startCaptureFn := func(ctx context.Context, i *interop.Invoke, segmentName string, timestamp int64) { + root, parent, sampled, _ := telemetry.ParseTracingHeader(i.TraceID) + if isTracingEnabled(root, parent, sampled) { + e := TracingEvent{ + Message: "START", + TraceID: root, + SegmentName: segmentName, + SegmentID: parent, + Timestamp: timestamp / int64(time.Millisecond), + } + api.LogTrace(e) + log.WithFields(logrus.Fields{"trace": e}).Info("sandbox trace") + } + } + + endCaptureFn := func(ctx context.Context, i *interop.Invoke, segmentName string, timestamp int64) { + root, parent, sampled, _ := telemetry.ParseTracingHeader(i.TraceID) + if isTracingEnabled(root, parent, sampled) { + e := TracingEvent{ + Message: "END", + TraceID: root, + SegmentName: "", + SegmentID: parent, + Timestamp: timestamp / int64(time.Millisecond), + } + api.LogTrace(e) + log.WithFields(logrus.Fields{"trace": e}).Info("sandbox trace") + } + } + + return &StandaloneTracer{ + startFunction: startCaptureFn, + endFunction: endCaptureFn, + } +} diff --git a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go index 0a756dd..1caeb8c 100644 --- a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go +++ b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go @@ -10,7 +10,7 @@ import ( ) func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { - internalState, err := s.AwaitRelease() + releaseAwait, err := s.AwaitRelease() if err != nil { switch err { case rapidcore.ErrInvokeDoneFailed: @@ -22,10 +22,10 @@ func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s InteropSe return case rapidcore.ErrInitDoneFailed: w.WriteHeader(DoneFailedHTTPCode) - w.Write(internalState.AsJSON()) + w.Write(releaseAwait.AsJSON()) return } } - w.Write(internalState.AsJSON()) + w.Write(releaseAwait.AsJSON()) } diff --git a/lambda/rapidcore/telemetry/eventLog.go b/lambda/rapidcore/telemetry/eventLog.go deleted file mode 100644 index 2f809fa..0000000 --- a/lambda/rapidcore/telemetry/eventLog.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "strings" - "sync" - "time" -) - -// TODO: Refactor to represent event structs below as a form of Events API entity - -type XrayEvent struct { - Msg string `json:"msg"` - TraceID string `json:"traceID"` - SegmentName string `json:"segmentName"` - SegmentID string `json:"segmentID"` - Timestamp int64 `json:"timestamp"` -} - -// PlatformLogEvent represents a platform-generated customer log entry -type PlatformLogEvent struct { - Name string `json:"name"` - State string `json:"state"` - ErrorType string `json:"errorType"` - Subscriptions []string `json:"subscriptions"` -} - -// FunctionLogEvent represents a runtime-generated customer log entry -type FunctionLogEvent struct{} - -// ExtensionLogEvent represents an agent-generated customer log entry -type ExtensionLogEvent struct{} - -type EventLog struct { - Events []SandboxEvent `json:"events,omitempty"` // populated by the StandaloneEventLog object - Xray []XrayEvent `json:"xray,omitempty"` - PlatformLog []PlatformLogEvent `json:"platformLogs,omitempty"` - Logs []string `json:"rawLogs,omitempty"` - mutex sync.Mutex -} - -func parseLogString(s string) []string { - elems := strings.Split(s, "\t")[1:] - for i, e := range elems { - elems[i] = strings.Split(e, ": ")[1] - elems[i] = strings.TrimSuffix(elems[i], "\n") - elems[i] = strings.TrimPrefix(elems[i], "[") - elems[i] = strings.TrimSuffix(elems[i], "]") - } - return elems -} - -func (p *EventLog) dispatchLogEvent(logStr string) { - elems := parseLogString(logStr) - if strings.HasPrefix(logStr, "XRAY") { - // format: 'XRAY\tMessage: %s\tTraceID: %s\tSegmentName: %s\tSegmentID: %s' - msg, traceID, segmentName, segmentID := elems[0], elems[1], elems[2], elems[3] - p.Xray = append(p.Xray, XrayEvent{Msg: msg, TraceID: traceID, SegmentName: segmentName, SegmentID: segmentID, Timestamp: time.Now().UnixNano() / int64(time.Millisecond)}) - } -} - -func (p *EventLog) Write(logline []byte) (int, error) { - p.mutex.Lock() - defer p.mutex.Unlock() - - logStr := string(logline) - p.Logs = append(p.Logs, logStr) - - p.dispatchLogEvent(logStr) - - return len(logline), nil -} - -func NewEventLog() *EventLog { - return &EventLog{} -} diff --git a/lambda/rapidcore/telemetry/events_api.go b/lambda/rapidcore/telemetry/events_api.go deleted file mode 100644 index 7a882fd..0000000 --- a/lambda/rapidcore/telemetry/events_api.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "sort" - "time" - - "go.amzn.com/lambda/telemetry" -) - -// EventType indicates the type of SandboxEvent. See full list: -type EventType = string - -const ( - PlatformInitRuntimeDone = EventType("platform.initRuntimeDone") - PlatformRestoreRuntimeDone = EventType("platform.restoreRuntimeDone") - PlatformRuntimeDone = EventType("platform.runtimeDone") - PlatformExtension = EventType("platform.extension") -) - -/* - SandboxEvent represents a generic sandbox event. For example: - {'time': '2021-03-16T13:10:42.358Z', - 'type': 'platform.extension', - 'record': { "name": "foo bar", "state": "Ready", "events": ["INVOKE", "SHUTDOWN"]}} -*/ -type SandboxEvent struct { - Time string `json:"time"` - Type EventType `json:"type"` - Record map[string]interface{} `json:"record"` -} - -type StandaloneEventLog struct { - requestID string - eventLog *EventLog -} - -func (s *StandaloneEventLog) SetCurrentRequestID(requestID string) { - s.requestID = requestID -} - -func (s *StandaloneEventLog) SendInitRuntimeDone(data *telemetry.InitRuntimeDoneData) error { - record := map[string]interface{}{"initializationType": data.InitSource, "status": data.Status} - s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformInitRuntimeDone, record}) - return nil -} - -func (s *StandaloneEventLog) SendRestoreRuntimeDone(status string) error { - record := map[string]interface{}{"status": status} - s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformRestoreRuntimeDone, record}) - return nil -} - -func (s *StandaloneEventLog) SendRuntimeDone(data telemetry.InvokeRuntimeDoneData) error { - // e.g. 'record': {'requestId': '1506eb3053d148f3bb7ec0fabe6f8d91','status': 'success', 'metrics': {...}, 'tracing': {...}} - record := map[string]interface{}{ - "requestId": s.requestID, - "status": data.Status, - "metrics": data.Metrics, - "internalMetrics": data.InternalMetrics, - "spans": data.Spans, - } - - if data.Tracing != nil { - record["tracing"] = map[string]string{ - "spanId": data.Tracing.SpanID, - "type": string(data.Tracing.Type), - "value": data.Tracing.Value, - } - } - - s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformRuntimeDone, record}) - return nil -} - -func (s *StandaloneEventLog) SendExtensionInit(agentName, state, errorType string, subscriptions []string) error { - // e.g. 'record': { "name": "", "state": "", errorType: "", events: [""] } - sort.Strings(subscriptions) - record := map[string]interface{}{"name": agentName, "state": state, "events": subscriptions} - if len(errorType) > 0 { - record["errorType"] = errorType - } - s.eventLog.Events = append(s.eventLog.Events, SandboxEvent{time.Now().Format(time.RFC3339), PlatformExtension, record}) - return nil -} - -func (s *StandaloneEventLog) SendImageErrorLog(logline string) { - // Called on bootstrap exec errors for OCI error modes, e.g. InvalidEntrypoint etc. -} - -func NewStandaloneEventLog(eventLog *EventLog) *StandaloneEventLog { - return &StandaloneEventLog{ - eventLog: eventLog, - } -} diff --git a/lambda/rapidcore/telemetry/xray.go b/lambda/rapidcore/telemetry/xray.go deleted file mode 100644 index d7a6842..0000000 --- a/lambda/rapidcore/telemetry/xray.go +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "context" - "encoding/json" - "fmt" - "io" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/telemetry" -) - -// InitSubsegmentName provides name attribute for Init subsegment -const InitSubsegmentName = "Initialization" - -// InvokeSubsegmentName provides name attribute for Invoke subsegment -const InvokeSubsegmentName = "Invocation" - -// OverheadSubsegmentName provides name attribute for Overhead subsegment -const OverheadSubsegmentName = "Overhead" - -type traceContextKey int - -const ( - traceIDKey traceContextKey = iota - documentIDKey -) - -type StandaloneTracer struct { - startFunction func(ctx context.Context, invoke *interop.Invoke, segmentName string) - endFunction func(ctx context.Context, invoke *interop.Invoke, segmentName string) - functionName string - invoke *interop.Invoke -} - -func (t *StandaloneTracer) Configure(invoke *interop.Invoke) { - - t.invoke = invoke -} - -func (t *StandaloneTracer) CaptureInvokeSegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, t.functionName) -} - -func (t *StandaloneTracer) CaptureInitSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, InitSubsegmentName) -} - -func (t *StandaloneTracer) CaptureInvokeSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, InvokeSubsegmentName) -} - -func (t *StandaloneTracer) CaptureOverheadSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, OverheadSubsegmentName) -} - -func (t *StandaloneTracer) withStartAndEnd(ctx context.Context, criticalFunction func(context.Context) error, segmentName string) error { - t.startFunction(ctx, t.invoke, segmentName) - err := criticalFunction(ctx) - t.endFunction(ctx, t.invoke, segmentName) - return err -} - -func (t *StandaloneTracer) RecordInitStartTime() {} -func (t *StandaloneTracer) RecordInitEndTime() {} -func (t *StandaloneTracer) SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) {} -func (t *StandaloneTracer) MarkError(ctx context.Context) {} -func (t *StandaloneTracer) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) {} - -func (t *StandaloneTracer) WithErrorCause(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { - return criticalFunction -} -func (t *StandaloneTracer) WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { - return criticalFunction -} -func (t *StandaloneTracer) TracingHeaderParser() func(context.Context, *interop.Invoke) string { - getCustomerTracingHeader := func(ctx context.Context, invoke *interop.Invoke) string { - var root, parent string - var ok bool - - if root, ok = ctx.Value(traceIDKey).(string); !ok { - return invoke.TraceID - } - - if parent, ok = ctx.Value(documentIDKey).(string); !ok { - return invoke.TraceID - } - - return fmt.Sprintf("Root=%s;Parent=%s;Sampled=1", root, parent) - } - - return getCustomerTracingHeader -} - -func isTracingEnabled(root, parent, sampled string) bool { - return len(root) != 0 && len(parent) != 0 && sampled == "1" -} - -func NewStandaloneTracer(eventLog io.Writer, functionName string) *StandaloneTracer { - traceFormat := "XRAY\tMessage: %s\tTraceID: %s\tSegmentName: %s\tSegmentID: %s" - startCaptureFn := func(ctx context.Context, i *interop.Invoke, segmentName string) { - root, parent, sampled := telemetry.ParseTraceID(i.TraceID) - if isTracingEnabled(root, parent, sampled) { - fmt.Fprintf(eventLog, traceFormat, "START", root, segmentName, parent) - } - } - - endCaptureFn := func(ctx context.Context, i *interop.Invoke, segmentName string) { - root, parent, sampled := telemetry.ParseTraceID(i.TraceID) - if isTracingEnabled(root, parent, sampled) { - fmt.Fprintf(eventLog, traceFormat, "END", root, "", parent) - } - } - - return &StandaloneTracer{ - startFunction: startCaptureFn, - endFunction: endCaptureFn, - functionName: functionName, - } -} diff --git a/lambda/supervisor/local_supervisor.go b/lambda/supervisor/local_supervisor.go index 1174089..4405686 100644 --- a/lambda/supervisor/local_supervisor.go +++ b/lambda/supervisor/local_supervisor.go @@ -4,9 +4,11 @@ package supervisor import ( + "context" "errors" "fmt" "os/exec" + "runtime" "sync" "syscall" "time" @@ -27,33 +29,31 @@ type process struct { } type LocalSupervisor struct { - events chan model.Event - processMapLock sync.Mutex - processMap map[string]process + events chan model.Event + processMapLock sync.Mutex + processMap map[string]process + freezeThawCycleStart time.Time + + RootPath string } -func NewLocalSupervisor() model.Supervisor { - return model.Supervisor{ - SupervisorClient: &LocalSupervisor{ - events: make(chan model.Event), - processMap: make(map[string]process), - }, - OperatorConfig: model.DomainConfig{ - RootPath: "/", - }, - RuntimeConfig: model.DomainConfig{ - RootPath: "/", - }, +func NewLocalSupervisor() *LocalSupervisor { + return &LocalSupervisor{ + events: make(chan model.Event), + processMap: make(map[string]process), + RootPath: "/", } } -func (*LocalSupervisor) Start(req *model.StartRequest) error { +func (*LocalSupervisor) Start(ctx context.Context, req *model.StartRequest) error { return nil } -func (*LocalSupervisor) Configure(req *model.ConfigureRequest) error { +func (*LocalSupervisor) Configure(ctx context.Context, req *model.ConfigureRequest) error { return nil } -func (s *LocalSupervisor) Exec(req *model.ExecRequest) error { +func (*LocalSupervisor) Exit(ctx context.Context) {} + +func (s *LocalSupervisor) Exec(ctx context.Context, req *model.ExecRequest) error { if req.Domain != "runtime" { log.Debug("Exec is a no op if domain != runtime") return nil @@ -97,6 +97,9 @@ func (s *LocalSupervisor) Exec(req *model.ExecRequest) error { } s.processMapLock.Unlock() + // The first freeze thaw cycle starts on Exec() at init time + s.freezeThawCycleStart = time.Now() + go func() { err = command.Wait() // close the termination channel to unblock whoever's blocked on @@ -141,11 +144,11 @@ func (s *LocalSupervisor) Exec(req *model.ExecRequest) error { return nil } -func kill(p process, name string, timeout *time.Duration) error { +func kill(p process, name string, deadline time.Time) error { // kill should report success if the process terminated by the time //supervisor receives the request. select { - // ifthis case is selected, the channel is closed, + // if this case is selected, the channel is closed, // which means the process is terminated case <-p.termination: log.Debugf("Process %s already terminated.", name) @@ -154,8 +157,8 @@ func kill(p process, name string, timeout *time.Duration) error { log.Infof("Sending SIGKILL to %s(%d).", name, p.pid) } - if timeout != nil && *timeout <= 0 { - return fmt.Errorf("Timed out while trying to SIGKILL %s", name) + if (time.Since(deadline)) > 0 { + return fmt.Errorf("invalid timeout while killing %s", name) } pgid, err := syscall.Getpgid(p.pid) @@ -167,23 +170,20 @@ func kill(p process, name string, timeout *time.Duration) error { syscall.Kill(p.pid, syscall.SIGKILL) } - // the nil channel blocks forever - var timer <-chan time.Time - if timeout != nil { - timer = time.After(*timeout) - } + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() // block until the (main) process exits // or the timeout fires select { case <-p.termination: return nil - case <-timer: - return fmt.Errorf("Timed out while trying to SIGKILL %s", name) + case <-ctx.Done(): + return fmt.Errorf("timed out while trying to SIGKILL %s", name) } } -func (s *LocalSupervisor) Kill(req *model.KillRequest) error { +func (s *LocalSupervisor) Kill(ctx context.Context, req *model.KillRequest) error { if req.Domain != "runtime" { log.Debug("Kill is a no op if domain != runtime") return nil @@ -198,12 +198,11 @@ func (s *LocalSupervisor) Kill(req *model.KillRequest) error { Message: &msg, } } - timeout := convertTimeout(req.Timeout) - return kill(process, req.Name, timeout) + return kill(process, req.Name, req.Deadline) } -func (s *LocalSupervisor) Terminate(req *model.TerminateRequest) error { +func (s *LocalSupervisor) Terminate(ctx context.Context, req *model.TerminateRequest) error { if req.Domain != "runtime" { log.Debug("Terminate is no op if domain != runtime") return nil @@ -235,12 +234,11 @@ func (s *LocalSupervisor) Terminate(req *model.TerminateRequest) error { return nil } -func (s *LocalSupervisor) Stop(req *model.StopRequest) error { +func (s *LocalSupervisor) Stop(ctx context.Context, req *model.StopRequest) (*model.StopResponse, error) { if req.Domain != "runtime" { log.Debug("Shutdown is no op if domain != runtime") - return nil + return &model.StopResponse{}, nil } - timeout := convertTimeout(req.Timeout) // shut down kills all the processes in the map s.processMapLock.Lock() @@ -253,7 +251,7 @@ func (s *LocalSupervisor) Stop(req *model.StopRequest) error { for name, proc := range s.processMap { go func(n string, p process) { log.Debugf("Killing %s", n) - err := kill(p, n, timeout) + err := kill(p, n, req.Deadline) if err != nil { errors <- err } else { @@ -269,34 +267,37 @@ func (s *LocalSupervisor) Stop(req *model.StopRequest) error { case <-successes: case e := <-errors: if err == nil { - err = fmt.Errorf("Shutdown failed: %s", e.Error()) + err = fmt.Errorf("shutdown failed: %s", e.Error()) } } } s.processMap = make(map[string]process) - return err + return nil, err } -func (*LocalSupervisor) Freeze(req *model.FreezeRequest) error { - return nil + +func (s *LocalSupervisor) Freeze(ctx context.Context, req *model.FreezeRequest) (*model.FreezeResponse, error) { + // We return mocked freeze/thaw cycle metrics to mimic usage metrics in standalone mode + var m runtime.MemStats + runtime.ReadMemStats(&m) + return &model.FreezeResponse{ + CycleDeltaMetrics: model.CycleDeltaMetrics{ + DomainCPURunNs: uint64(time.Since(s.freezeThawCycleStart).Nanoseconds()), + DomainRunNs: uint64(time.Since(s.freezeThawCycleStart).Nanoseconds()), + DomainMaxMemoryUsageBytes: m.Alloc, + MicrovmCPURunNs: uint64(time.Since(s.freezeThawCycleStart).Nanoseconds()), + }, + }, nil } -func (*LocalSupervisor) Thaw(req *model.ThawRequest) error { +func (s *LocalSupervisor) Thaw(ctx context.Context, req *model.ThawRequest) error { + s.freezeThawCycleStart = time.Now() return nil } -func (s *LocalSupervisor) Ping() error { +func (s *LocalSupervisor) Ping(ctx context.Context) error { return nil } -func (s *LocalSupervisor) Events() (<-chan model.Event, error) { +func (s *LocalSupervisor) Events(ctx context.Context, req *model.EventsRequest) (<-chan model.Event, error) { return s.events, nil } - -func convertTimeout(millis *uint64) *time.Duration { - var timeout *time.Duration - if millis != nil { - t := time.Duration(*millis) * time.Millisecond - timeout = &t - } - return timeout -} diff --git a/lambda/supervisor/local_supervisor_test.go b/lambda/supervisor/local_supervisor_test.go index 8b3336b..02a06f6 100644 --- a/lambda/supervisor/local_supervisor_test.go +++ b/lambda/supervisor/local_supervisor_test.go @@ -4,6 +4,7 @@ package supervisor import ( + "context" "errors" "fmt" "syscall" @@ -18,7 +19,7 @@ import ( func TestRuntimeDomainExec(t *testing.T) { supv := NewLocalSupervisor() - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", @@ -29,7 +30,7 @@ func TestRuntimeDomainExec(t *testing.T) { func TestInvalidRuntimeDomainExec(t *testing.T) { supv := NewLocalSupervisor() - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/none", @@ -40,10 +41,14 @@ func TestInvalidRuntimeDomainExec(t *testing.T) { func TestEvents(t *testing.T) { supv := NewLocalSupervisor() - client := supv.SupervisorClient.(*LocalSupervisor) sync := make(chan struct{}) go func() { - evt, ok := <-client.events + eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ + Domain: "runtime", + }) + require.NoError(t, err) + + evt, ok := <-eventCh require.True(t, ok) termination := evt.Event.ProcessTerminated() require.NotNil(t, termination) @@ -52,7 +57,7 @@ func TestEvents(t *testing.T) { sync <- struct{}{} }() - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", @@ -63,8 +68,7 @@ func TestEvents(t *testing.T) { func TestTerminate(t *testing.T) { supv := NewLocalSupervisor() - client := supv.SupervisorClient.(*LocalSupervisor) - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", @@ -72,13 +76,18 @@ func TestTerminate(t *testing.T) { }) require.NoError(t, err) time.Sleep(100 * time.Millisecond) - err = supv.Terminate(&model.TerminateRequest{ + err = supv.Terminate(context.Background(), &model.TerminateRequest{ Domain: "runtime", Name: "agent", }) require.NoError(t, err) // wait for process exit notification - ev := <-client.events + eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ + Domain: "runtime", + }) + require.NoError(t, err) + ev := <-eventCh + require.NotNil(t, ev.Event.ProcessTerminated()) term := *ev.Event.ProcessTerminated() require.Nil(t, term.Exited()) @@ -89,7 +98,7 @@ func TestTerminate(t *testing.T) { // Termiante should not fail if the message is not delivered func TestTerminateExited(t *testing.T) { supv := NewLocalSupervisor() - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", @@ -97,7 +106,7 @@ func TestTerminateExited(t *testing.T) { require.NoError(t, err) // wait a short bit for bash to exit time.Sleep(100 * time.Millisecond) - err = supv.Terminate(&model.TerminateRequest{ + err = supv.Terminate(context.Background(), &model.TerminateRequest{ Domain: "runtime", Name: "agent", }) @@ -106,22 +115,27 @@ func TestTerminateExited(t *testing.T) { func TestKill(t *testing.T) { supv := NewLocalSupervisor() - client := supv.SupervisorClient.(*LocalSupervisor) - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", Args: []string{"-c", "sleep 10s"}, }) require.NoError(t, err) - err = supv.Kill(&model.KillRequest{ - Domain: "runtime", - Name: "agent", + err = supv.Kill(context.Background(), &model.KillRequest{ + Domain: "runtime", + Name: "agent", + Deadline: time.Now().Add(time.Second), }) require.NoError(t, err) timer := time.NewTimer(50 * time.Millisecond) + eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ + Domain: "runtime", + }) + require.NoError(t, err) + select { - case _, ok := <-client.events: + case _, ok := <-eventCh: assert.True(t, ok) case <-timer.C: require.Fail(t, "Process should have exited by the time kill returns") @@ -130,27 +144,32 @@ func TestKill(t *testing.T) { func TestKillExited(t *testing.T) { supv := NewLocalSupervisor() - client := supv.SupervisorClient.(*LocalSupervisor) - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent", Path: "/bin/bash", }) require.NoError(t, err) //wait for natural exit event - <-client.events - err = supv.Kill(&model.KillRequest{ + eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ Domain: "runtime", - Name: "agent", + }) + require.NoError(t, err) + <-eventCh + err = supv.Kill(context.Background(), &model.KillRequest{ + Domain: "runtime", + Name: "agent", + Deadline: time.Now().Add(time.Second), }) require.NoError(t, err, "Kill should succeed for exited processes") } func TestKillUnknown(t *testing.T) { supv := NewLocalSupervisor() - err := supv.Kill(&model.KillRequest{ - Domain: "runtime", - Name: "unknown", + err := supv.Kill(context.Background(), &model.KillRequest{ + Domain: "runtime", + Name: "unknown", + Deadline: time.Now().Add(time.Second), }) require.Error(t, err) var supvError *model.SupervisorError @@ -160,10 +179,9 @@ func TestKillUnknown(t *testing.T) { func TestShutdown(t *testing.T) { supv := NewLocalSupervisor() - client := supv.SupervisorClient.(*LocalSupervisor) log.Debug("hello") // start a bunch of processes, some short running, some longer running - err := supv.Exec(&model.ExecRequest{ + err := supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent-0", Path: "/bin/bash", @@ -171,14 +189,14 @@ func TestShutdown(t *testing.T) { }) require.NoError(t, err) - err = supv.Exec(&model.ExecRequest{ + err = supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent-1", Path: "/bin/bash", }) require.NoError(t, err) - err = supv.Exec(&model.ExecRequest{ + err = supv.Exec(context.Background(), &model.ExecRequest{ Domain: "runtime", Name: "agent-2", Path: "/bin/bash", @@ -186,8 +204,9 @@ func TestShutdown(t *testing.T) { }) require.NoError(t, err) time.Sleep(100 * time.Millisecond) - err = supv.Stop(&model.StopRequest{ - Domain: "runtime", + _, err = supv.Stop(context.Background(), &model.StopRequest{ + Domain: "runtime", + Deadline: time.Now().Add(time.Second), }) require.NoError(t, err) // Shutdown is expected to block untill all processes have exited @@ -198,9 +217,13 @@ func TestShutdown(t *testing.T) { } done := false timer := time.NewTimer(200 * time.Millisecond) + eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ + Domain: "runtime", + }) + require.NoError(t, err) for !done { select { - case ev := <-client.events: + case ev := <-eventCh: data := ev.Event.ProcessTerminated() assert.NotNil(t, data) _, ok := expected[*data.Name] diff --git a/lambda/supervisor/model/model.go b/lambda/supervisor/model/model.go index 384726d..d89ec18 100644 --- a/lambda/supervisor/model/model.go +++ b/lambda/supervisor/model/model.go @@ -4,41 +4,73 @@ package model import ( + "context" "encoding/json" "fmt" "io" "os" "syscall" + "time" ) -type Supervisor struct { - SupervisorClient - OperatorConfig DomainConfig - RuntimeConfig DomainConfig +// Start, Stop and Configure methods are not used in Core anymore. +// Client interface splitted into Launcher and Executer parts for backward compatibility of dependent packages. +type ContainerSupervisor interface { + Start(context.Context, *StartRequest) error + Configure(context.Context, *ConfigureRequest) error + Stop(context.Context, *StopRequest) (*StopResponse, error) + Freeze(context.Context, *FreezeRequest) (*FreezeResponse, error) + Thaw(context.Context, *ThawRequest) error + Exit(context.Context) } -type DomainConfig struct { - // path to the root of the domain within the root mnt namespace - RootPath string +type ProcessSupervisor interface { + Exec(context.Context, *ExecRequest) error + Terminate(context.Context, *TerminateRequest) error + Kill(context.Context, *KillRequest) error + Events(context.Context, *EventsRequest) (<-chan Event, error) } type SupervisorClient interface { - Start(req *StartRequest) error - Configure(req *ConfigureRequest) error - Exec(req *ExecRequest) error - Terminate(req *TerminateRequest) error - Kill(req *KillRequest) error - Stop(req *StopRequest) error - Freeze(req *FreezeRequest) error - Thaw(req *ThawRequest) error - Ping() error - Events() (<-chan Event, error) + ContainerSupervisor + ProcessSupervisor + Ping(ctx context.Context) error } type StartRequest struct { Domain string `json:"domain"` - // name of the cgroup profile to start the domain in - CgroupProfile *string `json:"cgroup_profile,omitempty"` +} + +type Mount struct { + DriveMount DriveMount + BindMount BindMount + MountType MountType +} + +type MountType int + +const ( + _ MountType = iota + MountTypeDrive + MountTypeBind +) + +type CgroupProfileName string + +const ( + Throttled CgroupProfileName = "throttled" + Unthrottled CgroupProfileName = "unthrottled" +) + +func (m *Mount) MarshalJSON() ([]byte, error) { + switch m.MountType { + case MountTypeDrive: + return m.DriveMount.MarshalJSON() + case MountTypeBind: + return m.BindMount.MarshalJSON() + default: + return nil, fmt.Errorf("invalid mount type: %v", m.MountType) + } } // Mount in lockhard::mnt is a Rust enum, an algebraic type, where each case has different set of fields. @@ -66,6 +98,24 @@ func (m *DriveMount) MarshalJSON() ([]byte, error) { }) } +type BindMount struct { + Source string `json:"source,omitempty"` + Destination string `json:"destination,omitempty"` + Options []string `json:"options,omitempty"` +} + +func (m *BindMount) MarshalJSON() ([]byte, error) { + type bindMountAlias BindMount + + return json.Marshal(&struct { + Type string `json:"type,omitempty"` + *bindMountAlias + }{ + Type: "bind", + bindMountAlias: (*bindMountAlias)(m), + }) +} + type Capabilities struct { Ambient []string `json:"ambient,omitempty"` Bounding []string `json:"bounding,omitempty"` @@ -74,10 +124,14 @@ type Capabilities struct { Permitted []string `json:"permitted,omitempty"` } -type CgroupProfile struct { - Name string `json:"name"` - CPUPct *float64 `json:"cpu_pct,omitempty"` - MemMaxBytes *uint64 `json:"mem_max,omitempty"` +type CgroupProfiles struct { + Throttled CgroupProfileConfig `json:"throttled"` + Unthrottled CgroupProfileConfig `json:"unthrottled"` +} + +type CgroupProfileConfig struct { + CPULimit float64 `json:"cpu_limit"` + MemoryLimitBytes uint64 `json:"memory_limit_bytes"` } type ExecUser struct { @@ -88,12 +142,15 @@ type ExecUser struct { type ConfigureRequest struct { // domain to configure Domain string `json:"domain"` - Mounts []DriveMount `json:"mounts,omitempty"` + Mounts []Mount `json:"mounts,omitempty"` Capabilities *Capabilities `json:"capabilities,omitempty"` SeccompFilters []string `json:"seccomp_filters,omitempty"` // list of cgroup profiles available for the domain - // cgroup profiles are set on boot or thaw requests - CgroupProfiles []CgroupProfile `json:"cgroup_profiles,omitempty"` + // cgroup profiles are set on start and thaw request. Start profile + // if configured (as it can vary), thaw profile is always the same (throttled) + CgroupProfiles *CgroupProfiles `json:"cgroup_profiles,omitempty"` + // name of the cgroup profile to enforce at domain start + StartProfile CgroupProfileName `json:"start_profile,omitempty"` // uid and gid of the user the spawned process runs as (w.r.t. the domain user namespace). // If nil, Supervisor will use the ExecUser specified in the domain configuration file ExecUser *ExecUser `json:"exec_user,omitempty"` @@ -101,6 +158,10 @@ type ConfigureRequest struct { AdditionalStartHooks []Hook `json:"additional_start_hooks,omitempty"` } +type EventsRequest struct { + Domain string `json:"domain"` +} + type Event struct { Time uint64 `json:"timestamp_millis"` Event EventData `json:"event"` @@ -188,9 +249,6 @@ type Hook struct { Args []string `json:"args,omitempty"` // Map of ENV variables to set when running the hook Env *map[string]string `json:"envs,omitempty"` - // Maximum time for the hook to run. The hook will be considered failed - // if it takes more than this value (default 10_000) - TimeoutMillis *uint64 `json:"timeout_millis,omitempty"` } type ExecRequest struct { @@ -203,16 +261,38 @@ type ExecRequest struct { Path string `json:"path"` Args []string `json:"args,omitempty"` // If nil, root of the domain - Cwd *string `json:"cwd,omitempty"` - Env *map[string]string `json:"env,omitempty"` - // If not nil, points to the socket that Supervisor - // uses to get the processes stdout and stderr. - LogsSock *string `json:"logs_sock,omitempty"` - StdoutWriter io.Writer `json:"-"` - StderrWriter io.Writer `json:"-"` - ExtraFiles *[]*os.File `json:"-"` + Cwd *string `json:"cwd,omitempty"` + Env *map[string]string `json:"env,omitempty"` + Logging Logging `json:"log_config"` + StdoutWriter io.Writer `json:"-"` + StderrWriter io.Writer `json:"-"` + ExtraFiles *[]*os.File `json:"-"` +} + +// Logging specifies where Supervisor should send Command's logs to +type Logging struct { + Managed ManagedLogging `json:"managed"` } +type ManagedLogging struct { + Topic ManagedLoggingTopic `json:"topic"` + Formats []ManagedLoggingFormat `json:"formats"` +} + +type ManagedLoggingTopic string + +const ( + RuntimeManagedLoggingTopic ManagedLoggingTopic = "runtime" + RtExtensionManagedLoggingTopic ManagedLoggingTopic = "runtime_extension" +) + +type ManagedLoggingFormat string + +const ( + LineBasedManagedLogging ManagedLoggingFormat = "line" + MessageBasedManagedLogging ManagedLoggingFormat = "message" +) + type ErrorKind string const ( @@ -243,27 +323,54 @@ type TerminateRequest struct { // Force terminate a process (SIGKILL) // Block until process is exited or timeout -// If timeout is 0 or nil, block forever +// Deadline needs to be in the future type KillRequest struct { - Name string `json:"name"` - Domain string `json:"domain"` - Timeout *uint64 `json:",omitempty"` + Name string `json:"name"` + Domain string `json:"domain"` + Deadline time.Time `json:"deadline"` } -// Stop the domain. Supervisor will first try to -// cleanly terminate the domain's init process. If unsuccessful, -// within Timeout seconds, it will send SIGKILL. +// Stop the domain. type StopRequest struct { - Domain string `json:"domain"` - Timeout *uint64 `json:",omitempty"` + Domain string `json:"domain"` + Deadline time.Time `json:"deadline"` +} + +type StopResponse struct { + CycleDeltaMetrics CycleDeltaMetrics `json:"cycle_delta_metrics"` } type FreezeRequest struct { Domain string `json:"domain"` } +type FreezeResponse struct { + CycleDeltaMetrics CycleDeltaMetrics `json:"cycle_delta_metrics"` +} + +type MicrovmNetworkInterfaceMetrics struct { + ReceivedBytes uint64 `json:"received_bytes"` + TransmittedBytes uint64 `json:"transmitted_bytes"` +} + +type CycleDeltaMetrics struct { + // CPU time (in nanoseconds) obtained by domain cgroup from cpuacct.usage + // https://www.kernel.org/doc/Documentation/cgroup-v1/cpuacct.txt + DomainCPURunNs uint64 `json:"domain_cpu_run_ns"` + // time (in nanoseconds) for domain cycle + DomainRunNs uint64 `json:"domain_run_ns"` + // CPU delta time for service cgroup + ServiceCPURunNs uint64 `json:"service_cpu_run_ns"` + // Maximum memory used (in bytes) for domain + DomainMaxMemoryUsageBytes uint64 `json:"domain_max_memory_usage_bytes"` + // CPU delta time (in nanoseconds) obtained from /sys/fs/cgroup/cpu,cpuacct/cpuacct.usage + MicrovmCPURunNs uint64 `json:"microvm_cpu_run_ns"` + // Map with network interface name as key and network metrics as a value + MicrovmNetworksBytes map[string]MicrovmNetworkInterfaceMetrics `json:"microvm_network_interfaces"` + // time ( in nanoseconds ) for idle cpu time + InvokeIdleCPURunNs uint64 `json:"idle_cpu_run_ns"` +} + type ThawRequest struct { Domain string `json:"domain"` - // if not nil, changes the cgroup profile of the domain upon thawing. - CgroupProfile *string `json:"cgroup_profile,omitempty"` } diff --git a/lambda/supervisor/model/model_test.go b/lambda/supervisor/model/model_test.go new file mode 100644 index 0000000..ea39580 --- /dev/null +++ b/lambda/supervisor/model/model_test.go @@ -0,0 +1,31 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "testing" + "time" +) + +// LockHard accepts deadlines encoded as RFC3339 - we enforce this with a test +func Test_KillDeadlineIsMarshalledIntoRFC3339(t *testing.T) { + deadline, err := time.Parse(time.RFC3339, "2022-12-21T10:00:00Z") + if err != nil { + t.Error(err) + } + k := KillRequest{ + Name: "", + Domain: "", + Deadline: deadline, + } + bytes, err := json.Marshal(k) + if err != nil { + t.Error(err) + } + exepected := `{"name":"","domain":"","deadline":"2022-12-21T10:00:00Z"}` + if string(bytes) != exepected { + t.Errorf("error in marshaling `KillRequest` it does not match the expected string (Expected(%q) != Got(%q))", exepected, string(bytes)) + } +} diff --git a/lambda/rapidcore/telemetry/logsapi/constants.go b/lambda/telemetry/constants.go similarity index 94% rename from lambda/rapidcore/telemetry/logsapi/constants.go rename to lambda/telemetry/constants.go index f54e415..0198660 100644 --- a/lambda/rapidcore/telemetry/logsapi/constants.go +++ b/lambda/telemetry/constants.go @@ -1,17 +1,17 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package logsapi +package telemetry import "errors" -// ErrTelemetryServiceOff returned on attempt to subscribe after telemetry service has been turned off. -var ErrTelemetryServiceOff = errors.New("ErrTelemetryServiceOff") - -// Metrics const ( + // Metrics SubscribeSuccess = "logs_api_subscribe_success" SubscribeClientErr = "logs_api_subscribe_client_err" SubscribeServerErr = "logs_api_subscribe_server_err" NumSubscribers = "logs_api_num_subscribers" ) + +// ErrTelemetryServiceOff returned on attempt to subscribe after telemetry service has been turned off. +var ErrTelemetryServiceOff = errors.New("ErrTelemetryServiceOff") diff --git a/lambda/telemetry/events_api.go b/lambda/telemetry/events_api.go index e7c5c36..371f439 100644 --- a/lambda/telemetry/events_api.go +++ b/lambda/telemetry/events_api.go @@ -4,135 +4,151 @@ package telemetry import ( + "fmt" "time" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapi/model" ) -type RuntimeDoneInvokeMetrics struct { - ProducedBytes int64 - DurationMs float64 -} - -func GetRuntimeDoneInvokeMetrics(invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics, runtimeDoneTime int64) *RuntimeDoneInvokeMetrics { - if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && invokeReceivedTime != 0 { - return &RuntimeDoneInvokeMetrics{ +func GetRuntimeDoneInvokeMetrics(runtimeStartedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics, runtimeDoneTime int64) *interop.RuntimeDoneInvokeMetrics { + // time taken from sending the invoke to the sandbox until the runtime calls GET /next + duration := CalculateDuration(runtimeStartedTime, runtimeDoneTime) + if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && runtimeStartedTime != -1 { + return &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: invokeResponseMetrics.ProducedBytes, - // time taken from sending the invoke to the sandbox until the runtime calls GET /next - DurationMs: float64((runtimeDoneTime - invokeReceivedTime) / int64(time.Millisecond)), + DurationMs: duration, } } // when we get a reset before runtime called /response - if invokeReceivedTime != 0 { - return &RuntimeDoneInvokeMetrics{ + if runtimeStartedTime != -1 { + return &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(0), - DurationMs: float64((runtimeDoneTime - invokeReceivedTime) / int64(time.Millisecond)), + DurationMs: duration, } } // We didn't have time to register the invokeReceiveTime, which means we crash/reset very early, // too early for the runtime to actual run. In such case, the runtimeDone event shouldn't be sent // Not returning Nil even in this improbable case guarantees that we will always have some metrics to send to FluxPump - return &RuntimeDoneInvokeMetrics{ + return &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(0), DurationMs: float64(0), } } -type InitRuntimeDoneData struct { - InitSource string - Status string -} - -type InvokeRuntimeDoneData struct { - Status string - Metrics *RuntimeDoneInvokeMetrics - InternalMetrics *interop.InvokeResponseMetrics - Tracing *TracingCtx - Spans []Span -} +const ( + InitInsideInitPhase interop.InitPhase = "init" + InitInsideInvokePhase interop.InitPhase = "invoke" +) -type Span struct { - Name string - Start string - DurationMs float64 +func InitPhaseFromLifecyclePhase(phase interop.LifecyclePhase) (interop.InitPhase, error) { + switch phase { + case interop.LifecyclePhaseInit: + return InitInsideInitPhase, nil + case interop.LifecyclePhaseInvoke: + return InitInsideInvokePhase, nil + default: + return interop.InitPhase(""), fmt.Errorf("unexpected lifecycle phase: %v", phase) + } } -func GetRuntimeDoneSpans(invokeReceivedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) []Span { - if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && invokeReceivedTime != 0 { +func GetRuntimeDoneSpans(runtimeStartedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) []interop.Span { + if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && runtimeStartedTime != -1 { // time span from when the invoke is received in the sandbox to the moment the runtime calls PUT /response - responseLatencyMsSpan := Span{ + responseLatencyMsSpan := interop.Span{ Name: "responseLatency", - Start: getEpochTimeInISO8601FormatFromMonotime(invokeReceivedTime), - DurationMs: float64((invokeResponseMetrics.StartReadingResponseMonoTimeMs - invokeReceivedTime) / int64(time.Millisecond)), + Start: GetEpochTimeInISO8601FormatFromMonotime(runtimeStartedTime), + DurationMs: CalculateDuration(runtimeStartedTime, invokeResponseMetrics.StartReadingResponseMonoTimeMs), } // time span from when the runtime called PUT /response to the moment the body of the response is fully sent - responseDurationMsSpan := Span{ + responseDurationMsSpan := interop.Span{ Name: "responseDuration", - Start: getEpochTimeInISO8601FormatFromMonotime(invokeResponseMetrics.StartReadingResponseMonoTimeMs), - DurationMs: float64((invokeResponseMetrics.FinishReadingResponseMonoTimeMs - invokeResponseMetrics.StartReadingResponseMonoTimeMs) / int64(time.Millisecond)), + Start: GetEpochTimeInISO8601FormatFromMonotime(invokeResponseMetrics.StartReadingResponseMonoTimeMs), + DurationMs: CalculateDuration(invokeResponseMetrics.StartReadingResponseMonoTimeMs, invokeResponseMetrics.FinishReadingResponseMonoTimeMs), } - return []Span{responseLatencyMsSpan, responseDurationMsSpan} + return []interop.Span{responseLatencyMsSpan, responseDurationMsSpan} } - return []Span{} + return []interop.Span{} } -func getEpochTimeInISO8601FormatFromMonotime(monotime int64) string { - return time.Unix(0, metering.MonoToEpoch(monotime)).Format("2006-01-02T15:04:05.000Z") +// CalculateDuration calculates duration between two moments. +// The result is milliseconds with microsecond precision. +// Two assumptions here: +// 1. the passed values are nanoseconds +// 2. endNs > startNs +func CalculateDuration(startNs, endNs int64) float64 { + microseconds := int64(endNs-startNs) / int64(time.Microsecond) + return float64(microseconds) / 1000 } -type TracingCtx struct { - SpanID string - Type model.TracingType - Value string -} +const ( + InitTypeOnDemand interop.InitType = "on-demand" + InitTypeProvisionedConcurrency interop.InitType = "provisioned-concurrency" + InitTypeInitCaching interop.InitType = "snap-start" +) -func BuildTracingCtx(tracingType model.TracingType, traceID string, lambdaSegmentID string) *TracingCtx { - // it takes current tracing context and change its parent value with the provided lambda segment id - root, currentParent, sample := ParseTraceID(traceID) - if root == "" || sample != model.XRaySampled { - return nil - } +func InferInitType(initCachingEnabled bool, sandboxType interop.SandboxType) interop.InitType { + initSource := InitTypeOnDemand - return &TracingCtx{ - SpanID: currentParent, - Type: tracingType, - Value: BuildFullTraceID(root, lambdaSegmentID, sample), + // ToDo: Unify this selection of SandboxType by using the START message + // after having a roadmap on the combination of INIT modes + if initCachingEnabled { + initSource = InitTypeInitCaching + } else if sandboxType == interop.SandboxPreWarmed { + initSource = InitTypeProvisionedConcurrency } + + return initSource +} + +func GetEpochTimeInISO8601FormatFromMonotime(monotime int64) string { + return time.Unix(0, metering.MonoToEpoch(monotime)).Format("2006-01-02T15:04:05.000Z") } const ( RuntimeDoneSuccess = "success" - RuntimeDoneFailure = "failure" + RuntimeDoneError = "error" ) -type EventsAPI interface { - SetCurrentRequestID(requestID string) - SendInitRuntimeDone(data *InitRuntimeDoneData) error - SendRestoreRuntimeDone(status string) error - SendRuntimeDone(data InvokeRuntimeDoneData) error - SendExtensionInit(agentName, state, errorType string, subscriptions []string) error - SendImageErrorLog(logline string) -} - type NoOpEventsAPI struct{} -func (s *NoOpEventsAPI) SetCurrentRequestID(requestID string) {} -func (s *NoOpEventsAPI) SendInitRuntimeDone(data *InitRuntimeDoneData) error { - return nil -} -func (s *NoOpEventsAPI) SendRestoreRuntimeDone(status string) error { - return nil -} -func (s *NoOpEventsAPI) SendRuntimeDone(data InvokeRuntimeDoneData) error { - return nil -} -func (s *NoOpEventsAPI) SendExtensionInit(agentName, state, errorType string, subscriptions []string) error { - return nil +func (s *NoOpEventsAPI) SetCurrentRequestID(interop.RequestID) {} + +func (s *NoOpEventsAPI) SendInitStart(interop.InitStartData) error { return nil } + +func (s *NoOpEventsAPI) SendInitRuntimeDone(interop.InitRuntimeDoneData) error { return nil } + +func (s *NoOpEventsAPI) SendInitReport(interop.InitReportData) error { return nil } + +func (s *NoOpEventsAPI) SendRestoreRuntimeDone(interop.RestoreRuntimeDoneData) error { return nil } + +func (s *NoOpEventsAPI) SendInvokeStart(interop.InvokeStartData) error { return nil } + +func (s *NoOpEventsAPI) SendInvokeRuntimeDone(interop.InvokeRuntimeDoneData) error { return nil } + +func (s *NoOpEventsAPI) SendExtensionInit(interop.ExtensionInitData) error { return nil } + +func (s *NoOpEventsAPI) SendEnd(interop.EndData) error { return nil } + +func (s *NoOpEventsAPI) SendReportSpan(interop.Span) error { return nil } + +func (s *NoOpEventsAPI) SendReport(interop.ReportData) error { return nil } + +func (s *NoOpEventsAPI) SendFault(interop.FaultData) error { return nil } + +func (s *NoOpEventsAPI) SendImageErrorLog(interop.ImageErrorLogData) {} + +func (s *NoOpEventsAPI) FetchTailLogs(string) (string, error) { return "", nil } + +func (s *NoOpEventsAPI) GetRuntimeDoneSpans( + runtimeStartedTime int64, + invokeResponseMetrics *interop.InvokeResponseMetrics, + runtimeOverheadStartedTime int64, + runtimeReadyTime int64, +) []interop.Span { + return []interop.Span{} } -func (s *NoOpEventsAPI) SendImageErrorLog(logline string) {} diff --git a/lambda/telemetry/events_api_test.go b/lambda/telemetry/events_api_test.go index b943be9..f69e4ea 100644 --- a/lambda/telemetry/events_api_test.go +++ b/lambda/telemetry/events_api_test.go @@ -15,65 +15,66 @@ import ( func TestGetRuntimeDoneInvokeMetrics(t *testing.T) { now := metering.Monotime() - invokeReceivedTime := now + runtimeStartedTime := now invokeResponseMetrics := &interop.InvokeResponseMetrics{ ProducedBytes: int64(100), RuntimeCalledResponse: true, } runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) - expected := &RuntimeDoneInvokeMetrics{ + expected := &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(100), DurationMs: float64(10), } - assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, runtimeDoneTime)) + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(runtimeStartedTime, invokeResponseMetrics, runtimeDoneTime)) } func TestGetRuntimeDoneInvokeMetricsWhenRuntimeCalledError(t *testing.T) { now := metering.Monotime() - invokeReceivedTime := now + runtimeStartedTime := now invokeResponseMetrics := &interop.InvokeResponseMetrics{ ProducedBytes: int64(100), RuntimeCalledResponse: false, } - runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) + // validating microsecond precision + runtimeDoneTime := now + int64(time.Duration(10)*time.Millisecond+time.Duration(50)*time.Microsecond) - expected := &RuntimeDoneInvokeMetrics{ + expected := &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(0), - DurationMs: float64(10), + DurationMs: float64(10.05), } - assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, invokeResponseMetrics, runtimeDoneTime)) + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(runtimeStartedTime, invokeResponseMetrics, runtimeDoneTime)) } -func TestGetRuntimeDoneInvokeMetricsWhenInvokeReceivedTimeIsZero(t *testing.T) { - now := int64(0) // January 1st, 1970 at 00:00:00 UTC - invokeReceivedTime := now +func TestGetRuntimeDoneInvokeMetricsWhenRuntimeStartedTimeIsMinusOne(t *testing.T) { + now := int64(-1) + runtimeStartedTime := now runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) - expected := &RuntimeDoneInvokeMetrics{ + expected := &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(0), DurationMs: float64(0), } - actual := GetRuntimeDoneInvokeMetrics(invokeReceivedTime, nil, runtimeDoneTime) + actual := GetRuntimeDoneInvokeMetrics(runtimeStartedTime, nil, runtimeDoneTime) assert.Equal(t, expected, actual) } func TestGetRuntimeDoneInvokeMetricsWhenInvokeResponseMetricsIsNil(t *testing.T) { now := metering.Monotime() - invokeReceivedTime := now + runtimeStartedTime := now runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) - expected := &RuntimeDoneInvokeMetrics{ + expected := &interop.RuntimeDoneInvokeMetrics{ ProducedBytes: int64(0), DurationMs: float64(10), } - assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(invokeReceivedTime, nil, runtimeDoneTime)) + assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(runtimeStartedTime, nil, runtimeDoneTime)) } func TestGetRuntimeDoneSpans(t *testing.T) { @@ -81,29 +82,29 @@ func TestGetRuntimeDoneSpans(t *testing.T) { startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) - invokeReceivedTime := now + runtimeStartedTime := now invokeResponseMetrics := &interop.InvokeResponseMetrics{ StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, RuntimeCalledResponse: true, } - expectedResponseLatencyMsStartTime := getEpochTimeInISO8601FormatFromMonotime(now) - expectedResponseDurationMsStartTime := getEpochTimeInISO8601FormatFromMonotime(startReadingResponseMonoTimeMs) - expected := []Span{ - Span{ + expectedResponseLatencyMsStartTime := GetEpochTimeInISO8601FormatFromMonotime(now) + expectedResponseDurationMsStartTime := GetEpochTimeInISO8601FormatFromMonotime(startReadingResponseMonoTimeMs) + expected := []interop.Span{ + { Name: "responseLatency", Start: expectedResponseLatencyMsStartTime, DurationMs: 5, }, - Span{ + { Name: "responseDuration", Start: expectedResponseDurationMsStartTime, DurationMs: 2, }, } - assert.Equal(t, expected, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) + assert.Equal(t, expected, GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics)) } func TestGetRuntimeDoneSpansWhenRuntimeCalledError(t *testing.T) { @@ -111,29 +112,101 @@ func TestGetRuntimeDoneSpansWhenRuntimeCalledError(t *testing.T) { startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) - invokeReceivedTime := now + runtimeStartedTime := now invokeResponseMetrics := &interop.InvokeResponseMetrics{ StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, RuntimeCalledResponse: false, } - assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) + assert.Equal(t, []interop.Span{}, GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics)) } func TestGetRuntimeDoneSpansWhenInvokeResponseMetricsNil(t *testing.T) { - invokeReceivedTime := metering.Monotime() + runtimeStartedTime := metering.Monotime() - assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, nil)) + assert.Equal(t, []interop.Span{}, GetRuntimeDoneSpans(runtimeStartedTime, nil)) } -func TestGetRuntimeDoneSpansWhenInvokeReceivedTimeIsZero(t *testing.T) { - now := int64(0) // January 1st, 1970 at 00:00:00 UTC - invokeReceivedTime := now +func TestGetRuntimeDoneSpansWhenRuntimeStartedTimeIsMinusOne(t *testing.T) { + now := int64(-1) + runtimeStartedTime := now invokeResponseMetrics := &interop.InvokeResponseMetrics{ StartReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(5)), FinishReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(7)), } - assert.Equal(t, []Span{}, GetRuntimeDoneSpans(invokeReceivedTime, invokeResponseMetrics)) + assert.Equal(t, []interop.Span{}, GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics)) +} + +func TestInferInitType(t *testing.T) { + testCases := map[string]struct { + initCachingEnabled bool + sandboxType interop.SandboxType + expected interop.InitType + }{ + "on demand": { + initCachingEnabled: false, + sandboxType: interop.SandboxClassic, + expected: InitTypeOnDemand, + }, + "pc": { + initCachingEnabled: false, + sandboxType: interop.SandboxPreWarmed, + expected: InitTypeProvisionedConcurrency, + }, + "snap-start for OD": { + initCachingEnabled: true, + sandboxType: interop.SandboxClassic, + expected: InitTypeInitCaching, + }, + "snap-start for PC": { + initCachingEnabled: true, + sandboxType: interop.SandboxPreWarmed, + expected: InitTypeInitCaching, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + initType := InferInitType(tc.initCachingEnabled, tc.sandboxType) + assert.Equal(t, tc.expected, initType) + }) + } +} + +func TestCalculateDuration(t *testing.T) { + testCases := map[string]struct { + start int64 + end int64 + expected float64 + }{ + "milliseconds only": { + start: int64(100 * time.Millisecond), + end: int64(120 * time.Millisecond), + expected: 20, + }, + "with microseconds": { + start: int64(100 * time.Millisecond), + end: int64(210*time.Millisecond + 65*time.Microsecond), + expected: 110.065, + }, + "nanoseconds must be dropped": { + start: int64(100 * time.Millisecond), + end: int64(140*time.Millisecond + 999*time.Nanosecond), + expected: 40, + }, + "microseconds presented, nanoseconds dropped": { + start: int64(100 * time.Millisecond), + end: int64(150*time.Millisecond + 2*time.Microsecond + 999*time.Nanosecond), + expected: 50.002, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + actual := CalculateDuration(tc.start, tc.end) + assert.Equal(t, tc.expected, actual) + }) + } } diff --git a/lambda/telemetry/logs_egress_api.go b/lambda/telemetry/logs_egress_api.go index 7e84fe2..f4da62d 100644 --- a/lambda/telemetry/logs_egress_api.go +++ b/lambda/telemetry/logs_egress_api.go @@ -29,3 +29,5 @@ func (s *NoOpLogsEgressAPI) GetRuntimeSockets() (io.Writer, io.Writer, error) { // os.Stderr can not be used for the stderrWriter because stderr is for internal logging (not customer visible). return os.Stdout, os.Stdout, nil } + +var _ StdLogsEgressAPI = (*NoOpLogsEgressAPI)(nil) diff --git a/lambda/telemetry/logs_subscription_api.go b/lambda/telemetry/logs_subscription_api.go index 6ee9490..2fa39f0 100644 --- a/lambda/telemetry/logs_subscription_api.go +++ b/lambda/telemetry/logs_subscription_api.go @@ -12,7 +12,7 @@ import ( // SubscriptionAPI represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible type SubscriptionAPI interface { - Subscribe(agentName string, body io.Reader, headers map[string][]string) (resp []byte, status int, respHeaders map[string][]string, err error) + Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) (resp []byte, status int, respHeaders map[string][]string, err error) RecordCounterMetric(metricName string, count int) FlushMetrics() interop.TelemetrySubscriptionMetrics Clear() @@ -25,7 +25,7 @@ type SubscriptionAPI interface { type NoOpSubscriptionAPI struct{} // Subscribe writes response to a shared memory -func (m *NoOpSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string) ([]byte, int, map[string][]string, error) { +func (m *NoOpSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { return []byte(`{}`), http.StatusOK, map[string][]string{}, nil } diff --git a/lambda/telemetry/tracer.go b/lambda/telemetry/tracer.go index affca60..889682b 100644 --- a/lambda/telemetry/tracer.go +++ b/lambda/telemetry/tracer.go @@ -17,8 +17,8 @@ import ( type traceContextKey int const ( - traceIDKey traceContextKey = iota - documentIDKey + TraceIDKey traceContextKey = iota + DocumentIDKey ) type Tracer interface { @@ -30,11 +30,14 @@ type Tracer interface { RecordInitStartTime() RecordInitEndTime() SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) + SendRestoreSubsegmentWithRecordedTimesOnce(ctx context.Context) MarkError(ctx context.Context) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) WithErrorCause(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error - TracingHeaderParser() func(context.Context, *interop.Invoke) string + BuildTracingHeader() func(context.Context) string + BuildTracingCtxForStart() *interop.TracingCtx + BuildTracingCtxAfterInvokeComplete() *interop.TracingCtx } type NoOpTracer struct{} @@ -42,28 +45,25 @@ type NoOpTracer struct{} func (t *NoOpTracer) Configure(invoke *interop.Invoke) {} func (t *NoOpTracer) CaptureInvokeSegment(ctx context.Context, criticalFunction func(context.Context) error) error { - criticalFunction(ctx) - return nil + return criticalFunction(ctx) } func (t *NoOpTracer) CaptureInitSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - criticalFunction(ctx) - return nil + return criticalFunction(ctx) } func (t *NoOpTracer) CaptureInvokeSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - criticalFunction(ctx) - return nil + return criticalFunction(ctx) } func (t *NoOpTracer) CaptureOverheadSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - criticalFunction(ctx) - return nil + return criticalFunction(ctx) } func (t *NoOpTracer) RecordInitStartTime() {} func (t *NoOpTracer) RecordInitEndTime() {} func (t *NoOpTracer) SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) {} +func (t *NoOpTracer) SendRestoreSubsegmentWithRecordedTimesOnce(ctx context.Context) {} func (t *NoOpTracer) MarkError(ctx context.Context) {} func (t *NoOpTracer) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) {} @@ -73,8 +73,25 @@ func (t *NoOpTracer) WithErrorCause(ctx context.Context, appCtx appctx.Applicati func (t *NoOpTracer) WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { return criticalFunction } -func (t *NoOpTracer) TracingHeaderParser() func(context.Context, *interop.Invoke) string { - return GetCustomerTracingHeader +func (t *NoOpTracer) BuildTracingHeader() func(context.Context) string { + // extract root trace ID and parent from context and build the tracing header + return func(ctx context.Context) string { + root, _ := ctx.Value(TraceIDKey).(string) + parent, _ := ctx.Value(DocumentIDKey).(string) + + if root != "" && parent != "" { + return fmt.Sprintf("Root=%s;Parent=%s;Sampled=1", root, parent) + } + + return "" + } +} + +func (t *NoOpTracer) BuildTracingCtxForStart() *interop.TracingCtx { + return nil +} +func (t *NoOpTracer) BuildTracingCtxAfterInvokeComplete() *interop.TracingCtx { + return nil } func NewNoOpTracer() *NoOpTracer { @@ -83,49 +100,31 @@ func NewNoOpTracer() *NoOpTracer { // NewTraceContext returns new derived context with trace config set for testing func NewTraceContext(ctx context.Context, root string, parent string) context.Context { - ctxWithRoot := context.WithValue(ctx, traceIDKey, root) - return context.WithValue(ctxWithRoot, documentIDKey, parent) + ctxWithRoot := context.WithValue(ctx, TraceIDKey, root) + return context.WithValue(ctxWithRoot, DocumentIDKey, parent) } -// GetCustomerTracingHeader extracts the trace config from trace context and constructs header -func GetCustomerTracingHeader(ctx context.Context, invoke *interop.Invoke) string { - var root, parent string - var ok bool - - if root, ok = ctx.Value(traceIDKey).(string); !ok { - return invoke.TraceID - } - - if parent, ok = ctx.Value(documentIDKey).(string); !ok { - return invoke.TraceID - } - - return fmt.Sprintf("Root=%s;Parent=%s;Sampled=1", root, parent) - -} - -// ParseTraceID helps client to get TraceID, ParentID, Sampled information from a full trace -func ParseTraceID(fullTraceID string) (rootID, parentID, sample string) { - traceIDInfo := strings.Split(fullTraceID, ";") - for i := 0; i < len(traceIDInfo); i++ { - if len(traceIDInfo[i]) == 0 { - continue - } else { - var key string - var value string - keyValuePair := strings.Split(traceIDInfo[i], "=") - if len(keyValuePair) == 2 { - key = keyValuePair[0] - value = keyValuePair[1] - } - switch key { - case "Root": - rootID = value - case "Parent": - parentID = value - case "Sampled": - sample = value - } +// ParseTracingHeader extracts RootTraceID, ParentID, Sampled, and Lineage from a tracing header. +// Tracing header format is defined here: +// https://docs.aws.amazon.com/xray/latest/devguide/xray-concepts.html#xray-concepts-tracingheader +func ParseTracingHeader(tracingHeader string) (rootID, parentID, sampled, lineage string) { + keyValuePairs := strings.Split(tracingHeader, ";") + for _, pair := range keyValuePairs { + var key, value string + keyValue := strings.Split(pair, "=") + if len(keyValue) == 2 { + key = keyValue[0] + value = keyValue[1] + } + switch key { + case "Root": + rootID = value + case "Parent": + parentID = value + case "Sampled": + sampled = value + case "Lineage": + lineage = value } } return diff --git a/lambda/telemetry/tracer_test.go b/lambda/telemetry/tracer_test.go index c31653f..d67c389 100644 --- a/lambda/telemetry/tracer_test.go +++ b/lambda/telemetry/tracer_test.go @@ -4,35 +4,65 @@ package telemetry import ( + "context" + "fmt" + "strings" "testing" "go.amzn.com/lambda/rapi/model" ) +var BigString = strings.Repeat("a", 255) + var parserTests = []struct { - traceIDIn string - rootIDOut string - parentIDOut string - sampledOut string + tracingHeaderIn string + rootIDOut string + parentIDOut string + sampledOut string + lineageOut string }{ - {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", "1-5b3cc918-939afd635f8891ba6a9e1df6", "c88d77b0aef840e9", "1"}, - {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9", "1-5b3cc918-939afd635f8891ba6a9e1df6", "c88d77b0aef840e9", ""}, - {"1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", "", "c88d77b0aef840e9", "1"}, - {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6", "1-5b3cc918-939afd635f8891ba6a9e1df6", "", ""}, + {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", "1-5b3cc918-939afd635f8891ba6a9e1df6", "c88d77b0aef840e9", "1", ""}, + {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9", "1-5b3cc918-939afd635f8891ba6a9e1df6", "c88d77b0aef840e9", "", ""}, + {"1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", "", "c88d77b0aef840e9", "1", ""}, + {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6", "1-5b3cc918-939afd635f8891ba6a9e1df6", "", "", ""}, + {"", "", "", "", ""}, + {"abc;;", "", "", "", ""}, + {"abc", "", "", "", ""}, + {"abc;asd", "", "", "", ""}, + {"abc=as;asd=as", "", "", "", ""}, + {"Root=abc", "abc", "", "", ""}, + {"Root=abc;Parent=zxc;Sampled=1", "abc", "zxc", "1", ""}, + {"Root=root;Parent=par", "root", "par", "", ""}, + {"Root=root;Par", "root", "", "", ""}, + {"Root=", "", "", "", ""}, + {";Root=root;;", "root", "", "", ""}, + {"Root=root;Parent=parent;", "root", "parent", "", ""}, + {"Root=;Parent=parent;Sampled=1", "", "parent", "1", ""}, + {"Root=abc;Parent=zxc;Sampled=1;Lineage", "abc", "zxc", "1", ""}, + {"Root=abc;Parent=zxc;Sampled=1;Lineage=", "abc", "zxc", "1", ""}, + {"Root=abc;Parent=zxc;Sampled=1;Lineage=foo:1|bar:65535", "abc", "zxc", "1", "foo:1|bar:65535"}, + {"Root=abc;Parent=zxc;Lineage=foo:1|bar:65535;Sampled=1", "abc", "zxc", "1", "foo:1|bar:65535"}, + {fmt.Sprintf("Root=%s;Parent=%s;Sampled=1;Lineage=%s", BigString, BigString, BigString), BigString, BigString, "1", BigString}, } -func TestParseTraceID(t *testing.T) { +func TestParseTracingHeader(t *testing.T) { for _, tt := range parserTests { - t.Run(tt.traceIDIn, func(t *testing.T) { - rootID, parentID, sampled := ParseTraceID(tt.traceIDIn) + t.Run(tt.tracingHeaderIn, func(t *testing.T) { + rootID, parentID, sampled, lineage := ParseTracingHeader(tt.tracingHeaderIn) if rootID != tt.rootIDOut { - t.Errorf("got %q, wanted %q", rootID, tt.rootIDOut) + t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, rootID, tt.rootIDOut) } if parentID != tt.parentIDOut { - t.Errorf("got %q, wanted %q", rootID, tt.parentIDOut) + t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, parentID, tt.parentIDOut) } if sampled != tt.sampledOut { - t.Errorf("got %q, wanted %q", sampled, tt.sampledOut) + t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, sampled, tt.sampledOut) + } + if lineage != tt.lineageOut { + t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, lineage, tt.lineageOut) + } + if lineage != tt.lineageOut { + t.Errorf("got %q, wanted %q", lineage, tt.lineageOut) } }) } @@ -81,3 +111,45 @@ func TestBuildFullTraceID(t *testing.T) { }) } } + +func TestTracerDoesntSwallowErrorsFromCriticalFunctions(t *testing.T) { + ctx := context.Background() + + testCases := []struct { + name string + tracer Tracer + expectedError error + }{ + { + name: "NoOpTracer-success", + tracer: &NoOpTracer{}, + expectedError: nil, + }, + { + name: "NoOpTracer-fail", + tracer: &NoOpTracer{}, + expectedError: fmt.Errorf("invoke error"), + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + criticalFunction := func(ctx context.Context) error { + return test.expectedError + } + + if err := test.tracer.CaptureInvokeSegment(ctx, criticalFunction); err != test.expectedError { + t.Errorf("CaptureInvokeSegment failed; expected: '%v', but got: '%v'", test.expectedError, err) + } + if err := test.tracer.CaptureInitSubsegment(ctx, criticalFunction); err != test.expectedError { + t.Errorf("CaptureInitSubsegment failed; expected: '%v', but got: '%v'", test.expectedError, err) + } + if err := test.tracer.CaptureInvokeSubsegment(ctx, criticalFunction); err != test.expectedError { + t.Errorf("CaptureInvokeSubsegment failed; expected: '%v', but got: '%v'", test.expectedError, err) + } + if err := test.tracer.CaptureOverheadSubsegment(ctx, criticalFunction); err != test.expectedError { + t.Errorf("CaptureOverheadSubsegment failed; expected: '%v', but got: '%v'", test.expectedError, err) + } + }) + } +} diff --git a/lambda/testdata/flowtesting.go b/lambda/testdata/flowtesting.go index c028d7c..e2c4b49 100644 --- a/lambda/testdata/flowtesting.go +++ b/lambda/testdata/flowtesting.go @@ -4,10 +4,9 @@ package testdata import ( + "bytes" "context" - "io" "io/ioutil" - "net/http" "time" "go.amzn.com/lambda/appctx" @@ -25,15 +24,15 @@ const ( type MockInteropServer struct { Response []byte - ErrorResponse *interop.ErrorResponse + ErrorResponse *interop.ErrorInvokeResponse ResponseContentType string FunctionResponseMode string ActiveInvokeID string } // SendResponse writes response to a shared memory. -func (i *MockInteropServer) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { - bytes, err := ioutil.ReadAll(reader) +func (i *MockInteropServer) SendResponse(invokeID string, resp *interop.StreamableInvokeResponse) error { + bytes, err := ioutil.ReadAll(resp.Payload) if err != nil { return err } @@ -44,23 +43,23 @@ func (i *MockInteropServer) SendResponse(invokeID string, headers map[string]str } } i.Response = bytes - i.ResponseContentType = headers[contentTypeHeader] - i.FunctionResponseMode = headers[functionResponseModeHeader] + i.ResponseContentType = resp.Headers[contentTypeHeader] + i.FunctionResponseMode = resp.Headers[functionResponseModeHeader] return nil } // SendErrorResponse writes error response to a shared memory and sends GIRD FAULT. -func (i *MockInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error { +func (i *MockInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorInvokeResponse) error { i.ErrorResponse = response - i.ResponseContentType = response.ContentType - i.FunctionResponseMode = response.FunctionResponseMode + i.ResponseContentType = response.Headers.ContentType + i.FunctionResponseMode = response.Headers.FunctionResponseMode return nil } // SendInitErrorResponse writes error response during init to a shared memory and sends GIRD FAULT. -func (i *MockInteropServer) SendInitErrorResponse(invokeID string, response *interop.ErrorResponse) error { +func (i *MockInteropServer) SendInitErrorResponse(response *interop.ErrorInvokeResponse) error { i.ErrorResponse = response - i.ResponseContentType = response.ContentType + i.ResponseContentType = response.Headers.ContentType return nil } @@ -81,7 +80,7 @@ type FlowTest struct { InteropServer *MockInteropServer TelemetrySubscription *telemetry.NoOpSubscriptionAPI CredentialsService core.CredentialsService - EventsAPI telemetry.EventsAPI + EventsAPI interop.EventsAPI } // ConfigureForInit initialize synchronization gates and states for init. @@ -93,13 +92,25 @@ func (s *FlowTest) ConfigureForInit() { func (s *FlowTest) ConfigureForInvoke(ctx context.Context, invoke *interop.Invoke) { s.InteropServer.ActiveInvokeID = invoke.ID s.InvokeFlow.InitializeBarriers() - s.RenderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invoke, telemetry.GetCustomerTracingHeader)) + var buf bytes.Buffer // create default invoke renderer with new request buffer each time + s.ConfigureInvokeRenderer(ctx, invoke, &buf) +} + +// ConfigureInvokeRenderer overrides default invoke renderer to reuse request buffers (for benchmarks), etc. +func (s *FlowTest) ConfigureInvokeRenderer(ctx context.Context, invoke *interop.Invoke, buf *bytes.Buffer) { + s.RenderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invoke, buf, telemetry.NewNoOpTracer().BuildTracingHeader())) } func (s *FlowTest) ConfigureForRestore() { s.RenderingService.SetRenderer(rendering.NewRestoreRenderer()) } +func (s *FlowTest) ConfigureForRestoring() { + s.RegistrationService.PreregisterRuntime(s.Runtime) + s.Runtime.SetState(s.Runtime.RuntimeRestoringState) + s.RenderingService.SetRenderer(rendering.NewRestoreRenderer()) +} + func (s *FlowTest) ConfigureForInitCaching(token, awsKey, awsSecret, awsSession string) { credentialsExpiration := time.Now().Add(30 * time.Minute) s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession, credentialsExpiration) @@ -118,6 +129,8 @@ func NewFlowTest() *FlowTest { interopServer := &MockInteropServer{} eventsAPI := telemetry.NoOpEventsAPI{} appctx.StoreInteropServer(appCtx, interopServer) + appctx.StoreResponseSender(appCtx, interopServer) + return &FlowTest{ AppCtx: appCtx, InitFlow: initFlow, diff --git a/lambda/testdata/mocktracer/mocktracer.go b/lambda/testdata/mocktracer/mocktracer.go index f6ee9ab..3fb7054 100644 --- a/lambda/testdata/mocktracer/mocktracer.go +++ b/lambda/testdata/mocktracer/mocktracer.go @@ -5,14 +5,15 @@ package mocktracer import ( "context" - "go.amzn.com/lambda/xray" "time" + + "go.amzn.com/lambda/xray" ) // MockStartTime is start time set in Start method var MockStartTime = time.Now().UnixNano() -//MockEndTime is end time set in End method +// MockEndTime is end time set in End method var MockEndTime = time.Now().UnixNano() + 1 // MockTracer is used for unit tests From 2057adac54dbd781b51431f775363ec021648c77 Mon Sep 17 00:00:00 2001 From: Renato Valenzuela Date: Mon, 13 Nov 2023 22:06:06 +0000 Subject: [PATCH 19/41] chore(deps): Upgrade to Go 1.20 --- Makefile | 2 +- go.mod | 6 +++--- go.sum | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index c2d5e55..80ccb89 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ compile-lambda-linux-all: make ARCH=old compile-lambda-linux compile-with-docker: - docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.19 make ARCH=${ARCH} compile-lambda-linux + docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.20 make ARCH=${ARCH} compile-lambda-linux compile-lambda-linux: CGO_ENABLED=0 GOOS=linux GOARCH=${GO_ARCH_${ARCH}} go build -buildvcs=false -ldflags "${RELEASE_BUILD_LINKER_FLAGS}" -o ${DESTINATION_${ARCH}} ./cmd/aws-lambda-rie diff --git a/go.mod b/go.mod index 053c7e0..990a7dd 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module go.amzn.com -go 1.19 +go 1.20 require ( github.com/aws/aws-lambda-go v1.41.0 @@ -16,7 +16,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.0 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect + golang.org/x/net v0.18.0 // indirect + golang.org/x/sys v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d8fb9e9..0ea11d6 100644 --- a/go.sum +++ b/go.sum @@ -22,15 +22,15 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= +golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From a252c82e052326ff56a72a3a6507247d791aa769 Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Thu, 1 Feb 2024 18:09:10 +0100 Subject: [PATCH 20/41] Adapt new init changes (#30) --- cmd/localstack/awsutil.go | 8 ++-- cmd/localstack/custom_interop.go | 16 +++---- cmd/localstack/main.go | 5 ++- cmd/localstack/simple_bootstrap.go | 69 ++++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 13 deletions(-) create mode 100644 cmd/localstack/simple_bootstrap.go diff --git a/cmd/localstack/awsutil.go b/cmd/localstack/awsutil.go index 7a8ba8a..de18378 100644 --- a/cmd/localstack/awsutil.go +++ b/cmd/localstack/awsutil.go @@ -11,7 +11,6 @@ import ( "fmt" log "github.com/sirupsen/logrus" "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" "go.amzn.com/lambda/rapidcore/env" "golang.org/x/sys/unix" "io" @@ -34,7 +33,7 @@ func isBootstrapFileExist(filePath string) bool { return !os.IsNotExist(err) && !file.IsDir() } -func getBootstrap(args []string) (*rapidcore.Bootstrap, string) { +func getBootstrap(args []string) (interop.Bootstrap, string) { var bootstrapLookupCmd []string var handler string currentWorkingDir := "/var/task" // default value @@ -89,7 +88,7 @@ func getBootstrap(args []string) (*rapidcore.Bootstrap, string) { } } - return rapidcore.NewBootstrapSingleCmd(bootstrapLookupCmd, currentWorkingDir, ""), handler + return NewSimpleBootstrap(bootstrapLookupCmd, currentWorkingDir), handler } func PrintEndReports(invokeId string, initDuration string, memorySize string, invokeStart time.Time, timeoutDuration time.Duration, w io.Writer) { @@ -205,7 +204,7 @@ func getSubFoldersInList(prefix string, pathList []string) (oldFolders []string, return } -func InitHandler(sandbox Sandbox, functionVersion string, timeout int64, bs interop.Bootstrap) (time.Time, time.Time) { +func InitHandler(sandbox Sandbox, functionVersion string, timeout int64, bs interop.Bootstrap, accountId string) (time.Time, time.Time) { additionalFunctionEnvironmentVariables := map[string]string{} // Add default Env Vars if they were not defined. This is a required otherwise 1p Python2.7, Python3.6, and @@ -231,6 +230,7 @@ func InitHandler(sandbox Sandbox, functionVersion string, timeout int64, bs inte AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), AwsSession: os.Getenv("AWS_SESSION_TOKEN"), + AccountID: accountId, XRayDaemonAddress: GetenvWithDefault("AWS_XRAY_DAEMON_ADDRESS", "127.0.0.1:2000"), FunctionName: GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function"), FunctionVersion: functionVersion, diff --git a/cmd/localstack/custom_interop.go b/cmd/localstack/custom_interop.go index 3dcde93..2bd3541 100644 --- a/cmd/localstack/custom_interop.go +++ b/cmd/localstack/custom_interop.go @@ -194,23 +194,23 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate interop.Server, logCollecto return server } -func (c *CustomInteropServer) SendResponse(invokeID string, headers map[string]string, reader io.Reader, trailers http.Header, request *interop.CancellableRequest) error { +func (c *CustomInteropServer) SendResponse(invokeID string, resp *interop.StreamableInvokeResponse) error { log.Traceln("SendResponse called") - return c.delegate.SendResponse(invokeID, headers, reader, trailers, request) + return c.delegate.SendResponse(invokeID, resp) } -func (c *CustomInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorResponse) error { +func (c *CustomInteropServer) SendErrorResponse(invokeID string, resp *interop.ErrorInvokeResponse) error { log.Traceln("SendErrorResponse called") - return c.delegate.SendErrorResponse(invokeID, response) + return c.delegate.SendErrorResponse(invokeID, resp) } // SendInitErrorResponse writes error response during init to a shared memory and sends GIRD FAULT. -func (c *CustomInteropServer) SendInitErrorResponse(invokeID string, response *interop.ErrorResponse) error { +func (c *CustomInteropServer) SendInitErrorResponse(resp *interop.ErrorInvokeResponse) error { log.Traceln("SendInitErrorResponse called") - if err := c.localStackAdapter.SendStatus(Error, response.Payload); err != nil { + if err := c.localStackAdapter.SendStatus(Error, resp.Payload); err != nil { log.Fatalln("Failed to send init error to LocalStack " + err.Error() + ". Exiting.") } - return c.delegate.SendInitErrorResponse(invokeID, response) + return c.delegate.SendInitErrorResponse(resp) } func (c *CustomInteropServer) GetCurrentInvokeID() string { @@ -248,7 +248,7 @@ func (c *CustomInteropServer) Reset(reason string, timeoutMs int64) (*statejson. return c.delegate.Reset(reason, timeoutMs) } -func (c *CustomInteropServer) AwaitRelease() (*statejson.InternalStateDescription, error) { +func (c *CustomInteropServer) AwaitRelease() (*statejson.ReleaseResponse, error) { log.Traceln("AwaitRelease called") return c.delegate.AwaitRelease() } diff --git a/cmd/localstack/main.go b/cmd/localstack/main.go index 08b70d9..f03e3b6 100644 --- a/cmd/localstack/main.go +++ b/cmd/localstack/main.go @@ -16,6 +16,7 @@ type LsOpts struct { InteropPort string RuntimeEndpoint string RuntimeId string + AccountId string InitTracingPort string User string CodeArchives string @@ -41,6 +42,7 @@ func InitLsOpts() *LsOpts { // required RuntimeEndpoint: GetEnvOrDie("LOCALSTACK_RUNTIME_ENDPOINT"), RuntimeId: GetEnvOrDie("LOCALSTACK_RUNTIME_ID"), + AccountId: GetenvWithDefault("LOCALSTACK_FUNCTION_ACCOUNT_ID", "000000000000"), // optional with default InteropPort: GetenvWithDefault("LOCALSTACK_INTEROP_PORT", "9563"), InitTracingPort: GetenvWithDefault("LOCALSTACK_RUNTIME_TRACING_PORT", "9564"), @@ -72,6 +74,7 @@ func UnsetLsEnvs() { "LOCALSTACK_ENABLE_XRAY_TELEMETRY", "LOCALSTACK_INIT_LOG_LEVEL", "LOCALSTACK_POST_INVOKE_WAIT_MS", + "LOCALSTACK_FUNCTION_ACCOUNT_ID", // Docker container ID "HOSTNAME", @@ -230,7 +233,7 @@ func main() { // start runtime init. It is important to start `InitHandler` synchronously because we need to ensure the // notification channels and status fields are properly initialized before `AwaitInitialized` log.Debugln("Starting runtime init.") - InitHandler(sandbox.LambdaInvokeAPI(), GetEnvOrDie("AWS_LAMBDA_FUNCTION_VERSION"), int64(invokeTimeoutSeconds), bootstrap) // TODO: replace this with a custom init + InitHandler(sandbox.LambdaInvokeAPI(), GetEnvOrDie("AWS_LAMBDA_FUNCTION_VERSION"), int64(invokeTimeoutSeconds), bootstrap, lsOpts.AccountId) // TODO: replace this with a custom init log.Debugln("Awaiting initialization of runtime init.") if err := interopServer.delegate.AwaitInitialized(); err != nil { diff --git a/cmd/localstack/simple_bootstrap.go b/cmd/localstack/simple_bootstrap.go new file mode 100644 index 0000000..c9111a2 --- /dev/null +++ b/cmd/localstack/simple_bootstrap.go @@ -0,0 +1,69 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "fmt" + "os" + "path/filepath" + + "go.amzn.com/lambda/fatalerror" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore/env" +) + +// the type implement a simpler version of the Bootstrap +// this is useful in the Standalone Core implementation. +type simpleBootstrap struct { + cmd []string + workingDir string +} + +func NewSimpleBootstrap(cmd []string, currentWorkingDir string) interop.Bootstrap { + if currentWorkingDir == "" { + // use the root directory as the default working directory + currentWorkingDir = "/" + } + + // a single candidate command makes it automatically valid + return &simpleBootstrap{ + cmd: cmd, + workingDir: currentWorkingDir, + } +} + +func (b *simpleBootstrap) Cmd() ([]string, error) { + return b.cmd, nil +} + +// Cwd returns the working directory of the bootstrap process +// The path is validated against the chroot identified by `root` +func (b *simpleBootstrap) Cwd() (string, error) { + if !filepath.IsAbs(b.workingDir) { + return "", fmt.Errorf("the working directory '%s' is invalid, it needs to be an absolute path", b.workingDir) + } + + // evaluate the path relatively to the domain's mnt namespace root + if _, err := os.Stat(b.workingDir); os.IsNotExist(err) { + return "", fmt.Errorf("the working directory doesn't exist: %s", b.workingDir) + } + + return b.workingDir, nil +} + +// Env returns the environment variables available to +// the bootstrap process +func (b *simpleBootstrap) Env(e *env.Environment) map[string]string { + return e.RuntimeExecEnv() +} + +// ExtraFiles returns the extra file descriptors apart from 1 & 2 to be passed to runtime +func (b *simpleBootstrap) ExtraFiles() []*os.File { + return make([]*os.File, 0) +} + +func (b *simpleBootstrap) CachedFatalError(err error) (fatalerror.ErrorType, string, bool) { + // not implemented as it is not needed in Core but we need to fullfil the interface anyway + return fatalerror.ErrorType(""), "", false +} From e06355741daf9abf559eba2182d6210c0496bc44 Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Tue, 6 Feb 2024 10:51:37 +0100 Subject: [PATCH 21/41] Add new localstack tracer to return tracing headers set by invoke (#31) --- cmd/localstack/main.go | 4 ++- cmd/localstack/tracer.go | 63 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 cmd/localstack/tracer.go diff --git a/cmd/localstack/main.go b/cmd/localstack/main.go index f03e3b6..02bf247 100644 --- a/cmd/localstack/main.go +++ b/cmd/localstack/main.go @@ -178,6 +178,7 @@ func main() { logCollector := NewLogCollector() localStackLogsEgressApi := NewLocalStackLogsEgressAPI(logCollector) + tracer := NewLocalStackTracer() // build sandbox sandbox := rapidcore. @@ -191,7 +192,8 @@ func main() { }). SetExtensionsFlag(true). SetInitCachingFlag(true). - SetLogsEgressAPI(localStackLogsEgressApi) + SetLogsEgressAPI(localStackLogsEgressApi). + SetTracer(tracer) // xray daemon endpoint := "http://" + lsOpts.LocalstackIP + ":" + lsOpts.EdgePort diff --git a/cmd/localstack/tracer.go b/cmd/localstack/tracer.go new file mode 100644 index 0000000..8506a9a --- /dev/null +++ b/cmd/localstack/tracer.go @@ -0,0 +1,63 @@ +package main + +import ( + "context" + "encoding/json" + "go.amzn.com/lambda/appctx" + "go.amzn.com/lambda/interop" +) + +type LocalStackTracer struct { + invoke *interop.Invoke +} + +func (t *LocalStackTracer) Configure(invoke *interop.Invoke) { + t.invoke = invoke +} + +func (t *LocalStackTracer) CaptureInvokeSegment(ctx context.Context, criticalFunction func(context.Context) error) error { + return criticalFunction(ctx) +} + +func (t *LocalStackTracer) CaptureInitSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { + return criticalFunction(ctx) +} + +func (t *LocalStackTracer) CaptureInvokeSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { + return criticalFunction(ctx) +} + +func (t *LocalStackTracer) CaptureOverheadSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { + return criticalFunction(ctx) +} + +func (t *LocalStackTracer) RecordInitStartTime() {} +func (t *LocalStackTracer) RecordInitEndTime() {} +func (t *LocalStackTracer) SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) {} +func (t *LocalStackTracer) SendRestoreSubsegmentWithRecordedTimesOnce(ctx context.Context) {} +func (t *LocalStackTracer) MarkError(ctx context.Context) {} +func (t *LocalStackTracer) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) {} + +func (t *LocalStackTracer) WithErrorCause(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { + return criticalFunction +} +func (t *LocalStackTracer) WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { + return criticalFunction +} +func (t *LocalStackTracer) BuildTracingHeader() func(context.Context) string { + // extract root trace ID and parent from context and build the tracing header + return func(ctx context.Context) string { + return t.invoke.TraceID + } +} + +func (t *LocalStackTracer) BuildTracingCtxForStart() *interop.TracingCtx { + return nil +} +func (t *LocalStackTracer) BuildTracingCtxAfterInvokeComplete() *interop.TracingCtx { + return nil +} + +func NewLocalStackTracer() *LocalStackTracer { + return &LocalStackTracer{} +} From ba28a02302759867e299350a91822e6eb1c4ad41 Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Tue, 6 Feb 2024 10:52:18 +0100 Subject: [PATCH 22/41] Allow manual specification of filewatcher behavior (#29) --- cmd/localstack/awsutil.go | 4 ++-- cmd/localstack/filenotify/filenotify.go | 20 +++++++++++++++++--- cmd/localstack/hotreloading.go | 4 ++-- cmd/localstack/main.go | 4 +++- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/cmd/localstack/awsutil.go b/cmd/localstack/awsutil.go index de18378..7fa02c4 100644 --- a/cmd/localstack/awsutil.go +++ b/cmd/localstack/awsutil.go @@ -156,14 +156,14 @@ func RunDNSRewriter(opts *LsOpts, ctx context.Context) { log.Debugln("DNS server stopped") } -func RunHotReloadingListener(server *CustomInteropServer, targetPaths []string, ctx context.Context) { +func RunHotReloadingListener(server *CustomInteropServer, targetPaths []string, ctx context.Context, fileWatcherStrategy string) { if len(targetPaths) == 1 && targetPaths[0] == "" { log.Debugln("Hot reloading disabled.") return } defaultDebouncingDuration := 500 * time.Millisecond log.Infoln("Hot reloading enabled, starting filewatcher.", targetPaths) - changeListener, err := NewChangeListener(defaultDebouncingDuration) + changeListener, err := NewChangeListener(defaultDebouncingDuration, fileWatcherStrategy) if err != nil { log.Errorln("Hot reloading disabled due to change listener error.", err) return diff --git a/cmd/localstack/filenotify/filenotify.go b/cmd/localstack/filenotify/filenotify.go index 1c1d708..579a76a 100644 --- a/cmd/localstack/filenotify/filenotify.go +++ b/cmd/localstack/filenotify/filenotify.go @@ -38,14 +38,28 @@ func shouldUseEventWatcher() bool { } // New tries to use a fs-event watcher, and falls back to the poller if there is an error -func New(interval time.Duration) (FileWatcher, error) { +func New(interval time.Duration, fileWatcherStrategy string) (FileWatcher, error) { + if fileWatcherStrategy != "" { + log.Debugln("Forced usage of filewatcher strategy: ", fileWatcherStrategy) + if fileWatcherStrategy == "event" { + if watcher, err := NewEventWatcher(); err == nil { + return watcher, nil + } else { + log.Fatalln("Event based filewatcher is selected, but unable to start. Please try setting the filewatcher to polling. Error: ", err) + } + } else if fileWatcherStrategy == "polling" { + return NewPollingWatcher(interval), nil + } else { + log.Fatalf("Invalid filewatcher strategy %s. Only event and polling are allowed.\n", fileWatcherStrategy) + } + } if shouldUseEventWatcher() { if watcher, err := NewEventWatcher(); err == nil { - log.Debugln("Using event based filewatcher") + log.Debugln("Using event based filewatcher (autodetected)") return watcher, nil } } - log.Debugln("Using polling based filewatcher") + log.Debugln("Using polling based filewatcher (autodetected)") return NewPollingWatcher(interval), nil } diff --git a/cmd/localstack/hotreloading.go b/cmd/localstack/hotreloading.go index 9218465..a0c4467 100644 --- a/cmd/localstack/hotreloading.go +++ b/cmd/localstack/hotreloading.go @@ -16,8 +16,8 @@ type ChangeListener struct { watchedFolders []string } -func NewChangeListener(debouncingInterval time.Duration) (*ChangeListener, error) { - watcher, err := filenotify.New(200 * time.Millisecond) +func NewChangeListener(debouncingInterval time.Duration, fileWatcherStrategy string) (*ChangeListener, error) { + watcher, err := filenotify.New(200*time.Millisecond, fileWatcherStrategy) if err != nil { log.Errorln("Cannot create change listener due to filewatcher error.", err) return nil, err diff --git a/cmd/localstack/main.go b/cmd/localstack/main.go index 02bf247..e6eb026 100644 --- a/cmd/localstack/main.go +++ b/cmd/localstack/main.go @@ -21,6 +21,7 @@ type LsOpts struct { User string CodeArchives string HotReloadingPaths []string + FileWatcherStrategy string EnableDnsServer string LocalstackIP string InitLogLevel string @@ -52,6 +53,7 @@ func InitLsOpts() *LsOpts { // optional or empty CodeArchives: os.Getenv("LOCALSTACK_CODE_ARCHIVES"), HotReloadingPaths: strings.Split(GetenvWithDefault("LOCALSTACK_HOT_RELOADING_PATHS", ""), ","), + FileWatcherStrategy: os.Getenv("LOCALSTACK_FILE_WATCHER_STRATEGY"), EnableDnsServer: os.Getenv("LOCALSTACK_ENABLE_DNS_SERVER"), EnableXRayTelemetry: os.Getenv("LOCALSTACK_ENABLE_XRAY_TELEMETRY"), LocalstackIP: os.Getenv("LOCALSTACK_HOSTNAME"), @@ -230,7 +232,7 @@ func main() { if err != nil { log.Fatalln(err) } - go RunHotReloadingListener(interopServer, lsOpts.HotReloadingPaths, fileWatcherContext) + go RunHotReloadingListener(interopServer, lsOpts.HotReloadingPaths, fileWatcherContext, lsOpts.FileWatcherStrategy) // start runtime init. It is important to start `InitHandler` synchronously because we need to ensure the // notification channels and status fields are properly initialized before `AwaitInitialized` From 235ac2d84c2d084ac4f347609043c22010a4a2f7 Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Tue, 6 Feb 2024 11:24:33 +0100 Subject: [PATCH 23/41] Update go version and github actions versions (#32) --- .github/workflows/build.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 518aa03..3ccb66f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -12,19 +12,19 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: - go-version: '~1.18.2' + go-version: '1.20' - name: Build env: RELEASE_BUILD_LINKER_FLAGS: "-s -w" run: make compile-lambda-linux-all - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: aws-lambda-rie path: bin/* From ddc62d9501165b06504058fc96a178d21d1dc2f5 Mon Sep 17 00:00:00 2001 From: Dominik Schubert Date: Tue, 13 Feb 2024 17:04:11 +0100 Subject: [PATCH 24/41] Fix over-sized response handling (#33) --- cmd/localstack/custom_interop.go | 7 +++++-- cmd/localstack/main.go | 11 +++++++++++ debugging/Makefile | 2 +- lambda/core/directinvoke/directinvoke.go | 5 +++-- lambda/interop/model.go | 7 ++++--- lambda/rapi/rendering/rendering.go | 3 ++- 6 files changed, 26 insertions(+), 9 deletions(-) diff --git a/cmd/localstack/custom_interop.go b/cmd/localstack/custom_interop.go index 2bd3541..1941668 100644 --- a/cmd/localstack/custom_interop.go +++ b/cmd/localstack/custom_interop.go @@ -6,6 +6,7 @@ package main import ( "bytes" "encoding/json" + "errors" "fmt" "github.com/go-chi/chi" log "github.com/sirupsen/logrus" @@ -117,8 +118,8 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate interop.Server, logCollecto timeout := int(server.delegate.GetInvokeTimeout().Seconds()) isErr := false if err != nil { - switch err { - case rapidcore.ErrInvokeTimeout: + switch { + case errors.Is(err, rapidcore.ErrInvokeTimeout): log.Debugf("Got invoke timeout") isErr = true errorResponse := ErrorResponse{ @@ -137,6 +138,8 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate interop.Server, logCollecto if err != nil { log.Fatalln("unable to write to response") } + case errors.Is(err, rapidcore.ErrInvokeDoneFailed): + // we can actually just continue here, error message is sent below default: log.Fatalln(err) } diff --git a/cmd/localstack/main.go b/cmd/localstack/main.go index e6eb026..e936d78 100644 --- a/cmd/localstack/main.go +++ b/cmd/localstack/main.go @@ -5,6 +5,7 @@ package main import ( "context" log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" "go.amzn.com/lambda/rapidcore" "os" "runtime/debug" @@ -28,6 +29,7 @@ type LsOpts struct { EdgePort string EnableXRayTelemetry string PostInvokeWaitMS string + MaxPayloadSize string } func GetEnvOrDie(env string) string { @@ -50,6 +52,7 @@ func InitLsOpts() *LsOpts { User: GetenvWithDefault("LOCALSTACK_USER", "sbx_user1051"), InitLogLevel: GetenvWithDefault("LOCALSTACK_INIT_LOG_LEVEL", "warn"), EdgePort: GetenvWithDefault("EDGE_PORT", "4566"), + MaxPayloadSize: GetenvWithDefault("LOCALSTACK_MAX_PAYLOAD_SIZE", "6291556"), // optional or empty CodeArchives: os.Getenv("LOCALSTACK_CODE_ARCHIVES"), HotReloadingPaths: strings.Split(GetenvWithDefault("LOCALSTACK_HOT_RELOADING_PATHS", ""), ","), @@ -77,6 +80,7 @@ func UnsetLsEnvs() { "LOCALSTACK_INIT_LOG_LEVEL", "LOCALSTACK_POST_INVOKE_WAIT_MS", "LOCALSTACK_FUNCTION_ACCOUNT_ID", + "LOCALSTACK_MAX_PAYLOAD_SIZE", // Docker container ID "HOSTNAME", @@ -128,6 +132,13 @@ func main() { log.Fatal("Invalid value for LOCALSTACK_INIT_LOG_LEVEL") } + // patch MaxPayloadSize + payloadSize, err := strconv.Atoi(lsOpts.MaxPayloadSize) + if err != nil { + log.Panicln("Please specify a number for LOCALSTACK_MAX_PAYLOAD_SIZE") + } + interop.MaxPayloadSize = payloadSize + // enable dns server dnsServerContext, stopDnsServer := context.WithCancel(context.Background()) go RunDNSRewriter(lsOpts, dnsServerContext) diff --git a/debugging/Makefile b/debugging/Makefile index 9bd3e35..fe3a68f 100644 --- a/debugging/Makefile +++ b/debugging/Makefile @@ -1,5 +1,5 @@ # Golang EOL overview: https://endoflife.date/go -DOCKER_GOLANG_IMAGE ?= golang:1.19 +DOCKER_GOLANG_IMAGE ?= golang:1.20-bullseye # On ARM hosts, use: make ARCH=arm64 build-init # Check host architecture: uname -m diff --git a/lambda/core/directinvoke/directinvoke.go b/lambda/core/directinvoke/directinvoke.go index 3510132..396bd39 100644 --- a/lambda/core/directinvoke/directinvoke.go +++ b/lambda/core/directinvoke/directinvoke.go @@ -1,5 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +// LOCALSTACK CHANGES 2024-02-13: casting of MaxPayloadSize package directinvoke @@ -51,7 +52,7 @@ var ResetReasonMap = map[string]fatalerror.ErrorType{ "timeout": fatalerror.SandboxTimeout, } -var MaxDirectResponseSize int64 = interop.MaxPayloadSize // this is intentionally not a constant so we can configure it via CLI +var MaxDirectResponseSize = int64(interop.MaxPayloadSize) // this is intentionally not a constant so we can configure it via CLI var ResponseBandwidthRate int64 = interop.ResponseBandwidthRate var ResponseBandwidthBurstSize int64 = interop.ResponseBandwidthBurstSize @@ -104,7 +105,7 @@ func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.T now := metering.Monotime() - MaxDirectResponseSize = interop.MaxPayloadSize + MaxDirectResponseSize = int64(interop.MaxPayloadSize) if maxPayloadSize := r.Header.Get(MaxPayloadSizeHeader); maxPayloadSize != "" { if n, err := strconv.ParseInt(maxPayloadSize, 10, 64); err == nil && n >= -1 { MaxDirectResponseSize = n diff --git a/lambda/interop/model.go b/lambda/interop/model.go index a4bdbf4..ee7bb2a 100644 --- a/lambda/interop/model.go +++ b/lambda/interop/model.go @@ -1,5 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +// LOCALSTACK CHANGES 2024-02-13: adjust error message for ErrorResponseTooLarge to be in parity with what AWS returns; make MaxPayloadSize adjustable package interop @@ -18,10 +19,10 @@ import ( log "github.com/sirupsen/logrus" ) +var MaxPayloadSize int = 6*1024*1024 + 100 // 6 MiB + 100 bytes + // MaxPayloadSize max event body size declared as LAMBDA_EVENT_BODY_SIZE const ( - MaxPayloadSize = 6*1024*1024 + 100 // 6 MiB + 100 bytes - ResponseBandwidthRate = 2 * 1024 * 1024 // default average rate of 2 MiB/s ResponseBandwidthBurstSize = 6 * 1024 * 1024 // default burst size of 6 MiB @@ -355,7 +356,7 @@ type ErrorResponseTooLargeDI struct { // ErrorResponseTooLarge is returned when response provided by Runtime does not fit into shared memory buffer func (s *ErrorResponseTooLarge) Error() string { - return fmt.Sprintf("Response payload size (%d bytes) exceeded maximum allowed payload size (%d bytes).", s.ResponseSize, s.MaxResponseSize) + return fmt.Sprintf("Response payload size exceeded maximum allowed payload size (%d bytes).", s.MaxResponseSize) } // AsErrorResponse generates ErrorInvokeResponse from ErrorResponseTooLarge diff --git a/lambda/rapi/rendering/rendering.go b/lambda/rapi/rendering/rendering.go index 9a9d77b..08de1e3 100644 --- a/lambda/rapi/rendering/rendering.go +++ b/lambda/rapi/rendering/rendering.go @@ -1,5 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +// LOCALSTACK CHANGES 2024-02-13: casting of MaxPayloadSize package rendering @@ -174,7 +175,7 @@ func (s *InvokeRenderer) bufferInvokeRequest() error { defer s.requestMutex.Unlock() var err error = nil if s.requestBuffer.Len() == 0 { - reader := io.LimitReader(s.invoke.Payload, interop.MaxPayloadSize) + reader := io.LimitReader(s.invoke.Payload, int64(interop.MaxPayloadSize)) start := time.Now() _, err = s.requestBuffer.ReadFrom(reader) s.metrics = InvokeRendererMetrics{ From d5d750523ae172a282ed5f0f4d1ec56eaa7ee091 Mon Sep 17 00:00:00 2001 From: Frederic Mbea <117131783+mbfreder@users.noreply.github.com> Date: Thu, 7 Mar 2024 14:42:03 -0800 Subject: [PATCH 25/41] test: Add automated integration tests runs on GitHub(#112) Add automated integration tests runs on GitHub --- .github/workflows/integ-tests.yml | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 .github/workflows/integ-tests.yml diff --git a/.github/workflows/integ-tests.yml b/.github/workflows/integ-tests.yml new file mode 100644 index 0000000..cb2f9dc --- /dev/null +++ b/.github/workflows/integ-tests.yml @@ -0,0 +1,21 @@ +name: Run Integration Tests + +on: + pull_request: + branches: + - develop + +jobs: + integ-tests: + runs-on: ubuntu-latest + environment: + name: prod + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: allows us to build arm64 images + run: docker run --rm --privileged multiarch/qemu-user-static --reset -p yes + - name: run integration tests + run: make integ-tests-with-docker \ No newline at end of file From 1a320122c89528af4e707082a28d9ab0d81a12e8 Mon Sep 17 00:00:00 2001 From: Frederic Mbea <117131783+mbfreder@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:15:39 -0800 Subject: [PATCH 26/41] test: Refactored end-to-end tests (#113) * Refactored end-to-end tests to reduce duplicated code. --- .../local_lambda/test_end_to_end.py | 147 ++++++------------ 1 file changed, 46 insertions(+), 101 deletions(-) diff --git a/test/integration/local_lambda/test_end_to_end.py b/test/integration/local_lambda/test_end_to_end.py index c5c3e63..fd7f735 100644 --- a/test/integration/local_lambda/test_end_to_end.py +++ b/test/integration/local_lambda/test_end_to_end.py @@ -72,20 +72,34 @@ def tagged_name(self, name, architecture): def get_tag(self, architecture): return "" if architecture == "" else str(f"-{architecture}") + + def run_command(self, cmd): + Popen(cmd.split(" ")).communicate() + + def sleep_1s(self): + time.sleep(SLEEP_TIME) + + def invoke_function(self, port): + return requests.post( + f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} + ) + + def create_container_and_invoke_function(self, cmd, port): + self.run_command(cmd) + + # sleep 1s to give enough time for the endpoint to be up to curl + self.sleep_1s() + + return self.invoke_function(port) @parameterized.expand([("x86_64", "8000"), ("arm64", "9000"), ("", "9050")]) def test_env_var_with_equal_sign(self, arch, port): image, rie, image_name = self.tagged_name("envvarcheck", arch) cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_env_var_handler" - Popen(cmd.split(" ")).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + + r = self.create_container_and_invoke_function(cmd, port) + self.assertEqual(b'"4=4"', r.content) @parameterized.expand([("x86_64", "8001"), ("arm64", "9001"), ("", "9051")]) @@ -94,20 +108,13 @@ def test_two_invokes(self, arch, port): cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler" - Popen(cmd.split(" ")).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + r = self.create_container_and_invoke_function(cmd, port) + self.assertEqual(b'"My lambda ran succesfully"', r.content) # Make sure we can invoke the function twice - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + r = self.invoke_function(port) + self.assertEqual(b'"My lambda ran succesfully"', r.content) @parameterized.expand([("x86_64", "8002"), ("arm64", "9002"), ("", "9052")]) @@ -116,14 +123,8 @@ def test_lambda_function_arn_exists(self, arch, port): cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.assert_lambda_arn_in_context" - Popen(cmd.split(" ")).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + r = self.create_container_and_invoke_function(cmd, port) + self.assertEqual(b'"My lambda ran succesfully"', r.content) @parameterized.expand([("x86_64", "8003"), ("arm64", "9003"), ("", "9053")]) @@ -131,14 +132,9 @@ def test_lambda_function_arn_exists_with_defining_custom_name(self, arch, port): image, rie, image_name = self.tagged_name("customname", arch) cmd = f"docker run --name {image} --env AWS_LAMBDA_FUNCTION_NAME=MyCoolName -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.assert_lambda_arn_in_context" - Popen(cmd.split(" ")).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + + r = self.create_container_and_invoke_function(cmd, port) + self.assertEqual(b'"My lambda ran succesfully"', r.content) @parameterized.expand([("x86_64", "8004"), ("arm64", "9004"), ("", "9054")]) @@ -147,14 +143,8 @@ def test_timeout_invoke(self, arch, port): cmd = f"docker run --name {image} -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=1 -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.sleep_handler" - Popen(cmd.split(" ")).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + r = self.create_container_and_invoke_function(cmd, port) + self.assertEqual(b"Task timed out after 1.00 seconds", r.content) @parameterized.expand([("x86_64", "8005"), ("arm64", "9005"), ("", "9055")]) @@ -163,14 +153,8 @@ def test_exception_returned(self, arch, port): cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.exception_handler" - Popen(cmd.split(" ")).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + r = self.create_container_and_invoke_function(cmd, port) + self.assertEqual( b'{"errorMessage": "Raising an exception", "errorType": "Exception", "stackTrace": [" File \\"/var/task/main.py\\", line 13, in exception_handler\\n raise Exception(\\"Raising an exception\\")\\n"]}', r.content, @@ -182,15 +166,8 @@ def test_context_get_remaining_time_in_three_seconds(self, arch, port): cmd = f"docker run --name {image} -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=3 -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" - Popen(cmd.split(' ')).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) - + r = self.create_container_and_invoke_function(cmd, port) + # Execution time is not decided, 1.0s ~ 3.0s is a good estimation self.assertLess(int(r.content), 3000) self.assertGreater(int(r.content), 1000) @@ -201,15 +178,8 @@ def test_context_get_remaining_time_in_ten_seconds(self, arch, port): cmd = f"docker run --name {image} -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=10 -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" - Popen(cmd.split(' ')).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) - + r = self.create_container_and_invoke_function(cmd, port) + # Execution time is not decided, 8.0s ~ 10.0s is a good estimation self.assertLess(int(r.content), 10000) self.assertGreater(int(r.content), 8000) @@ -220,14 +190,7 @@ def test_context_get_remaining_time_in_default_deadline(self, arch, port): cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" - Popen(cmd.split(' ')).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + r = self.create_container_and_invoke_function(cmd, port) # Executation time is not decided, 298.0s ~ 300.0s is a good estimation self.assertLess(int(r.content), 300000) @@ -239,14 +202,8 @@ def test_invoke_with_pre_runtime_api_runtime(self, arch, port): cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler" - Popen(cmd.split(" ")).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + r = self.create_container_and_invoke_function(cmd, port) + self.assertEqual(b'"My lambda ran succesfully"', r.content) @parameterized.expand([("x86_64", "8010"), ("arm64", "9010"), ("", "9060")]) @@ -255,14 +212,8 @@ def test_function_name_is_overriden(self, arch, port): cmd = f"docker run --name {image} -d --env AWS_LAMBDA_FUNCTION_NAME=MyCoolName -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.assert_env_var_is_overwritten" - Popen(cmd.split(" ")).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + r = self.create_container_and_invoke_function(cmd, port) + self.assertEqual(b'"My lambda ran succesfully"', r.content) @parameterized.expand([("x86_64", "8011"), ("arm64", "9011"), ("", "9061")]) @@ -272,14 +223,8 @@ def test_port_override(self, arch, port): # Use port 8081 inside the container instead of 8080 cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8081 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler --runtime-interface-emulator-address 0.0.0.0:8081" - Popen(cmd.split(" ")).communicate() - - # sleep 1s to give enough time for the endpoint to be up to curl - time.sleep(SLEEP_TIME) - - r = requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} - ) + r = self.create_container_and_invoke_function(cmd, port) + self.assertEqual(b'"My lambda ran succesfully"', r.content) From 28d666181b3686b093a360ea929ccb4c6f009617 Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Tue, 12 Mar 2024 15:48:20 +0100 Subject: [PATCH 27/41] Allow LocalStack specific configuration of chmod on startup, remove DNS logic (#34) --- cmd/localstack/awsutil.go | 17 --------- cmd/localstack/dns.go | 71 ------------------------------------ cmd/localstack/file_utils.go | 29 +++++++++++++++ cmd/localstack/main.go | 30 +++------------ go.mod | 3 -- go.sum | 32 ---------------- 6 files changed, 34 insertions(+), 148 deletions(-) delete mode 100644 cmd/localstack/dns.go diff --git a/cmd/localstack/awsutil.go b/cmd/localstack/awsutil.go index 7fa02c4..c7fcbc4 100644 --- a/cmd/localstack/awsutil.go +++ b/cmd/localstack/awsutil.go @@ -139,23 +139,6 @@ func resetListener(changeChannel <-chan bool, server *CustomInteropServer) { } -func RunDNSRewriter(opts *LsOpts, ctx context.Context) { - if opts.EnableDnsServer != "1" { - log.Debugln("DNS server disabled.") - return - } - dnsForwarder, err := NewDnsForwarder(opts.LocalstackIP) - if err != nil { - log.Errorln("Error creating dns forwarder.") - return - } - defer dnsForwarder.Shutdown() - dnsForwarder.Start() - - <-ctx.Done() - log.Debugln("DNS server stopped") -} - func RunHotReloadingListener(server *CustomInteropServer, targetPaths []string, ctx context.Context, fileWatcherStrategy string) { if len(targetPaths) == 1 && targetPaths[0] == "" { log.Debugln("Hot reloading disabled.") diff --git a/cmd/localstack/dns.go b/cmd/localstack/dns.go deleted file mode 100644 index fc68819..0000000 --- a/cmd/localstack/dns.go +++ /dev/null @@ -1,71 +0,0 @@ -package main - -import ( - "github.com/miekg/dns" - log "github.com/sirupsen/logrus" - "net" -) - -type DNSForwarder struct { - server *dns.Server -} - -type DNSRewriteForwardHandler struct { - upstreamServer string - redirectTo string -} - -func (D DNSRewriteForwardHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - client := dns.Client{ - Net: "udp", - } - response, _, err := client.Exchange(r, D.upstreamServer+":53") - if err != nil { - log.Errorln("Error connecting to upstream: ", err) - return - } - for _, rr := range response.Answer { - switch rr.Header().Rrtype { - case dns.TypeA: - if t, ok := rr.(*dns.A); ok { - if t.A.Equal(net.IPv4(127, 0, 0, 1)) { - log.Debugln("Redirecting answer for ", t.Header().Name, "to ", D.redirectTo) - t.A = net.ParseIP(D.redirectTo) - } - } - } - } - err = w.WriteMsg(response) - if err != nil { - log.Errorln("Error writing response: ", err) - } -} - -func NewDnsForwarder(upstreamServer string) (*DNSForwarder, error) { - forwarder := &DNSForwarder{ - server: &dns.Server{ - Net: "udp", - Handler: DNSRewriteForwardHandler{ - upstreamServer: upstreamServer, - redirectTo: upstreamServer, - }, - }, - } - return forwarder, nil -} - -func (c *DNSForwarder) Start() { - go func() { - err := c.server.ListenAndServe() - if err != nil { - log.Errorln("Error starting DNS server: ", err) - } - }() -} - -func (c *DNSForwarder) Shutdown() { - err := c.server.Shutdown() - if err != nil { - log.Errorln("Error shutting down DNS server: ", err) - } -} diff --git a/cmd/localstack/file_utils.go b/cmd/localstack/file_utils.go index ed65c70..0de9519 100644 --- a/cmd/localstack/file_utils.go +++ b/cmd/localstack/file_utils.go @@ -1,11 +1,40 @@ package main import ( + "encoding/json" + log "github.com/sirupsen/logrus" "io" "os" "path/filepath" + "strconv" ) +type Chmod struct { + Path string `json:"path"` + Mode string `json:"mode"` +} + +// AdaptFilesystemPermissions Adapts the file system permissions to the mode specified in the chmodInfoString parameter +// chmodInfoString should be a json encoded list of `Chmod` structs. +// example: '[{"path": "/opt", "mode": "0755"}]'. The mode string should be an octal representation of the targeted file mode. +func AdaptFilesystemPermissions(chmodInfoString string) error { + var chmodInfo []Chmod + err := json.Unmarshal([]byte(chmodInfoString), &chmodInfo) + if err != nil { + return err + } + for _, chmod := range chmodInfo { + mode, err := strconv.ParseInt(chmod.Mode, 0, 32) + if err != nil { + return err + } + if err := ChmodRecursively(chmod.Path, os.FileMode(mode)); err != nil { + log.Warnf("Could not change file mode recursively of directory %s: %s\n", chmod.Path, err) + } + } + return nil +} + // Inspired by https://stackoverflow.com/questions/73864379/golang-change-permission-os-chmod-and-os-chowm-recursively // but using the more efficient WalkDir API func ChmodRecursively(root string, mode os.FileMode) error { diff --git a/cmd/localstack/main.go b/cmd/localstack/main.go index e936d78..064a174 100644 --- a/cmd/localstack/main.go +++ b/cmd/localstack/main.go @@ -23,7 +23,7 @@ type LsOpts struct { CodeArchives string HotReloadingPaths []string FileWatcherStrategy string - EnableDnsServer string + ChmodPaths string LocalstackIP string InitLogLevel string EdgePort string @@ -57,10 +57,10 @@ func InitLsOpts() *LsOpts { CodeArchives: os.Getenv("LOCALSTACK_CODE_ARCHIVES"), HotReloadingPaths: strings.Split(GetenvWithDefault("LOCALSTACK_HOT_RELOADING_PATHS", ""), ","), FileWatcherStrategy: os.Getenv("LOCALSTACK_FILE_WATCHER_STRATEGY"), - EnableDnsServer: os.Getenv("LOCALSTACK_ENABLE_DNS_SERVER"), EnableXRayTelemetry: os.Getenv("LOCALSTACK_ENABLE_XRAY_TELEMETRY"), LocalstackIP: os.Getenv("LOCALSTACK_HOSTNAME"), PostInvokeWaitMS: os.Getenv("LOCALSTACK_POST_INVOKE_WAIT_MS"), + ChmodPaths: GetenvWithDefault("LOCALSTACK_CHMOD_PATHS", "[]"), } } @@ -75,12 +75,12 @@ func UnsetLsEnvs() { "LOCALSTACK_USER", "LOCALSTACK_CODE_ARCHIVES", "LOCALSTACK_HOT_RELOADING_PATHS", - "LOCALSTACK_ENABLE_DNS_SERVER", "LOCALSTACK_ENABLE_XRAY_TELEMETRY", "LOCALSTACK_INIT_LOG_LEVEL", "LOCALSTACK_POST_INVOKE_WAIT_MS", "LOCALSTACK_FUNCTION_ACCOUNT_ID", "LOCALSTACK_MAX_PAYLOAD_SIZE", + "LOCALSTACK_CHMOD_PATHS", // Docker container ID "HOSTNAME", @@ -139,31 +139,13 @@ func main() { } interop.MaxPayloadSize = payloadSize - // enable dns server - dnsServerContext, stopDnsServer := context.WithCancel(context.Background()) - go RunDNSRewriter(lsOpts, dnsServerContext) - // download code archive if env variable is set if err := DownloadCodeArchives(lsOpts.CodeArchives); err != nil { log.Fatal("Failed to download code archives: " + err.Error()) } - // set file permissions of the tmp directory for better AWS parity - if err := ChmodRecursively("/tmp", 0700); err != nil { - log.Warnln("Could not change file mode recursively of directory /tmp:", err) - } - // set file permissions of the layers directory for better AWS parity - if err := ChmodRecursively("/opt", 0755); err != nil { - log.Warnln("Could not change file mode recursively of directory /opt:", err) - } - // set file permissions of the code directory if at least one layer is present for better AWS parity - // Limitation: hot reloading likely breaks file permission parity for /var/task in combination with layers - // Heuristic for detecting the presence of layers. It might fail for an empty layer or image-based Lambda. - if isDirEmpty, _ := IsDirEmpty("/opt"); !isDirEmpty { - log.Debugln("Detected layer present") - if err := ChmodRecursively("/var/task", 0755); err != nil { - log.Warnln("Could not change file mode recursively of directory /var/task:", err) - } + if err := AdaptFilesystemPermissions(lsOpts.ChmodPaths); err != nil { + log.Warnln("Could not change file mode of code directories:", err) } // parse CLI args @@ -200,8 +182,6 @@ func main() { AddShutdownFunc(func() { log.Debugln("Stopping file watcher") cancelFileWatcher() - log.Debugln("Stopping DNS server") - stopDnsServer() }). SetExtensionsFlag(true). SetInitCachingFlag(true). diff --git a/go.mod b/go.mod index 206b761..992860c 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,6 @@ require ( github.com/go-chi/chi v4.1.2+incompatible github.com/google/uuid v1.3.0 github.com/jessevdk/go-flags v1.5.0 - github.com/miekg/dns v1.1.50 github.com/shirou/gopsutil v2.19.10+incompatible github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 @@ -26,10 +25,8 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.0 // indirect - golang.org/x/mod v0.8.0 // indirect golang.org/x/net v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.6.0 // indirect gopkg.in/yaml.v2 v2.2.8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 8cc90e8..0474547 100644 --- a/go.sum +++ b/go.sum @@ -25,8 +25,6 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= -github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= -github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -43,53 +41,23 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= From 42aee21cf8ce4bee56cd083cee2446a1c33ba2a9 Mon Sep 17 00:00:00 2001 From: Renato Valenzuela <37676028+valerena@users.noreply.github.com> Date: Wed, 20 Mar 2024 10:56:46 -0700 Subject: [PATCH 28/41] chore(deps): Update to Go 1.21. Update deps (#116) --- Makefile | 2 +- go.mod | 15 +++++++-------- go.sum | 30 ++++++++++++------------------ 3 files changed, 20 insertions(+), 27 deletions(-) diff --git a/Makefile b/Makefile index 80ccb89..1916dae 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ compile-lambda-linux-all: make ARCH=old compile-lambda-linux compile-with-docker: - docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.20 make ARCH=${ARCH} compile-lambda-linux + docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.21 make ARCH=${ARCH} compile-lambda-linux compile-lambda-linux: CGO_ENABLED=0 GOOS=linux GOARCH=${GO_ARCH_${ARCH}} go build -buildvcs=false -ldflags "${RELEASE_BUILD_LINKER_FLAGS}" -o ${DESTINATION_${ARCH}} ./cmd/aws-lambda-rie diff --git a/go.mod b/go.mod index 990a7dd..d48dc60 100644 --- a/go.mod +++ b/go.mod @@ -1,22 +1,21 @@ module go.amzn.com -go 1.20 +go 1.21 require ( - github.com/aws/aws-lambda-go v1.41.0 - github.com/go-chi/chi v4.1.2+incompatible - github.com/google/uuid v1.3.0 + github.com/aws/aws-lambda-go v1.46.0 + github.com/go-chi/chi v1.5.5 + github.com/google/uuid v1.6.0 github.com/jessevdk/go-flags v1.5.0 github.com/sirupsen/logrus v1.9.3 - github.com/stretchr/testify v1.8.4 - golang.org/x/sync v0.2.0 + github.com/stretchr/testify v1.9.0 + golang.org/x/sync v0.6.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/objx v0.5.0 // indirect - golang.org/x/net v0.18.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/sys v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0ea11d6..8974775 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,12 @@ -github.com/aws/aws-lambda-go v1.41.0 h1:l/5fyVb6Ud9uYd411xdHZzSf2n86TakxzpvIoz7l+3Y= -github.com/aws/aws-lambda-go v1.41.0/go.mod h1:jwFe2KmMsHmffA1X2R09hH6lFzJQxzI8qK17ewzbQMM= +github.com/aws/aws-lambda-go v1.46.0 h1:UWVnvh2h2gecOlFhHQfIPQcD8pL/f7pVCutmFl+oXU8= +github.com/aws/aws-lambda-go v1.46.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= -github.com/go-chi/chi v4.1.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/go-chi/chi v1.5.5 h1:vOB/HbEMt9QqBqErz07QehcOKHaWFtuj87tTDVz2qXE= +github.com/go-chi/chi v1.5.5/go.mod h1:C9JqLr3tIYjDOZpzn+BCuxY8z8vmca43EeMgyZt7irw= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -14,23 +14,17 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= -golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= -golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= -golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 4c2c20dfc5b09f6e93bb49a31bb62a6ad628c262 Mon Sep 17 00:00:00 2001 From: Renato Valenzuela <37676028+valerena@users.noreply.github.com> Date: Thu, 28 Mar 2024 12:29:07 -0700 Subject: [PATCH 29/41] chore(deps): Update to Go 1.22 (#117) * Update to Go 1.22 * Update Makefile to run unit tests in container --- Makefile | 13 ++++++++----- go.mod | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 1916dae..6d36ae2 100644 --- a/Makefile +++ b/Makefile @@ -10,10 +10,10 @@ DESTINATION_old:= bin/${BINARY_NAME} DESTINATION_x86_64 := bin/${BINARY_NAME}-x86_64 DESTINATION_arm64 := bin/${BINARY_NAME}-arm64 +run_in_docker = docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.22 $(1) + compile-with-docker-all: - make ARCH=x86_64 compile-with-docker - make ARCH=arm64 compile-with-docker - make ARCH=old compile-with-docker + $(call run_in_docker, make compile-lambda-linux-all) compile-lambda-linux-all: make ARCH=x86_64 compile-lambda-linux @@ -21,11 +21,14 @@ compile-lambda-linux-all: make ARCH=old compile-lambda-linux compile-with-docker: - docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.21 make ARCH=${ARCH} compile-lambda-linux + $(call run_in_docker, make ARCH=${ARCH} compile-lambda-linux) compile-lambda-linux: CGO_ENABLED=0 GOOS=linux GOARCH=${GO_ARCH_${ARCH}} go build -buildvcs=false -ldflags "${RELEASE_BUILD_LINKER_FLAGS}" -o ${DESTINATION_${ARCH}} ./cmd/aws-lambda-rie +tests-with-docker: + $(call run_in_docker, make tests) + tests: go test ./... @@ -33,7 +36,7 @@ integ-tests-and-compile: tests make compile-lambda-linux-all make integ-tests -integ-tests-with-docker: tests +integ-tests-with-docker: tests-with-docker make compile-with-docker-all make integ-tests diff --git a/go.mod b/go.mod index d48dc60..4ee45d7 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module go.amzn.com -go 1.21 +go 1.22 require ( github.com/aws/aws-lambda-go v1.46.0 From fe11d78c4ba308bcee3963b237a91ea8c746d439 Mon Sep 17 00:00:00 2001 From: Roger Zhang Date: Fri, 19 Apr 2024 13:58:03 -0700 Subject: [PATCH 30/41] test: Refactor test cases (#119) Refactor testcases to - Use python3.12 - Respect docker architecture - Run different architecture in parallel GitHub actions --- .github/workflows/integ-tests.yml | 38 ++- Makefile | 30 +- .../local_lambda/test_end_to_end.py | 289 +++++++++--------- test/integration/testdata/Dockerfile-allinone | 3 +- 4 files changed, 211 insertions(+), 149 deletions(-) diff --git a/.github/workflows/integ-tests.yml b/.github/workflows/integ-tests.yml index cb2f9dc..7fddc95 100644 --- a/.github/workflows/integ-tests.yml +++ b/.github/workflows/integ-tests.yml @@ -6,16 +6,44 @@ on: - develop jobs: - integ-tests: + go-tests: runs-on: ubuntu-latest environment: - name: prod + name: integ-tests + steps: + - uses: actions/checkout@v4 + - name: run go tests + run: make tests-with-docker + integ-tests-x86: + runs-on: ubuntu-latest + environment: + name: integ-tests + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: run integration tests + run: make integ-tests-with-docker-x86-64 + integ-tests-arm64: + runs-on: ubuntu-latest + environment: + name: integ-tests + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: run integration tests + run: make integ-tests-with-docker-arm64 + integ-tests-old: + runs-on: ubuntu-latest + environment: + name: integ-tests steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: '3.11' - - name: allows us to build arm64 images - run: docker run --rm --privileged multiarch/qemu-user-static --reset -p yes - name: run integration tests - run: make integ-tests-with-docker \ No newline at end of file + run: make integ-tests-with-docker-old \ No newline at end of file diff --git a/Makefile b/Makefile index 6d36ae2..f7a714e 100644 --- a/Makefile +++ b/Makefile @@ -39,9 +39,35 @@ integ-tests-and-compile: tests integ-tests-with-docker: tests-with-docker make compile-with-docker-all make integ-tests - -integ-tests: + +prep-python: python3 -m venv .venv .venv/bin/pip install --upgrade pip .venv/bin/pip install requests parameterized + +exec-python-e2e-test: .venv/bin/python3 test/integration/local_lambda/test_end_to_end.py + +integ-tests: + make prep-python + docker run --rm --privileged multiarch/qemu-user-static --reset -p yes + make TEST_ARCH=x86_64 TEST_PORT=8002 exec-python-e2e-test + make TEST_ARCH=arm64 TEST_PORT=9002 exec-python-e2e-test + make TEST_ARCH="" TEST_PORT=9052 exec-python-e2e-test + +integ-tests-with-docker-x86-64: + make ARCH=x86_64 compile-with-docker + make prep-python + make TEST_ARCH=x86_64 TEST_PORT=8002 exec-python-e2e-test + +integ-tests-with-docker-arm64: + make ARCH=arm64 compile-with-docker + make prep-python + docker run --rm --privileged multiarch/qemu-user-static --reset -p yes + make TEST_ARCH=arm64 TEST_PORT=9002 exec-python-e2e-test + +integ-tests-with-docker-old: + make ARCH=old compile-with-docker + make prep-python + make TEST_ARCH="" TEST_PORT=9052 exec-python-e2e-test + \ No newline at end of file diff --git a/test/integration/local_lambda/test_end_to_end.py b/test/integration/local_lambda/test_end_to_end.py index fd7f735..7c5486f 100644 --- a/test/integration/local_lambda/test_end_to_end.py +++ b/test/integration/local_lambda/test_end_to_end.py @@ -5,73 +5,57 @@ from unittest import TestCase, main from pathlib import Path import time - +import os import requests +from contextlib import contextmanager from parameterized import parameterized -SLEEP_TIME = 2 +SLEEP_TIME = 1.5 DEFAULT_1P_ENTRYPOINT = "/lambda-entrypoint.sh" ARCHS = ["x86_64", "arm64", ""] + class TestEndToEnd(TestCase): + ARCH = os.environ.get('TEST_ARCH', "") + PORT = os.environ.get('TEST_PORT', 8002) @classmethod def setUpClass(cls): testdata_path = Path(__file__).resolve().parents[1].joinpath("testdata") dockerfile_path = testdata_path.joinpath("Dockerfile-allinone") - cls.image_name = "aws-lambda-local:testing" cls.path_to_binary = Path().resolve().joinpath("bin") # build image - for arch in ARCHS: - image_name = cls.image_name if arch == "" else f"{cls.image_name}-{arch}" - architecture = arch if arch == "arm64" else "amd64" - build_cmd = [ - "docker", - "build", - "--platform", - f"linux/{architecture}", - "-t", - image_name, - "-f", - str(dockerfile_path), - str(testdata_path), - ] - Popen(build_cmd).communicate() + image_name_base = "aws-lambda-local:testing" + cls.image_name = image_name_base if cls.ARCH == "" else f"{image_name_base}-{cls.ARCH}" + architecture = cls.ARCH if cls.ARCH == "arm64" else "amd64" + docker_arch = cls.ARCH if cls.ARCH == "arm64" else "x86_64" + + build_cmd = [ + "docker", + "build", + "--platform", + f"linux/{architecture}", + "-t", + cls.image_name, + "-f", + str(dockerfile_path), + str(testdata_path), + "--build-arg", + f"IMAGE_ARCH={docker_arch}", + ] + Popen(build_cmd).communicate() @classmethod def tearDownClass(cls): - images_to_delete = [ - "envvarcheck", - "twoinvokes", - "arnexists", - "customname", - "timeout", - "exception", - "remaining_time_in_three_seconds", - "remaining_time_in_ten_seconds", - "remaining_time_in_default_deadline", - "pre-runtime-api", - "assert-overwritten", - "port_override" - ] - - for image in images_to_delete: - for arch in ARCHS: - arch_tag = "" if arch == "" else f"-{arch}" - cmd = f"docker rm -f {image}{arch_tag}" - Popen(cmd.split(" ")).communicate() - - for arch in ARCHS: - arch_tag = "" if arch == "" else f"-{arch}" - Popen(f"docker rmi {cls.image_name}{arch_tag}".split(" ")).communicate() + Popen(f"docker rmi {cls.image_name}".split(" ")).communicate() - def tagged_name(self, name, architecture): - tag = self.get_tag(architecture) - return (name + tag, "aws-lambda-rie" + tag, self.image_name + tag) + def tagged_name(self, name): + tag = self.get_tag() + return (name + tag, "aws-lambda-rie" + tag, self.image_name) - def get_tag(self, architecture): - return "" if architecture == "" else str(f"-{architecture}") + def get_tag(self): + return "" if self.ARCH == "" else str(f"-{self.ARCH}") def run_command(self, cmd): Popen(cmd.split(" ")).communicate() @@ -79,153 +63,176 @@ def run_command(self, cmd): def sleep_1s(self): time.sleep(SLEEP_TIME) - def invoke_function(self, port): + def invoke_function(self): return requests.post( - f"http://localhost:{port}/2015-03-31/functions/function/invocations", json={} + f"http://localhost:{self.PORT}/2015-03-31/functions/function/invocations", json={} ) - def create_container_and_invoke_function(self, cmd, port): - self.run_command(cmd) + @contextmanager + def create_container(self, param, image): + try: + platform = "x86_64" if self.ARCH == "" else self.ARCH + cmd_full = f"docker run --platform linux/{platform} {param}" + self.run_command(cmd_full) - # sleep 1s to give enough time for the endpoint to be up to curl - self.sleep_1s() + # sleep 1s to give enough time for the endpoint to be up to curl + self.sleep_1s() + yield + except Exception as e: + print(f"An error occurred while executing cmd: {cmd_full}. error: {e}") + raise e + finally: + self.run_command(f"docker stop {image}") + self.run_command(f"docker rm -f {image}") - return self.invoke_function(port) - @parameterized.expand([("x86_64", "8000"), ("arm64", "9000"), ("", "9050")]) - def test_env_var_with_equal_sign(self, arch, port): - image, rie, image_name = self.tagged_name("envvarcheck", arch) - - cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_env_var_handler" + def test_env_var_with_equal_sign(self): + image, rie, image_name = self.tagged_name("envvarcheck") + params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_env_var_handler" - r = self.create_container_and_invoke_function(cmd, port) + with self.create_container(params, image): + r = self.invoke_function() - self.assertEqual(b'"4=4"', r.content) + self.assertEqual(b'"4=4"', r.content) - @parameterized.expand([("x86_64", "8001"), ("arm64", "9001"), ("", "9051")]) - def test_two_invokes(self, arch, port): - image, rie, image_name = self.tagged_name("twoinvokes", arch) - cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler" + def test_two_invokes(self): + image, rie, image_name = self.tagged_name("twoinvokes") - r = self.create_container_and_invoke_function(cmd, port) - - self.assertEqual(b'"My lambda ran succesfully"', r.content) + params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler" - # Make sure we can invoke the function twice - r = self.invoke_function(port) + with self.create_container(params, image): + r = self.invoke_function() - self.assertEqual(b'"My lambda ran succesfully"', r.content) + self.assertEqual(b'"My lambda ran succesfully"', r.content) + + # Make sure we can invoke the function twice + r = self.invoke_function() + + self.assertEqual(b'"My lambda ran succesfully"', r.content) + - @parameterized.expand([("x86_64", "8002"), ("arm64", "9002"), ("", "9052")]) - def test_lambda_function_arn_exists(self, arch, port): - image, rie, image_name = self.tagged_name("arnexists", arch) + def test_lambda_function_arn_exists(self): + image, rie, image_name = self.tagged_name("arnexists") - cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.assert_lambda_arn_in_context" + params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.assert_lambda_arn_in_context" - r = self.create_container_and_invoke_function(cmd, port) + with self.create_container(params, image): + r = self.invoke_function() - self.assertEqual(b'"My lambda ran succesfully"', r.content) + self.assertEqual(b'"My lambda ran succesfully"', r.content) + - @parameterized.expand([("x86_64", "8003"), ("arm64", "9003"), ("", "9053")]) - def test_lambda_function_arn_exists_with_defining_custom_name(self, arch, port): - image, rie, image_name = self.tagged_name("customname", arch) + def test_lambda_function_arn_exists_with_defining_custom_name(self): + image, rie, image_name = self.tagged_name("customname") - cmd = f"docker run --name {image} --env AWS_LAMBDA_FUNCTION_NAME=MyCoolName -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.assert_lambda_arn_in_context" + params = f"--name {image} --env AWS_LAMBDA_FUNCTION_NAME=MyCoolName -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.assert_lambda_arn_in_context" - r = self.create_container_and_invoke_function(cmd, port) + with self.create_container(params, image): + r = self.invoke_function() - self.assertEqual(b'"My lambda ran succesfully"', r.content) + self.assertEqual(b'"My lambda ran succesfully"', r.content) - @parameterized.expand([("x86_64", "8004"), ("arm64", "9004"), ("", "9054")]) - def test_timeout_invoke(self, arch, port): - image, rie, image_name = self.tagged_name("timeout", arch) - cmd = f"docker run --name {image} -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=1 -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.sleep_handler" + def test_timeout_invoke(self): + image, rie, image_name = self.tagged_name("timeout") - r = self.create_container_and_invoke_function(cmd, port) + params = f"--name {image} -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=1 -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.sleep_handler" + + with self.create_container(params, image): + r = self.invoke_function() - self.assertEqual(b"Task timed out after 1.00 seconds", r.content) + self.assertEqual(b"Task timed out after 1.00 seconds", r.content) - @parameterized.expand([("x86_64", "8005"), ("arm64", "9005"), ("", "9055")]) - def test_exception_returned(self, arch, port): - image, rie, image_name = self.tagged_name("exception", arch) - cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.exception_handler" + def test_exception_returned(self): + image, rie, image_name = self.tagged_name("exception") + + params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.exception_handler" + + with self.create_container(params, image): + r = self.invoke_function() + + # Except the 3 fields assrted below, there's another field `request_id` included start from python3.12 runtime. + # We should ignore asserting the field `request_id` for it is in a UUID like format and changes everytime + result = r.json() + self.assertEqual(result["errorMessage"], "Raising an exception") + self.assertEqual(result["errorType"], "Exception") + self.assertEqual(result["stackTrace"], [" File \"/var/task/main.py\", line 13, in exception_handler\n raise Exception(\"Raising an exception\")\n"]) - r = self.create_container_and_invoke_function(cmd, port) - - self.assertEqual( - b'{"errorMessage": "Raising an exception", "errorType": "Exception", "stackTrace": [" File \\"/var/task/main.py\\", line 13, in exception_handler\\n raise Exception(\\"Raising an exception\\")\\n"]}', - r.content, - ) - @parameterized.expand([("x86_64", "8006"), ("arm64", "9006"), ("", "9056")]) - def test_context_get_remaining_time_in_three_seconds(self, arch, port): - image, rie, image_name = self.tagged_name("remaining_time_in_three_seconds", arch) + def test_context_get_remaining_time_in_three_seconds(self): + image, rie, image_name = self.tagged_name("remaining_time_in_three_seconds") - cmd = f"docker run --name {image} -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=3 -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" + params = f"--name {image} -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=3 -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" - r = self.create_container_and_invoke_function(cmd, port) + with self.create_container(params, image): + r = self.invoke_function() - # Execution time is not decided, 1.0s ~ 3.0s is a good estimation - self.assertLess(int(r.content), 3000) - self.assertGreater(int(r.content), 1000) + # Execution time is not decided, 1.0s ~ 3.0s is a good estimation + self.assertLess(int(r.content), 3000) + self.assertGreater(int(r.content), 1000) - @parameterized.expand([("x86_64", "8007"), ("arm64", "9007"), ("", "9057")]) - def test_context_get_remaining_time_in_ten_seconds(self, arch, port): - image, rie, image_name = self.tagged_name("remaining_time_in_ten_seconds", arch) - cmd = f"docker run --name {image} -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=10 -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" + def test_context_get_remaining_time_in_ten_seconds(self): + image, rie, image_name = self.tagged_name("remaining_time_in_ten_seconds") - r = self.create_container_and_invoke_function(cmd, port) + params = f"--name {image} -d --env AWS_LAMBDA_FUNCTION_TIMEOUT=10 -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" + + with self.create_container(params, image): + r = self.invoke_function() - # Execution time is not decided, 8.0s ~ 10.0s is a good estimation - self.assertLess(int(r.content), 10000) - self.assertGreater(int(r.content), 8000) + # Execution time is not decided, 8.0s ~ 10.0s is a good estimation + self.assertLess(int(r.content), 10000) + self.assertGreater(int(r.content), 8000) + + + def test_context_get_remaining_time_in_default_deadline(self): + image, rie, image_name = self.tagged_name("remaining_time_in_default_deadline") - @parameterized.expand([("x86_64", "8008"), ("arm64", "9008"), ("", "9058")]) - def test_context_get_remaining_time_in_default_deadline(self, arch, port): - image, rie, image_name = self.tagged_name("remaining_time_in_default_deadline", arch) + params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" - cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.check_remaining_time_handler" + with self.create_container(params, image): + r = self.invoke_function() - r = self.create_container_and_invoke_function(cmd, port) + # Executation time is not decided, 298.0s ~ 300.0s is a good estimation + self.assertLess(int(r.content), 300000) + self.assertGreater(int(r.content), 298000) - # Executation time is not decided, 298.0s ~ 300.0s is a good estimation - self.assertLess(int(r.content), 300000) - self.assertGreater(int(r.content), 298000) - @parameterized.expand([("x86_64", "8009"), ("arm64", "9009"), ("", "9059")]) - def test_invoke_with_pre_runtime_api_runtime(self, arch, port): - image, rie, image_name = self.tagged_name("pre-runtime-api", arch) + def test_invoke_with_pre_runtime_api_runtime(self): + image, rie, image_name = self.tagged_name("pre-runtime-api") - cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler" + params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler" - r = self.create_container_and_invoke_function(cmd, port) + with self.create_container(params, image): + r = self.invoke_function() - self.assertEqual(b'"My lambda ran succesfully"', r.content) + self.assertEqual(b'"My lambda ran succesfully"', r.content) - @parameterized.expand([("x86_64", "8010"), ("arm64", "9010"), ("", "9060")]) - def test_function_name_is_overriden(self, arch, port): - image, rie, image_name = self.tagged_name("assert-overwritten", arch) - cmd = f"docker run --name {image} -d --env AWS_LAMBDA_FUNCTION_NAME=MyCoolName -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.assert_env_var_is_overwritten" + def test_function_name_is_overriden(self): + image, rie, image_name = self.tagged_name("assert-overwritten") - r = self.create_container_and_invoke_function(cmd, port) + params = f"--name {image} -d --env AWS_LAMBDA_FUNCTION_NAME=MyCoolName -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.assert_env_var_is_overwritten" + + with self.create_container(params, image): + r = self.invoke_function() - self.assertEqual(b'"My lambda ran succesfully"', r.content) + self.assertEqual(b'"My lambda ran succesfully"', r.content) + - @parameterized.expand([("x86_64", "8011"), ("arm64", "9011"), ("", "9061")]) - def test_port_override(self, arch, port): - image, rie, image_name = self.tagged_name("port_override", arch) + def test_port_override(self): + image, rie, image_name = self.tagged_name("port_override") # Use port 8081 inside the container instead of 8080 - cmd = f"docker run --name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {port}:8081 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler --runtime-interface-emulator-address 0.0.0.0:8081" + params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8081 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.success_handler --runtime-interface-emulator-address 0.0.0.0:8081" - r = self.create_container_and_invoke_function(cmd, port) + with self.create_container(params, image): + r = self.invoke_function() - self.assertEqual(b'"My lambda ran succesfully"', r.content) + self.assertEqual(b'"My lambda ran succesfully"', r.content) + if __name__ == "__main__": diff --git a/test/integration/testdata/Dockerfile-allinone b/test/integration/testdata/Dockerfile-allinone index b804e5c..1d28406 100644 --- a/test/integration/testdata/Dockerfile-allinone +++ b/test/integration/testdata/Dockerfile-allinone @@ -1,4 +1,5 @@ -FROM public.ecr.aws/lambda/python:3.8 +ARG IMAGE_ARCH +FROM public.ecr.aws/lambda/python:3.12-$IMAGE_ARCH WORKDIR /var/task COPY ./ ./ From ba56ed44080dbd27872fb6bbebc9b9197307f163 Mon Sep 17 00:00:00 2001 From: Marco Cieno Date: Tue, 30 Apr 2024 00:13:20 +0200 Subject: [PATCH 31/41] feat: allow user-defined client context (#110) --- cmd/aws-lambda-rie/handlers.go | 9 ++++++ .../local_lambda/test_end_to_end.py | 30 ++++++++++++++++--- test/integration/testdata/main.py | 4 +++ 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/cmd/aws-lambda-rie/handlers.go b/cmd/aws-lambda-rie/handlers.go index 42032cf..2cca12d 100644 --- a/cmd/aws-lambda-rie/handlers.go +++ b/cmd/aws-lambda-rie/handlers.go @@ -5,6 +5,7 @@ package main import ( "bytes" + "encoding/base64" "fmt" "io/ioutil" "math" @@ -81,6 +82,13 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs i return } + rawClientContext, err := base64.StdEncoding.DecodeString(r.Header.Get("X-Amz-Client-Context")) + if err != nil { + log.Errorf("Failed to decode X-Amz-Client-Context: %s", err) + w.WriteHeader(500) + return + } + initDuration := "" inv := GetenvWithDefault("AWS_LAMBDA_FUNCTION_TIMEOUT", "300") timeoutDuration, _ := time.ParseDuration(inv + "s") @@ -114,6 +122,7 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs i TraceID: r.Header.Get("X-Amzn-Trace-Id"), LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), Payload: bytes.NewReader(bodyBytes), + ClientContext: string(rawClientContext), } fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion) diff --git a/test/integration/local_lambda/test_end_to_end.py b/test/integration/local_lambda/test_end_to_end.py index 7c5486f..8e34b77 100644 --- a/test/integration/local_lambda/test_end_to_end.py +++ b/test/integration/local_lambda/test_end_to_end.py @@ -4,6 +4,8 @@ from subprocess import Popen, PIPE from unittest import TestCase, main from pathlib import Path +import base64 +import json import time import os import requests @@ -62,12 +64,14 @@ def run_command(self, cmd): def sleep_1s(self): time.sleep(SLEEP_TIME) - - def invoke_function(self): + + def invoke_function(self, json={}, headers={}): return requests.post( - f"http://localhost:{self.PORT}/2015-03-31/functions/function/invocations", json={} + f"http://localhost:{self.PORT}/2015-03-31/functions/function/invocations", + json=json, + headers=headers, ) - + @contextmanager def create_container(self, param, image): try: @@ -234,6 +238,24 @@ def test_port_override(self): self.assertEqual(b'"My lambda ran succesfully"', r.content) + def test_custom_client_context(self): + image, rie, image_name = self.tagged_name("custom_client_context") + + params = f"--name {image} -d -v {self.path_to_binary}:/local-lambda-runtime-server -p {self.PORT}:8080 --entrypoint /local-lambda-runtime-server/{rie} {image_name} {DEFAULT_1P_ENTRYPOINT} main.custom_client_context_handler" + + with self.create_container(params, image): + r = self.invoke_function(headers={ + "X-Amz-Client-Context": base64.b64encode(json.dumps({ + "custom": { + "foo": "bar", + "baz": 123, + } + }).encode('utf8')).decode('utf8'), + }) + content = json.loads(r.content) + self.assertEqual("bar", content["foo"]) + self.assertEqual(123, content["baz"]) + if __name__ == "__main__": main() diff --git a/test/integration/testdata/main.py b/test/integration/testdata/main.py index b6b527d..9757be8 100644 --- a/test/integration/testdata/main.py +++ b/test/integration/testdata/main.py @@ -41,3 +41,7 @@ def check_remaining_time_handler(event, context): # Wait 1s to see if the remaining time changes time.sleep(1) return context.get_remaining_time_in_millis() + + +def custom_client_context_handler(event, context): + return context.client_context.custom From d37e08c13600eae4deb1329603f01a78363f360e Mon Sep 17 00:00:00 2001 From: seshubaws <116689586+seshubaws@users.noreply.github.com> Date: Tue, 14 May 2024 11:45:41 -0700 Subject: [PATCH 32/41] Added workflow for automated releases (#121) * Added release workflow --- .github/workflows/release.yml | 41 +++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 .github/workflows/release.yml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..32e878d --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,41 @@ +name: Release + +on: + workflow_dispatch: + inputs: + releaseVersion: + description: "Version to use for the release." + required: true + default: "X.Y" + releaseBody: + description: "Information about the release" + required: true + default: "New release" +jobs: + Release: + environment: Release + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: main + - name: Set up python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Build + run: make compile-with-docker-all + - name: Run Integ Tests + run: | + make tests-with-docker + make integ-tests + - name: Release + uses: softprops/action-gh-release@v2 + with: + name: Release ${{ github.event.inputs.releaseVersion }} + tag_name: v${{ github.event.inputs.releaseVersion }} + body: ${{ github.event.inputs.releaseBody }} + files: | + bin/aws-lambda-rie + bin/aws-lambda-rie-arm64 + bin/aws-lambda-rie-x86_64 From 9e6041b151436647af596aec7b48c88f15ac360a Mon Sep 17 00:00:00 2001 From: Renato Valenzuela <37676028+valerena@users.noreply.github.com> Date: Mon, 3 Jun 2024 16:23:06 -0700 Subject: [PATCH 33/41] feat: Add automatic vulnerabilities check (#123) * Add automatic vulnerabilities check --- .github/workflows/check-binaries.yml | 78 ++++++++++++++++++++++++++++ Makefile | 5 +- 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/check-binaries.yml diff --git a/.github/workflows/check-binaries.yml b/.github/workflows/check-binaries.yml new file mode 100644 index 0000000..bd41ece --- /dev/null +++ b/.github/workflows/check-binaries.yml @@ -0,0 +1,78 @@ +name: Check binaries + +on: + workflow_dispatch: + schedule: + - cron: "0 16 * * 1-5" # min h d Mo DoW / 9am PST M-F + +jobs: + check-for-vulnerabilities: + runs-on: ubuntu-latest + outputs: + report_contents: ${{ steps.save-output.outputs.report_contents }} + steps: + - name: Setup python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: main + - name: Download latest release + uses: robinraju/release-downloader@v1.10 + with: + latest: true + fileName: 'aws-lambda-rie*' + out-file-path: "bin" + - name: Run check for vulnerabilities + id: check-binaries + run: | + make check-binaries + - if: always() && failure() # `always()` to run even if the previous step failed. Failure means that there are vulnerabilities + name: Save content of the vulnerabilities report as GitHub output + id: save-output + run: | + report_csv="$(ls -tr output.cve-bin-*.csv 2>/dev/null | tail -n1)" # last file generated + echo "Vulnerabilities stored in $report_csv" + final_report="${report_csv}.txt" + awk -F',' '{n=split($10, path, "/"); print $2,$3,$4,$5,path[n]}' "$report_csv" | column -t > "$final_report" # make the CSV nicer + echo "report_contents<> "$GITHUB_OUTPUT" + cat "$final_report" >> "$GITHUB_OUTPUT" + echo "EOF" >> "$GITHUB_OUTPUT" + - if: always() && steps.check-binaries.outcome == 'failure' + name: Build new binaries and check vulnerabilities again + id: check-new-version + run: | + mkdir ./bin2 + mv ./bin/* ./bin2 + make compile-with-docker-all + latest_version=$(strings bin/aws-lambda-rie* | grep '^go1\.' | sort | uniq) + echo "latest_version=$latest_version" >> "$GITHUB_OUTPUT" + make check-binaries + - if: always() && steps.check-binaries.outcome == 'failure' + name: Save outputs for the check with the latest build + id: save-new-version + run: | + if [ "${{ steps.check-new-version.outcome }}" == "failure" ]; then + fixed="No" + else + fixed="Yes" + fi + echo "fixed=$fixed" >> "$GITHUB_OUTPUT" + - if: always() && steps.check-binaries.outcome == 'failure' + name: Create GitHub Issue indicating vulnerabilities + id: create-issue + uses: dacbd/create-issue-action@main + with: + token: ${{ github.token }} + title: | + CVEs found in latest RIE release + body: | + ### CVEs found in latest RIE release + ``` + ${{ steps.save-output.outputs.report_contents }} + ``` + + #### Are these resolved by building with the latest patch version of Go (${{ steps.check-new-version.outputs.latest_version }})?: + > **${{ steps.save-new-version.outputs.fixed }}** diff --git a/Makefile b/Makefile index f7a714e..6b66e79 100644 --- a/Makefile +++ b/Makefile @@ -70,4 +70,7 @@ integ-tests-with-docker-old: make ARCH=old compile-with-docker make prep-python make TEST_ARCH="" TEST_PORT=9052 exec-python-e2e-test - \ No newline at end of file + +check-binaries: prep-python + .venv/bin/pip install cve-bin-tool + .venv/bin/python -m cve_bin_tool.cli bin/ -r go -d REDHAT,OSV,GAD,CURL --no-0-cve-report -f csv From 71388dd788b7a5519262391ce73fe6548dbaf86e Mon Sep 17 00:00:00 2001 From: Renato Valenzuela <37676028+valerena@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:51:08 -0700 Subject: [PATCH 34/41] fix: Vulnerability checks: create issue only when checked was done (#125) --- .github/workflows/check-binaries.yml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/check-binaries.yml b/.github/workflows/check-binaries.yml index bd41ece..75fa28f 100644 --- a/.github/workflows/check-binaries.yml +++ b/.github/workflows/check-binaries.yml @@ -34,13 +34,17 @@ jobs: id: save-output run: | report_csv="$(ls -tr output.cve-bin-*.csv 2>/dev/null | tail -n1)" # last file generated - echo "Vulnerabilities stored in $report_csv" + if [ -z "$report_csv" ]; then + echo "No file with vulnerabilities. Probably a failure in previous step." + else + echo "Vulnerabilities stored in $report_csv" + fi final_report="${report_csv}.txt" awk -F',' '{n=split($10, path, "/"); print $2,$3,$4,$5,path[n]}' "$report_csv" | column -t > "$final_report" # make the CSV nicer echo "report_contents<> "$GITHUB_OUTPUT" cat "$final_report" >> "$GITHUB_OUTPUT" echo "EOF" >> "$GITHUB_OUTPUT" - - if: always() && steps.check-binaries.outcome == 'failure' + - if: always() && steps.save-output.outputs.report_contents name: Build new binaries and check vulnerabilities again id: check-new-version run: | @@ -50,7 +54,7 @@ jobs: latest_version=$(strings bin/aws-lambda-rie* | grep '^go1\.' | sort | uniq) echo "latest_version=$latest_version" >> "$GITHUB_OUTPUT" make check-binaries - - if: always() && steps.check-binaries.outcome == 'failure' + - if: always() && steps.save-output.outputs.report_contents name: Save outputs for the check with the latest build id: save-new-version run: | @@ -60,7 +64,7 @@ jobs: fixed="Yes" fi echo "fixed=$fixed" >> "$GITHUB_OUTPUT" - - if: always() && steps.check-binaries.outcome == 'failure' + - if: always() && steps.save-output.outputs.report_contents name: Create GitHub Issue indicating vulnerabilities id: create-issue uses: dacbd/create-issue-action@main From bba504e47ac879a6340d73a5b1b388a0968b8a5a Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Tue, 20 Aug 2024 12:09:30 +0200 Subject: [PATCH 35/41] Update go version (#35) --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3ccb66f..4603319 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.20' + go-version: '1.22' - name: Build env: @@ -29,7 +29,7 @@ jobs: name: aws-lambda-rie path: bin/* - name: Release binaries - uses: softprops/action-gh-release@v1 + uses: softprops/action-gh-release@v2 if: startsWith(github.ref, 'refs/tags/') with: files: bin/* From 781cd9a296b10ed44a0223b34092703ddc4a36b6 Mon Sep 17 00:00:00 2001 From: Renato Valenzuela <37676028+valerena@users.noreply.github.com> Date: Wed, 11 Dec 2024 17:16:03 -0800 Subject: [PATCH 36/41] test: Add delay for time-related arm64 tests (#138) The tests on GitHub run on x86 instances (because arm64 instances don't have Docker installed) so when running the arm64 tests, invokes take longer because of the cross-architecture emulation. This caused that some tests that check remaining time in the function were not landing on the correct time range. It's unknown why this delay started manifesting more consistently, we might want to find a better solution in the future. --- .../local_lambda/test_end_to_end.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/test/integration/local_lambda/test_end_to_end.py b/test/integration/local_lambda/test_end_to_end.py index 8e34b77..d564bb1 100644 --- a/test/integration/local_lambda/test_end_to_end.py +++ b/test/integration/local_lambda/test_end_to_end.py @@ -173,9 +173,8 @@ def test_context_get_remaining_time_in_three_seconds(self): with self.create_container(params, image): r = self.invoke_function() - # Execution time is not decided, 1.0s ~ 3.0s is a good estimation - self.assertLess(int(r.content), 3000) - self.assertGreater(int(r.content), 1000) + # Execution time is not decided, but it should be around 2.0s + self.assertAround(int(r.content), 2000) def test_context_get_remaining_time_in_ten_seconds(self): @@ -186,9 +185,8 @@ def test_context_get_remaining_time_in_ten_seconds(self): with self.create_container(params, image): r = self.invoke_function() - # Execution time is not decided, 8.0s ~ 10.0s is a good estimation - self.assertLess(int(r.content), 10000) - self.assertGreater(int(r.content), 8000) + # Execution time is not decided, but it should be around 9.0s + self.assertAround(int(r.content), 9000) def test_context_get_remaining_time_in_default_deadline(self): @@ -199,9 +197,8 @@ def test_context_get_remaining_time_in_default_deadline(self): with self.create_container(params, image): r = self.invoke_function() - # Executation time is not decided, 298.0s ~ 300.0s is a good estimation - self.assertLess(int(r.content), 300000) - self.assertGreater(int(r.content), 298000) + # Execution time is not decided, but it should be around 299.0s + self.assertAround(int(r.content), 299000) def test_invoke_with_pre_runtime_api_runtime(self): @@ -256,6 +253,13 @@ def test_custom_client_context(self): self.assertEqual("bar", content["foo"]) self.assertEqual(123, content["baz"]) + def assertAround(self, number, target): + # Emulating arm64 on x86 causes the invoke to take longer + delay_arm64 = 500 + actual_target = target if self.ARCH != 'arm64' else target - delay_arm64 + + self.assertLess(number, actual_target + 1000) + self.assertGreater(number, actual_target - 1000) if __name__ == "__main__": main() From 0b2b5bebe96aa1e94d9bdb1138f0eff61a30d5fb Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Tue, 17 Dec 2024 16:58:45 +0100 Subject: [PATCH 37/41] Upgrade xray dependency and its indirect dependencies (#36) --- go.mod | 8 ++++---- go.sum | 41 +++++++++++++++++++++++++++++++---------- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/go.mod b/go.mod index fe41c9a..dd00025 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.22 require ( github.com/aws/aws-lambda-go v1.46.0 - github.com/aws/aws-sdk-go v1.44.62 - github.com/aws/aws-xray-daemon v0.0.0-20230202010956-acaf06e9a638 + github.com/aws/aws-sdk-go v1.44.298 + github.com/aws/aws-xray-daemon v0.0.0-20240827235329-2e2596c6bb93 github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 github.com/fsnotify/fsnotify v1.6.0 github.com/go-chi/chi v1.5.5 @@ -15,7 +15,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 golang.org/x/sync v0.6.0 - golang.org/x/sys v0.14.0 + golang.org/x/sys v0.18.0 ) require ( @@ -25,7 +25,7 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect - golang.org/x/net v0.18.0 // indirect + golang.org/x/net v0.23.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v2 v2.2.8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 5b5f27c..7c9340f 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,10 @@ github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d h1:G0m3OIz70MZUW github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= github.com/aws/aws-lambda-go v1.46.0 h1:UWVnvh2h2gecOlFhHQfIPQcD8pL/f7pVCutmFl+oXU8= github.com/aws/aws-lambda-go v1.46.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= -github.com/aws/aws-sdk-go v1.44.62 h1:N8qOPnBhl2ZCIFiqyB640Xt5CeX9D8CEVhG/Vj7jGJU= -github.com/aws/aws-sdk-go v1.44.62/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= -github.com/aws/aws-xray-daemon v0.0.0-20230202010956-acaf06e9a638 h1:G0C87W0m2uyh3uHV24Q60JJx+AyJ3//gJjalvSizXhc= -github.com/aws/aws-xray-daemon v0.0.0-20230202010956-acaf06e9a638/go.mod h1:glwf7zqf0NzGozJscRs0/xC+CpTU4DyMN4V9eXxD2Co= +github.com/aws/aws-sdk-go v1.44.298 h1:5qTxdubgV7PptZJmp/2qDwD2JL187ePL7VOxsSh1i3g= +github.com/aws/aws-sdk-go v1.44.298/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= +github.com/aws/aws-xray-daemon v0.0.0-20240827235329-2e2596c6bb93 h1:1O9QBEGf/IBjdybrzg4QtY+zuikfrLcwUjb0Mq/Hk+U= +github.com/aws/aws-xray-daemon v0.0.0-20240827235329-2e2596c6bb93/go.mod h1:OtEUQKJwoYfOGGQOQ+d/mayEHAC00j6vb51HqpnxiV0= github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 h1:kHaBemcxl8o/pQ5VM1c8PVE1PubbNx3mjUr09OqWGCs= github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575/go.mod h1:9d6lWj8KzO/fd/NrVaLscBKmPigpZpn5YawRPw+e3Yo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -38,23 +38,44 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= -golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= From cf26b4380dec54a0e5b85bc073e686131b74ca8d Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Wed, 26 Feb 2025 10:35:20 +0100 Subject: [PATCH 38/41] Upgrade xray dependency and transitive dependencies (#37) --- go.mod | 10 +++++----- go.sum | 10 ++++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index dd00025..d8831ca 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.22 require ( github.com/aws/aws-lambda-go v1.46.0 github.com/aws/aws-sdk-go v1.44.298 - github.com/aws/aws-xray-daemon v0.0.0-20240827235329-2e2596c6bb93 + github.com/aws/aws-xray-daemon v0.0.0-20250212175715-5defe1b8d61b github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 github.com/fsnotify/fsnotify v1.6.0 github.com/go-chi/chi v1.5.5 @@ -14,8 +14,8 @@ require ( github.com/shirou/gopsutil v2.19.10+incompatible github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 - golang.org/x/sync v0.6.0 - golang.org/x/sys v0.18.0 + golang.org/x/sync v0.10.0 + golang.org/x/sys v0.28.0 ) require ( @@ -25,8 +25,8 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect - golang.org/x/net v0.23.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/net v0.33.0 // indirect + golang.org/x/text v0.21.0 // indirect gopkg.in/yaml.v2 v2.2.8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 7c9340f..e49a343 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/aws/aws-sdk-go v1.44.298 h1:5qTxdubgV7PptZJmp/2qDwD2JL187ePL7VOxsSh1i github.com/aws/aws-sdk-go v1.44.298/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-xray-daemon v0.0.0-20240827235329-2e2596c6bb93 h1:1O9QBEGf/IBjdybrzg4QtY+zuikfrLcwUjb0Mq/Hk+U= github.com/aws/aws-xray-daemon v0.0.0-20240827235329-2e2596c6bb93/go.mod h1:OtEUQKJwoYfOGGQOQ+d/mayEHAC00j6vb51HqpnxiV0= +github.com/aws/aws-xray-daemon v0.0.0-20250212175715-5defe1b8d61b h1:hiV1SQDGCUECdYdKRvfBmIZnoCWggTDauTintGTkIFU= +github.com/aws/aws-xray-daemon v0.0.0-20250212175715-5defe1b8d61b/go.mod h1:1tKEa2CqVzCVcMS59532MHzZP5P0hF682qCGpR/Tl1k= github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 h1:kHaBemcxl8o/pQ5VM1c8PVE1PubbNx3mjUr09OqWGCs= github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575/go.mod h1:9d6lWj8KzO/fd/NrVaLscBKmPigpZpn5YawRPw+e3Yo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -48,10 +50,14 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -63,6 +69,8 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -72,6 +80,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= From 3a0772eae98d7653006b259e6be9c2a8e5b32d88 Mon Sep 17 00:00:00 2001 From: Roger Zhang Date: Wed, 23 Apr 2025 10:31:40 -0700 Subject: [PATCH 39/41] chore(deps): Update to Go 1.24 (#143) * update to 1.24 * fix changes --- Makefile | 2 +- go.mod | 2 +- lambda/rapi/handler/agentnext.go | 2 +- lambda/rapi/handler/agentregister.go | 2 +- lambda/rapi/handler/runtimelogs.go | 4 ++-- lambda/rapid/handlers.go | 4 ++-- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index 6b66e79..077cf31 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ DESTINATION_old:= bin/${BINARY_NAME} DESTINATION_x86_64 := bin/${BINARY_NAME}-x86_64 DESTINATION_arm64 := bin/${BINARY_NAME}-arm64 -run_in_docker = docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.22 $(1) +run_in_docker = docker run --env GOPROXY=direct -v $(shell pwd):/LambdaRuntimeLocal -w /LambdaRuntimeLocal golang:1.24 $(1) compile-with-docker-all: $(call run_in_docker, make compile-lambda-linux-all) diff --git a/go.mod b/go.mod index 4ee45d7..5519b8c 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module go.amzn.com -go 1.22 +go 1.24 require ( github.com/aws/aws-lambda-go v1.46.0 diff --git a/lambda/rapi/handler/agentnext.go b/lambda/rapi/handler/agentnext.go index 7ce76f0..ffdd61d 100644 --- a/lambda/rapi/handler/agentnext.go +++ b/lambda/rapi/handler/agentnext.go @@ -48,7 +48,7 @@ func (h *agentNextHandler) ServeHTTP(writer http.ResponseWriter, request *http.R } } else { log.Warnf("Unknown agent %s tried to call /next", agentID.String()) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension"+agentID.String()) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension %s", agentID.String()) return } diff --git a/lambda/rapi/handler/agentregister.go b/lambda/rapi/handler/agentregister.go index 8da9e4c..867ad9d 100644 --- a/lambda/rapi/handler/agentregister.go +++ b/lambda/rapi/handler/agentregister.go @@ -77,7 +77,7 @@ func (h *agentRegisterHandler) ServeHTTP(writer http.ResponseWriter, request *ht registerRequest, err := parseRegister(request) if err != nil { - rendering.RenderForbiddenWithTypeMsg(writer, request, errInvalidRequestFormat, err.Error()) + rendering.RenderForbiddenWithTypeMsg(writer, request, errInvalidRequestFormat, "%s", err.Error()) return } diff --git a/lambda/rapi/handler/runtimelogs.go b/lambda/rapi/handler/runtimelogs.go index 6b8a67e..4fd534e 100644 --- a/lambda/rapi/handler/runtimelogs.go +++ b/lambda/rapi/handler/runtimelogs.go @@ -30,7 +30,7 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http log.Errorf("Agent Verification Error: %s", err) switch err := err.(type) { case *ErrAgentIdentifierUnknown: - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension "+err.agentID.String()) + rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension %s", err.agentID.String()) h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) @@ -55,7 +55,7 @@ func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http switch err { case telemetry.ErrTelemetryServiceOff: rendering.RenderForbiddenWithTypeMsg(writer, request, - h.telemetrySubscription.GetServiceClosedErrorType(), h.telemetrySubscription.GetServiceClosedErrorMessage()) + h.telemetrySubscription.GetServiceClosedErrorType(), "%s", h.telemetrySubscription.GetServiceClosedErrorMessage()) h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) default: rendering.RenderInternalServerError(writer, request) diff --git a/lambda/rapid/handlers.go b/lambda/rapid/handlers.go index f379c4c..2e759e9 100644 --- a/lambda/rapid/handlers.go +++ b/lambda/rapid/handlers.go @@ -243,7 +243,7 @@ func (c *rapidContext) watchEvents(events <-chan supvmodel.Event) { if termination.Success() { err = fmt.Errorf("exit code 0") } else { - err = fmt.Errorf(termination.String()) + err = fmt.Errorf("%s", termination.String()) } appctx.StoreFirstFatalError(c.appCtx, fatalerror.AgentCrash) @@ -851,7 +851,7 @@ func handleRestore(execCtx *rapidContext, restore *interop.Restore) (interop.Res // check if there is any error stored in appctx to get the root cause error type // Runtime.ExitError is an example to such a scenario if fatalErrorFound { - err = fmt.Errorf(string(fatalErrorType)) + err = fmt.Errorf("%s", string(fatalErrorType)) } if err != nil { From 10daeb8b4fa00f1a68bf9cbde4bd508ec0601589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Batuhan=20Apayd=C4=B1n?= Date: Thu, 24 Apr 2025 17:26:22 +0300 Subject: [PATCH 40/41] fix CVE-2025-22872, update to go 1.23 (#38) --- go.mod | 12 +++++++----- go.sum | 26 ++++++++------------------ 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/go.mod b/go.mod index d8831ca..200cb41 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module go.amzn.com -go 1.22 +go 1.23.0 + +toolchain go1.24.1 require ( github.com/aws/aws-lambda-go v1.46.0 @@ -14,8 +16,8 @@ require ( github.com/shirou/gopsutil v2.19.10+incompatible github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 - golang.org/x/sync v0.10.0 - golang.org/x/sys v0.28.0 + golang.org/x/sync v0.12.0 + golang.org/x/sys v0.31.0 ) require ( @@ -25,8 +27,8 @@ require ( github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect - golang.org/x/net v0.33.0 // indirect - golang.org/x/text v0.21.0 // indirect + golang.org/x/net v0.38.0 // indirect + golang.org/x/text v0.23.0 // indirect gopkg.in/yaml.v2 v2.2.8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e49a343..3f2f234 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,6 @@ github.com/aws/aws-lambda-go v1.46.0 h1:UWVnvh2h2gecOlFhHQfIPQcD8pL/f7pVCutmFl+o github.com/aws/aws-lambda-go v1.46.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= github.com/aws/aws-sdk-go v1.44.298 h1:5qTxdubgV7PptZJmp/2qDwD2JL187ePL7VOxsSh1i3g= github.com/aws/aws-sdk-go v1.44.298/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= -github.com/aws/aws-xray-daemon v0.0.0-20240827235329-2e2596c6bb93 h1:1O9QBEGf/IBjdybrzg4QtY+zuikfrLcwUjb0Mq/Hk+U= -github.com/aws/aws-xray-daemon v0.0.0-20240827235329-2e2596c6bb93/go.mod h1:OtEUQKJwoYfOGGQOQ+d/mayEHAC00j6vb51HqpnxiV0= github.com/aws/aws-xray-daemon v0.0.0-20250212175715-5defe1b8d61b h1:hiV1SQDGCUECdYdKRvfBmIZnoCWggTDauTintGTkIFU= github.com/aws/aws-xray-daemon v0.0.0-20250212175715-5defe1b8d61b/go.mod h1:1tKEa2CqVzCVcMS59532MHzZP5P0hF682qCGpR/Tl1k= github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 h1:kHaBemcxl8o/pQ5VM1c8PVE1PubbNx3mjUr09OqWGCs= @@ -48,16 +46,12 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= -golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -67,10 +61,8 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -78,10 +70,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= From 9be92e099e99c9973f2e1730d1af6d4ffd769bb2 Mon Sep 17 00:00:00 2001 From: Daniel Fangl Date: Thu, 24 Apr 2025 16:33:15 +0200 Subject: [PATCH 41/41] Pull upstream changes, upgrade go to 1.24, add codeowners (#39) --- .github/workflows/build.yml | 2 +- CODEOWNERS | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 CODEOWNERS diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4603319..919a2d0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.24' - name: Build env: diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..e912ea1 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,2 @@ +# Add default reviewers for community PRs +* @dfangl @joe4dev @gregfurman @dominikschubert