Skip to content

Commit e5ba586

Browse files
authored
fix: fix graceful disconnect in DialWorkspaceAgent (#11993)
I noticed in testing that the CLI wasn't correctly sending the disconnect message when it shuts down, and thus agents are seeing this as a "lost" peer, rather than a "disconnected" one. What was happening is that we just used a single context for everything from the netconn to the RPCs, and when the context was canceled we failed to send the disconnect message due to canceled context. So, this PR splits things into two contexts, with a graceful one set to last up to 1 second longer than the main one.
1 parent bb99cb7 commit e5ba586

File tree

3 files changed

+190
-39
lines changed

3 files changed

+190
-39
lines changed

codersdk/workspaceagents.go

+82-38
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@ import (
1212
"net/netip"
1313
"strconv"
1414
"strings"
15+
"sync"
1516
"time"
1617

17-
"golang.org/x/sync/errgroup"
18-
1918
"github.com/google/uuid"
2019
"golang.org/x/xerrors"
2120
"nhooyr.io/websocket"
@@ -360,6 +359,15 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
360359
return agentConn, nil
361360
}
362361

362+
// tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is
363+
// included so that we can fake it in testing.
364+
//
365+
// @typescript-ignore tailnetConn
366+
type tailnetConn interface {
367+
tailnet.Coordinatee
368+
SetDERPMap(derpMap *tailcfg.DERPMap)
369+
}
370+
363371
// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
364372
//
365373
// 1) run the Coordinate API and pass node information back and forth
@@ -370,13 +378,20 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
370378
//
371379
// @typescript-ignore tailnetAPIConnector
372380
type tailnetAPIConnector struct {
373-
ctx context.Context
381+
// We keep track of two contexts: the main context from the caller, and a "graceful" context
382+
// that we keep open slightly longer than the main context to give a chance to send the
383+
// Disconnect message to the coordinator. That tells the coordinator that we really meant to
384+
// disconnect instead of just losing network connectivity.
385+
ctx context.Context
386+
gracefulCtx context.Context
387+
cancelGracefulCtx context.CancelFunc
388+
374389
logger slog.Logger
375390

376391
agentID uuid.UUID
377392
coordinateURL string
378393
dialOptions *websocket.DialOptions
379-
conn *tailnet.Conn
394+
conn tailnetConn
380395

381396
connected chan error
382397
isFirst bool
@@ -387,7 +402,7 @@ type tailnetAPIConnector struct {
387402
func runTailnetAPIConnector(
388403
ctx context.Context, logger slog.Logger,
389404
agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions,
390-
conn *tailnet.Conn,
405+
conn tailnetConn,
391406
) *tailnetAPIConnector {
392407
tac := &tailnetAPIConnector{
393408
ctx: ctx,
@@ -399,10 +414,23 @@ func runTailnetAPIConnector(
399414
connected: make(chan error, 1),
400415
closed: make(chan struct{}),
401416
}
417+
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
418+
go tac.manageGracefulTimeout()
402419
go tac.run()
403420
return tac
404421
}
405422

423+
// manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context
424+
// to allow a graceful disconnect.
425+
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
426+
defer tac.cancelGracefulCtx()
427+
<-tac.ctx.Done()
428+
select {
429+
case <-tac.closed:
430+
case <-time.After(time.Second):
431+
}
432+
}
433+
406434
func (tac *tailnetAPIConnector) run() {
407435
tac.isFirst = true
408436
defer close(tac.closed)
@@ -437,7 +465,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
437465
return nil, err
438466
}
439467
client, err := tailnet.NewDRPCClient(
440-
websocket.NetConn(tac.ctx, ws, websocket.MessageBinary),
468+
websocket.NetConn(tac.gracefulCtx, ws, websocket.MessageBinary),
441469
tac.logger,
442470
)
443471
if err != nil {
@@ -464,65 +492,81 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
464492
<-conn.Closed()
465493
}
466494
}()
467-
eg, egCtx := errgroup.WithContext(tac.ctx)
468-
eg.Go(func() error {
469-
return tac.coordinate(egCtx, client)
470-
})
471-
eg.Go(func() error {
472-
return tac.derpMap(egCtx, client)
473-
})
474-
err := eg.Wait()
475-
if err != nil &&
476-
!xerrors.Is(err, io.EOF) &&
477-
!xerrors.Is(err, context.Canceled) &&
478-
!xerrors.Is(err, context.DeadlineExceeded) {
479-
tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API")
480-
}
495+
wg := sync.WaitGroup{}
496+
wg.Add(2)
497+
go func() {
498+
defer wg.Done()
499+
tac.coordinate(client)
500+
}()
501+
go func() {
502+
defer wg.Done()
503+
dErr := tac.derpMap(client)
504+
if dErr != nil && tac.ctx.Err() == nil {
505+
// The main context is still active, meaning that we want the tailnet data plane to stay
506+
// up, even though we hit some error getting DERP maps on the control plane. That means
507+
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
508+
// close the underlying connection. This will trigger a retry of the control plane in
509+
// run().
510+
client.DRPCConn().Close()
511+
// Note that derpMap() logs it own errors, we don't bother here.
512+
}
513+
}()
514+
wg.Wait()
481515
}
482516

483-
func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error {
484-
coord, err := client.Coordinate(ctx)
517+
func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
518+
// we use the gracefulCtx here so that we'll have time to send the graceful disconnect
519+
coord, err := client.Coordinate(tac.gracefulCtx)
485520
if err != nil {
486-
return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err)
521+
tac.logger.Error(tac.ctx, "failed to connect to Coordinate RPC", slog.Error(err))
522+
return
487523
}
488524
defer func() {
489525
cErr := coord.Close()
490526
if cErr != nil {
491-
tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr))
527+
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
492528
}
493529
}()
494530
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID)
495-
tac.logger.Debug(ctx, "serving coordinator")
496-
err = <-coordination.Error()
497-
if err != nil &&
498-
!xerrors.Is(err, io.EOF) &&
499-
!xerrors.Is(err, context.Canceled) &&
500-
!xerrors.Is(err, context.DeadlineExceeded) {
501-
return xerrors.Errorf("remote coordination error: %w", err)
531+
tac.logger.Debug(tac.ctx, "serving coordinator")
532+
select {
533+
case <-tac.ctx.Done():
534+
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
535+
crdErr := coordination.Close()
536+
if crdErr != nil {
537+
tac.logger.Error(tac.ctx, "failed to close remote coordination", slog.Error(err))
538+
}
539+
case err = <-coordination.Error():
540+
if err != nil &&
541+
!xerrors.Is(err, io.EOF) &&
542+
!xerrors.Is(err, context.Canceled) &&
543+
!xerrors.Is(err, context.DeadlineExceeded) {
544+
tac.logger.Error(tac.ctx, "remote coordination error: %w", err)
545+
}
502546
}
503-
return nil
504547
}
505548

