diff --git a/coderd/httpmw/actor_test.go b/coderd/httpmw/actor_test.go index ef05a8cb3a3d2..30ec5bca4d2e8 100644 --- a/coderd/httpmw/actor_test.go +++ b/coderd/httpmw/actor_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" @@ -38,7 +38,7 @@ func TestRequireAPIKeyOrWorkspaceProxyAuth(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -75,7 +75,7 @@ func TestRequireAPIKeyOrWorkspaceProxyAuth(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) _, userToken = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -114,7 +114,7 @@ func TestRequireAPIKeyOrWorkspaceProxyAuth(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) proxy, token = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{}) r = httptest.NewRequest("GET", "/", nil) diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 6e2e75ace9825..06ee93422bbf9 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -6,10 +6,8 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/http/httptest" - "slices" "strings" "sync/atomic" "testing" @@ -18,12 +16,13 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" "golang.org/x/oauth2" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" @@ -83,9 +82,9 @@ func TestAPIKey(t *testing.T) { t.Run("NoCookie", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ DB: db, @@ -99,9 +98,9 @@ func TestAPIKey(t *testing.T) { t.Run("NoCookieRedirects", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ DB: db, @@ -118,9 +117,9 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidFormat", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) r.Header.Set(codersdk.SessionTokenHeader, "test-wow-hello") @@ -136,9 +135,9 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidIDLength", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) r.Header.Set(codersdk.SessionTokenHeader, "test-wow") @@ -154,9 +153,9 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidSecretLength", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) r.Header.Set(codersdk.SessionTokenHeader, "testtestid-wow") @@ -172,7 +171,7 @@ func TestAPIKey(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) id, secret = randomAPIKeyParts() r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() @@ -191,10 +190,10 @@ func TestAPIKey(t *testing.T) { t.Run("UserLinkNotFound", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = dbgen.User(t, db, database.User{ + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + user = dbgen.User(t, db, database.User{ LoginType: database.LoginTypeGithub, }) // Intentionally not inserting any user link @@ -219,10 +218,10 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidSecret", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = dbgen.User(t, db, database.User{}) + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + user = dbgen.User(t, db, database.User{}) // Use a different secret so they don't match! hashed = sha256.Sum256([]byte("differentsecret")) @@ -244,7 +243,7 @@ func TestAPIKey(t *testing.T) { t.Run("Expired", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -273,7 +272,7 @@ func TestAPIKey(t *testing.T) { t.Run("Valid", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -309,7 +308,7 @@ func TestAPIKey(t *testing.T) { t.Run("ValidWithScope", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -347,7 +346,7 @@ func TestAPIKey(t *testing.T) { t.Run("QueryParameter", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -381,7 +380,7 @@ func TestAPIKey(t *testing.T) { t.Run("ValidUpdateLastUsed", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -412,7 +411,7 @@ func TestAPIKey(t *testing.T) { t.Run("ValidUpdateExpiry", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -443,7 +442,7 @@ func TestAPIKey(t *testing.T) { t.Run("NoRefresh", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -475,7 +474,7 @@ func TestAPIKey(t *testing.T) { t.Run("OAuthNotExpired", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -511,7 +510,7 @@ func TestAPIKey(t *testing.T) { t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -561,7 +560,7 @@ func TestAPIKey(t *testing.T) { t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -607,7 +606,7 @@ func TestAPIKey(t *testing.T) { t.Run("OAuthRefresh", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -630,7 +629,7 @@ func TestAPIKey(t *testing.T) { oauthToken := &oauth2.Token{ AccessToken: "wow", RefreshToken: "moo", - Expiry: dbtime.Now().AddDate(0, 0, 1), + Expiry: dbtestutil.NowInDefaultTimezone().AddDate(0, 0, 1), } httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ DB: db, @@ -665,7 +664,7 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( ctx = testutil.Context(t, testutil.WaitShort) - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -715,7 +714,7 @@ func TestAPIKey(t *testing.T) { t.Run("RemoteIPUpdates", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -740,15 +739,15 @@ func TestAPIKey(t *testing.T) { gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) - require.Equal(t, net.ParseIP("1.1.1.1"), gotAPIKey.IPAddress.IPNet.IP) + require.Equal(t, "1.1.1.1", gotAPIKey.IPAddress.IPNet.IP.String()) }) t.Run("RedirectToLogin", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ @@ -767,9 +766,9 @@ func TestAPIKey(t *testing.T) { t.Run("Optional", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() count int64 handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -798,7 +797,7 @@ func TestAPIKey(t *testing.T) { t.Run("Tokens", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -831,7 +830,7 @@ func TestAPIKey(t *testing.T) { t.Run("MissingConfig", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) user = dbgen.User(t, db, database.User{}) _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -866,7 +865,7 @@ func TestAPIKey(t *testing.T) { t.Run("CustomRoles", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) org = dbgen.Organization(t, db, database.Organization{}) customRole = dbgen.CustomRole(t, db, database.CustomRole{ Name: "custom-role", @@ -933,7 +932,7 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( roleNotExistsName = "role-not-exists" - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) org = dbgen.Organization(t, db, database.Organization{}) user = dbgen.User(t, db, database.User{ RBACRoles: []string{ diff --git a/coderd/httpmw/chat_test.go b/coderd/httpmw/chat_test.go index a8bad05f33797..3acc2db8b9877 100644 --- a/coderd/httpmw/chat_test.go +++ b/coderd/httpmw/chat_test.go @@ -14,7 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" @@ -40,10 +40,10 @@ func TestExtractChat(t *testing.T) { t.Run("None", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - rw = httptest.NewRecorder() - r, _ = setupAuthentication(db) - rtr = chi.NewRouter() + db, _ = dbtestutil.NewDB(t) + rw = httptest.NewRecorder() + r, _ = setupAuthentication(db) + rtr = chi.NewRouter() ) rtr.Use( httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ @@ -62,10 +62,10 @@ func TestExtractChat(t *testing.T) { t.Run("InvalidUUID", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - rw = httptest.NewRecorder() - r, _ = setupAuthentication(db) - rtr = chi.NewRouter() + db, _ = dbtestutil.NewDB(t) + rw = httptest.NewRecorder() + r, _ = setupAuthentication(db) + rtr = chi.NewRouter() ) chi.RouteContext(r.Context()).URLParams.Add("chat", "not-a-uuid") rtr.Use( @@ -85,10 +85,10 @@ func TestExtractChat(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - rw = httptest.NewRecorder() - r, _ = setupAuthentication(db) - rtr = chi.NewRouter() + db, _ = dbtestutil.NewDB(t) + rw = httptest.NewRecorder() + r, _ = setupAuthentication(db) + rtr = chi.NewRouter() ) chi.RouteContext(r.Context()).URLParams.Add("chat", uuid.NewString()) rtr.Use( @@ -108,7 +108,7 @@ func TestExtractChat(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) rw = httptest.NewRecorder() r, user = setupAuthentication(db) rtr = chi.NewRouter() diff --git a/coderd/httpmw/groupparam_test.go b/coderd/httpmw/groupparam_test.go index a44fbc52df38b..52cfc05a07947 100644 --- a/coderd/httpmw/groupparam_test.go +++ b/coderd/httpmw/groupparam_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" ) @@ -23,11 +23,12 @@ func TestGroupParam(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - group = dbgen.Group(t, db, database.Group{}) + db, _ = dbtestutil.NewDB(t) r = httptest.NewRequest("GET", "/", nil) w = httptest.NewRecorder() ) + dbtestutil.DisableForeignKeysAndTriggers(t, db) + group := dbgen.Group(t, db, database.Group{}) router := chi.NewRouter() router.Use(httpmw.ExtractGroupParam(db)) @@ -52,11 +53,12 @@ func TestGroupParam(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - group = dbgen.Group(t, db, database.Group{}) + db, _ = dbtestutil.NewDB(t) r = httptest.NewRequest("GET", "/", nil) w = httptest.NewRecorder() ) + dbtestutil.DisableForeignKeysAndTriggers(t, db) + group := dbgen.Group(t, db, database.Group{}) router := chi.NewRouter() router.Use(httpmw.ExtractGroupParam(db)) diff --git a/coderd/httpmw/organizationparam_test.go b/coderd/httpmw/organizationparam_test.go index 68cc314abd26f..72101b89ca8aa 100644 --- a/coderd/httpmw/organizationparam_test.go +++ b/coderd/httpmw/organizationparam_test.go @@ -13,7 +13,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/rbac" @@ -42,10 +42,10 @@ func TestOrganizationParam(t *testing.T) { t.Run("None", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - rw = httptest.NewRecorder() - r, _ = setupAuthentication(db) - rtr = chi.NewRouter() + db, _ = dbtestutil.NewDB(t) + rw = httptest.NewRecorder() + r, _ = setupAuthentication(db) + rtr = chi.NewRouter() ) rtr.Use( httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ @@ -64,10 +64,10 @@ func TestOrganizationParam(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - rw = httptest.NewRecorder() - r, _ = setupAuthentication(db) - rtr = chi.NewRouter() + db, _ = dbtestutil.NewDB(t) + rw = httptest.NewRecorder() + r, _ = setupAuthentication(db) + rtr = chi.NewRouter() ) chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.NewString()) rtr.Use( @@ -87,10 +87,10 @@ func TestOrganizationParam(t *testing.T) { t.Run("InvalidUUID", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - rw = httptest.NewRecorder() - r, _ = setupAuthentication(db) - rtr = chi.NewRouter() + db, _ = dbtestutil.NewDB(t) + rw = httptest.NewRecorder() + r, _ = setupAuthentication(db) + rtr = chi.NewRouter() ) chi.RouteContext(r.Context()).URLParams.Add("organization", "not-a-uuid") rtr.Use( @@ -110,10 +110,10 @@ func TestOrganizationParam(t *testing.T) { t.Run("NotInOrganization", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - rw = httptest.NewRecorder() - r, u = setupAuthentication(db) - rtr = chi.NewRouter() + db, _ = dbtestutil.NewDB(t) + rw = httptest.NewRecorder() + r, u = setupAuthentication(db) + rtr = chi.NewRouter() ) organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{ ID: uuid.New(), @@ -144,7 +144,7 @@ func TestOrganizationParam(t *testing.T) { t.Parallel() var ( ctx = testutil.Context(t, testutil.WaitShort) - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) rw = httptest.NewRecorder() r, user = setupAuthentication(db) rtr = chi.NewRouter() diff --git a/coderd/httpmw/ratelimit_test.go b/coderd/httpmw/ratelimit_test.go index 1dd12da89df1a..51a05940fcbe7 100644 --- a/coderd/httpmw/ratelimit_test.go +++ b/coderd/httpmw/ratelimit_test.go @@ -14,7 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" ) @@ -70,7 +70,7 @@ func TestRateLimit(t *testing.T) { t.Run("RegularUser", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) u := dbgen.User(t, db, database.User{}) _, key := dbgen.APIKey(t, db, database.APIKey{UserID: u.ID}) @@ -113,7 +113,7 @@ func TestRateLimit(t *testing.T) { t.Run("OwnerBypass", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) u := dbgen.User(t, db, database.User{ RBACRoles: []string{codersdk.RoleOwner}, diff --git a/coderd/httpmw/templateparam_test.go b/coderd/httpmw/templateparam_test.go index 18b0b2f584e5f..49a97b5af76ea 100644 --- a/coderd/httpmw/templateparam_test.go +++ b/coderd/httpmw/templateparam_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" ) @@ -43,7 +43,7 @@ func TestTemplateParam(t *testing.T) { t.Run("None", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractTemplateParam(db)) rtr.Get("/", nil) @@ -58,7 +58,7 @@ func TestTemplateParam(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractTemplateParam(db)) rtr.Get("/", nil) @@ -75,7 +75,7 @@ func TestTemplateParam(t *testing.T) { t.Run("BadUUID", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractTemplateParam(db)) rtr.Get("/", nil) @@ -92,7 +92,8 @@ func TestTemplateParam(t *testing.T) { t.Run("Template", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) + dbtestutil.DisableForeignKeysAndTriggers(t, db) rtr := chi.NewRouter() rtr.Use( httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ diff --git a/coderd/httpmw/templateversionparam_test.go b/coderd/httpmw/templateversionparam_test.go index 3f67aafbcf191..06594322cacac 100644 --- a/coderd/httpmw/templateversionparam_test.go +++ b/coderd/httpmw/templateversionparam_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" ) @@ -21,6 +21,7 @@ func TestTemplateVersionParam(t *testing.T) { t.Parallel() setupAuthentication := func(db database.Store) (*http.Request, database.Template) { + dbtestutil.DisableForeignKeysAndTriggers(nil, db) user := dbgen.User(t, db, database.User{}) _, token := dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -47,7 +48,7 @@ func TestTemplateVersionParam(t *testing.T) { t.Run("None", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractTemplateVersionParam(db)) rtr.Get("/", nil) @@ -62,7 +63,7 @@ func TestTemplateVersionParam(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractTemplateVersionParam(db)) rtr.Get("/", nil) @@ -79,7 +80,7 @@ func TestTemplateVersionParam(t *testing.T) { t.Run("TemplateVersion", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use( httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ diff --git a/coderd/httpmw/userparam_test.go b/coderd/httpmw/userparam_test.go index bda00193e9a24..4c1fdd3458acd 100644 --- a/coderd/httpmw/userparam_test.go +++ b/coderd/httpmw/userparam_test.go @@ -11,7 +11,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" ) @@ -20,9 +20,9 @@ func TestUserParam(t *testing.T) { t.Parallel() setup := func(t *testing.T) (database.Store, *httptest.ResponseRecorder, *http.Request) { var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) user := dbgen.User(t, db, database.User{}) _, token := dbgen.APIKey(t, db, database.APIKey{ diff --git a/coderd/httpmw/workspaceagentparam_test.go b/coderd/httpmw/workspaceagentparam_test.go index 51e55b81e20a7..a9d6130966f5b 100644 --- a/coderd/httpmw/workspaceagentparam_test.go +++ b/coderd/httpmw/workspaceagentparam_test.go @@ -16,7 +16,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" ) @@ -67,7 +67,8 @@ func TestWorkspaceAgentParam(t *testing.T) { t.Run("None", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) + dbtestutil.DisableForeignKeysAndTriggers(t, db) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractWorkspaceBuildParam(db)) rtr.Get("/", nil) @@ -82,7 +83,8 @@ func TestWorkspaceAgentParam(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) + dbtestutil.DisableForeignKeysAndTriggers(t, db) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractWorkspaceAgentParam(db)) rtr.Get("/", nil) @@ -99,7 +101,8 @@ func TestWorkspaceAgentParam(t *testing.T) { t.Run("NotAuthorized", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) + dbtestutil.DisableForeignKeysAndTriggers(t, db) fakeAuthz := (&coderdtest.FakeAuthorizer{}).AlwaysReturn(xerrors.Errorf("constant failure")) dbFail := dbauthz.New(db, fakeAuthz, slog.Make(), coderdtest.AccessControlStorePointer()) @@ -129,7 +132,8 @@ func TestWorkspaceAgentParam(t *testing.T) { t.Run("WorkspaceAgent", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) + dbtestutil.DisableForeignKeysAndTriggers(t, db) rtr := chi.NewRouter() rtr.Use( httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ diff --git a/coderd/httpmw/workspacebuildparam_test.go b/coderd/httpmw/workspacebuildparam_test.go index e4bd4d10dafb2..b2469d07a52a9 100644 --- a/coderd/httpmw/workspacebuildparam_test.go +++ b/coderd/httpmw/workspacebuildparam_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" ) @@ -26,8 +26,15 @@ func TestWorkspaceBuildParam(t *testing.T) { _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, }) + org = dbgen.Organization(t, db, database.Organization{}) + tpl = dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) workspace = dbgen.Workspace(t, db, database.WorkspaceTable{ - OwnerID: user.ID, + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: tpl.ID, }) ) @@ -43,7 +50,7 @@ func TestWorkspaceBuildParam(t *testing.T) { t.Run("None", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractWorkspaceBuildParam(db)) rtr.Get("/", nil) @@ -58,7 +65,7 @@ func TestWorkspaceBuildParam(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractWorkspaceBuildParam(db)) rtr.Get("/", nil) @@ -75,7 +82,7 @@ func TestWorkspaceBuildParam(t *testing.T) { t.Run("WorkspaceBuild", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use( httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ @@ -91,10 +98,21 @@ func TestWorkspaceBuildParam(t *testing.T) { }) r, workspace := setupAuthentication(db) + tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{ + UUID: workspace.TemplateID, + Valid: true, + }, + OrganizationID: workspace.OrganizationID, + CreatedBy: workspace.OwnerID, + }) + pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{}) workspaceBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ - Transition: database.WorkspaceTransitionStart, - Reason: database.BuildReasonInitiator, - WorkspaceID: workspace.ID, + JobID: pj.ID, + TemplateVersionID: tv.ID, + Transition: database.WorkspaceTransitionStart, + Reason: database.BuildReasonInitiator, + WorkspaceID: workspace.ID, }) chi.RouteContext(r.Context()).URLParams.Add("workspacebuild", workspaceBuild.ID.String()) diff --git a/coderd/httpmw/workspaceparam_test.go b/coderd/httpmw/workspaceparam_test.go index 81f47d135f6ee..33b0c753068f7 100644 --- a/coderd/httpmw/workspaceparam_test.go +++ b/coderd/httpmw/workspaceparam_test.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/json" "fmt" + "net" "net/http" "net/http/httptest" "testing" @@ -12,12 +13,13 @@ import ( "github.com/go-chi/chi/v5" "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" @@ -46,6 +48,7 @@ func TestWorkspaceParam(t *testing.T) { CreatedAt: dbtime.Now(), UpdatedAt: dbtime.Now(), LoginType: database.LoginTypePassword, + RBACRoles: []string{}, }) require.NoError(t, err) @@ -64,6 +67,13 @@ func TestWorkspaceParam(t *testing.T) { ExpiresAt: dbtime.Now().Add(time.Minute), LoginType: database.LoginTypePassword, Scope: database.APIKeyScopeAll, + IPAddress: pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + Valid: true, + }, }) require.NoError(t, err) @@ -75,7 +85,7 @@ func TestWorkspaceParam(t *testing.T) { t.Run("None", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractWorkspaceParam(db)) rtr.Get("/", nil) @@ -90,7 +100,7 @@ func TestWorkspaceParam(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractWorkspaceParam(db)) rtr.Get("/", nil) @@ -106,7 +116,7 @@ func TestWorkspaceParam(t *testing.T) { t.Run("Found", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use( httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ @@ -120,11 +130,18 @@ func TestWorkspaceParam(t *testing.T) { rw.WriteHeader(http.StatusOK) }) r, user := setup(db) + org := dbgen.Organization(t, db, database.Organization{}) + tpl := dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{ ID: uuid.New(), OwnerID: user.ID, Name: "hello", AutomaticUpdates: database.AutomaticUpdatesNever, + OrganizationID: org.ID, + TemplateID: tpl.ID, }) require.NoError(t, err) chi.RouteContext(r.Context()).URLParams.Add("workspace", workspace.ID.String()) @@ -348,28 +365,45 @@ type setupConfig struct { func setupWorkspaceWithAgents(t testing.TB, cfg setupConfig) (database.Store, *http.Request) { t.Helper() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) var ( user = dbgen.User(t, db, database.User{}) _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, }) - workspace = dbgen.Workspace(t, db, database.WorkspaceTable{ - OwnerID: user.ID, - Name: cfg.WorkspaceName, + org = dbgen.Organization(t, db, database.Organization{}) + tpl = dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, }) - build = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ - WorkspaceID: workspace.ID, - Transition: database.WorkspaceTransitionStart, - Reason: database.BuildReasonInitiator, + workspace = dbgen.Workspace(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: tpl.ID, + Name: cfg.WorkspaceName, }) job = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ - ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, }) + tv = dbgen.TemplateVersion(t, db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{ + UUID: tpl.ID, + Valid: true, + }, + JobID: job.ID, + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + JobID: job.ID, + WorkspaceID: workspace.ID, + Transition: database.WorkspaceTransitionStart, + Reason: database.BuildReasonInitiator, + TemplateVersionID: tv.ID, + }) ) r := httptest.NewRequest("GET", "/", nil) diff --git a/coderd/httpmw/workspaceproxy_test.go b/coderd/httpmw/workspaceproxy_test.go index b0a028f3caee8..f35b97722ccd4 100644 --- a/coderd/httpmw/workspaceproxy_test.go +++ b/coderd/httpmw/workspaceproxy_test.go @@ -13,7 +13,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" @@ -33,9 +33,9 @@ func TestExtractWorkspaceProxy(t *testing.T) { t.Run("NoHeader", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{ DB: db, @@ -48,9 +48,9 @@ func TestExtractWorkspaceProxy(t *testing.T) { t.Run("InvalidFormat", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, "test:wow-hello") @@ -65,9 +65,9 @@ func TestExtractWorkspaceProxy(t *testing.T) { t.Run("InvalidID", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, "test:wow") @@ -82,9 +82,9 @@ func TestExtractWorkspaceProxy(t *testing.T) { t.Run("InvalidSecretLength", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, fmt.Sprintf("%s:%s", uuid.NewString(), "wow")) @@ -99,9 +99,9 @@ func TestExtractWorkspaceProxy(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) secret, err := cryptorand.HexString(64) @@ -119,9 +119,9 @@ func TestExtractWorkspaceProxy(t *testing.T) { t.Run("InvalidSecret", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() proxy, _ = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{}) ) @@ -142,9 +142,9 @@ func TestExtractWorkspaceProxy(t *testing.T) { t.Run("Valid", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() proxy, secret = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{}) ) @@ -165,9 +165,9 @@ func TestExtractWorkspaceProxy(t *testing.T) { t.Run("Deleted", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() proxy, secret = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{}) ) @@ -201,9 +201,9 @@ func TestExtractWorkspaceProxyParam(t *testing.T) { t.Run("OKName", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() proxy, _ = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{}) ) @@ -225,9 +225,9 @@ func TestExtractWorkspaceProxyParam(t *testing.T) { t.Run("OKID", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() proxy, _ = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{}) ) @@ -249,9 +249,9 @@ func TestExtractWorkspaceProxyParam(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db, _ = dbtestutil.NewDB(t) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) routeContext := chi.NewRouteContext() @@ -267,7 +267,7 @@ func TestExtractWorkspaceProxyParam(t *testing.T) { t.Run("FetchPrimary", func(t *testing.T) { t.Parallel() var ( - db = dbmem.New() + db, _ = dbtestutil.NewDB(t) r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() deploymentID = uuid.New() diff --git a/coderd/httpmw/workspaceresourceparam_test.go b/coderd/httpmw/workspaceresourceparam_test.go index 9549e8e6d3ecf..f6cb0772d262a 100644 --- a/coderd/httpmw/workspaceresourceparam_test.go +++ b/coderd/httpmw/workspaceresourceparam_test.go @@ -12,7 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbmem" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" ) @@ -21,6 +21,7 @@ func TestWorkspaceResourceParam(t *testing.T) { setup := func(t *testing.T, db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) { r := httptest.NewRequest("GET", "/", nil) + dbtestutil.DisableForeignKeysAndTriggers(t, db) job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ Type: jobType, Provisioner: database.ProvisionerTypeEcho, @@ -46,7 +47,7 @@ func TestWorkspaceResourceParam(t *testing.T) { t.Run("None", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractWorkspaceResourceParam(db)) rtr.Get("/", nil) @@ -61,7 +62,7 @@ func TestWorkspaceResourceParam(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use( httpmw.ExtractWorkspaceResourceParam(db), @@ -80,7 +81,7 @@ func TestWorkspaceResourceParam(t *testing.T) { t.Run("FoundBadJobType", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use( httpmw.ExtractWorkspaceResourceParam(db), @@ -102,7 +103,7 @@ func TestWorkspaceResourceParam(t *testing.T) { t.Run("Found", func(t *testing.T) { t.Parallel() - db := dbmem.New() + db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use( httpmw.ExtractWorkspaceResourceParam(db),