diff --git a/coderd/database/dbgen/generator.go b/coderd/database/dbgen/generator.go new file mode 100644 index 0000000000000..2d3c420fc6784 --- /dev/null +++ b/coderd/database/dbgen/generator.go @@ -0,0 +1,225 @@ +package dbgen + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/hex" + "fmt" + "testing" + "time" + + "github.com/coder/coder/cryptorand" + "github.com/tabbed/pqtype" + + "github.com/coder/coder/coderd/database" + "github.com/google/uuid" + "github.com/moby/moby/pkg/namesgenerator" + "github.com/stretchr/testify/require" +) + +// All methods take in a 'seed' object. Any provided fields in the seed will be +// maintained. Any fields omitted will have sensible defaults generated. + +func Template(t *testing.T, db database.Store, seed database.Template) database.Template { + template, err := db.InsertTemplate(context.Background(), database.InsertTemplateParams{ + ID: takeFirst(seed.ID, uuid.New()), + CreatedAt: takeFirst(seed.CreatedAt, time.Now()), + UpdatedAt: takeFirst(seed.UpdatedAt, time.Now()), + OrganizationID: takeFirst(seed.OrganizationID, uuid.New()), + Name: takeFirst(seed.Name, namesgenerator.GetRandomName(1)), + Provisioner: takeFirst(seed.Provisioner, database.ProvisionerTypeEcho), + ActiveVersionID: takeFirst(seed.ActiveVersionID, uuid.New()), + Description: takeFirst(seed.Description, namesgenerator.GetRandomName(1)), + DefaultTTL: takeFirst(seed.DefaultTTL, 3600), + CreatedBy: takeFirst(seed.CreatedBy, uuid.New()), + Icon: takeFirst(seed.Icon, namesgenerator.GetRandomName(1)), + UserACL: seed.UserACL, + GroupACL: seed.GroupACL, + DisplayName: takeFirst(seed.DisplayName, namesgenerator.GetRandomName(1)), + AllowUserCancelWorkspaceJobs: takeFirst(seed.AllowUserCancelWorkspaceJobs, true), + }) + require.NoError(t, err, "insert template") + return template +} + +func APIKey(t *testing.T, db database.Store, seed database.APIKey) (key database.APIKey, token string) { + id, _ := cryptorand.String(10) + secret, _ := cryptorand.String(22) + hashed := sha256.Sum256([]byte(secret)) + + key, err := db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ + ID: takeFirst(seed.ID, id), + // 0 defaults to 86400 at the db layer + LifetimeSeconds: takeFirst(seed.LifetimeSeconds, 0), + HashedSecret: takeFirstBytes(seed.HashedSecret, hashed[:]), + IPAddress: pqtype.Inet{}, + UserID: takeFirst(seed.UserID, uuid.New()), + LastUsed: takeFirst(seed.LastUsed, time.Now()), + ExpiresAt: takeFirst(seed.ExpiresAt, time.Now().Add(time.Hour)), + CreatedAt: takeFirst(seed.CreatedAt, time.Now()), + UpdatedAt: takeFirst(seed.UpdatedAt, time.Now()), + LoginType: takeFirst(seed.LoginType, database.LoginTypePassword), + Scope: takeFirst(seed.Scope, database.APIKeyScopeAll), + }) + require.NoError(t, err, "insert api key") + return key, fmt.Sprintf("%s-%s", key.ID, secret) +} + +func Workspace(t *testing.T, db database.Store, orig database.Workspace) database.Workspace { + workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{ + ID: takeFirst(orig.ID, uuid.New()), + OwnerID: takeFirst(orig.OwnerID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + TemplateID: takeFirst(orig.TemplateID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + AutostartSchedule: orig.AutostartSchedule, + Ttl: orig.Ttl, + }) + require.NoError(t, err, "insert workspace") + return workspace +} + +func WorkspaceBuild(t *testing.T, db database.Store, orig database.WorkspaceBuild) database.WorkspaceBuild { + build, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{ + ID: takeFirst(orig.ID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()), + TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()), + BuildNumber: takeFirst(orig.BuildNumber, 0), + Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), + InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), + JobID: takeFirst(orig.JobID, uuid.New()), + ProvisionerState: takeFirstBytes(orig.ProvisionerState, []byte{}), + Deadline: takeFirst(orig.Deadline, time.Now().Add(time.Hour)), + Reason: takeFirst(orig.Reason, database.BuildReasonInitiator), + }) + require.NoError(t, err, "insert workspace build") + return build +} + +func User(t *testing.T, db database.Store, orig database.User) database.User { + user, err := db.InsertUser(context.Background(), database.InsertUserParams{ + ID: takeFirst(orig.ID, uuid.New()), + Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)), + Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)), + HashedPassword: takeFirstBytes(orig.HashedPassword, []byte{}), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + RBACRoles: []string{}, + LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), + }) + require.NoError(t, err, "insert user") + return user +} + +func Organization(t *testing.T, db database.Store, orig database.Organization) database.Organization { + org, err := db.InsertOrganization(context.Background(), database.InsertOrganizationParams{ + ID: takeFirst(orig.ID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + }) + require.NoError(t, err, "insert organization") + return org +} + +func Group(t *testing.T, db database.Store, orig database.Group) database.Group { + group, err := db.InsertGroup(context.Background(), database.InsertGroupParams{ + ID: takeFirst(orig.ID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + AvatarURL: takeFirst(orig.AvatarURL, "https://logo.example.com"), + QuotaAllowance: takeFirst(orig.QuotaAllowance, 0), + }) + require.NoError(t, err, "insert group") + return group +} + +func ProvisionerJob(t *testing.T, db database.Store, orig database.ProvisionerJob) database.ProvisionerJob { + job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ + ID: takeFirst(orig.ID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), + Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), + StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile), + FileID: takeFirst(orig.FileID, uuid.New()), + Type: takeFirst(orig.Type, database.ProvisionerJobTypeWorkspaceBuild), + Input: takeFirstBytes(orig.Input, []byte("{}")), + Tags: orig.Tags, + }) + require.NoError(t, err, "insert job") + return job +} + +func WorkspaceResource(t *testing.T, db database.Store, orig database.WorkspaceResource) database.WorkspaceResource { + resource, err := db.InsertWorkspaceResource(context.Background(), database.InsertWorkspaceResourceParams{ + ID: takeFirst(orig.ID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + JobID: takeFirst(orig.JobID, uuid.New()), + Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), + Type: takeFirst(orig.Type, "fake_resource"), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Hide: takeFirst(orig.Hide, false), + Icon: takeFirst(orig.Icon, ""), + InstanceType: sql.NullString{ + String: takeFirst(orig.InstanceType.String, ""), + Valid: takeFirst(orig.InstanceType.Valid, false), + }, + DailyCost: takeFirst(orig.DailyCost, 0), + }) + require.NoError(t, err, "insert resource") + return resource +} + +func File(t *testing.T, db database.Store, orig database.File) database.File { + file, err := db.InsertFile(context.Background(), database.InsertFileParams{ + ID: takeFirst(orig.ID, uuid.New()), + Hash: takeFirst(orig.Hash, hex.EncodeToString(make([]byte, 32))), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + Mimetype: takeFirst(orig.Mimetype, "application/x-tar"), + Data: takeFirstBytes(orig.Data, []byte{}), + }) + require.NoError(t, err, "insert file") + return file +} + +func UserLink(t *testing.T, db database.Store, orig database.UserLink) database.UserLink { + link, err := db.InsertUserLink(context.Background(), database.InsertUserLinkParams{ + UserID: takeFirst(orig.UserID, uuid.New()), + LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub), + LinkedID: takeFirst(orig.LinkedID), + OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthExpiry: takeFirst(orig.OAuthExpiry, time.Now().Add(time.Hour*24)), + }) + + require.NoError(t, err, "insert link") + return link +} + +func TemplateVersion(t *testing.T, db database.Store, orig database.TemplateVersion) database.TemplateVersion { + version, err := db.InsertTemplateVersion(context.Background(), database.InsertTemplateVersionParams{ + ID: takeFirst(orig.ID, uuid.New()), + TemplateID: uuid.NullUUID{ + UUID: takeFirst(orig.TemplateID.UUID, uuid.New()), + Valid: takeFirst(orig.TemplateID.Valid, true), + }, + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Readme: takeFirst(orig.Readme, namesgenerator.GetRandomName(1)), + JobID: takeFirst(orig.JobID, uuid.New()), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + }) + require.NoError(t, err, "insert template version") + return version +} diff --git a/coderd/database/dbgen/generator_test.go b/coderd/database/dbgen/generator_test.go new file mode 100644 index 0000000000000..2266c866dbd09 --- /dev/null +++ b/coderd/database/dbgen/generator_test.go @@ -0,0 +1,107 @@ +package dbgen_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" +) + +func TestGenerator(t *testing.T) { + t.Parallel() + + t.Run("APIKey", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp, _ := dbgen.APIKey(t, db, database.APIKey{}) + require.Equal(t, exp, must(db.GetAPIKeyByID(context.Background(), exp.ID))) + }) + + t.Run("File", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.File(t, db, database.File{}) + require.Equal(t, exp, must(db.GetFileByID(context.Background(), exp.ID))) + }) + + t.Run("UserLink", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.UserLink(t, db, database.UserLink{}) + require.Equal(t, exp, must(db.GetUserLinkByLinkedID(context.Background(), exp.LinkedID))) + }) + + t.Run("WorkspaceResource", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{}) + require.Equal(t, exp, must(db.GetWorkspaceResourceByID(context.Background(), exp.ID))) + }) + + t.Run("Job", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{}) + require.Equal(t, exp, must(db.GetProvisionerJobByID(context.Background(), exp.ID))) + }) + + t.Run("Group", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.Group(t, db, database.Group{}) + require.Equal(t, exp, must(db.GetGroupByID(context.Background(), exp.ID))) + }) + + t.Run("Organization", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.Organization(t, db, database.Organization{}) + require.Equal(t, exp, must(db.GetOrganizationByID(context.Background(), exp.ID))) + }) + + t.Run("Workspace", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.Workspace(t, db, database.Workspace{}) + require.Equal(t, exp, must(db.GetWorkspaceByID(context.Background(), exp.ID))) + }) + + t.Run("Template", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.Template(t, db, database.Template{}) + require.Equal(t, exp, must(db.GetTemplateByID(context.Background(), exp.ID))) + }) + + t.Run("TemplateVersion", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.TemplateVersion(t, db, database.TemplateVersion{}) + require.Equal(t, exp, must(db.GetTemplateVersionByID(context.Background(), exp.ID))) + }) + + t.Run("WorkspaceBuild", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{}) + require.Equal(t, exp, must(db.GetWorkspaceBuildByID(context.Background(), exp.ID))) + }) + + t.Run("User", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + exp := dbgen.User(t, db, database.User{}) + require.Equal(t, exp, must(db.GetUserByID(context.Background(), exp.ID))) + }) +} + +func must[T any](value T, err error) T { + if err != nil { + panic(err) + } + return value +} diff --git a/coderd/database/dbgen/take.go b/coderd/database/dbgen/take.go new file mode 100644 index 0000000000000..717f2c0441cc3 --- /dev/null +++ b/coderd/database/dbgen/take.go @@ -0,0 +1,29 @@ +package dbgen + +// takeFirstBytes implements takeFirst for []byte. +// []byte is not a comparable type. +func takeFirstBytes(values ...[]byte) []byte { + return takeFirstF(values, func(v []byte) bool { + return len(v) != 0 + }) +} + +// takeFirstF takes the first value that returns true +func takeFirstF[Value any](values []Value, take func(v Value) bool) Value { + var empty Value + for _, v := range values { + if take(v) { + return v + } + } + // If all empty, return empty + return empty +} + +// takeFirst will take the first non-empty value. +func takeFirst[Value comparable](values ...Value) Value { + var empty Value + return takeFirstF(values, func(v Value) bool { + return v != empty + }) +} diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index c4266646ed442..425999eb9f6f9 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -11,13 +11,13 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/codersdk" @@ -151,24 +151,19 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidSecret", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + user = dbgen.User(t, db, database.User{}) + + // Use a different secret so they don't match! + hashed = sha256.Sum256([]byte("differentsecret")) + _, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + HashedSecret: hashed[:], + }) ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) - - // Use a different secret so they don't match! - hashed := sha256.Sum256([]byte("differentsecret")) - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) + r.Header.Set(codersdk.SessionTokenHeader, token) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -181,23 +176,18 @@ func TestAPIKey(t *testing.T) { t.Run("Expired", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + user = dbgen.User(t, db, database.User{}) + _, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + ExpiresAt: time.Now().Add(time.Hour * -1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + r.Header.Set(codersdk.SessionTokenHeader, token) - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -210,24 +200,18 @@ func TestAPIKey(t *testing.T) { t.Run("Valid", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + ExpiresAt: database.Now().AddDate(0, 0, 1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + r.Header.Set(codersdk.SessionTokenHeader, token) - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -242,7 +226,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) @@ -251,27 +235,21 @@ func TestAPIKey(t *testing.T) { t.Run("ValidWithScope", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + user = dbgen.User(t, db, database.User{}) + _, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + ExpiresAt: database.Now().AddDate(0, 0, 1), + Scope: database.APIKeyScopeApplicationConnect, + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ Name: codersdk.SessionTokenCookie, - Value: fmt.Sprintf("%s-%s", id, secret), - }) - - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - UserID: user.ID, - HashedSecret: hashed[:], - ExpiresAt: database.Now().AddDate(0, 0, 1), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeApplicationConnect, + Value: token, }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, @@ -294,26 +272,20 @@ func TestAPIKey(t *testing.T) { t.Run("QueryParameter", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + user = dbgen.User(t, db, database.User{}) + _, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + ExpiresAt: database.Now().AddDate(0, 0, 1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) q := r.URL.Query() - q.Add(codersdk.SessionTokenCookie, fmt.Sprintf("%s-%s", id, secret)) + q.Add(codersdk.SessionTokenCookie, token) r.URL.RawQuery = q.Encode() - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -332,25 +304,19 @@ func TestAPIKey(t *testing.T) { t.Run("ValidUpdateLastUsed", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now().AddDate(0, 0, -1), + ExpiresAt: database.Now().AddDate(0, 0, 1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + r.Header.Set(codersdk.SessionTokenHeader, token) - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LastUsed: database.Now().AddDate(0, 0, -1), - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -359,7 +325,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.NotEqual(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) @@ -369,25 +335,19 @@ func TestAPIKey(t *testing.T) { t.Run("ValidUpdateExpiry", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now(), + ExpiresAt: database.Now().Add(time.Minute), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + r.Header.Set(codersdk.SessionTokenHeader, token) - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LastUsed: database.Now(), - ExpiresAt: database.Now().Add(time.Minute), - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -396,7 +356,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) @@ -406,31 +366,23 @@ func TestAPIKey(t *testing.T) { t.Run("OAuthNotExpired", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) - ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) - - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LoginType: database.LoginTypeGithub, - LastUsed: database.Now(), - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) + db = databasefake.New() + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now(), + ExpiresAt: database.Now().AddDate(0, 0, 1), + LoginType: database.LoginTypeGithub, + }) + _ = dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + }) - _, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{ - UserID: user.ID, - LoginType: database.LoginTypeGithub, - }) - require.NoError(t, err) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.Header.Set(codersdk.SessionTokenHeader, token) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, @@ -440,7 +392,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) @@ -450,33 +402,27 @@ func TestAPIKey(t *testing.T) { t.Run("OAuthRefresh", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) - ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + db = databasefake.New() + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now(), + ExpiresAt: database.Now().AddDate(0, 0, 1), + LoginType: database.LoginTypeGithub, + }) + _ = dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + OAuthRefreshToken: "hello", + OAuthExpiry: database.Now().AddDate(0, 0, -1), + }) - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LoginType: database.LoginTypeGithub, - LastUsed: database.Now(), - UserID: user.ID, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) - _, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{ - UserID: user.ID, - LoginType: database.LoginTypeGithub, - OAuthExpiry: database.Now().AddDate(0, 0, -1), - OAuthRefreshToken: "hello", - }) - require.NoError(t, err) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.Header.Set(codersdk.SessionTokenHeader, token) - token := &oauth2.Token{ + oauthToken := &oauth2.Token{ AccessToken: "wow", RefreshToken: "moo", Expiry: database.Now().AddDate(0, 0, 1), @@ -486,7 +432,7 @@ func TestAPIKey(t *testing.T) { OAuth2Configs: &httpmw.OAuth2Configs{ Github: &oauth2Config{ tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) { - return token, nil + return oauthToken, nil }), }, }, @@ -496,36 +442,30 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) - require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt) + require.Equal(t, oauthToken.Expiry, gotAPIKey.ExpiresAt) }) t.Run("RemoteIPUpdates", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now().AddDate(0, 0, -1), + ExpiresAt: database.Now().AddDate(0, 0, 1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) r.RemoteAddr = "1.1.1.1" - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + r.Header.Set(codersdk.SessionTokenHeader, token) - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LastUsed: database.Now().AddDate(0, 0, -1), - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -534,7 +474,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, net.ParseIP("1.1.1.1"), gotAPIKey.IPAddress.IPNet.IP) @@ -595,25 +535,19 @@ func TestAPIKey(t *testing.T) { t.Run("Tokens", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) - ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + db = databasefake.New() + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now(), + ExpiresAt: database.Now().AddDate(0, 0, 1), + LoginType: database.LoginTypeToken, + }) - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LoginType: database.LoginTypeToken, - LastUsed: database.Now(), - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.Header.Set(codersdk.SessionTokenHeader, token) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, @@ -623,7 +557,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) @@ -632,25 +566,6 @@ func TestAPIKey(t *testing.T) { }) } -func createUser(ctx context.Context, t *testing.T, db database.Store, opts ...func(u *database.InsertUserParams)) database.User { - insert := database.InsertUserParams{ - ID: uuid.New(), - Email: "email@coder.com", - Username: "username", - HashedPassword: []byte{}, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - RBACRoles: []string{}, - LoginType: database.LoginTypePassword, - } - for _, opt := range opts { - opt(&insert) - } - user, err := db.InsertUser(ctx, insert) - require.NoError(t, err, "create user") - return user -} - type oauth2Config struct { tokenSource oauth2TokenSource } diff --git a/coderd/httpmw/groupparam_test.go b/coderd/httpmw/groupparam_test.go index 70850de4ce9be..28038f5d03c3d 100644 --- a/coderd/httpmw/groupparam_test.go +++ b/coderd/httpmw/groupparam_test.go @@ -12,46 +12,21 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/httpmw" - "github.com/coder/coder/testutil" ) func TestGroupParam(t *testing.T) { t.Parallel() - setup := func(t *testing.T) (database.Store, database.Group) { - t.Helper() - - ctx, _ := testutil.Context(t) - db := databasefake.New() - - orgID := uuid.New() - organization, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ - ID: orgID, - Name: "banana", - Description: "wowie", - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - }) - require.NoError(t, err) - - group, err := db.InsertGroup(ctx, database.InsertGroupParams{ - ID: uuid.New(), - Name: "yeww", - OrganizationID: organization.ID, - }) - require.NoError(t, err) - - return db, group - } - t.Run("OK", func(t *testing.T) { t.Parallel() var ( - db, group = setup(t) - r = httptest.NewRequest("GET", "/", nil) - w = httptest.NewRecorder() + db = databasefake.New() + group = dbgen.Group(t, db, database.Group{}) + r = httptest.NewRequest("GET", "/", nil) + w = httptest.NewRecorder() ) router := chi.NewRouter() @@ -77,9 +52,10 @@ func TestGroupParam(t *testing.T) { t.Parallel() var ( - db, group = setup(t) - r = httptest.NewRequest("GET", "/", nil) - w = httptest.NewRecorder() + db = databasefake.New() + group = dbgen.Group(t, db, database.Group{}) + r = httptest.NewRequest("GET", "/", nil) + w = httptest.NewRecorder() ) router := chi.NewRouter() diff --git a/coderd/httpmw/ratelimit_test.go b/coderd/httpmw/ratelimit_test.go index 61a6f5d903566..e004fb3ed3ed0 100644 --- a/coderd/httpmw/ratelimit_test.go +++ b/coderd/httpmw/ratelimit_test.go @@ -1,8 +1,6 @@ package httpmw_test import ( - "context" - "crypto/sha256" "fmt" "math/rand" "net" @@ -12,35 +10,17 @@ import ( "time" "github.com/go-chi/chi/v5" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" "github.com/coder/coder/testutil" ) -func insertAPIKey(ctx context.Context, t *testing.T, db database.Store, userID uuid.UUID) string { - id, secret := randomAPIKeyParts() - hashed := sha256.Sum256([]byte(secret)) - - _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LastUsed: database.Now().AddDate(0, 0, -1), - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: userID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) - - return fmt.Sprintf("%s-%s", id, secret) -} - func randRemoteAddr() string { var b [4]byte // nolint:gosec @@ -91,12 +71,9 @@ func TestRateLimit(t *testing.T) { t.Run("RegularUser", func(t *testing.T) { t.Parallel() - ctx := context.Background() - db := databasefake.New() - - u := createUser(ctx, t, db) - key := insertAPIKey(ctx, t, db, u.ID) + u := dbgen.User(t, db, database.User{}) + _, key := dbgen.APIKey(t, db, database.APIKey{UserID: u.ID}) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ @@ -137,15 +114,12 @@ func TestRateLimit(t *testing.T) { t.Run("OwnerBypass", func(t *testing.T) { t.Parallel() - ctx := context.Background() - db := databasefake.New() - u := createUser(ctx, t, db, func(u *database.InsertUserParams) { - u.RBACRoles = []string{rbac.RoleOwner()} + u := dbgen.User(t, db, database.User{ + RBACRoles: []string{rbac.RoleOwner()}, }) - - key := insertAPIKey(ctx, t, db, u.ID) + _, key := dbgen.APIKey(t, db, database.APIKey{UserID: u.ID}) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ diff --git a/coderd/httpmw/workspaceresourceparam_test.go b/coderd/httpmw/workspaceresourceparam_test.go index 8d83160eaf20d..b2f222f21a33c 100644 --- a/coderd/httpmw/workspaceresourceparam_test.go +++ b/coderd/httpmw/workspaceresourceparam_test.go @@ -12,38 +12,35 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/httpmw" ) func TestWorkspaceResourceParam(t *testing.T) { t.Parallel() - setup := func(db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) { + setup := func(t *testing.T, db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) { r := httptest.NewRequest("GET", "/", nil) - job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ - ID: uuid.New(), + job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ Type: jobType, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, }) - require.NoError(t, err) - workspaceBuild, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{ - ID: uuid.New(), + + build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ JobID: job.ID, Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) - require.NoError(t, err) - resource, err := db.InsertWorkspaceResource(context.Background(), database.InsertWorkspaceResourceParams{ - ID: uuid.New(), + + resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ JobID: job.ID, Transition: database.WorkspaceTransitionStart, }) - require.NoError(t, err) - ctx := chi.NewRouteContext() - ctx.URLParams.Add("workspacebuild", workspaceBuild.ID.String()) - r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx)) + routeCtx := chi.NewRouteContext() + routeCtx.URLParams.Add("workspacebuild", build.ID.String()) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) return r, resource } @@ -53,7 +50,7 @@ func TestWorkspaceResourceParam(t *testing.T) { rtr := chi.NewRouter() rtr.Use(httpmw.ExtractWorkspaceResourceParam(db)) rtr.Get("/", nil) - r, _ := setup(db, database.ProvisionerJobTypeWorkspaceBuild) + r, _ := setup(t, db, database.ProvisionerJobTypeWorkspaceBuild) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -71,7 +68,7 @@ func TestWorkspaceResourceParam(t *testing.T) { ) rtr.Get("/", nil) - r, _ := setup(db, database.ProvisionerJobTypeWorkspaceBuild) + r, _ := setup(t, db, database.ProvisionerJobTypeWorkspaceBuild) chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", uuid.NewString()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -93,7 +90,7 @@ func TestWorkspaceResourceParam(t *testing.T) { rw.WriteHeader(http.StatusOK) }) - r, job := setup(db, database.ProvisionerJobTypeTemplateVersionImport) + r, job := setup(t, db, database.ProvisionerJobTypeTemplateVersionImport) chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", job.ID.String()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -115,7 +112,7 @@ func TestWorkspaceResourceParam(t *testing.T) { rw.WriteHeader(http.StatusOK) }) - r, job := setup(db, database.ProvisionerJobTypeWorkspaceBuild) + r, job := setup(t, db, database.ProvisionerJobTypeWorkspaceBuild) chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", job.ID.String()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index 0ebf7fa09228c..424593b6f282d 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -13,6 +13,7 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/prometheusmetrics" "github.com/coder/coder/codersdk" "github.com/coder/coder/testutil" @@ -23,63 +24,50 @@ func TestActiveUsers(t *testing.T) { for _, tc := range []struct { Name string - Database func() database.Store + Database func(t *testing.T) database.Store Count int }{{ Name: "None", - Database: func() database.Store { + Database: func(t *testing.T) database.Store { return databasefake.New() }, Count: 0, }, { Name: "One", - Database: func() database.Store { + Database: func(t *testing.T) database.Store { db := databasefake.New() - _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ - UserID: uuid.New(), - LastUsed: database.Now(), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, + dbgen.APIKey(t, db, database.APIKey{ + LastUsed: database.Now(), }) return db }, Count: 1, }, { Name: "OneWithExpired", - Database: func() database.Store { + Database: func(t *testing.T) database.Store { db := databasefake.New() - _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ - UserID: uuid.New(), - LastUsed: database.Now(), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, + + dbgen.APIKey(t, db, database.APIKey{ + LastUsed: database.Now(), }) + // Because this API key hasn't been used in the past hour, this shouldn't // add to the user count. - _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ - UserID: uuid.New(), - LastUsed: database.Now().Add(-2 * time.Hour), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, + dbgen.APIKey(t, db, database.APIKey{ + LastUsed: database.Now().Add(-2 * time.Hour), }) return db }, Count: 1, }, { Name: "Multiple", - Database: func() database.Store { + Database: func(t *testing.T) database.Store { db := databasefake.New() - _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ - UserID: uuid.New(), - LastUsed: database.Now(), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, + dbgen.APIKey(t, db, database.APIKey{ + LastUsed: database.Now(), }) - _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ - UserID: uuid.New(), - LastUsed: database.Now(), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, + dbgen.APIKey(t, db, database.APIKey{ + LastUsed: database.Now(), }) return db }, @@ -89,7 +77,7 @@ func TestActiveUsers(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { t.Parallel() registry := prometheus.NewRegistry() - cancel, err := prometheusmetrics.ActiveUsers(context.Background(), registry, tc.Database(), time.Millisecond) + cancel, err := prometheusmetrics.ActiveUsers(context.Background(), registry, tc.Database(t), time.Millisecond) require.NoError(t, err) t.Cleanup(cancel) diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index f9a4d74782647..6032a8497dd98 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -16,6 +16,7 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/codersdk" @@ -87,36 +88,38 @@ func TestAcquireJob(t *testing.T) { t.Parallel() srv := setup(t, false) ctx := context.Background() - user, err := srv.Database.InsertUser(context.Background(), database.InsertUserParams{ - ID: uuid.New(), - Username: "testing", - LoginType: database.LoginTypePassword, - }) - require.NoError(t, err) - template, err := srv.Database.InsertTemplate(ctx, database.InsertTemplateParams{ - ID: uuid.New(), + + user := dbgen.User(t, srv.Database, database.User{}) + template := dbgen.Template(t, srv.Database, database.Template{ Name: "template", Provisioner: database.ProvisionerTypeEcho, }) - require.NoError(t, err) - version, err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ - ID: uuid.New(), + file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + versionFile := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + version := dbgen.TemplateVersion(t, srv.Database, database.TemplateVersion{ TemplateID: uuid.NullUUID{ UUID: template.ID, Valid: true, }, JobID: uuid.New(), }) - require.NoError(t, err) - workspace, err := srv.Database.InsertWorkspace(ctx, database.InsertWorkspaceParams{ - ID: uuid.New(), - OwnerID: user.ID, + // Import version job + _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + ID: version.JobID, + InitiatorID: user.ID, + FileID: versionFile.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeTemplateVersionImport, + Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{ + TemplateVersionID: version.ID, + })), + }) + workspace := dbgen.Workspace(t, srv.Database, database.Workspace{ TemplateID: template.ID, - Name: "workspace", + OwnerID: user.ID, }) - require.NoError(t, err) - build, err := srv.Database.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ - ID: uuid.New(), + build := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{ WorkspaceID: workspace.ID, BuildNumber: 1, JobID: uuid.New(), @@ -124,33 +127,17 @@ func TestAcquireJob(t *testing.T) { Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) - require.NoError(t, err) - - data, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ - WorkspaceBuildID: build.ID, - }) - require.NoError(t, err) - - file, err := srv.Database.InsertFile(ctx, database.InsertFileParams{ - ID: uuid.New(), - Hash: "something", - Data: []byte{}, - }) - require.NoError(t, err) - - _, err = srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ - ID: build.JobID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - OrganizationID: uuid.New(), - InitiatorID: user.ID, - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - FileID: file.ID, - Type: database.ProvisionerJobTypeWorkspaceBuild, - Input: data, + _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + ID: build.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ + WorkspaceBuildID: build.ID, + })), }) - require.NoError(t, err) published := make(chan struct{}) closeSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { @@ -159,8 +146,17 @@ func TestAcquireJob(t *testing.T) { require.NoError(t, err) defer closeSubscribe() - job, err := srv.AcquireJob(ctx, nil) - require.NoError(t, err) + var job *proto.AcquiredJob + + for { + // Grab jobs until we find the workspace build job. There is also + // an import version job that we need to ignore. + job, err = srv.AcquireJob(ctx, nil) + require.NoError(t, err) + if _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok { + break + } + } <-published @@ -191,44 +187,22 @@ func TestAcquireJob(t *testing.T) { t.Parallel() srv := setup(t, false) ctx := context.Background() - user, err := srv.Database.InsertUser(ctx, database.InsertUserParams{ - ID: uuid.New(), - Username: "testing", - LoginType: database.LoginTypePassword, - }) - require.NoError(t, err) - version, err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ - ID: uuid.New(), - }) - require.NoError(t, err) - data, err := json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ - TemplateVersionID: version.ID, - WorkspaceName: "testing", - ParameterValues: []database.ParameterValue{}, - }) - require.NoError(t, err) - - file, err := srv.Database.InsertFile(ctx, database.InsertFileParams{ - ID: uuid.New(), - Hash: "something", - Data: []byte{}, - }) - require.NoError(t, err) - - _, err = srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ - ID: uuid.New(), - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - OrganizationID: uuid.New(), - InitiatorID: user.ID, - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - FileID: file.ID, - Type: database.ProvisionerJobTypeTemplateVersionDryRun, - Input: data, + user := dbgen.User(t, srv.Database, database.User{}) + version := dbgen.TemplateVersion(t, srv.Database, database.TemplateVersion{}) + file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: version.ID, + WorkspaceName: "testing", + ParameterValues: []database.ParameterValue{}, + })), }) - require.NoError(t, err) job, err := srv.AcquireJob(ctx, nil) require.NoError(t, err) @@ -252,33 +226,16 @@ func TestAcquireJob(t *testing.T) { t.Parallel() srv := setup(t, false) ctx := context.Background() - user, err := srv.Database.InsertUser(ctx, database.InsertUserParams{ - ID: uuid.New(), - Username: "testing", - LoginType: database.LoginTypePassword, - }) - require.NoError(t, err) - - file, err := srv.Database.InsertFile(ctx, database.InsertFileParams{ - ID: uuid.New(), - Hash: "something", - Data: []byte{}, - }) - require.NoError(t, err) - _, err = srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ - ID: uuid.New(), - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - OrganizationID: uuid.New(), - InitiatorID: user.ID, - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - FileID: file.ID, - Type: database.ProvisionerJobTypeTemplateVersionImport, - Input: json.RawMessage{}, + user := dbgen.User(t, srv.Database, database.User{}) + file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + FileID: file.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeTemplateVersionImport, }) - require.NoError(t, err) job, err := srv.AcquireJob(ctx, nil) require.NoError(t, err) @@ -855,3 +812,10 @@ func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server { Auditor: mockAuditor(), } } + +func must[T any](value T, err error) T { + if err != nil { + panic(err) + } + return value +} diff --git a/coderd/workspaceapps_internal_test.go b/coderd/workspaceapps_internal_test.go index 1e298000d9dd5..f35d904b397af 100644 --- a/coderd/workspaceapps_internal_test.go +++ b/coderd/workspaceapps_internal_test.go @@ -10,42 +10,28 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/testutil" ) func TestAPIKeyEncryption(t *testing.T) { t.Parallel() - generateAPIKey := func(t *testing.T, db database.Store) (keyID, keySecret string, hashedSecret []byte, data encryptedAPIKeyPayload) { - keyID, keySecret, err := GenerateAPIKeyIDSecret() - require.NoError(t, err) + generateAPIKey := func(t *testing.T, db database.Store) (keyID, keyToken string, hashedSecret []byte, data encryptedAPIKeyPayload) { + key, token := dbgen.APIKey(t, db, database.APIKey{}) - hashedSecretArray := sha256.Sum256([]byte(keySecret)) data = encryptedAPIKeyPayload{ - APIKey: keyID + "-" + keySecret, + APIKey: token, ExpiresAt: database.Now().Add(24 * time.Hour), } - return keyID, keySecret, hashedSecretArray[:], data - } - insertAPIKey := func(t *testing.T, db database.Store, keyID string, hashedSecret []byte) { - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ - ID: keyID, - HashedSecret: hashedSecret, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) + return key.ID, token, key.HashedSecret[:], data } t.Run("OK", func(t *testing.T) { t.Parallel() db := databasefake.New() keyID, _, hashedSecret, data := generateAPIKey(t, db) - insertAPIKey(t, db, keyID, hashedSecret) encrypted, err := encryptAPIKey(data) require.NoError(t, err) @@ -66,8 +52,7 @@ func TestAPIKeyEncryption(t *testing.T) { t.Run("Expiry", func(t *testing.T) { t.Parallel() db := databasefake.New() - keyID, _, hashedSecret, data := generateAPIKey(t, db) - insertAPIKey(t, db, keyID, hashedSecret) + _, _, _, data := generateAPIKey(t, db) data.ExpiresAt = database.Now().Add(-1 * time.Hour) encrypted, err := encryptAPIKey(data) @@ -84,9 +69,17 @@ func TestAPIKeyEncryption(t *testing.T) { t.Run("KeyMatches", func(t *testing.T) { t.Parallel() db := databasefake.New() - keyID, _, _, data := generateAPIKey(t, db) + hashedSecret := sha256.Sum256([]byte("wrong")) - insertAPIKey(t, db, keyID, hashedSecret[:]) + // Insert a token with a mismatched hashed secret. + _, token := dbgen.APIKey(t, db, database.APIKey{ + HashedSecret: hashedSecret[:], + }) + + data := encryptedAPIKeyPayload{ + APIKey: token, + ExpiresAt: database.Now().Add(24 * time.Hour), + } encrypted, err := encryptAPIKey(data) require.NoError(t, err)