Skip to content

Commit 73cfe1a

Browse files
committed
fix: fix graceful disconnect in DialWorkspaceAgent
1 parent f39414c commit 73cfe1a

File tree

3 files changed

+190
-39
lines changed

3 files changed

+190
-39
lines changed

codersdk/workspaceagents.go

+82-38
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@ import (
1212
"net/netip"
1313
"strconv"
1414
"strings"
15+
"sync"
1516
"time"
1617

17-
"golang.org/x/sync/errgroup"
18-
1918
"github.com/google/uuid"
2019
"golang.org/x/xerrors"
2120
"nhooyr.io/websocket"
@@ -360,6 +359,15 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
360359
return agentConn, nil
361360
}
362361

362+
// tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is
363+
// included so that we can fake it in testing.
364+
//
365+
// @typescript-ignore tailnetConn
366+
type tailnetConn interface {
367+
tailnet.Coordinatee
368+
SetDERPMap(derpMap *tailcfg.DERPMap)
369+
}
370+
363371
// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
364372
//
365373
// 1) run the Coordinate API and pass node information back and forth
@@ -370,13 +378,20 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
370378
//
371379
// @typescript-ignore tailnetAPIConnector
372380
type tailnetAPIConnector struct {
373-
ctx context.Context
381+
// We keep track of two contexts: the main context from the caller, and a "graceful" context
382+
// that we keep open slightly longer than the main context to give a chance to send the
383+
// Disconnect message to the coordinator. That tells the coordinator that we really meant to
384+
// disconnect instead of just losing network connectivity.
385+
ctx context.Context
386+
gracefulCtx context.Context
387+
cancelGracefulCtx context.CancelFunc
388+
374389
logger slog.Logger
375390

376391
agentID uuid.UUID
377392
coordinateURL string
378393
dialOptions *websocket.DialOptions
379-
conn *tailnet.Conn
394+
conn tailnetConn
380395

381396
connected chan error
382397
isFirst bool
@@ -387,7 +402,7 @@ type tailnetAPIConnector struct {
387402
func runTailnetAPIConnector(
388403
ctx context.Context, logger slog.Logger,
389404
agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions,
390-
conn *tailnet.Conn,
405+
conn tailnetConn,
391406
) *tailnetAPIConnector {
392407
tac := &tailnetAPIConnector{
393408
ctx: ctx,
@@ -399,10 +414,23 @@ func runTailnetAPIConnector(
399414
connected: make(chan error, 1),
400415
closed: make(chan struct{}),
401416
}
417+
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
418+
go tac.manageGracefulTimeout()
402419
go tac.run()
403420
return tac
404421
}
405422

423+
// manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context
424+
// to allow a graceful disconnect.
425+
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
426+
defer tac.cancelGracefulCtx()
427+
<-tac.ctx.Done()
428+
select {
429+
case <-tac.closed:
430+
case <-time.After(time.Second):
431+
}
432+
}
433+
406434
func (tac *tailnetAPIConnector) run() {
407435
tac.isFirst = true
408436
defer close(tac.closed)
@@ -437,7 +465,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
437465
return nil, err
438466
}
439467
client, err := tailnet.NewDRPCClient(
440-
websocket.NetConn(tac.ctx, ws, websocket.MessageBinary),
468+
websocket.NetConn(tac.gracefulCtx, ws, websocket.MessageBinary),
441469
tac.logger,
442470
)
443471
if err != nil {
@@ -464,65 +492,81 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
464492
<-conn.Closed()
465493
}
466494
}()
467-
eg, egCtx := errgroup.WithContext(tac.ctx)
468-
eg.Go(func() error {
469-
return tac.coordinate(egCtx, client)
470-
})
471-
eg.Go(func() error {
472-
return tac.derpMap(egCtx, client)
473-
})
474-
err := eg.Wait()
475-
if err != nil &&
476-
!xerrors.Is(err, io.EOF) &&
477-
!xerrors.Is(err, context.Canceled) &&
478-
!xerrors.Is(err, context.DeadlineExceeded) {
479-
tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API")
480-
}
495+
wg := sync.WaitGroup{}
496+
wg.Add(2)
497+
go func() {
498+
defer wg.Done()
499+
tac.coordinate(client)
500+
}()
501+
go func() {
502+
defer wg.Done()
503+
dErr := tac.derpMap(client)
504+
if dErr != nil && tac.ctx.Err() == nil {
505+
// The main context is still active, meaning that we want the tailnet data plane to stay
506+
// up, even though we hit some error getting DERP maps on the control plane. That means
507+
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
508+
// close the underlying connection. This will trigger a retry of the control plane in
509+
// run().
510+
client.DRPCConn().Close()
511+
// Note that derpMap() logs it own errors, we don't bother here.
512+
}
513+
}()
514+
wg.Wait()
481515
}
482516

