Skip to content

feat(coderd): add dbcrypt package #9522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 6, 2023
35 changes: 35 additions & 0 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,13 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
return q.db.GetAuthorizationUserRoles(ctx, userID)
}

func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetDBCryptKeys(ctx)
}

func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return "", err
Expand Down Expand Up @@ -914,6 +921,13 @@ func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLin
return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg)
}

func (q *querier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetGitAuthLinksByUserID(ctx, userID)
}

func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
return fetch(q.log, q.auth, q.db.GetGitSSHKey)(ctx, userID)
}
Expand Down Expand Up @@ -1482,6 +1496,13 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database
return q.db.GetUserLinkByUserIDLoginType(ctx, arg)
}

func (q *querier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetUserLinksByUserID(ctx, userID)
}

func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) {
// This does the filtering in SQL.
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
Expand Down Expand Up @@ -1845,6 +1866,13 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
}

func (q *querier) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.InsertDBCryptKey(ctx, arg)
}

func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
return err
Expand Down Expand Up @@ -2144,6 +2172,13 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
}

func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.RevokeDBCryptKey(ctx, activeKeyDigest)
}

func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
return q.db.TryAcquireLock(ctx, id)
}
Expand Down
146 changes: 133 additions & 13 deletions coderd/database/dbfake/dbfake.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ import (

var validProxyByHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)

var errForeignKeyConstraint = &pq.Error{
Code: "23503",
Message: "update or delete on table violates foreign key constraint",
}

var errDuplicateKey = &pq.Error{
Code: "23505",
Message: "duplicate key value violates unique constraint",
Expand All @@ -45,6 +50,7 @@ func New() database.Store {
organizationMembers: make([]database.OrganizationMember, 0),
organizations: make([]database.Organization, 0),
users: make([]database.User, 0),
dbcryptKeys: make([]database.DBCryptKey, 0),
gitAuthLinks: make([]database.GitAuthLink, 0),
groups: make([]database.Group, 0),
groupMembers: make([]database.GroupMember, 0),
Expand Down Expand Up @@ -117,6 +123,7 @@ type data struct {
// New tables
workspaceAgentStats []database.WorkspaceAgentStat
auditLogs []database.AuditLog
dbcryptKeys []database.DBCryptKey
files []database.File
gitAuthLinks []database.GitAuthLink
gitSSHKey []database.GitSSHKey
Expand Down Expand Up @@ -665,6 +672,19 @@ func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool {
return false
}

func (q *FakeQuerier) GetActiveDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
ks := make([]database.DBCryptKey, 0, len(q.dbcryptKeys))
for _, k := range q.dbcryptKeys {
if !k.ActiveKeyDigest.Valid {
continue
}
ks = append([]database.DBCryptKey{}, k)
}
return ks, nil
}

func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error {
return xerrors.New("AcquireLock must only be called within a transaction")
}
Expand Down Expand Up @@ -1151,6 +1171,14 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
}, nil
}

func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
ks := make([]database.DBCryptKey, 0)
ks = append(ks, q.dbcryptKeys...)
return ks, nil
}

func (q *FakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
Expand Down Expand Up @@ -1393,6 +1421,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL
return database.GitAuthLink{}, sql.ErrNoRows
}

func (q *FakeQuerier) GetGitAuthLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
gals := make([]database.GitAuthLink, 0)
for _, gal := range q.gitAuthLinks {
if gal.UserID == userID {
gals = append(gals, gal)
}
}
return gals, nil
}

func (q *FakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
Expand Down Expand Up @@ -2833,6 +2873,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
return database.UserLink{}, sql.ErrNoRows
}

func (q *FakeQuerier) GetUserLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.UserLink, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
uls := make([]database.UserLink, 0)
for _, ul := range q.userLinks {
if ul.UserID == userID {
uls = append(uls, ul)
}
}
return uls, nil
}

