diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..008af37 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,19 @@ +on: + push: + branches: + - master + - v1 + pull_request: + branches: + - "**" +name: Linter +jobs: + golangci: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + path: src/github.com/unrolled/secure + - uses: golangci/golangci-lint-action@v6 + with: + working-directory: src/github.com/unrolled/secure diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 00fe2d2..d3b9d04 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -6,26 +6,17 @@ on: pull_request: branches: - "**" -name: tests +name: Tests jobs: tests: strategy: matrix: - go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x] + go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x, 1.22.x, 1.23.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} steps: - - name: Install Go - uses: actions/setup-go@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} - - name: Checkout code - uses: actions/checkout@v2 - - name: Test - run: make ci - golangci: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: golangci-lint - uses: golangci/golangci-lint-action@v2 + - run: make ci diff --git a/.golangci.yaml b/.golangci.yaml index 392b4b3..f9bbd25 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,36 +1,32 @@ run: - timeout: 5m + timeout: 10m + modules-download-mode: readonly + allow-parallel-runners: true linters: enable-all: true disable: - # Deprecated linters - - varcheck - - exhaustivestruct - - ifshort - - structcheck - - golint - - maligned - - interfacer - - nosnakecase - - deadcode - - scopelint - - rowserrcheck - - sqlclosecheck - - structcheck - - wastedassign - # Ignoring - - lll - - varnamelen - paralleltest - - testpackage - - goerr113 + - gochecknoglobals - exhaustruct - - nestif + - wrapcheck + - tagliatelle + - depguard + - ireturn - funlen - - goconst + - varnamelen + - gomnd + - execinquery + - copyloopvar + - intrange + - gocognit + - lll - cyclop - gocyclo - - gocognit + - testpackage + - err113 + - nestif - maintidx - contextcheck + - perfsprint + - exportloopref diff --git a/README.md b/README.md index 4ec82d5..1add1c7 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ s := secure.New(secure.Options{ AllowRequestFunc: nil, // AllowRequestFunc is a custom function type that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil. HostsProxyHeaders: []string{"X-Forwarded-Hosts"}, // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request. SSLRedirect: true, // If SSLRedirect is set to true, then only allow HTTPS requests. Default is false. - SSLTemporaryRedirect: false, // If SSLTemporaryRedirect is true, the a 302 will be used while redirecting. Default is false (301). + SSLTemporaryRedirect: false, // If SSLTemporaryRedirect is true, then a 307 will be used while redirecting. Default is false (301). SSLHost: "ssl.example.com", // SSLHost is the host name that is used to redirect HTTP requests to HTTPS. Default is "", which indicates to use the same host. SSLHostFunc: nil, // SSLHostFunc is a function pointer, the return value of the function is the host name that has same functionality as `SSHost`. Default is nil. If SSLHostFunc is nil, the `SSLHost` option will be used. SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"}, // SSLProxyHeaders is set of header keys with associated values that would indicate a valid HTTPS request. Useful when using Nginx: `map[string]string{"X-Forwarded-Proto": "https"}`. Default is blank map. diff --git a/csp.go b/csp.go index 948358f..4d516c2 100644 --- a/csp.go +++ b/csp.go @@ -12,7 +12,8 @@ type key int const cspNonceKey key = iota -// CSPNonce returns the nonce value associated with the present request. If no nonce has been generated it returns an empty string. +// CSPNonce returns the nonce value associated with the present request. +// If no nonce has been generated it returns an empty string. func CSPNonce(c context.Context) string { if val, ok := c.Value(cspNonceKey).(string); ok { return val diff --git a/csp_test.go b/csp_test.go index edf662f..f09042b 100644 --- a/csp_test.go +++ b/csp_test.go @@ -23,7 +23,10 @@ func TestCSPNonce(t *testing.T) { }{ {Options{ContentSecurityPolicy: csp}, []string{"Content-Security-Policy"}}, {Options{ContentSecurityPolicyReportOnly: csp}, []string{"Content-Security-Policy-Report-Only"}}, - {Options{ContentSecurityPolicy: csp, ContentSecurityPolicyReportOnly: csp}, []string{"Content-Security-Policy", "Content-Security-Policy-Report-Only"}}, + { + Options{ContentSecurityPolicy: csp, ContentSecurityPolicyReportOnly: csp}, + []string{"Content-Security-Policy", "Content-Security-Policy-Report-Only"}, + }, } for _, c := range cases { diff --git a/cspbuilder/builder.go b/cspbuilder/builder.go index 573fa2e..685ab5f 100644 --- a/cspbuilder/builder.go +++ b/cspbuilder/builder.go @@ -1,6 +1,7 @@ package cspbuilder import ( + "sort" "strings" ) @@ -62,7 +63,16 @@ func (builder *Builder) MustBuild() string { func (builder *Builder) Build() (string, error) { var sb strings.Builder - for directive := range builder.Directives { + // Pull the directive keys out. + directiveKeys := []string{} + for key := range builder.Directives { + directiveKeys = append(directiveKeys, key) + } + + // Sort the policies: https://www.w3.org/TR/CSP3/#framework-policy + sort.Strings(directiveKeys) + + for _, directive := range directiveKeys { if sb.Len() > 0 { sb.WriteString("; ") } diff --git a/cspbuilder/builder_test.go b/cspbuilder/builder_test.go index fb4645a..c0c4f16 100644 --- a/cspbuilder/builder_test.go +++ b/cspbuilder/builder_test.go @@ -42,6 +42,7 @@ func TestContentSecurityPolicyBuilder_Build_SingleDirective(t *testing.T) { tt.directiveName: tt.directiveValues, }, } + got, err := builder.Build() if (err != nil) != tt.wantErr { t.Errorf("ContentSecurityPolicyBuilder.Build() error = %v, wantErr %v", err, tt.wantErr) @@ -62,6 +63,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { directives map[string]([]string) builder Builder wantParts []string + wantFull string wantErr bool }{ { @@ -85,6 +87,8 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { "trusted-types policy-1 policy-#=_/@.% 'allow-duplicates'", "upgrade-insecure-requests", }, + + wantFull: "default-src 'self' example.com *.example.com; frame-ancestors 'self' http://*.example.com; report-to group1; require-trusted-types-for 'script'; sandbox allow-scripts; trusted-types policy-1 policy-#=_/@.% 'allow-duplicates'; upgrade-insecure-requests", }, } for _, tt := range tests { @@ -92,6 +96,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { builder := &Builder{ Directives: tt.directives, } + got, err := builder.Build() if (err != nil) != tt.wantErr { t.Errorf("ContentSecurityPolicyBuilder.Build() error = %v, wantErr %v", err, tt.wantErr) @@ -99,8 +104,13 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { return } + if got != tt.wantFull { + t.Errorf("ContentSecurityPolicyBuilder.Build() full = %v, but wanted %v", got, tt.wantFull) + } + { startsWithDirective := false + for directive := range tt.directives { if strings.HasPrefix(got, directive) { startsWithDirective = true @@ -108,6 +118,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { break } } + if !startsWithDirective { t.Errorf("ContentSecurityPolicyBuilder.Build() = '%v', does not start with directive name", got) } @@ -116,6 +127,7 @@ func TestContentSecurityPolicyBuilder_Build_MultipleDirectives(t *testing.T) { if strings.HasSuffix(got, " ") { t.Errorf("ContentSecurityPolicyBuilder.Build() = '%v', ends on whitespace", got) } + if strings.HasSuffix(got, ";") { t.Errorf("ContentSecurityPolicyBuilder.Build() = '%v', ends on semi-colon", got) } diff --git a/cspbuilder/directive_builder_test.go b/cspbuilder/directive_builder_test.go index 7d31133..8cab58d 100644 --- a/cspbuilder/directive_builder_test.go +++ b/cspbuilder/directive_builder_test.go @@ -79,6 +79,7 @@ func TestBuildDirectiveFrameAncestors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sb := &strings.Builder{} + err := buildDirectiveFrameAncestors(sb, tt.values) if tt.wantErr && err != nil { return @@ -222,6 +223,7 @@ func TestBuildDirectiveTrustedTypes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sb := &strings.Builder{} + err := buildDirectiveTrustedTypes(sb, tt.values) if tt.wantErr && err != nil { return diff --git a/secure.go b/secure.go index 0efcc61..9b93960 100644 --- a/secure.go +++ b/secure.go @@ -65,7 +65,7 @@ type Options struct { SSLRedirect bool // If SSLForceHost is true and SSLHost is set, requests will be forced to use SSLHost even the ones that are already using SSL. Default is false. SSLForceHost bool - // If SSLTemporaryRedirect is true, the a 302 will be used while redirecting. Default is false (301). + // If SSLTemporaryRedirect is true, then a 307 will be used while redirecting. Default is false (301). SSLTemporaryRedirect bool // If STSIncludeSubdomains is set to true, the `includeSubdomains` will be appended to the Strict-Transport-Security header. Default is false. STSIncludeSubdomains bool @@ -98,12 +98,12 @@ type Options struct { // AllowedHostsAreRegex determines, if the provided `AllowedHosts` slice contains valid regular expressions. If this flag is set to true, every request's host will be checked against these expressions. Default is false. AllowedHostsAreRegex bool // AllowRequestFunc is a custom function that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil. - AllowRequestFunc AllowRequestFunc + AllowRequestFunc AllowRequestFunc `json:"-" toml:"-" yaml:"-"` // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request. HostsProxyHeaders []string // SSLHostFunc is a function pointer, the return value of the function is the host name that has same functionality as `SSHost`. Default is nil. // If SSLHostFunc is nil, the `SSLHost` option will be used. - SSLHostFunc *SSLHostFunc + SSLHostFunc *SSLHostFunc `json:"-" toml:"-" yaml:"-"` // SSLProxyHeaders is set of header keys with associated values that would indicate a valid https request. Useful when using Nginx: `map[string]string{"X-Forwarded-Proto": "https"}`. Default is blank map. SSLProxyHeaders map[string]string // STSSeconds is the max-age of the Strict-Transport-Security header. Default is 0, which would NOT include the header. diff --git a/secure_test.go b/secure_test.go index 107ce71..80d8d2a 100644 --- a/secure_test.go +++ b/secure_test.go @@ -3,6 +3,7 @@ package secure import ( "context" "crypto/tls" + "encoding/json" "net/http" "net/http/httptest" "reflect" @@ -10,8 +11,14 @@ import ( "testing" ) +const ( + httpSchema = "http" + httpsSchema = "https" + exampleHost = "www.example.com" +) + //nolint:gochecknoglobals -var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +var myHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("bar")) }) @@ -34,7 +41,7 @@ func TestNoAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -49,7 +56,7 @@ func TestGoodSingleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -64,7 +71,7 @@ func TestBadSingleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -79,7 +86,7 @@ func TestRegexSingleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "sub.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -140,7 +147,7 @@ func TestGoodSingleAllowHostsProxyHeaders(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "example-internal" + req.Host = "example-internal-3" req.Header.Set("X-Proxy-Host", "www.example.com") s.Handler(myHandler).ServeHTTP(res, req) @@ -172,7 +179,7 @@ func TestGoodMultipleAllowHosts(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "sub.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -202,7 +209,7 @@ func TestAllowHostsInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www3.example.com" + req.Host = exampleHost s.Handler(myHandler).ServeHTTP(res, req) @@ -214,7 +221,7 @@ func TestBadHostHandler(t *testing.T) { AllowedHosts: []string{"www.example.com", "sub.example.com"}, }) - badHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + badHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "BadHost", http.StatusInternalServerError) }) @@ -239,8 +246,8 @@ func TestSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -255,8 +262,8 @@ func TestSSLInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -270,8 +277,8 @@ func TestBasicSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -287,8 +294,8 @@ func TestBasicSSLWithHost(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -300,7 +307,7 @@ func TestBasicSSLWithHostFunc(t *testing.T) { sslHostFunc := (func() SSLHostFunc { return func(host string) string { newHost := "" - if host == "www.example.com" { + if host == exampleHost { newHost = "secure.example.com:8443" } else if host == "www.example.org" { newHost = "secure.example.org" @@ -317,8 +324,8 @@ func TestBasicSSLWithHostFunc(t *testing.T) { // test www.example.com res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -329,7 +336,7 @@ func TestBasicSSLWithHostFunc(t *testing.T) { res = httptest.NewRecorder() req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) req.Host = "www.example.org" - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -340,7 +347,7 @@ func TestBasicSSLWithHostFunc(t *testing.T) { res = httptest.NewRecorder() req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) req.Host = "www.other.com" - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -355,8 +362,8 @@ func TestBadProxySSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -373,8 +380,8 @@ func TestCustomProxySSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -391,8 +398,8 @@ func TestCustomProxySSLInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "http") s.Handler(myHandler).ServeHTTP(res, req) @@ -410,7 +417,7 @@ func TestCustomProxyAndHostProxyHeadersWithRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) req.Host = "example-internal" - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") req.Header.Add("X-Forwarded-Host", "www.example.com") @@ -429,8 +436,8 @@ func TestCustomProxyAndHostSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -447,8 +454,8 @@ func TestCustomBadProxyAndHostSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -467,8 +474,8 @@ func TestCustomBadProxyAndHostSSLWithTempRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -545,7 +552,7 @@ func TestStsHeaderWithSSL(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -562,7 +569,7 @@ func TestStsHeaderWithSSLForRequestOnly(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "http" + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -579,7 +586,7 @@ func TestStsHeaderInDevMode(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -595,7 +602,7 @@ func TestStsHeaderWithSubdomains(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -611,7 +618,7 @@ func TestStsHeaderWithSubdomainsForRequestOnly(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -627,7 +634,7 @@ func TestStsHeaderWithPreload(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -643,7 +650,7 @@ func TestStsHeaderWithPreloadForRequest(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -660,7 +667,7 @@ func TestStsHeaderWithSubdomainsWithPreload(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.Handler(myHandler).ServeHTTP(res, req) @@ -677,7 +684,7 @@ func TestStsHeaderWithSubdomainsWithPreloadForRequestOnly(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema s.HandlerForRequestOnly(myHandler).ServeHTTP(res, req) @@ -1054,7 +1061,7 @@ func TestIsSSL(t *testing.T) { expect(t, s.isSSL(req), true) req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.URL.Scheme = "https" + req.URL.Scheme = httpsSchema expect(t, s.isSSL(req), true) req, _ = http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) @@ -1071,8 +1078,8 @@ func TestSSLForceHostWithHTTPS(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -1088,8 +1095,8 @@ func TestSSLForceHostWithHTTP(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "http") s.Handler(myHandler).ServeHTTP(res, req) @@ -1107,8 +1114,8 @@ func TestSSLForceHostWithSSLRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -1126,8 +1133,8 @@ func TestSSLForceHostTemporaryRedirect(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "https" + req.Host = exampleHost + req.URL.Scheme = httpsSchema req.Header.Add("X-Forwarded-Proto", "https") s.Handler(myHandler).ServeHTTP(res, req) @@ -1157,8 +1164,8 @@ func TestModifyResponseHeadersWithSSLAndDifferentSSLHost(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1180,8 +1187,8 @@ func TestModifyResponseHeadersWithSSLAndNoSSLHost(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1204,8 +1211,8 @@ func TestModifyResponseHeadersWithSSLAndMatchingSSLHost(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1228,8 +1235,8 @@ func TestModifyResponseHeadersWithSSLAndPortInLocationResponse(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1252,8 +1259,8 @@ func TestModifyResponseHeadersWithSSLAndPathInLocationResponse(t *testing.T) { }) req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "/foo", nil) - req.Host = "www.example.com" - req.URL.Scheme = "http" + req.Host = exampleHost + req.URL.Scheme = httpSchema req.Header.Add("X-Forwarded-Proto", "https") res := &http.Response{} @@ -1280,7 +1287,7 @@ func TestCustomSecureContextKey(t *testing.T) { var actual *http.Request - hf := func(w http.ResponseWriter, r *http.Request) { + hf := func(_ http.ResponseWriter, r *http.Request) { actual = r } @@ -1306,7 +1313,7 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) { var actual *http.Request - hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hf := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { actual = r }) @@ -1322,7 +1329,7 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) { func TestAllowRequestFuncTrue(t *testing.T) { s := New(Options{ - AllowRequestFunc: func(r *http.Request) bool { return true }, + AllowRequestFunc: func(_ *http.Request) bool { return true }, }) res := httptest.NewRecorder() @@ -1337,7 +1344,7 @@ func TestAllowRequestFuncTrue(t *testing.T) { func TestAllowRequestFuncFalse(t *testing.T) { s := New(Options{ - AllowRequestFunc: func(r *http.Request) bool { return false }, + AllowRequestFunc: func(_ *http.Request) bool { return false }, }) res := httptest.NewRecorder() @@ -1351,9 +1358,9 @@ func TestAllowRequestFuncFalse(t *testing.T) { func TestBadRequestHandler(t *testing.T) { s := New(Options{ - AllowRequestFunc: func(r *http.Request) bool { return false }, + AllowRequestFunc: func(_ *http.Request) bool { return false }, }) - badRequestFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + badRequestFunc := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "custom error", http.StatusConflict) }) s.SetBadRequestHandler(badRequestFunc) @@ -1368,6 +1375,21 @@ func TestBadRequestHandler(t *testing.T) { expect(t, strings.TrimSpace(res.Body.String()), `custom error`) } +func TestMarshal(t *testing.T) { + // Options has struct field tags to omit func fields + var o1 Options + + b, err := json.Marshal(o1) //nolint:musttag + if err != nil { + t.Errorf("unexpected error marshal: %v", err) + } + + err = json.Unmarshal(b, &o1) //nolint:musttag + if err != nil { + t.Errorf("unexpected error unmarshal: %v", err) + } +} + // Test Helper. func expect(t *testing.T, a interface{}, b interface{}) { t.Helper()