Skip to content

Commit 6371570

Browse files
committed
fix: fix graceful disconnect in DialWorkspaceAgent
1 parent 02f29b5 commit 6371570

File tree

3 files changed

+188
-39
lines changed

3 files changed

+188
-39
lines changed

codersdk/workspaceagents.go

+80-38
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@ import (
1212
"net/netip"
1313
"strconv"
1414
"strings"
15+
"sync"
1516
"time"
1617

17-
"golang.org/x/sync/errgroup"
18-
1918
"github.com/google/uuid"
2019
"golang.org/x/xerrors"
2120
"nhooyr.io/websocket"
@@ -360,6 +359,13 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
360359
return agentConn, nil
361360
}
362361

362+
// tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is
363+
// included so that we can fake it in testing.
364+
type tailnetConn interface {
365+
tailnet.Coordinatee
366+
SetDERPMap(derpMap *tailcfg.DERPMap)
367+
}
368+
363369
// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
364370
//
365371
// 1) run the Coordinate API and pass node information back and forth
@@ -370,13 +376,20 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
370376
//
371377
// @typescript-ignore tailnetAPIConnector
372378
type tailnetAPIConnector struct {
373-
ctx context.Context
379+
// We keep track of two contexts: the main context from the caller, and a "graceful" context
380+
// that we keep open slightly longer than the main context to give a chance to send the
381+
// Disconnect message to the coordinator. That tells the coordinator that we really meant to
382+
// disconnect instead of just losing network connectivity.
383+
ctx context.Context
384+
gracefulCtx context.Context
385+
cancelGracefulCtx context.CancelFunc
386+
374387
logger slog.Logger
375388

376389
agentID uuid.UUID
377390
coordinateURL string
378391
dialOptions *websocket.DialOptions
379-
conn *tailnet.Conn
392+
conn tailnetConn
380393

381394
connected chan error
382395
isFirst bool
@@ -387,7 +400,7 @@ type tailnetAPIConnector struct {
387400
func runTailnetAPIConnector(
388401
ctx context.Context, logger slog.Logger,
389402
agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions,
390-
conn *tailnet.Conn,
403+
conn tailnetConn,
391404
) *tailnetAPIConnector {
392405
tac := &tailnetAPIConnector{
393406
ctx: ctx,
@@ -399,10 +412,23 @@ func runTailnetAPIConnector(
399412
connected: make(chan error, 1),
400413
closed: make(chan struct{}),
401414
}
415+
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
416+
go tac.manageGracefulTimeout()
402417
go tac.run()
403418
return tac
404419
}
405420

421+
// manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context
422+
// to allow a graceful disconnect.
423+
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
424+
defer tac.cancelGracefulCtx()
425+
<-tac.ctx.Done()
426+
select {
427+
case <-tac.closed:
428+
case <-time.After(time.Second):
429+
}
430+
}
431+
406432
func (tac *tailnetAPIConnector) run() {
407433
tac.isFirst = true
408434
defer close(tac.closed)
@@ -437,7 +463,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
437463
return nil, err
438464
}
439465
client, err := tailnet.NewDRPCClient(
440-
websocket.NetConn(tac.ctx, ws, websocket.MessageBinary),
466+
websocket.NetConn(tac.gracefulCtx, ws, websocket.MessageBinary),
441467
tac.logger,
442468
)
443469
if err != nil {
@@ -464,65 +490,81 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
464490
<-conn.Closed()
465491
}
466492
}()
467-
eg, egCtx := errgroup.WithContext(tac.ctx)
468-
eg.Go(func() error {
469-
return tac.coordinate(egCtx, client)
470-
})
471-
eg.Go(func() error {
472-
return tac.derpMap(egCtx, client)
473-
})
474-
err := eg.Wait()
475-
if err != nil &&
476-
!xerrors.Is(err, io.EOF) &&
477-
!xerrors.Is(err, context.Canceled) &&
478-
!xerrors.Is(err, context.DeadlineExceeded) {
479-
tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API")
480-
}
493+
wg := sync.WaitGroup{}
494+
wg.Add(2)
495+
go func() {
496+
defer wg.Done()
497+
tac.coordinate(client)
498+
}()
499+
go func() {
500+
defer wg.Done()
501+
dErr := tac.derpMap(client)
502+
if dErr != nil && tac.ctx.Err() == nil {
503+
// The main context is still active, meaning that we want the tailnet data plane to stay
504+
// up, even though we hit some error getting DERP maps on the control plane. That means
505+
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
506+
// close the underlying connection. This will trigger a retry of the control plane in
507+
// run().
508+
client.DRPCConn().Close()
509+
// Note that derpMap() logs it own errors, we don't bother here.
510+
}
511+
}()
512+
wg.Wait()
481513
}
482514

