Skip to content

Commit 4426bcf

Browse files
committed
improve and reduce boilerplate in tests
1 parent 14a1740 commit 4426bcf

File tree

1 file changed

+62
-153
lines changed

1 file changed

+62
-153
lines changed

coderd/workspaceapps/db_test.go

Lines changed: 62 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/rand"
66
"database/sql"
7+
"encoding/json"
78
"fmt"
89
"io"
910
"net"
@@ -246,6 +247,9 @@ func Test_ResolveRequest(t *testing.T) {
246247
// Reset audit logs so cleanup check can pass.
247248
auditor.ResetLogs()
248249

250+
assertAuditAgent := auditAsserter[database.WorkspaceAgent](workspace)
251+
assertAuditApp := auditAsserter[database.WorkspaceApp](workspace)
252+
249253
t.Run("OK", func(t *testing.T) {
250254
t.Parallel()
251255

@@ -332,18 +336,8 @@ func Test_ResolveRequest(t *testing.T) {
332336
require.Equal(t, codersdk.SignedAppTokenCookie, cookie.Name)
333337
require.Equal(t, req.BasePath, cookie.Path)
334338

335-
require.True(t, auditor.Contains(t, database.AuditLog{
336-
OrganizationID: workspace.OrganizationID,
337-
Action: database.AuditActionOpen,
338-
ResourceType: audit.ResourceType(appsBySlug[app]),
339-
ResourceID: audit.ResourceID(appsBySlug[app]),
340-
ResourceTarget: audit.ResourceTarget(appsBySlug[app]),
341-
UserID: me.ID,
342-
UserAgent: sql.NullString{Valid: true, String: auditableUA},
343-
Ip: audit.ParseIP(auditableIP),
344-
StatusCode: int32(w.StatusCode), //nolint:gosec
345-
}), "audit log")
346-
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
339+
assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil)
340+
require.Len(t, auditor.AuditLogs(), 1, "audit log count")
347341

348342
var parsedToken workspaceapps.SignedToken
349343
err := jwtutils.Verify(ctx, api.AppSigningKeyCache, cookie.Value, &parsedToken)
@@ -421,16 +415,7 @@ func Test_ResolveRequest(t *testing.T) {
421415
require.NotNil(t, token)
422416
require.Zero(t, w.StatusCode)
423417

424-
require.True(t, auditor.Contains(t, database.AuditLog{
425-
OrganizationID: workspace.OrganizationID,
426-
Action: database.AuditActionOpen,
427-
ResourceType: audit.ResourceType(appsBySlug[app]),
428-
ResourceID: audit.ResourceID(appsBySlug[app]),
429-
ResourceTarget: audit.ResourceTarget(appsBySlug[app]),
430-
UserID: secondUser.ID,
431-
Ip: audit.ParseIP(auditableIP),
432-
StatusCode: int32(w.StatusCode), //nolint:gosec
433-
}), "audit log")
418+
assertAuditApp(t, rw, r, auditor, appsBySlug[app], secondUser.ID, nil)
434419
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
435420
}
436421
})
@@ -483,17 +468,8 @@ func Test_ResolveRequest(t *testing.T) {
483468
t.Fatalf("expected 200 (or unset) response code, got %d", rw.Code)
484469
}
485470

486-
require.True(t, auditor.Contains(t, database.AuditLog{
487-
OrganizationID: workspace.OrganizationID,
488-
ResourceType: audit.ResourceType(appsBySlug[app]),
489-
ResourceID: audit.ResourceID(appsBySlug[app]),
490-
ResourceTarget: audit.ResourceTarget(appsBySlug[app]),
491-
UserID: uuid.Nil, // Nil is not verified by Contains, see below.
492-
Ip: audit.ParseIP(auditableIP),
493-
StatusCode: int32(w.StatusCode), //nolint:gosec
494-
}), "audit log")
471+
assertAuditApp(t, rw, r, auditor, appsBySlug[app], uuid.Nil, nil)
495472
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
496-
require.Equal(t, uuid.Nil, auditor.AuditLogs()[0].UserID, "no user ID in audit log")
497473
}
498474
_ = w.Body.Close()
499475
}
@@ -617,15 +593,7 @@ func Test_ResolveRequest(t *testing.T) {
617593
require.Equal(t, token.AgentNameOrID, c.agent)
618594
require.Equal(t, token.WorkspaceID, workspace.ID)
619595
require.Equal(t, token.AgentID, agentID)
620-
require.True(t, auditor.Contains(t, database.AuditLog{
621-
OrganizationID: workspace.OrganizationID,
622-
ResourceType: audit.ResourceType(appsBySlug[token.AppSlugOrPort]),
623-
ResourceID: audit.ResourceID(appsBySlug[token.AppSlugOrPort]),
624-
ResourceTarget: audit.ResourceTarget(appsBySlug[token.AppSlugOrPort]),
625-
UserID: me.ID,
626-
Ip: audit.ParseIP(auditableIP),
627-
StatusCode: int32(w.StatusCode), //nolint:gosec
628-
}), "audit log")
596+
assertAuditApp(t, rw, r, auditor, appsBySlug[token.AppSlugOrPort], me.ID, nil)
629597
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
630598
} else {
631599
require.Nil(t, token)
@@ -710,15 +678,7 @@ func Test_ResolveRequest(t *testing.T) {
710678
require.NoError(t, err)
711679
require.Equal(t, appNameOwner, parsedToken.AppSlugOrPort)
712680

713-
require.True(t, auditor.Contains(t, database.AuditLog{
714-
OrganizationID: workspace.OrganizationID,
715-
ResourceType: audit.ResourceType(appsBySlug[token.AppSlugOrPort]),
716-
ResourceID: audit.ResourceID(appsBySlug[token.AppSlugOrPort]),
717-
ResourceTarget: audit.ResourceTarget(appsBySlug[token.AppSlugOrPort]),
718-
UserID: me.ID,
719-
Ip: audit.ParseIP(auditableIP),
720-
StatusCode: int32(w.StatusCode), //nolint:gosec
721-
}), "audit log")
681+
assertAuditApp(t, rw, r, auditor, appsBySlug[appNameOwner], me.ID, nil)
722682
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
723683
})
724684

