Skip to content

Commit fb953e4

Browse files
johnstcnkylecarbs
andcommitted
feat(coderd): add dbcrypt package
- Adds package enterprise/dbcrypt to implement database encryption/decryption - Adds table dbcrypt_keys and associated queries - Adds columns oauth_access_token_key_id and oauth_refresh_token_key_id to tables git_auth_links and user_links NOTE: This is part 1 of a 2-part PR. This PR focuses mainly on the dbcrypt and database packages. A separate PR will add the required plumbing to integrate this into enterprise/coderd properly. Co-authored-by: Kyle Carberry <kyle@coder.com>
1 parent 76ab22f commit fb953e4

22 files changed

+1811
-72
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,13 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
838838
return q.db.GetAuthorizationUserRoles(ctx, userID)
839839
}
840840

841+
func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
842+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
843+
return nil, err
844+
}
845+
return q.db.GetDBCryptKeys(ctx)
846+
}
847+
841848
func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
842849
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
843850
return "", err
@@ -914,6 +921,13 @@ func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLin
914921
return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg)
915922
}
916923

924+
func (q *querier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
925+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
926+
return nil, err
927+
}
928+
return q.db.GetGitAuthLinksByUserID(ctx, userID)
929+
}
930+
917931
func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
918932
return fetch(q.log, q.auth, q.db.GetGitSSHKey)(ctx, userID)
919933
}
@@ -1482,6 +1496,13 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database
14821496
return q.db.GetUserLinkByUserIDLoginType(ctx, arg)
14831497
}
14841498

1499+
func (q *querier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
1500+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
1501+
return nil, err
1502+
}
1503+
return q.db.GetUserLinksByUserID(ctx, userID)
1504+
}
1505+
14851506
func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) {
14861507
// This does the filtering in SQL.
14871508
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
@@ -1845,6 +1866,13 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
18451866
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
18461867
}
18471868

1869+
func (q *querier) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
1870+
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
1871+
return err
1872+
}
1873+
return q.db.InsertDBCryptKey(ctx, arg)
1874+
}
1875+
18481876
func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error {
18491877
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
18501878
return err
@@ -2144,6 +2172,13 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
21442172
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
21452173
}
21462174

2175+
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
2176+
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
2177+
return err
2178+
}
2179+
return q.db.RevokeDBCryptKey(ctx, activeKeyDigest)
2180+
}
2181+
21472182
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
21482183
return q.db.TryAcquireLock(ctx, id)
21492184
}

coderd/database/dbfake/dbfake.go

Lines changed: 138 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ import (
3131

3232
var validProxyByHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
3333

34+
var errForeignKeyConstraint = &pq.Error{
35+
Code: "23503",
36+
Message: "update or delete on table violates foreign key constraint",
37+
}
38+
3439
var errDuplicateKey = &pq.Error{
3540
Code: "23505",
3641
Message: "duplicate key value violates unique constraint",
@@ -45,6 +50,7 @@ func New() database.Store {
4550
organizationMembers: make([]database.OrganizationMember, 0),
4651
organizations: make([]database.Organization, 0),
4752
users: make([]database.User, 0),
53+
dbcryptKeys: make([]database.DBCryptKey, 0),
4854
gitAuthLinks: make([]database.GitAuthLink, 0),
4955
groups: make([]database.Group, 0),
5056
groupMembers: make([]database.GroupMember, 0),
@@ -117,6 +123,7 @@ type data struct {
117123
// New tables
118124
workspaceAgentStats []database.WorkspaceAgentStat
119125
auditLogs []database.AuditLog
126+
dbcryptKeys []database.DBCryptKey
120127
files []database.File
121128
gitAuthLinks []database.GitAuthLink
122129
gitSSHKey []database.GitSSHKey
@@ -665,6 +672,39 @@ func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool {
665672
return false
666673
}
667674

675+
func (q *FakeQuerier) insertDBCryptKeyNoLock(_ context.Context, arg database.InsertDBCryptKeyParams) error {
676+
err := validateDatabaseType(arg)
677+
if err != nil {
678+
return err
679+
}
680+
681+
for _, key := range q.dbcryptKeys {
682+
if key.Number == arg.Number {
683+
return errDuplicateKey
684+
}
685+
}
686+
687+
q.dbcryptKeys = append(q.dbcryptKeys, database.DBCryptKey{
688+
Number: arg.Number,
689+
ActiveKeyDigest: sql.NullString{String: arg.ActiveKeyDigest, Valid: true},
690+
Test: arg.Test,
691+
})
692+
return nil
693+
}
694+
695+
func (q *FakeQuerier) GetActiveDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
696+
q.mutex.RLock()
697+
defer q.mutex.RUnlock()
698+
ks := make([]database.DBCryptKey, 0, len(q.dbcryptKeys))
699+
for _, k := range q.dbcryptKeys {
700+
if !k.ActiveKeyDigest.Valid {
701+
continue
702+
}
703+
ks = append([]database.DBCryptKey{}, k)
704+
}
705+
return ks, nil
706+
}
707+
668708
func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error {
669709
return xerrors.New("AcquireLock must only be called within a transaction")
670710
}
@@ -1151,6 +1191,14 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
11511191
}, nil
11521192
}
11531193

1194+
func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
1195+
q.mutex.RLock()
1196+
defer q.mutex.RUnlock()
1197+
ks := make([]database.DBCryptKey, 0)
1198+
ks = append(ks, q.dbcryptKeys...)
1199+
return ks, nil
1200+
}
1201+
11541202
func (q *FakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
11551203
q.mutex.RLock()
11561204
defer q.mutex.RUnlock()
@@ -1393,6 +1441,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL
13931441
return database.GitAuthLink{}, sql.ErrNoRows
13941442
}
13951443

