@@ -8,10 +8,13 @@ import (
8
8
"io"
9
9
"testing"
10
10
11
+ "github.com/golang/mock/gomock"
12
+ "github.com/lib/pq"
11
13
"github.com/stretchr/testify/require"
12
14
13
15
"github.com/coder/coder/v2/coderd/database"
14
16
"github.com/coder/coder/v2/coderd/database/dbgen"
17
+ "github.com/coder/coder/v2/coderd/database/dbmock"
15
18
"github.com/coder/coder/v2/coderd/database/dbtestutil"
16
19
)
17
20
@@ -470,6 +473,46 @@ func TestNew(t *testing.T) {
470
473
require .Error (t , err )
471
474
require .ErrorContains (t , err , "has been revoked" )
472
475
})
476
+
477
+ t .Run ("Retry" , func (t * testing.T ) {
478
+ t .Parallel ()
479
+ // Given: a cipher is loaded
480
+ cipher := initCipher (t )
481
+ ctx , cancel := context .WithCancel (context .Background ())
482
+ testVal , err := cipher .Encrypt ([]byte ("coder" ))
483
+ key := database.DBCryptKey {
484
+ Number : 1 ,
485
+ ActiveKeyDigest : sql.NullString {String : cipher .HexDigest (), Valid : true },
486
+ Test : b64encode (testVal ),
487
+ }
488
+ require .NoError (t , err )
489
+ t .Cleanup (cancel )
490
+
491
+ // And: a database that returns an error once when we try to serialize a key
492
+ ctrl := gomock .NewController (t )
493
+ mockDB := dbmock .NewMockStore (ctrl )
494
+
495
+ gomock .InOrder (
496
+ // First try: we get a serialization error.
497
+ expectTx (mockDB ),
498
+ mockDB .EXPECT ().GetDBCryptKeys (gomock .Any ()).Times (1 ).Return ([]database.DBCryptKey {}, nil ),
499
+ mockDB .EXPECT ().InsertDBCryptKey (gomock .Any (), gomock .Any ()).Times (1 ).Return (& pq.Error {Code : "40001" }),
500
+ // Second try: we get the key we wanted to insert initially.
501
+ expectTx (mockDB ),
502
+ mockDB .EXPECT ().GetDBCryptKeys (gomock .Any ()).Times (1 ).Return ([]database.DBCryptKey {key }, nil ),
503
+ )
504
+
505
+ _ , err = New (ctx , mockDB , cipher )
506
+ require .NoError (t , err )
507
+ })
508
+ }
509
+
510
+ func expectTx (mdb * dbmock.MockStore ) * gomock.Call {
511
+ return mdb .EXPECT ().InTx (gomock .Any (), gomock .Any ()).Times (1 ).DoAndReturn (
512
+ func (f func (store database.Store ) error , _ * sql.TxOptions ) error {
513
+ return f (mdb )
514
+ },
515
+ )
473
516
}
474
517
475
518
func requireEncryptedEquals (t * testing.T , c Cipher , value , expected string ) {
@@ -511,3 +554,11 @@ func fakeBase64RandomData(t *testing.T, n int) string {
511
554
require .NoError (t , err )
512
555
return base64 .StdEncoding .EncodeToString (b )
513
556
}
557
+
558
+ func withInTx (mTx * dbmock.MockStore ) {
559
+ mTx .EXPECT ().InTx (gomock .Any (), gomock .Any ()).Times (1 ).DoAndReturn (
560
+ func (f func (store database.Store ) error , _ * sql.TxOptions ) error {
561
+ return f (mTx )
562
+ },
563
+ )
564
+ }
0 commit comments