From 8d5701a2ba5ad2b96e604775b41b95d3835f7751 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 30 Aug 2023 09:01:48 +0000 Subject: [PATCH] feat(coderd): add dbcrypt package This commit builds upon the previous work in #7959: - Moved dbcrypt package to enterprise/dbcrypt - Modified original dbcrypt behaviour to not delete un-decryptable rows. - Added a table dbcrypt_sentinel used to determine database encryption status. - Added support for multiple encryption keys in dbcrypt. NOTE: This is part 1 of a 2-part PR. This PR focuses mainly on the dbcrypt and database packages. A separate PR will add the required plumbing to integrate this into enterprise/coderd properly. Co-authored-by: Kyle Carberry --- coderd/database/dbauthz/dbauthz.go | 28 ++ coderd/database/dbfake/dbfake.go | 42 +++ coderd/database/dbgen/dbgen.go | 4 +- coderd/database/dbmetrics/dbmetrics.go | 28 ++ coderd/database/dbmock/dbmock.go | 59 ++++ coderd/database/dump.sql | 14 + .../000153_dbcrypt_sentinel_value.down.sql | 1 + .../000153_dbcrypt_sentinel_value.up.sql | 8 + coderd/database/migrations/migrate_test.go | 1 + coderd/database/models.go | 8 + coderd/database/querier.go | 4 + coderd/database/queries.sql.go | 89 ++++++ coderd/database/queries/dbcrypt.sql | 5 + coderd/database/queries/gitauth.sql | 4 + coderd/database/queries/user_links.sql | 3 + coderd/database/unique_constraint.go | 1 + enterprise/dbcrypt/cipher.go | 126 ++++++++ enterprise/dbcrypt/cipher_test.go | 144 +++++++++ enterprise/dbcrypt/dbcrypt.go | 296 ++++++++++++++++++ enterprise/dbcrypt/dbcrypt_test.go | 289 +++++++++++++++++ 20 files changed, 1152 insertions(+), 2 deletions(-) create mode 100644 coderd/database/migrations/000153_dbcrypt_sentinel_value.down.sql create mode 100644 coderd/database/migrations/000153_dbcrypt_sentinel_value.up.sql create mode 100644 coderd/database/queries/dbcrypt.sql create mode 100644 enterprise/dbcrypt/cipher.go create mode 100644 enterprise/dbcrypt/cipher_test.go create mode 100644 enterprise/dbcrypt/dbcrypt.go create mode 100644 enterprise/dbcrypt/dbcrypt_test.go diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 9115e9b5ac184..10ec5c1a5c287 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -828,6 +828,13 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI return q.db.GetAuthorizationUserRoles(ctx, userID) } +func (q *querier) GetDBCryptSentinelValue(ctx context.Context) (string, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return "", err + } + return q.db.GetDBCryptSentinelValue(ctx) +} + func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { return "", err @@ -904,6 +911,13 @@ func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLin return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg) } +func (q *querier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetGitAuthLinksByUserID(ctx, userID) +} + func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) { return fetch(q.log, q.auth, q.db.GetGitSSHKey)(ctx, userID) } @@ -1472,6 +1486,13 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database return q.db.GetUserLinkByUserIDLoginType(ctx, arg) } +func (q *querier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { + return nil, err + } + return q.db.GetUserLinksByUserID(ctx, userID) +} + func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) { // This does the filtering in SQL. prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type) @@ -2134,6 +2155,13 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg) } +func (q *querier) SetDBCryptSentinelValue(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { + return err + } + return q.db.SetDBCryptSentinelValue(ctx, value) +} + func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) { return q.db.TryAcquireLock(ctx, id) } diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index f7db817e65e96..e4f68815488ce 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -44,6 +44,7 @@ func New() database.Store { organizationMembers: make([]database.OrganizationMember, 0), organizations: make([]database.Organization, 0), users: make([]database.User, 0), + dbcryptSentinelValue: nil, gitAuthLinks: make([]database.GitAuthLink, 0), groups: make([]database.Group, 0), groupMembers: make([]database.GroupMember, 0), @@ -116,6 +117,7 @@ type data struct { // New tables workspaceAgentStats []database.WorkspaceAgentStat auditLogs []database.AuditLog + dbcryptSentinelValue *string files []database.File gitAuthLinks []database.GitAuthLink gitSSHKey []database.GitSSHKey @@ -1150,6 +1152,15 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U }, nil } +func (q *FakeQuerier) GetDBCryptSentinelValue(_ context.Context) (string, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + if q.dbcryptSentinelValue == nil { + return "", sql.ErrNoRows + } + return *q.dbcryptSentinelValue, nil +} + func (q *FakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -1392,6 +1403,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL return database.GitAuthLink{}, sql.ErrNoRows } +func (q *FakeQuerier) GetGitAuthLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + gals := make([]database.GitAuthLink, 0) + for _, gal := range q.gitAuthLinks { + if gal.UserID == userID { + gals = append(gals, gal) + } + } + return gals, nil +} + func (q *FakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2832,6 +2855,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat return database.UserLink{}, sql.ErrNoRows } +func (q *FakeQuerier) GetUserLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.UserLink, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + uls := make([]database.UserLink, 0) + for _, ul := range q.userLinks { + if ul.UserID == userID { + uls = append(uls, ul) + } + } + return uls, nil +} + func (q *FakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) { if err := validateDatabaseType(params); err != nil { return nil, err @@ -4791,6 +4826,13 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg return database.WorkspaceProxy{}, sql.ErrNoRows } +func (q *FakeQuerier) SetDBCryptSentinelValue(_ context.Context, value string) error { + q.mutex.Lock() + defer q.mutex.Unlock() + q.dbcryptSentinelValue = &value + return nil +} + func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) { return false, xerrors.New("TryAcquireLock must only be called within a transaction") } diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 2c3088b9be3b0..cba2e093c7514 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -473,7 +473,7 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database. LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub), LinkedID: takeFirst(orig.LinkedID), OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()), OAuthExpiry: takeFirst(orig.OAuthExpiry, database.Now().Add(time.Hour*24)), }) @@ -486,7 +486,7 @@ func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) dat ProviderID: takeFirst(orig.ProviderID, uuid.New().String()), UserID: takeFirst(orig.UserID, uuid.New()), OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), - OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()), + OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()), OAuthExpiry: takeFirst(orig.OAuthExpiry, database.Now().Add(time.Hour*24)), CreatedAt: takeFirst(orig.CreatedAt, database.Now()), UpdatedAt: takeFirst(orig.UpdatedAt, database.Now()), diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 8526eb4da1078..b94fa6270f5fc 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -279,6 +279,13 @@ func (m metricsStore) GetAuthorizationUserRoles(ctx context.Context, userID uuid return row, err } +func (m metricsStore) GetDBCryptSentinelValue(ctx context.Context) (string, error) { + start := time.Now() + r0, r1 := m.s.GetDBCryptSentinelValue(ctx) + m.queryLatencies.WithLabelValues("GetDBCryptSentinelValue").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetDERPMeshKey(ctx context.Context) (string, error) { start := time.Now() key, err := m.s.GetDERPMeshKey(ctx) @@ -349,6 +356,13 @@ func (m metricsStore) GetGitAuthLink(ctx context.Context, arg database.GetGitAut return link, err } +func (m metricsStore) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) { + start := time.Now() + r0, r1 := m.s.GetGitAuthLinksByUserID(ctx, userID) + m.queryLatencies.WithLabelValues("GetGitAuthLinksByUserID").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) { start := time.Now() key, err := m.s.GetGitSSHKey(ctx, userID) @@ -774,6 +788,13 @@ func (m metricsStore) GetUserLinkByUserIDLoginType(ctx context.Context, arg data return link, err } +func (m metricsStore) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) { + start := time.Now() + r0, r1 := m.s.GetUserLinksByUserID(ctx, userID) + m.queryLatencies.WithLabelValues("GetUserLinksByUserID").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) { start := time.Now() users, err := m.s.GetUsers(ctx, arg) @@ -1320,6 +1341,13 @@ func (m metricsStore) RegisterWorkspaceProxy(ctx context.Context, arg database.R return proxy, err } +func (m metricsStore) SetDBCryptSentinelValue(ctx context.Context, value string) error { + start := time.Now() + r0 := m.s.SetDBCryptSentinelValue(ctx, value) + m.queryLatencies.WithLabelValues("SetDBCryptSentinelValue").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) { start := time.Now() ok, err := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index b0ae7955a458d..bb05528036946 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -506,6 +506,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspaces(arg0, arg1, arg2 interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspaces", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspaces), arg0, arg1, arg2) } +// GetDBCryptSentinelValue mocks base method. +func (m *MockStore) GetDBCryptSentinelValue(arg0 context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDBCryptSentinelValue", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetDBCryptSentinelValue indicates an expected call of GetDBCryptSentinelValue. +func (mr *MockStoreMockRecorder) GetDBCryptSentinelValue(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDBCryptSentinelValue", reflect.TypeOf((*MockStore)(nil).GetDBCryptSentinelValue), arg0) +} + // GetDERPMeshKey mocks base method. func (m *MockStore) GetDERPMeshKey(arg0 context.Context) (string, error) { m.ctrl.T.Helper() @@ -656,6 +671,21 @@ func (mr *MockStoreMockRecorder) GetGitAuthLink(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGitAuthLink", reflect.TypeOf((*MockStore)(nil).GetGitAuthLink), arg0, arg1) } +// GetGitAuthLinksByUserID mocks base method. +func (m *MockStore) GetGitAuthLinksByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.GitAuthLink, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGitAuthLinksByUserID", arg0, arg1) + ret0, _ := ret[0].([]database.GitAuthLink) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetGitAuthLinksByUserID indicates an expected call of GetGitAuthLinksByUserID. +func (mr *MockStoreMockRecorder) GetGitAuthLinksByUserID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGitAuthLinksByUserID", reflect.TypeOf((*MockStore)(nil).GetGitAuthLinksByUserID), arg0, arg1) +} + // GetGitSSHKey mocks base method. func (m *MockStore) GetGitSSHKey(arg0 context.Context, arg1 uuid.UUID) (database.GitSSHKey, error) { m.ctrl.T.Helper() @@ -1601,6 +1631,21 @@ func (mr *MockStoreMockRecorder) GetUserLinkByUserIDLoginType(arg0, arg1 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserLinkByUserIDLoginType", reflect.TypeOf((*MockStore)(nil).GetUserLinkByUserIDLoginType), arg0, arg1) } +// GetUserLinksByUserID mocks base method. +func (m *MockStore) GetUserLinksByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.UserLink, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserLinksByUserID", arg0, arg1) + ret0, _ := ret[0].([]database.UserLink) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserLinksByUserID indicates an expected call of GetUserLinksByUserID. +func (mr *MockStoreMockRecorder) GetUserLinksByUserID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserLinksByUserID", reflect.TypeOf((*MockStore)(nil).GetUserLinksByUserID), arg0, arg1) +} + // GetUsers mocks base method. func (m *MockStore) GetUsers(arg0 context.Context, arg1 database.GetUsersParams) ([]database.GetUsersRow, error) { m.ctrl.T.Helper() @@ -2789,6 +2834,20 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), arg0, arg1) } +// SetDBCryptSentinelValue mocks base method. +func (m *MockStore) SetDBCryptSentinelValue(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDBCryptSentinelValue", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDBCryptSentinelValue indicates an expected call of SetDBCryptSentinelValue. +func (mr *MockStoreMockRecorder) SetDBCryptSentinelValue(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDBCryptSentinelValue", reflect.TypeOf((*MockStore)(nil).SetDBCryptSentinelValue), arg0, arg1) +} + // TryAcquireLock mocks base method. func (m *MockStore) TryAcquireLock(arg0 context.Context, arg1 int64) (bool, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 0c16610c89af8..e266af7a08a7f 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -267,6 +267,17 @@ CREATE TABLE audit_logs ( resource_icon text NOT NULL ); +CREATE TABLE dbcrypt_sentinel ( + only_one integer GENERATED ALWAYS AS (1) STORED, + val text DEFAULT ''::text NOT NULL +); + +COMMENT ON TABLE dbcrypt_sentinel IS 'A table used to determine if the database is encrypted'; + +COMMENT ON COLUMN dbcrypt_sentinel.only_one IS 'Ensures that only one row exists in the table.'; + +COMMENT ON COLUMN dbcrypt_sentinel.val IS 'Used to determine if the database is encrypted.'; + CREATE TABLE files ( hash character varying(64) NOT NULL, created_at timestamp with time zone NOT NULL, @@ -1028,6 +1039,9 @@ ALTER TABLE ONLY api_keys ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); +ALTER TABLE ONLY dbcrypt_sentinel + ADD CONSTRAINT dbcrypt_sentinel_only_one_key UNIQUE (only_one); + ALTER TABLE ONLY files ADD CONSTRAINT files_hash_created_by_key UNIQUE (hash, created_by); diff --git a/coderd/database/migrations/000153_dbcrypt_sentinel_value.down.sql b/coderd/database/migrations/000153_dbcrypt_sentinel_value.down.sql new file mode 100644 index 0000000000000..615b2c087227b --- /dev/null +++ b/coderd/database/migrations/000153_dbcrypt_sentinel_value.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS dbcrypt_sentinel; diff --git a/coderd/database/migrations/000153_dbcrypt_sentinel_value.up.sql b/coderd/database/migrations/000153_dbcrypt_sentinel_value.up.sql new file mode 100644 index 0000000000000..8c46a02ee1301 --- /dev/null +++ b/coderd/database/migrations/000153_dbcrypt_sentinel_value.up.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS dbcrypt_sentinel ( + only_one integer GENERATED ALWAYS AS (1) STORED UNIQUE, + val text NOT NULL DEFAULT ''::text +); + +COMMENT ON TABLE dbcrypt_sentinel IS 'A table used to determine if the database is encrypted'; +COMMENT ON COLUMN dbcrypt_sentinel.only_one IS 'Ensures that only one row exists in the table.'; +COMMENT ON COLUMN dbcrypt_sentinel.val IS 'Used to determine if the database is encrypted.'; diff --git a/coderd/database/migrations/migrate_test.go b/coderd/database/migrations/migrate_test.go index a138e58bac05f..de3d3995fb369 100644 --- a/coderd/database/migrations/migrate_test.go +++ b/coderd/database/migrations/migrate_test.go @@ -266,6 +266,7 @@ func TestMigrateUpWithFixtures(t *testing.T) { "template_version_parameters", "workspace_build_parameters", "template_version_variables", + "dbcrypt_sentinel", // having zero rows is a valid state for this table } s := &tableStats{s: make(map[string]int)} diff --git a/coderd/database/models.go b/coderd/database/models.go index 22dd6b257997c..d78285f01ddcd 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1524,6 +1524,14 @@ type AuditLog struct { ResourceIcon string `db:"resource_icon" json:"resource_icon"` } +// A table used to determine if the database is encrypted +type DbcryptSentinel struct { + // Ensures that only one row exists in the table. + OnlyOne sql.NullInt32 `db:"only_one" json:"only_one"` + // Used to determine if the database is encrypted. + Val string `db:"val" json:"val"` +} + type File struct { Hash string `db:"hash" json:"hash"` CreatedAt time.Time `db:"created_at" json:"created_at"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 520266bd1d25c..b4c4469bfe55b 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -58,6 +58,7 @@ type sqlcQuerier interface { // This function returns roles for authorization purposes. Implied member roles // are included. GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error) + GetDBCryptSentinelValue(ctx context.Context) (string, error) GetDERPMeshKey(ctx context.Context) (string, error) GetDefaultProxyConfig(ctx context.Context) (GetDefaultProxyConfigRow, error) GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]GetDeploymentDAUsRow, error) @@ -69,6 +70,7 @@ type sqlcQuerier interface { // Get all templates that use a file. GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]GetFileTemplatesRow, error) GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParams) (GitAuthLink, error) + GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]GitAuthLink, error) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error) GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error) GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error) @@ -150,6 +152,7 @@ type sqlcQuerier interface { GetUserLatencyInsights(ctx context.Context, arg GetUserLatencyInsightsParams) ([]GetUserLatencyInsightsRow, error) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error) + GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error) // This will never return deleted users. GetUsers(ctx context.Context, arg GetUsersParams) ([]GetUsersRow, error) // This shouldn't check for deleted, because it's frequently used @@ -247,6 +250,7 @@ type sqlcQuerier interface { InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error) RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) + SetDBCryptSentinelValue(ctx context.Context, val string) error // Non blocking lock. Returns true if the lock was acquired, false otherwise. // // This must be called from within a transaction. The lock will be automatically diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 95f357cb69835..04d8a545325e0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -636,6 +636,26 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam return i, err } +const getDBCryptSentinelValue = `-- name: GetDBCryptSentinelValue :one +SELECT val FROM dbcrypt_sentinel LIMIT 1 +` + +func (q *sqlQuerier) GetDBCryptSentinelValue(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, getDBCryptSentinelValue) + var val string + err := row.Scan(&val) + return val, err +} + +const setDBCryptSentinelValue = `-- name: SetDBCryptSentinelValue :exec +INSERT INTO dbcrypt_sentinel (val) VALUES ($1) ON CONFLICT (only_one) DO UPDATE SET val = excluded.val +` + +func (q *sqlQuerier) SetDBCryptSentinelValue(ctx context.Context, val string) error { + _, err := q.db.ExecContext(ctx, setDBCryptSentinelValue, val) + return err +} + const getFileByHashAndCreator = `-- name: GetFileByHashAndCreator :one SELECT hash, created_at, created_by, mimetype, data, id @@ -823,6 +843,41 @@ func (q *sqlQuerier) GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParam return i, err } +const getGitAuthLinksByUserID = `-- name: GetGitAuthLinksByUserID :many +SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry FROM git_auth_links WHERE user_id = $1 +` + +func (q *sqlQuerier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]GitAuthLink, error) { + rows, err := q.db.QueryContext(ctx, getGitAuthLinksByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GitAuthLink + for rows.Next() { + var i GitAuthLink + if err := rows.Scan( + &i.ProviderID, + &i.UserID, + &i.CreatedAt, + &i.UpdatedAt, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthExpiry, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertGitAuthLink = `-- name: InsertGitAuthLink :one INSERT INTO git_auth_links ( provider_id, @@ -5499,6 +5554,40 @@ func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUs return i, err } +const getUserLinksByUserID = `-- name: GetUserLinksByUserID :many +SELECT user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry FROM user_links WHERE user_id = $1 +` + +func (q *sqlQuerier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error) { + rows, err := q.db.QueryContext(ctx, getUserLinksByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []UserLink + for rows.Next() { + var i UserLink + if err := rows.Scan( + &i.UserID, + &i.LoginType, + &i.LinkedID, + &i.OAuthAccessToken, + &i.OAuthRefreshToken, + &i.OAuthExpiry, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertUserLink = `-- name: InsertUserLink :one INSERT INTO user_links ( diff --git a/coderd/database/queries/dbcrypt.sql b/coderd/database/queries/dbcrypt.sql new file mode 100644 index 0000000000000..780a4bc0952c4 --- /dev/null +++ b/coderd/database/queries/dbcrypt.sql @@ -0,0 +1,5 @@ +-- name: GetDBCryptSentinelValue :one +SELECT val FROM dbcrypt_sentinel LIMIT 1; + +-- name: SetDBCryptSentinelValue :exec +INSERT INTO dbcrypt_sentinel (val) VALUES ($1) ON CONFLICT (only_one) DO UPDATE SET val = excluded.val; diff --git a/coderd/database/queries/gitauth.sql b/coderd/database/queries/gitauth.sql index a35de98a08908..b83b481cf4672 100644 --- a/coderd/database/queries/gitauth.sql +++ b/coderd/database/queries/gitauth.sql @@ -1,6 +1,10 @@ -- name: GetGitAuthLink :one SELECT * FROM git_auth_links WHERE provider_id = $1 AND user_id = $2; +-- name: GetGitAuthLinksByUserID :many +SELECT * FROM git_auth_links WHERE user_id = $1; + + -- name: InsertGitAuthLink :one INSERT INTO git_auth_links ( provider_id, diff --git a/coderd/database/queries/user_links.sql b/coderd/database/queries/user_links.sql index 2390cb9782b30..69cd058b56caf 100644 --- a/coderd/database/queries/user_links.sql +++ b/coderd/database/queries/user_links.sql @@ -14,6 +14,9 @@ FROM WHERE user_id = $1 AND login_type = $2; +-- name: GetUserLinksByUserID :many +SELECT * FROM user_links WHERE user_id = $1; + -- name: InsertUserLink :one INSERT INTO user_links ( diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index 294b4b12d51af..ba238421fcd93 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -6,6 +6,7 @@ type UniqueConstraint string // UniqueConstraint enums. const ( + UniqueDbcryptSentinelOnlyOneKey UniqueConstraint = "dbcrypt_sentinel_only_one_key" // ALTER TABLE ONLY dbcrypt_sentinel ADD CONSTRAINT dbcrypt_sentinel_only_one_key UNIQUE (only_one); UniqueFilesHashCreatedByKey UniqueConstraint = "files_hash_created_by_key" // ALTER TABLE ONLY files ADD CONSTRAINT files_hash_created_by_key UNIQUE (hash, created_by); UniqueGitAuthLinksProviderIDUserIDKey UniqueConstraint = "git_auth_links_provider_id_user_id_key" // ALTER TABLE ONLY git_auth_links ADD CONSTRAINT git_auth_links_provider_id_user_id_key UNIQUE (provider_id, user_id); UniqueGroupMembersUserIDGroupIDKey UniqueConstraint = "group_members_user_id_group_id_key" // ALTER TABLE ONLY group_members ADD CONSTRAINT group_members_user_id_group_id_key UNIQUE (user_id, group_id); diff --git a/enterprise/dbcrypt/cipher.go b/enterprise/dbcrypt/cipher.go new file mode 100644 index 0000000000000..8c7b870e07660 --- /dev/null +++ b/enterprise/dbcrypt/cipher.go @@ -0,0 +1,126 @@ +package dbcrypt + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + + "golang.org/x/xerrors" +) + +type Cipher interface { + Encrypt([]byte) ([]byte, error) + Decrypt([]byte) ([]byte, error) + HexDigest() string +} + +// CipherAES256 returns a new AES-256 cipher. +func CipherAES256(key []byte) (*AES256, error) { + if len(key) != 32 { + return nil, xerrors.Errorf("key must be 32 bytes") + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + digest := fmt.Sprintf("%x", sha256.Sum256(key))[:7] + return &AES256{aead: aead, digest: digest}, nil +} + +type AES256 struct { + aead cipher.AEAD + // digest is the first 7 bytes of the hex-encoded SHA-256 digest of aead. + digest string +} + +func (a *AES256) Encrypt(plaintext []byte) ([]byte, error) { + nonce := make([]byte, a.aead.NonceSize()) + _, err := io.ReadFull(rand.Reader, nonce) + if err != nil { + return nil, err + } + dst := make([]byte, len(nonce)) + copy(dst, nonce) + return a.aead.Seal(dst, nonce, plaintext, nil), nil +} + +func (a *AES256) Decrypt(ciphertext []byte) ([]byte, error) { + if len(ciphertext) < a.aead.NonceSize() { + return nil, xerrors.Errorf("ciphertext too short") + } + decrypted, err := a.aead.Open(nil, ciphertext[:a.aead.NonceSize()], ciphertext[a.aead.NonceSize():], nil) + if err != nil { + return nil, &DecryptFailedError{Inner: err} + } + return decrypted, nil +} + +func (a *AES256) HexDigest() string { + return a.digest +} + +type ( + CipherDigest string + Ciphers struct { + primary string + m map[string]Cipher + } +) + +// NewCiphers returns a new Ciphers instance with the given ciphers. +// The first cipher in the list is the primary cipher. Any ciphers after the +// first are considered secondary ciphers and are only used for decryption. +func NewCiphers(cs ...Cipher) *Ciphers { + var primary string + m := make(map[string]Cipher) + for idx, c := range cs { + if _, ok := c.(*Ciphers); ok { + panic("developer error: do not nest Ciphers") + } + m[c.HexDigest()] = c + if idx == 0 { + primary = c.HexDigest() + } + } + return &Ciphers{primary: primary, m: m} +} + +// Encrypt encrypts the given plaintext using the primary cipher and returns the +// ciphertext. The ciphertext is prefixed with the primary cipher's digest. +func (cs Ciphers) Encrypt(plaintext []byte) ([]byte, error) { + c, ok := cs.m[cs.primary] + if !ok { + return nil, xerrors.Errorf("no ciphers configured") + } + prefix := []byte(c.HexDigest() + "-") + encrypted, err := c.Encrypt(plaintext) + if err != nil { + return nil, err + } + return append(prefix, encrypted...), nil +} + +// Decrypt decrypts the given ciphertext using the cipher indicated by the +// ciphertext's prefix. The prefix is the first 7 bytes of the hex-encoded +// SHA-256 digest of the cipher's key. Decryption will fail if the prefix +// does not match any of the configured ciphers. +func (cs Ciphers) Decrypt(ciphertext []byte) ([]byte, error) { + requiredPrefix := string(ciphertext[:7]) + c, ok := cs.m[requiredPrefix] + if !ok { + return nil, xerrors.Errorf("missing required decryption cipher %s", requiredPrefix) + } + return c.Decrypt(ciphertext[8:]) +} + +// HexDigest returns the digest of the primary cipher. +func (cs Ciphers) HexDigest() string { + return cs.primary +} diff --git a/enterprise/dbcrypt/cipher_test.go b/enterprise/dbcrypt/cipher_test.go new file mode 100644 index 0000000000000..b638c93ad2fa4 --- /dev/null +++ b/enterprise/dbcrypt/cipher_test.go @@ -0,0 +1,144 @@ +package dbcrypt_test + +import ( + "bytes" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/enterprise/dbcrypt" +) + +func TestCipherAES256(t *testing.T) { + t.Parallel() + + t.Run("ValidInput", func(t *testing.T) { + t.Parallel() + key := bytes.Repeat([]byte{'a'}, 32) + cipher, err := dbcrypt.CipherAES256(key) + require.NoError(t, err) + + output, err := cipher.Encrypt([]byte("hello world")) + require.NoError(t, err) + + response, err := cipher.Decrypt(output) + require.NoError(t, err) + require.Equal(t, "hello world", string(response)) + }) + + t.Run("InvalidInput", func(t *testing.T) { + t.Parallel() + key := bytes.Repeat([]byte{'a'}, 32) + cipher, err := dbcrypt.CipherAES256(key) + require.NoError(t, err) + _, err = cipher.Decrypt(bytes.Repeat([]byte{'a'}, 100)) + var decryptErr *dbcrypt.DecryptFailedError + require.ErrorAs(t, err, &decryptErr) + }) + + t.Run("InvalidKeySize", func(t *testing.T) { + t.Parallel() + + _, err := dbcrypt.CipherAES256(bytes.Repeat([]byte{'a'}, 31)) + require.ErrorContains(t, err, "key must be 32 bytes") + }) + + t.Run("TestNonce", func(t *testing.T) { + t.Parallel() + key := bytes.Repeat([]byte{'a'}, 32) + cipher, err := dbcrypt.CipherAES256(key) + require.NoError(t, err) + require.Equal(t, "3ba3f5f", cipher.HexDigest()) + + encrypted1, err := cipher.Encrypt([]byte("hello world")) + require.NoError(t, err) + encrypted2, err := cipher.Encrypt([]byte("hello world")) + require.NoError(t, err) + require.NotEqual(t, encrypted1, encrypted2, "nonce should be different for each encryption") + + munged := make([]byte, len(encrypted1)) + copy(munged, encrypted1) + munged[0] = munged[0] ^ 0xff + _, err = cipher.Decrypt(munged) + var decryptErr *dbcrypt.DecryptFailedError + require.ErrorAs(t, err, &decryptErr, "munging the first byte of the encrypted data should cause decryption to fail") + }) +} + +func TestCiphers(t *testing.T) { + t.Parallel() + + // Given: two ciphers + key1 := bytes.Repeat([]byte{'a'}, 32) + key2 := bytes.Repeat([]byte{'b'}, 32) + cipher1, err := dbcrypt.CipherAES256(key1) + require.NoError(t, err) + cipher2, err := dbcrypt.CipherAES256(key2) + require.NoError(t, err) + + ciphers := dbcrypt.NewCiphers(cipher1, cipher2) + + // Then: it should encrypt with the cipher1 + output, err := ciphers.Encrypt([]byte("hello world")) + require.NoError(t, err) + // The first 7 bytes of the output should be the hex digest of cipher1 + require.Equal(t, cipher1.HexDigest(), string(output[:7])) + + // And: it should decrypt successfully + decrypted, err := ciphers.Decrypt(output) + require.NoError(t, err) + require.Equal(t, "hello world", string(decrypted)) + + // Decryption of the above should fail with cipher2 + _, err = cipher2.Decrypt(output) + var decryptErr *dbcrypt.DecryptFailedError + require.ErrorAs(t, err, &decryptErr) + + // Decryption of data encrypted with cipher2 should succeed + output2, err := cipher2.Encrypt([]byte("hello world")) + require.NoError(t, err) + decrypted2, err := ciphers.Decrypt(bytes.Join([][]byte{[]byte(cipher2.HexDigest()), output2}, []byte{'-'})) + require.NoError(t, err) + require.Equal(t, "hello world", string(decrypted2)) + + // Decryption of data encrypted with cipher1 should succeed + output1, err := cipher1.Encrypt([]byte("hello world")) + require.NoError(t, err) + decrypted1, err := ciphers.Decrypt(bytes.Join([][]byte{[]byte(cipher1.HexDigest()), output1}, []byte{'-'})) + require.NoError(t, err) + require.Equal(t, "hello world", string(decrypted1)) + + // Wrapping a Ciphers with itself should panic. + require.PanicsWithValue(t, "developer error: do not nest Ciphers", func() { + _ = dbcrypt.NewCiphers(ciphers) + }) +} + +// This test ensures backwards compatibility. If it breaks, something is very wrong. +func TestCiphersBackwardCompatibility(t *testing.T) { + t.Parallel() + var ( + msg = "hello world" + key = bytes.Repeat([]byte{'a'}, 32) + //nolint: gosec // The below is the base64-encoded result of encrypting the above message with the above key. + encoded = `M2JhM2Y1Zi3r1KSStbmfMBXDzdjVcCrtumdMFsJ4QiYlb3fV1HB8yxg9obHaz5I=` + ) + + // This is the code that was used to generate the above. + // Note that the output of this code will change every time it is run. + // encrypted, err := cs.Encrypt([]byte(msg)) + // require.NoError(t, err) + // t.Logf("encoded: %q", base64.StdEncoding.EncodeToString(encrypted)) + + cipher, err := dbcrypt.CipherAES256(key) + require.NoError(t, err) + require.Equal(t, "3ba3f5f", cipher.HexDigest()) + cs := dbcrypt.NewCiphers(cipher) + + decoded, err := base64.StdEncoding.DecodeString(encoded) + require.NoError(t, err, "the encoded string should be valid base64") + decrypted, err := cs.Decrypt(decoded) + require.NoError(t, err, "decryption should succeed") + require.Equal(t, msg, string(decrypted), "decrypted message should match original message") +} diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go new file mode 100644 index 0000000000000..d197704af7f7d --- /dev/null +++ b/enterprise/dbcrypt/dbcrypt.go @@ -0,0 +1,296 @@ +// Package dbcrypt provides a database.Store wrapper that encrypts/decrypts +// values stored at rest in the database. +// +// Encryption is done using a Cipher. The Cipher is stored in an atomic pointer +// so that it can be rotated as required. +// +// The Cipher is currently used to encrypt/decrypt the following fields: +// - database.UserLink.OAuthAccessToken +// - database.UserLink.OAuthRefreshToken +// - database.GitAuthLink.OAuthAccessToken +// - database.GitAuthLink.OAuthRefreshToken +// - database.DBCryptSentinelValue +// +// Encrypted fields are stored in the following format: +// "dbcrypt-${b64encode(-)}" +// +// The first 7 characters of the cipher's SHA256 digest are used to identify the cipher +// used to encrypt the value. +// +// Multiple ciphers can be provided to support key rotation. The primary cipher is used +// to encrypt and decrypt all data. Secondary ciphers are only used for decryption. +// We currently only use a single secondary cipher. +package dbcrypt + +import ( + "context" + "database/sql" + "encoding/base64" + "errors" + "strings" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + + "github.com/google/uuid" + "golang.org/x/xerrors" +) + +// MagicPrefix is prepended to all encrypted values in the database. +// This is used to determine if a value is encrypted or not. +// If it is encrypted but a key is not provided, an error is returned. +// MagicPrefix will be followed by the first 7 characters of the cipher's +// SHA256 digest, followed by a dash, followed by the base64-encoded +// encrypted value. +const MagicPrefix = "dbcrypt-" + +// sentinelValue is the value that is stored in the database to indicate +// whether encryption is enabled. If not enabled, the value either not +// present, or is the raw string "coder". +// Otherwise, the value must be the encrypted value of the string "coder" +// using the current cipher. +const sentinelValue = "coder" + +var ( + ErrNotEnabled = xerrors.New("encryption is not enabled") + ErrSentinelMismatch = xerrors.New("database is already encrypted under a different key") + b64encode = base64.StdEncoding.EncodeToString + b64decode = base64.StdEncoding.DecodeString +) + +// DecryptFailedError is returned when decryption fails. +// It unwraps to sql.ErrNoRows. +type DecryptFailedError struct { + Inner error +} + +func (e *DecryptFailedError) Error() string { + return xerrors.Errorf("decrypt failed: %w", e.Inner).Error() +} + +func (*DecryptFailedError) Unwrap() error { + return sql.ErrNoRows +} + +// New creates a database.Store wrapper that encrypts/decrypts values +// stored at rest in the database. +func New(ctx context.Context, db database.Store, cs *Ciphers) (database.Store, error) { + if cs == nil { + return nil, xerrors.Errorf("no ciphers configured") + } + dbc := &dbCrypt{ + ciphers: cs, + Store: db, + } + // nolint: gocritic // This is allowed. + if err := ensureEncrypted(dbauthz.AsSystemRestricted(ctx), dbc); err != nil { + return nil, xerrors.Errorf("ensure encrypted database fields: %w", err) + } + return dbc, nil +} + +type dbCrypt struct { + ciphers *Ciphers + database.Store +} + +func (db *dbCrypt) InTx(function func(database.Store) error, txOpts *sql.TxOptions) error { + return db.Store.InTx(func(s database.Store) error { + return function(&dbCrypt{ + ciphers: db.ciphers, + Store: s, + }) + }, txOpts) +} + +func (db *dbCrypt) GetDBCryptSentinelValue(ctx context.Context) (string, error) { + rawValue, err := db.Store.GetDBCryptSentinelValue(ctx) + if err != nil { + return "", err + } + return rawValue, db.decryptFields(&rawValue) +} + +func (db *dbCrypt) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) { + link, err := db.Store.GetUserLinkByLinkedID(ctx, linkedID) + if err != nil { + return database.UserLink{}, err + } + return link, db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken) +} + +func (db *dbCrypt) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) { + links, err := db.Store.GetUserLinksByUserID(ctx, userID) + if err != nil { + return nil, err + } + for _, link := range links { + if err := db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken); err != nil { + return nil, err + } + } + return links, nil +} + +func (db *dbCrypt) GetUserLinkByUserIDLoginType(ctx context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) { + link, err := db.Store.GetUserLinkByUserIDLoginType(ctx, params) + if err != nil { + return database.UserLink{}, err + } + return link, db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken) +} + +func (db *dbCrypt) InsertUserLink(ctx context.Context, params database.InsertUserLinkParams) (database.UserLink, error) { + err := db.encryptFields(¶ms.OAuthAccessToken, ¶ms.OAuthRefreshToken) + if err != nil { + return database.UserLink{}, err + } + link, err := db.Store.InsertUserLink(ctx, params) + if err != nil { + return database.UserLink{}, err + } + return link, db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken) +} + +func (db *dbCrypt) UpdateUserLink(ctx context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) { + err := db.encryptFields(¶ms.OAuthAccessToken, ¶ms.OAuthRefreshToken) + if err != nil { + return database.UserLink{}, err + } + updated, err := db.Store.UpdateUserLink(ctx, params) + if err != nil { + return database.UserLink{}, err + } + return updated, db.decryptFields(&updated.OAuthAccessToken, &updated.OAuthRefreshToken) +} + +func (db *dbCrypt) InsertGitAuthLink(ctx context.Context, params database.InsertGitAuthLinkParams) (database.GitAuthLink, error) { + err := db.encryptFields(¶ms.OAuthAccessToken, ¶ms.OAuthRefreshToken) + if err != nil { + return database.GitAuthLink{}, err + } + link, err := db.Store.InsertGitAuthLink(ctx, params) + if err != nil { + return database.GitAuthLink{}, err + } + return link, db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken) +} + +func (db *dbCrypt) GetGitAuthLink(ctx context.Context, params database.GetGitAuthLinkParams) (database.GitAuthLink, error) { + link, err := db.Store.GetGitAuthLink(ctx, params) + if err != nil { + return database.GitAuthLink{}, err + } + return link, db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken) +} + +func (db *dbCrypt) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) { + links, err := db.Store.GetGitAuthLinksByUserID(ctx, userID) + if err != nil { + return nil, err + } + for _, link := range links { + if err := db.decryptFields(&link.OAuthAccessToken, &link.OAuthRefreshToken); err != nil { + return nil, err + } + } + return links, nil +} + +func (db *dbCrypt) UpdateGitAuthLink(ctx context.Context, params database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { + err := db.encryptFields(¶ms.OAuthAccessToken, ¶ms.OAuthRefreshToken) + if err != nil { + return database.GitAuthLink{}, err + } + updated, err := db.Store.UpdateGitAuthLink(ctx, params) + if err != nil { + return database.GitAuthLink{}, err + } + return updated, db.decryptFields(&updated.OAuthAccessToken, &updated.OAuthRefreshToken) +} + +func (db *dbCrypt) SetDBCryptSentinelValue(ctx context.Context, value string) error { + err := db.encryptFields(&value) + if err != nil { + return err + } + return db.Store.SetDBCryptSentinelValue(ctx, value) +} + +func (db *dbCrypt) encryptFields(fields ...*string) error { + // If no cipher is loaded, then we can't encrypt anything! + if db.ciphers == nil { + return ErrNotEnabled + } + + for _, field := range fields { + if field == nil { + continue + } + + encrypted, err := db.ciphers.Encrypt([]byte(*field)) + if err != nil { + return err + } + // Base64 is used to support UTF-8 encoding in PostgreSQL. + *field = MagicPrefix + b64encode(encrypted) + } + return nil +} + +// decryptFields decrypts the given fields in place. +// If the value fails to decrypt, sql.ErrNoRows will be returned. +func (db *dbCrypt) decryptFields(fields ...*string) error { + if db.ciphers == nil { + return ErrNotEnabled + } + + for _, field := range fields { + if field == nil { + continue + } + + if len(*field) < 8 || !strings.HasPrefix(*field, MagicPrefix) { + // We do not force decryption of unencrypted rows. This could be damaging + // to the deployment, and admins can always manually purge data. + continue + } + + data, err := b64decode((*field)[8:]) + if err != nil { + // If it's not base64 with the prefix, we should complain loudly. + return &DecryptFailedError{ + Inner: xerrors.Errorf("malformed encrypted field %q: %w", *field, err), + } + } + decrypted, err := db.ciphers.Decrypt(data) + if err != nil { + // If the encryption key changed, return our special error that unwraps to sql.ErrNoRows. + return &DecryptFailedError{Inner: err} + } + *field = string(decrypted) + } + return nil +} + +func ensureEncrypted(ctx context.Context, dbc *dbCrypt) error { + return dbc.InTx(func(s database.Store) error { + val, err := s.GetDBCryptSentinelValue(ctx) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return err + } + } + + if val != "" && val != sentinelValue { + return ErrSentinelMismatch + } + + // Mark the database as officially having been touched by the new cipher. + if err := s.SetDBCryptSentinelValue(ctx, sentinelValue); err != nil { + return xerrors.Errorf("mark database as encrypted: %w", err) + } + + return nil + }, nil) +} diff --git a/enterprise/dbcrypt/dbcrypt_test.go b/enterprise/dbcrypt/dbcrypt_test.go new file mode 100644 index 0000000000000..d5d088aab30d9 --- /dev/null +++ b/enterprise/dbcrypt/dbcrypt_test.go @@ -0,0 +1,289 @@ +package dbcrypt_test + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "io" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/enterprise/dbcrypt" +) + +func TestUserLinks(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("InsertUserLink", func(t *testing.T) { + t.Parallel() + db, crypt, cipher := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.UserLink(t, crypt, database.UserLink{ + UserID: user.ID, + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + require.Equal(t, link.OAuthAccessToken, "access") + require.Equal(t, link.OAuthRefreshToken, "refresh") + + link, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID) + require.NoError(t, err) + requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access") + requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh") + }) + + t.Run("UpdateUserLink", func(t *testing.T) { + t.Parallel() + db, crypt, cipher := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.UserLink(t, crypt, database.UserLink{ + UserID: user.ID, + }) + updated, err := crypt.UpdateUserLink(ctx, database.UpdateUserLinkParams{ + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + UserID: link.UserID, + LoginType: link.LoginType, + }) + require.NoError(t, err) + require.Equal(t, updated.OAuthAccessToken, "access") + require.Equal(t, updated.OAuthRefreshToken, "refresh") + + link, err = db.GetUserLinkByLinkedID(ctx, link.LinkedID) + require.NoError(t, err) + requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access") + requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh") + }) + + t.Run("GetUserLinkByLinkedID", func(t *testing.T) { + t.Parallel() + db, crypt, cipher := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.UserLink(t, crypt, database.UserLink{ + UserID: user.ID, + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + link, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID) + require.NoError(t, err) + requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access") + requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh") + }) + + t.Run("GetUserLinkByUserIDLoginType", func(t *testing.T) { + t.Parallel() + db, crypt, cipher := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.UserLink(t, crypt, database.UserLink{ + UserID: user.ID, + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: link.UserID, + LoginType: link.LoginType, + }) + require.NoError(t, err) + requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access") + requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh") + }) +} + +func TestGitAuthLinks(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("InsertGitAuthLink", func(t *testing.T) { + t.Parallel() + db, crypt, cipher := setup(t) + link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{ + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + require.Equal(t, link.OAuthAccessToken, "access") + require.Equal(t, link.OAuthRefreshToken, "refresh") + + link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access") + requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh") + }) + + t.Run("UpdateGitAuthLink", func(t *testing.T) { + t.Parallel() + db, crypt, cipher := setup(t) + link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{}) + updated, err := crypt.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + require.NoError(t, err) + require.Equal(t, updated.OAuthAccessToken, "access") + require.Equal(t, updated.OAuthRefreshToken, "refresh") + + link, err = db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access") + requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh") + }) + + t.Run("GetGitAuthLink", func(t *testing.T) { + t.Parallel() + db, crypt, cipher := setup(t) + link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{ + OAuthAccessToken: "access", + OAuthRefreshToken: "refresh", + }) + link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{ + UserID: link.UserID, + ProviderID: link.ProviderID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, cipher, link.OAuthAccessToken, "access") + requireEncryptedEquals(t, cipher, link.OAuthRefreshToken, "refresh") + }) +} + +func TestNew(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + // Given: a cipher is loaded + cipher := dbcrypt.NewCiphers(initCipher(t)) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rawDB, _ := dbtestutil.NewDB(t) + + // When: we init the crypt db + cryptDB, err := dbcrypt.New(ctx, rawDB, cipher) + require.NoError(t, err) + + // Then: the sentinel value is encrypted + cryptVal, err := cryptDB.GetDBCryptSentinelValue(ctx) + require.NoError(t, err) + require.Equal(t, "coder", cryptVal) + + rawVal, err := rawDB.GetDBCryptSentinelValue(ctx) + require.NoError(t, err) + require.Contains(t, rawVal, dbcrypt.MagicPrefix) + requireEncryptedEquals(t, cipher, rawVal, "coder") + }) + + t.Run("NoCipher", func(t *testing.T) { + t.Parallel() + // Given: no cipher is loaded + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rawDB, _ := dbtestutil.NewDB(t) + + // When: we init the crypt db + _, err := dbcrypt.New(ctx, rawDB, nil) + + // Then: an error is returned + require.ErrorContains(t, err, "no ciphers configured") + + // And: the sentinel value is not present + _, err = rawDB.GetDBCryptSentinelValue(ctx) + require.ErrorIs(t, err, sql.ErrNoRows) + }) + + t.Run("CipherChanged", func(t *testing.T) { + t.Parallel() + // Given: no cipher is loaded + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rawDB, _ := dbtestutil.NewDB(t) + + // And: the sentinel value is encrypted with a different cipher + cipher1 := initCipher(t) + field := "coder" + encrypted, err := dbcrypt.NewCiphers(cipher1).Encrypt([]byte(field)) + require.NoError(t, err) + b64encrypted := base64.StdEncoding.EncodeToString(encrypted) + require.NoError(t, rawDB.SetDBCryptSentinelValue(ctx, dbcrypt.MagicPrefix+b64encrypted)) + + // When: we init the crypt db with no access to the old cipher + cipher2 := initCipher(t) + _, err = dbcrypt.New(ctx, rawDB, dbcrypt.NewCiphers(cipher2)) + // Then: a special error is returned + require.ErrorIs(t, err, dbcrypt.ErrSentinelMismatch) + + // And the sentinel value should remain unchanged. For now. + rawVal, err := rawDB.GetDBCryptSentinelValue(ctx) + require.NoError(t, err) + requireEncryptedEquals(t, dbcrypt.NewCiphers(cipher1), rawVal, field) + + // When: we set the secondary cipher + cs := dbcrypt.NewCiphers(cipher2, cipher1) + _, err = dbcrypt.New(ctx, rawDB, cs) + // Then: no error is returned + require.NoError(t, err) + + // And the sentinel value should be re-encrypted with the new value. + rawVal, err = rawDB.GetDBCryptSentinelValue(ctx) + require.NoError(t, err) + requireEncryptedEquals(t, dbcrypt.NewCiphers(cipher2), rawVal, field) + }) +} + +func requireEncryptedEquals(t *testing.T, c dbcrypt.Cipher, value, expected string) { + t.Helper() + require.Greater(t, len(value), 8, "value is not encrypted") + require.Equal(t, dbcrypt.MagicPrefix, value[:8], "missing magic prefix") + data, err := base64.StdEncoding.DecodeString(value[8:]) + require.NoError(t, err, "invalid base64") + require.Greater(t, len(data), 8, "missing cipher digest") + require.Equal(t, c.HexDigest(), string(data[:7]), "cipher digest mismatch") + got, err := c.Decrypt(data) + require.NoError(t, err, "failed to decrypt data") + require.Equal(t, expected, string(got), "decrypted data does not match") +} + +func initCipher(t *testing.T) *dbcrypt.AES256 { + t.Helper() + key := make([]byte, 32) // AES-256 key size is 32 bytes + _, err := io.ReadFull(rand.Reader, key) + require.NoError(t, err) + c, err := dbcrypt.CipherAES256(key) + require.NoError(t, err) + return c +} + +func setup(t *testing.T) (db, cryptodb database.Store, ciphers *dbcrypt.Ciphers) { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rawDB, _ := dbtestutil.NewDB(t) + + _, err := rawDB.GetDBCryptSentinelValue(ctx) + require.ErrorIs(t, err, sql.ErrNoRows) + + ciphers = dbcrypt.NewCiphers(initCipher(t)) + cryptDB, err := dbcrypt.New(ctx, rawDB, ciphers) + require.NoError(t, err) + + rawVal, err := rawDB.GetDBCryptSentinelValue(ctx) + require.NoError(t, err) + require.Contains(t, rawVal, dbcrypt.MagicPrefix) + + cryptVal, err := cryptDB.GetDBCryptSentinelValue(ctx) + require.NoError(t, err) + require.Equal(t, "coder", cryptVal) + + return rawDB, cryptDB, ciphers +}