Skip to content

feat: add resume support to coordinator connections #14234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
PR comments
  • Loading branch information
deansheather committed Aug 13, 2024
commit 904fde2cce886b9ffd057d71bd3628b9c9393b7f
30 changes: 2 additions & 28 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -805,35 +805,9 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.

// Read the coordinator resume token signing key from the
// database.
resumeTokenKey := [64]byte{}
resumeTokenKeyStr, err := tx.GetCoordinatorResumeTokenSigningKey(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get coordinator resume token key: %w", err)
}
if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) {
b := make([]byte, len(resumeTokenKey))
_, err := rand.Read(b)
if err != nil {
return xerrors.Errorf("generate fresh coordinator resume token key: %w", err)
}

resumeTokenKeyStr = hex.EncodeToString(b)
err = tx.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr)
if err != nil {
return xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err)
}
}

resumeTokenKeyBytes, err := hex.DecodeString(resumeTokenKeyStr)
resumeTokenKey, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, tx)
if err != nil {
return xerrors.Errorf("decode coordinator resume token key from database: %w", err)
}
if len(resumeTokenKeyBytes) != len(resumeTokenKey) {
return xerrors.Errorf("coordinator resume token key in database is not the correct length, expect %d got %d", len(resumeTokenKey), len(resumeTokenKeyBytes))
}
copy(resumeTokenKey[:], resumeTokenKeyBytes)
if resumeTokenKey == [64]byte{} {
return xerrors.Errorf("coordinator resume token key in database is empty")
return xerrors.Errorf("get coordinator resume token key from database: %w", err)
}
options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(resumeTokenKey, tailnet.DefaultResumeTokenExpiry)

Expand Down
2 changes: 1 addition & 1 deletion coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
TailnetCoordinator: options.Coordinator,
BaseDERPMap: derpMap,
DERPMapUpdateFrequency: 150 * time.Millisecond,
CoordinatorResumeTokenProvider: tailnet.InsecureTestResumeTokenProvider,
CoordinatorResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
DeploymentValues: options.DeploymentValues,
Expand Down
7 changes: 5 additions & 2 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -853,11 +853,14 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
)
if resumeToken != "" {
var err error
peerID, err = api.Options.CoordinatorResumeTokenProvider.ParseResumeToken(resumeToken)
peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(resumeToken)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
Message: workspacesdk.CoordinateAPIInvalidResumeToken,
Detail: err.Error(),
Validations: []codersdk.ValidationError{
{Field: "resume_token", Detail: workspacesdk.CoordinateAPIInvalidResumeToken},
},
})
return
}
Expand Down
20 changes: 11 additions & 9 deletions codersdk/workspacesdk/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,22 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
close(tac.connected)
}
if err != nil {
if !errors.Is(err, context.Canceled) {
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err))
}
if res.StatusCode == http.StatusBadRequest {
err = codersdk.ReadBodyAsError(res)
var sdkErr *codersdk.Error
if xerrors.As(err, &sdkErr) {
if sdkErr.Message == CoordinateAPIInvalidResumeToken {
didLog := false
bodyErr := codersdk.ReadBodyAsError(res)
var sdkErr *codersdk.Error
if xerrors.As(bodyErr, &sdkErr) {
for _, v := range sdkErr.Validations {
if v.Field == "resume_token" {
// Unset the resume token for the next attempt
tac.logger.Debug(tac.ctx, "server replied invalid resume token; unsetting for next connection attempt")
tac.logger.Warn(tac.ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
tac.resumeToken.Store(nil)
didLog = true
}
}
}
if !didLog && !errors.Is(err, context.Canceled) {
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
}
return nil, err
}
client, err := tailnet.NewDRPCClient(
Expand Down
22 changes: 14 additions & 8 deletions codersdk/workspacesdk/connector_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
ResumeTokenProvider: tailnet.InsecureTestResumeTokenProvider,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)

Expand Down Expand Up @@ -185,12 +185,15 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
t.Logf("received resume token: %s", resumeToken)
assert.Equal(t, expectResumeToken, resumeToken)
if resumeToken != "" {
peerID, err = resumeTokenProvider.ParseResumeToken(resumeToken)
peerID, err = resumeTokenProvider.VerifyResumeToken(resumeToken)
assert.NoError(t, err, "failed to parse resume token")
if err != nil {
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
Message: CoordinateAPIInvalidResumeToken,
Detail: err.Error(),
Validations: []codersdk.ValidationError{
{Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken},
},
})
return
}
Expand Down Expand Up @@ -253,8 +256,8 @@ func (r resumeTokenProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.Refre
return r.genFn(peerID)
}

// ParseResumeToken implements tailnet.ResumeTokenProvider.
func (r resumeTokenProvider) ParseResumeToken(token string) (uuid.UUID, error) {
// VerifyResumeToken implements tailnet.ResumeTokenProvider.
func (r resumeTokenProvider) VerifyResumeToken(token string) (uuid.UUID, error) {
return r.parseFn(token)
}

Expand Down Expand Up @@ -307,12 +310,15 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
)
t.Logf("received resume token: %s", resumeToken)
if resumeToken != "" {
_, err = resumeTokenProvider.ParseResumeToken(resumeToken)
_, err = resumeTokenProvider.VerifyResumeToken(resumeToken)
assert.Error(t, err, "parse resume token should return an error")
atomic.AddInt64(&didFail, 1)
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
Message: CoordinateAPIInvalidResumeToken,
Detail: err.Error(),
Validations: []codersdk.ValidationError{
{Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken},
},
})
return
}
Expand Down Expand Up @@ -385,7 +391,7 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
testutil.RequireSendCtx(ctx, t, eventCh, batch)
},
ResumeTokenProvider: tailnet.InsecureTestResumeTokenProvider,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)

