Skip to content

Commit a81eac1

Browse files
committed
PR comments 4
1 parent 3db51e0 commit a81eac1

File tree

8 files changed

+71
-57
lines changed

8 files changed

+71
-57
lines changed

coderd/coderdtest/coderdtest.go

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -96,25 +96,26 @@ type Options struct {
9696
// AccessURL denotes a custom access URL. By default we use the httptest
9797
// server's URL. Setting this may result in unexpected behavior (especially
9898
// with running agents).
99-
AccessURL *url.URL
100-
AppHostname string
101-
AWSCertificates awsidentity.Certificates
102-
Authorizer rbac.Authorizer
103-
AzureCertificates x509.VerifyOptions
104-
GithubOAuth2Config *coderd.GithubOAuth2Config
105-
RealIPConfig *httpmw.RealIPConfig
106-
OIDCConfig *coderd.OIDCConfig
107-
GoogleTokenValidator *idtoken.Validator
108-
SSHKeygenAlgorithm gitsshkey.Algorithm
109-
AutobuildTicker <-chan time.Time
110-
AutobuildStats chan<- autobuild.Stats
111-
Auditor audit.Auditor
112-
TLSCertificates []tls.Certificate
113-
ExternalAuthConfigs []*externalauth.Config
114-
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
115-
RefreshEntitlements func(ctx context.Context) error
116-
TemplateScheduleStore schedule.TemplateScheduleStore
117-
Coordinator tailnet.Coordinator
99+
AccessURL *url.URL
100+
AppHostname string
101+
AWSCertificates awsidentity.Certificates
102+
Authorizer rbac.Authorizer
103+
AzureCertificates x509.VerifyOptions
104+
GithubOAuth2Config *coderd.GithubOAuth2Config
105+
RealIPConfig *httpmw.RealIPConfig
106+
OIDCConfig *coderd.OIDCConfig
107+
GoogleTokenValidator *idtoken.Validator
108+
SSHKeygenAlgorithm gitsshkey.Algorithm
109+
AutobuildTicker <-chan time.Time
110+
AutobuildStats chan<- autobuild.Stats
111+
Auditor audit.Auditor
112+
TLSCertificates []tls.Certificate
113+
ExternalAuthConfigs []*externalauth.Config
114+
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
115+
RefreshEntitlements func(ctx context.Context) error
116+
TemplateScheduleStore schedule.TemplateScheduleStore
117+
Coordinator tailnet.Coordinator
118+
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
118119

119120
HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
120121
HealthcheckTimeout time.Duration
@@ -240,6 +241,9 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
240241
if options.Database == nil {
241242
options.Database, options.Pubsub = dbtestutil.NewDB(t)
242243
}
244+
if options.CoordinatorResumeTokenProvider == nil {
245+
options.CoordinatorResumeTokenProvider = tailnet.NewInsecureTestResumeTokenProvider()
246+
}
243247

244248
if options.NotificationsEnqueuer == nil {
245249
options.NotificationsEnqueuer = new(testutil.FakeNotificationsEnqueuer)
@@ -492,7 +496,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
492496
TailnetCoordinator: options.Coordinator,
493497
BaseDERPMap: derpMap,
494498
DERPMapUpdateFrequency: 150 * time.Millisecond,
495-
CoordinatorResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
499+
CoordinatorResumeTokenProvider: options.CoordinatorResumeTokenProvider,
496500
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
497501
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
498502
DeploymentValues: options.DeploymentValues,

coderd/database/dbauthz/dbauthz.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1367,7 +1367,7 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
13671367
}
13681368

13691369
func (q *querier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
1370-
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
1370+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
13711371
return "", err
13721372
}
13731373
return q.db.GetCoordinatorResumeTokenSigningKey(ctx)

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2567,7 +2567,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
25672567
}))
25682568
s.Run("GetCoordinatorResumeTokenSigningKey", s.Subtest(func(db database.Store, check *expects) {
25692569
db.UpsertCoordinatorResumeTokenSigningKey(context.Background(), "foo")
2570-
check.Args().Asserts(rbac.ResourceSystem, policy.ActionUpdate)
2570+
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
25712571
}))
25722572
s.Run("InsertMissingGroups", s.Subtest(func(db database.Store, check *expects) {
25732573
check.Args(database.InsertMissingGroupsParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate).Errors(errMatchAny)

coderd/workspaceagents_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import (
4545
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
4646
"github.com/coder/coder/v2/tailnet/tailnettest"
4747
"github.com/coder/coder/v2/testutil"
48+
"github.com/coder/quartz"
4849
)
4950

5051
func TestWorkspaceAgent(t *testing.T) {
@@ -533,11 +534,16 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
533534
t.Parallel()
534535

535536
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
537+
clock := quartz.NewMock(t)
538+
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
539+
require.NoError(t, err)
540+
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
536541
coordinator := &resumeTokenTestFakeCoordinator{
537542
Coordinator: tailnet.NewCoordinator(logger),
538543
}
539544
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
540-
Coordinator: coordinator,
545+
Coordinator: coordinator,
546+
CoordinatorResumeTokenProvider: resumeTokenProvider,
541547
})
542548
defer closer.Close()
543549
user := coderdtest.CreateFirstUser(t, client)
@@ -564,6 +570,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
564570

565571
// Connect with a valid resume token, and ensure that the peer ID is set to
566572
// the stored value.
573+
clock.Advance(time.Second)
567574
coordinator.lastPeerID = uuid.Nil
568575
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
569576
require.NoError(t, err)
@@ -572,6 +579,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
572579

573580
// Connect with an invalid resume token, and ensure that the request is
574581
// rejected.
582+
clock.Advance(time.Second)
575583
coordinator.lastPeerID = uuid.Nil
576584
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
577585
require.Error(t, err)

codersdk/workspacesdk/connector.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
182182
close(tac.connected)
183183
}
184184
if err != nil {
185-
didLog := false
186185
bodyErr := codersdk.ReadBodyAsError(res)
187186
var sdkErr *codersdk.Error
188187
if xerrors.As(bodyErr, &sdkErr) {
@@ -191,11 +190,11 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
191190
// Unset the resume token for the next attempt
192191
tac.logger.Warn(tac.ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
193192
tac.resumeToken = nil
194-
didLog = true
193+
return nil, err
195194
}
196195
}
197196
}
198-
if !didLog && !errors.Is(err, context.Canceled) {
197+
if !errors.Is(err, context.Canceled) {
199198
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
200199
}
201200
return nil, err

codersdk/workspacesdk/connector_internal_test.go

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
160160
derpMapCh := make(chan *tailcfg.DERPMap)
161161
defer close(derpMapCh)
162162

163-
resumeTokenProvider := tailnet.NewInsecureTestResumeTokenProvider()
163+
clock := quartz.NewMock(t)
164+
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
165+
require.NoError(t, err)
166+
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
164167
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
165168
Logger: logger,
166169
CoordPtr: &coordPtr,
@@ -215,18 +218,16 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
215218

216219
fConn := newFakeTailnetConn()
217220

218-
clock := quartz.NewMock(t)
219221
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
220222
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
221223
defer newTickerTrap.Close()
222224
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{})
223225
uut.runConnector(fConn)
224226

225-
// Fetch first token.
227+
// Fetch first token. We don't need to advance the clock since we use a
228+
// channel with a single item to immediately fetch.
226229
trappedTicker := newTickerTrap.MustWait(ctx)
227230
trappedTicker.Release()
228-
waiter := clock.Advance(trappedTicker.Duration)
229-
waiter.MustWait(ctx)
230231
// We call ticker.Reset after each token fetch to apply the refresh duration
231232
// requested by the server.
232233
trappedReset := tickerResetTrap.MustWait(ctx)
@@ -235,7 +236,7 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
235236
originalResumeToken := uut.resumeToken.Token
236237

237238
// Fetch second token.
238-
waiter = clock.Advance(trappedReset.Duration)
239+
waiter := clock.Advance(trappedReset.Duration)
239240
waiter.MustWait(ctx)
240241
trappedReset = tickerResetTrap.MustWait(ctx)
241242
trappedReset.Release()
@@ -275,13 +276,17 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
275276
derpMapCh := make(chan *tailcfg.DERPMap)
276277
defer close(derpMapCh)
277278

279+
clock := quartz.NewMock(t)
280+
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
281+
require.NoError(t, err)
282+
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
278283
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
279284
Logger: logger,
280285
CoordPtr: &coordPtr,
281286
DERPMapUpdateFrequency: time.Millisecond,
282287
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
283288
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
284-
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
289+
ResumeTokenProvider: resumeTokenProvider,
285290
})
286291
require.NoError(t, err)
287292

