Skip to content

feat: add resume support to coordinator connections #14234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
PR comments 2
  • Loading branch information
deansheather committed Aug 14, 2024
commit 73a5cee7decf517aed130e001f550f57864ddff0
3 changes: 2 additions & 1 deletion cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/pretty"
"github.com/coder/quartz"
"github.com/coder/retry"
"github.com/coder/serpent"
"github.com/coder/wgtunnel/tunnelsdk"
Expand Down Expand Up @@ -809,7 +810,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
if err != nil {
return xerrors.Errorf("get coordinator resume token key from database: %w", err)
}
options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(resumeTokenKey, tailnet.DefaultResumeTokenExpiry)
options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(resumeTokenKey, quartz.NewReal(), tailnet.DefaultResumeTokenExpiry)

return nil
}, nil)
Expand Down
21 changes: 8 additions & 13 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {

logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)

// We block DERP in this test to ensure that even if there's no direct
// We block direct in this test to ensure that even if there's no direct
// connection, no shenanigans happen with the peer IDs on either side.
dv := coderdtest.DeploymentValues(t)
err := dv.DERP.Config.BlockDirect.Set("true")
Expand Down Expand Up @@ -563,22 +563,17 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
proxyClient := codersdk.New(proxyURL)
proxyClient.SetSessionToken(client.SessionToken())

// Connect from a client.
conn, err := func() (*workspacesdk.AgentConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() // Connection should remain open even if the dial context is canceled.
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

return workspacesdk.New(proxyClient).
DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{
Logger: logger.Named("client"),
})
}()
// Connect from a client.
conn, err := workspacesdk.New(proxyClient).
DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{
Logger: logger.Named("client"),
})
require.NoError(t, err)
defer conn.Close()

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

ok := conn.AwaitReachable(ctx)
require.True(t, ok)
originalAgentPeers := agentCloser.TailnetConn().GetKnownPeerIDs()
Expand Down
56 changes: 7 additions & 49 deletions codersdk/workspacesdk/connector_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)

func init() {
Expand Down Expand Up @@ -159,7 +160,7 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh)

resumeTokenProvider := tailnet.NewResumeTokenKeyProvider([64]byte{1}, time.Second)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider([64]byte{1}, quartz.NewReal(), time.Second)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
Expand Down Expand Up @@ -198,7 +199,7 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
return
}
}
testutil.RequireSendCtx(ctx, t, peerIDCh, peerID)
testutil.AssertSendCtx(ctx, t, peerIDCh, peerID)

sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
Expand Down Expand Up @@ -244,23 +245,6 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
require.Equal(t, originalPeerID, testutil.RequireRecvCtx(ctx, t, peerIDCh))
}

type resumeTokenProvider struct {
genFn func(uuid.UUID) (*proto.RefreshResumeTokenResponse, error)
parseFn func(string) (uuid.UUID, error)
}

var _ tailnet.ResumeTokenProvider = resumeTokenProvider{}

// GenerateResumeToken implements tailnet.ResumeTokenProvider.
func (r resumeTokenProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
return r.genFn(peerID)
}

// VerifyResumeToken implements tailnet.ResumeTokenProvider.
func (r resumeTokenProvider) VerifyResumeToken(token string) (uuid.UUID, error) {
return r.parseFn(token)
}

func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
Expand All @@ -275,43 +259,22 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh)

resumeTokenProvider := resumeTokenProvider{
genFn: func(uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
return &proto.RefreshResumeTokenResponse{
Token: uuid.NewString(),
RefreshIn: durationpb.New(time.Minute),
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
}, nil
},
parseFn: func(string) (uuid.UUID, error) {
return uuid.UUID{}, xerrors.New("test error")
},
}
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
ResumeTokenProvider: resumeTokenProvider,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)