506-
func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error {
507-
s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{})
549+
func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
550+
s, err := client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{})
508551
if err != nil {
509552
return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err)
510553
}
511554
defer func() {
512555
cErr := s.Close()
513556
if cErr != nil {
514-
tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
557+
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
515558
}
516559
}()
517560
for {
518561
dmp, err := s.Recv()
519562
if err != nil {
520-
if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
563+
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
521564
return nil
522565
}
523-
return xerrors.Errorf("error receiving DERP Map: %w", err)
566+
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
567+
return err
524568
}
525-
tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp))
569+
tac.logger.Debug(tac.ctx, "got new DERP Map", slog.F("derp_map", dmp))
526570
dm := tailnet.DERPMapFromProto(dmp)
527571
tac.conn.SetDERPMap(dm)
528572
}
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package codersdk
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
"net/http/httptest"
8+
"sync/atomic"
9+
"testing"
10+
"time"
11+
12+
"github.com/google/uuid"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
"nhooyr.io/websocket"
16+
"tailscale.com/tailcfg"
17+
18+
"cdr.dev/slog"
19+
"cdr.dev/slog/sloggers/slogtest"
20+
"github.com/coder/coder/v2/tailnet"
21+
"github.com/coder/coder/v2/tailnet/proto"
22+
"github.com/coder/coder/v2/tailnet/tailnettest"
23+
"github.com/coder/coder/v2/testutil"
24+
)
25+
26+
func TestTailnetAPIConnector_Disconnects(t *testing.T) {
27+
t.Parallel()
28+
testCtx := testutil.Context(t, testutil.WaitShort)
29+
ctx, cancel := context.WithCancel(testCtx)
30+
logger := slogtest.Make(t, &slogtest.Options{
31+
// we get EOF when we simulate a DERPMap error
32+
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, io.EOF),
33+
}).Leveled(slog.LevelDebug)
34+
agentID := uuid.UUID{0x55}
35+
clientID := uuid.UUID{0x66}
36+
fCoord := tailnettest.NewFakeCoordinator()
37+
var coord tailnet.Coordinator = fCoord
38+
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
39+
coordPtr.Store(&coord)
40+
derpMapCh := make(chan *tailcfg.DERPMap)
41+
defer close(derpMapCh)
42+
svc, err := tailnet.NewClientService(
43+
logger, &coordPtr,
44+
time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh },
45+
)
46+
require.NoError(t, err)
47+
48+
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49+
sws, err := websocket.Accept(w, r, nil)
50+
if !assert.NoError(t, err) {
51+
return
52+
}
53+
ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary)
54+
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
55+
Name: "client",
56+
ID: clientID,
57+
Auth: tailnet.ClientTunnelAuth{AgentID: agentID},
58+
})
59+
assert.NoError(t, err)
60+
}))
61+
62+
fConn := newFakeTailnetConn()
63+
64+
uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn)
65+
66+
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
67+
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
68+
require.NotNil(t, reqTun.AddTunnel)
69+
70+
_ = testutil.RequireRecvCtx(ctx, t, uut.connected)
71+
72+
// simulate a problem with DERPMaps by sending nil
73+
testutil.RequireSendCtx(ctx, t, derpMapCh, nil)
74+
75+
// this should cause the coordinate call to hang up WITHOUT disconnecting
76+
reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs)
77+
require.Nil(t, reqNil)
78+
79+
// ...and then reconnect
80+
call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
81+
reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs)
82+
require.NotNil(t, reqTun.AddTunnel)
83+
84+
// canceling the context should trigger the disconnect message
85+
cancel()
86+
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
87+
require.NotNil(t, reqDisc)
88+
require.NotNil(t, reqDisc.Disconnect)
89+
}
90+
91+
type fakeTailnetConn struct{}
92+
93+
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
94+
// TODO implement me
95+
panic("implement me")
96+
}
97+
98+
func (*fakeTailnetConn) SetAllPeersLost() {}
99+
100+
func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
101+
102+
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
103+
104+
func newFakeTailnetConn() *fakeTailnetConn {
105+
return &fakeTailnetConn{}
106+
}

tailnet/coordinator.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ func (c *remoteCoordination) Close() (retErr error) {
134134
if err != nil {
135135
return xerrors.Errorf("send disconnect: %w", err)
136136
}
137+
c.logger.Debug(context.Background(), "sent disconnect")
137138
return nil
138139
}
139140

@@ -167,7 +168,7 @@ func (c *remoteCoordination) respLoop() {
167168
}
168169
}
169170

170-
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a
171+
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a
171172
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
172173
// a client---agents should NOT set this!).
173174
func NewRemoteCoordination(logger slog.Logger,

0 commit comments

Comments
 (0)