@@ -317,7 +322,6 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
317322

318323
fConn := newFakeTailnetConn()
319324

320-
clock := quartz.NewMock(t)
321325
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
322326
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
323327
defer newTickerTrap.Close()
@@ -327,8 +331,6 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
327331
// Wait for the resume token to be fetched for the first time.
328332
trappedTicker := newTickerTrap.MustWait(ctx)
329333
trappedTicker.Release()
330-
waiter := clock.Advance(trappedTicker.Duration)
331-
waiter.MustWait(ctx)
332334
trappedReset := tickerResetTrap.MustWait(ctx)
333335
trappedReset.Release()
334336
originalResumeToken := uut.resumeToken.Token
@@ -346,9 +348,6 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
346348
// Since we failed the initial reconnect and we're definitely reconnected
347349
// now, the stored resume token should now be nil.
348350
require.Nil(t, uut.resumeToken)
349-
// Continue to the next token fetch.
350-
waiter = clock.Advance(trappedTicker.Duration)
351-
waiter.MustWait(ctx)
352351
trappedReset = tickerResetTrap.MustWait(ctx)
353352
trappedReset.Release()
354353
require.NotNil(t, uut.resumeToken)

tailnet/resume.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,14 @@ func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, clock quartz.Clock, ex
116116

117117
type resumeTokenPayload struct {
118118
PeerID uuid.UUID `json:"sub"`
119-
Expiry time.Time `json:"exp"`
119+
Expiry int64 `json:"exp"`
120120
}
121121