1444+
func (q *FakeQuerier) GetGitAuthLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
1445+
q.mutex.RLock()
1446+
defer q.mutex.RUnlock()
1447+
gals := make([]database.GitAuthLink, 0)
1448+
for _, gal := range q.gitAuthLinks {
1449+
if gal.UserID == userID {
1450+
gals = append(gals, gal)
1451+
}
1452+
}
1453+
return gals, nil
1454+
}
1455+
13961456
func (q *FakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
13971457
q.mutex.RLock()
13981458
defer q.mutex.RUnlock()
@@ -2833,6 +2893,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
28332893
return database.UserLink{}, sql.ErrNoRows
28342894
}
28352895

2896+
func (q *FakeQuerier) GetUserLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.UserLink, error) {
2897+
q.mutex.RLock()
2898+
defer q.mutex.RUnlock()
2899+
uls := make([]database.UserLink, 0)
2900+
for _, ul := range q.userLinks {
2901+
if ul.UserID == userID {
2902+
uls = append(uls, ul)
2903+
}
2904+
}
2905+
return uls, nil
2906+
}
2907+
28362908
func (q *FakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) {
28372909
if err := validateDatabaseType(params); err != nil {
28382910
return nil, err
@@ -3846,6 +3918,11 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit
38463918
return alog, nil
38473919
}
38483920

3921+
func (q *FakeQuerier) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
3922+
// This only ever gets called inside a transaction, so we need to not lock.
3923+
return q.insertDBCryptKeyNoLock(ctx, arg)
3924+
}
3925+
38493926
func (q *FakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error {
38503927
q.mutex.Lock()
38513928
defer q.mutex.Unlock()
@@ -3892,13 +3969,15 @@ func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGi
38923969
defer q.mutex.Unlock()
38933970
// nolint:gosimple
38943971
gitAuthLink := database.GitAuthLink{
3895-
ProviderID: arg.ProviderID,
3896-
UserID: arg.UserID,
3897-
CreatedAt: arg.CreatedAt,
3898-
UpdatedAt: arg.UpdatedAt,
3899-
OAuthAccessToken: arg.OAuthAccessToken,
3900-
OAuthRefreshToken: arg.OAuthRefreshToken,
3901-
OAuthExpiry: arg.OAuthExpiry,
3972+
ProviderID: arg.ProviderID,
3973+
UserID: arg.UserID,
3974+
CreatedAt: arg.CreatedAt,
3975+
UpdatedAt: arg.UpdatedAt,
3976+
OAuthAccessToken: arg.OAuthAccessToken,
3977+
OAuthAccessTokenKeyID: arg.OAuthAccessTokenKeyID,
3978+
OAuthRefreshToken: arg.OAuthRefreshToken,
3979+
OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID,
3980+
OAuthExpiry: arg.OAuthExpiry,
39023981
}
39033982
q.gitAuthLinks = append(q.gitAuthLinks, gitAuthLink)
39043983
return gitAuthLink, nil
@@ -4362,12 +4441,14 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
43624441

43634442
//nolint:gosimple
43644443
link := database.UserLink{
4365-
UserID: args.UserID,
4366-
LoginType: args.LoginType,
4367-
LinkedID: args.LinkedID,
4368-
OAuthAccessToken: args.OAuthAccessToken,
4369-
OAuthRefreshToken: args.OAuthRefreshToken,
4370-
OAuthExpiry: args.OAuthExpiry,
4444+
UserID: args.UserID,
4445+
LoginType: args.LoginType,
4446+
LinkedID: args.LinkedID,
4447+
OAuthAccessToken: args.OAuthAccessToken,
4448+
OAuthAccessTokenKeyID: args.OAuthAccessTokenKeyID,
4449+
OAuthRefreshToken: args.OAuthRefreshToken,
4450+
OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID,
4451+
OAuthExpiry: args.OAuthExpiry,
43714452
}
43724453

43734454
q.userLinks = append(q.userLinks, link)
@@ -4793,6 +4874,46 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
47934874
return database.WorkspaceProxy{}, sql.ErrNoRows
47944875
}
47954876

4877+
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
4878+
q.mutex.Lock()
4879+
defer q.mutex.Unlock()
4880+
4881+
for i := range q.dbcryptKeys {
4882+
key := q.dbcryptKeys[i]
4883+
4884+
// Is the key already revoked?
4885+
if !key.ActiveKeyDigest.Valid {
4886+
continue
4887+
}
4888+
4889+
if key.ActiveKeyDigest.String != activeKeyDigest {
4890+
continue
4891+
}
4892+
4893+
// Check for foreign key constraints.
4894+
for _, ul := range q.userLinks {
4895+
if (ul.OAuthAccessTokenKeyID.Valid && ul.OAuthAccessTokenKeyID.String == activeKeyDigest) ||
4896+
(ul.OAuthRefreshTokenKeyID.Valid && ul.OAuthRefreshTokenKeyID.String == activeKeyDigest) {
4897+
return errForeignKeyConstraint
4898+
}
4899+
}
4900+
for _, gal := range q.gitAuthLinks {
4901+
if (gal.OAuthAccessTokenKeyID.Valid && gal.OAuthAccessTokenKeyID.String == activeKeyDigest) ||
4902+
(gal.OAuthRefreshTokenKeyID.Valid && gal.OAuthRefreshTokenKeyID.String == activeKeyDigest) {
4903+
return errForeignKeyConstraint
4904+
}
4905+
}
4906+
4907+
// Revoke the key.
4908+
q.dbcryptKeys[i].RevokedAt = sql.NullTime{Time: dbtime.Now(), Valid: true}
4909+
q.dbcryptKeys[i].RevokedKeyDigest = sql.NullString{String: key.ActiveKeyDigest.String, Valid: true}
4910+
q.dbcryptKeys[i].ActiveKeyDigest = sql.NullString{}
4911+
return nil
4912+
}
4913+
4914+
return sql.ErrNoRows
4915+
}
4916+
47964917
func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) {
47974918
return false, xerrors.New("TryAcquireLock must only be called within a transaction")
47984919
}
@@ -4834,7 +4955,9 @@ func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
48344955
}
48354956
gitAuthLink.UpdatedAt = arg.UpdatedAt
48364957
gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken
4958+
gitAuthLink.OAuthAccessTokenKeyID = arg.OAuthAccessTokenKeyID
48374959
gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken
4960+
gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID
48384961
gitAuthLink.OAuthExpiry = arg.OAuthExpiry
48394962
q.gitAuthLinks[index] = gitAuthLink
48404963

