Skip to content

feat: add resume support to coordinator connections #14234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
PR comments 4
  • Loading branch information
deansheather committed Aug 20, 2024
commit a81eac1e38d5890cef46c851623f9e16fb599ee9
44 changes: 24 additions & 20 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,26 @@ type Options struct {
// AccessURL denotes a custom access URL. By default we use the httptest
// server's URL. Setting this may result in unexpected behavior (especially
// with running agents).
AccessURL *url.URL
AppHostname string
AWSCertificates awsidentity.Certificates
Authorizer rbac.Authorizer
AzureCertificates x509.VerifyOptions
GithubOAuth2Config *coderd.GithubOAuth2Config
RealIPConfig *httpmw.RealIPConfig
OIDCConfig *coderd.OIDCConfig
GoogleTokenValidator *idtoken.Validator
SSHKeygenAlgorithm gitsshkey.Algorithm
AutobuildTicker <-chan time.Time
AutobuildStats chan<- autobuild.Stats
Auditor audit.Auditor
TLSCertificates []tls.Certificate
ExternalAuthConfigs []*externalauth.Config
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
RefreshEntitlements func(ctx context.Context) error
TemplateScheduleStore schedule.TemplateScheduleStore
Coordinator tailnet.Coordinator
AccessURL *url.URL
AppHostname string
AWSCertificates awsidentity.Certificates
Authorizer rbac.Authorizer
AzureCertificates x509.VerifyOptions
GithubOAuth2Config *coderd.GithubOAuth2Config
RealIPConfig *httpmw.RealIPConfig
OIDCConfig *coderd.OIDCConfig
GoogleTokenValidator *idtoken.Validator
SSHKeygenAlgorithm gitsshkey.Algorithm
AutobuildTicker <-chan time.Time
AutobuildStats chan<- autobuild.Stats
Auditor audit.Auditor
TLSCertificates []tls.Certificate
ExternalAuthConfigs []*externalauth.Config
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
RefreshEntitlements func(ctx context.Context) error
TemplateScheduleStore schedule.TemplateScheduleStore
Coordinator tailnet.Coordinator
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider

HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
HealthcheckTimeout time.Duration
Expand Down Expand Up @@ -240,6 +241,9 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
if options.Database == nil {
options.Database, options.Pubsub = dbtestutil.NewDB(t)
}
if options.CoordinatorResumeTokenProvider == nil {
options.CoordinatorResumeTokenProvider = tailnet.NewInsecureTestResumeTokenProvider()
}

if options.NotificationsEnqueuer == nil {
options.NotificationsEnqueuer = new(testutil.FakeNotificationsEnqueuer)
Expand Down Expand Up @@ -492,7 +496,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
TailnetCoordinator: options.Coordinator,
BaseDERPMap: derpMap,
DERPMapUpdateFrequency: 150 * time.Millisecond,
CoordinatorResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
CoordinatorResumeTokenProvider: options.CoordinatorResumeTokenProvider,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
DeploymentValues: options.DeploymentValues,
Expand Down
2 changes: 1 addition & 1 deletion coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
}

func (q *querier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return "", err
}
return q.db.GetCoordinatorResumeTokenSigningKey(ctx)
Expand Down
2 changes: 1 addition & 1 deletion coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2567,7 +2567,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
}))
s.Run("GetCoordinatorResumeTokenSigningKey", s.Subtest(func(db database.Store, check *expects) {
db.UpsertCoordinatorResumeTokenSigningKey(context.Background(), "foo")
check.Args().Asserts(rbac.ResourceSystem, policy.ActionUpdate)
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("InsertMissingGroups", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertMissingGroupsParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate).Errors(errMatchAny)
Expand Down
10 changes: 9 additions & 1 deletion coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)

func TestWorkspaceAgent(t *testing.T) {
Expand Down Expand Up @@ -533,11 +534,16 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
t.Parallel()

logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
coordinator := &resumeTokenTestFakeCoordinator{
Coordinator: tailnet.NewCoordinator(logger),
}
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Coordinator: coordinator,
Coordinator: coordinator,
CoordinatorResumeTokenProvider: resumeTokenProvider,
})
defer closer.Close()
user := coderdtest.CreateFirstUser(t, client)
Expand All @@ -564,6 +570,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {

// Connect with a valid resume token, and ensure that the peer ID is set to
// the stored value.
clock.Advance(time.Second)
coordinator.lastPeerID = uuid.Nil
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
require.NoError(t, err)
Expand All @@ -572,6 +579,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {

// Connect with an invalid resume token, and ensure that the request is
// rejected.
clock.Advance(time.Second)
coordinator.lastPeerID = uuid.Nil
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
require.Error(t, err)
Expand Down
5 changes: 2 additions & 3 deletions codersdk/workspacesdk/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
close(tac.connected)
}
if err != nil {
didLog := false
bodyErr := codersdk.ReadBodyAsError(res)
var sdkErr *codersdk.Error
if xerrors.As(bodyErr, &sdkErr) {
Expand All @@ -191,11 +190,11 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
// Unset the resume token for the next attempt
tac.logger.Warn(tac.ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
tac.resumeToken = nil
didLog = true
return nil, err
}
}
}
if !didLog && !errors.Is(err, context.Canceled) {
if !errors.Is(err, context.Canceled) {
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
}
return nil, err
Expand Down
25 changes: 12 additions & 13 deletions codersdk/workspacesdk/connector_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh)

resumeTokenProvider := tailnet.NewInsecureTestResumeTokenProvider()
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
Expand Down Expand Up @@ -215,18 +218,16 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {

fConn := newFakeTailnetConn()

clock := quartz.NewMock(t)
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
defer newTickerTrap.Close()
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{})
uut.runConnector(fConn)

// Fetch first token.
// Fetch first token. We don't need to advance the clock since we use a
// channel with a single item to immediately fetch.
trappedTicker := newTickerTrap.MustWait(ctx)
trappedTicker.Release()
waiter := clock.Advance(trappedTicker.Duration)
waiter.MustWait(ctx)
// We call ticker.Reset after each token fetch to apply the refresh duration
// requested by the server.
trappedReset := tickerResetTrap.MustWait(ctx)
Expand All @@ -235,7 +236,7 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
originalResumeToken := uut.resumeToken.Token

// Fetch second token.
waiter = clock.Advance(trappedReset.Duration)
waiter := clock.Advance(trappedReset.Duration)
waiter.MustWait(ctx)
trappedReset = tickerResetTrap.MustWait(ctx)
trappedReset.Release()
Expand Down Expand Up @@ -275,13 +276,17 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh)

clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
ResumeTokenProvider: resumeTokenProvider,
})
require.NoError(t, err)

