Skip to content

Commit 41e5231

Browse files
authored
chore: Add more dbgen functions (coder#6005)
1 parent 5fe4819 commit 41e5231

File tree

3 files changed

+119
-1
lines changed

3 files changed

+119
-1
lines changed

coderd/database/dbgen/generator.go

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"database/sql"
77
"encoding/hex"
88
"fmt"
9+
"net"
910
"testing"
1011
"time"
1112

@@ -21,6 +22,34 @@ import (
2122
// All methods take in a 'seed' object. Any provided fields in the seed will be
2223
// maintained. Any fields omitted will have sensible defaults generated.
2324

25+
func AuditLog(t *testing.T, db database.Store, seed database.AuditLog) database.AuditLog {
26+
log, err := db.InsertAuditLog(context.Background(), database.InsertAuditLogParams{
27+
ID: takeFirst(seed.ID, uuid.New()),
28+
Time: takeFirst(seed.Time, time.Now()),
29+
UserID: takeFirst(seed.UserID, uuid.New()),
30+
OrganizationID: takeFirst(seed.OrganizationID, uuid.New()),
31+
Ip: pqtype.Inet{
32+
IPNet: takeFirstIP(seed.Ip.IPNet, net.IPNet{}),
33+
Valid: takeFirst(seed.Ip.Valid, false),
34+
},
35+
UserAgent: sql.NullString{
36+
String: takeFirst(seed.UserAgent.String, ""),
37+
Valid: takeFirst(seed.UserAgent.Valid, false),
38+
},
39+
ResourceType: takeFirst(seed.ResourceType, database.ResourceTypeOrganization),
40+
ResourceID: takeFirst(seed.ResourceID, uuid.New()),
41+
ResourceTarget: takeFirst(seed.ResourceTarget, uuid.NewString()),
42+
Action: takeFirst(seed.Action, database.AuditActionCreate),
43+
Diff: takeFirstBytes(seed.Diff, []byte("{}")),
44+
StatusCode: takeFirst(seed.StatusCode, 200),
45+
AdditionalFields: takeFirstBytes(seed.Diff, []byte("{}")),
46+
RequestID: takeFirst(seed.RequestID, uuid.New()),
47+
ResourceIcon: takeFirst(seed.ResourceIcon, ""),
48+
})
49+
require.NoError(t, err, "insert audit log")
50+
return log
51+
}
52+
2453
func Template(t *testing.T, db database.Store, seed database.Template) database.Template {
2554
template, err := db.InsertTemplate(context.Background(), database.InsertTemplateParams{
2655
ID: takeFirst(seed.ID, uuid.New()),
@@ -66,6 +95,47 @@ func APIKey(t *testing.T, db database.Store, seed database.APIKey) (key database
6695
return key, fmt.Sprintf("%s-%s", key.ID, secret)
6796
}
6897

98+
func WorkspaceAgent(t *testing.T, db database.Store, orig database.WorkspaceAgent) database.WorkspaceAgent {
99+
workspace, err := db.InsertWorkspaceAgent(context.Background(), database.InsertWorkspaceAgentParams{
100+
ID: takeFirst(orig.ID, uuid.New()),
101+
CreatedAt: takeFirst(orig.CreatedAt, time.Now()),
102+
UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()),
103+
Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)),
104+
ResourceID: takeFirst(orig.ResourceID, uuid.New()),
105+
AuthToken: takeFirst(orig.AuthToken, uuid.New()),
106+
AuthInstanceID: sql.NullString{
107+
String: takeFirst(orig.AuthInstanceID.String, namesgenerator.GetRandomName(1)),
108+
Valid: takeFirst(orig.AuthInstanceID.Valid, true),
109+
},
110+
Architecture: takeFirst(orig.Architecture, "amd64"),
111+
EnvironmentVariables: pqtype.NullRawMessage{
112+
RawMessage: takeFirstBytes(orig.EnvironmentVariables.RawMessage, []byte("{}")),
113+
Valid: takeFirst(orig.EnvironmentVariables.Valid, false),
114+
},
115+
OperatingSystem: takeFirst(orig.OperatingSystem, "linux"),
116+
StartupScript: sql.NullString{
117+
String: takeFirst(orig.StartupScript.String, ""),
118+
Valid: takeFirst(orig.StartupScript.Valid, false),
119+
},
120+
Directory: takeFirst(orig.Directory, ""),
121+
InstanceMetadata: pqtype.NullRawMessage{
122+
RawMessage: takeFirstBytes(orig.ResourceMetadata.RawMessage, []byte("{}")),
123+
Valid: takeFirst(orig.ResourceMetadata.Valid, false),
124+
},
125+
ResourceMetadata: pqtype.NullRawMessage{
126+
RawMessage: takeFirstBytes(orig.ResourceMetadata.RawMessage, []byte("{}")),
127+
Valid: takeFirst(orig.ResourceMetadata.Valid, false),
128+
},
129+
ConnectionTimeoutSeconds: takeFirst(orig.ConnectionTimeoutSeconds, 3600),
130+
TroubleshootingURL: takeFirst(orig.TroubleshootingURL, "https://example.com"),
131+
MOTDFile: takeFirst(orig.TroubleshootingURL, ""),
132+
LoginBeforeReady: takeFirst(orig.LoginBeforeReady, false),
133+
StartupScriptTimeoutSeconds: takeFirst(orig.StartupScriptTimeoutSeconds, 3600),
134+
})
135+
require.NoError(t, err, "insert workspace agent")
136+
return workspace
137+
}
138+
69139
func Workspace(t *testing.T, db database.Store, orig database.Workspace) database.Workspace {
70140
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
71141
ID: takeFirst(orig.ID, uuid.New()),
@@ -89,7 +159,7 @@ func WorkspaceBuild(t *testing.T, db database.Store, orig database.WorkspaceBuil
89159
UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()),
90160
WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()),
91161
TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()),
92-
BuildNumber: takeFirst(orig.BuildNumber, 0),
162+
BuildNumber: takeFirst(orig.BuildNumber, 1),
93163
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
94164
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
95165
JobID: takeFirst(orig.JobID, uuid.New()),
@@ -140,6 +210,20 @@ func Group(t *testing.T, db database.Store, orig database.Group) database.Group
140210
return group
141211
}
142212

