From 8b61a4e63d7e4f545b2f2959a49e5c45bd388da8 Mon Sep 17 00:00:00 2001 From: Cory Jacobsen Date: Fri, 21 Jun 2024 11:07:51 -0600 Subject: [PATCH 1/2] bringing linter up to date (#95) --- .github/workflows/lint.yaml | 19 ++++ .github/workflows/test.yaml | 22 ++-- .golangci.yaml | 43 ++++---- csp.go | 3 +- csp_test.go | 5 +- cspbuilder/builder_test.go | 5 + cspbuilder/directive_builder_test.go | 2 + secure_test.go | 144 ++++++++++++++------------- 8 files changed, 137 insertions(+), 106 deletions(-) create mode 100644 .github/workflows/lint.yaml diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..effbbad --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,19 @@ +on: + push: + branches: + - master + - v1 + pull_request: + branches: + - "**" +name: Linter +jobs: + golangci: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + path: src/github.com/unrolled/secure + - uses: golangci/golangci-lint-action@v4 + with: + working-directory: src/github.com/unrolled/secure diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 00fe2d2..485d6c6 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -6,26 +6,26 @@ on: pull_request: branches: - "**" -name: tests +name: Tests jobs: tests: strategy: matrix: - go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x] + go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x, 1.22.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} steps: - - name: Install Go - uses: actions/setup-go@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} - - name: Checkout code - uses: actions/checkout@v2 - - name: Test - run: make ci + - run: make ci golangci: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: golangci-lint - uses: golangci/golangci-lint-action@v2 + - uses: actions/checkout@v4 + with: + path: src/github.com/unrolled/secure + - uses: golangci/golangci-lint-action@v4 + with: + working-directory: src/github.com/7shifts/seven-deploy diff --git a/.golangci.yaml b/.golangci.yaml index 392b4b3..3eab3e6 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,36 +1,31 @@ run: - timeout: 5m + timeout: 10m + modules-download-mode: readonly + allow-parallel-runners: true linters: enable-all: true disable: - # Deprecated linters - - varcheck - - exhaustivestruct - - ifshort - - structcheck - - golint - - maligned - - interfacer - - nosnakecase - - deadcode - - scopelint - - rowserrcheck - - sqlclosecheck - - structcheck - - wastedassign - # Ignoring - - lll - - varnamelen - paralleltest - - testpackage - - goerr113 + - gochecknoglobals - exhaustruct - - nestif + - wrapcheck + - tagliatelle + - depguard + - ireturn - funlen - - goconst + - varnamelen + - gomnd + - execinquery + - copyloopvar + - intrange + - gocognit + - lll - cyclop - gocyclo - - gocognit + - testpackage + - err113 + - nestif - maintidx - contextcheck + - perfsprint diff --git a/csp.go b/csp.go index 948358f..4d516c2 100644 --- a/csp.go +++ b/csp.go @@ -12,7 +12,8 @@ type key int const cspNonceKey key = iota -// CSPNonce returns the nonce value associated with the present request. If no nonce has been generated it returns an empty string. +// CSPNonce returns the nonce value associated with the present request. +// If no nonce has been generated it returns an empty string. func CSPNonce(c context.Context) string { if val, ok := c.Value(cspNonceKey).(string); ok { return val diff --git a/csp_test.go b/csp_test.go index edf662f..f09042b 100644 --- a/csp_test.go +++ b/csp_test.go @@ -23,7 +23,10 @@ func TestCSPNonce(t *testing.T) { }{ {Options{ContentSecurityPolicy: csp}, []string{"Content-Security-Policy"}}, {Options{ContentSecurityPolicyReportOnly: csp}, []string{"Content-Security-Policy-Report-Only"}}, - {Options{ContentSecurityPolicy: csp, ContentSecurityPolicyReportOnly: csp}, []string{"Content-Security-Policy", "Content-Security-Policy-Report-Only"}}, + { + Options{ContentSecurityPolicy: csp, ContentSecurityPolicyReportOnly: csp}, + []string{"Content-Security-Policy", "Content-Security-Policy-Report-Only"}, + }, } for _, c := range cases { diff --git a/cspbuilder/builder_test.go b/cspbuilder/builder_test.go index fb4645a..35079d3 100644 --- a/cspbuilder/builder_test.go +++ b/cspbuilder/builder_test.go @@ -42,6 +42,7 @@ func TestContentSecurityPolicyBuilder_Build_SingleDirective(t *testing.T) { tt.directiveName: tt.directiveValues, }, } + got, err := builder.Build() if (err != nil) != tt.wantErr { t.Errorf("ContentSecurityPolicyBuilder.Build() error = %v, wantErr %v", err, tt.wantErr) @@ -92,6 +93,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { builder := &Builder{ Directives: tt.directives, } + got, err := builder.Build() if (err != nil) != tt.wantErr { t.Errorf("ContentSecurityPolicyBuilder.Build() error = %v, wantErr %v", err, tt.wantErr) @@ -101,6 +103,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { { startsWithDirective := false + for directive := range tt.directives { if strings.HasPrefix(got, directive) { startsWithDirective = true @@ -108,6 +111,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { break } } + if !startsWithDirective { t.Errorf("ContentSecurityPolicyBuilder.Build() = '%v', does not start with directive name", got) } @@ -116,6 +120,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { if strings.HasSuffix(got, " ") { t.Errorf("ContentSecurityPolicyBuilder.Build() = '%v', ends on whitespace", got) } + if strings.HasSuffix(got, ";") { t.Errorf("ContentSecurityPolicyBuilder.Build() = '%v', ends on semi-colon", got) } diff --git a/cspbuilder/directive_builder_test.go b/cspbuilder/directive_builder_test.go index 7d31133..8cab58d 100644 --- a/cspbuilder/directive_builder_test.go +++ b/cspbuilder/directive_builder_test.go @@ -79,6 +79,7 @@ func TestBuildDirectiveFrameAncestors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sb := &strings.Builder{} + err := buildDirectiveFrameAncestors(sb, tt.values) if tt.wantErr && err != nil { return @@ -222,6 +223,7 @@ func TestBuildDirectiveTrustedTypes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sb := &strings.Builder{} + err := buildDirectiveTrustedTypes(sb, tt.values) if tt.wantErr && err != nil { return diff --git a/secure_test.go b/secure_test.go index 107ce71..23f0e44 100644 --- a/secure_test.go +++ b/secure_test.go @@ -10,8 +10,14 @@ import ( "testing" ) +const ( + httpSchema = "http" + httpsSchema = "https" + exampleHost = "www.example.com" +) + //nolint:gochecknoglobals -var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +var myHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("bar")) }) @@ -34,7 +40,7 @@ func TestNoAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -49,7 +55,7 @@ func TestGoodSingleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -64,7 +70,7 @@ func TestBadSingleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -79,7 +85,7 @@ func TestRegexSingleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "sub.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -140,7 +146,7 @@ func TestGoodSingleAllowHostsProxyHeaders(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "example-internal" + req.Host = "example-internal-3" req.Header.Set("X-Proxy-Host", "www.example.com") s.Handler(myHandler).ServeHTTP(res, req) @@ -172,7 +178,7 @@ func TestGoodMultipleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "sub.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -202,7 +208,7 @@ func TestAllowHostsInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www3.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -214,7 +220,7 @@ func TestBadHostHandler(t *testing.T) { AllowedHosts: []string{"www.example.com", "sub.example.com"}, }) - badHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + badHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "BadHost", http.StatusInternalServerError) }) @@ -239,8 +245,8 @@ func TestSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -255,8 +261,8 @@ func TestSSLInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -270,8 +276,8 @@ func TestBasicSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -287,8 +293,8 @@ func TestBasicSSLWithHost(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -300,7 +306,7 @@ func TestBasicSSLWithHostFunc(t *testing.T) { sslHostFunc := (func() SSLHostFunc { return func(host string) string { newHost := "" - if host == "www.example.com" { + if host == exampleHost { newHost = "secure.example.com:8443" } else if host == "www.example.org" { newHost = "secure.example.org" @@ -317,8 +323,8 @@ func TestBasicSSLWithHostFunc(t *testing.T) { // test www.example.com res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -329,7 +335,7 @@ func TestBasicSSLWithHostFunc(t *testing.T) { res = httptest.NewRecorder() req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) req.Host = "www.example.org" - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -340,7 +346,7 @@ func TestBasicSSLWithHostFunc(t *testing.T) { res = httptest.NewRecorder() req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) req.Host = "www.other.com" - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -355,8 +361,8 @@ func TestBadProxySSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -373,8 +379,8 @@ func TestCustomProxySSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -391,8 +397,8 @@ func TestCustomProxySSLInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "http") s.Handler(myHandler).ServeHTTP(res, req) @@ -410,7 +416,7 @@ func TestCustomProxyAndHostProxyHeadersWithRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) req.Host = "example-internal" - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") req.Header.Add("X-Forwarded-Host", "www.example.com") @@ -429,8 +435,8 @@ func TestCustomProxyAndHostSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -447,8 +453,8 @@ func TestCustomBadProxyAndHostSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -467,8 +473,8 @@ func TestCustomBadProxyAndHostSSLWithTempRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -545,7 +551,7 @@ func TestStsHeaderWithSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -562,7 +568,7 @@ func TestStsHeaderWithSSLForRequestOnly(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -579,7 +585,7 @@ func TestStsHeaderInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -595,7 +601,7 @@ func TestStsHeaderWithSubdomains(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -611,7 +617,7 @@ func TestStsHeaderWithSubdomainsForRequestOnly(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -627,7 +633,7 @@ func TestStsHeaderWithPreload(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -643,7 +649,7 @@ func TestStsHeaderWithPreloadForRequest(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -660,7 +666,7 @@ func TestStsHeaderWithSubdomainsWithPreload(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -677,7 +683,7 @@ func TestStsHeaderWithSubdomainsWithPreloadForRequestOnly(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -1054,7 +1060,7 @@ func TestIsSSL(t *testing.T) { expect(t, s.isSSL(req), true) req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema expect(t, s.isSSL(req), true) req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) @@ -1071,8 +1077,8 @@ func TestSSLForceHostWithHTTPS(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -1088,8 +1094,8 @@ func TestSSLForceHostWithHTTP(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "http") s.Handler(myHandler).ServeHTTP(res, req) @@ -1107,8 +1113,8 @@ func TestSSLForceHostWithSSLRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -1126,8 +1132,8 @@ func TestSSLForceHostTemporaryRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -1157,8 +1163,8 @@ func TestModifyResponseHeadersWithSSLAndDifferentSSLHost(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1180,8 +1186,8 @@ func TestModifyResponseHeadersWithSSLAndNoSSLHost(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1204,8 +1210,8 @@ func TestModifyResponseHeadersWithSSLAndMatchingSSLHost(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1228,8 +1234,8 @@ func TestModifyResponseHeadersWithSSLAndPortInLocationResponse(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1252,8 +1258,8 @@ func TestModifyResponseHeadersWithSSLAndPathInLocationResponse(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1280,7 +1286,7 @@ func TestCustomSecureContextKey(t *testing.T) { var actual *http.Request - hf := func(w http.ResponseWriter, r *http.Request) { + hf := func(_ http.ResponseWriter, r *http.Request) { actual = r } @@ -1306,7 +1312,7 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) { var actual *http.Request - hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hf := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { actual = r }) @@ -1322,7 +1328,7 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) { func TestAllowRequestFuncTrue(t *testing.T) { s := New(Options{ - AllowRequestFunc: func(r *http.Request) bool { return true }, + AllowRequestFunc: func(_ *http.Request) bool { return true }, }) res := httptest.NewRecorder() @@ -1337,7 +1343,7 @@ func TestAllowRequestFuncTrue(t *testing.T) { func TestAllowRequestFuncFalse(t *testing.T) { s := New(Options{ - AllowRequestFunc: func(r *http.Request) bool { return false }, + AllowRequestFunc: func(_ *http.Request) bool { return false }, }) res := httptest.NewRecorder() @@ -1351,9 +1357,9 @@ func TestAllowRequestFuncFalse(t *testing.T) { func TestBadRequestHandler(t *testing.T) { s := New(Options{ - AllowRequestFunc: func(r *http.Request) bool { return false }, + AllowRequestFunc: func(_ *http.Request) bool { return false }, }) - badRequestFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + badRequestFunc := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "custom error", http.StatusConflict) }) s.SetBadRequestHandler(badRequestFunc) From 3d539f92663c30ffe95cf00b65ff66a4f0330ee4 Mon Sep 17 00:00:00 2001 From: Cory Jacobsen Date: Tue, 25 Jun 2024 10:38:15 -0600 Subject: [PATCH 2/2] fix: sort csp directives per w3 spec (#96) --- .github/workflows/test.yaml | 2 +- cspbuilder/builder.go | 12 +++++++++++- cspbuilder/builder_test.go | 7 +++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 485d6c6..089c34e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -28,4 +28,4 @@ jobs: path: src/github.com/unrolled/secure - uses: golangci/golangci-lint-action@v4 with: - working-directory: src/github.com/7shifts/seven-deploy + working-directory: src/github.com/unrolled/secure diff --git a/cspbuilder/builder.go b/cspbuilder/builder.go index 573fa2e..685ab5f 100644 --- a/cspbuilder/builder.go +++ b/cspbuilder/builder.go @@ -1,6 +1,7 @@ package cspbuilder import ( + "sort" "strings" ) @@ -62,7 +63,16 @@ func (builder *Builder) MustBuild() string { func (builder *Builder) Build() (string, error) { var sb strings.Builder - for directive := range builder.Directives { + // Pull the directive keys out. + directiveKeys := []string{} + for key := range builder.Directives { + directiveKeys = append(directiveKeys, key) + } + + // Sort the policies: https://www.w3.org/TR/CSP3/#framework-policy + sort.Strings(directiveKeys) + + for _, directive := range directiveKeys { if sb.Len() > 0 { sb.WriteString("; ") } diff --git a/cspbuilder/builder_test.go b/cspbuilder/builder_test.go index 35079d3..c0c4f16 100644 --- a/cspbuilder/builder_test.go +++ b/cspbuilder/builder_test.go @@ -63,6 +63,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { directives map[string]([]string) builder Builder wantParts []string + wantFull string wantErr bool }{ { @@ -86,6 +87,8 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { "trusted-types policy-1 policy-#=_/@.% 'allow-duplicates'", "upgrade-insecure-requests", }, + + wantFull: "default-src 'self' example.com *.example.com; frame-ancestors 'self' http://*.example.com; report-to group1; require-trusted-types-for 'script'; sandbox allow-scripts; trusted-types policy-1 policy-#=_/@.% 'allow-duplicates'; upgrade-insecure-requests", }, } for _, tt := range tests { @@ -101,6 +104,10 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { return } + if got != tt.wantFull { + t.Errorf("ContentSecurityPolicyBuilder.Build() full = %v, but wanted %v", got, tt.wantFull) + } + { startsWithDirective := false