483-
func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error {
484-
coord, err := client.Coordinate(ctx)
517+
func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
518+
// we use the gracefulCtx here so that we'll have time to send the graceful disconnect
519+
coord, err := client.Coordinate(tac.gracefulCtx)
485520
if err != nil {
486-
return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err)
521+
tac.logger.Error(tac.ctx, "failed to connect to Coordinate RPC", slog.Error(err))
522+
return
487523
}
488524
defer func() {
489525
cErr := coord.Close()
490526
if cErr != nil {
491-
tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr))
527+
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
492528
}
493529
}()
494530
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID)
495-
tac.logger.Debug(ctx, "serving coordinator")
496-
err = <-coordination.Error()
497-
if err != nil &&
498-
!xerrors.Is(err, io.EOF) &&
499-
!xerrors.Is(err, context.Canceled) &&
500-
!xerrors.Is(err, context.DeadlineExceeded) {
501-
return xerrors.Errorf("remote coordination error: %w", err)
531+
tac.logger.Debug(tac.ctx, "serving coordinator")
532+
select {
533+
case <-tac.ctx.Done():
534+
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
535+
crdErr := coordination.Close()
536+
if crdErr != nil {
537+
tac.logger.Error(tac.ctx, "failed to close remote coordination", slog.Error(err))
538+
}
539+
case err = <-coordination.Error():
540+
if err != nil &&
541+
!xerrors.Is(err, io.EOF) &&
542+
!xerrors.Is(err, context.Canceled) &&
543+
!xerrors.Is(err, context.DeadlineExceeded) {
544+
tac.logger.Error(tac.ctx, "remote coordination error: %w", err)
545+
}
502546
}
503-
return nil
504547
}
505548

506-
func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error {
507-
s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{})
549+
func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
550+
s, err := client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{})
508551
if err != nil {
509552
return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err)
510553
}
511554
defer func() {
512555
cErr := s.Close()
513556
if cErr != nil {
514-
tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
557+
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
515558
}
516559
}()
517560
for {
518561
dmp, err := s.Recv()
519562
if err != nil {
520-
if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
563+
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
521564
return nil
522565
}
523-
return xerrors.Errorf("error receiving DERP Map: %w", err)
566+
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
567+
return err
524568
}
525-
tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp))
569+
tac.logger.Debug(tac.ctx, "got new DERP Map", slog.F("derp_map", dmp))
526570
dm := tailnet.DERPMapFromProto(dmp)
527571
tac.conn.SetDERPMap(dm)
528572
}
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package codersdk
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
"net/http/httptest"
8+
"sync/atomic"
9+
"testing"
10+
"time"
11+
12+
"github.com/google/uuid"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
"nhooyr.io/websocket"
16+
"tailscale.com/tailcfg"
17+
18+
"cdr.dev/slog"
19+
"cdr.dev/slog/sloggers/slogtest"
20+
"github.com/coder/coder/v2/tailnet"
21+
"github.com/coder/coder/v2/tailnet/proto"
22+
"github.com/coder/coder/v2/tailnet/tailnettest"
23+
"github.com/coder/coder/v2/testutil"
24+
)
25+
26+
func TestTailnetAPIConnector_Disconnects(t *testing.T) {
27+
t.Parallel()
28+
testCtx := testutil.Context(t, testutil.WaitShort)
29+
ctx, cancel := context.WithCancel(testCtx)
30+
logger := slogtest.Make(t, &slogtest.Options{
31+
// we get EOF when we simulate a DERPMap error
32+
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, io.EOF),
33+
}).Leveled(slog.LevelDebug)
34+
agentID := uuid.UUID{0x55}
35+
clientID := uuid.UUID{0x66}
36+
fCoord := tailnettest.NewFakeCoordinator()
37+
var coord tailnet.Coordinator = fCoord
38+
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
39+
coordPtr.Store(&coord)
40+
derpMapCh := make(chan *tailcfg.DERPMap)
41+
defer close(derpMapCh)
42+
svc, err := tailnet.NewClientService(
43+
logger, &coordPtr,
44+
time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh },
45+
)
46+
require.NoError(t, err)
47+
48+
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49+
sws, err := websocket.Accept(w, r, nil)
50+
if !assert.NoError(t, err) {
51+
return
52+
}
53+
ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary)
54+
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
55+
Name: "client",
56+
ID: clientID,
57+
Auth: tailnet.ClientTunnelAuth{AgentID: agentID},
58+
})
59+
assert.NoError(t, err)
60+
}))
61+
62+
fConn := newFakeTailnetConn()
63+
64+
uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn)
65+
66+
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
67+
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
68+
require.NotNil(t, reqTun.AddTunnel)
69+
70+
_ = testutil.RequireRecvCtx(ctx, t, uut.connected)
71+
72+
// simulate a problem with DERPMaps by sending nil
73+
testutil.RequireSendCtx(ctx, t, derpMapCh, nil)
74+
75+
// this should cause the coordinate call to hang up WITHOUT disconnecting
76+
reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs)
77+
require.Nil(t, reqNil)
78+
79+
// ...and then reconnect
80+
call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
81+
reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs)
82+
require.NotNil(t, reqTun.AddTunnel)
83+
84+
// canceling the context should trigger the disconnect message
85+
cancel()
86+
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
87+
require.NotNil(t, reqDisc)
88+
require.NotNil(t, reqDisc.Disconnect)
89+
}
90+
91+
type fakeTailnetConn struct{}
92+
93+
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
94+
// TODO implement me
95+
panic("implement me")
96+
}
97+
98+
func (*fakeTailnetConn) SetAllPeersLost() {}
99+
100+
func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
101+
102+
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
103+
104+
func newFakeTailnetConn() *fakeTailnetConn {
105+
return &fakeTailnetConn{}
106+
}

tailnet/coordinator.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ func (c *remoteCoordination) Close() (retErr error) {
134134
if err != nil {
135135
return xerrors.Errorf("send disconnect: %w", err)
136136
}
137+
c.logger.Debug(context.Background(), "sent disconnect")
137138
return nil
138139
}
139140

@@ -167,7 +168,7 @@ func (c *remoteCoordination) respLoop() {
167168
}
168169
}
169170

170-
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a
171+
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a
171172
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
172173
// a client---agents should NOT set this!).
173174
func NewRemoteCoordination(logger slog.Logger,

0 commit comments

Comments
 (0)