122122
func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
123+
exp := p.clock.Now().Add(p.expiry)
123124
payload := resumeTokenPayload{
124125
PeerID: peerID,
125-
Expiry: p.clock.Now().Add(p.expiry),
126+
Expiry: exp.Unix(),
126127
}
127128
payloadBytes, err := json.Marshal(payload)
128129
if err != nil {
@@ -154,11 +155,11 @@ func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.Re
154155
return &proto.RefreshResumeTokenResponse{
155156
Token: serialized,
156157
RefreshIn: durationpb.New(p.expiry / 2),
157-
ExpiresAt: timestamppb.New(payload.Expiry),
158+
ExpiresAt: timestamppb.New(exp),
158159
}, nil
159160
}
160161

161-
// VerifySignedToken parses a signed workspace app token with the given key and
162+
// VerifyResumeToken parses a signed tailnet resume token with the given key and
162163
// returns the payload. If the token is invalid or expired, an error is
163164
// returned.
164165
func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error) {
@@ -186,7 +187,8 @@ func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error)
186187
if err != nil {
187188
return uuid.Nil, xerrors.Errorf("unmarshal payload: %w", err)
188189
}
189-
if tok.Expiry.Before(p.clock.Now()) {
190+
exp := time.Unix(tok.Expiry, 0)
191+
if exp.Before(p.clock.Now()) {
190192
return uuid.Nil, xerrors.New("signed resume token expired")
191193
}
192194

tailnet/resume_test.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package tailnet_test
22

33
import (
4+
"context"
45
"encoding/hex"
56
"testing"
67
"time"
@@ -22,8 +23,8 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
2223

2324
assertRandomKey := func(t *testing.T, key tailnet.ResumeTokenSigningKey) {
2425
t.Helper()
25-
assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key is empty")
26-
assert.NotEqualValues(t, [64]byte{1}, key, "key is all 1s")
26+
assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key should not be empty")
27+
assert.NotEqualValues(t, [64]byte{1}, key, "key should not be all 1s")
2728
}
2829

2930
t.Run("GenerateRetrieve", func(t *testing.T) {
@@ -37,7 +38,7 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
3738

3839
key2, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
3940
require.NoError(t, err)
40-
require.Equal(t, key1, key2, "keys are different")
41+
require.Equal(t, key1, key2, "keys should not be different")
4142
})
4243

4344
t.Run("GetError", func(t *testing.T) {
@@ -48,7 +49,6 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
4849

4950
ctx := testutil.Context(t, testutil.WaitShort)
5051
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
51-
require.Error(t, err)
5252
require.ErrorIs(t, err, assert.AnError)
5353
})
5454

@@ -61,7 +61,6 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
6161

6262
ctx := testutil.Context(t, testutil.WaitShort)
6363
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
64-
require.Error(t, err)
6564
require.ErrorIs(t, err, assert.AnError)
6665
})
6766