@@ -792,18 +752,9 @@ func Test_ResolveRequest(t *testing.T) {
792752
require.Equal(t, req.AppSlugOrPort, token.AppSlugOrPort)
793753
require.Equal(t, "http://127.0.0.1:9090", token.AppURL)
794754

795-
w := rw.Result()
796-
_ = w.Body.Close()
797-
require.Equal(t, http.StatusOK, w.StatusCode)
798-
require.True(t, auditor.Contains(t, database.AuditLog{
799-
OrganizationID: workspace.OrganizationID,
800-
ResourceType: audit.ResourceType(agent),
801-
ResourceID: audit.ResourceID(agent),
802-
ResourceTarget: audit.ResourceTarget(agent),
803-
UserID: me.ID,
804-
Ip: audit.ParseIP(auditableIP),
805-
StatusCode: int32(w.StatusCode), //nolint:gosec
806-
}), "audit log for agent, not app")
755+
assertAuditAgent(t, rw, r, auditor, agent, me.ID, map[string]any{
756+
"slug_or_port": "9090",
757+
})
807758
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
808759
})
809760

@@ -876,17 +827,7 @@ func Test_ResolveRequest(t *testing.T) {
876827
})
877828
require.True(t, ok)
878829
require.Equal(t, req.AppSlugOrPort, token.AppSlugOrPort)
879-
w := rw.Result()
880-
_ = w.Body.Close()
881-
require.True(t, auditor.Contains(t, database.AuditLog{
882-
OrganizationID: workspace.OrganizationID,
883-
ResourceType: audit.ResourceType(appsBySlug[token.AppSlugOrPort]),
884-
ResourceID: audit.ResourceID(appsBySlug[token.AppSlugOrPort]),
885-
ResourceTarget: audit.ResourceTarget(appsBySlug[token.AppSlugOrPort]),
886-
UserID: me.ID,
887-
Ip: audit.ParseIP(auditableIP),
888-
StatusCode: int32(w.StatusCode), //nolint:gosec
889-
}), "audit log")
830+
assertAuditApp(t, rw, r, auditor, appsBySlug[appNameEndsInS], me.ID, nil)
890831
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
891832
})
892833

@@ -923,17 +864,9 @@ func Test_ResolveRequest(t *testing.T) {
923864
require.Equal(t, req.AgentNameOrID, token.Request.AgentNameOrID)
924865
require.Empty(t, token.AppSlugOrPort)
925866
require.Empty(t, token.AppURL)
926-
w := rw.Result()
927-
_ = w.Body.Close()
928-
require.True(t, auditor.Contains(t, database.AuditLog{
929-
OrganizationID: workspace.OrganizationID,
930-
ResourceType: audit.ResourceType(agent),
931-
ResourceID: audit.ResourceID(agent),
932-
ResourceTarget: audit.ResourceTarget(agent),
933-
UserID: me.ID,
934-
Ip: audit.ParseIP(auditableIP),
935-
StatusCode: int32(w.StatusCode), //nolint:gosec
936-
}), "audit log for agent, not app")
867+
assertAuditAgent(t, rw, r, auditor, agent, me.ID, map[string]any{
868+
"slug_or_port": "terminal",
869+
})
937870
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
938871
})
939872