483-
func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error {
484-
coord, err := client.Coordinate(ctx)
515+
func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
516+
// we use the gracefulCtx here so that we'll have time to send the graceful disconnect
517+
coord, err := client.Coordinate(tac.gracefulCtx)
485518
if err != nil {
486-
return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err)
519+
tac.logger.Error(tac.ctx, "failed to connect to Coordinate RPC", slog.Error(err))
520+
return
487521
}
488522
defer func() {
489523
cErr := coord.Close()
490524
if cErr != nil {
491-
tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr))
525+
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
492526
}
493527
}()
494528
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID)
495-
tac.logger.Debug(ctx, "serving coordinator")
496-
err = <-coordination.Error()
497-
if err != nil &&
498-
!xerrors.Is(err, io.EOF) &&
499-
!xerrors.Is(err, context.Canceled) &&
500-
!xerrors.Is(err, context.DeadlineExceeded) {
501-
return xerrors.Errorf("remote coordination error: %w", err)
529+
tac.logger.Debug(tac.ctx, "serving coordinator")
530+
select {
531+
case <-tac.ctx.Done():
532+
// main context canceled; do graceful disconnect
533+
crdErr := coordination.Close()
534+
if crdErr != nil {
535+
tac.logger.Error(tac.ctx, "failed to close remote coordination", slog.Error(err))
536+
}
537+
case err = <-coordination.Error():
538+
if err != nil &&
539+
!xerrors.Is(err, io.EOF) &&
540+
!xerrors.Is(err, context.Canceled) &&
541+
!xerrors.Is(err, context.DeadlineExceeded) {
542+
tac.logger.Error(tac.ctx, "remote coordination error: %w", err)
543+
}
502544
}
503-
return nil
504545
}
505546

506-
func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error {
507-
s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{})
547+
func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
548+
s, err := client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{})
508549
if err != nil {
509550
return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err)
510551
}
511552
defer func() {
512553
cErr := s.Close()
513554
if cErr != nil {
514-
tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
555+
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
515556
}
516557
}()
517558
for {
518559
dmp, err := s.Recv()
519560
if err != nil {
520-
if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
561+
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
521562
return nil
522563
}
523-
return xerrors.Errorf("error receiving DERP Map: %w", err)
564+
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
565+
return err
524566
}
525-
tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp))
567+
tac.logger.Debug(tac.ctx, "got new DERP Map", slog.F("derp_map", dmp))
526568
dm := tailnet.DERPMapFromProto(dmp)
527569
tac.conn.SetDERPMap(dm)
528570
}
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package codersdk
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
"net/http/httptest"
8+
"sync/atomic"
9+
"testing"
10+
"time"
11+
12+
"github.com/google/uuid"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
"nhooyr.io/websocket"
16+
"tailscale.com/tailcfg"
17+
18+
"cdr.dev/slog"
19+
"cdr.dev/slog/sloggers/slogtest"
20+
"github.com/coder/coder/v2/tailnet"
21+
"github.com/coder/coder/v2/tailnet/proto"
22+
"github.com/coder/coder/v2/tailnet/tailnettest"
23+
"github.com/coder/coder/v2/testutil"
24+
)
25+
26+
func TestTailnetAPIConnector_Disconnects(t *testing.T) {
27+
t.Parallel()
28+
testCtx := testutil.Context(t, testutil.WaitShort)
29+
ctx, cancel := context.WithCancel(testCtx)
30+
logger := slogtest.Make(t, &slogtest.Options{
31+
// we get EOF when we simulate a DERPMap error
32+
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, io.EOF),
33+
}).Leveled(slog.LevelDebug)
34+
agentID := uuid.UUID{0x55}
35+
clientID := uuid.UUID{0x66}
36+
fCoord := tailnettest.NewFakeCoordinator()
37+
var coord tailnet.Coordinator = fCoord
38+
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
39+
coordPtr.Store(&coord)
40+
derpMapCh := make(chan *tailcfg.DERPMap)
41+
defer close(derpMapCh)
42+
svc, err := tailnet.NewClientService(
43+
logger, &coordPtr,
44+
time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh },
45+
)
46+
require.NoError(t, err)
47+
48+
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49+
sws, err := websocket.Accept(w, r, nil)
50+
if !assert.NoError(t, err) {
51+
return
52+
}
53+
ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary)
54+
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
55+
Name: "client",
56+
ID: clientID,
57+
Auth: tailnet.ClientTunnelAuth{AgentID: agentID},
58+
})
59+
assert.NoError(t, err)
60+
}))
61+
62+
fConn := newFakeTailnetConn()
63+
64+
uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn)
65+
66+
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
67+
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
68+
require.NotNil(t, reqTun.AddTunnel)
69+
70+
_ = testutil.RequireRecvCtx(ctx, t, uut.connected)
71+
72+
// simulate a problem with DERPMaps by sending nil
73+
testutil.RequireSendCtx(ctx, t, derpMapCh, nil)
74+
75+
// this should cause the coordinate call to hang up WITHOUT disconnecting
76+
reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs)
77+
require.Nil(t, reqNil)
78+
79+
// ...and then reconnect
80+
call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
81+
reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs)
82+
require.NotNil(t, reqTun.AddTunnel)
83+
84+
// canceling the context should trigger the disconnect message
85+
cancel()
86+
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
87+
require.NotNil(t, reqDisc)
88+
require.NotNil(t, reqDisc.Disconnect)
89+
}
90+
91+
type fakeTailnetConn struct{}
92+
93+
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
94+
// TODO implement me
95+
panic("implement me")
96+
}
97+
98+
func (*fakeTailnetConn) SetAllPeersLost() {}
99+
100+
func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
101+
102+
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
103+
104+
func newFakeTailnetConn() *fakeTailnetConn {
105+
return &fakeTailnetConn{}
106+
}

tailnet/coordinator.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ func (c *remoteCoordination) Close() (retErr error) {
134134
if err != nil {
135135
return xerrors.Errorf("send disconnect: %w", err)
136136
}
137+
c.logger.Debug(context.Background(), "sent disconnect")
137138
return nil
138139
}
139140

@@ -167,7 +168,7 @@ func (c *remoteCoordination) respLoop() {
167168
}
168169
}
169170

170-
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a
171+
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a
171172
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
172173
// a client---agents should NOT set this!).
173174
func NewRemoteCoordination(logger slog.Logger,

0 commit comments

Comments
 (0)