4
4
"context"
5
5
"database/sql"
6
6
"encoding/base64"
7
- "runtime"
8
7
"strings"
9
8
"sync/atomic"
10
9
@@ -56,19 +55,15 @@ func (db *dbCrypt) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (
56
55
if err != nil {
57
56
return database.UserLink {}, err
58
57
}
59
- return link , db .decryptFields (func () error {
60
- return db .Store .DeleteUserLinkByLinkedID (ctx , linkedID )
61
- }, & link .OAuthAccessToken , & link .OAuthRefreshToken )
58
+ return link , db .decryptFields (& link .OAuthAccessToken , & link .OAuthRefreshToken )
62
59
}
63
60
64
61
func (db * dbCrypt ) GetUserLinkByUserIDLoginType (ctx context.Context , params database.GetUserLinkByUserIDLoginTypeParams ) (database.UserLink , error ) {
65
62
link , err := db .Store .GetUserLinkByUserIDLoginType (ctx , params )
66
63
if err != nil {
67
64
return database.UserLink {}, err
68
65
}
69
- return link , db .decryptFields (func () error {
70
- return db .Store .DeleteUserLinkByLinkedID (ctx , link .LinkedID )
71
- }, & link .OAuthAccessToken , & link .OAuthRefreshToken )
66
+ return link , db .decryptFields (& link .OAuthAccessToken , & link .OAuthRefreshToken )
72
67
}
73
68
74
69
func (db * dbCrypt ) InsertUserLink (ctx context.Context , params database.InsertUserLinkParams ) (database.UserLink , error ) {
@@ -100,12 +95,7 @@ func (db *dbCrypt) GetGitAuthLink(ctx context.Context, params database.GetGitAut
100
95
if err != nil {
101
96
return database.GitAuthLink {}, err
102
97
}
103
- return link , db .decryptFields (func () error {
104
- return db .Store .DeleteGitAuthLink (ctx , database.DeleteGitAuthLinkParams { // nolint:gosimple
105
- ProviderID : params .ProviderID ,
106
- UserID : params .UserID ,
107
- })
108
- }, & link .OAuthAccessToken , & link .OAuthRefreshToken )
98
+ return link , db .decryptFields (& link .OAuthAccessToken , & link .OAuthRefreshToken )
109
99
}
110
100
111
101
func (db * dbCrypt ) UpdateGitAuthLink (ctx context.Context , params database.UpdateGitAuthLinkParams ) (database.GitAuthLink , error ) {
@@ -140,20 +130,7 @@ func (db *dbCrypt) encryptFields(fields ...*string) error {
140
130
141
131
// decryptFields decrypts the given fields in place.
142
132
// If the value fails to decrypt, sql.ErrNoRows will be returned.
143
- func (db * dbCrypt ) decryptFields (deleteFn func () error , fields ... * string ) error {
144
- doDelete := func (reason string ) error {
145
- err := deleteFn ()
146
- if err != nil {
147
- return xerrors .Errorf ("delete encrypted row: %w" , err )
148
- }
149
- pc , _ , _ , ok := runtime .Caller (2 )
150
- details := runtime .FuncForPC (pc )
151
- if ok && details != nil {
152
- db .Logger .Debug (context .Background (), "deleted row" , slog .F ("reason" , reason ), slog .F ("caller" , details .Name ()))
153
- }
154
- return sql .ErrNoRows
155
- }
156
-
133
+ func (db * dbCrypt ) decryptFields (fields ... * string ) error {
157
134
cipherPtr := db .ExternalTokenCipher .Load ()
158
135
// If no cipher is loaded, then we don't need to encrypt or decrypt anything!
159
136
if cipherPtr == nil {
@@ -163,8 +140,8 @@ func (db *dbCrypt) decryptFields(deleteFn func() error, fields ...*string) error
163
140
}
164
141
if strings .HasPrefix (* field , MagicPrefix ) {
165
142
// If we have a magic prefix but encryption is disabled,
166
- // we should delete the row .
167
- return doDelete ( " encryption disabled" )
143
+ // complain loudly .
144
+ return xerrors . Errorf ( "failed to decrypt field %q: encryption is disabled", * field )
168
145
}
169
146
}
170
147
return nil
@@ -182,13 +159,13 @@ func (db *dbCrypt) decryptFields(deleteFn func() error, fields ...*string) error
182
159
}
183
160
data , err := base64 .StdEncoding .DecodeString ((* field )[len (MagicPrefix ):])
184
161
if err != nil {
185
- // If it's not base64 with the prefix, we should delete the row .
186
- return doDelete ( "stored value was not base64 encoded" )
162
+ // If it's not base64 with the prefix, we should complain loudly .
163
+ return xerrors . Errorf ( "malformed encrypted field %q: %w" , * field , err )
187
164
}
188
165
decrypted , err := cipher .Decrypt (data )
189
166
if err != nil {
190
- // If the encryption key changed, we should delete the row .
191
- return doDelete ( "encryption key changed" )
167
+ // If the encryption key changed, return our special error that unwraps to sql.ErrNoRows .
168
+ return & DecryptFailedError { Inner : err }
192
169
}
193
170
* field = string (decrypted )
194
171
}
0 commit comments