@@ -70,12 +69,21 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
7069

7170
db := dbmock.NewMockStore(gomock.NewController(t))
7271
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("invalid", nil)
73-
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(nil)
72+
73+
var storedKey tailnet.ResumeTokenSigningKey
74+
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Do(func(_ context.Context, value string) error {
75+
keyBytes, err := hex.DecodeString(value)
76+
require.NoError(t, err)
77+
require.Len(t, keyBytes, len(storedKey))
78+
copy(storedKey[:], keyBytes)
79+
return nil
80+
})
7481

7582
ctx := testutil.Context(t, testutil.WaitShort)
7683
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
7784
require.NoError(t, err)
7885
assertRandomKey(t, key)
86+
require.Equal(t, storedKey, key, "key should match stored value")
7987
})
8088

8189
t.Run("LengthErrorShouldRegenerate", func(t *testing.T) {
@@ -100,7 +108,6 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
100108

101109
ctx := testutil.Context(t, testutil.WaitShort)
102110
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
103-
require.Error(t, err)
104111
require.ErrorContains(t, err, "is empty")
105112
})
106113
}
@@ -115,16 +122,14 @@ func TestResumeTokenKeyProvider(t *testing.T) {
115122
t.Parallel()
116123

117124
id := uuid.New()
118-
now := time.Now()
119125
clock := quartz.NewMock(t)
120-
clock.Set(now)
121126
provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry)
122127
token, err := provider.GenerateResumeToken(id)
123128
require.NoError(t, err)
124129
require.NotNil(t, token)
125130
require.NotEmpty(t, token.Token)
126131
require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration())
127-
require.WithinDuration(t, now.Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
132+
require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
128133

129134
gotID, err := provider.VerifyResumeToken(token.Token)
130135
require.NoError(t, err)
@@ -150,7 +155,6 @@ func TestResumeTokenKeyProvider(t *testing.T) {
150155
_ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second)
151156

152157
_, err = provider.VerifyResumeToken(token.Token)
153-
require.Error(t, err)
154158
require.ErrorContains(t, err, "expired")
155159
})
156160

@@ -159,7 +163,6 @@ func TestResumeTokenKeyProvider(t *testing.T) {
159163

160164
provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
161165
_, err := provider.VerifyResumeToken("invalid")
162-
require.Error(t, err)
163166
require.ErrorContains(t, err, "parse JWS")
164167
})
165168

@@ -175,7 +178,6 @@ func TestResumeTokenKeyProvider(t *testing.T) {
175178

176179
provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
177180
_, err = provider.VerifyResumeToken(token.Token)
178-
require.Error(t, err)
179181
require.ErrorContains(t, err, "verify JWS")
180182
})
181183
}

0 commit comments

Comments
 (0)