From 8b61a4e63d7e4f545b2f2959a49e5c45bd388da8 Mon Sep 17 00:00:00 2001 From: Cory Jacobsen Date: Fri, 21 Jun 2024 11:07:51 -0600 Subject: [PATCH 01/10] bringing linter up to date (#95) --- .github/workflows/lint.yaml | 19 ++++ .github/workflows/test.yaml | 22 ++-- .golangci.yaml | 43 ++++---- csp.go | 3 +- csp_test.go | 5 +- cspbuilder/builder_test.go | 5 + cspbuilder/directive_builder_test.go | 2 + secure_test.go | 144 ++++++++++++++------------- 8 files changed, 137 insertions(+), 106 deletions(-) create mode 100644 .github/workflows/lint.yaml diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..effbbad --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,19 @@ +on: + push: + branches: + - master + - v1 + pull_request: + branches: + - "**" +name: Linter +jobs: + golangci: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + path: src/github.com/unrolled/secure + - uses: golangci/golangci-lint-action@v4 + with: + working-directory: src/github.com/unrolled/secure diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 00fe2d2..485d6c6 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -6,26 +6,26 @@ on: pull_request: branches: - "**" -name: tests +name: Tests jobs: tests: strategy: matrix: - go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x] + go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x, 1.22.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} steps: - - name: Install Go - uses: actions/setup-go@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} - - name: Checkout code - uses: actions/checkout@v2 - - name: Test - run: make ci + - run: make ci golangci: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: golangci-lint - uses: golangci/golangci-lint-action@v2 + - uses: actions/checkout@v4 + with: + path: src/github.com/unrolled/secure + - uses: golangci/golangci-lint-action@v4 + with: + working-directory: src/github.com/7shifts/seven-deploy diff --git a/.golangci.yaml b/.golangci.yaml index 392b4b3..3eab3e6 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,36 +1,31 @@ run: - timeout: 5m + timeout: 10m + modules-download-mode: readonly + allow-parallel-runners: true linters: enable-all: true disable: - # Deprecated linters - - varcheck - - exhaustivestruct - - ifshort - - structcheck - - golint - - maligned - - interfacer - - nosnakecase - - deadcode - - scopelint - - rowserrcheck - - sqlclosecheck - - structcheck - - wastedassign - # Ignoring - - lll - - varnamelen - paralleltest - - testpackage - - goerr113 + - gochecknoglobals - exhaustruct - - nestif + - wrapcheck + - tagliatelle + - depguard + - ireturn - funlen - - goconst + - varnamelen + - gomnd + - execinquery + - copyloopvar + - intrange + - gocognit + - lll - cyclop - gocyclo - - gocognit + - testpackage + - err113 + - nestif - maintidx - contextcheck + - perfsprint diff --git a/csp.go b/csp.go index 948358f..4d516c2 100644 --- a/csp.go +++ b/csp.go @@ -12,7 +12,8 @@ type key int const cspNonceKey key = iota -// CSPNonce returns the nonce value associated with the present request. If no nonce has been generated it returns an empty string. +// CSPNonce returns the nonce value associated with the present request. +// If no nonce has been generated it returns an empty string. func CSPNonce(c context.Context) string { if val, ok := c.Value(cspNonceKey).(string); ok { return val diff --git a/csp_test.go b/csp_test.go index edf662f..f09042b 100644 --- a/csp_test.go +++ b/csp_test.go @@ -23,7 +23,10 @@ func TestCSPNonce(t *testing.T) { }{ {Options{ContentSecurityPolicy: csp}, []string{"Content-Security-Policy"}}, {Options{ContentSecurityPolicyReportOnly: csp}, []string{"Content-Security-Policy-Report-Only"}}, - {Options{ContentSecurityPolicy: csp, ContentSecurityPolicyReportOnly: csp}, []string{"Content-Security-Policy", "Content-Security-Policy-Report-Only"}}, + { + Options{ContentSecurityPolicy: csp, ContentSecurityPolicyReportOnly: csp}, + []string{"Content-Security-Policy", "Content-Security-Policy-Report-Only"}, + }, } for _, c := range cases { diff --git a/cspbuilder/builder_test.go b/cspbuilder/builder_test.go index fb4645a..35079d3 100644 --- a/cspbuilder/builder_test.go +++ b/cspbuilder/builder_test.go @@ -42,6 +42,7 @@ func TestContentSecurityPolicyBuilder_Build_SingleDirective(t *testing.T) { tt.directiveName: tt.directiveValues, }, } + got, err := builder.Build() if (err != nil) != tt.wantErr { t.Errorf("ContentSecurityPolicyBuilder.Build() error = %v, wantErr %v", err, tt.wantErr) @@ -92,6 +93,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { builder := &Builder{ Directives: tt.directives, } + got, err := builder.Build() if (err != nil) != tt.wantErr { t.Errorf("ContentSecurityPolicyBuilder.Build() error = %v, wantErr %v", err, tt.wantErr) @@ -101,6 +103,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { { startsWithDirective := false + for directive := range tt.directives { if strings.HasPrefix(got, directive) { startsWithDirective = true @@ -108,6 +111,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { break } } + if !startsWithDirective { t.Errorf("ContentSecurityPolicyBuilder.Build() = '%v', does not start with directive name", got) } @@ -116,6 +120,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { if strings.HasSuffix(got, " ") { t.Errorf("ContentSecurityPolicyBuilder.Build() = '%v', ends on whitespace", got) } + if strings.HasSuffix(got, ";") { t.Errorf("ContentSecurityPolicyBuilder.Build() = '%v', ends on semi-colon", got) } diff --git a/cspbuilder/directive_builder_test.go b/cspbuilder/directive_builder_test.go index 7d31133..8cab58d 100644 --- a/cspbuilder/directive_builder_test.go +++ b/cspbuilder/directive_builder_test.go @@ -79,6 +79,7 @@ func TestBuildDirectiveFrameAncestors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sb := &strings.Builder{} + err := buildDirectiveFrameAncestors(sb, tt.values) if tt.wantErr && err != nil { return @@ -222,6 +223,7 @@ func TestBuildDirectiveTrustedTypes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sb := &strings.Builder{} + err := buildDirectiveTrustedTypes(sb, tt.values) if tt.wantErr && err != nil { return diff --git a/secure_test.go b/secure_test.go index 107ce71..23f0e44 100644 --- a/secure_test.go +++ b/secure_test.go @@ -10,8 +10,14 @@ import ( "testing" ) +const ( + httpSchema = "http" + httpsSchema = "https" + exampleHost = "www.example.com" +) + //nolint:gochecknoglobals -var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +var myHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("bar")) }) @@ -34,7 +40,7 @@ func TestNoAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -49,7 +55,7 @@ func TestGoodSingleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -64,7 +70,7 @@ func TestBadSingleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -79,7 +85,7 @@ func TestRegexSingleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "sub.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -140,7 +146,7 @@ func TestGoodSingleAllowHostsProxyHeaders(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "example-internal" + req.Host = "example-internal-3" req.Header.Set("X-Proxy-Host", "www.example.com") s.Handler(myHandler).ServeHTTP(res, req) @@ -172,7 +178,7 @@ func TestGoodMultipleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "sub.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -202,7 +208,7 @@ func TestAllowHostsInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www3.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -214,7 +220,7 @@ func TestBadHostHandler(t *testing.T) { AllowedHosts: []string{"www.example.com", "sub.example.com"}, }) - badHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + badHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "BadHost", http.StatusInternalServerError) }) @@ -239,8 +245,8 @@ func TestSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -255,8 +261,8 @@ func TestSSLInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -270,8 +276,8 @@ func TestBasicSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -287,8 +293,8 @@ func TestBasicSSLWithHost(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -300,7 +306,7 @@ func TestBasicSSLWithHostFunc(t *testing.T) { sslHostFunc := (func() SSLHostFunc { return func(host string) string { newHost := "" - if host == "www.example.com" { + if host == exampleHost { newHost = "secure.example.com:8443" } else if host == "www.example.org" { newHost = "secure.example.org" @@ -317,8 +323,8 @@ func TestBasicSSLWithHostFunc(t *testing.T) { // test www.example.com res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -329,7 +335,7 @@ func TestBasicSSLWithHostFunc(t *testing.T) { res = httptest.NewRecorder() req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) req.Host = "www.example.org" - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -340,7 +346,7 @@ func TestBasicSSLWithHostFunc(t *testing.T) { res = httptest.NewRecorder() req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) req.Host = "www.other.com" - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -355,8 +361,8 @@ func TestBadProxySSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -373,8 +379,8 @@ func TestCustomProxySSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -391,8 +397,8 @@ func TestCustomProxySSLInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "http") s.Handler(myHandler).ServeHTTP(res, req) @@ -410,7 +416,7 @@ func TestCustomProxyAndHostProxyHeadersWithRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) req.Host = "example-internal" - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") req.Header.Add("X-Forwarded-Host", "www.example.com") @@ -429,8 +435,8 @@ func TestCustomProxyAndHostSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -447,8 +453,8 @@ func TestCustomBadProxyAndHostSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -467,8 +473,8 @@ func TestCustomBadProxyAndHostSSLWithTempRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -545,7 +551,7 @@ func TestStsHeaderWithSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -562,7 +568,7 @@ func TestStsHeaderWithSSLForRequestOnly(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -579,7 +585,7 @@ func TestStsHeaderInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -595,7 +601,7 @@ func TestStsHeaderWithSubdomains(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -611,7 +617,7 @@ func TestStsHeaderWithSubdomainsForRequestOnly(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -627,7 +633,7 @@ func TestStsHeaderWithPreload(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -643,7 +649,7 @@ func TestStsHeaderWithPreloadForRequest(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -660,7 +666,7 @@ func TestStsHeaderWithSubdomainsWithPreload(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -677,7 +683,7 @@ func TestStsHeaderWithSubdomainsWithPreloadForRequestOnly(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -1054,7 +1060,7 @@ func TestIsSSL(t *testing.T) { expect(t, s.isSSL(req), true) req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema expect(t, s.isSSL(req), true) req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) @@ -1071,8 +1077,8 @@ func TestSSLForceHostWithHTTPS(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -1088,8 +1094,8 @@ func TestSSLForceHostWithHTTP(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "http") s.Handler(myHandler).ServeHTTP(res, req) @@ -1107,8 +1113,8 @@ func TestSSLForceHostWithSSLRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -1126,8 +1132,8 @@ func TestSSLForceHostTemporaryRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -1157,8 +1163,8 @@ func TestModifyResponseHeadersWithSSLAndDifferentSSLHost(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1180,8 +1186,8 @@ func TestModifyResponseHeadersWithSSLAndNoSSLHost(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1204,8 +1210,8 @@ func TestModifyResponseHeadersWithSSLAndMatchingSSLHost(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1228,8 +1234,8 @@ func TestModifyResponseHeadersWithSSLAndPortInLocationResponse(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1252,8 +1258,8 @@ func TestModifyResponseHeadersWithSSLAndPathInLocationResponse(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1280,7 +1286,7 @@ func TestCustomSecureContextKey(t *testing.T) { var actual *http.Request - hf := func(w http.ResponseWriter, r *http.Request) { + hf := func(_ http.ResponseWriter, r *http.Request) { actual = r } @@ -1306,7 +1312,7 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) { var actual *http.Request - hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hf := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { actual = r }) @@ -1322,7 +1328,7 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) { func TestAllowRequestFuncTrue(t *testing.T) { s := New(Options{ - AllowRequestFunc: func(r *http.Request) bool { return true }, + AllowRequestFunc: func(_ *http.Request) bool { return true }, }) res := httptest.NewRecorder() @@ -1337,7 +1343,7 @@ func TestAllowRequestFuncTrue(t *testing.T) { func TestAllowRequestFuncFalse(t *testing.T) { s := New(Options{ - AllowRequestFunc: func(r *http.Request) bool { return false }, + AllowRequestFunc: func(_ *http.Request) bool { return false }, }) res := httptest.NewRecorder() @@ -1351,9 +1357,9 @@ func TestAllowRequestFuncFalse(t *testing.T) { func TestBadRequestHandler(t *testing.T) { s := New(Options{ - AllowRequestFunc: func(r *http.Request) bool { return false }, + AllowRequestFunc: func(_ *http.Request) bool { return false }, }) - badRequestFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + badRequestFunc := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "custom error", http.StatusConflict) }) s.SetBadRequestHandler(badRequestFunc) From 3d539f92663c30ffe95cf00b65ff66a4f0330ee4 Mon Sep 17 00:00:00 2001 From: Cory Jacobsen Date: Tue, 25 Jun 2024 10:38:15 -0600 Subject: [PATCH 02/10] fix: sort csp directives per w3 spec (#96) --- .github/workflows/test.yaml | 2 +- cspbuilder/builder.go | 12 +++++++++++- cspbuilder/builder_test.go | 7 +++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 485d6c6..089c34e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -28,4 +28,4 @@ jobs: path: src/github.com/unrolled/secure - uses: golangci/golangci-lint-action@v4 with: - working-directory: src/github.com/7shifts/seven-deploy + working-directory: src/github.com/unrolled/secure diff --git a/cspbuilder/builder.go b/cspbuilder/builder.go index 573fa2e..685ab5f 100644 --- a/cspbuilder/builder.go +++ b/cspbuilder/builder.go @@ -1,6 +1,7 @@ package cspbuilder import ( + "sort" "strings" ) @@ -62,7 +63,16 @@ func (builder *Builder) MustBuild() string { func (builder *Builder) Build() (string, error) { var sb strings.Builder - for directive := range builder.Directives { + // Pull the directive keys out. + directiveKeys := []string{} + for key := range builder.Directives { + directiveKeys = append(directiveKeys, key) + } + + // Sort the policies: https://www.w3.org/TR/CSP3/#framework-policy + sort.Strings(directiveKeys) + + for _, directive := range directiveKeys { if sb.Len() > 0 { sb.WriteString("; ") } diff --git a/cspbuilder/builder_test.go b/cspbuilder/builder_test.go index 35079d3..c0c4f16 100644 --- a/cspbuilder/builder_test.go +++ b/cspbuilder/builder_test.go @@ -63,6 +63,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { directives map[string]([]string) builder Builder wantParts []string + wantFull string wantErr bool }{ { @@ -86,6 +87,8 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { "trusted-types policy-1 policy-#=_/@.% 'allow-duplicates'", "upgrade-insecure-requests", }, + + wantFull: "default-src 'self' example.com *.example.com; frame-ancestors 'self' http://*.example.com; report-to group1; require-trusted-types-for 'script'; sandbox allow-scripts; trusted-types policy-1 policy-#=_/@.% 'allow-duplicates'; upgrade-insecure-requests", }, } for _, tt := range tests { @@ -101,6 +104,10 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { return } + if got != tt.wantFull { + t.Errorf("ContentSecurityPolicyBuilder.Build() full = %v, but wanted %v", got, tt.wantFull) + } + { startsWithDirective := false From 109136a2c6bbc379906e566213a43a59913b8be7 Mon Sep 17 00:00:00 2001 From: Bas Date: Fri, 27 Sep 2024 22:53:33 +0200 Subject: [PATCH 03/10] Fixed comment to mention the correct HTTP status code used (#98) The comment says a 302 will be used but the code actually uses Go core lib `http.StatusTemporaryRedirect` which is 307 (which i think is the right one to use because that instructs clients to re-send the same HTTP method as in the original request) --- secure.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/secure.go b/secure.go index 0efcc61..9663d2a 100644 --- a/secure.go +++ b/secure.go @@ -65,7 +65,7 @@ type Options struct { SSLRedirect bool // If SSLForceHost is true and SSLHost is set, requests will be forced to use SSLHost even the ones that are already using SSL. Default is false. SSLForceHost bool - // If SSLTemporaryRedirect is true, the a 302 will be used while redirecting. Default is false (301). + // If SSLTemporaryRedirect is true, then a 307 will be used while redirecting. Default is false (301). SSLTemporaryRedirect bool // If STSIncludeSubdomains is set to true, the `includeSubdomains` will be appended to the Strict-Transport-Security header. Default is false. STSIncludeSubdomains bool From 9bedcaa71aec0a1324b6b4070114788fa8b17b72 Mon Sep 17 00:00:00 2001 From: Cory Jacobsen Date: Fri, 27 Sep 2024 14:54:37 -0600 Subject: [PATCH 04/10] chore: fix readme comment (#99) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4ec82d5..1add1c7 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ s := secure.New(secure.Options{ AllowRequestFunc: nil, // AllowRequestFunc is a custom function type that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil. HostsProxyHeaders: []string{"X-Forwarded-Hosts"}, // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request. SSLRedirect: true, // If SSLRedirect is set to true, then only allow HTTPS requests. Default is false. - SSLTemporaryRedirect: false, // If SSLTemporaryRedirect is true, the a 302 will be used while redirecting. Default is false (301). + SSLTemporaryRedirect: false, // If SSLTemporaryRedirect is true, then a 307 will be used while redirecting. Default is false (301). SSLHost: "ssl.example.com", // SSLHost is the host name that is used to redirect HTTP requests to HTTPS. Default is "", which indicates to use the same host. SSLHostFunc: nil, // SSLHostFunc is a function pointer, the return value of the function is the host name that has same functionality as `SSHost`. Default is nil. If SSLHostFunc is nil, the `SSLHost` option will be used. SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"}, // SSLProxyHeaders is set of header keys with associated values that would indicate a valid HTTPS request. Useful when using Nginx: `map[string]string{"X-Forwarded-Proto": "https"}`. Default is blank map. From dbf54fcd6ecf3a695a663909ea47e2cbd1953f8a Mon Sep 17 00:00:00 2001 From: aerth <6263105+aerth@users.noreply.github.com> Date: Fri, 27 Sep 2024 20:56:11 +0000 Subject: [PATCH 05/10] skip func fields in marshalling Options (json,yaml,toml) (#97) Co-authored-by: aerth --- secure.go | 4 ++-- secure_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/secure.go b/secure.go index 9663d2a..f72842c 100644 --- a/secure.go +++ b/secure.go @@ -98,12 +98,12 @@ type Options struct { // AllowedHostsAreRegex determines, if the provided `AllowedHosts` slice contains valid regular expressions. If this flag is set to true, every request's host will be checked against these expressions. Default is false. AllowedHostsAreRegex bool // AllowRequestFunc is a custom function that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil. - AllowRequestFunc AllowRequestFunc + AllowRequestFunc AllowRequestFunc `json:"-" yaml:"-" toml:"-"` // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request. HostsProxyHeaders []string // SSLHostFunc is a function pointer, the return value of the function is the host name that has same functionality as `SSHost`. Default is nil. // If SSLHostFunc is nil, the `SSLHost` option will be used. - SSLHostFunc *SSLHostFunc + SSLHostFunc *SSLHostFunc `json:"-" yaml:"-" toml:"-"` // SSLProxyHeaders is set of header keys with associated values that would indicate a valid https request. Useful when using Nginx: `map[string]string{"X-Forwarded-Proto": "https"}`. Default is blank map. SSLProxyHeaders map[string]string // STSSeconds is the max-age of the Strict-Transport-Security header. Default is 0, which would NOT include the header. diff --git a/secure_test.go b/secure_test.go index 23f0e44..1a47fe1 100644 --- a/secure_test.go +++ b/secure_test.go @@ -3,6 +3,7 @@ package secure import ( "context" "crypto/tls" + "encoding/json" "net/http" "net/http/httptest" "reflect" @@ -1374,6 +1375,41 @@ func TestBadRequestHandler(t *testing.T) { expect(t, strings.TrimSpace(res.Body.String()), `custom error`) } +func TestMarshal(t *testing.T) { + // func cant be marshalled + var t1 = struct { + A string + F func() + }{} + _, err := json.Marshal(t1) //lint:ignore SA1026 ignore marshal error + if err == nil { + t.Error("expected error got none") + } else if !strings.Contains(err.Error(), "unsupported type: func()") { + t.Error("unexpected error:", err) + } + + // struct field tags omits func + var t2 = struct { + A string + F func() `json:"-"` + }{} + _, err = json.Marshal(t2) + if err != nil { + t.Error("unexpected error:", err) + } + + // Options has struct field tags to omit func fields + var o1 Options + b, err := json.Marshal(o1) + if err != nil { + t.Errorf("unexpected error marshal: %v", err) + } + err = json.Unmarshal(b, &o1) + if err != nil { + t.Errorf("unexpected error unmarshal: %v", err) + } +} + // Test Helper. func expect(t *testing.T, a interface{}, b interface{}) { t.Helper() From a852e7b610a1aab4f303dbf89185d7bb0f345eb8 Mon Sep 17 00:00:00 2001 From: Cory Jacobsen Date: Fri, 27 Sep 2024 15:02:29 -0600 Subject: [PATCH 06/10] lint: fixing linter errs (#100) --- .github/workflows/lint.yaml | 2 +- .golangci.yaml | 1 + secure.go | 4 ++-- secure_test.go | 28 ++++------------------------ 4 files changed, 8 insertions(+), 27 deletions(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index effbbad..008af37 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -14,6 +14,6 @@ jobs: - uses: actions/checkout@v4 with: path: src/github.com/unrolled/secure - - uses: golangci/golangci-lint-action@v4 + - uses: golangci/golangci-lint-action@v6 with: working-directory: src/github.com/unrolled/secure diff --git a/.golangci.yaml b/.golangci.yaml index 3eab3e6..f9bbd25 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -29,3 +29,4 @@ linters: - maintidx - contextcheck - perfsprint + - exportloopref diff --git a/secure.go b/secure.go index f72842c..9b93960 100644 --- a/secure.go +++ b/secure.go @@ -98,12 +98,12 @@ type Options struct { // AllowedHostsAreRegex determines, if the provided `AllowedHosts` slice contains valid regular expressions. If this flag is set to true, every request's host will be checked against these expressions. Default is false. AllowedHostsAreRegex bool // AllowRequestFunc is a custom function that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil. - AllowRequestFunc AllowRequestFunc `json:"-" yaml:"-" toml:"-"` + AllowRequestFunc AllowRequestFunc `json:"-" toml:"-" yaml:"-"` // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request. HostsProxyHeaders []string // SSLHostFunc is a function pointer, the return value of the function is the host name that has same functionality as `SSHost`. Default is nil. // If SSLHostFunc is nil, the `SSLHost` option will be used. - SSLHostFunc *SSLHostFunc `json:"-" yaml:"-" toml:"-"` + SSLHostFunc *SSLHostFunc `json:"-" toml:"-" yaml:"-"` // SSLProxyHeaders is set of header keys with associated values that would indicate a valid https request. Useful when using Nginx: `map[string]string{"X-Forwarded-Proto": "https"}`. Default is blank map. SSLProxyHeaders map[string]string // STSSeconds is the max-age of the Strict-Transport-Security header. Default is 0, which would NOT include the header. diff --git a/secure_test.go b/secure_test.go index 1a47fe1..80d8d2a 100644 --- a/secure_test.go +++ b/secure_test.go @@ -1376,35 +1376,15 @@ func TestBadRequestHandler(t *testing.T) { } func TestMarshal(t *testing.T) { - // func cant be marshalled - var t1 = struct { - A string - F func() - }{} - _, err := json.Marshal(t1) //lint:ignore SA1026 ignore marshal error - if err == nil { - t.Error("expected error got none") - } else if !strings.Contains(err.Error(), "unsupported type: func()") { - t.Error("unexpected error:", err) - } - - // struct field tags omits func - var t2 = struct { - A string - F func() `json:"-"` - }{} - _, err = json.Marshal(t2) - if err != nil { - t.Error("unexpected error:", err) - } - // Options has struct field tags to omit func fields var o1 Options - b, err := json.Marshal(o1) + + b, err := json.Marshal(o1) //nolint:musttag if err != nil { t.Errorf("unexpected error marshal: %v", err) } - err = json.Unmarshal(b, &o1) + + err = json.Unmarshal(b, &o1) //nolint:musttag if err != nil { t.Errorf("unexpected error unmarshal: %v", err) } From 973d4ea020cfdab84551a595557b56af9824d749 Mon Sep 17 00:00:00 2001 From: Cory Jacobsen Date: Fri, 27 Sep 2024 15:04:23 -0600 Subject: [PATCH 07/10] chore: add go 1.23 to test matrix (#101) --- .github/workflows/test.yaml | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 089c34e..d3b9d04 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,7 +11,7 @@ jobs: tests: strategy: matrix: - go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x, 1.22.x] + go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x, 1.22.x, 1.23.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} steps: @@ -20,12 +20,3 @@ jobs: with: go-version: ${{ matrix.go-version }} - run: make ci - golangci: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - path: src/github.com/unrolled/secure - - uses: golangci/golangci-lint-action@v4 - with: - working-directory: src/github.com/unrolled/secure From 67655e9b23213c3882d99a9a47e611f33f198ab2 Mon Sep 17 00:00:00 2001 From: Cory Jacobsen Date: Tue, 15 Oct 2024 09:49:31 -0600 Subject: [PATCH 08/10] Update badge references for README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1add1c7..6bdb95c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Secure [![GoDoc](https://godoc.org/github.com/unrolled/secure?status.svg)](http://godoc.org/github.com/unrolled/secure) [![Test](https://github.com/unrolled/secure/workflows/tests/badge.svg?branch=v1)](https://github.com/unrolled/secure/actions) +# Secure [![GoDoc](https://pkg.go.dev/badge/github.com/unrolled/secure)](http://godoc.org/github.com/unrolled/secure) [![Test](https://github.com/unrolled/secure/actions/workflows/test.yaml/badge.svg)](https://github.com/unrolled/secure/actions) Secure is an HTTP middleware for Go that facilitates some quick security wins. It's a standard net/http [Handler](http://golang.org/pkg/net/http/#Handler), and can be used with many [frameworks](#integration-examples) or directly with Go's net/http package. From c88f91938057576c002f31c1f4c32802f035fac5 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Tue, 22 Oct 2024 17:38:51 -0400 Subject: [PATCH 09/10] feat: add coep, corp, x-dns-prefetch-control, x-permitted-cross-doman-policies (#102) * feat: add coep, corp, x-dns-prefetch-control, x-permitted-cross-domain-policies headers * updated readme * update XDNSPrefetchControl to be of type string and fix test * add newly added variables in the default section * remove len check --- README.md | 9 +++++- secure.go | 76 +++++++++++++++++++++++++++++++++++++------------- secure_test.go | 68 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 6bdb95c..cf8b92d 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,10 @@ s := secure.New(secure.Options{ FeaturePolicy: "vibrate 'none';", // Deprecated: this header has been renamed to PermissionsPolicy. FeaturePolicy allows the Feature-Policy header with the value to be set with a custom value. Default is "". PermissionsPolicy: "fullscreen=(), geolocation=()", // PermissionsPolicy allows the Permissions-Policy header with the value to be set with a custom value. Default is "". CrossOriginOpenerPolicy: "same-origin", // CrossOriginOpenerPolicy allows the Cross-Origin-Opener-Policy header with the value to be set with a custom value. Default is "". - + CrossOriginEmbedderPolicy: "require-corp", // CrossOriginEmbedderPolicy allows the Cross-Origin-Embedder-Policy header with the value to be set with a custom value. Default is "". + CrossOriginResourcePolicy: "same-origin", // CrossOriginResourcePolicy allows the Cross-Origin-Resource-Policy header with the value to be set with a custom value. Default is "". + XDNSPrefetchControl: "on", // XDNSPrefetchControl allows the X-DNS-Prefetch-Control header to be set via "on" or "off" keyword. Default is "". + XPermittedCrossDomainPolicies: "none", // XPermittedCrossDomainPolicies allows the X-Permitted-Cross-Domain-Policies to be set with a custom value. Default is "". IsDevelopment: true, // This will cause the AllowedHosts, SSLRedirect, and STSSeconds/STSIncludeSubdomains options to be ignored during development. When deploying to production, be sure to set this to false. }) // ... @@ -121,6 +124,10 @@ l := secure.New(secure.Options{ FeaturePolicy: "", PermissionsPolicy: "", CrossOriginOpenerPolicy: "", + CrossOriginEmbedderPolicy: "", + CrossOriginResourcePolicy: "", + XDNSPrefetchControl: "", + XPermittedCrossDomainPolicies: "", IsDevelopment: false, }) ~~~ diff --git a/secure.go b/secure.go index 9b93960..122c43c 100644 --- a/secure.go +++ b/secure.go @@ -11,25 +11,28 @@ import ( type secureCtxKey string const ( - stsHeader = "Strict-Transport-Security" - stsSubdomainString = "; includeSubDomains" - stsPreloadString = "; preload" - frameOptionsHeader = "X-Frame-Options" - frameOptionsValue = "DENY" - contentTypeHeader = "X-Content-Type-Options" - contentTypeValue = "nosniff" - xssProtectionHeader = "X-XSS-Protection" - xssProtectionValue = "1; mode=block" - cspHeader = "Content-Security-Policy" - cspReportOnlyHeader = "Content-Security-Policy-Report-Only" - hpkpHeader = "Public-Key-Pins" - referrerPolicyHeader = "Referrer-Policy" - featurePolicyHeader = "Feature-Policy" - permissionsPolicyHeader = "Permissions-Policy" - coopHeader = "Cross-Origin-Opener-Policy" - - ctxDefaultSecureHeaderKey = secureCtxKey("SecureResponseHeader") - cspNonceSize = 16 + stsHeader = "Strict-Transport-Security" + stsSubdomainString = "; includeSubDomains" + stsPreloadString = "; preload" + frameOptionsHeader = "X-Frame-Options" + frameOptionsValue = "DENY" + contentTypeHeader = "X-Content-Type-Options" + contentTypeValue = "nosniff" + xssProtectionHeader = "X-XSS-Protection" + xssProtectionValue = "1; mode=block" + cspHeader = "Content-Security-Policy" + cspReportOnlyHeader = "Content-Security-Policy-Report-Only" + hpkpHeader = "Public-Key-Pins" + referrerPolicyHeader = "Referrer-Policy" + featurePolicyHeader = "Feature-Policy" + permissionsPolicyHeader = "Permissions-Policy" + coopHeader = "Cross-Origin-Opener-Policy" + coepHeader = "Cross-Origin-Embedder-Policy" + corpHeader = "Cross-Origin-Resource-Policy" + dnsPreFetchControlHeader = "X-DNS-Prefetch-Control" + permittedCrossDomainPolicies = "X-Permitted-Cross-Domain-Policies" + ctxDefaultSecureHeaderKey = secureCtxKey("SecureResponseHeader") + cspNonceSize = 16 ) // SSLHostFunc is a custom function type that can be used to dynamically set the SSL host of a request. @@ -91,6 +94,18 @@ type Options struct { // CrossOriginOpenerPolicy allows you to ensure a top-level document does not share a browsing context group with cross-origin documents. Default is "". // Reference: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cross-Origin-Opener-Policy CrossOriginOpenerPolicy string + // CrossOriginResourcePolicy header blocks others from loading your resources cross-origin in some cases. + // Reference https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cross-Origin-Opener-Policy + CrossOriginResourcePolicy string + // CrossOriginEmbedderPolicy header helps control what resources can be loaded cross-origin. + // Reference https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cross-Origin-Embedder-Policy + CrossOriginEmbedderPolicy string + // XDNSPrefetchControl header helps control DNS prefetching, which can improve user privacy at the expense of performance. + // Reference: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-DNS-Prefetch-Control + XDNSPrefetchControl string + // XPermittedCrossDomainPolicies header tells some clients (mostly Adobe products) your domain's policy for loading cross-domain content. + // Reference: https://owasp.org/www-project-secure-headers/ + XPermittedCrossDomainPolicies string // SSLHost is the host name that is used to redirect http requests to https. Default is "", which indicates to use the same host. SSLHost string // AllowedHosts is a slice of fully qualified domain names that are allowed. Default is an empty slice, which allows any and all host names. @@ -466,6 +481,29 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He responseHeader.Set(coopHeader, s.opt.CrossOriginOpenerPolicy) } + // Cross Origin Resource Policy header. + if len(s.opt.CrossOriginResourcePolicy) > 0 { + responseHeader.Set(corpHeader, s.opt.CrossOriginResourcePolicy) + } + + // Cross-Origin-Embedder-Policy header. + if len(s.opt.CrossOriginEmbedderPolicy) > 0 { + responseHeader.Set(coepHeader, s.opt.CrossOriginEmbedderPolicy) + } + + // X-DNS-Prefetch-Control header. + switch strings.ToLower(s.opt.XDNSPrefetchControl) { + case "on": + responseHeader.Set(dnsPreFetchControlHeader, "on") + case "off": + responseHeader.Set(dnsPreFetchControlHeader, "off") + } + + // X-Permitted-Cross-Domain-Policies header. + if len(s.opt.XPermittedCrossDomainPolicies) > 0 { + responseHeader.Set(permittedCrossDomainPolicies, s.opt.XPermittedCrossDomainPolicies) + } + return responseHeader, r, nil } diff --git a/secure_test.go b/secure_test.go index 80d8d2a..52a5c63 100644 --- a/secure_test.go +++ b/secure_test.go @@ -1046,6 +1046,74 @@ func TestCrossOriginOpenerPolicy(t *testing.T) { expect(t, res.Header().Get("Cross-Origin-Opener-Policy"), "same-origin") } +func TestCrossOriginEmbedderPolicy(t *testing.T) { + s := New(Options{ + CrossOriginEmbedderPolicy: "require-corp", + }) + + res := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) + + s.Handler(myHandler).ServeHTTP(res, req) + + expect(t, res.Code, http.StatusOK) + expect(t, res.Header().Get("Cross-Origin-Embedder-Policy"), "require-corp") +} + +func TestCrossOriginResourcePolicy(t *testing.T) { + s := New(Options{ + CrossOriginResourcePolicy: "same-origin", + }) + + res := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) + + s.Handler(myHandler).ServeHTTP(res, req) + + expect(t, res.Code, http.StatusOK) + expect(t, res.Header().Get("Cross-Origin-Resource-Policy"), "same-origin") +} + +func TestXDNSPreFetchControl(t *testing.T) { + s := New(Options{ + XDNSPrefetchControl: "on", + }) + + res := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) + + s.Handler(myHandler).ServeHTTP(res, req) + + expect(t, res.Code, http.StatusOK) + expect(t, res.Header().Get("X-DNS-Prefetch-Control"), "on") + + k := New(Options{ + XDNSPrefetchControl: "off", + }) + + res = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/bar", nil) + + k.Handler(myHandler).ServeHTTP(res, req) + + expect(t, res.Code, http.StatusOK) + expect(t, res.Header().Get("X-DNS-Prefetch-Control"), "off") +} + +func TestXPermittedCrossDomainPolicies(t *testing.T) { + s := New(Options{ + XPermittedCrossDomainPolicies: "none", + }) + + res := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) + + s.Handler(myHandler).ServeHTTP(res, req) + + expect(t, res.Code, http.StatusOK) + expect(t, res.Header().Get("X-Permitted-Cross-Domain-Policies"), "none") +} + func TestIsSSL(t *testing.T) { s := New(Options{ SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"}, From ba42bc3377183ef4987b3f718ae3786938f457d6 Mon Sep 17 00:00:00 2001 From: Cory Jacobsen Date: Tue, 22 Oct 2024 15:50:50 -0600 Subject: [PATCH 10/10] chore: ensure the package does not add any headers when no config is provided (#103) --- secure_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/secure_test.go b/secure_test.go index 52a5c63..0411d94 100644 --- a/secure_test.go +++ b/secure_test.go @@ -22,6 +22,18 @@ var myHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("bar")) }) +func TestNoConfigOnlyOneHeader(t *testing.T) { + s := New() + + res := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com/foo", nil) + + s.Handler(myHandler).ServeHTTP(res, req) + + expect(t, len(res.Header()), 1) + expect(t, strings.Contains(res.Header().Get("Content-Type"), "text/plain"), true) +} + func TestNoConfig(t *testing.T) { s := New()