diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index e4b802092f03d..8ddd779d795e9 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -838,6 +838,13 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI return q.db.GetAuthorizationUserRoles(ctx, userID) } +func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetDBCryptKeys(ctx) +} + func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { return "", err @@ -914,6 +921,13 @@ func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLin return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg) } +func (q *querier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetGitAuthLinksByUserID(ctx, userID) +} + func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) { return fetch(q.log, q.auth, q.db.GetGitSSHKey)(ctx, userID) } @@ -1482,6 +1496,13 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database return q.db.GetUserLinkByUserIDLoginType(ctx, arg) } +func (q *querier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetUserLinksByUserID(ctx, userID) +} + func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) { // This does the filtering in SQL. prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) @@ -1845,6 +1866,13 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg) } +func (q *querier) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error { + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { + return err + } + return q.db.InsertDBCryptKey(ctx, arg) +} + func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error { if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil { return err @@ -2144,6 +2172,13 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg) } +func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + return err + } + return q.db.RevokeDBCryptKey(ctx, activeKeyDigest) +} + func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) { return q.db.TryAcquireLock(ctx, id) } diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 9ba29ddf6d682..e73578a61a7df 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -31,6 +31,11 @@ import ( var validProxyByHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) +var errForeignKeyConstraint = &pq.Error{ + Code: "23503", + Message: "update or delete on table violates foreign key constraint", +} + var errDuplicateKey = &pq.Error{ Code: "23505", Message: "duplicate key value violates unique constraint", @@ -45,6 +50,7 @@ func New() database.Store { organizationMembers: make([]database.OrganizationMember, 0), organizations: make([]database.Organization, 0), users: make([]database.User, 0), + dbcryptKeys: make([]database.DBCryptKey, 0), gitAuthLinks: make([]database.GitAuthLink, 0), groups: make([]database.Group, 0), groupMembers: make([]database.GroupMember, 0), @@ -117,6 +123,7 @@ type data struct { // New tables workspaceAgentStats []database.WorkspaceAgentStat auditLogs []database.AuditLog + dbcryptKeys []database.DBCryptKey files []database.File gitAuthLinks []database.GitAuthLink gitSSHKey []database.GitSSHKey @@ -665,6 +672,19 @@ func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool { return false } +func (q *FakeQuerier) GetActiveDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + ks := make([]database.DBCryptKey, 0, len(q.dbcryptKeys)) + for _, k := range q.dbcryptKeys { + if !k.ActiveKeyDigest.Valid { + continue + } + ks = append([]database.DBCryptKey{}, k) + } + return ks, nil +} + func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error { return xerrors.New("AcquireLock must only be called within a transaction") } @@ -1151,6 +1171,14 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U }, nil } +func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + ks := make([]database.DBCryptKey, 0) + ks = append(ks, q.dbcryptKeys...) + return ks, nil +} + func (q *FakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -1393,6 +1421,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL return database.GitAuthLink{}, sql.ErrNoRows } +func (q *FakeQuerier) GetGitAuthLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + gals := make([]database.GitAuthLink, 0) + for _, gal := range q.gitAuthLinks { + if gal.UserID == userID { + gals = append(gals, gal) + } + } + return gals, nil +} + func (q *FakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2833,6 +2873,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat return database.UserLink{}, sql.ErrNoRows } +func (q *FakeQuerier) GetUserLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + uls := make([]database.UserLink, 0) + for _, ul := range q.userLinks { + if ul.UserID == userID { + uls = append(uls, ul) + } + } + return uls, nil +} + func (q *FakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) { if err := validateDatabaseType(params); err != nil { return nil, err @@ -3846,6 +3898,26 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit return alog, nil } +func (q *FakeQuerier) InsertDBCryptKey(_ context.Context, arg database.InsertDBCryptKeyParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + + for _, key := range q.dbcryptKeys { + if key.Number == arg.Number { + return errDuplicateKey + } + } + + q.dbcryptKeys = append(q.dbcryptKeys, database.DBCryptKey{ + Number: arg.Number, + ActiveKeyDigest: sql.NullString{String: arg.ActiveKeyDigest, Valid: true}, + Test: arg.Test, + }) + return nil +} + func (q *FakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error { q.mutex.Lock() defer q.mutex.Unlock() @@ -3892,13 +3964,15 @@ func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGi defer q.mutex.Unlock() // nolint:gosimple gitAuthLink := database.GitAuthLink{ - ProviderID: arg.ProviderID, - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OAuthAccessToken: arg.OAuthAccessToken, - OAuthRefreshToken: arg.OAuthRefreshToken, - OAuthExpiry: arg.OAuthExpiry, + ProviderID: arg.ProviderID, + UserID: arg.UserID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + OAuthAccessToken: arg.OAuthAccessToken, + OAuthAccessTokenKeyID: arg.OAuthAccessTokenKeyID, + OAuthRefreshToken: arg.OAuthRefreshToken, + OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID, + OAuthExpiry: arg.OAuthExpiry, } q.gitAuthLinks = append(q.gitAuthLinks, gitAuthLink) return gitAuthLink, nil @@ -4362,12 +4436,14 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser //nolint:gosimple link := database.UserLink{ - UserID: args.UserID, - LoginType: args.LoginType, - LinkedID: args.LinkedID, - OAuthAccessToken: args.OAuthAccessToken, - OAuthRefreshToken: args.OAuthRefreshToken, - OAuthExpiry: args.OAuthExpiry, + UserID: args.UserID, + LoginType: args.LoginType, + LinkedID: args.LinkedID, + OAuthAccessToken: args.OAuthAccessToken, + OAuthAccessTokenKeyID: args.OAuthAccessTokenKeyID, + OAuthRefreshToken: args.OAuthRefreshToken, + OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID, + OAuthExpiry: args.OAuthExpiry, } q.userLinks = append(q.userLinks, link) @@ -4793,6 +4869,46 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg return database.WorkspaceProxy{}, sql.ErrNoRows } +func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + for i := range q.dbcryptKeys { + key := q.dbcryptKeys[i] + + // Is the key already revoked? + if !key.ActiveKeyDigest.Valid { + continue + } + + if key.ActiveKeyDigest.String != activeKeyDigest { + continue + } + + // Check for foreign key constraints. + for _, ul := range q.userLinks { + if (ul.OAuthAccessTokenKeyID.Valid && ul.OAuthAccessTokenKeyID.String == activeKeyDigest) || + (ul.OAuthRefreshTokenKeyID.Valid && ul.OAuthRefreshTokenKeyID.String == activeKeyDigest) { + return errForeignKeyConstraint + } + } + for _, gal := range q.gitAuthLinks { + if (gal.OAuthAccessTokenKeyID.Valid && gal.OAuthAccessTokenKeyID.String == activeKeyDigest) || + (gal.OAuthRefreshTokenKeyID.Valid && gal.OAuthRefreshTokenKeyID.String == activeKeyDigest) { + return errForeignKeyConstraint + } + } + + // Revoke the key. + q.dbcryptKeys[i].RevokedAt = sql.NullTime{Time: dbtime.Now(), Valid: true} + q.dbcryptKeys[i].RevokedKeyDigest = sql.NullString{String: key.ActiveKeyDigest.String, Valid: true} + q.dbcryptKeys[i].ActiveKeyDigest = sql.NullString{} + return nil + } + + return sql.ErrNoRows +} + func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) { return false, xerrors.New("TryAcquireLock must only be called within a transaction") } @@ -4834,7 +4950,9 @@ func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi } gitAuthLink.UpdatedAt = arg.UpdatedAt gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken + gitAuthLink.OAuthAccessTokenKeyID = arg.OAuthAccessTokenKeyID gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken + gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID gitAuthLink.OAuthExpiry = arg.OAuthExpiry q.gitAuthLinks[index] = gitAuthLink @@ -5306,7 +5424,9 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs for i, link := range q.userLinks { if link.UserID == params.UserID && link.LoginType == params.LoginType { link.OAuthAccessToken = params.OAuthAccessToken + link.OAuthAccessTokenKeyID = params.OAuthAccessTokenKeyID link.OAuthRefreshToken = params.OAuthRefreshToken + link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID link.OAuthExpiry = params.OAuthExpiry q.userLinks[i] = link diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 12036274eb811..f68a8cebadcc8 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -470,12 +470,14 @@ func File(t testing.TB, db database.Store, orig database.File) database.File { func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.UserLink { link, err := db.InsertUserLink(genCtx, database.InsertUserLinkParams{ - UserID: takeFirst(orig.UserID, uuid.New()), - LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub), - LinkedID: takeFirst(orig.LinkedID), - OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)), + UserID: takeFirst(orig.UserID, uuid.New()), + LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub), + LinkedID: takeFirst(orig.LinkedID), + OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}), + OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()), + OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}), + OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)), }) require.NoError(t, err, "insert link") @@ -484,13 +486,15 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database. func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) database.GitAuthLink { link, err := db.InsertGitAuthLink(genCtx, database.InsertGitAuthLinkParams{ - ProviderID: takeFirst(orig.ProviderID, uuid.New().String()), - UserID: takeFirst(orig.UserID, uuid.New()), - OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)), - CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()), + ProviderID: takeFirst(orig.ProviderID, uuid.New().String()), + UserID: takeFirst(orig.UserID, uuid.New()), + OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}), + OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()), + OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}), + OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)), + CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()), }) require.NoError(t, err, "insert git auth link") diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 8526eb4da1078..0a02896200f60 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -279,6 +279,13 @@ func (m metricsStore) GetAuthorizationUserRoles(ctx context.Context, userID uuid return row, err } +func (m metricsStore) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) { + start := time.Now() + r0, r1 := m.s.GetDBCryptKeys(ctx) + m.queryLatencies.WithLabelValues("GetDBCryptKeys").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetDERPMeshKey(ctx context.Context) (string, error) { start := time.Now() key, err := m.s.GetDERPMeshKey(ctx) @@ -349,6 +356,13 @@ func (m metricsStore) GetGitAuthLink(ctx context.Context, arg database.GetGitAut return link, err } +func (m metricsStore) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) { + start := time.Now() + r0, r1 := m.s.GetGitAuthLinksByUserID(ctx, userID) + m.queryLatencies.WithLabelValues("GetGitAuthLinksByUserID").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) { start := time.Now() key, err := m.s.GetGitSSHKey(ctx, userID) @@ -774,6 +788,13 @@ func (m metricsStore) GetUserLinkByUserIDLoginType(ctx context.Context, arg data return link, err } +func (m metricsStore) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) { + start := time.Now() + r0, r1 := m.s.GetUserLinksByUserID(ctx, userID) + m.queryLatencies.WithLabelValues("GetUserLinksByUserID").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) { start := time.Now() users, err := m.s.GetUsers(ctx, arg) @@ -1068,6 +1089,13 @@ func (m metricsStore) InsertAuditLog(ctx context.Context, arg database.InsertAud return log, err } +func (m metricsStore) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error { + start := time.Now() + r0 := m.s.InsertDBCryptKey(ctx, arg) + m.queryLatencies.WithLabelValues("InsertDBCryptKey").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) InsertDERPMeshKey(ctx context.Context, value string) error { start := time.Now() err := m.s.InsertDERPMeshKey(ctx, value) @@ -1320,6 +1348,13 @@ func (m metricsStore) RegisterWorkspaceProxy(ctx context.Context, arg database.R return proxy, err } +func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { + start := time.Now() + r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest) + m.queryLatencies.WithLabelValues("RevokeDBCryptKey").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) { start := time.Now() ok, err := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index b0ae7955a458d..be1f994d81161 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -506,6 +506,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspaces(arg0, arg1, arg2 interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspaces", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspaces), arg0, arg1, arg2) } +// GetDBCryptKeys mocks base method. +func (m *MockStore) GetDBCryptKeys(arg0 context.Context) ([]database.DBCryptKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDBCryptKeys", arg0) + ret0, _ := ret[0].([]database.DBCryptKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDBCryptKeys indicates an expected call of GetDBCryptKeys. +func (mr *MockStoreMockRecorder) GetDBCryptKeys(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDBCryptKeys", reflect.TypeOf((*MockStore)(nil).GetDBCryptKeys), arg0) +} + // GetDERPMeshKey mocks base method. func (m *MockStore) GetDERPMeshKey(arg0 context.Context) (string, error) { m.ctrl.T.Helper() @@ -656,6 +671,21 @@ func (mr *MockStoreMockRecorder) GetGitAuthLink(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGitAuthLink", reflect.TypeOf((*MockStore)(nil).GetGitAuthLink), arg0, arg1) } +// GetGitAuthLinksByUserID mocks base method. +func (m *MockStore) GetGitAuthLinksByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.GitAuthLink, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGitAuthLinksByUserID", arg0, arg1) + ret0, _ := ret[0].([]database.GitAuthLink) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetGitAuthLinksByUserID indicates an expected call of GetGitAuthLinksByUserID. +func (mr *MockStoreMockRecorder) GetGitAuthLinksByUserID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGitAuthLinksByUserID", reflect.TypeOf((*MockStore)(nil).GetGitAuthLinksByUserID), arg0, arg1) +} + // GetGitSSHKey mocks base method. func (m *MockStore) GetGitSSHKey(arg0 context.Context, arg1 uuid.UUID) (database.GitSSHKey, error) { m.ctrl.T.Helper() @@ -1601,6 +1631,21 @@ func (mr *MockStoreMockRecorder) GetUserLinkByUserIDLoginType(arg0, arg1 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserLinkByUserIDLoginType", reflect.TypeOf((*MockStore)(nil).GetUserLinkByUserIDLoginType), arg0, arg1) } +// GetUserLinksByUserID mocks base method. +func (m *MockStore) GetUserLinksByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.UserLink, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserLinksByUserID", arg0, arg1) + ret0, _ := ret[0].([]database.UserLink) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserLinksByUserID indicates an expected call of GetUserLinksByUserID. +func (mr *MockStoreMockRecorder) GetUserLinksByUserID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserLinksByUserID", reflect.TypeOf((*MockStore)(nil).GetUserLinksByUserID), arg0, arg1) +} + // GetUsers mocks base method. func (m *MockStore) GetUsers(arg0 context.Context, arg1 database.GetUsersParams) ([]database.GetUsersRow, error) { m.ctrl.T.Helper() @@ -2245,6 +2290,20 @@ func (mr *MockStoreMockRecorder) InsertAuditLog(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAuditLog", reflect.TypeOf((*MockStore)(nil).InsertAuditLog), arg0, arg1) } +// InsertDBCryptKey mocks base method. +func (m *MockStore) InsertDBCryptKey(arg0 context.Context, arg1 database.InsertDBCryptKeyParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertDBCryptKey", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// InsertDBCryptKey indicates an expected call of InsertDBCryptKey. +func (mr *MockStoreMockRecorder) InsertDBCryptKey(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertDBCryptKey", reflect.TypeOf((*MockStore)(nil).InsertDBCryptKey), arg0, arg1) +} + // InsertDERPMeshKey mocks base method. func (m *MockStore) InsertDERPMeshKey(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() @@ -2789,6 +2848,20 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), arg0, arg1) } +// RevokeDBCryptKey mocks base method. +func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevokeDBCryptKey", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RevokeDBCryptKey indicates an expected call of RevokeDBCryptKey. +func (mr *MockStoreMockRecorder) RevokeDBCryptKey(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeDBCryptKey", reflect.TypeOf((*MockStore)(nil).RevokeDBCryptKey), arg0, arg1) +} + // TryAcquireLock mocks base method. func (m *MockStore) TryAcquireLock(arg0 context.Context, arg1 int64) (bool, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 48e4e1a8862c1..3ee0ac7e19894 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -275,6 +275,29 @@ CREATE TABLE audit_logs ( resource_icon text NOT NULL ); +CREATE TABLE dbcrypt_keys ( + number integer NOT NULL, + active_key_digest text, + revoked_key_digest text, + created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP, + revoked_at timestamp with time zone, + test text NOT NULL +); + +COMMENT ON TABLE dbcrypt_keys IS 'A table used to store the keys used to encrypt the database.'; + +COMMENT ON COLUMN dbcrypt_keys.number IS 'An integer used to identify the key.'; + +COMMENT ON COLUMN dbcrypt_keys.active_key_digest IS 'If the key is active, the digest of the active key.'; + +COMMENT ON COLUMN dbcrypt_keys.revoked_key_digest IS 'If the key has been revoked, the digest of the revoked key.'; + +COMMENT ON COLUMN dbcrypt_keys.created_at IS 'The time at which the key was created.'; + +COMMENT ON COLUMN dbcrypt_keys.revoked_at IS 'The time at which the key was revoked.'; + +COMMENT ON COLUMN dbcrypt_keys.test IS 'A column used to test the encryption.'; + CREATE TABLE files ( hash character varying(64) NOT NULL, created_at timestamp with time zone NOT NULL, @@ -291,9 +314,15 @@ CREATE TABLE git_auth_links ( updated_at timestamp with time zone NOT NULL, oauth_access_token text NOT NULL, oauth_refresh_token text NOT NULL, - oauth_expiry timestamp with time zone NOT NULL + oauth_expiry timestamp with time zone NOT NULL, + oauth_access_token_key_id text, + oauth_refresh_token_key_id text ); +COMMENT ON COLUMN git_auth_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted'; + +COMMENT ON COLUMN git_auth_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted'; + CREATE TABLE gitsshkeys ( user_id uuid NOT NULL, created_at timestamp with time zone NOT NULL, @@ -701,9 +730,15 @@ CREATE TABLE user_links ( linked_id text DEFAULT ''::text NOT NULL, oauth_access_token text DEFAULT ''::text NOT NULL, oauth_refresh_token text DEFAULT ''::text NOT NULL, - oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL + oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, + oauth_access_token_key_id text, + oauth_refresh_token_key_id text ); +COMMENT ON COLUMN user_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted'; + +COMMENT ON COLUMN user_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted'; + CREATE TABLE workspace_agent_logs ( agent_id uuid NOT NULL, created_at timestamp with time zone NOT NULL, @@ -1037,6 +1072,15 @@ ALTER TABLE ONLY api_keys ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); +ALTER TABLE ONLY dbcrypt_keys + ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest); + +ALTER TABLE ONLY dbcrypt_keys + ADD CONSTRAINT dbcrypt_keys_pkey PRIMARY KEY (number); + +ALTER TABLE ONLY dbcrypt_keys + ADD CONSTRAINT dbcrypt_keys_revoked_key_digest_key UNIQUE (revoked_key_digest); + ALTER TABLE ONLY files ADD CONSTRAINT files_hash_created_by_key UNIQUE (hash, created_by); @@ -1249,6 +1293,12 @@ CREATE TRIGGER trigger_update_users AFTER INSERT OR UPDATE ON users FOR EACH ROW ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; +ALTER TABLE ONLY git_auth_links + ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); + +ALTER TABLE ONLY git_auth_links + ADD CONSTRAINT git_auth_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ALTER TABLE ONLY gitsshkeys ADD CONSTRAINT gitsshkeys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id); @@ -1303,6 +1353,12 @@ ALTER TABLE ONLY templates ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; +ALTER TABLE ONLY user_links + ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); + +ALTER TABLE ONLY user_links + ADD CONSTRAINT user_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000154_dbcrypt_key_ids.down.sql b/coderd/database/migrations/000154_dbcrypt_key_ids.down.sql new file mode 100644 index 0000000000000..7dea0a1909227 --- /dev/null +++ b/coderd/database/migrations/000154_dbcrypt_key_ids.down.sql @@ -0,0 +1,43 @@ +BEGIN; + +-- Before dropping this table, we need to check if there exist any +-- foreign key references to it. We do this by checking the following: +-- user_links.oauth_access_token_key_id +-- user_links.oauth_refresh_token_key_id +-- git_auth_links.oauth_access_token_key_id +-- git_auth_links.oauth_refresh_token_key_id +DO $$ +BEGIN +IF EXISTS ( + SELECT * + FROM user_links + WHERE oauth_access_token_key_id IS NOT NULL + OR oauth_refresh_token_key_id IS NOT NULL + ) THEN RAISE EXCEPTION 'Cannot drop dbcrypt_keys table as there are still foreign key references to it from user_links.'; +END IF; + +IF EXISTS ( + SELECT * + FROM git_auth_links + WHERE oauth_access_token_key_id IS NOT NULL + OR oauth_refresh_token_key_id IS NOT NULL + ) THEN RAISE EXCEPTION 'Cannot drop dbcrypt_keys table as there are still foreign key references to it from git_auth_links.'; +END IF; + +END +$$; + + +-- Drop the columns first. +ALTER TABLE git_auth_links + DROP COLUMN IF EXISTS oauth_access_token_key_id, + DROP COLUMN IF EXISTS oauth_refresh_token_key_id; + +ALTER TABLE user_links + DROP COLUMN IF EXISTS oauth_access_token_key_id, + DROP COLUMN IF EXISTS oauth_refresh_token_key_id; + +-- Finally, drop the table. +DROP TABLE IF EXISTS dbcrypt_keys; + +COMMIT; diff --git a/coderd/database/migrations/000154_dbcrypt_key_ids.up.sql b/coderd/database/migrations/000154_dbcrypt_key_ids.up.sql new file mode 100644 index 0000000000000..804c41b82b67f --- /dev/null +++ b/coderd/database/migrations/000154_dbcrypt_key_ids.up.sql @@ -0,0 +1,30 @@ +CREATE TABLE IF NOT EXISTS dbcrypt_keys ( + number int NOT NULL PRIMARY KEY, + active_key_digest text UNIQUE, + revoked_key_digest text UNIQUE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + revoked_at TIMESTAMP WITH TIME ZONE DEFAULT NULL, + test TEXT NOT NULL +); + +COMMENT ON TABLE dbcrypt_keys IS 'A table used to store the keys used to encrypt the database.'; +COMMENT ON COLUMN dbcrypt_keys.number IS 'An integer used to identify the key.'; +COMMENT ON COLUMN dbcrypt_keys.active_key_digest IS 'If the key is active, the digest of the active key.'; +COMMENT ON COLUMN dbcrypt_keys.revoked_key_digest IS 'If the key has been revoked, the digest of the revoked key.'; +COMMENT ON COLUMN dbcrypt_keys.created_at IS 'The time at which the key was created.'; +COMMENT ON COLUMN dbcrypt_keys.revoked_at IS 'The time at which the key was revoked.'; +COMMENT ON COLUMN dbcrypt_keys.test IS 'A column used to test the encryption.'; + +ALTER TABLE git_auth_links +ADD COLUMN IF NOT EXISTS oauth_access_token_key_id text REFERENCES dbcrypt_keys(active_key_digest), +ADD COLUMN IF NOT EXISTS oauth_refresh_token_key_id text REFERENCES dbcrypt_keys(active_key_digest); + +COMMENT ON COLUMN git_auth_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted'; +COMMENT ON COLUMN git_auth_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted'; + +ALTER TABLE user_links +ADD COLUMN IF NOT EXISTS oauth_access_token_key_id text REFERENCES dbcrypt_keys(active_key_digest), +ADD COLUMN IF NOT EXISTS oauth_refresh_token_key_id text REFERENCES dbcrypt_keys(active_key_digest); + +COMMENT ON COLUMN user_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted'; +COMMENT ON COLUMN user_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted'; diff --git a/coderd/database/migrations/migrate_test.go b/coderd/database/migrations/migrate_test.go index a138e58bac05f..b512811f2ab18 100644 --- a/coderd/database/migrations/migrate_test.go +++ b/coderd/database/migrations/migrate_test.go @@ -266,6 +266,7 @@ func TestMigrateUpWithFixtures(t *testing.T) { "template_version_parameters", "workspace_build_parameters", "template_version_variables", + "dbcrypt_keys", // having zero rows is a valid state for this table } s := &tableStats{s: make(map[string]int)} diff --git a/coderd/database/models.go b/coderd/database/models.go index cd9ba8f990d2d..4d1852a54114e 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1591,6 +1591,22 @@ type AuditLog struct { ResourceIcon string `db:"resource_icon" json:"resource_icon"` } +// A table used to store the keys used to encrypt the database. +type DBCryptKey struct { + // An integer used to identify the key. + Number int32 `db:"number" json:"number"` + // If the key is active, the digest of the active key. + ActiveKeyDigest sql.NullString `db:"active_key_digest" json:"active_key_digest"` + // If the key has been revoked, the digest of the revoked key. + RevokedKeyDigest sql.NullString `db:"revoked_key_digest" json:"revoked_key_digest"` + // The time at which the key was created. + CreatedAt sql.NullTime `db:"created_at" json:"created_at"` + // The time at which the key was revoked. + RevokedAt sql.NullTime `db:"revoked_at" json:"revoked_at"` + // A column used to test the encryption. + Test string `db:"test" json:"test"` +} + type File struct { Hash string `db:"hash" json:"hash"` CreatedAt time.Time `db:"created_at" json:"created_at"` @@ -1608,6 +1624,10 @@ type GitAuthLink struct { OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + // The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted + OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` + // The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` } type GitSSHKey struct { @@ -1949,6 +1969,10 @@ type UserLink struct { OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + // The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted + OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` + // The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` } // Visible fields of users are allowed to be joined with other tables for including context of other resources. diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 520266bd1d25c..cdf4d184544bb 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -58,6 +58,7 @@ type sqlcQuerier interface { // This function returns roles for authorization purposes. Implied member roles // are included. GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error) + GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error) GetDERPMeshKey(ctx context.Context) (string, error) GetDefaultProxyConfig(ctx context.Context) (GetDefaultProxyConfigRow, error) GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]GetDeploymentDAUsRow, error) @@ -69,6 +70,7 @@ type sqlcQuerier interface { // Get all templates that use a file. GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]GetFileTemplatesRow, error) GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParams) (GitAuthLink, error) + GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]GitAuthLink, error) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error) GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error) GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error) @@ -150,6 +152,7 @@ type sqlcQuerier interface { GetUserLatencyInsights(ctx context.Context, arg GetUserLatencyInsightsParams) ([]GetUserLatencyInsightsRow, error) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error) + GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error) // This will never return deleted users. GetUsers(ctx context.Context, arg GetUsersParams) ([]GetUsersRow, error) // This shouldn't check for deleted, because it's frequently used @@ -206,6 +209,7 @@ type sqlcQuerier interface { // every member of the org. InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) + InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error InsertDERPMeshKey(ctx context.Context, value string) error InsertDeploymentID(ctx context.Context, value string) error InsertFile(ctx context.Context, arg InsertFileParams) (File, error) @@ -247,6 +251,7 @@ type sqlcQuerier interface { InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error) RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) + RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error // Non blocking lock. Returns true if the lock was acquired, false otherwise. // // This must be called from within a transaction. The lock will be automatically diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index dbbb5d4085e93..4d9bc72a37157 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -636,6 +636,74 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam return i, err } +const getDBCryptKeys = `-- name: GetDBCryptKeys :many +SELECT number, active_key_digest, revoked_key_digest, created_at, revoked_at, test FROM dbcrypt_keys ORDER BY number ASC +` + +func (q *sqlQuerier) GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error) { + rows, err := q.db.QueryContext(ctx, getDBCryptKeys) + if err != nil { + return nil, err + } + defer rows.Close() + var items []DBCryptKey + for rows.Next() { + var i DBCryptKey + if err := rows.Scan( + &i.Number, + &i.ActiveKeyDigest, + &i.RevokedKeyDigest, + &i.CreatedAt, + &i.RevokedAt, + &i.Test, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertDBCryptKey = `-- name: InsertDBCryptKey :exec +INSERT INTO dbcrypt_keys + (number, active_key_digest, created_at, test) +VALUES ($1::int, $2::text, CURRENT_TIMESTAMP, $3::text) +` + +type InsertDBCryptKeyParams struct { + Number int32 `db:"number" json:"number"` + ActiveKeyDigest string `db:"active_key_digest" json:"active_key_digest"` + Test string `db:"test" json:"test"` +} + +func (q *sqlQuerier) InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error { + _, err := q.db.ExecContext(ctx, insertDBCryptKey, arg.Number, arg.ActiveKeyDigest, arg.Test) + return err +} + +const revokeDBCryptKey = `-- name: RevokeDBCryptKey :exec +UPDATE dbcrypt_keys +SET + revoked_key_digest = active_key_digest, + active_key_digest = revoked_key_digest, + revoked_at = CURRENT_TIMESTAMP +WHERE + active_key_digest = $1::text +AND + revoked_key_digest IS NULL +` + +func (q *sqlQuerier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { + _, err := q.db.ExecContext(ctx, revokeDBCryptKey, activeKeyDigest) + return err +} + const getFileByHashAndCreator = `-- name: GetFileByHashAndCreator :one SELECT hash, created_at, created_by, mimetype, data, id @@ -800,7 +868,7 @@ func (q *sqlQuerier) InsertFile(ctx context.Context, arg InsertFileParams) (File } const getGitAuthLink = `-- name: GetGitAuthLink :one -SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry FROM git_auth_links WHERE provider_id = $1 AND user_id = $2 +SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM git_auth_links WHERE provider_id = $1 AND user_id = $2 ` type GetGitAuthLinkParams struct { @@ -819,10 +887,49 @@ func (q *sqlQuerier) GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParam &i.OAuthAccessToken, &i.OAuthRefreshToken, &i.OAuthExpiry, + &i.OAuthAccessTokenKeyID, + &i.OAuthRefreshTokenKeyID, ) return i, err } +const getGitAuthLinksByUserID = `-- name: GetGitAuthLinksByUserID :many +SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM git_auth_links WHERE user_id = $1 +` + +func (q *sqlQuerier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]GitAuthLink, error) { + rows, err := q.db.QueryContext(ctx, getGitAuthLinksByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GitAuthLink + for rows.Next() { + var i GitAuthLink + if err := rows.Scan( + &i.ProviderID, + &i.UserID, + &i.CreatedAt, + &i.UpdatedAt, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthExpiry, + &i.OAuthAccessTokenKeyID, + &i.OAuthRefreshTokenKeyID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertGitAuthLink = `-- name: InsertGitAuthLink :one INSERT INTO git_auth_links ( provider_id, @@ -830,7 +937,9 @@ INSERT INTO git_auth_links ( created_at, updated_at, oauth_access_token, + oauth_access_token_key_id, oauth_refresh_token, + oauth_refresh_token_key_id, oauth_expiry ) VALUES ( $1, @@ -839,18 +948,22 @@ INSERT INTO git_auth_links ( $4, $5, $6, - $7 -) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry + $7, + $8, + $9 +) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id ` type InsertGitAuthLinkParams struct { - ProviderID string `db:"provider_id" json:"provider_id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + ProviderID string `db:"provider_id" json:"provider_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } func (q *sqlQuerier) InsertGitAuthLink(ctx context.Context, arg InsertGitAuthLinkParams) (GitAuthLink, error) { @@ -860,7 +973,9 @@ func (q *sqlQuerier) InsertGitAuthLink(ctx context.Context, arg InsertGitAuthLin arg.CreatedAt, arg.UpdatedAt, arg.OAuthAccessToken, + arg.OAuthAccessTokenKeyID, arg.OAuthRefreshToken, + arg.OAuthRefreshTokenKeyID, arg.OAuthExpiry, ) var i GitAuthLink @@ -872,6 +987,8 @@ func (q *sqlQuerier) InsertGitAuthLink(ctx context.Context, arg InsertGitAuthLin &i.OAuthAccessToken, &i.OAuthRefreshToken, &i.OAuthExpiry, + &i.OAuthAccessTokenKeyID, + &i.OAuthRefreshTokenKeyID, ) return i, err } @@ -880,18 +997,22 @@ const updateGitAuthLink = `-- name: UpdateGitAuthLink :one UPDATE git_auth_links SET updated_at = $3, oauth_access_token = $4, - oauth_refresh_token = $5, - oauth_expiry = $6 -WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry + oauth_access_token_key_id = $5, + oauth_refresh_token = $6, + oauth_refresh_token_key_id = $7, + oauth_expiry = $8 +WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id ` type UpdateGitAuthLinkParams struct { - ProviderID string `db:"provider_id" json:"provider_id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + ProviderID string `db:"provider_id" json:"provider_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } func (q *sqlQuerier) UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLinkParams) (GitAuthLink, error) { @@ -900,7 +1021,9 @@ func (q *sqlQuerier) UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLin arg.UserID, arg.UpdatedAt, arg.OAuthAccessToken, + arg.OAuthAccessTokenKeyID, arg.OAuthRefreshToken, + arg.OAuthRefreshTokenKeyID, arg.OAuthExpiry, ) var i GitAuthLink @@ -912,6 +1035,8 @@ func (q *sqlQuerier) UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLin &i.OAuthAccessToken, &i.OAuthRefreshToken, &i.OAuthExpiry, + &i.OAuthAccessTokenKeyID, + &i.OAuthRefreshTokenKeyID, ) return i, err } @@ -5450,7 +5575,7 @@ func (q *sqlQuerier) InsertTemplateVersionVariable(ctx context.Context, arg Inse const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one SELECT - user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM user_links WHERE @@ -5467,13 +5592,15 @@ func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) &i.OAuthAccessToken, &i.OAuthRefreshToken, &i.OAuthExpiry, + &i.OAuthAccessTokenKeyID, + &i.OAuthRefreshTokenKeyID, ) return i, err } const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one SELECT - user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry + user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM user_links WHERE @@ -5495,10 +5622,48 @@ func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUs &i.OAuthAccessToken, &i.OAuthRefreshToken, &i.OAuthExpiry, + &i.OAuthAccessTokenKeyID, + &i.OAuthRefreshTokenKeyID, ) return i, err } +const getUserLinksByUserID = `-- name: GetUserLinksByUserID :many +SELECT user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM user_links WHERE user_id = $1 +` + +func (q *sqlQuerier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error) { + rows, err := q.db.QueryContext(ctx, getUserLinksByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []UserLink + for rows.Next() { + var i UserLink + if err := rows.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthExpiry, + &i.OAuthAccessTokenKeyID, + &i.OAuthRefreshTokenKeyID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertUserLink = `-- name: InsertUserLink :one INSERT INTO user_links ( @@ -5506,20 +5671,24 @@ INSERT INTO login_type, linked_id, oauth_access_token, + oauth_access_token_key_id, oauth_refresh_token, + oauth_refresh_token_key_id, oauth_expiry ) VALUES - ( $1, $2, $3, $4, $5, $6 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry + ( $1, $2, $3, $4, $5, $6, $7, $8 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id ` type InsertUserLinkParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - LoginType LoginType `db:"login_type" json:"login_type"` - LinkedID string `db:"linked_id" json:"linked_id"` - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` + LinkedID string `db:"linked_id" json:"linked_id"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` } func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) { @@ -5528,7 +5697,9 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam arg.LoginType, arg.LinkedID, arg.OAuthAccessToken, + arg.OAuthAccessTokenKeyID, arg.OAuthRefreshToken, + arg.OAuthRefreshTokenKeyID, arg.OAuthExpiry, ) var i UserLink @@ -5539,6 +5710,8 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam &i.OAuthAccessToken, &i.OAuthRefreshToken, &i.OAuthExpiry, + &i.OAuthAccessTokenKeyID, + &i.OAuthRefreshTokenKeyID, ) return i, err } @@ -5548,24 +5721,30 @@ UPDATE user_links SET oauth_access_token = $1, - oauth_refresh_token = $2, - oauth_expiry = $3 + oauth_access_token_key_id = $2, + oauth_refresh_token = $3, + oauth_refresh_token_key_id = $4, + oauth_expiry = $5 WHERE - user_id = $4 AND login_type = $5 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry + user_id = $6 AND login_type = $7 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id ` type UpdateUserLinkParams struct { - OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` - OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` - OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - LoginType LoginType `db:"login_type" json:"login_type"` + OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"` + OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"` + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` + OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + LoginType LoginType `db:"login_type" json:"login_type"` } func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) { row := q.db.QueryRowContext(ctx, updateUserLink, arg.OAuthAccessToken, + arg.OAuthAccessTokenKeyID, arg.OAuthRefreshToken, + arg.OAuthRefreshTokenKeyID, arg.OAuthExpiry, arg.UserID, arg.LoginType, @@ -5578,6 +5757,8 @@ func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParam &i.OAuthAccessToken, &i.OAuthRefreshToken, &i.OAuthExpiry, + &i.OAuthAccessTokenKeyID, + &i.OAuthRefreshTokenKeyID, ) return i, err } @@ -5588,7 +5769,7 @@ UPDATE SET linked_id = $1 WHERE - user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry + user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id ` type UpdateUserLinkedIDParams struct { @@ -5607,6 +5788,8 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke &i.OAuthAccessToken, &i.OAuthRefreshToken, &i.OAuthExpiry, + &i.OAuthAccessTokenKeyID, + &i.OAuthRefreshTokenKeyID, ) return i, err } diff --git a/coderd/database/queries/dbcrypt.sql b/coderd/database/queries/dbcrypt.sql new file mode 100644 index 0000000000000..ef1021609d5a7 --- /dev/null +++ b/coderd/database/queries/dbcrypt.sql @@ -0,0 +1,18 @@ +-- name: GetDBCryptKeys :many +SELECT * FROM dbcrypt_keys ORDER BY number ASC; + +-- name: RevokeDBCryptKey :exec +UPDATE dbcrypt_keys +SET + revoked_key_digest = active_key_digest, + active_key_digest = revoked_key_digest, + revoked_at = CURRENT_TIMESTAMP +WHERE + active_key_digest = @active_key_digest::text +AND + revoked_key_digest IS NULL; + +-- name: InsertDBCryptKey :exec +INSERT INTO dbcrypt_keys + (number, active_key_digest, created_at, test) +VALUES (@number::int, @active_key_digest::text, CURRENT_TIMESTAMP, @test::text); diff --git a/coderd/database/queries/gitauth.sql b/coderd/database/queries/gitauth.sql index a35de98a08908..b2ce97dae1404 100644 --- a/coderd/database/queries/gitauth.sql +++ b/coderd/database/queries/gitauth.sql @@ -1,6 +1,9 @@ -- name: GetGitAuthLink :one SELECT * FROM git_auth_links WHERE provider_id = $1 AND user_id = $2; +-- name: GetGitAuthLinksByUserID :many +SELECT * FROM git_auth_links WHERE user_id = $1; + -- name: InsertGitAuthLink :one INSERT INTO git_auth_links ( provider_id, @@ -8,7 +11,9 @@ INSERT INTO git_auth_links ( created_at, updated_at, oauth_access_token, + oauth_access_token_key_id, oauth_refresh_token, + oauth_refresh_token_key_id, oauth_expiry ) VALUES ( $1, @@ -17,13 +22,17 @@ INSERT INTO git_auth_links ( $4, $5, $6, - $7 + $7, + $8, + $9 ) RETURNING *; -- name: UpdateGitAuthLink :one UPDATE git_auth_links SET updated_at = $3, oauth_access_token = $4, - oauth_refresh_token = $5, - oauth_expiry = $6 + oauth_access_token_key_id = $5, + oauth_refresh_token = $6, + oauth_refresh_token_key_id = $7, + oauth_expiry = $8 WHERE provider_id = $1 AND user_id = $2 RETURNING *; diff --git a/coderd/database/queries/user_links.sql b/coderd/database/queries/user_links.sql index 2390cb9782b30..5db3324c676a2 100644 --- a/coderd/database/queries/user_links.sql +++ b/coderd/database/queries/user_links.sql @@ -14,6 +14,9 @@ FROM WHERE user_id = $1 AND login_type = $2; +-- name: GetUserLinksByUserID :many +SELECT * FROM user_links WHERE user_id = $1; + -- name: InsertUserLink :one INSERT INTO user_links ( @@ -21,11 +24,13 @@ INSERT INTO login_type, linked_id, oauth_access_token, + oauth_access_token_key_id, oauth_refresh_token, + oauth_refresh_token_key_id, oauth_expiry ) VALUES - ( $1, $2, $3, $4, $5, $6 ) RETURNING *; + ( $1, $2, $3, $4, $5, $6, $7, $8 ) RETURNING *; -- name: UpdateUserLinkedID :one UPDATE @@ -40,7 +45,9 @@ UPDATE user_links SET oauth_access_token = $1, - oauth_refresh_token = $2, - oauth_expiry = $3 + oauth_access_token_key_id = $2, + oauth_refresh_token = $3, + oauth_refresh_token_key_id = $4, + oauth_expiry = $5 WHERE - user_id = $4 AND login_type = $5 RETURNING *; + user_id = $6 AND login_type = $7 RETURNING *; diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 73c59c257de31..1bdc972927f6f 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -40,6 +40,7 @@ overrides: api_key_scope_application_connect: APIKeyScopeApplicationConnect avatar_url: AvatarURL created_by_avatar_url: CreatedByAvatarURL + dbcrypt_key: DBCryptKey session_count_vscode: SessionCountVSCode session_count_jetbrains: SessionCountJetBrains session_count_reconnecting_pty: SessionCountReconnectingPTY @@ -47,9 +48,11 @@ overrides: connection_median_latency_ms: ConnectionMedianLatencyMS login_type_oidc: LoginTypeOIDC oauth_access_token: OAuthAccessToken + oauth_access_token_key_id: OAuthAccessTokenKeyID oauth_expiry: OAuthExpiry oauth_id_token: OAuthIDToken oauth_refresh_token: OAuthRefreshToken + oauth_refresh_token_key_id: OAuthRefreshTokenKeyID parameter_type_system_hcl: ParameterTypeSystemHCL userstatus: UserStatus gitsshkey: GitSSHKey diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index 294b4b12d51af..ea0a9a64d3137 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -6,6 +6,8 @@ type UniqueConstraint string // UniqueConstraint enums. const ( + UniqueDbcryptKeysActiveKeyDigestKey UniqueConstraint = "dbcrypt_keys_active_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest); + UniqueDbcryptKeysRevokedKeyDigestKey UniqueConstraint = "dbcrypt_keys_revoked_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_revoked_key_digest_key UNIQUE (revoked_key_digest); UniqueFilesHashCreatedByKey UniqueConstraint = "files_hash_created_by_key" // ALTER TABLE ONLY files ADD CONSTRAINT files_hash_created_by_key UNIQUE (hash, created_by); UniqueGitAuthLinksProviderIDUserIDKey UniqueConstraint = "git_auth_links_provider_id_user_id_key" // ALTER TABLE ONLY git_auth_links ADD CONSTRAINT git_auth_links_provider_id_user_id_key UNIQUE (provider_id, user_id); UniqueGroupMembersUserIDGroupIDKey UniqueConstraint = "group_members_user_id_group_id_key" // ALTER TABLE ONLY group_members ADD CONSTRAINT group_members_user_id_group_id_key UNIQUE (user_id, group_id); diff --git a/enterprise/dbcrypt/cipher.go b/enterprise/dbcrypt/cipher.go new file mode 100644 index 0000000000000..fc6f25ee90955 --- /dev/null +++ b/enterprise/dbcrypt/cipher.go @@ -0,0 +1,98 @@ +package dbcrypt + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + + "golang.org/x/xerrors" +) + +// cipherAES256GCM is the name of the AES-256 cipher. +// This is used to identify the cipher used to encrypt a value. +// It is added to the digest to ensure that if, in the future, +// we add a new cipher type, and a key is re-used, we don't +// accidentally decrypt the wrong values. +// When adding a new cipher type, add a new constant here +// and ensure to add the cipher name to the digest of the new +// cipher type. +const ( + cipherAES256GCM = "aes256gcm" +) + +type Cipher interface { + Encrypt([]byte) ([]byte, error) + Decrypt([]byte) ([]byte, error) + HexDigest() string +} + +// NewCiphers is a convenience function for creating multiple ciphers. +// It currently only supports AES-256-GCM. +func NewCiphers(keys ...[]byte) ([]Cipher, error) { + var cs []Cipher + for _, key := range keys { + c, err := cipherAES256(key) + if err != nil { + return nil, err + } + cs = append(cs, c) + } + return cs, nil +} + +// cipherAES256 returns a new AES-256 cipher. +func cipherAES256(key []byte) (*aes256, error) { + if len(key) != 32 { + return nil, xerrors.Errorf("key must be 32 bytes") + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + // We add the cipher name to the digest to ensure that if, in the future, + // we add a new cipher type, and a key is re-used, we don't accidentally + // decrypt the wrong values. + toDigest := []byte(cipherAES256GCM) + toDigest = append(toDigest, key...) + digest := fmt.Sprintf("%x", sha256.Sum256(toDigest))[:7] + return &aes256{aead: aead, digest: digest}, nil +} + +type aes256 struct { + aead cipher.AEAD + // digest is the first 7 bytes of the hex-encoded SHA-256 digest of aead. + digest string +} + +func (a *aes256) Encrypt(plaintext []byte) ([]byte, error) { + nonce := make([]byte, a.aead.NonceSize()) + _, err := io.ReadFull(rand.Reader, nonce) + if err != nil { + return nil, err + } + dst := make([]byte, len(nonce)) + copy(dst, nonce) + return a.aead.Seal(dst, nonce, plaintext, nil), nil +} + +func (a *aes256) Decrypt(ciphertext []byte) ([]byte, error) { + if len(ciphertext) < a.aead.NonceSize() { + return nil, xerrors.Errorf("ciphertext too short") + } + decrypted, err := a.aead.Open(nil, ciphertext[:a.aead.NonceSize()], ciphertext[a.aead.NonceSize():], nil) + if err != nil { + return nil, &DecryptFailedError{Inner: err} + } + return decrypted, nil +} + +func (a *aes256) HexDigest() string { + return a.digest +} diff --git a/enterprise/dbcrypt/cipher_internal_test.go b/enterprise/dbcrypt/cipher_internal_test.go new file mode 100644 index 0000000000000..b6740de17eec6 --- /dev/null +++ b/enterprise/dbcrypt/cipher_internal_test.go @@ -0,0 +1,91 @@ +package dbcrypt + +import ( + "bytes" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCipherAES256(t *testing.T) { + t.Parallel() + + t.Run("ValidInput", func(t *testing.T) { + t.Parallel() + key := bytes.Repeat([]byte{'a'}, 32) + cipher, err := cipherAES256(key) + require.NoError(t, err) + + output, err := cipher.Encrypt([]byte("hello world")) + require.NoError(t, err) + + response, err := cipher.Decrypt(output) + require.NoError(t, err) + require.Equal(t, "hello world", string(response)) + }) + + t.Run("InvalidInput", func(t *testing.T) { + t.Parallel() + key := bytes.Repeat([]byte{'a'}, 32) + cipher, err := cipherAES256(key) + require.NoError(t, err) + _, err = cipher.Decrypt(bytes.Repeat([]byte{'a'}, 100)) + var decryptErr *DecryptFailedError + require.ErrorAs(t, err, &decryptErr) + }) + + t.Run("InvalidKeySize", func(t *testing.T) { + t.Parallel() + + _, err := cipherAES256(bytes.Repeat([]byte{'a'}, 31)) + require.ErrorContains(t, err, "key must be 32 bytes") + }) + + t.Run("TestNonce", func(t *testing.T) { + t.Parallel() + key := bytes.Repeat([]byte{'a'}, 32) + cipher, err := cipherAES256(key) + require.NoError(t, err) + require.Equal(t, "864f702", cipher.HexDigest()) + + encrypted1, err := cipher.Encrypt([]byte("hello world")) + require.NoError(t, err) + encrypted2, err := cipher.Encrypt([]byte("hello world")) + require.NoError(t, err) + require.NotEqual(t, encrypted1, encrypted2, "nonce should be different for each encryption") + + munged := make([]byte, len(encrypted1)) + copy(munged, encrypted1) + munged[0] = munged[0] ^ 0xff + _, err = cipher.Decrypt(munged) + var decryptErr *DecryptFailedError + require.ErrorAs(t, err, &decryptErr, "munging the first byte of the encrypted data should cause decryption to fail") + }) +} + +// This test ensures backwards compatibility. If it breaks, something is very wrong. +func TestCiphersBackwardCompatibility(t *testing.T) { + t.Parallel() + var ( + msg = "hello world" + key = bytes.Repeat([]byte{'a'}, 32) + //nolint: gosec // The below is the base64-encoded result of encrypting the above message with the above key. + encoded = `YhAz+lE2fFeeiVPH9voKN7UV1xSDrgcnC0LmNXmaAk1Yg0kPFO3x` + ) + + cipher, err := cipherAES256(key) + require.NoError(t, err) + + // This is the code that was used to generate the above. + // Note that the output of this code will change every time it is run. + // encrypted, err := cipher.Encrypt([]byte(msg)) + // require.NoError(t, err) + // t.Logf("encoded: %q", base64.StdEncoding.EncodeToString(encrypted)) + + decoded, err := base64.StdEncoding.DecodeString(encoded) + require.NoError(t, err, "the encoded string should be valid base64") + decrypted, err := cipher.Decrypt(decoded) + require.NoError(t, err, "decryption should succeed") + require.Equal(t, msg, string(decrypted), "decrypted message should match original message") +} diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go new file mode 100644 index 0000000000000..b050757f594b4 --- /dev/null +++ b/enterprise/dbcrypt/dbcrypt.go @@ -0,0 +1,374 @@ +package dbcrypt + +import ( + "context" + "database/sql" + "encoding/base64" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + + "github.com/google/uuid" + "golang.org/x/xerrors" +) + +// testValue is the value that is stored in dbcrypt_keys.test. +// This is used to determine if the key is valid. +const testValue = "coder" + +var ( + b64encode = base64.StdEncoding.EncodeToString + b64decode = base64.StdEncoding.DecodeString +) + +// DecryptFailedError is returned when decryption fails. +type DecryptFailedError struct { + Inner error +} + +func (e *DecryptFailedError) Error() string { + return xerrors.Errorf("decrypt failed: %w", e.Inner).Error() +} + +// New creates a database.Store wrapper that encrypts/decrypts values +// stored at rest in the database. +func New(ctx context.Context, db database.Store, ciphers ...Cipher) (database.Store, error) { + cm := make(map[string]Cipher) + for _, c := range ciphers { + cm[c.HexDigest()] = c + } + dbc := &dbCrypt{ + ciphers: cm, + Store: db, + } + if len(ciphers) > 0 { + dbc.primaryCipherDigest = ciphers[0].HexDigest() + } + // nolint: gocritic // This is allowed. + authCtx := dbauthz.AsSystemRestricted(ctx) + if err := dbc.ensureEncryptedWithRetry(authCtx); err != nil { + return nil, xerrors.Errorf("ensure encrypted database fields: %w", err) + } + return dbc, nil +} + +type dbCrypt struct { + // primaryCipherDigest is the digest of the primary cipher used for encrypting data. + primaryCipherDigest string + // ciphers is a map of cipher digests to ciphers. + ciphers map[string]Cipher + database.Store +} + +func (db *dbCrypt) InTx(function func(database.Store) error, txOpts *sql.TxOptions) error { + return db.Store.InTx(func(s database.Store) error { + return function(&dbCrypt{ + primaryCipherDigest: db.primaryCipherDigest, + ciphers: db.ciphers, + Store: s, + }) + }, txOpts) +} + +func (db *dbCrypt) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) { + ks, err := db.Store.GetDBCryptKeys(ctx) + if err != nil { + return nil, err + } + // Decrypt the test field to ensure that the key is valid. + for i := range ks { + if !ks[i].ActiveKeyDigest.Valid { + // Key has been revoked. We can't decrypt the test field, but + // we need to return it so that the caller knows that the key + // has been revoked. + continue + } + if err := db.decryptField(&ks[i].Test, ks[i].ActiveKeyDigest); err != nil { + return nil, err + } + } + return ks, nil +} + +// This does not need any special handling as it does not touch any encrypted fields. +// Explicitly defining this here to avoid confusion. +func (db *dbCrypt) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { + return db.Store.RevokeDBCryptKey(ctx, activeKeyDigest) +} + +func (db *dbCrypt) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error { + // It's nicer to be able to pass a *sql.NullString to encryptField, but we need to pass a string here. + var digest sql.NullString + err := db.encryptField(&arg.Test, &digest) + if err != nil { + return err + } + arg.ActiveKeyDigest = digest.String + return db.Store.InsertDBCryptKey(ctx, arg) +} + +func (db *dbCrypt) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) { + link, err := db.Store.GetUserLinkByLinkedID(ctx, linkedID) + if err != nil { + return database.UserLink{}, err + } + if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil { + return database.UserLink{}, err + } + if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil { + return database.UserLink{}, err + } + return link, nil +} + +func (db *dbCrypt) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) { + links, err := db.Store.GetUserLinksByUserID(ctx, userID) + if err != nil { + return nil, err + } + for idx := range links { + if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil { + return nil, err + } + if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil { + return nil, err + } + } + return links, nil +} + +func (db *dbCrypt) GetUserLinkByUserIDLoginType(ctx context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { + link, err := db.Store.GetUserLinkByUserIDLoginType(ctx, params) + if err != nil { + return database.UserLink{}, err + } + if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil { + return database.UserLink{}, err + } + if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil { + return database.UserLink{}, err + } + return link, nil +} + +func (db *dbCrypt) InsertUserLink(ctx context.Context, params database.InsertUserLinkParams) (database.UserLink, error) { + if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil { + return database.UserLink{}, err + } + if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil { + return database.UserLink{}, err + } + link, err := db.Store.InsertUserLink(ctx, params) + if err != nil { + return database.UserLink{}, err + } + if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil { + return database.UserLink{}, err + } + if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil { + return database.UserLink{}, err + } + return link, nil +} + +func (db *dbCrypt) UpdateUserLink(ctx context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) { + if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil { + return database.UserLink{}, err + } + if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil { + return database.UserLink{}, err + } + link, err := db.Store.UpdateUserLink(ctx, params) + if err != nil { + return database.UserLink{}, err + } + if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil { + return database.UserLink{}, err + } + if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil { + return database.UserLink{}, err + } + return link, nil +} + +func (db *dbCrypt) InsertGitAuthLink(ctx context.Context, params database.InsertGitAuthLinkParams) (database.GitAuthLink, error) { + if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil { + return database.GitAuthLink{}, err + } + if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil { + return database.GitAuthLink{}, err + } + link, err := db.Store.InsertGitAuthLink(ctx, params) + if err != nil { + return database.GitAuthLink{}, err + } + if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil { + return database.GitAuthLink{}, err + } + if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil { + return database.GitAuthLink{}, err + } + return link, nil +} + +func (db *dbCrypt) GetGitAuthLink(ctx context.Context, params database.GetGitAuthLinkParams) (database.GitAuthLink, error) { + link, err := db.Store.GetGitAuthLink(ctx, params) + if err != nil { + return database.GitAuthLink{}, err + } + if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil { + return database.GitAuthLink{}, err + } + if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil { + return database.GitAuthLink{}, err + } + return link, nil +} + +func (db *dbCrypt) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) { + links, err := db.Store.GetGitAuthLinksByUserID(ctx, userID) + if err != nil { + return nil, err + } + for idx := range links { + if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil { + return nil, err + } + if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil { + return nil, err + } + } + return links, nil +} + +func (db *dbCrypt) UpdateGitAuthLink(ctx context.Context, params database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { + if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil { + return database.GitAuthLink{}, err + } + if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil { + return database.GitAuthLink{}, err + } + link, err := db.Store.UpdateGitAuthLink(ctx, params) + if err != nil { + return database.GitAuthLink{}, err + } + if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil { + return database.GitAuthLink{}, err + } + if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil { + return database.GitAuthLink{}, err + } + return link, nil +} + +func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error { + // If no cipher is loaded, then we can't encrypt anything! + if db.ciphers == nil || db.primaryCipherDigest == "" { + return nil + } + + if field == nil { + return xerrors.Errorf("developer error: encryptField called with nil field") + } + if digest == nil { + return xerrors.Errorf("developer error: encryptField called with nil digest") + } + + encrypted, err := db.ciphers[db.primaryCipherDigest].Encrypt([]byte(*field)) + if err != nil { + return err + } + // Base64 is used to support UTF-8 encoding in PostgreSQL. + *field = b64encode(encrypted) + *digest = sql.NullString{String: db.primaryCipherDigest, Valid: true} + return nil +} + +// decryptFields decrypts the given field using the key with the given digest. +// If the value fails to decrypt, sql.ErrNoRows will be returned. +func (db *dbCrypt) decryptField(field *string, digest sql.NullString) error { + if field == nil { + return xerrors.Errorf("developer error: decryptField called with nil field") + } + + if !digest.Valid || digest.String == "" { + // This field is not encrypted. + return nil + } + + key, ok := db.ciphers[digest.String] + if !ok { + return &DecryptFailedError{ + Inner: xerrors.Errorf("no cipher with digest %q", digest.String), + } + } + + data, err := b64decode(*field) + if err != nil { + // If it's not valid base64, we should complain loudly. + return &DecryptFailedError{ + Inner: xerrors.Errorf("malformed encrypted field %q: %w", *field, err), + } + } + decrypted, err := key.Decrypt(data) + if err != nil { + return &DecryptFailedError{Inner: err} + } + *field = string(decrypted) + return nil +} + +func (db *dbCrypt) ensureEncryptedWithRetry(ctx context.Context) error { + var err error + for i := 0; i < 3; i++ { + err = db.ensureEncrypted(ctx) + if err == nil { + return nil + } + // If we get a serialization error, then we need to retry. + if !database.IsSerializedError(err) { + return err + } + // otherwise, retry + } + // If we get here, then we ran out of retries + return err +} + +func (db *dbCrypt) ensureEncrypted(ctx context.Context) error { + return db.InTx(func(s database.Store) error { + // Attempt to read the encrypted test fields of the currently active keys. + ks, err := s.GetDBCryptKeys(ctx) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return err + } + + var highestNumber int32 + var activeCipherFound bool + for _, k := range ks { + // If our primary key has been revoked, then we can't do anything. + if k.RevokedKeyDigest.Valid && k.RevokedKeyDigest.String == db.primaryCipherDigest { + return xerrors.Errorf("primary encryption key %q has been revoked", db.primaryCipherDigest) + } + + if k.ActiveKeyDigest.Valid && k.ActiveKeyDigest.String == db.primaryCipherDigest { + activeCipherFound = true + } + + if k.Number > highestNumber { + highestNumber = k.Number + } + } + + if activeCipherFound || len(db.ciphers) == 0 { + return nil + } + + // If we get here, then we have a new key that we need to insert. + return db.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{ + Number: highestNumber + 1, + ActiveKeyDigest: db.primaryCipherDigest, + Test: testValue, + }) + }, &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) +} diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go new file mode 100644 index 0000000000000..1b457373b28f8 --- /dev/null +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -0,0 +1,679 @@ +package dbcrypt + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "io" + "testing" + + "github.com/golang/mock/gomock" + "github.com/lib/pq" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtestutil" +) + +func TestUserLinks(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("InsertUserLink", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.UserLink(t, crypt, database.UserLink{ + UserID: user.ID, + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + require.Equal(t, "access", link.OAuthAccessToken) + require.Equal(t, "refresh", link.OAuthRefreshToken) + require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String) + + rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access") + requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh") + }) + + t.Run("UpdateUserLink", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.UserLink(t, crypt, database.UserLink{ + UserID: user.ID, + }) + + updated, err := crypt.UpdateUserLink(ctx, database.UpdateUserLinkParams{ + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + UserID: link.UserID, + LoginType: link.LoginType, + }) + require.NoError(t, err) + require.Equal(t, "access", updated.OAuthAccessToken) + require.Equal(t, "refresh", updated.OAuthRefreshToken) + require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String) + + rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access") + requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh") + }) + + t.Run("GetUserLinkByLinkedID", func(t *testing.T) { + t.Parallel() + t.Run("OK", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.UserLink(t, crypt, database.UserLink{ + UserID: user.ID, + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + + link, err := crypt.GetUserLinkByLinkedID(ctx, link.LinkedID) + require.NoError(t, err) + require.Equal(t, "access", link.OAuthAccessToken) + require.Equal(t, "refresh", link.OAuthRefreshToken) + require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String) + + rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access") + requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh") + }) + + t.Run("DecryptErr", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, db, database.User{}) + link := dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + OAuthAccessToken: fakeBase64RandomData(t, 32), + OAuthRefreshToken: fakeBase64RandomData(t, 32), + OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true}, + OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true}, + }) + + _, err := crypt.GetUserLinkByLinkedID(ctx, link.LinkedID) + require.Error(t, err, "expected an error") + var derr *DecryptFailedError + require.ErrorAs(t, err, &derr, "expected a decrypt error") + }) + }) + + t.Run("GetUserLinksByUserID", func(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.UserLink(t, crypt, database.UserLink{ + UserID: user.ID, + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + links, err := crypt.GetUserLinksByUserID(ctx, link.UserID) + require.NoError(t, err) + require.Len(t, links, 1) + require.Equal(t, "access", links[0].OAuthAccessToken) + require.Equal(t, "refresh", links[0].OAuthRefreshToken) + require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthAccessTokenKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthRefreshTokenKeyID.String) + + rawLinks, err := db.GetUserLinksByUserID(ctx, link.UserID) + require.NoError(t, err) + require.Len(t, rawLinks, 1) + requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthAccessToken, "access") + requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthRefreshToken, "refresh") + }) + + t.Run("Empty", func(t *testing.T) { + t.Parallel() + _, crypt, _ := setup(t) + user := dbgen.User(t, crypt, database.User{}) + links, err := crypt.GetUserLinksByUserID(ctx, user.ID) + require.NoError(t, err) + require.Empty(t, links) + }) + + t.Run("DecryptErr", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, db, database.User{}) + _ = dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + OAuthAccessToken: fakeBase64RandomData(t, 32), + OAuthRefreshToken: fakeBase64RandomData(t, 32), + OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true}, + OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true}, + }) + _, err := crypt.GetUserLinksByUserID(ctx, user.ID) + require.Error(t, err, "expected an error") + var derr *DecryptFailedError + require.ErrorAs(t, err, &derr, "expected a decrypt error") + }) + }) + + t.Run("GetUserLinkByUserIDLoginType", func(t *testing.T) { + t.Parallel() + t.Run("OK", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.UserLink(t, crypt, database.UserLink{ + UserID: user.ID, + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + + link, err := crypt.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: link.UserID, + LoginType: link.LoginType, + }) + require.NoError(t, err) + require.Equal(t, "access", link.OAuthAccessToken) + require.Equal(t, "refresh", link.OAuthRefreshToken) + require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String) + + rawLink, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: link.UserID, + LoginType: link.LoginType, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access") + requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh") + }) + + t.Run("DecryptErr", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, db, database.User{}) + link := dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + OAuthAccessToken: fakeBase64RandomData(t, 32), + OAuthRefreshToken: fakeBase64RandomData(t, 32), + OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true}, + OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true}, + }) + + _, err := crypt.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: link.UserID, + LoginType: link.LoginType, + }) + require.Error(t, err, "expected an error") + var derr *DecryptFailedError + require.ErrorAs(t, err, &derr, "expected a decrypt error") + }) + }) +} + +func TestGitAuthLinks(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("InsertGitAuthLink", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{ + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + require.Equal(t, "access", link.OAuthAccessToken) + require.Equal(t, "refresh", link.OAuthRefreshToken) + + link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access") + requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh") + }) + + t.Run("UpdateGitAuthLink", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{}) + updated, err := crypt.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + require.NoError(t, err) + require.Equal(t, "access", updated.OAuthAccessToken) + require.Equal(t, "refresh", updated.OAuthRefreshToken) + + link, err = db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access") + requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh") + }) + + t.Run("GetGitAuthLink", func(t *testing.T) { + t.Run("OK", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{ + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{ + UserID: link.UserID, + ProviderID: link.ProviderID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access") + requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh") + }) + t.Run("DecryptErr", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + link := dbgen.GitAuthLink(t, db, database.GitAuthLink{ + OAuthAccessToken: fakeBase64RandomData(t, 32), + OAuthRefreshToken: fakeBase64RandomData(t, 32), + OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true}, + OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true}, + }) + + _, err := crypt.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{ + UserID: link.UserID, + ProviderID: link.ProviderID, + }) + require.Error(t, err, "expected an error") + var derr *DecryptFailedError + require.ErrorAs(t, err, &derr, "expected a decrypt error") + }) + }) + + t.Run("GetGitAuthLinksByUserID", func(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{ + UserID: user.ID, + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + links, err := crypt.GetGitAuthLinksByUserID(ctx, link.UserID) + require.NoError(t, err) + require.Len(t, links, 1) + require.Equal(t, "access", links[0].OAuthAccessToken) + require.Equal(t, "refresh", links[0].OAuthRefreshToken) + require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthAccessTokenKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthRefreshTokenKeyID.String) + + rawLinks, err := db.GetGitAuthLinksByUserID(ctx, link.UserID) + require.NoError(t, err) + require.Len(t, rawLinks, 1) + requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthAccessToken, "access") + requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthRefreshToken, "refresh") + }) + + t.Run("DecryptErr", func(t *testing.T) { + db, crypt, ciphers := setup(t) + user := dbgen.User(t, db, database.User{}) + link := dbgen.GitAuthLink(t, db, database.GitAuthLink{ + UserID: user.ID, + OAuthAccessToken: fakeBase64RandomData(t, 32), + OAuthRefreshToken: fakeBase64RandomData(t, 32), + OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true}, + OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true}, + }) + _, err := crypt.GetGitAuthLinksByUserID(ctx, link.UserID) + require.Error(t, err, "expected an error") + var derr *DecryptFailedError + require.ErrorAs(t, err, &derr, "expected a decrypt error") + }) + }) +} + +func TestNew(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + // Given: a cipher is loaded + cipher := initCipher(t) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rawDB, _ := dbtestutil.NewDB(t) + + // Before: no keys should be present + keys, err := rawDB.GetDBCryptKeys(ctx) + require.NoError(t, err, "no error should be returned") + require.Empty(t, keys, "no keys should be present") + + // When: we init the crypt db + _, err = New(ctx, rawDB, cipher) + require.NoError(t, err) + + // Then: a new key is inserted + keys, err = rawDB.GetDBCryptKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1, "one key should be present") + require.Equal(t, cipher.HexDigest(), keys[0].ActiveKeyDigest.String, "key digest mismatch") + require.Empty(t, keys[0].RevokedKeyDigest.String, "key should not be revoked") + requireEncryptedEquals(t, cipher, keys[0].Test, "coder") + }) + + t.Run("MissingKey", func(t *testing.T) { + t.Parallel() + + // Given: there exist two valid encryption keys + cipher1 := initCipher(t) + cipher2 := initCipher(t) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rawDB, _ := dbtestutil.NewDB(t) + + // Given: key 1 is already present in the database + err := rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{ + Number: 1, + ActiveKeyDigest: cipher1.HexDigest(), + Test: fakeBase64RandomData(t, 32), + }) + require.NoError(t, err, "no error should be returned") + keys, err := rawDB.GetDBCryptKeys(ctx) + require.NoError(t, err, "no error should be returned") + require.Len(t, keys, 1, "one key should be present") + + // When: we init the crypt db with no keys + _, err = New(ctx, rawDB) + // Then: we error because we don't know how to decrypt the existing key + require.Error(t, err, "expected an error") + var derr *DecryptFailedError + require.ErrorAs(t, err, &derr, "expected a decrypt error") + + // When: we init the crypt db with key 2 + _, err = New(ctx, rawDB, cipher2) + + // Then: we error because the key is not revoked and we don't know how to decrypt it + require.Error(t, err, "expected an error") + require.ErrorAs(t, err, &derr, "expected a decrypt error") + + // When: the existing key is marked as having been revoked + err = rawDB.RevokeDBCryptKey(ctx, cipher1.HexDigest()) + require.NoError(t, err, "no error should be returned") + + // And: we init the crypt db with key 2 + _, err = New(ctx, rawDB, cipher2) + + // Then: we succeed + require.NoError(t, err) + + // And: key 2 should now be the active key + keys, err = rawDB.GetDBCryptKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 2, "two keys should be present") + require.EqualValues(t, keys[0].Number, 1, "key number mismatch") + require.Empty(t, keys[0].ActiveKeyDigest.String, "key should not be active") + require.Equal(t, cipher1.HexDigest(), keys[0].RevokedKeyDigest.String, "key should be revoked") + + require.EqualValues(t, keys[1].Number, 2, "key number mismatch") + require.Equal(t, cipher2.HexDigest(), keys[1].ActiveKeyDigest.String, "key digest mismatch") + require.Empty(t, keys[1].RevokedKeyDigest.String, "key should not be revoked") + requireEncryptedEquals(t, cipher2, keys[1].Test, "coder") + }) + + t.Run("NoKeys", func(t *testing.T) { + t.Parallel() + // Given: no cipher is loaded + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rawDB, _ := dbtestutil.NewDB(t) + + keys, err := rawDB.GetDBCryptKeys(ctx) + require.NoError(t, err, "no error should be returned") + require.Empty(t, keys, "no keys should be present") + + // When: we init the crypt db with no ciphers + _, err = New(ctx, rawDB) + + // Then: it should succeed. + require.NoError(t, err, "dbcrypt.New should work with no keys against an unencrypted database") + + // Assert invariant: no keys are inserted + keys, err = rawDB.GetDBCryptKeys(ctx) + require.NoError(t, err, "no error should be returned") + require.Empty(t, keys, "no keys should be present") + + // Insert a key + require.NoError(t, rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{ + Number: 1, + ActiveKeyDigest: "whatever", + Test: fakeBase64RandomData(t, 32), + })) + + // This should fail as we do not know how to decrypt the key: + _, err = New(ctx, rawDB) + require.Error(t, err) + // Until we revoke the key: + require.NoError(t, rawDB.RevokeDBCryptKey(ctx, "whatever")) + _, err = New(ctx, rawDB) + require.NoError(t, err, "the above should still hold if the key is revoked") + }) + + t.Run("PrimaryRevoked", func(t *testing.T) { + t.Parallel() + // Given: a cipher is loaded + cipher := initCipher(t) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rawDB, _ := dbtestutil.NewDB(t) + + // And: the cipher is revoked before we init the crypt db + err := rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{ + Number: 1, + ActiveKeyDigest: cipher.HexDigest(), + Test: fakeBase64RandomData(t, 32), + }) + require.NoError(t, err, "no error should be returned") + err = rawDB.RevokeDBCryptKey(ctx, cipher.HexDigest()) + require.NoError(t, err, "no error should be returned") + + // Then: when we init the crypt db, we error because the key is revoked + _, err = New(ctx, rawDB, cipher) + require.Error(t, err) + require.ErrorContains(t, err, "has been revoked") + }) + + t.Run("Retry", func(t *testing.T) { + t.Parallel() + // Given: a cipher is loaded + cipher := initCipher(t) + ctx, cancel := context.WithCancel(context.Background()) + testVal, err := cipher.Encrypt([]byte("coder")) + key := database.DBCryptKey{ + Number: 1, + ActiveKeyDigest: sql.NullString{String: cipher.HexDigest(), Valid: true}, + Test: b64encode(testVal), + } + require.NoError(t, err) + t.Cleanup(cancel) + + // And: a database that returns an error once when we try to serialize a key + ctrl := gomock.NewController(t) + mockDB := dbmock.NewMockStore(ctrl) + + gomock.InOrder( + // First try: we get a serialization error. + expectInTx(mockDB), + mockDB.EXPECT().GetDBCryptKeys(gomock.Any()).Times(1).Return([]database.DBCryptKey{}, nil), + mockDB.EXPECT().InsertDBCryptKey(gomock.Any(), gomock.Any()).Times(1).Return(&pq.Error{Code: "40001"}), + // Second try: we get the key we wanted to insert initially. + expectInTx(mockDB), + mockDB.EXPECT().GetDBCryptKeys(gomock.Any()).Times(1).Return([]database.DBCryptKey{key}, nil), + ) + + _, err = New(ctx, mockDB, cipher) + require.NoError(t, err) + }) +} + +func TestEncryptDecryptField(t *testing.T) { + t.Parallel() + t.Run("OK", func(t *testing.T) { + t.Parallel() + _, cryptDB, ciphers := setup(t) + field := "coder" + digest := sql.NullString{} + require.NoError(t, cryptDB.encryptField(&field, &digest)) + require.Equal(t, ciphers[0].HexDigest(), digest.String) + requireEncryptedEquals(t, ciphers[0], field, "coder") + require.NoError(t, cryptDB.decryptField(&field, digest)) + require.Equal(t, "coder", field) + }) + + t.Run("NoKeys", func(t *testing.T) { + t.Parallel() + // With no keys, encryption and decryption are both no-ops. + _, cryptDB := setupNoCiphers(t) + field := "coder" + digest := sql.NullString{} + require.NoError(t, cryptDB.encryptField(&field, &digest)) + require.Empty(t, digest.String) + require.Equal(t, "coder", field) + require.NoError(t, cryptDB.decryptField(&field, digest)) + require.Equal(t, "coder", field) + }) + + t.Run("MissingKey", func(t *testing.T) { + t.Parallel() + _, cryptDB, ciphers := setup(t) + field := "coder" + digest := sql.NullString{} + err := cryptDB.encryptField(&field, &digest) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], field, "coder") + require.Equal(t, ciphers[0].HexDigest(), digest.String) + require.True(t, digest.Valid) + + digest = sql.NullString{String: "missing", Valid: true} + var derr *DecryptFailedError + err = cryptDB.decryptField(&field, digest) + require.Error(t, err) + require.ErrorAs(t, err, &derr) + }) + + t.Run("CantEncryptOrDecryptNil", func(t *testing.T) { + t.Parallel() + _, cryptDB, _ := setup(t) + require.ErrorContains(t, cryptDB.encryptField(nil, nil), "developer error") + require.ErrorContains(t, cryptDB.decryptField(nil, sql.NullString{}), "developer error") + }) + + t.Run("EncryptEmptyString", func(t *testing.T) { + t.Parallel() + _, cryptDB, ciphers := setup(t) + field := "" + digest := sql.NullString{} + require.NoError(t, cryptDB.encryptField(&field, &digest)) + requireEncryptedEquals(t, ciphers[0], field, "") + require.Equal(t, ciphers[0].HexDigest(), digest.String) + require.NoError(t, cryptDB.decryptField(&field, digest)) + require.Empty(t, field) + }) + + t.Run("DecryptEmptyString", func(t *testing.T) { + t.Parallel() + _, cryptDB, ciphers := setup(t) + field := "" + digest := sql.NullString{String: ciphers[0].HexDigest(), Valid: true} + err := cryptDB.decryptField(&field, digest) + // Currently this has to fail because the ciphertext must at least + // have a nonce. This may need to be changed depending on future + // ciphers. + require.ErrorContains(t, err, "ciphertext too short") + }) + + t.Run("InvalidBase64", func(t *testing.T) { + t.Parallel() + _, cryptDB, ciphers := setup(t) + field := "not valid base64" + digest := sql.NullString{String: ciphers[0].HexDigest(), Valid: true} + err := cryptDB.decryptField(&field, digest) + require.ErrorContains(t, err, "illegal base64 data") + }) +} + +func expectInTx(mdb *dbmock.MockStore) *gomock.Call { + return mdb.EXPECT().InTx(gomock.Any(), gomock.Any()).Times(1).DoAndReturn( + func(f func(store database.Store) error, _ *sql.TxOptions) error { + return f(mdb) + }, + ) +} + +func requireEncryptedEquals(t *testing.T, c Cipher, value, expected string) { + t.Helper() + data, err := base64.StdEncoding.DecodeString(value) + require.NoError(t, err, "invalid base64") + got, err := c.Decrypt(data) + require.NoError(t, err, "failed to decrypt data") + require.Equal(t, expected, string(got), "decrypted data does not match") +} + +func initCipher(t *testing.T) *aes256 { + t.Helper() + key := make([]byte, 32) // AES-256 key size is 32 bytes + _, err := io.ReadFull(rand.Reader, key) + require.NoError(t, err) + c, err := cipherAES256(key) + require.NoError(t, err) + return c +} + +func setup(t *testing.T) (db database.Store, cryptDB *dbCrypt, cs []Cipher) { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rawDB, _ := dbtestutil.NewDB(t) + + cs = append(cs, initCipher(t)) + cdb, err := New(ctx, rawDB, cs...) + require.NoError(t, err) + cryptDB, ok := cdb.(*dbCrypt) + require.True(t, ok) + + return rawDB, cryptDB, cs +} + +func setupNoCiphers(t *testing.T) (db database.Store, cryptodb *dbCrypt) { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rawDB, _ := dbtestutil.NewDB(t) + cdb, err := New(ctx, rawDB) + require.NoError(t, err) + cryptDB, ok := cdb.(*dbCrypt) + require.True(t, ok) + return rawDB, cryptDB +} + +func fakeBase64RandomData(t *testing.T, n int) string { + t.Helper() + b := make([]byte, n) + _, err := io.ReadFull(rand.Reader, b) + require.NoError(t, err) + return base64.StdEncoding.EncodeToString(b) +} diff --git a/enterprise/dbcrypt/doc.go b/enterprise/dbcrypt/doc.go new file mode 100644 index 0000000000000..39429f059a200 --- /dev/null +++ b/enterprise/dbcrypt/doc.go @@ -0,0 +1,34 @@ +// Package dbcrypt provides a database.Store wrapper that encrypts/decrypts +// values stored at rest in the database. +// +// Encryption is done using Ciphers, which is an abstraction over a set of +// encryption keys. Each key has a unique identifier, which is used to +// uniquely identify the key whilst maintaining secrecy. +// +// Currently, AES-256-GCM is the only implemented cipher mode. +// The Cipher is currently used to encrypt/decrypt the following fields: +// - database.UserLink.OAuthAccessToken +// - database.UserLink.OAuthRefreshToken +// - database.GitAuthLink.OAuthAccessToken +// - database.GitAuthLink.OAuthRefreshToken +// - database.DBCryptSentinelValue +// +// Multiple ciphers can be provided to support key rotation. The primary cipher +// is used to encrypt and decrypt all data. Secondary ciphers are only used +// for decryption and, as a general rule, should only be active when rotating +// keys. +// +// Encryption keys are stored in the database in the table `dbcrypt_keys`. +// The table has the following schema: +// - number: the key number. This is used to avoid conflicts when rotating keys. +// - created_at: the time the key was created. +// - active_key_digest: the SHA256 digest of the active key. If null, the key has been revoked. +// - revoked_key_digest: the SHA256 digest of the revoked key. If null, the key has not been revoked. +// - revoked_at: the time the key was revoked. If null, the key has not been revoked. +// - test: the encrypted value of the string "coder". This is used to ensure that the key is valid. +// +// Encrypted fields are stored in the database as a base64-encoded string. +// Each encrypted column MUST have a corresponding _key_id column that is a foreign key +// reference to `dbcrypt_keys.active_key_digest`. This ensures that a key cannot be +// revoked until all rows that use that key have been migrated to a new key. +package dbcrypt