213+
func GroupMember(t *testing.T, db database.Store, orig database.GroupMember) database.GroupMember {
214+
member := database.GroupMember{
215+
UserID: takeFirst(orig.UserID, uuid.New()),
216+
GroupID: takeFirst(orig.GroupID, uuid.New()),
217+
}
218+
//nolint:gosimple
219+
err := db.InsertGroupMember(context.Background(), database.InsertGroupMemberParams{
220+
UserID: member.UserID,
221+
GroupID: member.GroupID,
222+
})
223+
require.NoError(t, err, "insert group member")
224+
return member
225+
}
226+
143227
func ProvisionerJob(t *testing.T, db database.Store, orig database.ProvisionerJob) database.ProvisionerJob {
144228
job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
145229
ID: takeFirst(orig.ID, uuid.New()),

coderd/database/dbgen/generator_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ import (
1414
func TestGenerator(t *testing.T) {
1515
t.Parallel()
1616

17+
t.Run("AuditLog", func(t *testing.T) {
18+
t.Parallel()
19+
db := databasefake.New()
20+
_ = dbgen.AuditLog(t, db, database.AuditLog{})
21+
logs := must(db.GetAuditLogsOffset(context.Background(), database.GetAuditLogsOffsetParams{Limit: 1}))
22+
require.Len(t, logs, 1)
23+
})
24+
1725
t.Run("APIKey", func(t *testing.T) {
1826
t.Parallel()
1927
db := databasefake.New()
@@ -56,6 +64,17 @@ func TestGenerator(t *testing.T) {
5664
require.Equal(t, exp, must(db.GetGroupByID(context.Background(), exp.ID)))
5765
})
5866

67+
t.Run("GroupMember", func(t *testing.T) {
68+
t.Parallel()
69+
db := databasefake.New()
70+
g := dbgen.Group(t, db, database.Group{})
71+
u := dbgen.User(t, db, database.User{})
72+
exp := []database.User{u}
73+
dbgen.GroupMember(t, db, database.GroupMember{GroupID: g.ID, UserID: u.ID})
74+
75+
require.Equal(t, exp, must(db.GetGroupMembers(context.Background(), g.ID)))
76+
})
77+
5978
t.Run("Organization", func(t *testing.T) {
6079
t.Parallel()
6180
db := databasefake.New()
@@ -70,6 +89,13 @@ func TestGenerator(t *testing.T) {
7089
require.Equal(t, exp, must(db.GetWorkspaceByID(context.Background(), exp.ID)))
7190
})
7291

92+
t.Run("WorkspaceAgent", func(t *testing.T) {
93+
t.Parallel()
94+
db := databasefake.New()
95+
exp := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{})
96+
require.Equal(t, exp, must(db.GetWorkspaceAgentByID(context.Background(), exp.ID)))
97+
})
98+
7399
t.Run("Template", func(t *testing.T) {
74100
t.Parallel()
75101
db := databasefake.New()

coderd/database/dbgen/take.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
package dbgen
22

3+
import "net"
4+
5+
func takeFirstIP(values ...net.IPNet) net.IPNet {
6+
return takeFirstF(values, func(v net.IPNet) bool {
7+
return len(v.IP) != 0 && len(v.Mask) != 0
8+
})
9+
}
10+
311
// takeFirstBytes implements takeFirst for []byte.
412
// []byte is not a comparable type.
513
func takeFirstBytes(values ...[]byte) []byte {

0 commit comments

Comments
 (0)