@@ -967,15 +900,7 @@ func Test_ResolveRequest(t *testing.T) {
967900
})
968901
require.False(t, ok)
969902
require.Nil(t, token)
970-
w := rw.Result()
971-
_ = w.Body.Close()
972-
require.Equal(t, http.StatusNotFound, w.StatusCode)
973-
require.True(t, auditor.Contains(t, database.AuditLog{
974-
OrganizationID: workspace.OrganizationID,
975-
UserID: secondUser.ID,
976-
Ip: audit.ParseIP(auditableIP),
977-
StatusCode: int32(w.StatusCode), //nolint:gosec
978-
}), "audit log insufficient permissions")
903+
assertAuditApp(t, rw, r, auditor, appsBySlug[appNameOwner], secondUser.ID, nil)
979904
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
980905
})
981906

@@ -1108,12 +1033,7 @@ func Test_ResolveRequest(t *testing.T) {
11081033
w := rw.Result()
11091034
defer w.Body.Close()
11101035
require.Equal(t, http.StatusBadGateway, w.StatusCode)
1111-
require.True(t, auditor.Contains(t, database.AuditLog{
1112-
OrganizationID: workspace.OrganizationID,
1113-
UserID: me.ID,
1114-
Ip: audit.ParseIP(auditableIP),
1115-
StatusCode: int32(w.StatusCode), //nolint:gosec
1116-
}), "audit log unhealthy agent")
1036+
assertAuditApp(t, rw, r, auditor, appsBySlug[appNameAgentUnhealthy], me.ID, nil)
11171037
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
11181038

11191039
body, err := io.ReadAll(w.Body)
@@ -1172,14 +1092,7 @@ func Test_ResolveRequest(t *testing.T) {
11721092
})
11731093
require.True(t, ok, "ResolveRequest failed, should pass even though app is initializing")
11741094
require.NotNil(t, token)
1175-
w := rw.Result()
1176-
_ = w.Body.Close()
1177-
require.True(t, auditor.Contains(t, database.AuditLog{
1178-
OrganizationID: workspace.OrganizationID,
1179-
UserID: me.ID,
1180-
Ip: audit.ParseIP(auditableIP),
1181-
StatusCode: int32(w.StatusCode), //nolint:gosec
1182-
}), "audit log initializing app")
1095+
assertAuditApp(t, rw, r, auditor, appsBySlug[token.AppSlugOrPort], me.ID, nil)
11831096
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
11841097
})
11851098

@@ -1237,14 +1150,7 @@ func Test_ResolveRequest(t *testing.T) {
12371150
})
12381151
require.True(t, ok, "ResolveRequest failed, should pass even though app is unhealthy")
12391152
require.NotNil(t, token)
1240-
w := rw.Result()
1241-
_ = w.Body.Close()
1242-
require.True(t, auditor.Contains(t, database.AuditLog{
1243-
OrganizationID: workspace.OrganizationID,
1244-
UserID: me.ID,
1245-
Ip: audit.ParseIP(auditableIP),
1246-
StatusCode: int32(w.StatusCode), //nolint:gosec
1247-
}), "audit log unhealthy app")
1153+
assertAuditApp(t, rw, r, auditor, appsBySlug[token.AppSlugOrPort], me.ID, nil)
12481154
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
12491155
})
12501156

@@ -1281,18 +1187,7 @@ func Test_ResolveRequest(t *testing.T) {
12811187
AppRequest: req,
12821188
})
12831189
require.True(t, ok)
1284-
w := rw.Result()
1285-
_ = w.Body.Close()
1286-
require.True(t, auditor.Contains(t, database.AuditLog{
1287-
OrganizationID: workspace.OrganizationID,
1288-
Action: database.AuditActionOpen,
1289-
ResourceType: audit.ResourceType(appsBySlug[app]),
1290-
ResourceID: audit.ResourceID(appsBySlug[app]),
1291-
ResourceTarget: audit.ResourceTarget(appsBySlug[app]),
1292-
UserID: me.ID,
1293-
Ip: audit.ParseIP(auditableIP),
1294-
StatusCode: int32(w.StatusCode), //nolint:gosec
1295-
}), "audit log 1")
1190+
assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil)
12961191
require.Len(t, auditor.AuditLogs(), 1, "single audit log")
12971192

12981193
// Second request, no audit log because the session is active.
@@ -1310,8 +1205,6 @@ func Test_ResolveRequest(t *testing.T) {
13101205
AppRequest: req,
13111206
})
13121207
require.True(t, ok)
1313-
w = rw.Result()
1314-
_ = w.Body.Close()
13151208
require.Len(t, auditor.AuditLogs(), 1, "single audit log, previous session active")
13161209

