From 50ea251b9f368837c6110fa1467aadccbc92f1b7 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 30 Jan 2023 11:47:00 -0600 Subject: [PATCH 01/21] feat: Add database data generator to make fakedbs easier to populate One example shown in groupparm_test.go. More examples coming --- coderd/database/databasefake/generator.go | 212 ++++++++++++++++++++++ coderd/httpmw/groupparam_test.go | 19 +- 2 files changed, 214 insertions(+), 17 deletions(-) create mode 100644 coderd/database/databasefake/generator.go diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go new file mode 100644 index 0000000000000..c9a9c06d317e3 --- /dev/null +++ b/coderd/database/databasefake/generator.go @@ -0,0 +1,212 @@ +package databasefake + +import ( + "context" + "testing" + "time" + + "github.com/coder/coder/coderd/database" + "github.com/google/uuid" + "github.com/moby/moby/pkg/namesgenerator" + "github.com/stretchr/testify/require" +) + +const primaryOrgName = "primary-org" + +type Generator struct { + // names is a map of names to uuids. + names map[string]uuid.UUID + primaryOrg *database.Organization + testT *testing.T + + db database.Store +} + +func NewGenerator(t *testing.T, db database.Store) *Generator { + if _, ok := db.(FakeDatabase); !ok { + panic("Generator db must be a FakeDatabase") + } + return &Generator{ + names: make(map[string]uuid.UUID), + testT: t, + db: db, + } +} + +// PrimaryOrg is to keep all resources in the same default org if not +// specified. +func (g *Generator) PrimaryOrg(ctx context.Context) database.Organization { + if g.primaryOrg == nil { + org := g.Organization(ctx, "primary-org", database.Organization{ + ID: g.Lookup(primaryOrgName), + Name: primaryOrgName, + Description: "This is the default primary organization for all tests", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + g.primaryOrg = &org + } + + return *g.primaryOrg +} + +func populate[DBType any](ctx context.Context, g *Generator, name string, seed DBType) DBType { + out := g.Populate(ctx, map[string]interface{}{ + name: seed, + }) + return out[name].(DBType) +} + +func (g *Generator) Group(ctx context.Context, name string, seed database.Group) database.Group { + return populate(ctx, g, name, seed) +} + +func (g *Generator) Organization(ctx context.Context, name string, seed database.Organization) database.Organization { + return populate(ctx, g, name, seed) +} + +func (g *Generator) Workspace(ctx context.Context, name string, seed database.Workspace) database.Workspace { + return populate(ctx, g, name, seed) +} + +func (g *Generator) Template(ctx context.Context, name string, seed database.Template) database.Template { + return populate(ctx, g, name, seed) +} + +func (g *Generator) TemplateVersion(ctx context.Context, name string, seed database.TemplateVersion) database.TemplateVersion { + return populate(ctx, g, name, seed) +} + +func (g *Generator) WorkspaceBuild(ctx context.Context, name string, seed database.WorkspaceBuild) database.WorkspaceBuild { + return populate(ctx, g, name, seed) +} + +func (g *Generator) User(ctx context.Context, name string, seed database.User) database.User { + return populate(ctx, g, name, seed) +} + +func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) map[string]interface{} { + db := g.db + t := g.testT + + for name, v := range seed { + switch orig := v.(type) { + case database.Template: + template, err := db.InsertTemplate(ctx, database.InsertTemplateParams{ + ID: g.Lookup(name), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), + ActiveVersionID: takeFirst(orig.ActiveVersionID, uuid.New()), + Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), + DefaultTTL: takeFirst(orig.DefaultTTL, 3600), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + Icon: takeFirst(orig.Icon, namesgenerator.GetRandomName(1)), + UserACL: orig.UserACL, + GroupACL: orig.GroupACL, + DisplayName: takeFirst(orig.DisplayName, namesgenerator.GetRandomName(1)), + AllowUserCancelWorkspaceJobs: takeFirst(orig.AllowUserCancelWorkspaceJobs, true), + }) + require.NoError(t, err, "insert template") + + seed[name] = template + case database.Workspace: + workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ + ID: g.Lookup(name), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), + TemplateID: takeFirst(orig.TemplateID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + AutostartSchedule: orig.AutostartSchedule, + Ttl: orig.Ttl, + }) + require.NoError(t, err, "insert workspace") + + seed[name] = workspace + case database.WorkspaceBuild: + build, err := db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ + ID: g.Lookup(name), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()), + TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()), + BuildNumber: takeFirst(orig.BuildNumber, 0), + Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), + InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), + JobID: takeFirst(orig.InitiatorID, uuid.New()), + ProvisionerState: []byte{}, + Deadline: time.Now(), + Reason: takeFirst(orig.Reason, database.BuildReasonInitiator), + }) + require.NoError(t, err, "insert workspace build") + + seed[name] = build + case database.User: + user, err := db.InsertUser(ctx, database.InsertUserParams{ + ID: g.Lookup(name), + Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)), + Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)), + HashedPassword: []byte{}, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + RBACRoles: []string{}, + LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), + }) + require.NoError(t, err, "insert user") + + seed[name] = user + + case database.Organization: + org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ + ID: g.Lookup(name), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Description: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + require.NoError(t, err, "insert organization") + + seed[name] = org + + case database.Group: + org, err := db.InsertGroup(ctx, database.InsertGroupParams{ + ID: g.Lookup(name), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), + AvatarURL: takeFirst(orig.Name, "https://logo.example.com"), + QuotaAllowance: takeFirst(orig.QuotaAllowance, 0), + }) + require.NoError(t, err, "insert organization") + + seed[name] = org + } + } + return seed +} + +func (tc *Generator) Lookup(name string) uuid.UUID { + if tc.names == nil { + tc.names = make(map[string]uuid.UUID) + } + if id, ok := tc.names[name]; ok { + return id + } + id := uuid.New() + tc.names[name] = id + return id +} + +// takeFirst will take the first non-empty value. +func takeFirst[Value comparable](values ...Value) Value { + var empty Value + for _, v := range values { + if v != empty { + return v + } + } + // If all empty, return empty + return empty +} diff --git a/coderd/httpmw/groupparam_test.go b/coderd/httpmw/groupparam_test.go index 70850de4ce9be..de32d58aeb9bb 100644 --- a/coderd/httpmw/groupparam_test.go +++ b/coderd/httpmw/groupparam_test.go @@ -24,23 +24,8 @@ func TestGroupParam(t *testing.T) { ctx, _ := testutil.Context(t) db := databasefake.New() - - orgID := uuid.New() - organization, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ - ID: orgID, - Name: "banana", - Description: "wowie", - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - }) - require.NoError(t, err) - - group, err := db.InsertGroup(ctx, database.InsertGroupParams{ - ID: uuid.New(), - Name: "yeww", - OrganizationID: organization.ID, - }) - require.NoError(t, err) + gen := databasefake.NewGenerator(t, db) + group := gen.Group(ctx, "group", database.Group{}) return db, group } From 722194bd11194fa1fc6bcfa2f0507f04a9f33c4b Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 30 Jan 2023 13:46:52 -0600 Subject: [PATCH 02/21] Add resources and jobS --- coderd/database/databasefake/generator.go | 106 ++++++++++++++++--- coderd/httpmw/groupparam_test.go | 1 + coderd/httpmw/workspaceresourceparam_test.go | 32 +++--- 3 files changed, 105 insertions(+), 34 deletions(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index c9a9c06d317e3..f1d15e0e08a8c 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -2,6 +2,7 @@ package databasefake import ( "context" + "database/sql" "testing" "time" @@ -57,6 +58,14 @@ func populate[DBType any](ctx context.Context, g *Generator, name string, seed D return out[name].(DBType) } +func (g *Generator) WorkspaceResource(ctx context.Context, name string, seed database.WorkspaceResource) database.WorkspaceResource { + return populate(ctx, g, name, seed) +} + +func (g *Generator) Job(ctx context.Context, name string, seed database.ProvisionerJob) database.ProvisionerJob { + return populate(ctx, g, name, seed) +} + func (g *Generator) Group(ctx context.Context, name string, seed database.Group) database.Group { return populate(ctx, g, name, seed) } @@ -94,8 +103,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.Template: template, err := db.InsertTemplate(ctx, database.InsertTemplateParams{ ID: g.Lookup(name), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), + CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), @@ -115,8 +124,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.Workspace: workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ ID: g.Lookup(name), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), + CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), TemplateID: takeFirst(orig.TemplateID, uuid.New()), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), @@ -129,16 +138,16 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.WorkspaceBuild: build, err := db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ ID: g.Lookup(name), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), + CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()), TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()), BuildNumber: takeFirst(orig.BuildNumber, 0), Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), - JobID: takeFirst(orig.InitiatorID, uuid.New()), - ProvisionerState: []byte{}, - Deadline: time.Now(), + JobID: takeFirst(orig.JobID, uuid.New()), + ProvisionerState: takeFirstBytes(orig.ProvisionerState, []byte{}), + Deadline: takeFirstTime(orig.CreatedAt, time.Now().Add(time.Hour)), Reason: takeFirst(orig.Reason, database.BuildReasonInitiator), }) require.NoError(t, err, "insert workspace build") @@ -149,9 +158,9 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m ID: g.Lookup(name), Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)), Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)), - HashedPassword: []byte{}, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), + HashedPassword: takeFirstBytes(orig.HashedPassword, []byte{}), + CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), RBACRoles: []string{}, LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), }) @@ -164,8 +173,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m ID: g.Lookup(name), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), Description: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), + CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), }) require.NoError(t, err, "insert organization") @@ -182,12 +191,55 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m require.NoError(t, err, "insert organization") seed[name] = org + + case database.ProvisionerJob: + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: g.Lookup(name), + CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), + InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), + Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), + StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile), + FileID: takeFirst(orig.FileID, uuid.New()), + Type: takeFirst(orig.Type, database.ProvisionerJobTypeWorkspaceBuild), + Input: takeFirstBytes(orig.Input, []byte("{}")), + Tags: orig.Tags, + }) + require.NoError(t, err, "insert job") + + seed[name] = job + + case database.WorkspaceResource: + resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ + ID: g.Lookup(name), + CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + JobID: takeFirst(orig.JobID, uuid.New()), + Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), + // TODO: What type to put here? + Type: takeFirst(orig.Type, ""), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Hide: takeFirst(orig.Hide, false), + Icon: takeFirst(orig.Name, ""), + InstanceType: sql.NullString{ + String: takeFirst(orig.InstanceType.String, ""), + Valid: takeFirst(orig.InstanceType.Valid, false), + }, + DailyCost: takeFirst(orig.DailyCost, 0), + }) + require.NoError(t, err, "insert resource") + + seed[name] = resource } } return seed } func (tc *Generator) Lookup(name string) uuid.UUID { + if name == "" { + // No name means the caller doesn't care about the ID. + return uuid.New() + } if tc.names == nil { tc.names = make(map[string]uuid.UUID) } @@ -199,14 +251,34 @@ func (tc *Generator) Lookup(name string) uuid.UUID { return id } -// takeFirst will take the first non-empty value. -func takeFirst[Value comparable](values ...Value) Value { +func takeFirstTime(values ...time.Time) time.Time { + return takeFirstF(values, func(v time.Time) bool { + return !v.IsZero() + }) +} + +func takeFirstBytes(values ...[]byte) []byte { + return takeFirstF(values, func(v []byte) bool { + return len(v) != 0 + }) +} + +// takeFirstF takes the first value that returns true +func takeFirstF[Value any](values []Value, take func(v Value) bool) Value { var empty Value for _, v := range values { - if v != empty { + if take(v) { return v } } // If all empty, return empty return empty } + +// takeFirst will take the first non-empty value. +func takeFirst[Value comparable](values ...Value) Value { + var empty Value + return takeFirstF(values, func(v Value) bool { + return v != empty + }) +} diff --git a/coderd/httpmw/groupparam_test.go b/coderd/httpmw/groupparam_test.go index de32d58aeb9bb..409c204e1abd6 100644 --- a/coderd/httpmw/groupparam_test.go +++ b/coderd/httpmw/groupparam_test.go @@ -25,6 +25,7 @@ func TestGroupParam(t *testing.T) { ctx, _ := testutil.Context(t) db := databasefake.New() gen := databasefake.NewGenerator(t, db) + group := gen.Group(ctx, "group", database.Group{}) return db, group diff --git a/coderd/httpmw/workspaceresourceparam_test.go b/coderd/httpmw/workspaceresourceparam_test.go index 8d83160eaf20d..7de9edc54c46c 100644 --- a/coderd/httpmw/workspaceresourceparam_test.go +++ b/coderd/httpmw/workspaceresourceparam_test.go @@ -18,32 +18,30 @@ import ( func TestWorkspaceResourceParam(t *testing.T) { t.Parallel() - setup := func(db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) { + setup := func(t *testing.T, db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) { r := httptest.NewRequest("GET", "/", nil) - job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ - ID: uuid.New(), + ctx := context.Background() + gen := databasefake.NewGenerator(t, db) + job := gen.Job(ctx, "", database.ProvisionerJob{ Type: jobType, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, }) - require.NoError(t, err) - workspaceBuild, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{ - ID: uuid.New(), + + build := gen.WorkspaceBuild(ctx, "", database.WorkspaceBuild{ JobID: job.ID, Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) - require.NoError(t, err) - resource, err := db.InsertWorkspaceResource(context.Background(), database.InsertWorkspaceResourceParams{ - ID: uuid.New(), + + resource := gen.WorkspaceResource(ctx, "", database.WorkspaceResource{ JobID: job.ID, Transition: database.WorkspaceTransitionStart, }) - require.NoError(t, err) - ctx := chi.NewRouteContext() - ctx.URLParams.Add("workspacebuild", workspaceBuild.ID.String()) - r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx)) + routeCtx := chi.NewRouteContext() + routeCtx.URLParams.Add("workspacebuild", build.ID.String()) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx)) return r, resource } @@ -53,7 +51,7 @@ func TestWorkspaceResourceParam(t *testing.T) { rtr := chi.NewRouter() rtr.Use(httpmw.ExtractWorkspaceResourceParam(db)) rtr.Get("/", nil) - r, _ := setup(db, database.ProvisionerJobTypeWorkspaceBuild) + r, _ := setup(t, db, database.ProvisionerJobTypeWorkspaceBuild) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -71,7 +69,7 @@ func TestWorkspaceResourceParam(t *testing.T) { ) rtr.Get("/", nil) - r, _ := setup(db, database.ProvisionerJobTypeWorkspaceBuild) + r, _ := setup(t, db, database.ProvisionerJobTypeWorkspaceBuild) chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", uuid.NewString()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -93,7 +91,7 @@ func TestWorkspaceResourceParam(t *testing.T) { rw.WriteHeader(http.StatusOK) }) - r, job := setup(db, database.ProvisionerJobTypeTemplateVersionImport) + r, job := setup(t, db, database.ProvisionerJobTypeTemplateVersionImport) chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", job.ID.String()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -115,7 +113,7 @@ func TestWorkspaceResourceParam(t *testing.T) { rw.WriteHeader(http.StatusOK) }) - r, job := setup(db, database.ProvisionerJobTypeWorkspaceBuild) + r, job := setup(t, db, database.ProvisionerJobTypeWorkspaceBuild) chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", job.ID.String()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) From 38eb88ea34ebda5eac7c92e8b7c3a8bd946be390 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 30 Jan 2023 14:03:31 -0600 Subject: [PATCH 03/21] Support api keys in generator --- coderd/database/databasefake/generator.go | 83 ++++++++++++++++++++--- coderd/httpmw/apikey_test.go | 31 ++++----- coderd/httpmw/ratelimit_test.go | 14 ++-- 3 files changed, 93 insertions(+), 35 deletions(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index f1d15e0e08a8c..62ae347d5c9de 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -2,10 +2,16 @@ package databasefake import ( "context" + "crypto/sha256" "database/sql" + "fmt" "testing" "time" + "github.com/coder/coder/cryptorand" + + "github.com/tabbed/pqtype" + "github.com/coder/coder/coderd/database" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" @@ -52,10 +58,42 @@ func (g *Generator) PrimaryOrg(ctx context.Context) database.Organization { } func populate[DBType any](ctx context.Context, g *Generator, name string, seed DBType) DBType { + if name == "" { + name = g.RandomName() + } + out := g.Populate(ctx, map[string]interface{}{ name: seed, }) - return out[name].(DBType) + v, ok := out[name].(DBType) + if !ok { + panic("developer error, type mismatch") + } + return v +} + +func (g *Generator) RandomName() string { + for { + name := namesgenerator.GetRandomName(0) + if _, ok := g.names[name]; !ok { + return name + } + } +} + +func (g *Generator) APIKey(ctx context.Context, name string, seed database.APIKey) (key database.APIKey, token string) { + if name == "" { + name = g.RandomName() + } + + out := g.Populate(ctx, map[string]interface{}{ + name: seed, + }) + key, keyOk := out[name].(database.APIKey) + secret, secOk := out[name+"_secret"].(string) + require.True(g.testT, keyOk && secOk, "APIKey & secret must be populated with the right type") + + return key, fmt.Sprintf("%s-%s", key.ID, secret) } func (g *Generator) WorkspaceResource(ctx context.Context, name string, seed database.WorkspaceResource) database.WorkspaceResource { @@ -100,11 +138,34 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m for name, v := range seed { switch orig := v.(type) { + case database.APIKey: + id, _ := cryptorand.String(10) + secret, _ := cryptorand.String(22) + hashed := sha256.Sum256([]byte(secret)) + + key, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: takeFirst(orig.ID, id), + LifetimeSeconds: takeFirst(orig.LifetimeSeconds, 3600), + HashedSecret: takeFirstBytes(orig.HashedSecret, hashed[:]), + IPAddress: pqtype.Inet{}, + UserID: takeFirst(orig.UserID, uuid.New()), + LastUsed: takeFirstTime(orig.LastUsed, time.Now()), + ExpiresAt: takeFirstTime(orig.ExpiresAt, time.Now().Add(time.Hour)), + CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), + LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), + Scope: takeFirst(orig.Scope, database.APIKeyScopeAll), + }) + require.NoError(t, err, "insert api key") + + seed[name] = key + // Need to also save the secret + seed[name+"_secret"] = secret case database.Template: template, err := db.InsertTemplate(ctx, database.InsertTemplateParams{ ID: g.Lookup(name), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), @@ -125,7 +186,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ ID: g.Lookup(name), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), TemplateID: takeFirst(orig.TemplateID, uuid.New()), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), @@ -139,7 +200,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m build, err := db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ ID: g.Lookup(name), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()), TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()), BuildNumber: takeFirst(orig.BuildNumber, 0), @@ -147,7 +208,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), JobID: takeFirst(orig.JobID, uuid.New()), ProvisionerState: takeFirstBytes(orig.ProvisionerState, []byte{}), - Deadline: takeFirstTime(orig.CreatedAt, time.Now().Add(time.Hour)), + Deadline: takeFirstTime(orig.Deadline, time.Now().Add(time.Hour)), Reason: takeFirst(orig.Reason, database.BuildReasonInitiator), }) require.NoError(t, err, "insert workspace build") @@ -160,7 +221,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)), HashedPassword: takeFirstBytes(orig.HashedPassword, []byte{}), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), RBACRoles: []string{}, LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), }) @@ -172,9 +233,9 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ ID: g.Lookup(name), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - Description: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), }) require.NoError(t, err, "insert organization") @@ -185,7 +246,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m ID: g.Lookup(name), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), - AvatarURL: takeFirst(orig.Name, "https://logo.example.com"), + AvatarURL: takeFirst(orig.AvatarURL, "https://logo.example.com"), QuotaAllowance: takeFirst(orig.QuotaAllowance, 0), }) require.NoError(t, err, "insert organization") @@ -196,7 +257,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: g.Lookup(name), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), @@ -220,7 +281,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m Type: takeFirst(orig.Type, ""), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), Hide: takeFirst(orig.Hide, false), - Icon: takeFirst(orig.Name, ""), + Icon: takeFirst(orig.Icon, ""), InstanceType: sql.NullString{ String: takeFirst(orig.InstanceType.String, ""), Valid: takeFirst(orig.InstanceType.Valid, false), diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index c4266646ed442..b93ffc7d0fde8 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -151,24 +151,21 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidSecret", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ctx = context.Background() + user = gen.User(ctx, "", database.User{}) + + // Use a different secret so they don't match! + hashed = sha256.Sum256([]byte("differentsecret")) + _, token = gen.APIKey(ctx, "", database.APIKey{ + UserID: user.ID, + HashedSecret: hashed[:], + }) ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) - - // Use a different secret so they don't match! - hashed := sha256.Sum256([]byte("differentsecret")) - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) + r.Header.Set(codersdk.SessionTokenHeader, token) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, diff --git a/coderd/httpmw/ratelimit_test.go b/coderd/httpmw/ratelimit_test.go index 61a6f5d903566..cbc6965b02f2a 100644 --- a/coderd/httpmw/ratelimit_test.go +++ b/coderd/httpmw/ratelimit_test.go @@ -94,9 +94,9 @@ func TestRateLimit(t *testing.T) { ctx := context.Background() db := databasefake.New() - - u := createUser(ctx, t, db) - key := insertAPIKey(ctx, t, db, u.ID) + gen := databasefake.NewGenerator(t, db) + u := gen.User(ctx, "", database.User{}) + _, key := gen.APIKey(ctx, "", database.APIKey{UserID: u.ID}) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ @@ -141,11 +141,11 @@ func TestRateLimit(t *testing.T) { db := databasefake.New() - u := createUser(ctx, t, db, func(u *database.InsertUserParams) { - u.RBACRoles = []string{rbac.RoleOwner()} + gen := databasefake.NewGenerator(t, db) + u := gen.User(ctx, "", database.User{ + RBACRoles: []string{rbac.RoleOwner()}, }) - - key := insertAPIKey(ctx, t, db, u.ID) + _, key := gen.APIKey(ctx, "", database.APIKey{UserID: u.ID}) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ From d4fe6b7d4e22b8e3f27b737581ff72803d895070 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 30 Jan 2023 14:08:47 -0600 Subject: [PATCH 04/21] Begin refactoring api key tests with generator --- coderd/httpmw/apikey_test.go | 124 ++++++++++++++++------------------- 1 file changed, 55 insertions(+), 69 deletions(-) diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index b93ffc7d0fde8..7cb1a31c95fb2 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -151,10 +151,11 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidSecret", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - gen = databasefake.NewGenerator(t, db) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ctx = context.Background() user = gen.User(ctx, "", database.User{}) @@ -178,23 +179,20 @@ func TestAPIKey(t *testing.T) { t.Run("Expired", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + ctx = context.Background() + user = gen.User(ctx, "", database.User{}) + _, token = gen.APIKey(ctx, "", database.APIKey{ + UserID: user.ID, + ExpiresAt: time.Now().Add(time.Hour * -1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + r.Header.Set(codersdk.SessionTokenHeader, token) - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -207,24 +205,20 @@ func TestAPIKey(t *testing.T) { t.Run("Valid", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + ctx = context.Background() + user = gen.User(ctx, "", database.User{}) + sentAPIKey, token = gen.APIKey(ctx, "", database.APIKey{ + UserID: user.ID, + ExpiresAt: database.Now().AddDate(0, 0, 1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + r.Header.Set(codersdk.SessionTokenHeader, token) - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -239,7 +233,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) @@ -248,27 +242,23 @@ func TestAPIKey(t *testing.T) { t.Run("ValidWithScope", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + ctx = context.Background() + user = gen.User(ctx, "", database.User{}) + _, token = gen.APIKey(ctx, "", database.APIKey{ + UserID: user.ID, + ExpiresAt: database.Now().AddDate(0, 0, 1), + Scope: database.APIKeyScopeApplicationConnect, + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ Name: codersdk.SessionTokenCookie, - Value: fmt.Sprintf("%s-%s", id, secret), - }) - - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - UserID: user.ID, - HashedSecret: hashed[:], - ExpiresAt: database.Now().AddDate(0, 0, 1), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeApplicationConnect, + Value: token, }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, @@ -291,26 +281,22 @@ func TestAPIKey(t *testing.T) { t.Run("QueryParameter", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + ctx = context.Background() + user = gen.User(ctx, "", database.User{}) + _, token = gen.APIKey(ctx, "", database.APIKey{ + UserID: user.ID, + ExpiresAt: database.Now().AddDate(0, 0, 1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) q := r.URL.Query() - q.Add(codersdk.SessionTokenCookie, fmt.Sprintf("%s-%s", id, secret)) + q.Add(codersdk.SessionTokenCookie, token) r.URL.RawQuery = q.Encode() - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, From d423ba2c9a14623b958e64cd4cfb9bfbef5e2a02 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 30 Jan 2023 14:39:49 -0600 Subject: [PATCH 05/21] Refactor api tests to use fake generator --- coderd/database/databasefake/generator.go | 68 +++-- coderd/httpmw/apikey_test.go | 266 ++++++++----------- coderd/httpmw/groupparam_test.go | 2 +- coderd/httpmw/ratelimit_test.go | 8 +- coderd/httpmw/workspaceresourceparam_test.go | 6 +- 5 files changed, 160 insertions(+), 190 deletions(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index 62ae347d5c9de..667e329409477 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -44,7 +44,7 @@ func NewGenerator(t *testing.T, db database.Store) *Generator { // specified. func (g *Generator) PrimaryOrg(ctx context.Context) database.Organization { if g.primaryOrg == nil { - org := g.Organization(ctx, "primary-org", database.Organization{ + org := g.Organization(ctx, database.Organization{ ID: g.Lookup(primaryOrgName), Name: primaryOrgName, Description: "This is the default primary organization for all tests", @@ -81,11 +81,8 @@ func (g *Generator) RandomName() string { } } -func (g *Generator) APIKey(ctx context.Context, name string, seed database.APIKey) (key database.APIKey, token string) { - if name == "" { - name = g.RandomName() - } - +func (g *Generator) APIKey(ctx context.Context, seed database.APIKey) (key database.APIKey, token string) { + name := g.RandomName() out := g.Populate(ctx, map[string]interface{}{ name: seed, }) @@ -96,40 +93,44 @@ func (g *Generator) APIKey(ctx context.Context, name string, seed database.APIKe return key, fmt.Sprintf("%s-%s", key.ID, secret) } -func (g *Generator) WorkspaceResource(ctx context.Context, name string, seed database.WorkspaceResource) database.WorkspaceResource { - return populate(ctx, g, name, seed) +func (g *Generator) UserLink(ctx context.Context, seed database.UserLink) database.UserLink { + return populate(ctx, g, "", seed) +} + +func (g *Generator) WorkspaceResource(ctx context.Context, seed database.WorkspaceResource) database.WorkspaceResource { + return populate(ctx, g, "", seed) } -func (g *Generator) Job(ctx context.Context, name string, seed database.ProvisionerJob) database.ProvisionerJob { - return populate(ctx, g, name, seed) +func (g *Generator) Job(ctx context.Context, seed database.ProvisionerJob) database.ProvisionerJob { + return populate(ctx, g, "", seed) } -func (g *Generator) Group(ctx context.Context, name string, seed database.Group) database.Group { - return populate(ctx, g, name, seed) +func (g *Generator) Group(ctx context.Context, seed database.Group) database.Group { + return populate(ctx, g, "", seed) } -func (g *Generator) Organization(ctx context.Context, name string, seed database.Organization) database.Organization { - return populate(ctx, g, name, seed) +func (g *Generator) Organization(ctx context.Context, seed database.Organization) database.Organization { + return populate(ctx, g, "", seed) } -func (g *Generator) Workspace(ctx context.Context, name string, seed database.Workspace) database.Workspace { - return populate(ctx, g, name, seed) +func (g *Generator) Workspace(ctx context.Context, seed database.Workspace) database.Workspace { + return populate(ctx, g, "", seed) } -func (g *Generator) Template(ctx context.Context, name string, seed database.Template) database.Template { - return populate(ctx, g, name, seed) +func (g *Generator) Template(ctx context.Context, seed database.Template) database.Template { + return populate(ctx, g, "", seed) } -func (g *Generator) TemplateVersion(ctx context.Context, name string, seed database.TemplateVersion) database.TemplateVersion { - return populate(ctx, g, name, seed) +func (g *Generator) TemplateVersion(ctx context.Context, seed database.TemplateVersion) database.TemplateVersion { + return populate(ctx, g, "", seed) } -func (g *Generator) WorkspaceBuild(ctx context.Context, name string, seed database.WorkspaceBuild) database.WorkspaceBuild { - return populate(ctx, g, name, seed) +func (g *Generator) WorkspaceBuild(ctx context.Context, seed database.WorkspaceBuild) database.WorkspaceBuild { + return populate(ctx, g, "", seed) } -func (g *Generator) User(ctx context.Context, name string, seed database.User) database.User { - return populate(ctx, g, name, seed) +func (g *Generator) User(ctx context.Context, seed database.User) database.User { + return populate(ctx, g, "", seed) } func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) map[string]interface{} { @@ -144,8 +145,9 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m hashed := sha256.Sum256([]byte(secret)) key, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ - ID: takeFirst(orig.ID, id), - LifetimeSeconds: takeFirst(orig.LifetimeSeconds, 3600), + ID: takeFirst(orig.ID, id), + // 0 defaults to 86400 at the db layer + LifetimeSeconds: takeFirst(orig.LifetimeSeconds, 0), HashedSecret: takeFirstBytes(orig.HashedSecret, hashed[:]), IPAddress: pqtype.Inet{}, UserID: takeFirst(orig.UserID, uuid.New()), @@ -291,6 +293,20 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m require.NoError(t, err, "insert resource") seed[name] = resource + + case database.UserLink: + link, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{ + UserID: takeFirst(orig.UserID, uuid.New()), + LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub), + LinkedID: takeFirst(orig.LinkedID), + OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthExpiry: takeFirstTime(orig.OAuthExpiry, time.Now().Add(time.Hour*24)), + }) + + require.NoError(t, err, "insert link") + + seed[name] = link } } return seed diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 7cb1a31c95fb2..656cd34f540d3 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -157,11 +156,11 @@ func TestAPIKey(t *testing.T) { rw = httptest.NewRecorder() ctx = context.Background() - user = gen.User(ctx, "", database.User{}) + user = gen.User(ctx, database.User{}) // Use a different secret so they don't match! hashed = sha256.Sum256([]byte("differentsecret")) - _, token = gen.APIKey(ctx, "", database.APIKey{ + _, token = gen.APIKey(ctx, database.APIKey{ UserID: user.ID, HashedSecret: hashed[:], }) @@ -182,8 +181,8 @@ func TestAPIKey(t *testing.T) { db = databasefake.New() gen = databasefake.NewGenerator(t, db) ctx = context.Background() - user = gen.User(ctx, "", database.User{}) - _, token = gen.APIKey(ctx, "", database.APIKey{ + user = gen.User(ctx, database.User{}) + _, token = gen.APIKey(ctx, database.APIKey{ UserID: user.ID, ExpiresAt: time.Now().Add(time.Hour * -1), }) @@ -208,8 +207,8 @@ func TestAPIKey(t *testing.T) { db = databasefake.New() gen = databasefake.NewGenerator(t, db) ctx = context.Background() - user = gen.User(ctx, "", database.User{}) - sentAPIKey, token = gen.APIKey(ctx, "", database.APIKey{ + user = gen.User(ctx, database.User{}) + sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ UserID: user.ID, ExpiresAt: database.Now().AddDate(0, 0, 1), }) @@ -245,8 +244,8 @@ func TestAPIKey(t *testing.T) { db = databasefake.New() gen = databasefake.NewGenerator(t, db) ctx = context.Background() - user = gen.User(ctx, "", database.User{}) - _, token = gen.APIKey(ctx, "", database.APIKey{ + user = gen.User(ctx, database.User{}) + _, token = gen.APIKey(ctx, database.APIKey{ UserID: user.ID, ExpiresAt: database.Now().AddDate(0, 0, 1), Scope: database.APIKeyScopeApplicationConnect, @@ -284,8 +283,8 @@ func TestAPIKey(t *testing.T) { db = databasefake.New() gen = databasefake.NewGenerator(t, db) ctx = context.Background() - user = gen.User(ctx, "", database.User{}) - _, token = gen.APIKey(ctx, "", database.APIKey{ + user = gen.User(ctx, database.User{}) + _, token = gen.APIKey(ctx, database.APIKey{ UserID: user.ID, ExpiresAt: database.Now().AddDate(0, 0, 1), }) @@ -315,25 +314,21 @@ func TestAPIKey(t *testing.T) { t.Run("ValidUpdateLastUsed", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + ctx = context.Background() + user = gen.User(ctx, database.User{}) + sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now().AddDate(0, 0, -1), + ExpiresAt: database.Now().AddDate(0, 0, 1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + r.Header.Set(codersdk.SessionTokenHeader, token) - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LastUsed: database.Now().AddDate(0, 0, -1), - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -342,7 +337,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.NotEqual(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) @@ -352,25 +347,21 @@ func TestAPIKey(t *testing.T) { t.Run("ValidUpdateExpiry", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + ctx = context.Background() + user = gen.User(ctx, database.User{}) + sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now(), + ExpiresAt: database.Now().Add(time.Minute), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + r.Header.Set(codersdk.SessionTokenHeader, token) - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LastUsed: database.Now(), - ExpiresAt: database.Now().Add(time.Minute), - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -379,7 +370,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) @@ -389,31 +380,25 @@ func TestAPIKey(t *testing.T) { t.Run("OAuthNotExpired", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) - ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) - - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LoginType: database.LoginTypeGithub, - LastUsed: database.Now(), - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + ctx = context.Background() + user = gen.User(ctx, database.User{}) + sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now(), + ExpiresAt: database.Now().AddDate(0, 0, 1), + LoginType: database.LoginTypeGithub, + }) + _ = gen.UserLink(ctx, database.UserLink{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + }) - _, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{ - UserID: user.ID, - LoginType: database.LoginTypeGithub, - }) - require.NoError(t, err) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.Header.Set(codersdk.SessionTokenHeader, token) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, @@ -423,7 +408,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) @@ -433,33 +418,29 @@ func TestAPIKey(t *testing.T) { t.Run("OAuthRefresh", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) - ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + ctx = context.Background() + user = gen.User(ctx, database.User{}) + sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now(), + ExpiresAt: database.Now().AddDate(0, 0, 1), + LoginType: database.LoginTypeGithub, + }) + _ = gen.UserLink(ctx, database.UserLink{ + UserID: user.ID, + LoginType: database.LoginTypeGithub, + OAuthRefreshToken: "hello", + OAuthExpiry: database.Now().AddDate(0, 0, -1), + }) - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LoginType: database.LoginTypeGithub, - LastUsed: database.Now(), - UserID: user.ID, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) - _, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{ - UserID: user.ID, - LoginType: database.LoginTypeGithub, - OAuthExpiry: database.Now().AddDate(0, 0, -1), - OAuthRefreshToken: "hello", - }) - require.NoError(t, err) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.Header.Set(codersdk.SessionTokenHeader, token) - token := &oauth2.Token{ + oauthToken := &oauth2.Token{ AccessToken: "wow", RefreshToken: "moo", Expiry: database.Now().AddDate(0, 0, 1), @@ -469,7 +450,7 @@ func TestAPIKey(t *testing.T) { OAuth2Configs: &httpmw.OAuth2Configs{ Github: &oauth2Config{ tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) { - return token, nil + return oauthToken, nil }), }, }, @@ -479,36 +460,32 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) - require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt) + require.Equal(t, oauthToken.Expiry, gotAPIKey.ExpiresAt) }) t.Run("RemoteIPUpdates", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + ctx = context.Background() + user = gen.User(ctx, database.User{}) + sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now().AddDate(0, 0, -1), + ExpiresAt: database.Now().AddDate(0, 0, 1), + }) + + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() ) r.RemoteAddr = "1.1.1.1" - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + r.Header.Set(codersdk.SessionTokenHeader, token) - _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LastUsed: database.Now().AddDate(0, 0, -1), - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, RedirectToLogin: false, @@ -517,7 +494,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, net.ParseIP("1.1.1.1"), gotAPIKey.IPAddress.IPNet.IP) @@ -578,25 +555,21 @@ func TestAPIKey(t *testing.T) { t.Run("Tokens", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - id, secret = randomAPIKeyParts() - hashed = sha256.Sum256([]byte(secret)) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - user = createUser(r.Context(), t, db) - ) - r.Header.Set(codersdk.SessionTokenHeader, fmt.Sprintf("%s-%s", id, secret)) + db = databasefake.New() + gen = databasefake.NewGenerator(t, db) + ctx = context.Background() + user = gen.User(ctx, database.User{}) + sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + UserID: user.ID, + LastUsed: database.Now(), + ExpiresAt: database.Now().AddDate(0, 0, 1), + LoginType: database.LoginTypeToken, + }) - sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LoginType: database.LoginTypeToken, - LastUsed: database.Now(), - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: user.ID, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.Header.Set(codersdk.SessionTokenHeader, token) httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: db, @@ -606,7 +579,7 @@ func TestAPIKey(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) - gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID) require.NoError(t, err) require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed) @@ -615,25 +588,6 @@ func TestAPIKey(t *testing.T) { }) } -func createUser(ctx context.Context, t *testing.T, db database.Store, opts ...func(u *database.InsertUserParams)) database.User { - insert := database.InsertUserParams{ - ID: uuid.New(), - Email: "email@coder.com", - Username: "username", - HashedPassword: []byte{}, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - RBACRoles: []string{}, - LoginType: database.LoginTypePassword, - } - for _, opt := range opts { - opt(&insert) - } - user, err := db.InsertUser(ctx, insert) - require.NoError(t, err, "create user") - return user -} - type oauth2Config struct { tokenSource oauth2TokenSource } diff --git a/coderd/httpmw/groupparam_test.go b/coderd/httpmw/groupparam_test.go index 409c204e1abd6..b726dfd8e45c0 100644 --- a/coderd/httpmw/groupparam_test.go +++ b/coderd/httpmw/groupparam_test.go @@ -26,7 +26,7 @@ func TestGroupParam(t *testing.T) { db := databasefake.New() gen := databasefake.NewGenerator(t, db) - group := gen.Group(ctx, "group", database.Group{}) + group := gen.Group(ctx, database.Group{}) return db, group } diff --git a/coderd/httpmw/ratelimit_test.go b/coderd/httpmw/ratelimit_test.go index cbc6965b02f2a..dd6d461d7f16c 100644 --- a/coderd/httpmw/ratelimit_test.go +++ b/coderd/httpmw/ratelimit_test.go @@ -95,8 +95,8 @@ func TestRateLimit(t *testing.T) { db := databasefake.New() gen := databasefake.NewGenerator(t, db) - u := gen.User(ctx, "", database.User{}) - _, key := gen.APIKey(ctx, "", database.APIKey{UserID: u.ID}) + u := gen.User(ctx, database.User{}) + _, key := gen.APIKey(ctx, database.APIKey{UserID: u.ID}) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ @@ -142,10 +142,10 @@ func TestRateLimit(t *testing.T) { db := databasefake.New() gen := databasefake.NewGenerator(t, db) - u := gen.User(ctx, "", database.User{ + u := gen.User(ctx, database.User{ RBACRoles: []string{rbac.RoleOwner()}, }) - _, key := gen.APIKey(ctx, "", database.APIKey{UserID: u.ID}) + _, key := gen.APIKey(ctx, database.APIKey{UserID: u.ID}) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ diff --git a/coderd/httpmw/workspaceresourceparam_test.go b/coderd/httpmw/workspaceresourceparam_test.go index 7de9edc54c46c..a858ea6497a56 100644 --- a/coderd/httpmw/workspaceresourceparam_test.go +++ b/coderd/httpmw/workspaceresourceparam_test.go @@ -22,19 +22,19 @@ func TestWorkspaceResourceParam(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) ctx := context.Background() gen := databasefake.NewGenerator(t, db) - job := gen.Job(ctx, "", database.ProvisionerJob{ + job := gen.Job(ctx, database.ProvisionerJob{ Type: jobType, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, }) - build := gen.WorkspaceBuild(ctx, "", database.WorkspaceBuild{ + build := gen.WorkspaceBuild(ctx, database.WorkspaceBuild{ JobID: job.ID, Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) - resource := gen.WorkspaceResource(ctx, "", database.WorkspaceResource{ + resource := gen.WorkspaceResource(ctx, database.WorkspaceResource{ JobID: job.ID, Transition: database.WorkspaceTransitionStart, }) From fac054aa9e6649828ee386dbb0cba94f8a0004d2 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 30 Jan 2023 14:51:10 -0600 Subject: [PATCH 06/21] Refactor workspace apps test --- coderd/workspaceapps_internal_test.go | 40 ++++++++++++--------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/coderd/workspaceapps_internal_test.go b/coderd/workspaceapps_internal_test.go index 1e298000d9dd5..eb6f785e1dd92 100644 --- a/coderd/workspaceapps_internal_test.go +++ b/coderd/workspaceapps_internal_test.go @@ -16,36 +16,22 @@ import ( func TestAPIKeyEncryption(t *testing.T) { t.Parallel() - generateAPIKey := func(t *testing.T, db database.Store) (keyID, keySecret string, hashedSecret []byte, data encryptedAPIKeyPayload) { - keyID, keySecret, err := GenerateAPIKeyIDSecret() - require.NoError(t, err) + generateAPIKey := func(t *testing.T, db database.Store) (keyID, keyToken string, hashedSecret []byte, data encryptedAPIKeyPayload) { + gen := databasefake.NewGenerator(t, db) + key, token := gen.APIKey(context.Background(), database.APIKey{}) - hashedSecretArray := sha256.Sum256([]byte(keySecret)) data = encryptedAPIKeyPayload{ - APIKey: keyID + "-" + keySecret, + APIKey: token, ExpiresAt: database.Now().Add(24 * time.Hour), } - return keyID, keySecret, hashedSecretArray[:], data - } - insertAPIKey := func(t *testing.T, db database.Store, keyID string, hashedSecret []byte) { - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ - ID: keyID, - HashedSecret: hashedSecret, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) + return key.ID, token, key.HashedSecret[:], data } t.Run("OK", func(t *testing.T) { t.Parallel() db := databasefake.New() keyID, _, hashedSecret, data := generateAPIKey(t, db) - insertAPIKey(t, db, keyID, hashedSecret) encrypted, err := encryptAPIKey(data) require.NoError(t, err) @@ -66,8 +52,7 @@ func TestAPIKeyEncryption(t *testing.T) { t.Run("Expiry", func(t *testing.T) { t.Parallel() db := databasefake.New() - keyID, _, hashedSecret, data := generateAPIKey(t, db) - insertAPIKey(t, db, keyID, hashedSecret) + _, _, _, data := generateAPIKey(t, db) data.ExpiresAt = database.Now().Add(-1 * time.Hour) encrypted, err := encryptAPIKey(data) @@ -84,9 +69,18 @@ func TestAPIKeyEncryption(t *testing.T) { t.Run("KeyMatches", func(t *testing.T) { t.Parallel() db := databasefake.New() - keyID, _, _, data := generateAPIKey(t, db) + + gen := databasefake.NewGenerator(t, db) hashedSecret := sha256.Sum256([]byte("wrong")) - insertAPIKey(t, db, keyID, hashedSecret[:]) + // Insert a token with a mismatched hashed secret. + _, token := gen.APIKey(context.Background(), database.APIKey{ + HashedSecret: hashedSecret[:], + }) + + data := encryptedAPIKeyPayload{ + APIKey: token, + ExpiresAt: database.Now().Add(24 * time.Hour), + } encrypted, err := encryptAPIKey(data) require.NoError(t, err) From 387c1bd0982417d98cafdf2e27a6512ee55c3fa2 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 30 Jan 2023 14:53:26 -0600 Subject: [PATCH 07/21] Remove unused function --- coderd/database/databasefake/generator.go | 19 +++++++++---------- coderd/httpmw/ratelimit_test.go | 20 -------------------- 2 files changed, 9 insertions(+), 30 deletions(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index 667e329409477..165d993f4230e 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -8,14 +8,13 @@ import ( "testing" "time" - "github.com/coder/coder/cryptorand" - - "github.com/tabbed/pqtype" - - "github.com/coder/coder/coderd/database" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" "github.com/stretchr/testify/require" + "github.com/tabbed/pqtype" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/cryptorand" ) const primaryOrgName = "primary-org" @@ -312,19 +311,19 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m return seed } -func (tc *Generator) Lookup(name string) uuid.UUID { +func (g *Generator) Lookup(name string) uuid.UUID { if name == "" { // No name means the caller doesn't care about the ID. return uuid.New() } - if tc.names == nil { - tc.names = make(map[string]uuid.UUID) + if g.names == nil { + g.names = make(map[string]uuid.UUID) } - if id, ok := tc.names[name]; ok { + if id, ok := g.names[name]; ok { return id } id := uuid.New() - tc.names[name] = id + g.names[name] = id return id } diff --git a/coderd/httpmw/ratelimit_test.go b/coderd/httpmw/ratelimit_test.go index dd6d461d7f16c..ec8b30ad03513 100644 --- a/coderd/httpmw/ratelimit_test.go +++ b/coderd/httpmw/ratelimit_test.go @@ -2,7 +2,6 @@ package httpmw_test import ( "context" - "crypto/sha256" "fmt" "math/rand" "net" @@ -12,7 +11,6 @@ import ( "time" "github.com/go-chi/chi/v5" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/database" @@ -23,24 +21,6 @@ import ( "github.com/coder/coder/testutil" ) -func insertAPIKey(ctx context.Context, t *testing.T, db database.Store, userID uuid.UUID) string { - id, secret := randomAPIKeyParts() - hashed := sha256.Sum256([]byte(secret)) - - _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ - ID: id, - HashedSecret: hashed[:], - LastUsed: database.Now().AddDate(0, 0, -1), - ExpiresAt: database.Now().AddDate(0, 0, 1), - UserID: userID, - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, - }) - require.NoError(t, err) - - return fmt.Sprintf("%s-%s", id, secret) -} - func randRemoteAddr() string { var b [4]byte // nolint:gosec From 50cba39870c858bb52eaba128b8cec3d74cc3474 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 30 Jan 2023 15:21:16 -0600 Subject: [PATCH 08/21] Refactor provsionerdserver tests --- coderd/database/databasefake/generator.go | 55 +++++- .../provisionerdserver_test.go | 182 +++++++----------- 2 files changed, 121 insertions(+), 116 deletions(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index 165d993f4230e..a56f576c3734b 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "database/sql" + "encoding/hex" "fmt" "testing" "time" @@ -92,6 +93,10 @@ func (g *Generator) APIKey(ctx context.Context, seed database.APIKey) (key datab return key, fmt.Sprintf("%s-%s", key.ID, secret) } +func (g *Generator) File(ctx context.Context, seed database.File) database.File { + return populate(ctx, g, "", seed) +} + func (g *Generator) UserLink(ctx context.Context, seed database.UserLink) database.UserLink { return populate(ctx, g, "", seed) } @@ -164,7 +169,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m seed[name+"_secret"] = secret case database.Template: template, err := db.InsertTemplate(ctx, database.InsertTemplateParams{ - ID: g.Lookup(name), + ID: takeFirst(orig.ID, g.Lookup(name)), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), @@ -182,10 +187,30 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert template") + seed[name] = template + + case database.TemplateVersion: + template, err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + ID: takeFirst(orig.ID, g.Lookup(name)), + TemplateID: uuid.NullUUID{ + UUID: takeFirst(orig.TemplateID.UUID, uuid.New()), + Valid: takeFirst(orig.TemplateID.Valid, true), + }, + OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), + CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Readme: takeFirst(orig.Readme, namesgenerator.GetRandomName(1)), + JobID: takeFirst(orig.JobID, uuid.New()), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + }) + require.NoError(t, err, "insert template") + seed[name] = template case database.Workspace: workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ - ID: g.Lookup(name), + ID: takeFirst(orig.ID, g.Lookup(name)), + OwnerID: takeFirst(orig.OwnerID, uuid.New()), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), @@ -199,7 +224,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m seed[name] = workspace case database.WorkspaceBuild: build, err := db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ - ID: g.Lookup(name), + ID: takeFirst(orig.ID, g.Lookup(name)), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()), @@ -217,7 +242,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m seed[name] = build case database.User: user, err := db.InsertUser(ctx, database.InsertUserParams{ - ID: g.Lookup(name), + ID: takeFirst(orig.ID, g.Lookup(name)), Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)), Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)), HashedPassword: takeFirstBytes(orig.HashedPassword, []byte{}), @@ -232,7 +257,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.Organization: org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ - ID: g.Lookup(name), + ID: takeFirst(orig.ID, g.Lookup(name)), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), @@ -244,7 +269,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.Group: org, err := db.InsertGroup(ctx, database.InsertGroupParams{ - ID: g.Lookup(name), + ID: takeFirst(orig.ID, g.Lookup(name)), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), AvatarURL: takeFirst(orig.AvatarURL, "https://logo.example.com"), @@ -256,7 +281,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.ProvisionerJob: job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: g.Lookup(name), + ID: takeFirst(orig.ID, g.Lookup(name)), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), @@ -274,7 +299,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.WorkspaceResource: resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ - ID: g.Lookup(name), + ID: takeFirst(orig.ID, g.Lookup(name)), CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), JobID: takeFirst(orig.JobID, uuid.New()), Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), @@ -293,6 +318,18 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m seed[name] = resource + case database.File: + file, err := db.InsertFile(ctx, database.InsertFileParams{ + ID: takeFirst(orig.ID, g.Lookup(name)), + Hash: takeFirst(orig.Hash, hex.EncodeToString(make([]byte, 32))), + CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + Mimetype: takeFirst(orig.Mimetype, "application/x-tar"), + Data: takeFirstBytes(orig.Data, []byte{}), + }) + require.NoError(t, err, "insert file") + + seed[name] = file case database.UserLink: link, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{ UserID: takeFirst(orig.UserID, uuid.New()), @@ -306,6 +343,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m require.NoError(t, err, "insert link") seed[name] = link + default: + panic(fmt.Sprintf("unknown type %T", orig)) } } return seed diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index f9a4d74782647..f840a051c5e29 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -87,36 +87,39 @@ func TestAcquireJob(t *testing.T) { t.Parallel() srv := setup(t, false) ctx := context.Background() - user, err := srv.Database.InsertUser(context.Background(), database.InsertUserParams{ - ID: uuid.New(), - Username: "testing", - LoginType: database.LoginTypePassword, - }) - require.NoError(t, err) - template, err := srv.Database.InsertTemplate(ctx, database.InsertTemplateParams{ - ID: uuid.New(), + + gen := databasefake.NewGenerator(t, srv.Database) + user := gen.User(ctx, database.User{}) + template := gen.Template(ctx, database.Template{ Name: "template", Provisioner: database.ProvisionerTypeEcho, }) - require.NoError(t, err) - version, err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ - ID: uuid.New(), + file := gen.File(ctx, database.File{CreatedBy: user.ID}) + versionFile := gen.File(ctx, database.File{CreatedBy: user.ID}) + version := gen.TemplateVersion(ctx, database.TemplateVersion{ TemplateID: uuid.NullUUID{ UUID: template.ID, Valid: true, }, JobID: uuid.New(), }) - require.NoError(t, err) - workspace, err := srv.Database.InsertWorkspace(ctx, database.InsertWorkspaceParams{ - ID: uuid.New(), - OwnerID: user.ID, + // Import version job + _ = gen.Job(ctx, database.ProvisionerJob{ + ID: version.JobID, + InitiatorID: user.ID, + FileID: versionFile.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeTemplateVersionImport, + Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{ + TemplateVersionID: version.ID, + })), + }) + workspace := gen.Workspace(ctx, database.Workspace{ TemplateID: template.ID, - Name: "workspace", + OwnerID: user.ID, }) - require.NoError(t, err) - build, err := srv.Database.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ - ID: uuid.New(), + build := gen.WorkspaceBuild(ctx, database.WorkspaceBuild{ WorkspaceID: workspace.ID, BuildNumber: 1, JobID: uuid.New(), @@ -124,33 +127,17 @@ func TestAcquireJob(t *testing.T) { Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) - require.NoError(t, err) - - data, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ - WorkspaceBuildID: build.ID, - }) - require.NoError(t, err) - - file, err := srv.Database.InsertFile(ctx, database.InsertFileParams{ - ID: uuid.New(), - Hash: "something", - Data: []byte{}, - }) - require.NoError(t, err) - - _, err = srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ - ID: build.JobID, - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - OrganizationID: uuid.New(), - InitiatorID: user.ID, - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - FileID: file.ID, - Type: database.ProvisionerJobTypeWorkspaceBuild, - Input: data, + _ = gen.Job(ctx, database.ProvisionerJob{ + ID: build.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ + WorkspaceBuildID: build.ID, + })), }) - require.NoError(t, err) published := make(chan struct{}) closeSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { @@ -159,8 +146,17 @@ func TestAcquireJob(t *testing.T) { require.NoError(t, err) defer closeSubscribe() - job, err := srv.AcquireJob(ctx, nil) - require.NoError(t, err) + var job *proto.AcquiredJob + + for { + // Grab jobs until we find the workspace build job. There is also + // an import version job that we need to ignore. + job, err = srv.AcquireJob(ctx, nil) + require.NoError(t, err) + if _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok { + break + } + } <-published @@ -191,44 +187,23 @@ func TestAcquireJob(t *testing.T) { t.Parallel() srv := setup(t, false) ctx := context.Background() - user, err := srv.Database.InsertUser(ctx, database.InsertUserParams{ - ID: uuid.New(), - Username: "testing", - LoginType: database.LoginTypePassword, - }) - require.NoError(t, err) - version, err := srv.Database.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ - ID: uuid.New(), - }) - require.NoError(t, err) - data, err := json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ - TemplateVersionID: version.ID, - WorkspaceName: "testing", - ParameterValues: []database.ParameterValue{}, - }) - require.NoError(t, err) - - file, err := srv.Database.InsertFile(ctx, database.InsertFileParams{ - ID: uuid.New(), - Hash: "something", - Data: []byte{}, - }) - require.NoError(t, err) - - _, err = srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ - ID: uuid.New(), - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - OrganizationID: uuid.New(), - InitiatorID: user.ID, - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - FileID: file.ID, - Type: database.ProvisionerJobTypeTemplateVersionDryRun, - Input: data, + gen := databasefake.NewGenerator(t, srv.Database) + user := gen.User(ctx, database.User{}) + version := gen.TemplateVersion(ctx, database.TemplateVersion{}) + file := gen.File(ctx, database.File{CreatedBy: user.ID}) + _ = gen.Job(ctx, database.ProvisionerJob{ + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + FileID: file.ID, + Type: database.ProvisionerJobTypeTemplateVersionDryRun, + Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ + TemplateVersionID: version.ID, + WorkspaceName: "testing", + ParameterValues: []database.ParameterValue{}, + })), }) - require.NoError(t, err) job, err := srv.AcquireJob(ctx, nil) require.NoError(t, err) @@ -252,33 +227,17 @@ func TestAcquireJob(t *testing.T) { t.Parallel() srv := setup(t, false) ctx := context.Background() - user, err := srv.Database.InsertUser(ctx, database.InsertUserParams{ - ID: uuid.New(), - Username: "testing", - LoginType: database.LoginTypePassword, - }) - require.NoError(t, err) - - file, err := srv.Database.InsertFile(ctx, database.InsertFileParams{ - ID: uuid.New(), - Hash: "something", - Data: []byte{}, - }) - require.NoError(t, err) - _, err = srv.Database.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ - ID: uuid.New(), - CreatedAt: database.Now(), - UpdatedAt: database.Now(), - OrganizationID: uuid.New(), - InitiatorID: user.ID, - Provisioner: database.ProvisionerTypeEcho, - StorageMethod: database.ProvisionerStorageMethodFile, - FileID: file.ID, - Type: database.ProvisionerJobTypeTemplateVersionImport, - Input: json.RawMessage{}, + gen := databasefake.NewGenerator(t, srv.Database) + user := gen.User(ctx, database.User{}) + file := gen.File(ctx, database.File{CreatedBy: user.ID}) + _ = gen.Job(ctx, database.ProvisionerJob{ + FileID: file.ID, + InitiatorID: user.ID, + Provisioner: database.ProvisionerTypeEcho, + StorageMethod: database.ProvisionerStorageMethodFile, + Type: database.ProvisionerJobTypeTemplateVersionImport, }) - require.NoError(t, err) job, err := srv.AcquireJob(ctx, nil) require.NoError(t, err) @@ -855,3 +814,10 @@ func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server { Auditor: mockAuditor(), } } + +func must[T any](value T, err error) T { + if err != nil { + panic(err) + } + return value +} From 396aa4a48abda163c0aadceb369e003ed202ad19 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 30 Jan 2023 15:26:41 -0600 Subject: [PATCH 09/21] Refactor prom metrics test --- coderd/database/databasefake/generator.go | 29 ++++++----- .../prometheusmetrics_test.go | 52 ++++++++----------- 2 files changed, 36 insertions(+), 45 deletions(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index a56f576c3734b..ec028bd569b45 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -141,6 +141,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m db := g.db t := g.testT + output := make(map[string]interface{}) for name, v := range seed { switch orig := v.(type) { case database.APIKey: @@ -164,9 +165,9 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert api key") - seed[name] = key + output[name] = key // Need to also save the secret - seed[name+"_secret"] = secret + output[name+"_secret"] = secret case database.Template: template, err := db.InsertTemplate(ctx, database.InsertTemplateParams{ ID: takeFirst(orig.ID, g.Lookup(name)), @@ -187,7 +188,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert template") - seed[name] = template + output[name] = template case database.TemplateVersion: template, err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ @@ -206,7 +207,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert template") - seed[name] = template + output[name] = template case database.Workspace: workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ ID: takeFirst(orig.ID, g.Lookup(name)), @@ -221,7 +222,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert workspace") - seed[name] = workspace + output[name] = workspace case database.WorkspaceBuild: build, err := db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ ID: takeFirst(orig.ID, g.Lookup(name)), @@ -239,7 +240,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert workspace build") - seed[name] = build + output[name] = build case database.User: user, err := db.InsertUser(ctx, database.InsertUserParams{ ID: takeFirst(orig.ID, g.Lookup(name)), @@ -253,7 +254,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert user") - seed[name] = user + output[name] = user case database.Organization: org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ @@ -265,7 +266,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert organization") - seed[name] = org + output[name] = org case database.Group: org, err := db.InsertGroup(ctx, database.InsertGroupParams{ @@ -277,7 +278,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert organization") - seed[name] = org + output[name] = org case database.ProvisionerJob: job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ @@ -295,7 +296,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert job") - seed[name] = job + output[name] = job case database.WorkspaceResource: resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ @@ -316,7 +317,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert resource") - seed[name] = resource + output[name] = resource case database.File: file, err := db.InsertFile(ctx, database.InsertFileParams{ @@ -329,7 +330,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m }) require.NoError(t, err, "insert file") - seed[name] = file + output[name] = file case database.UserLink: link, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{ UserID: takeFirst(orig.UserID, uuid.New()), @@ -342,12 +343,12 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m require.NoError(t, err, "insert link") - seed[name] = link + output[name] = link default: panic(fmt.Sprintf("unknown type %T", orig)) } } - return seed + return output } func (g *Generator) Lookup(name string) uuid.UUID { diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index 0ebf7fa09228c..74c382ad22986 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -23,63 +23,53 @@ func TestActiveUsers(t *testing.T) { for _, tc := range []struct { Name string - Database func() database.Store + Database func(t *testing.T) database.Store Count int }{{ Name: "None", - Database: func() database.Store { + Database: func(t *testing.T) database.Store { return databasefake.New() }, Count: 0, }, { Name: "One", - Database: func() database.Store { + Database: func(t *testing.T) database.Store { db := databasefake.New() - _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ - UserID: uuid.New(), - LastUsed: database.Now(), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, + gen := databasefake.NewGenerator(t, db) + gen.APIKey(context.Background(), database.APIKey{ + LastUsed: database.Now(), }) return db }, Count: 1, }, { Name: "OneWithExpired", - Database: func() database.Store { + Database: func(t *testing.T) database.Store { db := databasefake.New() - _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ - UserID: uuid.New(), - LastUsed: database.Now(), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, + gen := databasefake.NewGenerator(t, db) + + gen.APIKey(context.Background(), database.APIKey{ + LastUsed: database.Now(), }) + // Because this API key hasn't been used in the past hour, this shouldn't // add to the user count. - _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ - UserID: uuid.New(), - LastUsed: database.Now().Add(-2 * time.Hour), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, + gen.APIKey(context.Background(), database.APIKey{ + LastUsed: database.Now().Add(-2 * time.Hour), }) return db }, Count: 1, }, { Name: "Multiple", - Database: func() database.Store { + Database: func(t *testing.T) database.Store { db := databasefake.New() - _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ - UserID: uuid.New(), - LastUsed: database.Now(), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, + gen := databasefake.NewGenerator(t, db) + gen.APIKey(context.Background(), database.APIKey{ + LastUsed: database.Now(), }) - _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ - UserID: uuid.New(), - LastUsed: database.Now(), - LoginType: database.LoginTypePassword, - Scope: database.APIKeyScopeAll, + gen.APIKey(context.Background(), database.APIKey{ + LastUsed: database.Now(), }) return db }, @@ -89,7 +79,7 @@ func TestActiveUsers(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { t.Parallel() registry := prometheus.NewRegistry() - cancel, err := prometheusmetrics.ActiveUsers(context.Background(), registry, tc.Database(), time.Millisecond) + cancel, err := prometheusmetrics.ActiveUsers(context.Background(), registry, tc.Database(t), time.Millisecond) require.NoError(t, err) t.Cleanup(cancel) From 564849988a4e7a8e2bf897693e83090672e5f174 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 30 Jan 2023 15:41:10 -0600 Subject: [PATCH 10/21] Add unit tests for generator --- .../database/databasefake/generator_test.go | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 coderd/database/databasefake/generator_test.go diff --git a/coderd/database/databasefake/generator_test.go b/coderd/database/databasefake/generator_test.go new file mode 100644 index 0000000000000..ddbd7a7a41c09 --- /dev/null +++ b/coderd/database/databasefake/generator_test.go @@ -0,0 +1,98 @@ +package databasefake_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" +) + +func TestGenerator(t *testing.T) { + t.Parallel() + + // Reuse the same database for all tests. + db := databasefake.New() + gen := databasefake.NewGenerator(t, db) + + t.Run("APIKey", func(t *testing.T) { + t.Parallel() + exp, _ := gen.APIKey(context.Background(), database.APIKey{}) + require.Equal(t, exp, must(db.GetAPIKeyByID(context.Background(), exp.ID))) + }) + + t.Run("File", func(t *testing.T) { + t.Parallel() + exp := gen.File(context.Background(), database.File{}) + require.Equal(t, exp, must(db.GetFileByID(context.Background(), exp.ID))) + }) + + t.Run("UserLink", func(t *testing.T) { + t.Parallel() + exp := gen.UserLink(context.Background(), database.UserLink{}) + require.Equal(t, exp, must(db.GetUserLinkByLinkedID(context.Background(), exp.LinkedID))) + }) + + t.Run("WorkspaceResource", func(t *testing.T) { + t.Parallel() + exp := gen.WorkspaceResource(context.Background(), database.WorkspaceResource{}) + require.Equal(t, exp, must(db.GetWorkspaceResourceByID(context.Background(), exp.ID))) + }) + + t.Run("Job", func(t *testing.T) { + t.Parallel() + exp := gen.Job(context.Background(), database.ProvisionerJob{}) + require.Equal(t, exp, must(db.GetProvisionerJobByID(context.Background(), exp.ID))) + }) + + t.Run("Group", func(t *testing.T) { + t.Parallel() + exp := gen.Group(context.Background(), database.Group{}) + require.Equal(t, exp, must(db.GetGroupByID(context.Background(), exp.ID))) + }) + + t.Run("Organization", func(t *testing.T) { + t.Parallel() + exp := gen.Organization(context.Background(), database.Organization{}) + require.Equal(t, exp, must(db.GetOrganizationByID(context.Background(), exp.ID))) + }) + + t.Run("Workspace", func(t *testing.T) { + t.Parallel() + exp := gen.Workspace(context.Background(), database.Workspace{}) + require.Equal(t, exp, must(db.GetWorkspaceByID(context.Background(), exp.ID))) + }) + + t.Run("Template", func(t *testing.T) { + t.Parallel() + exp := gen.Template(context.Background(), database.Template{}) + require.Equal(t, exp, must(db.GetTemplateByID(context.Background(), exp.ID))) + }) + + t.Run("TemplateVersion", func(t *testing.T) { + t.Parallel() + exp := gen.TemplateVersion(context.Background(), database.TemplateVersion{}) + require.Equal(t, exp, must(db.GetTemplateVersionByID(context.Background(), exp.ID))) + }) + + t.Run("WorkspaceBuild", func(t *testing.T) { + t.Parallel() + exp := gen.WorkspaceBuild(context.Background(), database.WorkspaceBuild{}) + require.Equal(t, exp, must(db.GetWorkspaceBuildByID(context.Background(), exp.ID))) + }) + + t.Run("User", func(t *testing.T) { + t.Parallel() + exp := gen.User(context.Background(), database.User{}) + require.Equal(t, exp, must(db.GetUserByID(context.Background(), exp.ID))) + }) +} + +func must[T any](value T, err error) T { + if err != nil { + panic(err) + } + return value +} From 298ca3e7044306a2bbc07cd4f1dc9314f41559d2 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 09:51:07 -0600 Subject: [PATCH 11/21] Remove takeFirstTime --- coderd/database/databasefake/generator.go | 63 +++++++++++------------ 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index ec028bd569b45..fef081464d86a 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -137,6 +137,8 @@ func (g *Generator) User(ctx context.Context, seed database.User) database.User return populate(ctx, g, "", seed) } +// Populate uses `require` which calls `t.FailNow()` and must be called from the +// go routine running the test or benchmark function. func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) map[string]interface{} { db := g.db t := g.testT @@ -156,10 +158,10 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m HashedSecret: takeFirstBytes(orig.HashedSecret, hashed[:]), IPAddress: pqtype.Inet{}, UserID: takeFirst(orig.UserID, uuid.New()), - LastUsed: takeFirstTime(orig.LastUsed, time.Now()), - ExpiresAt: takeFirstTime(orig.ExpiresAt, time.Now().Add(time.Hour)), - CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), + LastUsed: takeFirst(orig.LastUsed, time.Now()), + ExpiresAt: takeFirst(orig.ExpiresAt, time.Now().Add(time.Hour)), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), Scope: takeFirst(orig.Scope, database.APIKeyScopeAll), }) @@ -171,8 +173,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.Template: template, err := db.InsertTemplate(ctx, database.InsertTemplateParams{ ID: takeFirst(orig.ID, g.Lookup(name)), - CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), @@ -198,8 +200,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m Valid: takeFirst(orig.TemplateID.Valid, true), }, OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), - CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), Readme: takeFirst(orig.Readme, namesgenerator.GetRandomName(1)), JobID: takeFirst(orig.JobID, uuid.New()), @@ -212,8 +214,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ ID: takeFirst(orig.ID, g.Lookup(name)), OwnerID: takeFirst(orig.OwnerID, uuid.New()), - CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), TemplateID: takeFirst(orig.TemplateID, uuid.New()), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), @@ -226,8 +228,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.WorkspaceBuild: build, err := db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ ID: takeFirst(orig.ID, g.Lookup(name)), - CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()), TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()), BuildNumber: takeFirst(orig.BuildNumber, 0), @@ -235,7 +237,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), JobID: takeFirst(orig.JobID, uuid.New()), ProvisionerState: takeFirstBytes(orig.ProvisionerState, []byte{}), - Deadline: takeFirstTime(orig.Deadline, time.Now().Add(time.Hour)), + Deadline: takeFirst(orig.Deadline, time.Now().Add(time.Hour)), Reason: takeFirst(orig.Reason, database.BuildReasonInitiator), }) require.NoError(t, err, "insert workspace build") @@ -247,8 +249,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)), Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)), HashedPassword: takeFirstBytes(orig.HashedPassword, []byte{}), - CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), RBACRoles: []string{}, LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), }) @@ -261,8 +263,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m ID: takeFirst(orig.ID, g.Lookup(name)), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), - CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), }) require.NoError(t, err, "insert organization") @@ -283,8 +285,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.ProvisionerJob: job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: takeFirst(orig.ID, g.Lookup(name)), - CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirstTime(orig.UpdatedAt, time.Now()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), @@ -301,14 +303,13 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m case database.WorkspaceResource: resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ ID: takeFirst(orig.ID, g.Lookup(name)), - CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), JobID: takeFirst(orig.JobID, uuid.New()), Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), - // TODO: What type to put here? - Type: takeFirst(orig.Type, ""), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - Hide: takeFirst(orig.Hide, false), - Icon: takeFirst(orig.Icon, ""), + Type: takeFirst(orig.Type, "fake_resource"), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Hide: takeFirst(orig.Hide, false), + Icon: takeFirst(orig.Icon, ""), InstanceType: sql.NullString{ String: takeFirst(orig.InstanceType.String, ""), Valid: takeFirst(orig.InstanceType.Valid, false), @@ -323,7 +324,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m file, err := db.InsertFile(ctx, database.InsertFileParams{ ID: takeFirst(orig.ID, g.Lookup(name)), Hash: takeFirst(orig.Hash, hex.EncodeToString(make([]byte, 32))), - CreatedAt: takeFirstTime(orig.CreatedAt, time.Now()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), Mimetype: takeFirst(orig.Mimetype, "application/x-tar"), Data: takeFirstBytes(orig.Data, []byte{}), @@ -338,7 +339,7 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m LinkedID: takeFirst(orig.LinkedID), OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthExpiry: takeFirstTime(orig.OAuthExpiry, time.Now().Add(time.Hour*24)), + OAuthExpiry: takeFirst(orig.OAuthExpiry, time.Now().Add(time.Hour*24)), }) require.NoError(t, err, "insert link") @@ -367,12 +368,8 @@ func (g *Generator) Lookup(name string) uuid.UUID { return id } -func takeFirstTime(values ...time.Time) time.Time { - return takeFirstF(values, func(v time.Time) bool { - return !v.IsZero() - }) -} - +// takeFirstBytes implements takeFirst for []byte. +// []byte is not a comparable type. func takeFirstBytes(values ...[]byte) []byte { return takeFirstF(values, func(v []byte) bool { return len(v) != 0 From 3bfd4d6ce6a6f99f0961e4c8110d004093c729b4 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 09:56:03 -0600 Subject: [PATCH 12/21] Add more randomness to names --- coderd/database/databasefake/generator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index fef081464d86a..50536fe25d3f3 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -74,7 +74,7 @@ func populate[DBType any](ctx context.Context, g *Generator, name string, seed D func (g *Generator) RandomName() string { for { - name := namesgenerator.GetRandomName(0) + name := namesgenerator.GetRandomName(1) if _, ok := g.names[name]; !ok { return name } From 7fc8ead043344ea7465ce18f3d1f7562f8fd5bf2 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 09:59:21 -0600 Subject: [PATCH 13/21] Use fatal over panic --- coderd/database/databasefake/generator.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index 50536fe25d3f3..1c0e94b356d27 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -140,6 +140,7 @@ func (g *Generator) User(ctx context.Context, seed database.User) database.User // Populate uses `require` which calls `t.FailNow()` and must be called from the // go routine running the test or benchmark function. func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) map[string]interface{} { + g.testT.Helper() db := g.db t := g.testT @@ -346,7 +347,8 @@ func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) m output[name] = link default: - panic(fmt.Sprintf("unknown type %T", orig)) + // If you hit this, just add your type to the switch. + t.Fatalf("unknown type '%T' used in fake data generator", orig) } } return output From 5ba346377e28a432899f28f80a4179de606bc648 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 10:06:14 -0600 Subject: [PATCH 14/21] Remove use of panics --- coderd/database/databasefake/generator.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index 1c0e94b356d27..60bf0f643cdc8 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -31,7 +31,9 @@ type Generator struct { func NewGenerator(t *testing.T, db database.Store) *Generator { if _, ok := db.(FakeDatabase); !ok { - panic("Generator db must be a FakeDatabase") + // This does not work for postgres databases because of foreign key + // constraints + t.Fatalf("Generator db must be a FakeDatabase") } return &Generator{ names: make(map[string]uuid.UUID), @@ -58,6 +60,7 @@ func (g *Generator) PrimaryOrg(ctx context.Context) database.Organization { } func populate[DBType any](ctx context.Context, g *Generator, name string, seed DBType) DBType { + g.testT.Helper() if name == "" { name = g.RandomName() } @@ -67,7 +70,7 @@ func populate[DBType any](ctx context.Context, g *Generator, name string, seed D }) v, ok := out[name].(DBType) if !ok { - panic("developer error, type mismatch") + g.testT.Fatalf("developer error, type mismatch in data generator") } return v } From bc88e41f4c7b1ecc16ee625041621ef88a37e5a6 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 12:43:53 -0600 Subject: [PATCH 15/21] Remove all state from generator --- coderd/database/databasefake/generator.go | 579 ++++++++---------- .../database/databasefake/generator_test.go | 40 +- 2 files changed, 268 insertions(+), 351 deletions(-) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index 60bf0f643cdc8..2c3b6eda0045b 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "database/sql" "encoding/hex" - "fmt" "testing" "time" @@ -18,359 +17,269 @@ import ( "github.com/coder/coder/cryptorand" ) -const primaryOrgName = "primary-org" - -type Generator struct { - // names is a map of names to uuids. - names map[string]uuid.UUID - primaryOrg *database.Organization - testT *testing.T - - db database.Store +type Supported interface { + database.APIKey | generatedAPIKey | + database.File | + database.UserLink | + database.WorkspaceResource | + database.ProvisionerJob | + database.Group | + database.Organization | + database.Workspace | + database.Template | + database.TemplateVersion | + database.WorkspaceBuild | + database.User } -func NewGenerator(t *testing.T, db database.Store) *Generator { - if _, ok := db.(FakeDatabase); !ok { - // This does not work for postgres databases because of foreign key - // constraints - t.Fatalf("Generator db must be a FakeDatabase") - } - return &Generator{ - names: make(map[string]uuid.UUID), - testT: t, - db: db, - } +type generatedAPIKey struct { + Secret string + Key database.APIKey } -// PrimaryOrg is to keep all resources in the same default org if not -// specified. -func (g *Generator) PrimaryOrg(ctx context.Context) database.Organization { - if g.primaryOrg == nil { - org := g.Organization(ctx, database.Organization{ - ID: g.Lookup(primaryOrgName), - Name: primaryOrgName, - Description: "This is the default primary organization for all tests", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }) - g.primaryOrg = &org - } - - return *g.primaryOrg -} - -func populate[DBType any](ctx context.Context, g *Generator, name string, seed DBType) DBType { - g.testT.Helper() - if name == "" { - name = g.RandomName() - } - - out := g.Populate(ctx, map[string]interface{}{ - name: seed, +// GenerateAPIKey is a special case that allows returning the secret for the +// api key. +func GenerateAPIKey(t *testing.T, db database.Store, seed database.APIKey) (key database.APIKey, secret string) { + out := generate(t, db, generatedAPIKey{ + Key: seed, }) - v, ok := out[name].(DBType) + v, ok := out.(generatedAPIKey) if !ok { - g.testT.Fatalf("developer error, type mismatch in data generator") + t.Fatalf("Returned type '%T' doses not match expected '%T'", out, generatedAPIKey{}) } - return v + return v.Key, v.Secret } -func (g *Generator) RandomName() string { - for { - name := namesgenerator.GetRandomName(1) - if _, ok := g.names[name]; !ok { - return name - } +func Generate[Object Supported](t *testing.T, db database.Store, seed Object) Object { + out := generate(t, db, seed) + v, ok := out.(Object) + if !ok { + var empty Object + t.Fatalf("Returned type '%T' doses not match expected '%T'", out, empty) } + return v } -func (g *Generator) APIKey(ctx context.Context, seed database.APIKey) (key database.APIKey, token string) { - name := g.RandomName() - out := g.Populate(ctx, map[string]interface{}{ - name: seed, - }) - key, keyOk := out[name].(database.APIKey) - secret, secOk := out[name+"_secret"].(string) - require.True(g.testT, keyOk && secOk, "APIKey & secret must be populated with the right type") - - return key, fmt.Sprintf("%s-%s", key.ID, secret) -} - -func (g *Generator) File(ctx context.Context, seed database.File) database.File { - return populate(ctx, g, "", seed) -} - -func (g *Generator) UserLink(ctx context.Context, seed database.UserLink) database.UserLink { - return populate(ctx, g, "", seed) -} - -func (g *Generator) WorkspaceResource(ctx context.Context, seed database.WorkspaceResource) database.WorkspaceResource { - return populate(ctx, g, "", seed) -} - -func (g *Generator) Job(ctx context.Context, seed database.ProvisionerJob) database.ProvisionerJob { - return populate(ctx, g, "", seed) -} - -func (g *Generator) Group(ctx context.Context, seed database.Group) database.Group { - return populate(ctx, g, "", seed) -} - -func (g *Generator) Organization(ctx context.Context, seed database.Organization) database.Organization { - return populate(ctx, g, "", seed) -} - -func (g *Generator) Workspace(ctx context.Context, seed database.Workspace) database.Workspace { - return populate(ctx, g, "", seed) -} - -func (g *Generator) Template(ctx context.Context, seed database.Template) database.Template { - return populate(ctx, g, "", seed) -} - -func (g *Generator) TemplateVersion(ctx context.Context, seed database.TemplateVersion) database.TemplateVersion { - return populate(ctx, g, "", seed) -} - -func (g *Generator) WorkspaceBuild(ctx context.Context, seed database.WorkspaceBuild) database.WorkspaceBuild { - return populate(ctx, g, "", seed) -} - -func (g *Generator) User(ctx context.Context, seed database.User) database.User { - return populate(ctx, g, "", seed) -} - -// Populate uses `require` which calls `t.FailNow()` and must be called from the -// go routine running the test or benchmark function. -func (g *Generator) Populate(ctx context.Context, seed map[string]interface{}) map[string]interface{} { - g.testT.Helper() - db := g.db - t := g.testT - - output := make(map[string]interface{}) - for name, v := range seed { - switch orig := v.(type) { - case database.APIKey: - id, _ := cryptorand.String(10) - secret, _ := cryptorand.String(22) - hashed := sha256.Sum256([]byte(secret)) - - key, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ - ID: takeFirst(orig.ID, id), - // 0 defaults to 86400 at the db layer - LifetimeSeconds: takeFirst(orig.LifetimeSeconds, 0), - HashedSecret: takeFirstBytes(orig.HashedSecret, hashed[:]), - IPAddress: pqtype.Inet{}, - UserID: takeFirst(orig.UserID, uuid.New()), - LastUsed: takeFirst(orig.LastUsed, time.Now()), - ExpiresAt: takeFirst(orig.ExpiresAt, time.Now().Add(time.Hour)), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), - Scope: takeFirst(orig.Scope, database.APIKeyScopeAll), - }) - require.NoError(t, err, "insert api key") - - output[name] = key - // Need to also save the secret - output[name+"_secret"] = secret - case database.Template: - template, err := db.InsertTemplate(ctx, database.InsertTemplateParams{ - ID: takeFirst(orig.ID, g.Lookup(name)), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), - ActiveVersionID: takeFirst(orig.ActiveVersionID, uuid.New()), - Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), - DefaultTTL: takeFirst(orig.DefaultTTL, 3600), - CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), - Icon: takeFirst(orig.Icon, namesgenerator.GetRandomName(1)), - UserACL: orig.UserACL, - GroupACL: orig.GroupACL, - DisplayName: takeFirst(orig.DisplayName, namesgenerator.GetRandomName(1)), - AllowUserCancelWorkspaceJobs: takeFirst(orig.AllowUserCancelWorkspaceJobs, true), - }) - require.NoError(t, err, "insert template") - - output[name] = template - - case database.TemplateVersion: - template, err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ - ID: takeFirst(orig.ID, g.Lookup(name)), - TemplateID: uuid.NullUUID{ - UUID: takeFirst(orig.TemplateID.UUID, uuid.New()), - Valid: takeFirst(orig.TemplateID.Valid, true), - }, - OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - Readme: takeFirst(orig.Readme, namesgenerator.GetRandomName(1)), - JobID: takeFirst(orig.JobID, uuid.New()), - CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), - }) - require.NoError(t, err, "insert template") - - output[name] = template - case database.Workspace: - workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ - ID: takeFirst(orig.ID, g.Lookup(name)), - OwnerID: takeFirst(orig.OwnerID, uuid.New()), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), - TemplateID: takeFirst(orig.TemplateID, uuid.New()), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - AutostartSchedule: orig.AutostartSchedule, - Ttl: orig.Ttl, - }) - require.NoError(t, err, "insert workspace") - - output[name] = workspace - case database.WorkspaceBuild: - build, err := db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ - ID: takeFirst(orig.ID, g.Lookup(name)), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()), - TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()), - BuildNumber: takeFirst(orig.BuildNumber, 0), - Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), - InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), - JobID: takeFirst(orig.JobID, uuid.New()), - ProvisionerState: takeFirstBytes(orig.ProvisionerState, []byte{}), - Deadline: takeFirst(orig.Deadline, time.Now().Add(time.Hour)), - Reason: takeFirst(orig.Reason, database.BuildReasonInitiator), - }) - require.NoError(t, err, "insert workspace build") - - output[name] = build - case database.User: - user, err := db.InsertUser(ctx, database.InsertUserParams{ - ID: takeFirst(orig.ID, g.Lookup(name)), - Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)), - Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)), - HashedPassword: takeFirstBytes(orig.HashedPassword, []byte{}), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - RBACRoles: []string{}, - LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), - }) - require.NoError(t, err, "insert user") +func generate(t *testing.T, db database.Store, seed interface{}) interface{} { + t.Helper() - output[name] = user - - case database.Organization: - org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ - ID: takeFirst(orig.ID, g.Lookup(name)), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - }) - require.NoError(t, err, "insert organization") - - output[name] = org - - case database.Group: - org, err := db.InsertGroup(ctx, database.InsertGroupParams{ - ID: takeFirst(orig.ID, g.Lookup(name)), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), - AvatarURL: takeFirst(orig.AvatarURL, "https://logo.example.com"), - QuotaAllowance: takeFirst(orig.QuotaAllowance, 0), - }) - require.NoError(t, err, "insert organization") - - output[name] = org - - case database.ProvisionerJob: - job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: takeFirst(orig.ID, g.Lookup(name)), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - OrganizationID: takeFirst(orig.OrganizationID, g.PrimaryOrg(ctx).ID), - InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), - Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), - StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile), - FileID: takeFirst(orig.FileID, uuid.New()), - Type: takeFirst(orig.Type, database.ProvisionerJobTypeWorkspaceBuild), - Input: takeFirstBytes(orig.Input, []byte("{}")), - Tags: orig.Tags, - }) - require.NoError(t, err, "insert job") - - output[name] = job - - case database.WorkspaceResource: - resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ - ID: takeFirst(orig.ID, g.Lookup(name)), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - JobID: takeFirst(orig.JobID, uuid.New()), - Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), - Type: takeFirst(orig.Type, "fake_resource"), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - Hide: takeFirst(orig.Hide, false), - Icon: takeFirst(orig.Icon, ""), - InstanceType: sql.NullString{ - String: takeFirst(orig.InstanceType.String, ""), - Valid: takeFirst(orig.InstanceType.Valid, false), - }, - DailyCost: takeFirst(orig.DailyCost, 0), - }) - require.NoError(t, err, "insert resource") - - output[name] = resource - - case database.File: - file, err := db.InsertFile(ctx, database.InsertFileParams{ - ID: takeFirst(orig.ID, g.Lookup(name)), - Hash: takeFirst(orig.Hash, hex.EncodeToString(make([]byte, 32))), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), - Mimetype: takeFirst(orig.Mimetype, "application/x-tar"), - Data: takeFirstBytes(orig.Data, []byte{}), - }) - require.NoError(t, err, "insert file") + if _, ok := db.(FakeDatabase); !ok { + // This does not work for postgres databases because of foreign key + // constraints + t.Fatalf("Generate() db must be a FakeDatabase") + } - output[name] = file - case database.UserLink: - link, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{ - UserID: takeFirst(orig.UserID, uuid.New()), - LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub), - LinkedID: takeFirst(orig.LinkedID), - OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthExpiry: takeFirst(orig.OAuthExpiry, time.Now().Add(time.Hour*24)), - }) + // db fake doesn't use contexts anyway. + ctx := context.Background() + + switch orig := seed.(type) { + case database.APIKey, generatedAPIKey: + // Annoying, but we need a way to return the secret if + // the caller needs it. + var g generatedAPIKey + v, isKey := seed.(database.APIKey) + if isKey { + g = generatedAPIKey{ + Key: v, + } + } else { + var ok bool + g, ok = seed.(generatedAPIKey) + if !ok { + t.Fatalf("type '%T' unsupported", seed) + } + } - require.NoError(t, err, "insert link") + id, _ := cryptorand.String(10) + secret, _ := cryptorand.String(22) + hashed := sha256.Sum256([]byte(secret)) + + key, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: takeFirst(g.Key.ID, id), + // 0 defaults to 86400 at the db layer + LifetimeSeconds: takeFirst(g.Key.LifetimeSeconds, 0), + HashedSecret: takeFirstBytes(g.Key.HashedSecret, hashed[:]), + IPAddress: pqtype.Inet{}, + UserID: takeFirst(g.Key.UserID, uuid.New()), + LastUsed: takeFirst(g.Key.LastUsed, time.Now()), + ExpiresAt: takeFirst(g.Key.ExpiresAt, time.Now().Add(time.Hour)), + CreatedAt: takeFirst(g.Key.CreatedAt, time.Now()), + UpdatedAt: takeFirst(g.Key.UpdatedAt, time.Now()), + LoginType: takeFirst(g.Key.LoginType, database.LoginTypePassword), + Scope: takeFirst(g.Key.Scope, database.APIKeyScopeAll), + }) + require.NoError(t, err, "insert api key") + g.Key = key + g.Secret = secret - output[name] = link - default: - // If you hit this, just add your type to the switch. - t.Fatalf("unknown type '%T' used in fake data generator", orig) + if isKey { + return g.Key } - } - return output -} + return g + case database.Template: + template, err := db.InsertTemplate(ctx, database.InsertTemplateParams{ + ID: takeFirst(orig.ID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), + ActiveVersionID: takeFirst(orig.ActiveVersionID, uuid.New()), + Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), + DefaultTTL: takeFirst(orig.DefaultTTL, 3600), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + Icon: takeFirst(orig.Icon, namesgenerator.GetRandomName(1)), + UserACL: orig.UserACL, + GroupACL: orig.GroupACL, + DisplayName: takeFirst(orig.DisplayName, namesgenerator.GetRandomName(1)), + AllowUserCancelWorkspaceJobs: takeFirst(orig.AllowUserCancelWorkspaceJobs, true), + }) + require.NoError(t, err, "insert template") + return template + case database.TemplateVersion: + version, err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ + ID: takeFirst(orig.ID, uuid.New()), + TemplateID: uuid.NullUUID{ + UUID: takeFirst(orig.TemplateID.UUID, uuid.New()), + Valid: takeFirst(orig.TemplateID.Valid, true), + }, + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Readme: takeFirst(orig.Readme, namesgenerator.GetRandomName(1)), + JobID: takeFirst(orig.JobID, uuid.New()), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + }) + require.NoError(t, err, "insert template version") + return version + case database.Workspace: + workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ + ID: takeFirst(orig.ID, uuid.New()), + OwnerID: takeFirst(orig.OwnerID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + TemplateID: takeFirst(orig.TemplateID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + AutostartSchedule: orig.AutostartSchedule, + Ttl: orig.Ttl, + }) + require.NoError(t, err, "insert workspace") + return workspace + case database.WorkspaceBuild: + build, err := db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ + ID: takeFirst(orig.ID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()), + TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()), + BuildNumber: takeFirst(orig.BuildNumber, 0), + Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), + InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), + JobID: takeFirst(orig.JobID, uuid.New()), + ProvisionerState: takeFirstBytes(orig.ProvisionerState, []byte{}), + Deadline: takeFirst(orig.Deadline, time.Now().Add(time.Hour)), + Reason: takeFirst(orig.Reason, database.BuildReasonInitiator), + }) + require.NoError(t, err, "insert workspace build") + return build + case database.User: + user, err := db.InsertUser(ctx, database.InsertUserParams{ + ID: takeFirst(orig.ID, uuid.New()), + Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)), + Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)), + HashedPassword: takeFirstBytes(orig.HashedPassword, []byte{}), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + RBACRoles: []string{}, + LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), + }) + require.NoError(t, err, "insert user") + return user + case database.Organization: + org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ + ID: takeFirst(orig.ID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + }) + require.NoError(t, err, "insert organization") + return org + case database.Group: + group, err := db.InsertGroup(ctx, database.InsertGroupParams{ + ID: takeFirst(orig.ID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + AvatarURL: takeFirst(orig.AvatarURL, "https://logo.example.com"), + QuotaAllowance: takeFirst(orig.QuotaAllowance, 0), + }) + require.NoError(t, err, "insert group") + return group + case database.ProvisionerJob: + job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ + ID: takeFirst(orig.ID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), + Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), + StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile), + FileID: takeFirst(orig.FileID, uuid.New()), + Type: takeFirst(orig.Type, database.ProvisionerJobTypeWorkspaceBuild), + Input: takeFirstBytes(orig.Input, []byte("{}")), + Tags: orig.Tags, + }) + require.NoError(t, err, "insert job") + return job + case database.WorkspaceResource: + resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ + ID: takeFirst(orig.ID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + JobID: takeFirst(orig.JobID, uuid.New()), + Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), + Type: takeFirst(orig.Type, "fake_resource"), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Hide: takeFirst(orig.Hide, false), + Icon: takeFirst(orig.Icon, ""), + InstanceType: sql.NullString{ + String: takeFirst(orig.InstanceType.String, ""), + Valid: takeFirst(orig.InstanceType.Valid, false), + }, + DailyCost: takeFirst(orig.DailyCost, 0), + }) + require.NoError(t, err, "insert resource") + return resource + case database.File: + file, err := db.InsertFile(ctx, database.InsertFileParams{ + ID: takeFirst(orig.ID, uuid.New()), + Hash: takeFirst(orig.Hash, hex.EncodeToString(make([]byte, 32))), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + Mimetype: takeFirst(orig.Mimetype, "application/x-tar"), + Data: takeFirstBytes(orig.Data, []byte{}), + }) + require.NoError(t, err, "insert file") + return file + case database.UserLink: + link, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{ + UserID: takeFirst(orig.UserID, uuid.New()), + LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub), + LinkedID: takeFirst(orig.LinkedID), + OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthExpiry: takeFirst(orig.OAuthExpiry, time.Now().Add(time.Hour*24)), + }) -func (g *Generator) Lookup(name string) uuid.UUID { - if name == "" { - // No name means the caller doesn't care about the ID. - return uuid.New() - } - if g.names == nil { - g.names = make(map[string]uuid.UUID) - } - if id, ok := g.names[name]; ok { - return id + require.NoError(t, err, "insert link") + return link + default: + // If you hit this, just add your type to the switch. + t.Fatalf("unknown type '%T' used in fake data generator", orig) + // This line will never be hit, but the compiler does not know that :/ + return nil } - id := uuid.New() - g.names[name] = id - return id } // takeFirstBytes implements takeFirst for []byte. diff --git a/coderd/database/databasefake/generator_test.go b/coderd/database/databasefake/generator_test.go index ddbd7a7a41c09..05de4b560a999 100644 --- a/coderd/database/databasefake/generator_test.go +++ b/coderd/database/databasefake/generator_test.go @@ -13,79 +13,87 @@ import ( func TestGenerator(t *testing.T) { t.Parallel() - // Reuse the same database for all tests. - db := databasefake.New() - gen := databasefake.NewGenerator(t, db) - t.Run("APIKey", func(t *testing.T) { t.Parallel() - exp, _ := gen.APIKey(context.Background(), database.APIKey{}) + db := databasefake.New() + exp, _ := databasefake.GenerateAPIKey(t, db, database.APIKey{}) require.Equal(t, exp, must(db.GetAPIKeyByID(context.Background(), exp.ID))) }) t.Run("File", func(t *testing.T) { t.Parallel() - exp := gen.File(context.Background(), database.File{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.File{}) require.Equal(t, exp, must(db.GetFileByID(context.Background(), exp.ID))) }) t.Run("UserLink", func(t *testing.T) { t.Parallel() - exp := gen.UserLink(context.Background(), database.UserLink{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.UserLink{}) require.Equal(t, exp, must(db.GetUserLinkByLinkedID(context.Background(), exp.LinkedID))) }) t.Run("WorkspaceResource", func(t *testing.T) { t.Parallel() - exp := gen.WorkspaceResource(context.Background(), database.WorkspaceResource{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.WorkspaceResource{}) require.Equal(t, exp, must(db.GetWorkspaceResourceByID(context.Background(), exp.ID))) }) t.Run("Job", func(t *testing.T) { t.Parallel() - exp := gen.Job(context.Background(), database.ProvisionerJob{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.ProvisionerJob{}) require.Equal(t, exp, must(db.GetProvisionerJobByID(context.Background(), exp.ID))) }) t.Run("Group", func(t *testing.T) { t.Parallel() - exp := gen.Group(context.Background(), database.Group{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.Group{}) require.Equal(t, exp, must(db.GetGroupByID(context.Background(), exp.ID))) }) t.Run("Organization", func(t *testing.T) { t.Parallel() - exp := gen.Organization(context.Background(), database.Organization{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.Organization{}) require.Equal(t, exp, must(db.GetOrganizationByID(context.Background(), exp.ID))) }) t.Run("Workspace", func(t *testing.T) { t.Parallel() - exp := gen.Workspace(context.Background(), database.Workspace{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.Workspace{}) require.Equal(t, exp, must(db.GetWorkspaceByID(context.Background(), exp.ID))) }) t.Run("Template", func(t *testing.T) { t.Parallel() - exp := gen.Template(context.Background(), database.Template{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.Template{}) require.Equal(t, exp, must(db.GetTemplateByID(context.Background(), exp.ID))) }) t.Run("TemplateVersion", func(t *testing.T) { t.Parallel() - exp := gen.TemplateVersion(context.Background(), database.TemplateVersion{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.TemplateVersion{}) require.Equal(t, exp, must(db.GetTemplateVersionByID(context.Background(), exp.ID))) }) t.Run("WorkspaceBuild", func(t *testing.T) { t.Parallel() - exp := gen.WorkspaceBuild(context.Background(), database.WorkspaceBuild{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.WorkspaceBuild{}) require.Equal(t, exp, must(db.GetWorkspaceBuildByID(context.Background(), exp.ID))) }) t.Run("User", func(t *testing.T) { t.Parallel() - exp := gen.User(context.Background(), database.User{}) + db := databasefake.New() + exp := databasefake.Generate(t, db, database.User{}) require.Equal(t, exp, must(db.GetUserByID(context.Background(), exp.ID))) }) } From 1113d8f2a8b5929d298d1d793ca1292e87e79f00 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 12:55:53 -0600 Subject: [PATCH 16/21] Move generate functions and no more generics --- .../databasefake/fakegen/generator.go | 221 ++++++++++++++++++ coderd/database/databasefake/fakegen/take.go | 29 +++ coderd/database/databasefake/generator.go | 22 ++ 3 files changed, 272 insertions(+) create mode 100644 coderd/database/databasefake/fakegen/generator.go create mode 100644 coderd/database/databasefake/fakegen/take.go diff --git a/coderd/database/databasefake/fakegen/generator.go b/coderd/database/databasefake/fakegen/generator.go new file mode 100644 index 0000000000000..e3ec2b57ffd0d --- /dev/null +++ b/coderd/database/databasefake/fakegen/generator.go @@ -0,0 +1,221 @@ +package fakegen + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/hex" + "testing" + "time" + + "github.com/coder/coder/cryptorand" + "github.com/tabbed/pqtype" + + "github.com/coder/coder/coderd/database" + "github.com/google/uuid" + "github.com/moby/moby/pkg/namesgenerator" + "github.com/stretchr/testify/require" +) + +func Template(t *testing.T, db database.Store, seed database.Template) database.Template { + template, err := db.InsertTemplate(context.Background(), database.InsertTemplateParams{ + ID: takeFirst(seed.ID, uuid.New()), + CreatedAt: takeFirst(seed.CreatedAt, time.Now()), + UpdatedAt: takeFirst(seed.UpdatedAt, time.Now()), + OrganizationID: takeFirst(seed.OrganizationID, uuid.New()), + Name: takeFirst(seed.Name, namesgenerator.GetRandomName(1)), + Provisioner: takeFirst(seed.Provisioner, database.ProvisionerTypeEcho), + ActiveVersionID: takeFirst(seed.ActiveVersionID, uuid.New()), + Description: takeFirst(seed.Description, namesgenerator.GetRandomName(1)), + DefaultTTL: takeFirst(seed.DefaultTTL, 3600), + CreatedBy: takeFirst(seed.CreatedBy, uuid.New()), + Icon: takeFirst(seed.Icon, namesgenerator.GetRandomName(1)), + UserACL: seed.UserACL, + GroupACL: seed.GroupACL, + DisplayName: takeFirst(seed.DisplayName, namesgenerator.GetRandomName(1)), + AllowUserCancelWorkspaceJobs: takeFirst(seed.AllowUserCancelWorkspaceJobs, true), + }) + require.NoError(t, err, "insert template") + return template +} + +func APIKey(t *testing.T, db database.Store, seed database.APIKey) (key database.APIKey, secret string) { + id, _ := cryptorand.String(10) + secret, _ = cryptorand.String(22) + hashed := sha256.Sum256([]byte(secret)) + + key, err := db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ + ID: takeFirst(seed.ID, id), + // 0 defaults to 86400 at the db layer + LifetimeSeconds: takeFirst(seed.LifetimeSeconds, 0), + HashedSecret: takeFirstBytes(seed.HashedSecret, hashed[:]), + IPAddress: pqtype.Inet{}, + UserID: takeFirst(seed.UserID, uuid.New()), + LastUsed: takeFirst(seed.LastUsed, time.Now()), + ExpiresAt: takeFirst(seed.ExpiresAt, time.Now().Add(time.Hour)), + CreatedAt: takeFirst(seed.CreatedAt, time.Now()), + UpdatedAt: takeFirst(seed.UpdatedAt, time.Now()), + LoginType: takeFirst(seed.LoginType, database.LoginTypePassword), + Scope: takeFirst(seed.Scope, database.APIKeyScopeAll), + }) + require.NoError(t, err, "insert api key") + return key, secret +} + +func Workspace(t *testing.T, db database.Store, orig database.Workspace) database.Workspace { + workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{ + ID: takeFirst(orig.ID, uuid.New()), + OwnerID: takeFirst(orig.OwnerID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + TemplateID: takeFirst(orig.TemplateID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + AutostartSchedule: orig.AutostartSchedule, + Ttl: orig.Ttl, + }) + require.NoError(t, err, "insert workspace") + return workspace +} + +func WorkspaceBuild(t *testing.T, db database.Store, orig database.WorkspaceBuild) database.WorkspaceBuild { + build, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{ + ID: takeFirst(orig.ID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()), + TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()), + BuildNumber: takeFirst(orig.BuildNumber, 0), + Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), + InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), + JobID: takeFirst(orig.JobID, uuid.New()), + ProvisionerState: takeFirstBytes(orig.ProvisionerState, []byte{}), + Deadline: takeFirst(orig.Deadline, time.Now().Add(time.Hour)), + Reason: takeFirst(orig.Reason, database.BuildReasonInitiator), + }) + require.NoError(t, err, "insert workspace build") + return build +} + +func User(t *testing.T, db database.Store, orig database.User) database.User { + user, err := db.InsertUser(context.Background(), database.InsertUserParams{ + ID: takeFirst(orig.ID, uuid.New()), + Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)), + Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)), + HashedPassword: takeFirstBytes(orig.HashedPassword, []byte{}), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + RBACRoles: []string{}, + LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), + }) + require.NoError(t, err, "insert user") + return user +} + +func Organization(t *testing.T, db database.Store, ctx context.Context, orig database.Organization) database.Organization { + org, err := db.InsertOrganization(context.Background(), database.InsertOrganizationParams{ + ID: takeFirst(orig.ID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + }) + require.NoError(t, err, "insert organization") + return org +} + +func Group(t *testing.T, db database.Store, ctx context.Context, orig database.Group) database.Group { + group, err := db.InsertGroup(context.Background(), database.InsertGroupParams{ + ID: takeFirst(orig.ID, uuid.New()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + AvatarURL: takeFirst(orig.AvatarURL, "https://logo.example.com"), + QuotaAllowance: takeFirst(orig.QuotaAllowance, 0), + }) + require.NoError(t, err, "insert group") + return group +} + +func ProvisionerJob(t *testing.T, db database.Store, ctx context.Context, orig database.ProvisionerJob) database.ProvisionerJob { + job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ + ID: takeFirst(orig.ID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), + Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), + StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile), + FileID: takeFirst(orig.FileID, uuid.New()), + Type: takeFirst(orig.Type, database.ProvisionerJobTypeWorkspaceBuild), + Input: takeFirstBytes(orig.Input, []byte("{}")), + Tags: orig.Tags, + }) + require.NoError(t, err, "insert job") + return job +} + +func WorkspaceResource(t *testing.T, db database.Store, ctx context.Context, orig database.WorkspaceResource) database.WorkspaceResource { + resource, err := db.InsertWorkspaceResource(context.Background(), database.InsertWorkspaceResourceParams{ + ID: takeFirst(orig.ID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + JobID: takeFirst(orig.JobID, uuid.New()), + Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), + Type: takeFirst(orig.Type, "fake_resource"), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Hide: takeFirst(orig.Hide, false), + Icon: takeFirst(orig.Icon, ""), + InstanceType: sql.NullString{ + String: takeFirst(orig.InstanceType.String, ""), + Valid: takeFirst(orig.InstanceType.Valid, false), + }, + DailyCost: takeFirst(orig.DailyCost, 0), + }) + require.NoError(t, err, "insert resource") + return resource +} + +func File(t *testing.T, db database.Store, ctx context.Context, orig database.File) database.File { + file, err := db.InsertFile(context.Background(), database.InsertFileParams{ + ID: takeFirst(orig.ID, uuid.New()), + Hash: takeFirst(orig.Hash, hex.EncodeToString(make([]byte, 32))), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + Mimetype: takeFirst(orig.Mimetype, "application/x-tar"), + Data: takeFirstBytes(orig.Data, []byte{}), + }) + require.NoError(t, err, "insert file") + return file +} + +func UserLink(t *testing.T, db database.Store, ctx context.Context, orig database.UserLink) database.UserLink { + link, err := db.InsertUserLink(context.Background(), database.InsertUserLinkParams{ + UserID: takeFirst(orig.UserID, uuid.New()), + LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub), + LinkedID: takeFirst(orig.LinkedID), + OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthExpiry: takeFirst(orig.OAuthExpiry, time.Now().Add(time.Hour*24)), + }) + + require.NoError(t, err, "insert link") + return link +} + +func TemplateVersion(t *testing.T, db database.Store, orig database.TemplateVersion) database.TemplateVersion { + version, err := db.InsertTemplateVersion(context.Background(), database.InsertTemplateVersionParams{ + ID: takeFirst(orig.ID, uuid.New()), + TemplateID: uuid.NullUUID{ + UUID: takeFirst(orig.TemplateID.UUID, uuid.New()), + Valid: takeFirst(orig.TemplateID.Valid, true), + }, + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, time.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), + Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), + Readme: takeFirst(orig.Readme, namesgenerator.GetRandomName(1)), + JobID: takeFirst(orig.JobID, uuid.New()), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + }) + require.NoError(t, err, "insert template version") + return version +} diff --git a/coderd/database/databasefake/fakegen/take.go b/coderd/database/databasefake/fakegen/take.go new file mode 100644 index 0000000000000..317a618be36a7 --- /dev/null +++ b/coderd/database/databasefake/fakegen/take.go @@ -0,0 +1,29 @@ +package fakegen + +// takeFirstBytes implements takeFirst for []byte. +// []byte is not a comparable type. +func takeFirstBytes(values ...[]byte) []byte { + return takeFirstF(values, func(v []byte) bool { + return len(v) != 0 + }) +} + +// takeFirstF takes the first value that returns true +func takeFirstF[Value any](values []Value, take func(v Value) bool) Value { + var empty Value + for _, v := range values { + if take(v) { + return v + } + } + // If all empty, return empty + return empty +} + +// takeFirst will take the first non-empty value. +func takeFirst[Value comparable](values ...Value) Value { + var empty Value + return takeFirstF(values, func(v Value) bool { + return v != empty + }) +} diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go index 2c3b6eda0045b..ef850f57f63b0 100644 --- a/coderd/database/databasefake/generator.go +++ b/coderd/database/databasefake/generator.go @@ -60,6 +60,28 @@ func Generate[Object Supported](t *testing.T, db database.Store, seed Object) Ob return v } +func Template(t *testing.T, db database.Store, seed database.Template) database.Template { + template, err := db.InsertTemplate(context.Background(), database.InsertTemplateParams{ + ID: takeFirst(seed.ID, uuid.New()), + CreatedAt: takeFirst(seed.CreatedAt, time.Now()), + UpdatedAt: takeFirst(seed.UpdatedAt, time.Now()), + OrganizationID: takeFirst(seed.OrganizationID, uuid.New()), + Name: takeFirst(seed.Name, namesgenerator.GetRandomName(1)), + Provisioner: takeFirst(seed.Provisioner, database.ProvisionerTypeEcho), + ActiveVersionID: takeFirst(seed.ActiveVersionID, uuid.New()), + Description: takeFirst(seed.Description, namesgenerator.GetRandomName(1)), + DefaultTTL: takeFirst(seed.DefaultTTL, 3600), + CreatedBy: takeFirst(seed.CreatedBy, uuid.New()), + Icon: takeFirst(seed.Icon, namesgenerator.GetRandomName(1)), + UserACL: seed.UserACL, + GroupACL: seed.GroupACL, + DisplayName: takeFirst(seed.DisplayName, namesgenerator.GetRandomName(1)), + AllowUserCancelWorkspaceJobs: takeFirst(seed.AllowUserCancelWorkspaceJobs, true), + }) + require.NoError(t, err, "insert template") + return template +} + func generate(t *testing.T, db database.Store, seed interface{}) interface{} { t.Helper() From 5cd9acd5f950fde104726ee9ada99416939c33a0 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 13:02:36 -0600 Subject: [PATCH 17/21] Move to database gen package --- coderd/database/databasefake/generator.go | 333 ------------------ .../fakegen => databasegen}/generator.go | 14 +- .../generator_test.go | 28 +- .../fakegen => databasegen}/take.go | 2 +- 4 files changed, 23 insertions(+), 354 deletions(-) delete mode 100644 coderd/database/databasefake/generator.go rename coderd/database/{databasefake/fakegen => databasegen}/generator.go (93%) rename coderd/database/{databasefake => databasegen}/generator_test.go (74%) rename coderd/database/{databasefake/fakegen => databasegen}/take.go (97%) diff --git a/coderd/database/databasefake/generator.go b/coderd/database/databasefake/generator.go deleted file mode 100644 index ef850f57f63b0..0000000000000 --- a/coderd/database/databasefake/generator.go +++ /dev/null @@ -1,333 +0,0 @@ -package databasefake - -import ( - "context" - "crypto/sha256" - "database/sql" - "encoding/hex" - "testing" - "time" - - "github.com/google/uuid" - "github.com/moby/moby/pkg/namesgenerator" - "github.com/stretchr/testify/require" - "github.com/tabbed/pqtype" - - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/cryptorand" -) - -type Supported interface { - database.APIKey | generatedAPIKey | - database.File | - database.UserLink | - database.WorkspaceResource | - database.ProvisionerJob | - database.Group | - database.Organization | - database.Workspace | - database.Template | - database.TemplateVersion | - database.WorkspaceBuild | - database.User -} - -type generatedAPIKey struct { - Secret string - Key database.APIKey -} - -// GenerateAPIKey is a special case that allows returning the secret for the -// api key. -func GenerateAPIKey(t *testing.T, db database.Store, seed database.APIKey) (key database.APIKey, secret string) { - out := generate(t, db, generatedAPIKey{ - Key: seed, - }) - v, ok := out.(generatedAPIKey) - if !ok { - t.Fatalf("Returned type '%T' doses not match expected '%T'", out, generatedAPIKey{}) - } - return v.Key, v.Secret -} - -func Generate[Object Supported](t *testing.T, db database.Store, seed Object) Object { - out := generate(t, db, seed) - v, ok := out.(Object) - if !ok { - var empty Object - t.Fatalf("Returned type '%T' doses not match expected '%T'", out, empty) - } - return v -} - -func Template(t *testing.T, db database.Store, seed database.Template) database.Template { - template, err := db.InsertTemplate(context.Background(), database.InsertTemplateParams{ - ID: takeFirst(seed.ID, uuid.New()), - CreatedAt: takeFirst(seed.CreatedAt, time.Now()), - UpdatedAt: takeFirst(seed.UpdatedAt, time.Now()), - OrganizationID: takeFirst(seed.OrganizationID, uuid.New()), - Name: takeFirst(seed.Name, namesgenerator.GetRandomName(1)), - Provisioner: takeFirst(seed.Provisioner, database.ProvisionerTypeEcho), - ActiveVersionID: takeFirst(seed.ActiveVersionID, uuid.New()), - Description: takeFirst(seed.Description, namesgenerator.GetRandomName(1)), - DefaultTTL: takeFirst(seed.DefaultTTL, 3600), - CreatedBy: takeFirst(seed.CreatedBy, uuid.New()), - Icon: takeFirst(seed.Icon, namesgenerator.GetRandomName(1)), - UserACL: seed.UserACL, - GroupACL: seed.GroupACL, - DisplayName: takeFirst(seed.DisplayName, namesgenerator.GetRandomName(1)), - AllowUserCancelWorkspaceJobs: takeFirst(seed.AllowUserCancelWorkspaceJobs, true), - }) - require.NoError(t, err, "insert template") - return template -} - -func generate(t *testing.T, db database.Store, seed interface{}) interface{} { - t.Helper() - - if _, ok := db.(FakeDatabase); !ok { - // This does not work for postgres databases because of foreign key - // constraints - t.Fatalf("Generate() db must be a FakeDatabase") - } - - // db fake doesn't use contexts anyway. - ctx := context.Background() - - switch orig := seed.(type) { - case database.APIKey, generatedAPIKey: - // Annoying, but we need a way to return the secret if - // the caller needs it. - var g generatedAPIKey - v, isKey := seed.(database.APIKey) - if isKey { - g = generatedAPIKey{ - Key: v, - } - } else { - var ok bool - g, ok = seed.(generatedAPIKey) - if !ok { - t.Fatalf("type '%T' unsupported", seed) - } - } - - id, _ := cryptorand.String(10) - secret, _ := cryptorand.String(22) - hashed := sha256.Sum256([]byte(secret)) - - key, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ - ID: takeFirst(g.Key.ID, id), - // 0 defaults to 86400 at the db layer - LifetimeSeconds: takeFirst(g.Key.LifetimeSeconds, 0), - HashedSecret: takeFirstBytes(g.Key.HashedSecret, hashed[:]), - IPAddress: pqtype.Inet{}, - UserID: takeFirst(g.Key.UserID, uuid.New()), - LastUsed: takeFirst(g.Key.LastUsed, time.Now()), - ExpiresAt: takeFirst(g.Key.ExpiresAt, time.Now().Add(time.Hour)), - CreatedAt: takeFirst(g.Key.CreatedAt, time.Now()), - UpdatedAt: takeFirst(g.Key.UpdatedAt, time.Now()), - LoginType: takeFirst(g.Key.LoginType, database.LoginTypePassword), - Scope: takeFirst(g.Key.Scope, database.APIKeyScopeAll), - }) - require.NoError(t, err, "insert api key") - g.Key = key - g.Secret = secret - - if isKey { - return g.Key - } - return g - case database.Template: - template, err := db.InsertTemplate(ctx, database.InsertTemplateParams{ - ID: takeFirst(orig.ID, uuid.New()), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), - ActiveVersionID: takeFirst(orig.ActiveVersionID, uuid.New()), - Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), - DefaultTTL: takeFirst(orig.DefaultTTL, 3600), - CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), - Icon: takeFirst(orig.Icon, namesgenerator.GetRandomName(1)), - UserACL: orig.UserACL, - GroupACL: orig.GroupACL, - DisplayName: takeFirst(orig.DisplayName, namesgenerator.GetRandomName(1)), - AllowUserCancelWorkspaceJobs: takeFirst(orig.AllowUserCancelWorkspaceJobs, true), - }) - require.NoError(t, err, "insert template") - return template - case database.TemplateVersion: - version, err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ - ID: takeFirst(orig.ID, uuid.New()), - TemplateID: uuid.NullUUID{ - UUID: takeFirst(orig.TemplateID.UUID, uuid.New()), - Valid: takeFirst(orig.TemplateID.Valid, true), - }, - OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - Readme: takeFirst(orig.Readme, namesgenerator.GetRandomName(1)), - JobID: takeFirst(orig.JobID, uuid.New()), - CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), - }) - require.NoError(t, err, "insert template version") - return version - case database.Workspace: - workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ - ID: takeFirst(orig.ID, uuid.New()), - OwnerID: takeFirst(orig.OwnerID, uuid.New()), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), - TemplateID: takeFirst(orig.TemplateID, uuid.New()), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - AutostartSchedule: orig.AutostartSchedule, - Ttl: orig.Ttl, - }) - require.NoError(t, err, "insert workspace") - return workspace - case database.WorkspaceBuild: - build, err := db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ - ID: takeFirst(orig.ID, uuid.New()), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()), - TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()), - BuildNumber: takeFirst(orig.BuildNumber, 0), - Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), - InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), - JobID: takeFirst(orig.JobID, uuid.New()), - ProvisionerState: takeFirstBytes(orig.ProvisionerState, []byte{}), - Deadline: takeFirst(orig.Deadline, time.Now().Add(time.Hour)), - Reason: takeFirst(orig.Reason, database.BuildReasonInitiator), - }) - require.NoError(t, err, "insert workspace build") - return build - case database.User: - user, err := db.InsertUser(ctx, database.InsertUserParams{ - ID: takeFirst(orig.ID, uuid.New()), - Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)), - Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)), - HashedPassword: takeFirstBytes(orig.HashedPassword, []byte{}), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - RBACRoles: []string{}, - LoginType: takeFirst(orig.LoginType, database.LoginTypePassword), - }) - require.NoError(t, err, "insert user") - return user - case database.Organization: - org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{ - ID: takeFirst(orig.ID, uuid.New()), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - }) - require.NoError(t, err, "insert organization") - return org - case database.Group: - group, err := db.InsertGroup(ctx, database.InsertGroupParams{ - ID: takeFirst(orig.ID, uuid.New()), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), - AvatarURL: takeFirst(orig.AvatarURL, "https://logo.example.com"), - QuotaAllowance: takeFirst(orig.QuotaAllowance, 0), - }) - require.NoError(t, err, "insert group") - return group - case database.ProvisionerJob: - job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ - ID: takeFirst(orig.ID, uuid.New()), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()), - OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), - InitiatorID: takeFirst(orig.InitiatorID, uuid.New()), - Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho), - StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile), - FileID: takeFirst(orig.FileID, uuid.New()), - Type: takeFirst(orig.Type, database.ProvisionerJobTypeWorkspaceBuild), - Input: takeFirstBytes(orig.Input, []byte("{}")), - Tags: orig.Tags, - }) - require.NoError(t, err, "insert job") - return job - case database.WorkspaceResource: - resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ - ID: takeFirst(orig.ID, uuid.New()), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - JobID: takeFirst(orig.JobID, uuid.New()), - Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart), - Type: takeFirst(orig.Type, "fake_resource"), - Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), - Hide: takeFirst(orig.Hide, false), - Icon: takeFirst(orig.Icon, ""), - InstanceType: sql.NullString{ - String: takeFirst(orig.InstanceType.String, ""), - Valid: takeFirst(orig.InstanceType.Valid, false), - }, - DailyCost: takeFirst(orig.DailyCost, 0), - }) - require.NoError(t, err, "insert resource") - return resource - case database.File: - file, err := db.InsertFile(ctx, database.InsertFileParams{ - ID: takeFirst(orig.ID, uuid.New()), - Hash: takeFirst(orig.Hash, hex.EncodeToString(make([]byte, 32))), - CreatedAt: takeFirst(orig.CreatedAt, time.Now()), - CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), - Mimetype: takeFirst(orig.Mimetype, "application/x-tar"), - Data: takeFirstBytes(orig.Data, []byte{}), - }) - require.NoError(t, err, "insert file") - return file - case database.UserLink: - link, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{ - UserID: takeFirst(orig.UserID, uuid.New()), - LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub), - LinkedID: takeFirst(orig.LinkedID), - OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthExpiry: takeFirst(orig.OAuthExpiry, time.Now().Add(time.Hour*24)), - }) - - require.NoError(t, err, "insert link") - return link - default: - // If you hit this, just add your type to the switch. - t.Fatalf("unknown type '%T' used in fake data generator", orig) - // This line will never be hit, but the compiler does not know that :/ - return nil - } -} - -// takeFirstBytes implements takeFirst for []byte. -// []byte is not a comparable type. -func takeFirstBytes(values ...[]byte) []byte { - return takeFirstF(values, func(v []byte) bool { - return len(v) != 0 - }) -} - -// takeFirstF takes the first value that returns true -func takeFirstF[Value any](values []Value, take func(v Value) bool) Value { - var empty Value - for _, v := range values { - if take(v) { - return v - } - } - // If all empty, return empty - return empty -} - -// takeFirst will take the first non-empty value. -func takeFirst[Value comparable](values ...Value) Value { - var empty Value - return takeFirstF(values, func(v Value) bool { - return v != empty - }) -} diff --git a/coderd/database/databasefake/fakegen/generator.go b/coderd/database/databasegen/generator.go similarity index 93% rename from coderd/database/databasefake/fakegen/generator.go rename to coderd/database/databasegen/generator.go index e3ec2b57ffd0d..a504728def9ef 100644 --- a/coderd/database/databasefake/fakegen/generator.go +++ b/coderd/database/databasegen/generator.go @@ -1,4 +1,4 @@ -package fakegen +package databasegen import ( "context" @@ -112,7 +112,7 @@ func User(t *testing.T, db database.Store, orig database.User) database.User { return user } -func Organization(t *testing.T, db database.Store, ctx context.Context, orig database.Organization) database.Organization { +func Organization(t *testing.T, db database.Store, orig database.Organization) database.Organization { org, err := db.InsertOrganization(context.Background(), database.InsertOrganizationParams{ ID: takeFirst(orig.ID, uuid.New()), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), @@ -124,7 +124,7 @@ func Organization(t *testing.T, db database.Store, ctx context.Context, orig dat return org } -func Group(t *testing.T, db database.Store, ctx context.Context, orig database.Group) database.Group { +func Group(t *testing.T, db database.Store, orig database.Group) database.Group { group, err := db.InsertGroup(context.Background(), database.InsertGroupParams{ ID: takeFirst(orig.ID, uuid.New()), Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)), @@ -136,7 +136,7 @@ func Group(t *testing.T, db database.Store, ctx context.Context, orig database.G return group } -func ProvisionerJob(t *testing.T, db database.Store, ctx context.Context, orig database.ProvisionerJob) database.ProvisionerJob { +func ProvisionerJob(t *testing.T, db database.Store, orig database.ProvisionerJob) database.ProvisionerJob { job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ ID: takeFirst(orig.ID, uuid.New()), CreatedAt: takeFirst(orig.CreatedAt, time.Now()), @@ -154,7 +154,7 @@ func ProvisionerJob(t *testing.T, db database.Store, ctx context.Context, orig d return job } -func WorkspaceResource(t *testing.T, db database.Store, ctx context.Context, orig database.WorkspaceResource) database.WorkspaceResource { +func WorkspaceResource(t *testing.T, db database.Store, orig database.WorkspaceResource) database.WorkspaceResource { resource, err := db.InsertWorkspaceResource(context.Background(), database.InsertWorkspaceResourceParams{ ID: takeFirst(orig.ID, uuid.New()), CreatedAt: takeFirst(orig.CreatedAt, time.Now()), @@ -174,7 +174,7 @@ func WorkspaceResource(t *testing.T, db database.Store, ctx context.Context, ori return resource } -func File(t *testing.T, db database.Store, ctx context.Context, orig database.File) database.File { +func File(t *testing.T, db database.Store, orig database.File) database.File { file, err := db.InsertFile(context.Background(), database.InsertFileParams{ ID: takeFirst(orig.ID, uuid.New()), Hash: takeFirst(orig.Hash, hex.EncodeToString(make([]byte, 32))), @@ -187,7 +187,7 @@ func File(t *testing.T, db database.Store, ctx context.Context, orig database.Fi return file } -func UserLink(t *testing.T, db database.Store, ctx context.Context, orig database.UserLink) database.UserLink { +func UserLink(t *testing.T, db database.Store, orig database.UserLink) database.UserLink { link, err := db.InsertUserLink(context.Background(), database.InsertUserLinkParams{ UserID: takeFirst(orig.UserID, uuid.New()), LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub), diff --git a/coderd/database/databasefake/generator_test.go b/coderd/database/databasegen/generator_test.go similarity index 74% rename from coderd/database/databasefake/generator_test.go rename to coderd/database/databasegen/generator_test.go index 05de4b560a999..1aebf6e2fa4e0 100644 --- a/coderd/database/databasefake/generator_test.go +++ b/coderd/database/databasegen/generator_test.go @@ -1,9 +1,11 @@ -package databasefake_test +package databasegen_test import ( "context" "testing" + "github.com/coder/coder/coderd/database/databasegen" + "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/database" @@ -16,84 +18,84 @@ func TestGenerator(t *testing.T) { t.Run("APIKey", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp, _ := databasefake.GenerateAPIKey(t, db, database.APIKey{}) + exp, _ := databasegen.APIKey(t, db, database.APIKey{}) require.Equal(t, exp, must(db.GetAPIKeyByID(context.Background(), exp.ID))) }) t.Run("File", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.File{}) + exp := databasegen.File(t, db, database.File{}) require.Equal(t, exp, must(db.GetFileByID(context.Background(), exp.ID))) }) t.Run("UserLink", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.UserLink{}) + exp := databasegen.UserLink(t, db, database.UserLink{}) require.Equal(t, exp, must(db.GetUserLinkByLinkedID(context.Background(), exp.LinkedID))) }) t.Run("WorkspaceResource", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.WorkspaceResource{}) + exp := databasegen.WorkspaceResource(t, db, database.WorkspaceResource{}) require.Equal(t, exp, must(db.GetWorkspaceResourceByID(context.Background(), exp.ID))) }) t.Run("Job", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.ProvisionerJob{}) + exp := databasegen.ProvisionerJob(t, db, database.ProvisionerJob{}) require.Equal(t, exp, must(db.GetProvisionerJobByID(context.Background(), exp.ID))) }) t.Run("Group", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.Group{}) + exp := databasegen.Group(t, db, database.Group{}) require.Equal(t, exp, must(db.GetGroupByID(context.Background(), exp.ID))) }) t.Run("Organization", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.Organization{}) + exp := databasegen.Organization(t, db, database.Organization{}) require.Equal(t, exp, must(db.GetOrganizationByID(context.Background(), exp.ID))) }) t.Run("Workspace", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.Workspace{}) + exp := databasegen.Workspace(t, db, database.Workspace{}) require.Equal(t, exp, must(db.GetWorkspaceByID(context.Background(), exp.ID))) }) t.Run("Template", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.Template{}) + exp := databasegen.Template(t, db, database.Template{}) require.Equal(t, exp, must(db.GetTemplateByID(context.Background(), exp.ID))) }) t.Run("TemplateVersion", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.TemplateVersion{}) + exp := databasegen.TemplateVersion(t, db, database.TemplateVersion{}) require.Equal(t, exp, must(db.GetTemplateVersionByID(context.Background(), exp.ID))) }) t.Run("WorkspaceBuild", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.WorkspaceBuild{}) + exp := databasegen.WorkspaceBuild(t, db, database.WorkspaceBuild{}) require.Equal(t, exp, must(db.GetWorkspaceBuildByID(context.Background(), exp.ID))) }) t.Run("User", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasefake.Generate(t, db, database.User{}) + exp := databasegen.User(t, db, database.User{}) require.Equal(t, exp, must(db.GetUserByID(context.Background(), exp.ID))) }) } diff --git a/coderd/database/databasefake/fakegen/take.go b/coderd/database/databasegen/take.go similarity index 97% rename from coderd/database/databasefake/fakegen/take.go rename to coderd/database/databasegen/take.go index 317a618be36a7..4bd1dec9feb43 100644 --- a/coderd/database/databasefake/fakegen/take.go +++ b/coderd/database/databasegen/take.go @@ -1,4 +1,4 @@ -package fakegen +package databasegen // takeFirstBytes implements takeFirst for []byte. // []byte is not a comparable type. From e7d2b5f453ebefcaa8d84131a7b12ce7d87e188a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 13:09:10 -0600 Subject: [PATCH 18/21] Fix all refactored tests to use new methods --- coderd/httpmw/apikey_test.go | 77 +++++++------------ coderd/httpmw/groupparam_test.go | 29 +++---- coderd/httpmw/ratelimit_test.go | 17 ++-- coderd/httpmw/workspaceresourceparam_test.go | 9 ++- .../prometheusmetrics_test.go | 15 ++-- .../provisionerdserver_test.go | 37 +++++---- coderd/workspaceapps_internal_test.go | 8 +- 7 files changed, 79 insertions(+), 113 deletions(-) diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 656cd34f540d3..dcfa568a726c2 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + "github.com/coder/coder/coderd/database/databasegen" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -150,17 +152,14 @@ func TestAPIKey(t *testing.T) { t.Run("InvalidSecret", func(t *testing.T) { t.Parallel() var ( - db = databasefake.New() - gen = databasefake.NewGenerator(t, db) - r = httptest.NewRequest("GET", "/", nil) - rw = httptest.NewRecorder() - - ctx = context.Background() - user = gen.User(ctx, database.User{}) + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + user = databasegen.User(t, db, database.User{}) // Use a different secret so they don't match! hashed = sha256.Sum256([]byte("differentsecret")) - _, token = gen.APIKey(ctx, database.APIKey{ + _, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, HashedSecret: hashed[:], }) @@ -179,10 +178,9 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - gen = databasefake.NewGenerator(t, db) ctx = context.Background() - user = gen.User(ctx, database.User{}) - _, token = gen.APIKey(ctx, database.APIKey{ + user = databasegen.User(t, db, database.User{}) + _, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, ExpiresAt: time.Now().Add(time.Hour * -1), }) @@ -205,10 +203,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - gen = databasefake.NewGenerator(t, db) - ctx = context.Background() - user = gen.User(ctx, database.User{}) - sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + user = databasegen.User(t, db, database.User{}) + sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, ExpiresAt: database.Now().AddDate(0, 0, 1), }) @@ -242,10 +238,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - gen = databasefake.NewGenerator(t, db) - ctx = context.Background() - user = gen.User(ctx, database.User{}) - _, token = gen.APIKey(ctx, database.APIKey{ + user = databasegen.User(t, db, database.User{}) + _, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, ExpiresAt: database.Now().AddDate(0, 0, 1), Scope: database.APIKeyScopeApplicationConnect, @@ -281,10 +275,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - gen = databasefake.NewGenerator(t, db) - ctx = context.Background() - user = gen.User(ctx, database.User{}) - _, token = gen.APIKey(ctx, database.APIKey{ + user = databasegen.User(t, db, database.User{}) + _, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, ExpiresAt: database.Now().AddDate(0, 0, 1), }) @@ -315,10 +307,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - gen = databasefake.NewGenerator(t, db) - ctx = context.Background() - user = gen.User(ctx, database.User{}) - sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + user = databasegen.User(t, db, database.User{}) + sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now().AddDate(0, 0, -1), ExpiresAt: database.Now().AddDate(0, 0, 1), @@ -348,10 +338,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - gen = databasefake.NewGenerator(t, db) - ctx = context.Background() - user = gen.User(ctx, database.User{}) - sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + user = databasegen.User(t, db, database.User{}) + sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), @@ -381,16 +369,14 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - gen = databasefake.NewGenerator(t, db) - ctx = context.Background() - user = gen.User(ctx, database.User{}) - sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + user = databasegen.User(t, db, database.User{}) + sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), LoginType: database.LoginTypeGithub, }) - _ = gen.UserLink(ctx, database.UserLink{ + _ = databasegen.UserLink(t, db, database.UserLink{ UserID: user.ID, LoginType: database.LoginTypeGithub, }) @@ -419,16 +405,15 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - gen = databasefake.NewGenerator(t, db) ctx = context.Background() - user = gen.User(ctx, database.User{}) - sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + user = databasegen.User(t, db, database.User{}) + sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), LoginType: database.LoginTypeGithub, }) - _ = gen.UserLink(ctx, database.UserLink{ + _ = databasegen.UserLink(t, db, database.UserLink{ UserID: user.ID, LoginType: database.LoginTypeGithub, OAuthRefreshToken: "hello", @@ -471,10 +456,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - gen = databasefake.NewGenerator(t, db) - ctx = context.Background() - user = gen.User(ctx, database.User{}) - sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + user = databasegen.User(t, db, database.User{}) + sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now().AddDate(0, 0, -1), ExpiresAt: database.Now().AddDate(0, 0, 1), @@ -556,10 +539,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - gen = databasefake.NewGenerator(t, db) - ctx = context.Background() - user = gen.User(ctx, database.User{}) - sentAPIKey, token = gen.APIKey(ctx, database.APIKey{ + user = databasegen.User(t, db, database.User{}) + sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), diff --git a/coderd/httpmw/groupparam_test.go b/coderd/httpmw/groupparam_test.go index b726dfd8e45c0..e77786605935d 100644 --- a/coderd/httpmw/groupparam_test.go +++ b/coderd/httpmw/groupparam_test.go @@ -6,6 +6,8 @@ import ( "net/http/httptest" "testing" + "github.com/coder/coder/coderd/database/databasegen" + "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -13,31 +15,19 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/coderd/httpmw" - "github.com/coder/coder/testutil" ) func TestGroupParam(t *testing.T) { t.Parallel() - setup := func(t *testing.T) (database.Store, database.Group) { - t.Helper() - - ctx, _ := testutil.Context(t) - db := databasefake.New() - gen := databasefake.NewGenerator(t, db) - - group := gen.Group(ctx, database.Group{}) - - return db, group - } - t.Run("OK", func(t *testing.T) { t.Parallel() var ( - db, group = setup(t) - r = httptest.NewRequest("GET", "/", nil) - w = httptest.NewRecorder() + db = databasefake.New() + group = databasegen.Group(t, db, database.Group{}) + r = httptest.NewRequest("GET", "/", nil) + w = httptest.NewRecorder() ) router := chi.NewRouter() @@ -63,9 +53,10 @@ func TestGroupParam(t *testing.T) { t.Parallel() var ( - db, group = setup(t) - r = httptest.NewRequest("GET", "/", nil) - w = httptest.NewRecorder() + db = databasefake.New() + group = databasegen.Group(t, db, database.Group{}) + r = httptest.NewRequest("GET", "/", nil) + w = httptest.NewRecorder() ) router := chi.NewRouter() diff --git a/coderd/httpmw/ratelimit_test.go b/coderd/httpmw/ratelimit_test.go index ec8b30ad03513..4e22065e7408b 100644 --- a/coderd/httpmw/ratelimit_test.go +++ b/coderd/httpmw/ratelimit_test.go @@ -1,7 +1,6 @@ package httpmw_test import ( - "context" "fmt" "math/rand" "net" @@ -10,6 +9,8 @@ import ( "testing" "time" + "github.com/coder/coder/coderd/database/databasegen" + "github.com/go-chi/chi/v5" "github.com/stretchr/testify/require" @@ -71,12 +72,9 @@ func TestRateLimit(t *testing.T) { t.Run("RegularUser", func(t *testing.T) { t.Parallel() - ctx := context.Background() - db := databasefake.New() - gen := databasefake.NewGenerator(t, db) - u := gen.User(ctx, database.User{}) - _, key := gen.APIKey(ctx, database.APIKey{UserID: u.ID}) + u := databasegen.User(t, db, database.User{}) + _, key := databasegen.APIKey(t, db, database.APIKey{UserID: u.ID}) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ @@ -117,15 +115,12 @@ func TestRateLimit(t *testing.T) { t.Run("OwnerBypass", func(t *testing.T) { t.Parallel() - ctx := context.Background() - db := databasefake.New() - gen := databasefake.NewGenerator(t, db) - u := gen.User(ctx, database.User{ + u := databasegen.User(t, db, database.User{ RBACRoles: []string{rbac.RoleOwner()}, }) - _, key := gen.APIKey(ctx, database.APIKey{UserID: u.ID}) + _, key := databasegen.APIKey(t, db, database.APIKey{UserID: u.ID}) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ diff --git a/coderd/httpmw/workspaceresourceparam_test.go b/coderd/httpmw/workspaceresourceparam_test.go index a858ea6497a56..ddde1bfb20b05 100644 --- a/coderd/httpmw/workspaceresourceparam_test.go +++ b/coderd/httpmw/workspaceresourceparam_test.go @@ -6,6 +6,8 @@ import ( "net/http/httptest" "testing" + "github.com/coder/coder/coderd/database/databasegen" + "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -21,20 +23,19 @@ func TestWorkspaceResourceParam(t *testing.T) { setup := func(t *testing.T, db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) { r := httptest.NewRequest("GET", "/", nil) ctx := context.Background() - gen := databasefake.NewGenerator(t, db) - job := gen.Job(ctx, database.ProvisionerJob{ + job := databasegen.ProvisionerJob(t, db, database.ProvisionerJob{ Type: jobType, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, }) - build := gen.WorkspaceBuild(ctx, database.WorkspaceBuild{ + build := databasegen.WorkspaceBuild(t, db, database.WorkspaceBuild{ JobID: job.ID, Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) - resource := gen.WorkspaceResource(ctx, database.WorkspaceResource{ + resource := databasegen.WorkspaceResource(t, db, database.WorkspaceResource{ JobID: job.ID, Transition: database.WorkspaceTransitionStart, }) diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index 74c382ad22986..675b9d6fc2798 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/coder/coder/coderd/database/databasegen" + "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" @@ -35,8 +37,7 @@ func TestActiveUsers(t *testing.T) { Name: "One", Database: func(t *testing.T) database.Store { db := databasefake.New() - gen := databasefake.NewGenerator(t, db) - gen.APIKey(context.Background(), database.APIKey{ + databasegen.APIKey(t, db, database.APIKey{ LastUsed: database.Now(), }) return db @@ -46,15 +47,14 @@ func TestActiveUsers(t *testing.T) { Name: "OneWithExpired", Database: func(t *testing.T) database.Store { db := databasefake.New() - gen := databasefake.NewGenerator(t, db) - gen.APIKey(context.Background(), database.APIKey{ + databasegen.APIKey(t, db, database.APIKey{ LastUsed: database.Now(), }) // Because this API key hasn't been used in the past hour, this shouldn't // add to the user count. - gen.APIKey(context.Background(), database.APIKey{ + databasegen.APIKey(t, db, database.APIKey{ LastUsed: database.Now().Add(-2 * time.Hour), }) return db @@ -64,11 +64,10 @@ func TestActiveUsers(t *testing.T) { Name: "Multiple", Database: func(t *testing.T) database.Store { db := databasefake.New() - gen := databasefake.NewGenerator(t, db) - gen.APIKey(context.Background(), database.APIKey{ + databasegen.APIKey(t, db, database.APIKey{ LastUsed: database.Now(), }) - gen.APIKey(context.Background(), database.APIKey{ + databasegen.APIKey(t, db, database.APIKey{ LastUsed: database.Now(), }) return db diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index f840a051c5e29..c98880241d0ee 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/coder/coder/coderd/database/databasegen" + "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -88,15 +90,14 @@ func TestAcquireJob(t *testing.T) { srv := setup(t, false) ctx := context.Background() - gen := databasefake.NewGenerator(t, srv.Database) - user := gen.User(ctx, database.User{}) - template := gen.Template(ctx, database.Template{ + user := databasegen.User(t, srv.Database, database.User{}) + template := databasegen.Template(t, srv.Database, database.Template{ Name: "template", Provisioner: database.ProvisionerTypeEcho, }) - file := gen.File(ctx, database.File{CreatedBy: user.ID}) - versionFile := gen.File(ctx, database.File{CreatedBy: user.ID}) - version := gen.TemplateVersion(ctx, database.TemplateVersion{ + file := databasegen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + versionFile := databasegen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + version := databasegen.TemplateVersion(t, srv.Database, database.TemplateVersion{ TemplateID: uuid.NullUUID{ UUID: template.ID, Valid: true, @@ -104,7 +105,7 @@ func TestAcquireJob(t *testing.T) { JobID: uuid.New(), }) // Import version job - _ = gen.Job(ctx, database.ProvisionerJob{ + _ = databasegen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ ID: version.JobID, InitiatorID: user.ID, FileID: versionFile.ID, @@ -115,11 +116,11 @@ func TestAcquireJob(t *testing.T) { TemplateVersionID: version.ID, })), }) - workspace := gen.Workspace(ctx, database.Workspace{ + workspace := databasegen.Workspace(t, srv.Database, database.Workspace{ TemplateID: template.ID, OwnerID: user.ID, }) - build := gen.WorkspaceBuild(ctx, database.WorkspaceBuild{ + build := databasegen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{ WorkspaceID: workspace.ID, BuildNumber: 1, JobID: uuid.New(), @@ -127,7 +128,7 @@ func TestAcquireJob(t *testing.T) { Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) - _ = gen.Job(ctx, database.ProvisionerJob{ + _ = databasegen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ ID: build.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, @@ -188,11 +189,10 @@ func TestAcquireJob(t *testing.T) { srv := setup(t, false) ctx := context.Background() - gen := databasefake.NewGenerator(t, srv.Database) - user := gen.User(ctx, database.User{}) - version := gen.TemplateVersion(ctx, database.TemplateVersion{}) - file := gen.File(ctx, database.File{CreatedBy: user.ID}) - _ = gen.Job(ctx, database.ProvisionerJob{ + user := databasegen.User(t, srv.Database, database.User{}) + version := databasegen.TemplateVersion(t, srv.Database, database.TemplateVersion{}) + file := databasegen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + _ = databasegen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, @@ -228,10 +228,9 @@ func TestAcquireJob(t *testing.T) { srv := setup(t, false) ctx := context.Background() - gen := databasefake.NewGenerator(t, srv.Database) - user := gen.User(ctx, database.User{}) - file := gen.File(ctx, database.File{CreatedBy: user.ID}) - _ = gen.Job(ctx, database.ProvisionerJob{ + user := databasegen.User(t, srv.Database, database.User{}) + file := databasegen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + _ = databasegen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ FileID: file.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, diff --git a/coderd/workspaceapps_internal_test.go b/coderd/workspaceapps_internal_test.go index eb6f785e1dd92..12515a710eacb 100644 --- a/coderd/workspaceapps_internal_test.go +++ b/coderd/workspaceapps_internal_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/coder/coder/coderd/database/databasegen" + "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/database" @@ -17,8 +19,7 @@ func TestAPIKeyEncryption(t *testing.T) { t.Parallel() generateAPIKey := func(t *testing.T, db database.Store) (keyID, keyToken string, hashedSecret []byte, data encryptedAPIKeyPayload) { - gen := databasefake.NewGenerator(t, db) - key, token := gen.APIKey(context.Background(), database.APIKey{}) + key, token := databasegen.APIKey(t, db, database.APIKey{}) data = encryptedAPIKeyPayload{ APIKey: token, @@ -70,10 +71,9 @@ func TestAPIKeyEncryption(t *testing.T) { t.Parallel() db := databasefake.New() - gen := databasefake.NewGenerator(t, db) hashedSecret := sha256.Sum256([]byte("wrong")) // Insert a token with a mismatched hashed secret. - _, token := gen.APIKey(context.Background(), database.APIKey{ + _, token := databasegen.APIKey(t, db, database.APIKey{ HashedSecret: hashedSecret[:], }) From 5e12b783b26a9bbf253137b73371560c669a2543 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 13:35:54 -0600 Subject: [PATCH 19/21] Remove unused ctx --- coderd/httpmw/apikey_test.go | 2 -- coderd/httpmw/workspaceresourceparam_test.go | 1 - 2 files changed, 3 deletions(-) diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index dcfa568a726c2..0b5f9fce0d6a6 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -178,7 +178,6 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - ctx = context.Background() user = databasegen.User(t, db, database.User{}) _, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, @@ -405,7 +404,6 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - ctx = context.Background() user = databasegen.User(t, db, database.User{}) sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ UserID: user.ID, diff --git a/coderd/httpmw/workspaceresourceparam_test.go b/coderd/httpmw/workspaceresourceparam_test.go index ddde1bfb20b05..a7c29c32eb889 100644 --- a/coderd/httpmw/workspaceresourceparam_test.go +++ b/coderd/httpmw/workspaceresourceparam_test.go @@ -22,7 +22,6 @@ func TestWorkspaceResourceParam(t *testing.T) { setup := func(t *testing.T, db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) { r := httptest.NewRequest("GET", "/", nil) - ctx := context.Background() job := databasegen.ProvisionerJob(t, db, database.ProvisionerJob{ Type: jobType, Provisioner: database.ProvisionerTypeEcho, From 144ef40445c9ee17bbce8c15904124a398d2d2b1 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 13:47:28 -0600 Subject: [PATCH 20/21] Add comment about unused fields --- coderd/database/databasegen/generator.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/coderd/database/databasegen/generator.go b/coderd/database/databasegen/generator.go index a504728def9ef..c1fff5999f480 100644 --- a/coderd/database/databasegen/generator.go +++ b/coderd/database/databasegen/generator.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "database/sql" "encoding/hex" + "fmt" "testing" "time" @@ -17,6 +18,9 @@ import ( "github.com/stretchr/testify/require" ) +// All methods take in a 'seed' object. Any provided fields in the seed will be +// maintained. Any fields omitted will have sensible defaults generated. + func Template(t *testing.T, db database.Store, seed database.Template) database.Template { template, err := db.InsertTemplate(context.Background(), database.InsertTemplateParams{ ID: takeFirst(seed.ID, uuid.New()), @@ -39,9 +43,9 @@ func Template(t *testing.T, db database.Store, seed database.Template) database. return template } -func APIKey(t *testing.T, db database.Store, seed database.APIKey) (key database.APIKey, secret string) { +func APIKey(t *testing.T, db database.Store, seed database.APIKey) (key database.APIKey, token string) { id, _ := cryptorand.String(10) - secret, _ = cryptorand.String(22) + secret, _ := cryptorand.String(22) hashed := sha256.Sum256([]byte(secret)) key, err := db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ @@ -59,7 +63,7 @@ func APIKey(t *testing.T, db database.Store, seed database.APIKey) (key database Scope: takeFirst(seed.Scope, database.APIKeyScopeAll), }) require.NoError(t, err, "insert api key") - return key, secret + return key, fmt.Sprintf("%s-%s", key.ID, secret) } func Workspace(t *testing.T, db database.Store, orig database.Workspace) database.Workspace { From 61ca1ca11394723c21dc35bb42d030f9e9664bfb Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 31 Jan 2023 14:55:57 -0600 Subject: [PATCH 21/21] rename to dbgen --- .../{databasegen => dbgen}/generator.go | 2 +- .../{databasegen => dbgen}/generator_test.go | 29 +++++------ .../database/{databasegen => dbgen}/take.go | 2 +- coderd/httpmw/apikey_test.go | 51 +++++++++---------- coderd/httpmw/groupparam_test.go | 7 ++- coderd/httpmw/ratelimit_test.go | 11 ++-- coderd/httpmw/workspaceresourceparam_test.go | 9 ++-- .../prometheusmetrics_test.go | 13 +++-- .../provisionerdserver_test.go | 35 +++++++------ coderd/workspaceapps_internal_test.go | 7 ++- 10 files changed, 79 insertions(+), 87 deletions(-) rename coderd/database/{databasegen => dbgen}/generator.go (99%) rename coderd/database/{databasegen => dbgen}/generator_test.go (74%) rename coderd/database/{databasegen => dbgen}/take.go (97%) diff --git a/coderd/database/databasegen/generator.go b/coderd/database/dbgen/generator.go similarity index 99% rename from coderd/database/databasegen/generator.go rename to coderd/database/dbgen/generator.go index c1fff5999f480..2d3c420fc6784 100644 --- a/coderd/database/databasegen/generator.go +++ b/coderd/database/dbgen/generator.go @@ -1,4 +1,4 @@ -package databasegen +package dbgen import ( "context" diff --git a/coderd/database/databasegen/generator_test.go b/coderd/database/dbgen/generator_test.go similarity index 74% rename from coderd/database/databasegen/generator_test.go rename to coderd/database/dbgen/generator_test.go index 1aebf6e2fa4e0..2266c866dbd09 100644 --- a/coderd/database/databasegen/generator_test.go +++ b/coderd/database/dbgen/generator_test.go @@ -1,15 +1,14 @@ -package databasegen_test +package dbgen_test import ( "context" "testing" - "github.com/coder/coder/coderd/database/databasegen" - "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" ) func TestGenerator(t *testing.T) { @@ -18,84 +17,84 @@ func TestGenerator(t *testing.T) { t.Run("APIKey", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp, _ := databasegen.APIKey(t, db, database.APIKey{}) + exp, _ := dbgen.APIKey(t, db, database.APIKey{}) require.Equal(t, exp, must(db.GetAPIKeyByID(context.Background(), exp.ID))) }) t.Run("File", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.File(t, db, database.File{}) + exp := dbgen.File(t, db, database.File{}) require.Equal(t, exp, must(db.GetFileByID(context.Background(), exp.ID))) }) t.Run("UserLink", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.UserLink(t, db, database.UserLink{}) + exp := dbgen.UserLink(t, db, database.UserLink{}) require.Equal(t, exp, must(db.GetUserLinkByLinkedID(context.Background(), exp.LinkedID))) }) t.Run("WorkspaceResource", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.WorkspaceResource(t, db, database.WorkspaceResource{}) + exp := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{}) require.Equal(t, exp, must(db.GetWorkspaceResourceByID(context.Background(), exp.ID))) }) t.Run("Job", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.ProvisionerJob(t, db, database.ProvisionerJob{}) + exp := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{}) require.Equal(t, exp, must(db.GetProvisionerJobByID(context.Background(), exp.ID))) }) t.Run("Group", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.Group(t, db, database.Group{}) + exp := dbgen.Group(t, db, database.Group{}) require.Equal(t, exp, must(db.GetGroupByID(context.Background(), exp.ID))) }) t.Run("Organization", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.Organization(t, db, database.Organization{}) + exp := dbgen.Organization(t, db, database.Organization{}) require.Equal(t, exp, must(db.GetOrganizationByID(context.Background(), exp.ID))) }) t.Run("Workspace", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.Workspace(t, db, database.Workspace{}) + exp := dbgen.Workspace(t, db, database.Workspace{}) require.Equal(t, exp, must(db.GetWorkspaceByID(context.Background(), exp.ID))) }) t.Run("Template", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.Template(t, db, database.Template{}) + exp := dbgen.Template(t, db, database.Template{}) require.Equal(t, exp, must(db.GetTemplateByID(context.Background(), exp.ID))) }) t.Run("TemplateVersion", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.TemplateVersion(t, db, database.TemplateVersion{}) + exp := dbgen.TemplateVersion(t, db, database.TemplateVersion{}) require.Equal(t, exp, must(db.GetTemplateVersionByID(context.Background(), exp.ID))) }) t.Run("WorkspaceBuild", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.WorkspaceBuild(t, db, database.WorkspaceBuild{}) + exp := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{}) require.Equal(t, exp, must(db.GetWorkspaceBuildByID(context.Background(), exp.ID))) }) t.Run("User", func(t *testing.T) { t.Parallel() db := databasefake.New() - exp := databasegen.User(t, db, database.User{}) + exp := dbgen.User(t, db, database.User{}) require.Equal(t, exp, must(db.GetUserByID(context.Background(), exp.ID))) }) } diff --git a/coderd/database/databasegen/take.go b/coderd/database/dbgen/take.go similarity index 97% rename from coderd/database/databasegen/take.go rename to coderd/database/dbgen/take.go index 4bd1dec9feb43..717f2c0441cc3 100644 --- a/coderd/database/databasegen/take.go +++ b/coderd/database/dbgen/take.go @@ -1,4 +1,4 @@ -package databasegen +package dbgen // takeFirstBytes implements takeFirst for []byte. // []byte is not a comparable type. diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 0b5f9fce0d6a6..425999eb9f6f9 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -11,14 +11,13 @@ import ( "testing" "time" - "github.com/coder/coder/coderd/database/databasegen" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/codersdk" @@ -155,11 +154,11 @@ func TestAPIKey(t *testing.T) { db = databasefake.New() r = httptest.NewRequest("GET", "/", nil) rw = httptest.NewRecorder() - user = databasegen.User(t, db, database.User{}) + user = dbgen.User(t, db, database.User{}) // Use a different secret so they don't match! hashed = sha256.Sum256([]byte("differentsecret")) - _, token = databasegen.APIKey(t, db, database.APIKey{ + _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, HashedSecret: hashed[:], }) @@ -178,8 +177,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - user = databasegen.User(t, db, database.User{}) - _, token = databasegen.APIKey(t, db, database.APIKey{ + user = dbgen.User(t, db, database.User{}) + _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, ExpiresAt: time.Now().Add(time.Hour * -1), }) @@ -202,8 +201,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - user = databasegen.User(t, db, database.User{}) - sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, ExpiresAt: database.Now().AddDate(0, 0, 1), }) @@ -237,8 +236,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - user = databasegen.User(t, db, database.User{}) - _, token = databasegen.APIKey(t, db, database.APIKey{ + user = dbgen.User(t, db, database.User{}) + _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, ExpiresAt: database.Now().AddDate(0, 0, 1), Scope: database.APIKeyScopeApplicationConnect, @@ -274,8 +273,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - user = databasegen.User(t, db, database.User{}) - _, token = databasegen.APIKey(t, db, database.APIKey{ + user = dbgen.User(t, db, database.User{}) + _, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, ExpiresAt: database.Now().AddDate(0, 0, 1), }) @@ -306,8 +305,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - user = databasegen.User(t, db, database.User{}) - sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now().AddDate(0, 0, -1), ExpiresAt: database.Now().AddDate(0, 0, 1), @@ -337,8 +336,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - user = databasegen.User(t, db, database.User{}) - sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), @@ -368,14 +367,14 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - user = databasegen.User(t, db, database.User{}) - sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), LoginType: database.LoginTypeGithub, }) - _ = databasegen.UserLink(t, db, database.UserLink{ + _ = dbgen.UserLink(t, db, database.UserLink{ UserID: user.ID, LoginType: database.LoginTypeGithub, }) @@ -404,14 +403,14 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - user = databasegen.User(t, db, database.User{}) - sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), LoginType: database.LoginTypeGithub, }) - _ = databasegen.UserLink(t, db, database.UserLink{ + _ = dbgen.UserLink(t, db, database.UserLink{ UserID: user.ID, LoginType: database.LoginTypeGithub, OAuthRefreshToken: "hello", @@ -454,8 +453,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - user = databasegen.User(t, db, database.User{}) - sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now().AddDate(0, 0, -1), ExpiresAt: database.Now().AddDate(0, 0, 1), @@ -537,8 +536,8 @@ func TestAPIKey(t *testing.T) { t.Parallel() var ( db = databasefake.New() - user = databasegen.User(t, db, database.User{}) - sentAPIKey, token = databasegen.APIKey(t, db, database.APIKey{ + user = dbgen.User(t, db, database.User{}) + sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{ UserID: user.ID, LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), diff --git a/coderd/httpmw/groupparam_test.go b/coderd/httpmw/groupparam_test.go index e77786605935d..28038f5d03c3d 100644 --- a/coderd/httpmw/groupparam_test.go +++ b/coderd/httpmw/groupparam_test.go @@ -6,14 +6,13 @@ import ( "net/http/httptest" "testing" - "github.com/coder/coder/coderd/database/databasegen" - "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/httpmw" ) @@ -25,7 +24,7 @@ func TestGroupParam(t *testing.T) { var ( db = databasefake.New() - group = databasegen.Group(t, db, database.Group{}) + group = dbgen.Group(t, db, database.Group{}) r = httptest.NewRequest("GET", "/", nil) w = httptest.NewRecorder() ) @@ -54,7 +53,7 @@ func TestGroupParam(t *testing.T) { var ( db = databasefake.New() - group = databasegen.Group(t, db, database.Group{}) + group = dbgen.Group(t, db, database.Group{}) r = httptest.NewRequest("GET", "/", nil) w = httptest.NewRecorder() ) diff --git a/coderd/httpmw/ratelimit_test.go b/coderd/httpmw/ratelimit_test.go index 4e22065e7408b..e004fb3ed3ed0 100644 --- a/coderd/httpmw/ratelimit_test.go +++ b/coderd/httpmw/ratelimit_test.go @@ -9,13 +9,12 @@ import ( "testing" "time" - "github.com/coder/coder/coderd/database/databasegen" - "github.com/go-chi/chi/v5" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" @@ -73,8 +72,8 @@ func TestRateLimit(t *testing.T) { t.Parallel() db := databasefake.New() - u := databasegen.User(t, db, database.User{}) - _, key := databasegen.APIKey(t, db, database.APIKey{UserID: u.ID}) + u := dbgen.User(t, db, database.User{}) + _, key := dbgen.APIKey(t, db, database.APIKey{UserID: u.ID}) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ @@ -117,10 +116,10 @@ func TestRateLimit(t *testing.T) { db := databasefake.New() - u := databasegen.User(t, db, database.User{ + u := dbgen.User(t, db, database.User{ RBACRoles: []string{rbac.RoleOwner()}, }) - _, key := databasegen.APIKey(t, db, database.APIKey{UserID: u.ID}) + _, key := dbgen.APIKey(t, db, database.APIKey{UserID: u.ID}) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ diff --git a/coderd/httpmw/workspaceresourceparam_test.go b/coderd/httpmw/workspaceresourceparam_test.go index a7c29c32eb889..b2f222f21a33c 100644 --- a/coderd/httpmw/workspaceresourceparam_test.go +++ b/coderd/httpmw/workspaceresourceparam_test.go @@ -6,14 +6,13 @@ import ( "net/http/httptest" "testing" - "github.com/coder/coder/coderd/database/databasegen" - "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/httpmw" ) @@ -22,19 +21,19 @@ func TestWorkspaceResourceParam(t *testing.T) { setup := func(t *testing.T, db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) { r := httptest.NewRequest("GET", "/", nil) - job := databasegen.ProvisionerJob(t, db, database.ProvisionerJob{ + job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ Type: jobType, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, }) - build := databasegen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ JobID: job.ID, Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) - resource := databasegen.WorkspaceResource(t, db, database.WorkspaceResource{ + resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ JobID: job.ID, Transition: database.WorkspaceTransitionStart, }) diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index 675b9d6fc2798..424593b6f282d 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -6,8 +6,6 @@ import ( "testing" "time" - "github.com/coder/coder/coderd/database/databasegen" - "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" @@ -15,6 +13,7 @@ import ( "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/prometheusmetrics" "github.com/coder/coder/codersdk" "github.com/coder/coder/testutil" @@ -37,7 +36,7 @@ func TestActiveUsers(t *testing.T) { Name: "One", Database: func(t *testing.T) database.Store { db := databasefake.New() - databasegen.APIKey(t, db, database.APIKey{ + dbgen.APIKey(t, db, database.APIKey{ LastUsed: database.Now(), }) return db @@ -48,13 +47,13 @@ func TestActiveUsers(t *testing.T) { Database: func(t *testing.T) database.Store { db := databasefake.New() - databasegen.APIKey(t, db, database.APIKey{ + dbgen.APIKey(t, db, database.APIKey{ LastUsed: database.Now(), }) // Because this API key hasn't been used in the past hour, this shouldn't // add to the user count. - databasegen.APIKey(t, db, database.APIKey{ + dbgen.APIKey(t, db, database.APIKey{ LastUsed: database.Now().Add(-2 * time.Hour), }) return db @@ -64,10 +63,10 @@ func TestActiveUsers(t *testing.T) { Name: "Multiple", Database: func(t *testing.T) database.Store { db := databasefake.New() - databasegen.APIKey(t, db, database.APIKey{ + dbgen.APIKey(t, db, database.APIKey{ LastUsed: database.Now(), }) - databasegen.APIKey(t, db, database.APIKey{ + dbgen.APIKey(t, db, database.APIKey{ LastUsed: database.Now(), }) return db diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index c98880241d0ee..6032a8497dd98 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -9,8 +9,6 @@ import ( "testing" "time" - "github.com/coder/coder/coderd/database/databasegen" - "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -18,6 +16,7 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/codersdk" @@ -90,14 +89,14 @@ func TestAcquireJob(t *testing.T) { srv := setup(t, false) ctx := context.Background() - user := databasegen.User(t, srv.Database, database.User{}) - template := databasegen.Template(t, srv.Database, database.Template{ + user := dbgen.User(t, srv.Database, database.User{}) + template := dbgen.Template(t, srv.Database, database.Template{ Name: "template", Provisioner: database.ProvisionerTypeEcho, }) - file := databasegen.File(t, srv.Database, database.File{CreatedBy: user.ID}) - versionFile := databasegen.File(t, srv.Database, database.File{CreatedBy: user.ID}) - version := databasegen.TemplateVersion(t, srv.Database, database.TemplateVersion{ + file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + versionFile := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + version := dbgen.TemplateVersion(t, srv.Database, database.TemplateVersion{ TemplateID: uuid.NullUUID{ UUID: template.ID, Valid: true, @@ -105,7 +104,7 @@ func TestAcquireJob(t *testing.T) { JobID: uuid.New(), }) // Import version job - _ = databasegen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ ID: version.JobID, InitiatorID: user.ID, FileID: versionFile.ID, @@ -116,11 +115,11 @@ func TestAcquireJob(t *testing.T) { TemplateVersionID: version.ID, })), }) - workspace := databasegen.Workspace(t, srv.Database, database.Workspace{ + workspace := dbgen.Workspace(t, srv.Database, database.Workspace{ TemplateID: template.ID, OwnerID: user.ID, }) - build := databasegen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{ + build := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{ WorkspaceID: workspace.ID, BuildNumber: 1, JobID: uuid.New(), @@ -128,7 +127,7 @@ func TestAcquireJob(t *testing.T) { Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) - _ = databasegen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ ID: build.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, @@ -189,10 +188,10 @@ func TestAcquireJob(t *testing.T) { srv := setup(t, false) ctx := context.Background() - user := databasegen.User(t, srv.Database, database.User{}) - version := databasegen.TemplateVersion(t, srv.Database, database.TemplateVersion{}) - file := databasegen.File(t, srv.Database, database.File{CreatedBy: user.ID}) - _ = databasegen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + user := dbgen.User(t, srv.Database, database.User{}) + version := dbgen.TemplateVersion(t, srv.Database, database.TemplateVersion{}) + file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, @@ -228,9 +227,9 @@ func TestAcquireJob(t *testing.T) { srv := setup(t, false) ctx := context.Background() - user := databasegen.User(t, srv.Database, database.User{}) - file := databasegen.File(t, srv.Database, database.File{CreatedBy: user.ID}) - _ = databasegen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ + user := dbgen.User(t, srv.Database, database.User{}) + file := dbgen.File(t, srv.Database, database.File{CreatedBy: user.ID}) + _ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{ FileID: file.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, diff --git a/coderd/workspaceapps_internal_test.go b/coderd/workspaceapps_internal_test.go index 12515a710eacb..f35d904b397af 100644 --- a/coderd/workspaceapps_internal_test.go +++ b/coderd/workspaceapps_internal_test.go @@ -6,12 +6,11 @@ import ( "testing" "time" - "github.com/coder/coder/coderd/database/databasegen" - "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbgen" "github.com/coder/coder/testutil" ) @@ -19,7 +18,7 @@ func TestAPIKeyEncryption(t *testing.T) { t.Parallel() generateAPIKey := func(t *testing.T, db database.Store) (keyID, keyToken string, hashedSecret []byte, data encryptedAPIKeyPayload) { - key, token := databasegen.APIKey(t, db, database.APIKey{}) + key, token := dbgen.APIKey(t, db, database.APIKey{}) data = encryptedAPIKeyPayload{ APIKey: token, @@ -73,7 +72,7 @@ func TestAPIKeyEncryption(t *testing.T) { hashedSecret := sha256.Sum256([]byte("wrong")) // Insert a token with a mismatched hashed secret. - _, token := databasegen.APIKey(t, db, database.APIKey{ + _, token := dbgen.APIKey(t, db, database.APIKey{ HashedSecret: hashedSecret[:], })