Skip to content

Commit e9a4817

Browse files
committed
fixing tests
1 parent b4eb230 commit e9a4817

File tree

2 files changed

+19
-162
lines changed

2 files changed

+19
-162
lines changed

tailnet/resume.go

+9-57
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@ package tailnet
33
import (
44
"context"
55
"crypto/rand"
6-
"database/sql"
7-
"encoding/hex"
86
"time"
97

10-
"github.com/go-jose/go-jose/v4"
118
"github.com/go-jose/go-jose/v4/jwt"
129
"github.com/google/uuid"
1310
"golang.org/x/xerrors"
@@ -53,47 +50,6 @@ func GenerateResumeTokenSigningKey() (ResumeTokenSigningKey, error) {
5350
return key, nil
5451
}
5552

56-
type ResumeTokenSigningKeyDatabaseStore interface {
57-
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
58-
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, key string) error
59-
}
60-
61-
// ResumeTokenSigningKeyFromDatabase retrieves the coordinator resume token
62-
// signing key from the database. If the key is not found, a new key is
63-
// generated and inserted into the database.
64-
func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSigningKeyDatabaseStore) (ResumeTokenSigningKey, error) {
65-
var resumeTokenKey ResumeTokenSigningKey
66-
resumeTokenKeyStr, err := db.GetCoordinatorResumeTokenSigningKey(ctx)
67-
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
68-
return resumeTokenKey, xerrors.Errorf("get coordinator resume token key: %w", err)
69-
}
70-
if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) {
71-
newKey, err := GenerateResumeTokenSigningKey()
72-
if err != nil {
73-
return resumeTokenKey, xerrors.Errorf("generate fresh coordinator resume token key: %w", err)
74-
}
75-
76-
resumeTokenKeyStr = hex.EncodeToString(newKey[:])
77-
err = db.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr)
78-
if err != nil {
79-
return resumeTokenKey, xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err)
80-
}
81-
}
82-
83-
resumeTokenKeyBytes, err := hex.DecodeString(resumeTokenKeyStr)
84-
if err != nil {
85-
return resumeTokenKey, xerrors.Errorf("decode coordinator resume token key from database: %w", err)
86-
}
87-
if len(resumeTokenKeyBytes) != len(resumeTokenKey) {
88-
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is not the correct length, expect %d got %d", len(resumeTokenKey), len(resumeTokenKeyBytes))
89-
}
90-
copy(resumeTokenKey[:], resumeTokenKeyBytes)
91-
if resumeTokenKey == [64]byte{} {
92-
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is empty")
93-
}
94-
return resumeTokenKey, nil
95-
}
96-
9753
type ResumeTokenKeyProvider struct {
9854
key jwtutils.SigningKeyManager
9955
clock quartz.Clock
@@ -111,19 +67,11 @@ func NewResumeTokenKeyProvider(key jwtutils.SigningKeyManager, clock quartz.Cloc
11167
}
11268
}
11369

