From 1bfbf46e84791ad465db651082db194d0fd082d2 Mon Sep 17 00:00:00 2001 From: defelmnq Date: Fri, 22 Nov 2024 01:31:49 +0100 Subject: [PATCH 1/8] work-on-testing-cases --- coderd/httpmw/provisionerdaemon_test.go | 148 ++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 coderd/httpmw/provisionerdaemon_test.go diff --git a/coderd/httpmw/provisionerdaemon_test.go b/coderd/httpmw/provisionerdaemon_test.go new file mode 100644 index 0000000000000..8c46c91c32622 --- /dev/null +++ b/coderd/httpmw/provisionerdaemon_test.go @@ -0,0 +1,148 @@ +package httpmw_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/codersdk" +) + +func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts httpmw.ExtractProvisionerAuthConfig + expectedStatusCode int + expectedResponseMessage string + provisionerKey string + provisionerPSK string + }{ + { + name: "NoKeyProvided_Optional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: true, + }, + expectedStatusCode: http.StatusOK, + expectedResponseMessage: "", + }, + { + name: "NoKeyProvided_NotOptional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: false, + }, + expectedStatusCode: http.StatusUnauthorized, + expectedResponseMessage: "provisioner daemon key required", + }, + { + name: "ProvisionerKeyAndPSKProvided_NotOptional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: false, + }, + provisionerKey: "key", + provisionerPSK: "psk", + expectedStatusCode: http.StatusBadRequest, + expectedResponseMessage: "provisioner daemon key and psk provided, but only one is allowed", + }, + { + name: "ProvisionerKeyAndPSKProvided_Optional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: true, + }, + provisionerKey: "key", + expectedStatusCode: http.StatusOK, + expectedResponseMessage: "", + }, + { + name: "InvalidProvisionerKey_NotOptional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: false, + }, + provisionerKey: "invalid", + expectedStatusCode: http.StatusBadRequest, + expectedResponseMessage: "provisioner daemon key invalid", + }, + { + name: "InvalidProvisionerKey_Optional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: true, + }, + provisionerKey: "invalid", + expectedStatusCode: http.StatusOK, + expectedResponseMessage: "", + }, + { + name: "InvalidProvisionerPSK_NotOptional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: false, + PSK: "psk", + }, + provisionerPSK: "invalid", + expectedStatusCode: http.StatusUnauthorized, + expectedResponseMessage: "provisioner daemon psk invalid", + }, + { + name: "InvalidProvisionerPSK_Optional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: true, + PSK: "psk", + }, + provisionerPSK: "invalid", + expectedStatusCode: http.StatusOK, + expectedResponseMessage: "", + }, + { + name: "ValidProvisionerPSK_NotOptional", + opts: httpmw.ExtractProvisionerAuthConfig{ + DB: nil, + Optional: false, + PSK: "ThisIsAValidPSK", + }, + provisionerPSK: "ThisIsAValidPSK", + expectedStatusCode: http.StatusOK, + expectedResponseMessage: "", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + if test.provisionerKey != "" { + r.Header.Set(codersdk.ProvisionerDaemonKey, test.provisionerKey) + } + if test.provisionerPSK != "" { + r.Header.Set(codersdk.ProvisionerDaemonPSK, test.provisionerPSK) + } + + httpmw.ExtractProvisionerDaemonAuthenticated(test.opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + require.Equal(t, test.expectedStatusCode, res.Result().StatusCode) + if test.expectedResponseMessage != "" { + require.Contains(t, res.Body.String(), test.expectedResponseMessage) + } + }) + } + +} From d9ab576a9f79d42c607999751c35afc09efb986b Mon Sep 17 00:00:00 2001 From: defelmnq Date: Fri, 22 Nov 2024 01:51:31 +0100 Subject: [PATCH 2/8] improve testing coverage --- coderd/httpmw/provisionerdaemon_test.go | 74 +++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/coderd/httpmw/provisionerdaemon_test.go b/coderd/httpmw/provisionerdaemon_test.go index 8c46c91c32622..d8704bada6857 100644 --- a/coderd/httpmw/provisionerdaemon_test.go +++ b/coderd/httpmw/provisionerdaemon_test.go @@ -11,6 +11,9 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" + "github.com/coder/coder/v2/enterprise/coderd/license" + "github.com/coder/coder/v2/testutil" ) func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { @@ -145,4 +148,75 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { }) } + t.Run("ProvisionerKey", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + }) + // nolint:gocritic // test + key, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ + Name: "dont-TEST-me", + }) + require.NoError(t, err) + + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + r.Header.Set(codersdk.ProvisionerDaemonKey, key.Key) + + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: db, + Optional: false, + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + require.Equal(t, http.StatusOK, res.Result().StatusCode) + }) + + t.Run("ProvisionerKey_NotFound", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + }) + // nolint:gocritic // test + _, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ + Name: "dont-TEST-me", + }) + require.NoError(t, err) + + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + r.Header.Set(codersdk.ProvisionerDaemonKey, "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4") + + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: db, + Optional: false, + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + require.Contains(t, res.Body.String(), "provisioner daemon key invalid") + }) + } From e26f5e1e2a4a1f3c11ef57146675b1a17de34961 Mon Sep 17 00:00:00 2001 From: defelmnq Date: Fri, 22 Nov 2024 02:02:30 +0100 Subject: [PATCH 3/8] improve testing coverage --- coderd/httpmw/provisionerdaemon_test.go | 143 ++++++++++++------------ 1 file changed, 70 insertions(+), 73 deletions(-) diff --git a/coderd/httpmw/provisionerdaemon_test.go b/coderd/httpmw/provisionerdaemon_test.go index d8704bada6857..20c4dd25b1e62 100644 --- a/coderd/httpmw/provisionerdaemon_test.go +++ b/coderd/httpmw/provisionerdaemon_test.go @@ -11,9 +11,6 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" - "github.com/coder/coder/v2/enterprise/coderd/license" - "github.com/coder/coder/v2/testutil" ) func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { @@ -148,75 +145,75 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { }) } - t.Run("ProvisionerKey", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureExternalProvisionerDaemons: 1, - }, - }, - }) - // nolint:gocritic // test - key, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ - Name: "dont-TEST-me", - }) - require.NoError(t, err) - - routeCtx := chi.NewRouteContext() - r := httptest.NewRequest(http.MethodGet, "/", nil) - r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) - res := httptest.NewRecorder() - - r.Header.Set(codersdk.ProvisionerDaemonKey, key.Key) - - httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ - DB: db, - Optional: false, - })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })).ServeHTTP(res, r) - - require.Equal(t, http.StatusOK, res.Result().StatusCode) - }) - - t.Run("ProvisionerKey_NotFound", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureExternalProvisionerDaemons: 1, - }, - }, - }) - // nolint:gocritic // test - _, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ - Name: "dont-TEST-me", - }) - require.NoError(t, err) - - routeCtx := chi.NewRouteContext() - r := httptest.NewRequest(http.MethodGet, "/", nil) - r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) - res := httptest.NewRecorder() - - r.Header.Set(codersdk.ProvisionerDaemonKey, "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4") - - httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ - DB: db, - Optional: false, - })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })).ServeHTTP(res, r) - - require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) - require.Contains(t, res.Body.String(), "provisioner daemon key invalid") - }) + // t.Run("ProvisionerKey", func(t *testing.T) { + // t.Parallel() + + // ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + // defer cancel() + // client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + // LicenseOptions: &coderdenttest.LicenseOptions{ + // Features: license.Features{ + // codersdk.FeatureExternalProvisionerDaemons: 1, + // }, + // }, + // }) + // // nolint:gocritic // test + // key, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ + // Name: "dont-TEST-me", + // }) + // require.NoError(t, err) + + // routeCtx := chi.NewRouteContext() + // r := httptest.NewRequest(http.MethodGet, "/", nil) + // r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + // res := httptest.NewRecorder() + + // r.Header.Set(codersdk.ProvisionerDaemonKey, key.Key) + + // httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + // DB: db, + // Optional: false, + // })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // w.WriteHeader(http.StatusOK) + // })).ServeHTTP(res, r) + + // require.Equal(t, http.StatusOK, res.Result().StatusCode) + // }) + + // t.Run("ProvisionerKey_NotFound", func(t *testing.T) { + // t.Parallel() + + // ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + // defer cancel() + // client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + // LicenseOptions: &coderdenttest.LicenseOptions{ + // Features: license.Features{ + // codersdk.FeatureExternalProvisionerDaemons: 1, + // }, + // }, + // }) + // // nolint:gocritic // test + // _, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ + // Name: "dont-TEST-me", + // }) + // require.NoError(t, err) + + // routeCtx := chi.NewRouteContext() + // r := httptest.NewRequest(http.MethodGet, "/", nil) + // r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + // res := httptest.NewRecorder() + + // r.Header.Set(codersdk.ProvisionerDaemonKey, "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4") + + // httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + // DB: db, + // Optional: false, + // })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // w.WriteHeader(http.StatusOK) + // })).ServeHTTP(res, r) + + // require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + // require.Contains(t, res.Body.String(), "provisioner daemon key invalid") + // }) } From 13d6dd91eb97abbc197d665f847706279c8c1f2a Mon Sep 17 00:00:00 2001 From: defelmnq Date: Fri, 22 Nov 2024 02:08:31 +0100 Subject: [PATCH 4/8] improve testing coverage --- coderd/httpmw/provisionerdaemon_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coderd/httpmw/provisionerdaemon_test.go b/coderd/httpmw/provisionerdaemon_test.go index 20c4dd25b1e62..e454a8ed76d75 100644 --- a/coderd/httpmw/provisionerdaemon_test.go +++ b/coderd/httpmw/provisionerdaemon_test.go @@ -134,10 +134,11 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { r.Header.Set(codersdk.ProvisionerDaemonPSK, test.provisionerPSK) } - httpmw.ExtractProvisionerDaemonAuthenticated(test.opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpmw.ExtractProvisionerDaemonAuthenticated(test.opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })).ServeHTTP(res, r) + //nolint:bodyclose require.Equal(t, test.expectedStatusCode, res.Result().StatusCode) if test.expectedResponseMessage != "" { require.Contains(t, res.Body.String(), test.expectedResponseMessage) @@ -215,5 +216,4 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { // require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) // require.Contains(t, res.Body.String(), "provisioner daemon key invalid") // }) - } From a6afc338d466c6b39d5cc99e1aac896fc0bfbafd Mon Sep 17 00:00:00 2001 From: defelmnq Date: Sat, 23 Nov 2024 04:03:53 +0100 Subject: [PATCH 5/8] move middlewar --- .../coderd}/httpmw/provisionerdaemon_test.go | 143 +++++++++--------- 1 file changed, 73 insertions(+), 70 deletions(-) rename {coderd => enterprise/coderd}/httpmw/provisionerdaemon_test.go (61%) diff --git a/coderd/httpmw/provisionerdaemon_test.go b/enterprise/coderd/httpmw/provisionerdaemon_test.go similarity index 61% rename from coderd/httpmw/provisionerdaemon_test.go rename to enterprise/coderd/httpmw/provisionerdaemon_test.go index e454a8ed76d75..7a10b7092a614 100644 --- a/coderd/httpmw/provisionerdaemon_test.go +++ b/enterprise/coderd/httpmw/provisionerdaemon_test.go @@ -11,6 +11,9 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" + "github.com/coder/coder/v2/enterprise/coderd/license" + "github.com/coder/coder/v2/testutil" ) func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { @@ -146,74 +149,74 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { }) } - // t.Run("ProvisionerKey", func(t *testing.T) { - // t.Parallel() - - // ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - // defer cancel() - // client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - // LicenseOptions: &coderdenttest.LicenseOptions{ - // Features: license.Features{ - // codersdk.FeatureExternalProvisionerDaemons: 1, - // }, - // }, - // }) - // // nolint:gocritic // test - // key, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ - // Name: "dont-TEST-me", - // }) - // require.NoError(t, err) - - // routeCtx := chi.NewRouteContext() - // r := httptest.NewRequest(http.MethodGet, "/", nil) - // r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) - // res := httptest.NewRecorder() - - // r.Header.Set(codersdk.ProvisionerDaemonKey, key.Key) - - // httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ - // DB: db, - // Optional: false, - // })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // w.WriteHeader(http.StatusOK) - // })).ServeHTTP(res, r) - - // require.Equal(t, http.StatusOK, res.Result().StatusCode) - // }) - - // t.Run("ProvisionerKey_NotFound", func(t *testing.T) { - // t.Parallel() - - // ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - // defer cancel() - // client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - // LicenseOptions: &coderdenttest.LicenseOptions{ - // Features: license.Features{ - // codersdk.FeatureExternalProvisionerDaemons: 1, - // }, - // }, - // }) - // // nolint:gocritic // test - // _, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ - // Name: "dont-TEST-me", - // }) - // require.NoError(t, err) - - // routeCtx := chi.NewRouteContext() - // r := httptest.NewRequest(http.MethodGet, "/", nil) - // r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) - // res := httptest.NewRecorder() - - // r.Header.Set(codersdk.ProvisionerDaemonKey, "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4") - - // httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ - // DB: db, - // Optional: false, - // })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // w.WriteHeader(http.StatusOK) - // })).ServeHTTP(res, r) - - // require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) - // require.Contains(t, res.Body.String(), "provisioner daemon key invalid") - // }) + t.Run("ProvisionerKey", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + }) + // nolint:gocritic // test + key, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ + Name: "dont-TEST-me", + }) + require.NoError(t, err) + + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + r.Header.Set(codersdk.ProvisionerDaemonKey, key.Key) + + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: db, + Optional: false, + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + require.Equal(t, http.StatusOK, res.Result().StatusCode) + }) + + t.Run("ProvisionerKey_NotFound", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + }) + // nolint:gocritic // test + _, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{ + Name: "dont-TEST-me", + }) + require.NoError(t, err) + + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + r.Header.Set(codersdk.ProvisionerDaemonKey, "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4") + + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: db, + Optional: false, + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + require.Contains(t, res.Body.String(), "provisioner daemon key invalid") + }) } From 66bd65e22b142d37703b53cdb99bcebd7089d552 Mon Sep 17 00:00:00 2001 From: defelmnq Date: Sat, 23 Nov 2024 04:18:28 +0100 Subject: [PATCH 6/8] move middlewar --- enterprise/coderd/httpmw/provisionerdaemon_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/enterprise/coderd/httpmw/provisionerdaemon_test.go b/enterprise/coderd/httpmw/provisionerdaemon_test.go index 7a10b7092a614..6ea593f3ccc7d 100644 --- a/enterprise/coderd/httpmw/provisionerdaemon_test.go +++ b/enterprise/coderd/httpmw/provisionerdaemon_test.go @@ -177,10 +177,11 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ DB: db, Optional: false, - })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })).ServeHTTP(res, r) + //nolint:bodyclose require.Equal(t, http.StatusOK, res.Result().StatusCode) }) @@ -212,10 +213,11 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ DB: db, Optional: false, - })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })).ServeHTTP(res, r) + //nolint:bodyclose require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) require.Contains(t, res.Body.String(), "provisioner daemon key invalid") }) From 5eafd6d38cf6462fb019ecea660789497dc995ce Mon Sep 17 00:00:00 2001 From: defelmnq Date: Mon, 25 Nov 2024 14:00:14 +0100 Subject: [PATCH 7/8] improve testing coverage --- coderd/httpmw/provisionerdaemon.go | 3 + enterprise/coderd/httpmw/doc.go | 5 + .../coderd/httpmw/provisionerdaemon_test.go | 103 ++++++++++++++---- 3 files changed, 92 insertions(+), 19 deletions(-) create mode 100644 enterprise/coderd/httpmw/doc.go diff --git a/coderd/httpmw/provisionerdaemon.go b/coderd/httpmw/provisionerdaemon.go index b2b4e2c04088e..e8a50ae0fc3b3 100644 --- a/coderd/httpmw/provisionerdaemon.go +++ b/coderd/httpmw/provisionerdaemon.go @@ -25,6 +25,9 @@ type ExtractProvisionerAuthConfig struct { PSK string } +// ExtractProvisionerDaemonAuthenticated authenticates a request as a provisioner daemon. +// If the request is not authenticated, the next handler is called unless Optional is true. +// This function currently is tested inside the enterprise package. func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/enterprise/coderd/httpmw/doc.go b/enterprise/coderd/httpmw/doc.go new file mode 100644 index 0000000000000..ef48f0f6e0498 --- /dev/null +++ b/enterprise/coderd/httpmw/doc.go @@ -0,0 +1,5 @@ +// Package httpmw contains middleware for HTTP handlers. +// Currently, the tested middleware is inside the AGPL package. +// As the middleware also contains enterprise-only logic, tests had to be +// moved here. +package httpmw diff --git a/enterprise/coderd/httpmw/provisionerdaemon_test.go b/enterprise/coderd/httpmw/provisionerdaemon_test.go index 6ea593f3ccc7d..192ba89027929 100644 --- a/enterprise/coderd/httpmw/provisionerdaemon_test.go +++ b/enterprise/coderd/httpmw/provisionerdaemon_test.go @@ -7,8 +7,13 @@ import ( "testing" "github.com/go-chi/chi/v5" + "github.com/google/uuid" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" @@ -17,6 +22,10 @@ import ( ) func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { + const ( + //nolint:gosec // test key generated by test + functionalKey = "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4" + ) t.Parallel() tests := []struct { @@ -33,8 +42,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { DB: nil, Optional: true, }, - expectedStatusCode: http.StatusOK, - expectedResponseMessage: "", + expectedStatusCode: http.StatusOK, }, { name: "NoKeyProvided_NotOptional", @@ -62,9 +70,8 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { DB: nil, Optional: true, }, - provisionerKey: "key", - expectedStatusCode: http.StatusOK, - expectedResponseMessage: "", + provisionerKey: "key", + expectedStatusCode: http.StatusOK, }, { name: "InvalidProvisionerKey_NotOptional", @@ -82,9 +89,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { DB: nil, Optional: true, }, - provisionerKey: "invalid", - expectedStatusCode: http.StatusOK, - expectedResponseMessage: "", + provisionerKey: "invalid", }, { name: "InvalidProvisionerPSK_NotOptional", @@ -104,9 +109,8 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { Optional: true, PSK: "psk", }, - provisionerPSK: "invalid", - expectedStatusCode: http.StatusOK, - expectedResponseMessage: "", + provisionerPSK: "invalid", + expectedStatusCode: http.StatusOK, }, { name: "ValidProvisionerPSK_NotOptional", @@ -115,9 +119,8 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { Optional: false, PSK: "ThisIsAValidPSK", }, - provisionerPSK: "ThisIsAValidPSK", - expectedStatusCode: http.StatusOK, - expectedResponseMessage: "", + provisionerPSK: "ThisIsAValidPSK", + expectedStatusCode: http.StatusOK, }, } @@ -152,8 +155,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { t.Run("ProvisionerKey", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() + ctx := testutil.Context(t, testutil.WaitShort) client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -188,8 +190,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { t.Run("ProvisionerKey_NotFound", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() + ctx := testutil.Context(t, testutil.WaitShort) client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -208,7 +209,9 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) res := httptest.NewRecorder() - r.Header.Set(codersdk.ProvisionerDaemonKey, "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4") + //nolint:gosec // test key generated by test + pkey := "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4" + r.Header.Set(codersdk.ProvisionerDaemonKey, pkey) httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ DB: db, @@ -221,4 +224,66 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) require.Contains(t, res.Body.String(), "provisioner daemon key invalid") }) + + t.Run("ProvisionerKey_CompareFail", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockDB := dbmock.NewMockStore(ctrl) + + gomock.InOrder( + mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{ + ID: uuid.New(), + HashedSecret: []byte("hashedSecret"), + }, nil), + ) + + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey) + + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: mockDB, + Optional: false, + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + //nolint:bodyclose + require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + require.Contains(t, res.Body.String(), "provisioner daemon key invalid") + }) + + t.Run("ProvisionerKey_DBError", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockDB := dbmock.NewMockStore(ctrl) + + gomock.InOrder( + mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{}, xerrors.New("error")), + ) + + routeCtx := chi.NewRouteContext() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) + res := httptest.NewRecorder() + + //nolint:gosec // test key generated by test + r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey) + + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: mockDB, + Optional: false, + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(res, r) + + //nolint:bodyclose + require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + require.Contains(t, res.Body.String(), "get provisioner daemon key") + }) } From 711eeeefe7ceaf76c68079b049ecba5a8488acfa Mon Sep 17 00:00:00 2001 From: defelmnq Date: Mon, 25 Nov 2024 14:19:39 +0100 Subject: [PATCH 8/8] improve testing coverage --- enterprise/coderd/httpmw/provisionerdaemon_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/enterprise/coderd/httpmw/provisionerdaemon_test.go b/enterprise/coderd/httpmw/provisionerdaemon_test.go index 192ba89027929..84da7f546fa35 100644 --- a/enterprise/coderd/httpmw/provisionerdaemon_test.go +++ b/enterprise/coderd/httpmw/provisionerdaemon_test.go @@ -89,7 +89,8 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { DB: nil, Optional: true, }, - provisionerKey: "invalid", + provisionerKey: "invalid", + expectedStatusCode: http.StatusOK, }, { name: "InvalidProvisionerPSK_NotOptional", @@ -283,7 +284,7 @@ func TestExtractProvisionerDaemonAuthenticated(t *testing.T) { })).ServeHTTP(res, r) //nolint:bodyclose - require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode) + require.Equal(t, http.StatusInternalServerError, res.Result().StatusCode) require.Contains(t, res.Body.String(), "get provisioner daemon key") }) }