13171210
// Third request, session timed out, new audit log.
@@ -1330,18 +1223,7 @@ func Test_ResolveRequest(t *testing.T) {
13301223
AppRequest: req,
13311224
})
13321225
require.True(t, ok)
1333-
w = rw.Result()
1334-
_ = w.Body.Close()
1335-
require.True(t, auditor.Contains(t, database.AuditLog{
1336-
OrganizationID: workspace.OrganizationID,
1337-
Action: database.AuditActionOpen,
1338-
ResourceType: audit.ResourceType(appsBySlug[app]),
1339-
ResourceID: audit.ResourceID(appsBySlug[app]),
1340-
ResourceTarget: audit.ResourceTarget(appsBySlug[app]),
1341-
UserID: me.ID,
1342-
Ip: audit.ParseIP(auditableIP),
1343-
StatusCode: int32(w.StatusCode), //nolint:gosec
1344-
}), "audit log 2")
1226+
assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil)
13451227
require.Len(t, auditor.AuditLogs(), 2, "two audit logs, session timed out")
13461228

13471229
// Fourth request, new IP produces new audit log.
@@ -1360,18 +1242,7 @@ func Test_ResolveRequest(t *testing.T) {
13601242
AppRequest: req,
13611243
})
13621244
require.True(t, ok)
1363-
w = rw.Result()
1364-
_ = w.Body.Close()
1365-
require.True(t, auditor.Contains(t, database.AuditLog{
1366-
OrganizationID: workspace.OrganizationID,
1367-
Action: database.AuditActionOpen,
1368-
ResourceType: audit.ResourceType(appsBySlug[app]),
1369-
ResourceID: audit.ResourceID(appsBySlug[app]),
1370-
ResourceTarget: audit.ResourceTarget(appsBySlug[app]),
1371-
UserID: me.ID,
1372-
Ip: audit.ParseIP(auditableIP),
1373-
StatusCode: int32(w.StatusCode), //nolint:gosec
1374-
}), "audit log 3")
1245+
assertAuditApp(t, rw, r, auditor, appsBySlug[app], me.ID, nil)
13751246
require.Len(t, auditor.AuditLogs(), 3, "three audit logs, new IP")
13761247
}
13771248
})
@@ -1413,3 +1284,41 @@ func signedTokenProviderWithAuditor(t testing.TB, provider workspaceapps.SignedT
14131284
shallowCopy.WorkspaceAppAuditSessionTimeout = sessionTimeout
14141285
return &shallowCopy
14151286
}
1287+
1288+
func auditAsserter[T audit.Auditable](workspace codersdk.Workspace) func(t testing.TB, rr *httptest.ResponseRecorder, r *http.Request, auditor *audit.MockAuditor, auditable T, userID uuid.UUID, additionalFields map[string]any) {
1289+
return func(t testing.TB, rr *httptest.ResponseRecorder, r *http.Request, auditor *audit.MockAuditor, auditable T, userID uuid.UUID, additionalFields map[string]any) {
1290+
t.Helper()
1291+
1292+
resp := rr.Result()
1293+
defer resp.Body.Close()
1294+
1295+
require.True(t, auditor.Contains(t, database.AuditLog{
1296+
OrganizationID: workspace.OrganizationID,
1297+
Action: database.AuditActionOpen,
1298+
ResourceType: audit.ResourceType(auditable),
1299+
ResourceID: audit.ResourceID(auditable),
1300+
ResourceTarget: audit.ResourceTarget(auditable),
1301+
UserID: userID,
1302+
Ip: audit.ParseIP(r.RemoteAddr),
1303+
UserAgent: sql.NullString{Valid: r.UserAgent() != "", String: r.UserAgent()},
1304+
StatusCode: int32(resp.StatusCode), //nolint:gosec
1305+
}), "audit log")
1306+
1307+
// Verify additional fields, assume the last log entry.
1308+
alog := auditor.AuditLogs()[len(auditor.AuditLogs())-1]
1309+
1310+
// Contains does not verify uuid.Nil.
1311+
if userID == uuid.Nil {
1312+
require.Equal(t, uuid.Nil, alog.UserID, "unauthenticated user")
1313+
}
1314+
1315+
add := make(map[string]any)
1316+
if len(alog.AdditionalFields) > 0 {
1317+
err := json.Unmarshal([]byte(alog.AdditionalFields), &add)
1318+
require.NoError(t, err, "audit log unmarhsal additional fields")
1319+
}
1320+
for k, v := range additionalFields {
1321+
require.Equal(t, v, add[k], "audit log additional field %s: additional fields: %v", k, add)
1322+
}
1323+
}
1324+
}

0 commit comments

Comments
 (0)