Expand Down Expand Up @@ -317,7 +322,6 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {

fConn := newFakeTailnetConn()

clock := quartz.NewMock(t)
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
defer newTickerTrap.Close()
Expand All @@ -327,8 +331,6 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
// Wait for the resume token to be fetched for the first time.
trappedTicker := newTickerTrap.MustWait(ctx)
trappedTicker.Release()
waiter := clock.Advance(trappedTicker.Duration)
waiter.MustWait(ctx)
trappedReset := tickerResetTrap.MustWait(ctx)
trappedReset.Release()
originalResumeToken := uut.resumeToken.Token
Expand All @@ -346,9 +348,6 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
// Since we failed the initial reconnect and we're definitely reconnected
// now, the stored resume token should now be nil.
require.Nil(t, uut.resumeToken)
// Continue to the next token fetch.
waiter = clock.Advance(trappedTicker.Duration)
waiter.MustWait(ctx)
trappedReset = tickerResetTrap.MustWait(ctx)
trappedReset.Release()
require.NotNil(t, uut.resumeToken)
Expand Down
12 changes: 7 additions & 5 deletions tailnet/resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,14 @@ func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, clock quartz.Clock, ex

type resumeTokenPayload struct {
PeerID uuid.UUID `json:"sub"`
Expiry time.Time `json:"exp"`
Expiry int64 `json:"exp"`
}

func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
exp := p.clock.Now().Add(p.expiry)
payload := resumeTokenPayload{
PeerID: peerID,
Expiry: p.clock.Now().Add(p.expiry),
Expiry: exp.Unix(),
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
Expand Down Expand Up @@ -154,11 +155,11 @@ func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.Re
return &proto.RefreshResumeTokenResponse{
Token: serialized,
RefreshIn: durationpb.New(p.expiry / 2),
ExpiresAt: timestamppb.New(payload.Expiry),
ExpiresAt: timestamppb.New(exp),
}, nil
}

// VerifySignedToken parses a signed workspace app token with the given key and
// VerifyResumeToken parses a signed tailnet resume token with the given key and
// returns the payload. If the token is invalid or expired, an error is
// returned.
func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error) {
Expand Down Expand Up @@ -186,7 +187,8 @@ func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error)
if err != nil {
return uuid.Nil, xerrors.Errorf("unmarshal payload: %w", err)
}
if tok.Expiry.Before(p.clock.Now()) {
exp := time.Unix(tok.Expiry, 0)
if exp.Before(p.clock.Now()) {
return uuid.Nil, xerrors.New("signed resume token expired")
}

Expand Down
28 changes: 15 additions & 13 deletions tailnet/resume_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tailnet_test

import (
"context"
"encoding/hex"
"testing"
"time"
Expand All @@ -22,8 +23,8 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {

assertRandomKey := func(t *testing.T, key tailnet.ResumeTokenSigningKey) {
t.Helper()
assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key is empty")
assert.NotEqualValues(t, [64]byte{1}, key, "key is all 1s")
assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key should not be empty")
assert.NotEqualValues(t, [64]byte{1}, key, "key should not be all 1s")
}

t.Run("GenerateRetrieve", func(t *testing.T) {
Expand All @@ -37,7 +38,7 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {

key2, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.NoError(t, err)
require.Equal(t, key1, key2, "keys are different")
require.Equal(t, key1, key2, "keys should not be different")
})

t.Run("GetError", func(t *testing.T) {
Expand All @@ -48,7 +49,6 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {

ctx := testutil.Context(t, testutil.WaitShort)
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.Error(t, err)
require.ErrorIs(t, err, assert.AnError)
})

Expand All @@ -61,7 +61,6 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {

ctx := testutil.Context(t, testutil.WaitShort)
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.Error(t, err)
require.ErrorIs(t, err, assert.AnError)
})

Expand All @@ -70,12 +69,21 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {

db := dbmock.NewMockStore(gomock.NewController(t))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("invalid", nil)
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(nil)

var storedKey tailnet.ResumeTokenSigningKey
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Do(func(_ context.Context, value string) error {
keyBytes, err := hex.DecodeString(value)
require.NoError(t, err)
require.Len(t, keyBytes, len(storedKey))
copy(storedKey[:], keyBytes)
return nil
})

ctx := testutil.Context(t, testutil.WaitShort)
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.NoError(t, err)
assertRandomKey(t, key)
require.Equal(t, storedKey, key, "key should match stored value")
})

t.Run("LengthErrorShouldRegenerate", func(t *testing.T) {
Expand All @@ -100,7 +108,6 @@ func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {

ctx := testutil.Context(t, testutil.WaitShort)
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.Error(t, err)
require.ErrorContains(t, err, "is empty")
})
}
Expand All @@ -115,16 +122,14 @@ func TestResumeTokenKeyProvider(t *testing.T) {
t.Parallel()

id := uuid.New()
now := time.Now()
clock := quartz.NewMock(t)
clock.Set(now)
provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry)
token, err := provider.GenerateResumeToken(id)
require.NoError(t, err)
require.NotNil(t, token)
require.NotEmpty(t, token.Token)
require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration())
require.WithinDuration(t, now.Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)

gotID, err := provider.VerifyResumeToken(token.Token)
require.NoError(t, err)
Expand All @@ -150,7 +155,6 @@ func TestResumeTokenKeyProvider(t *testing.T) {
_ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second)

_, err = provider.VerifyResumeToken(token.Token)
require.Error(t, err)
require.ErrorContains(t, err, "expired")
})

Expand All @@ -159,7 +163,6 @@ func TestResumeTokenKeyProvider(t *testing.T) {

provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
_, err := provider.VerifyResumeToken("invalid")
require.Error(t, err)
require.ErrorContains(t, err, "parse JWS")
})

Expand All @@ -175,7 +178,6 @@ func TestResumeTokenKeyProvider(t *testing.T) {

provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
_, err = provider.VerifyResumeToken(token.Token)
require.Error(t, err)
require.ErrorContains(t, err, "verify JWS")
})
}
Loading