@@ -5306,7 +5429,9 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
53065429
for i, link := range q.userLinks {
53075430
if link.UserID == params.UserID && link.LoginType == params.LoginType {
53085431
link.OAuthAccessToken = params.OAuthAccessToken
5432+
link.OAuthAccessTokenKeyID = params.OAuthAccessTokenKeyID
53095433
link.OAuthRefreshToken = params.OAuthRefreshToken
5434+
link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID
53105435
link.OAuthExpiry = params.OAuthExpiry
53115436

53125437
q.userLinks[i] = link

coderd/database/dbgen/dbgen.go

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -470,12 +470,14 @@ func File(t testing.TB, db database.Store, orig database.File) database.File {
470470

471471
func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.UserLink {
472472
link, err := db.InsertUserLink(genCtx, database.InsertUserLinkParams{
473-
UserID: takeFirst(orig.UserID, uuid.New()),
474-
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
475-
LinkedID: takeFirst(orig.LinkedID),
476-
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
477-
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
478-
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
473+
UserID: takeFirst(orig.UserID, uuid.New()),
474+
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
475+
LinkedID: takeFirst(orig.LinkedID),
476+
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
477+
OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}),
478+
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
479+
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
480+
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
479481
})
480482

481483
require.NoError(t, err, "insert link")
@@ -484,13 +486,15 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
484486

485487
func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) database.GitAuthLink {
486488
link, err := db.InsertGitAuthLink(genCtx, database.InsertGitAuthLinkParams{
487-
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
488-
UserID: takeFirst(orig.UserID, uuid.New()),
489-
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
490-
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
491-
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
492-
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
493-
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
489+
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
490+
UserID: takeFirst(orig.UserID, uuid.New()),
491+
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
492+
OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}),
493+
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
494+
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
495+
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
496+
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
497+
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
494498
})
495499

496500
require.NoError(t, err, "insert git auth link")

0 commit comments

Comments
 (0)