var (
websocketConnCh = make(chan *websocket.Conn, 64)
peerIDCh = make(chan uuid.UUID, 64)
didFail int64
)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Accept a resume_token query parameter to use the same peer ID.
var (
peerID = uuid.New()
resumeToken = r.URL.Query().Get("resume_token")
)
t.Logf("received resume token: %s", resumeToken)
if resumeToken != "" {
_, err = resumeTokenProvider.VerifyResumeToken(resumeToken)
assert.Error(t, err, "parse resume token should return an error")
if r.URL.Query().Get("resume_token") != "" {
atomic.AddInt64(&didFail, 1)
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
Message: CoordinateAPIInvalidResumeToken,
Expand All @@ -322,7 +285,6 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
})
return
}
testutil.RequireSendCtx(ctx, t, peerIDCh, peerID)

sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
Expand All @@ -332,7 +294,7 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client",
ID: peerID,
ID: uuid.New(),
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
})
assert.NoError(t, err)
Expand All @@ -353,7 +315,6 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
// Sever the connection and expect it to reconnect with the resume token,
// which should fail and cause the client to be disconnected. The client
// should then reconnect with no resume token.
originalPeerID := testutil.RequireRecvCtx(ctx, t, peerIDCh)
wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh)
_ = wsConn.Close(websocket.StatusGoingAway, "test")

Expand All @@ -363,10 +324,7 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
return rt != nil && rt.Token != originalResumeToken.Token
}, testutil.WaitShort, testutil.IntervalFast)

// Peer ID should be different.
require.NotEqual(t, originalPeerID, testutil.RequireRecvCtx(ctx, t, peerIDCh))

// The resume token should have failed to parse.
// The resume token should have been rejected by the server.
require.EqualValues(t, 1, atomic.LoadInt64(&didFail))
}

Expand Down
18 changes: 10 additions & 8 deletions tailnet/resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
)

const (
Expand All @@ -31,7 +32,7 @@ func NewInsecureTestResumeTokenProvider() ResumeTokenProvider {
if err != nil {
panic(err)
}
return NewResumeTokenKeyProvider(key, time.Hour)
return NewResumeTokenKeyProvider(key, quartz.NewReal(), time.Hour)
}

type ResumeTokenProvider interface {
Expand Down Expand Up @@ -65,13 +66,12 @@ func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSignin
return resumeTokenKey, xerrors.Errorf("get coordinator resume token key: %w", err)
}
if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) {
b := make([]byte, len(resumeTokenKey))
_, err := rand.Read(b)
newKey, err := GenerateResumeTokenSigningKey()
if err != nil {
return resumeTokenKey, xerrors.Errorf("generate fresh coordinator resume token key: %w", err)
}

resumeTokenKeyStr = hex.EncodeToString(b)
resumeTokenKeyStr = hex.EncodeToString(newKey[:])
err = db.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr)
if err != nil {
return resumeTokenKey, xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err)
Expand All @@ -94,15 +94,17 @@ func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSignin

type ResumeTokenKeyProvider struct {
key ResumeTokenSigningKey
clock quartz.Clock
expiry time.Duration
}

func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, expiry time.Duration) ResumeTokenProvider {
func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider {
if expiry <= 0 {
expiry = DefaultResumeTokenExpiry
}
return ResumeTokenKeyProvider{
key: key,
clock: clock,
expiry: DefaultResumeTokenExpiry,
}
}
Expand All @@ -115,7 +117,7 @@ type resumeTokenPayload struct {
func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
payload := resumeTokenPayload{
PeerID: peerID,
Expiry: time.Now().Add(p.expiry),
Expiry: p.clock.Now().Add(p.expiry),
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
Expand Down Expand Up @@ -172,8 +174,8 @@ func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error)
if err != nil {
return uuid.Nil, xerrors.Errorf("unmarshal payload: %w", err)
}
if tok.Expiry.Before(time.Now()) {
return uuid.Nil, xerrors.New("signed app token expired")
if tok.Expiry.Before(p.clock.Now()) {
return uuid.Nil, xerrors.New("signed resume token expired")
}

return tok.PeerID, nil
Expand Down
Loading