func (q *FakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) {
if err := validateDatabaseType(params); err != nil {
return nil, err
Expand Down Expand Up @@ -3846,6 +3898,26 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit
return alog, nil
}

func (q *FakeQuerier) InsertDBCryptKey(_ context.Context, arg database.InsertDBCryptKeyParams) error {
err := validateDatabaseType(arg)
if err != nil {
return err
}

for _, key := range q.dbcryptKeys {
if key.Number == arg.Number {
return errDuplicateKey
}
}

q.dbcryptKeys = append(q.dbcryptKeys, database.DBCryptKey{
Number: arg.Number,
ActiveKeyDigest: sql.NullString{String: arg.ActiveKeyDigest, Valid: true},
Test: arg.Test,
})
return nil
}

func (q *FakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error {
q.mutex.Lock()
defer q.mutex.Unlock()
Expand Down Expand Up @@ -3892,13 +3964,15 @@ func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGi
defer q.mutex.Unlock()
// nolint:gosimple
gitAuthLink := database.GitAuthLink{
ProviderID: arg.ProviderID,
UserID: arg.UserID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
OAuthAccessToken: arg.OAuthAccessToken,
OAuthRefreshToken: arg.OAuthRefreshToken,
OAuthExpiry: arg.OAuthExpiry,
ProviderID: arg.ProviderID,
UserID: arg.UserID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
OAuthAccessToken: arg.OAuthAccessToken,
OAuthAccessTokenKeyID: arg.OAuthAccessTokenKeyID,
OAuthRefreshToken: arg.OAuthRefreshToken,
OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID,
OAuthExpiry: arg.OAuthExpiry,
}
q.gitAuthLinks = append(q.gitAuthLinks, gitAuthLink)
return gitAuthLink, nil
Expand Down Expand Up @@ -4362,12 +4436,14 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser

//nolint:gosimple
link := database.UserLink{
UserID: args.UserID,
LoginType: args.LoginType,
LinkedID: args.LinkedID,
OAuthAccessToken: args.OAuthAccessToken,
OAuthRefreshToken: args.OAuthRefreshToken,
OAuthExpiry: args.OAuthExpiry,
UserID: args.UserID,
LoginType: args.LoginType,
LinkedID: args.LinkedID,
OAuthAccessToken: args.OAuthAccessToken,
OAuthAccessTokenKeyID: args.OAuthAccessTokenKeyID,
OAuthRefreshToken: args.OAuthRefreshToken,
OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID,
OAuthExpiry: args.OAuthExpiry,
}

q.userLinks = append(q.userLinks, link)
Expand Down Expand Up @@ -4793,6 +4869,46 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
return database.WorkspaceProxy{}, sql.ErrNoRows
}

func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
q.mutex.Lock()
defer q.mutex.Unlock()

for i := range q.dbcryptKeys {
key := q.dbcryptKeys[i]

// Is the key already revoked?
if !key.ActiveKeyDigest.Valid {
continue
}

if key.ActiveKeyDigest.String != activeKeyDigest {
continue
}

// Check for foreign key constraints.
for _, ul := range q.userLinks {
if (ul.OAuthAccessTokenKeyID.Valid && ul.OAuthAccessTokenKeyID.String == activeKeyDigest) ||
(ul.OAuthRefreshTokenKeyID.Valid && ul.OAuthRefreshTokenKeyID.String == activeKeyDigest) {
return errForeignKeyConstraint
}
}
for _, gal := range q.gitAuthLinks {
if (gal.OAuthAccessTokenKeyID.Valid && gal.OAuthAccessTokenKeyID.String == activeKeyDigest) ||
(gal.OAuthRefreshTokenKeyID.Valid && gal.OAuthRefreshTokenKeyID.String == activeKeyDigest) {
return errForeignKeyConstraint
}
}

// Revoke the key.
q.dbcryptKeys[i].RevokedAt = sql.NullTime{Time: dbtime.Now(), Valid: true}
q.dbcryptKeys[i].RevokedKeyDigest = sql.NullString{String: key.ActiveKeyDigest.String, Valid: true}
q.dbcryptKeys[i].ActiveKeyDigest = sql.NullString{}
return nil
}

return sql.ErrNoRows
}

func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) {
return false, xerrors.New("TryAcquireLock must only be called within a transaction")
}
Expand Down Expand Up @@ -4834,7 +4950,9 @@ func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
}
gitAuthLink.UpdatedAt = arg.UpdatedAt
gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken
gitAuthLink.OAuthAccessTokenKeyID = arg.OAuthAccessTokenKeyID
gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken
gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID
gitAuthLink.OAuthExpiry = arg.OAuthExpiry
q.gitAuthLinks[index] = gitAuthLink

Expand Down Expand Up @@ -5306,7 +5424,9 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
for i, link := range q.userLinks {
if link.UserID == params.UserID && link.LoginType == params.LoginType {
link.OAuthAccessToken = params.OAuthAccessToken
link.OAuthAccessTokenKeyID = params.OAuthAccessTokenKeyID
link.OAuthRefreshToken = params.OAuthRefreshToken
link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID
link.OAuthExpiry = params.OAuthExpiry

q.userLinks[i] = link
Expand Down
30 changes: 17 additions & 13 deletions coderd/database/dbgen/dbgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,14 @@ func File(t testing.TB, db database.Store, orig database.File) database.File {

func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.UserLink {
link, err := db.InsertUserLink(genCtx, database.InsertUserLinkParams{
UserID: takeFirst(orig.UserID, uuid.New()),
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
LinkedID: takeFirst(orig.LinkedID),
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
UserID: takeFirst(orig.UserID, uuid.New()),
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
LinkedID: takeFirst(orig.LinkedID),
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}),
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
})

require.NoError(t, err, "insert link")
Expand All @@ -484,13 +486,15 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.

func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) database.GitAuthLink {
link, err := db.InsertGitAuthLink(genCtx, database.InsertGitAuthLinkParams{
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
UserID: takeFirst(orig.UserID, uuid.New()),
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
UserID: takeFirst(orig.UserID, uuid.New()),
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}),
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
})

require.NoError(t, err, "insert git auth link")
Expand Down
Loading