Expand Down
2 changes: 1 addition & 1 deletion enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func TestDialCoordinator(t *testing.T) {
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: agpl.InsecureTestResumeTokenProvider,
ResumeTokenProvider: agpl.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)

Expand Down
4 changes: 2 additions & 2 deletions tailnet/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ func TestRemoteCoordination(t *testing.T) {
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: tailnet.InsecureTestResumeTokenProvider,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
sC, cC := net.Pipe()
Expand Down Expand Up @@ -682,7 +682,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: tailnet.InsecureTestResumeTokenProvider,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
sC, cC := net.Pipe()
Expand Down
81 changes: 72 additions & 9 deletions tailnet/resume.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package tailnet

import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"encoding/json"
"time"

Expand All @@ -19,22 +23,81 @@ const (
resumeTokenSigningAlgorithm = jose.HS512
)

var InsecureTestResumeTokenProvider ResumeTokenProvider = ResumeTokenKeyProvider{
key: [64]byte{1},
expiry: time.Hour,
// NewInsecureTestResumeTokenProvider returns a ResumeTokenProvider that uses a
// random key with short expiry for testing purposes. If any errors occur while
// generating the key, the function panics.
func NewInsecureTestResumeTokenProvider() ResumeTokenProvider {
key, err := GenerateResumeTokenSigningKey()
if err != nil {
panic(err)
}
return NewResumeTokenKeyProvider(key, time.Hour)
}

type ResumeTokenProvider interface {
GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error)
ParseResumeToken(token string) (uuid.UUID, error)
VerifyResumeToken(token string) (uuid.UUID, error)
}

type ResumeTokenSigningKey [64]byte

func GenerateResumeTokenSigningKey() (ResumeTokenSigningKey, error) {
var key ResumeTokenSigningKey
_, err := rand.Read(key[:])
if err != nil {
return key, xerrors.Errorf("generate random key: %w", err)
}
return key, nil
}

type ResumeTokenSigningKeyDatabaseStore interface {
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, key string) error
}

// ResumeTokenSigningKeyFromDatabase retrieves the coordinator resume token
// signing key from the database. If the key is not found, a new key is
// generated and inserted into the database.
func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSigningKeyDatabaseStore) (ResumeTokenSigningKey, error) {
var resumeTokenKey ResumeTokenSigningKey
resumeTokenKeyStr, err := db.GetCoordinatorResumeTokenSigningKey(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return resumeTokenKey, xerrors.Errorf("get coordinator resume token key: %w", err)
}
if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) {
b := make([]byte, len(resumeTokenKey))
_, err := rand.Read(b)
if err != nil {
return resumeTokenKey, xerrors.Errorf("generate fresh coordinator resume token key: %w", err)
}

resumeTokenKeyStr = hex.EncodeToString(b)
err = db.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr)
if err != nil {
return resumeTokenKey, xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err)
}
}

resumeTokenKeyBytes, err := hex.DecodeString(resumeTokenKeyStr)
if err != nil {
return resumeTokenKey, xerrors.Errorf("decode coordinator resume token key from database: %w", err)
}
if len(resumeTokenKeyBytes) != len(resumeTokenKey) {
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is not the correct length, expect %d got %d", len(resumeTokenKey), len(resumeTokenKeyBytes))
}
copy(resumeTokenKey[:], resumeTokenKeyBytes)
if resumeTokenKey == [64]byte{} {
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is empty")
}
return resumeTokenKey, nil
}

type ResumeTokenKeyProvider struct {
key [64]byte
key ResumeTokenSigningKey
expiry time.Duration
}

func NewResumeTokenKeyProvider(key [64]byte, expiry time.Duration) ResumeTokenProvider {
func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, expiry time.Duration) ResumeTokenProvider {
if expiry <= 0 {
expiry = DefaultResumeTokenExpiry
}
Expand All @@ -45,8 +108,8 @@ func NewResumeTokenKeyProvider(key [64]byte, expiry time.Duration) ResumeTokenPr
}

type resumeTokenPayload struct {
PeerID uuid.UUID `json:"peer_id"`
Expiry time.Time `json:"expiry"`
PeerID uuid.UUID `json:"sub"`
Expiry time.Time `json:"exp"`
}

func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
Expand Down Expand Up @@ -87,7 +150,7 @@ func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.Re
// VerifySignedToken parses a signed workspace app token with the given key and
// returns the payload. If the token is invalid or expired, an error is
// returned.
func (p ResumeTokenKeyProvider) ParseResumeToken(str string) (uuid.UUID, error) {
func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error) {
object, err := jose.ParseSigned(str)
if err != nil {
return uuid.Nil, xerrors.Errorf("parse JWS: %w", err)
Expand Down
4 changes: 2 additions & 2 deletions tailnet/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestClientService_ServeClient_V2(t *testing.T) {
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
telemetryEvents <- batch
},
ResumeTokenProvider: tailnet.InsecureTestResumeTokenProvider,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)

Expand Down Expand Up @@ -145,7 +145,7 @@ func TestClientService_ServeClient_V1(t *testing.T) {
DERPMapUpdateFrequency: 0,
DERPMapFn: nil,
NetworkTelemetryHandler: nil,
ResumeTokenProvider: tailnet.InsecureTestResumeTokenProvider,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)

Expand Down
2 changes: 1 addition & 1 deletion tailnet/test/integration/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (o SimpleServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux {
}
},
NetworkTelemetryHandler: func(batch []*tailnetproto.TelemetryEvent) {},
ResumeTokenProvider: tailnet.InsecureTestResumeTokenProvider,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)

Expand Down
Loading