114-
type resumeTokenPayload struct {
115-
jwt.Claims
116-
PeerID uuid.UUID `json:"sub"`
117-
Expiry int64 `json:"exp"`
118-
}
119-
12070
func (p ResumeTokenKeyProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
12171
exp := p.clock.Now().Add(p.expiry)
122-
payload := resumeTokenPayload{
123-
PeerID: peerID,
124-
Claims: jwt.Claims{
125-
Expiry: jwt.NewNumericDate(exp),
126-
},
72+
payload := jwt.Claims{
73+
Subject: peerID.String(),
74+
Expiry: jwt.NewNumericDate(exp),
12775
}
12876

12977
token, err := jwtutils.Sign(ctx, p.key, payload)
@@ -142,12 +90,16 @@ func (p ResumeTokenKeyProvider) GenerateResumeToken(ctx context.Context, peerID
14290
// returns the payload. If the token is invalid or expired, an error is
14391
// returned.
14492
func (p ResumeTokenKeyProvider) VerifyResumeToken(ctx context.Context, str string) (uuid.UUID, error) {
145-
var tok resumeTokenPayload
93+
var tok jwt.Claims
14694
err := jwtutils.Verify(ctx, p.key, str, &tok, jwtutils.WithVerifyExpected(jwt.Expected{
14795
Time: p.clock.Now(),
14896
}))
14997
if err != nil {
15098
return uuid.Nil, xerrors.Errorf("verify payload: %w", err)
15199
}
152-
return tok.PeerID, nil
100+
parsed, err := uuid.Parse(tok.Subject)
101+
if err != nil {
102+
return uuid.Nil, xerrors.Errorf("parse peerID from token: %w", err)
103+
}
104+
return parsed, nil
153105
}

tailnet/resume_test.go

+10-105
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,20 @@
11
package tailnet_test
22

33
import (
4-
"context"
5-
"encoding/hex"
64
"testing"
75
"time"
86

7+
"github.com/go-jose/go-jose/v4"
8+
"github.com/go-jose/go-jose/v4/jwt"
99
"github.com/google/uuid"
10-
"github.com/stretchr/testify/assert"
1110
"github.com/stretchr/testify/require"
12-
"go.uber.org/mock/gomock"
1311

14-
"github.com/coder/coder/v2/coderd/database/dbmock"
15-
"github.com/coder/coder/v2/coderd/database/dbtestutil"
1612
"github.com/coder/coder/v2/coderd/jwtutils"
1713
"github.com/coder/coder/v2/tailnet"
1814
"github.com/coder/coder/v2/testutil"
1915
"github.com/coder/quartz"
2016
)
2117

22-
func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
23-
t.Parallel()
24-
25-
assertRandomKey := func(t *testing.T, key tailnet.ResumeTokenSigningKey) {
26-
t.Helper()
27-
assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key should not be empty")
28-
assert.NotEqualValues(t, [64]byte{1}, key, "key should not be all 1s")
29-
}
30-
31-
t.Run("GenerateRetrieve", func(t *testing.T) {
32-
t.Parallel()
33-
34-
db, _ := dbtestutil.NewDB(t)
35-
ctx := testutil.Context(t, testutil.WaitShort)
36-
key1, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
37-
require.NoError(t, err)
38-
assertRandomKey(t, key1)
39-
40-
key2, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
41-
require.NoError(t, err)
42-
require.Equal(t, key1, key2, "keys should not be different")
43-
})
44-
45-
t.Run("GetError", func(t *testing.T) {
46-
t.Parallel()
47-
48-
db := dbmock.NewMockStore(gomock.NewController(t))
49-
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", assert.AnError)
50-
51-
ctx := testutil.Context(t, testutil.WaitShort)
52-
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
53-
require.ErrorIs(t, err, assert.AnError)
54-
})
55-
56-
t.Run("UpsertError", func(t *testing.T) {
57-
t.Parallel()
58-
59-
db := dbmock.NewMockStore(gomock.NewController(t))
60-
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", nil)
61-
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(assert.AnError)
62-
63-
ctx := testutil.Context(t, testutil.WaitShort)
64-
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
65-
require.ErrorIs(t, err, assert.AnError)
66-
})
67-
68-
t.Run("DecodeErrorShouldRegenerate", func(t *testing.T) {
69-
t.Parallel()
70-
71-
db := dbmock.NewMockStore(gomock.NewController(t))
72-
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("invalid", nil)
73-
74-
var storedKey tailnet.ResumeTokenSigningKey
75-
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Do(func(_ context.Context, value string) error {
76-
keyBytes, err := hex.DecodeString(value)
77-
require.NoError(t, err)
78-
require.Len(t, keyBytes, len(storedKey))
79-
copy(storedKey[:], keyBytes)
80-
return nil
81-
})
82-
83-
ctx := testutil.Context(t, testutil.WaitShort)
84-
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
85-
require.NoError(t, err)
86-
assertRandomKey(t, key)
87-
require.Equal(t, storedKey, key, "key should match stored value")
88-
})
89-
90-
t.Run("LengthErrorShouldRegenerate", func(t *testing.T) {
91-
t.Parallel()
92-
93-
db := dbmock.NewMockStore(gomock.NewController(t))
94-
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("deadbeef", nil)
95-
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(nil)
96-
97-
ctx := testutil.Context(t, testutil.WaitShort)
98-
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
99-
require.NoError(t, err)
100-
assertRandomKey(t, key)
101-
})
102-
103-
t.Run("EmptyError", func(t *testing.T) {
104-
t.Parallel()
105-
106-
db := dbmock.NewMockStore(gomock.NewController(t))
107-
emptyKey := hex.EncodeToString(make([]byte, 64))
108-
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return(emptyKey, nil)
109-
110-
ctx := testutil.Context(t, testutil.WaitShort)
111-
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
112-
require.ErrorContains(t, err, "is empty")
113-
})
114-
}
115-
11618
func TestResumeTokenKeyProvider(t *testing.T) {
11719
t.Parallel()
11820

@@ -156,7 +58,7 @@ func TestResumeTokenKeyProvider(t *testing.T) {
15658
_ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second)
15759

15860
_, err = provider.VerifyResumeToken(ctx, token.Token)
159-
require.ErrorContains(t, err, "expired")
61+
require.ErrorIs(t, err, jwt.ErrExpired)
16062
})
16163

16264
t.Run("InvalidToken", func(t *testing.T) {
@@ -175,17 +77,20 @@ func TestResumeTokenKeyProvider(t *testing.T) {
17577
// Generate a resume token with a different key
17678
otherKey, err := tailnet.GenerateResumeTokenSigningKey()
17779
require.NoError(t, err)
178-
otherProvider := tailnet.NewResumeTokenKeyProvider(newKeySigner(otherKey), quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
80+
otherSigner := newKeySigner(otherKey)
81+
otherProvider := tailnet.NewResumeTokenKeyProvider(otherSigner, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
17982
token, err := otherProvider.GenerateResumeToken(ctx, uuid.New())
18083
require.NoError(t, err)
18184

182-
provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
85+
signer := newKeySigner(key)
86+
signer.ID = otherSigner.ID
87+
provider := tailnet.NewResumeTokenKeyProvider(signer, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
18388
_, err = provider.VerifyResumeToken(ctx, token.Token)
184-
require.ErrorContains(t, err, "verify JWS")
89+
require.ErrorIs(t, err, jose.ErrCryptoFailure)
18590
})
18691
}
18792

188-
func newKeySigner(key tailnet.ResumeTokenSigningKey) jwtutils.SigningKeyManager {
93+
func newKeySigner(key tailnet.ResumeTokenSigningKey) jwtutils.StaticKeyManager {
18994
return jwtutils.StaticKeyManager{
19095
ID: uuid.New().String(),
19196
Key: key[:],

0 commit comments

Comments
 (0)