@@ -31,6 +31,11 @@ import (
31
31
32
32
var validProxyByHostnameRegex = regexp .MustCompile (`^[a-zA-Z0-9._-]+$` )
33
33
34
+ var errForeignKeyConstraint = & pq.Error {
35
+ Code : "23503" ,
36
+ Message : "update or delete on table violates foreign key constraint" ,
37
+ }
38
+
34
39
var errDuplicateKey = & pq.Error {
35
40
Code : "23505" ,
36
41
Message : "duplicate key value violates unique constraint" ,
@@ -45,6 +50,7 @@ func New() database.Store {
45
50
organizationMembers : make ([]database.OrganizationMember , 0 ),
46
51
organizations : make ([]database.Organization , 0 ),
47
52
users : make ([]database.User , 0 ),
53
+ dbcryptKeys : make ([]database.DBCryptKey , 0 ),
48
54
gitAuthLinks : make ([]database.GitAuthLink , 0 ),
49
55
groups : make ([]database.Group , 0 ),
50
56
groupMembers : make ([]database.GroupMember , 0 ),
@@ -117,6 +123,7 @@ type data struct {
117
123
// New tables
118
124
workspaceAgentStats []database.WorkspaceAgentStat
119
125
auditLogs []database.AuditLog
126
+ dbcryptKeys []database.DBCryptKey
120
127
files []database.File
121
128
gitAuthLinks []database.GitAuthLink
122
129
gitSSHKey []database.GitSSHKey
@@ -665,6 +672,39 @@ func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool {
665
672
return false
666
673
}
667
674
675
+ func (q * FakeQuerier ) insertDBCryptKeyNoLock (_ context.Context , arg database.InsertDBCryptKeyParams ) error {
676
+ err := validateDatabaseType (arg )
677
+ if err != nil {
678
+ return err
679
+ }
680
+
681
+ for _ , key := range q .dbcryptKeys {
682
+ if key .Number == arg .Number {
683
+ return errDuplicateKey
684
+ }
685
+ }
686
+
687
+ q .dbcryptKeys = append (q .dbcryptKeys , database.DBCryptKey {
688
+ Number : arg .Number ,
689
+ ActiveKeyDigest : sql.NullString {String : arg .ActiveKeyDigest , Valid : true },
690
+ Test : arg .Test ,
691
+ })
692
+ return nil
693
+ }
694
+
695
+ func (q * FakeQuerier ) GetActiveDBCryptKeys (_ context.Context ) ([]database.DBCryptKey , error ) {
696
+ q .mutex .RLock ()
697
+ defer q .mutex .RUnlock ()
698
+ ks := make ([]database.DBCryptKey , 0 , len (q .dbcryptKeys ))
699
+ for _ , k := range q .dbcryptKeys {
700
+ if ! k .ActiveKeyDigest .Valid {
701
+ continue
702
+ }
703
+ ks = append ([]database.DBCryptKey {}, k )
704
+ }
705
+ return ks , nil
706
+ }
707
+
668
708
func (* FakeQuerier ) AcquireLock (_ context.Context , _ int64 ) error {
669
709
return xerrors .New ("AcquireLock must only be called within a transaction" )
670
710
}
@@ -1151,6 +1191,14 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
1151
1191
}, nil
1152
1192
}
1153
1193
1194
+ func (q * FakeQuerier ) GetDBCryptKeys (_ context.Context ) ([]database.DBCryptKey , error ) {
1195
+ q .mutex .RLock ()
1196
+ defer q .mutex .RUnlock ()
1197
+ ks := make ([]database.DBCryptKey , 0 )
1198
+ ks = append (ks , q .dbcryptKeys ... )
1199
+ return ks , nil
1200
+ }
1201
+
1154
1202
func (q * FakeQuerier ) GetDERPMeshKey (_ context.Context ) (string , error ) {
1155
1203
q .mutex .RLock ()
1156
1204
defer q .mutex .RUnlock ()
@@ -1393,6 +1441,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL
1393
1441
return database.GitAuthLink {}, sql .ErrNoRows
1394
1442
}
1395
1443
1444
+ func (q * FakeQuerier ) GetGitAuthLinksByUserID (_ context.Context , userID uuid.UUID ) ([]database.GitAuthLink , error ) {
1445
+ q .mutex .RLock ()
1446
+ defer q .mutex .RUnlock ()
1447
+ gals := make ([]database.GitAuthLink , 0 )
1448
+ for _ , gal := range q .gitAuthLinks {
1449
+ if gal .UserID == userID {
1450
+ gals = append (gals , gal )
1451
+ }
1452
+ }
1453
+ return gals , nil
1454
+ }
1455
+
1396
1456
func (q * FakeQuerier ) GetGitSSHKey (_ context.Context , userID uuid.UUID ) (database.GitSSHKey , error ) {
1397
1457
q .mutex .RLock ()
1398
1458
defer q .mutex .RUnlock ()
@@ -2833,6 +2893,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
2833
2893
return database.UserLink {}, sql .ErrNoRows
2834
2894
}
2835
2895
2896
+ func (q * FakeQuerier ) GetUserLinksByUserID (_ context.Context , userID uuid.UUID ) ([]database.UserLink , error ) {
2897
+ q .mutex .RLock ()
2898
+ defer q .mutex .RUnlock ()
2899
+ uls := make ([]database.UserLink , 0 )
2900
+ for _ , ul := range q .userLinks {
2901
+ if ul .UserID == userID {
2902
+ uls = append (uls , ul )
2903
+ }
2904
+ }
2905
+ return uls , nil
2906
+ }
2907
+
2836
2908
func (q * FakeQuerier ) GetUsers (_ context.Context , params database.GetUsersParams ) ([]database.GetUsersRow , error ) {
2837
2909
if err := validateDatabaseType (params ); err != nil {
2838
2910
return nil , err
@@ -3846,6 +3918,11 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit
3846
3918
return alog , nil
3847
3919
}
3848
3920
3921
+ func (q * FakeQuerier ) InsertDBCryptKey (ctx context.Context , arg database.InsertDBCryptKeyParams ) error {
3922
+ // This only ever gets called inside a transaction, so we need to not lock.
3923
+ return q .insertDBCryptKeyNoLock (ctx , arg )
3924
+ }
3925
+
3849
3926
func (q * FakeQuerier ) InsertDERPMeshKey (_ context.Context , id string ) error {
3850
3927
q .mutex .Lock ()
3851
3928
defer q .mutex .Unlock ()
@@ -3892,13 +3969,15 @@ func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGi
3892
3969
defer q .mutex .Unlock ()
3893
3970
// nolint:gosimple
3894
3971
gitAuthLink := database.GitAuthLink {
3895
- ProviderID : arg .ProviderID ,
3896
- UserID : arg .UserID ,
3897
- CreatedAt : arg .CreatedAt ,
3898
- UpdatedAt : arg .UpdatedAt ,
3899
- OAuthAccessToken : arg .OAuthAccessToken ,
3900
- OAuthRefreshToken : arg .OAuthRefreshToken ,
3901
- OAuthExpiry : arg .OAuthExpiry ,
3972
+ ProviderID : arg .ProviderID ,
3973
+ UserID : arg .UserID ,
3974
+ CreatedAt : arg .CreatedAt ,
3975
+ UpdatedAt : arg .UpdatedAt ,
3976
+ OAuthAccessToken : arg .OAuthAccessToken ,
3977
+ OAuthAccessTokenKeyID : arg .OAuthAccessTokenKeyID ,
3978
+ OAuthRefreshToken : arg .OAuthRefreshToken ,
3979
+ OAuthRefreshTokenKeyID : arg .OAuthRefreshTokenKeyID ,
3980
+ OAuthExpiry : arg .OAuthExpiry ,
3902
3981
}
3903
3982
q .gitAuthLinks = append (q .gitAuthLinks , gitAuthLink )
3904
3983
return gitAuthLink , nil
@@ -4362,12 +4441,14 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
4362
4441
4363
4442
//nolint:gosimple
4364
4443
link := database.UserLink {
4365
- UserID : args .UserID ,
4366
- LoginType : args .LoginType ,
4367
- LinkedID : args .LinkedID ,
4368
- OAuthAccessToken : args .OAuthAccessToken ,
4369
- OAuthRefreshToken : args .OAuthRefreshToken ,
4370
- OAuthExpiry : args .OAuthExpiry ,
4444
+ UserID : args .UserID ,
4445
+ LoginType : args .LoginType ,
4446
+ LinkedID : args .LinkedID ,
4447
+ OAuthAccessToken : args .OAuthAccessToken ,
4448
+ OAuthAccessTokenKeyID : args .OAuthAccessTokenKeyID ,
4449
+ OAuthRefreshToken : args .OAuthRefreshToken ,
4450
+ OAuthRefreshTokenKeyID : args .OAuthRefreshTokenKeyID ,
4451
+ OAuthExpiry : args .OAuthExpiry ,
4371
4452
}
4372
4453
4373
4454
q .userLinks = append (q .userLinks , link )
@@ -4793,6 +4874,46 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
4793
4874
return database.WorkspaceProxy {}, sql .ErrNoRows
4794
4875
}
4795
4876
4877
+ func (q * FakeQuerier ) RevokeDBCryptKey (_ context.Context , activeKeyDigest string ) error {
4878
+ q .mutex .Lock ()
4879
+ defer q .mutex .Unlock ()
4880
+
4881
+ for i := range q .dbcryptKeys {
4882
+ key := q .dbcryptKeys [i ]
4883
+
4884
+ // Is the key already revoked?
4885
+ if ! key .ActiveKeyDigest .Valid {
4886
+ continue
4887
+ }
4888
+
4889
+ if key .ActiveKeyDigest .String != activeKeyDigest {
4890
+ continue
4891
+ }
4892
+
4893
+ // Check for foreign key constraints.
4894
+ for _ , ul := range q .userLinks {
4895
+ if (ul .OAuthAccessTokenKeyID .Valid && ul .OAuthAccessTokenKeyID .String == activeKeyDigest ) ||
4896
+ (ul .OAuthRefreshTokenKeyID .Valid && ul .OAuthRefreshTokenKeyID .String == activeKeyDigest ) {
4897
+ return errForeignKeyConstraint
4898
+ }
4899
+ }
4900
+ for _ , gal := range q .gitAuthLinks {
4901
+ if (gal .OAuthAccessTokenKeyID .Valid && gal .OAuthAccessTokenKeyID .String == activeKeyDigest ) ||
4902
+ (gal .OAuthRefreshTokenKeyID .Valid && gal .OAuthRefreshTokenKeyID .String == activeKeyDigest ) {
4903
+ return errForeignKeyConstraint
4904
+ }
4905
+ }
4906
+
4907
+ // Revoke the key.
4908
+ q .dbcryptKeys [i ].RevokedAt = sql.NullTime {Time : dbtime .Now (), Valid : true }
4909
+ q .dbcryptKeys [i ].RevokedKeyDigest = sql.NullString {String : key .ActiveKeyDigest .String , Valid : true }
4910
+ q .dbcryptKeys [i ].ActiveKeyDigest = sql.NullString {}
4911
+ return nil
4912
+ }
4913
+
4914
+ return sql .ErrNoRows
4915
+ }
4916
+
4796
4917
func (* FakeQuerier ) TryAcquireLock (_ context.Context , _ int64 ) (bool , error ) {
4797
4918
return false , xerrors .New ("TryAcquireLock must only be called within a transaction" )
4798
4919
}
@@ -4834,7 +4955,9 @@ func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
4834
4955
}
4835
4956
gitAuthLink .UpdatedAt = arg .UpdatedAt
4836
4957
gitAuthLink .OAuthAccessToken = arg .OAuthAccessToken
4958
+ gitAuthLink .OAuthAccessTokenKeyID = arg .OAuthAccessTokenKeyID
4837
4959
gitAuthLink .OAuthRefreshToken = arg .OAuthRefreshToken
4960
+ gitAuthLink .OAuthRefreshTokenKeyID = arg .OAuthRefreshTokenKeyID
4838
4961
gitAuthLink .OAuthExpiry = arg .OAuthExpiry
4839
4962
q .gitAuthLinks [index ] = gitAuthLink
4840
4963
@@ -5306,7 +5429,9 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
5306
5429
for i , link := range q .userLinks {
5307
5430
if link .UserID == params .UserID && link .LoginType == params .LoginType {
5308
5431
link .OAuthAccessToken = params .OAuthAccessToken
5432
+ link .OAuthAccessTokenKeyID = params .OAuthAccessTokenKeyID
5309
5433
link .OAuthRefreshToken = params .OAuthRefreshToken
5434
+ link .OAuthRefreshTokenKeyID = params .OAuthRefreshTokenKeyID
5310
5435
link .OAuthExpiry = params .OAuthExpiry
5311
5436
5312
5437
q .userLinks [i ] = link
0 commit comments