Skip to content

Commit 4a6fc40

Browse files
authored
feat: Add database data generator to make fakedbs easier to populate (#5922)
* feat: Add database data generator to make fakedbs easier to populate
1 parent c162c0f commit 4a6fc40

10 files changed

+640
-472
lines changed

coderd/database/dbgen/generator.go

+225
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package dbgen
2+
3+
import (
4+
"context"
5+
"crypto/sha256"
6+
"database/sql"
7+
"encoding/hex"
8+
"fmt"
9+
"testing"
10+
"time"
11+
12+
"github.com/coder/coder/cryptorand"
13+
"github.com/tabbed/pqtype"
14+
15+
"github.com/coder/coder/coderd/database"
16+
"github.com/google/uuid"
17+
"github.com/moby/moby/pkg/namesgenerator"
18+
"github.com/stretchr/testify/require"
19+
)
20+
21+
// All methods take in a 'seed' object. Any provided fields in the seed will be
22+
// maintained. Any fields omitted will have sensible defaults generated.
23+
24+
func Template(t *testing.T, db database.Store, seed database.Template) database.Template {
25+
template, err := db.InsertTemplate(context.Background(), database.InsertTemplateParams{
26+
ID: takeFirst(seed.ID, uuid.New()),
27+
CreatedAt: takeFirst(seed.CreatedAt, time.Now()),
28+
UpdatedAt: takeFirst(seed.UpdatedAt, time.Now()),
29+
OrganizationID: takeFirst(seed.OrganizationID, uuid.New()),
30+
Name: takeFirst(seed.Name, namesgenerator.GetRandomName(1)),
31+
Provisioner: takeFirst(seed.Provisioner, database.ProvisionerTypeEcho),
32+
ActiveVersionID: takeFirst(seed.ActiveVersionID, uuid.New()),
33+
Description: takeFirst(seed.Description, namesgenerator.GetRandomName(1)),
34+
DefaultTTL: takeFirst(seed.DefaultTTL, 3600),
35+
CreatedBy: takeFirst(seed.CreatedBy, uuid.New()),
36+
Icon: takeFirst(seed.Icon, namesgenerator.GetRandomName(1)),
37+
UserACL: seed.UserACL,
38+
GroupACL: seed.GroupACL,
39+
DisplayName: takeFirst(seed.DisplayName, namesgenerator.GetRandomName(1)),
40+
AllowUserCancelWorkspaceJobs: takeFirst(seed.AllowUserCancelWorkspaceJobs, true),
41+
})
42+
require.NoError(t, err, "insert template")
43+
return template
44+
}
45+
46+
func APIKey(t *testing.T, db database.Store, seed database.APIKey) (key database.APIKey, token string) {
47+
id, _ := cryptorand.String(10)
48+
secret, _ := cryptorand.String(22)
49+
hashed := sha256.Sum256([]byte(secret))
50+
51+
key, err := db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{
52+
ID: takeFirst(seed.ID, id),
53+
// 0 defaults to 86400 at the db layer
54+
LifetimeSeconds: takeFirst(seed.LifetimeSeconds, 0),
55+
HashedSecret: takeFirstBytes(seed.HashedSecret, hashed[:]),
56+
IPAddress: pqtype.Inet{},
57+
UserID: takeFirst(seed.UserID, uuid.New()),
58+
LastUsed: takeFirst(seed.LastUsed, time.Now()),
59+
ExpiresAt: takeFirst(seed.ExpiresAt, time.Now().Add(time.Hour)),
60+
CreatedAt: takeFirst(seed.CreatedAt, time.Now()),
61+
UpdatedAt: takeFirst(seed.UpdatedAt, time.Now()),
62+
LoginType: takeFirst(seed.LoginType, database.LoginTypePassword),
63+
Scope: takeFirst(seed.Scope, database.APIKeyScopeAll),
64+
})
65+
require.NoError(t, err, "insert api key")
66+
return key, fmt.Sprintf("%s-%s", key.ID, secret)
67+
}
68+
69+
func Workspace(t *testing.T, db database.Store, orig database.Workspace) database.Workspace {
70+
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
71+
ID: takeFirst(orig.ID, uuid.New()),
72+
OwnerID: takeFirst(orig.OwnerID, uuid.New()),
73+
CreatedAt: takeFirst(orig.CreatedAt, time.Now()),
74+
UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()),
75+
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
76+
TemplateID: takeFirst(orig.TemplateID, uuid.New()),
77+
Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)),
78+
AutostartSchedule: orig.AutostartSchedule,
79+
Ttl: orig.Ttl,
80+
})
81+
require.NoError(t, err, "insert workspace")
82+
return workspace
83+
}
84+
85+
func WorkspaceBuild(t *testing.T, db database.Store, orig database.WorkspaceBuild) database.WorkspaceBuild {
86+
build, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{
87+
ID: takeFirst(orig.ID, uuid.New()),
88+
CreatedAt: takeFirst(orig.CreatedAt, time.Now()),
89+
UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()),
90+
WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()),
91+
TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()),
92+
BuildNumber: takeFirst(orig.BuildNumber, 0),
93+
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
94+
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
95+
JobID: takeFirst(orig.JobID, uuid.New()),
96+
ProvisionerState: takeFirstBytes(orig.ProvisionerState, []byte{}),
97+
Deadline: takeFirst(orig.Deadline, time.Now().Add(time.Hour)),
98+
Reason: takeFirst(orig.Reason, database.BuildReasonInitiator),
99+
})
100+
require.NoError(t, err, "insert workspace build")
101+
return build
102+
}
103+
104+
func User(t *testing.T, db database.Store, orig database.User) database.User {
105+
user, err := db.InsertUser(context.Background(), database.InsertUserParams{
106+
ID: takeFirst(orig.ID, uuid.New()),
107+
Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)),
108+
Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)),
109+
HashedPassword: takeFirstBytes(orig.HashedPassword, []byte{}),
110+
CreatedAt: takeFirst(orig.CreatedAt, time.Now()),
111+
UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()),
112+
RBACRoles: []string{},
113+
LoginType: takeFirst(orig.LoginType, database.LoginTypePassword),
114+
})
115+
require.NoError(t, err, "insert user")
116+
return user
117+
}
118+
119+
func Organization(t *testing.T, db database.Store, orig database.Organization) database.Organization {
120+
org, err := db.InsertOrganization(context.Background(), database.InsertOrganizationParams{
121+
ID: takeFirst(orig.ID, uuid.New()),
122+
Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)),
123+
Description: takeFirst(orig.Description, namesgenerator.GetRandomName(1)),
124+
CreatedAt: takeFirst(orig.CreatedAt, time.Now()),
125+
UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()),
126+
})
127+
require.NoError(t, err, "insert organization")
128+
return org
129+
}
130+
131+
func Group(t *testing.T, db database.Store, orig database.Group) database.Group {
132+
group, err := db.InsertGroup(context.Background(), database.InsertGroupParams{
133+
ID: takeFirst(orig.ID, uuid.New()),
134+
Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)),
135+
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
136+
AvatarURL: takeFirst(orig.AvatarURL, "https://logo.example.com"),
137+
QuotaAllowance: takeFirst(orig.QuotaAllowance, 0),
138+
})
139+
require.NoError(t, err, "insert group")
140+
return group
141+
}
142+
143+
func ProvisionerJob(t *testing.T, db database.Store, orig database.ProvisionerJob) database.ProvisionerJob {
144+
job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
145+
ID: takeFirst(orig.ID, uuid.New()),
146+
CreatedAt: takeFirst(orig.CreatedAt, time.Now()),
147+
UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()),
148+
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
149+
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
150+
Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho),
151+
StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile),
152+
FileID: takeFirst(orig.FileID, uuid.New()),
153+
Type: takeFirst(orig.Type, database.ProvisionerJobTypeWorkspaceBuild),
154+
Input: takeFirstBytes(orig.Input, []byte("{}")),
155+
Tags: orig.Tags,
156+
})
157+
require.NoError(t, err, "insert job")
158+
return job
159+
}
160+
161+
func WorkspaceResource(t *testing.T, db database.Store, orig database.WorkspaceResource) database.WorkspaceResource {
162+
resource, err := db.InsertWorkspaceResource(context.Background(), database.InsertWorkspaceResourceParams{
163+
ID: takeFirst(orig.ID, uuid.New()),
164+
CreatedAt: takeFirst(orig.CreatedAt, time.Now()),
165+
JobID: takeFirst(orig.JobID, uuid.New()),
166+
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
167+
Type: takeFirst(orig.Type, "fake_resource"),
168+
Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)),
169+
Hide: takeFirst(orig.Hide, false),
170+
Icon: takeFirst(orig.Icon, ""),
171+
InstanceType: sql.NullString{
172+
String: takeFirst(orig.InstanceType.String, ""),
173+
Valid: takeFirst(orig.InstanceType.Valid, false),
174+
},
175+
DailyCost: takeFirst(orig.DailyCost, 0),
176+
})
177+
require.NoError(t, err, "insert resource")
178+
return resource
179+
}
180+
181+
func File(t *testing.T, db database.Store, orig database.File) database.File {
182+
file, err := db.InsertFile(context.Background(), database.InsertFileParams{
183+
ID: takeFirst(orig.ID, uuid.New()),
184+
Hash: takeFirst(orig.Hash, hex.EncodeToString(make([]byte, 32))),
185+
CreatedAt: takeFirst(orig.CreatedAt, time.Now()),
186+
CreatedBy: takeFirst(orig.CreatedBy, uuid.New()),
187+
Mimetype: takeFirst(orig.Mimetype, "application/x-tar"),
188+
Data: takeFirstBytes(orig.Data, []byte{}),
189+
})
190+
require.NoError(t, err, "insert file")
191+
return file
192+
}
193+
194+
func UserLink(t *testing.T, db database.Store, orig database.UserLink) database.UserLink {
195+
link, err := db.InsertUserLink(context.Background(), database.InsertUserLinkParams{
196+
UserID: takeFirst(orig.UserID, uuid.New()),
197+
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
198+
LinkedID: takeFirst(orig.LinkedID),
199+
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
200+
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
201+
OAuthExpiry: takeFirst(orig.OAuthExpiry, time.Now().Add(time.Hour*24)),
202+
})
203+
204+
require.NoError(t, err, "insert link")
205+
return link
206+
}
207+
208+
func TemplateVersion(t *testing.T, db database.Store, orig database.TemplateVersion) database.TemplateVersion {
209+
version, err := db.InsertTemplateVersion(context.Background(), database.InsertTemplateVersionParams{
210+
ID: takeFirst(orig.ID, uuid.New()),
211+
TemplateID: uuid.NullUUID{
212+
UUID: takeFirst(orig.TemplateID.UUID, uuid.New()),
213+
Valid: takeFirst(orig.TemplateID.Valid, true),
214+
},
215+
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
216+
CreatedAt: takeFirst(orig.CreatedAt, time.Now()),
217+
UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()),
218+
Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)),
219+
Readme: takeFirst(orig.Readme, namesgenerator.GetRandomName(1)),
220+
JobID: takeFirst(orig.JobID, uuid.New()),
221+
CreatedBy: takeFirst(orig.CreatedBy, uuid.New()),
222+
})
223+
require.NoError(t, err, "insert template version")
224+
return version
225+
}
+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package dbgen_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/coder/coder/coderd/database"
10+
"github.com/coder/coder/coderd/database/databasefake"
11+
"github.com/coder/coder/coderd/database/dbgen"
12+
)
13+
14+
func TestGenerator(t *testing.T) {
15+
t.Parallel()
16+
17+
t.Run("APIKey", func(t *testing.T) {
18+
t.Parallel()
19+
db := databasefake.New()
20+
exp, _ := dbgen.APIKey(t, db, database.APIKey{})
21+
require.Equal(t, exp, must(db.GetAPIKeyByID(context.Background(), exp.ID)))
22+
})
23+
24+
t.Run("File", func(t *testing.T) {
25+
t.Parallel()
26+
db := databasefake.New()
27+
exp := dbgen.File(t, db, database.File{})
28+
require.Equal(t, exp, must(db.GetFileByID(context.Background(), exp.ID)))
29+
})
30+
31+
t.Run("UserLink", func(t *testing.T) {
32+
t.Parallel()
33+
db := databasefake.New()
34+
exp := dbgen.UserLink(t, db, database.UserLink{})
35+
require.Equal(t, exp, must(db.GetUserLinkByLinkedID(context.Background(), exp.LinkedID)))
36+
})
37+
38+
t.Run("WorkspaceResource", func(t *testing.T) {
39+
t.Parallel()
40+
db := databasefake.New()
41+
exp := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{})
42+
require.Equal(t, exp, must(db.GetWorkspaceResourceByID(context.Background(), exp.ID)))
43+
})
44+
45+
t.Run("Job", func(t *testing.T) {
46+
t.Parallel()
47+
db := databasefake.New()
48+
exp := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{})
49+
require.Equal(t, exp, must(db.GetProvisionerJobByID(context.Background(), exp.ID)))
50+
})
51+
52+
t.Run("Group", func(t *testing.T) {
53+
t.Parallel()
54+
db := databasefake.New()
55+
exp := dbgen.Group(t, db, database.Group{})
56+
require.Equal(t, exp, must(db.GetGroupByID(context.Background(), exp.ID)))
57+
})
58+
59+
t.Run("Organization", func(t *testing.T) {
60+
t.Parallel()
61+
db := databasefake.New()
62+
exp := dbgen.Organization(t, db, database.Organization{})
63+
require.Equal(t, exp, must(db.GetOrganizationByID(context.Background(), exp.ID)))
64+
})
65+
66+
t.Run("Workspace", func(t *testing.T) {
67+
t.Parallel()
68+
db := databasefake.New()
69+
exp := dbgen.Workspace(t, db, database.Workspace{})
70+
require.Equal(t, exp, must(db.GetWorkspaceByID(context.Background(), exp.ID)))
71+
})
72+
73+
t.Run("Template", func(t *testing.T) {
74+
t.Parallel()
75+
db := databasefake.New()
76+
exp := dbgen.Template(t, db, database.Template{})
77+
require.Equal(t, exp, must(db.GetTemplateByID(context.Background(), exp.ID)))
78+
})
79+
80+
t.Run("TemplateVersion", func(t *testing.T) {
81+
t.Parallel()
82+
db := databasefake.New()
83+
exp := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
84+
require.Equal(t, exp, must(db.GetTemplateVersionByID(context.Background(), exp.ID)))
85+
})
86+
87+
t.Run("WorkspaceBuild", func(t *testing.T) {
88+
t.Parallel()
89+
db := databasefake.New()
90+
exp := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{})
91+
require.Equal(t, exp, must(db.GetWorkspaceBuildByID(context.Background(), exp.ID)))
92+
})
93+
94+
t.Run("User", func(t *testing.T) {
95+
t.Parallel()
96+
db := databasefake.New()
97+
exp := dbgen.User(t, db, database.User{})
98+
require.Equal(t, exp, must(db.GetUserByID(context.Background(), exp.ID)))
99+
})
100+
}
101+
102+
func must[T any](value T, err error) T {
103+
if err != nil {
104+
panic(err)
105+
}
106+
return value
107+
}

coderd/database/dbgen/take.go

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package dbgen
2+
3+
// takeFirstBytes implements takeFirst for []byte.
4+
// []byte is not a comparable type.
5+
func takeFirstBytes(values ...[]byte) []byte {
6+
return takeFirstF(values, func(v []byte) bool {
7+
return len(v) != 0
8+
})
9+
}
10+
11+
// takeFirstF takes the first value that returns true
12+
func takeFirstF[Value any](values []Value, take func(v Value) bool) Value {
13+
var empty Value
14+
for _, v := range values {
15+
if take(v) {
16+
return v
17+
}
18+
}
19+
// If all empty, return empty
20+
return empty
21+
}
22+
23+
// takeFirst will take the first non-empty value.
24+
func takeFirst[Value comparable](values ...Value) Value {
25+
var empty Value
26+
return takeFirstF(values, func(v Value) bool {
27+
return v != empty
28+
})
29+
}

0 commit comments

Comments
 (0)