Skip to content

Commit 2df9a3e

Browse files
authored
fix: fix tailnet remoteCoordination to wait for server (#14666)
Fixes #12560 When gracefully disconnecting from the coordinator, we would send the Disconnect message and then close the dRPC stream. However, closing the dRPC stream can cause the server not to process the Disconnect message, since we use the stream context in a `select` while sending it to the coordinator. This is a product bug uncovered by the flake, and probably results in us failing graceful disconnect some minority of the time. Instead, the `remoteCoordination` (and `inMemoryCoordination` for consistency) should send the Disconnect message and then wait for the coordinator to hang up (on some graceful disconnect timer, in the form of a context).
1 parent 7ea8a22 commit 2df9a3e

File tree

7 files changed

+72
-49
lines changed

7 files changed

+72
-49
lines changed

agent/agent.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1357,7 +1357,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai
13571357
defer close(errCh)
13581358
select {
13591359
case <-ctx.Done():
1360-
err := coordination.Close()
1360+
err := coordination.Close(a.hardCtx)
13611361
if err != nil {
13621362
a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err))
13631363
}

agent/agent_test.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -1896,7 +1896,9 @@ func TestAgent_UpdatedDERP(t *testing.T) {
18961896
coordinator, conn)
18971897
t.Cleanup(func() {
18981898
t.Logf("closing coordination %s", name)
1899-
err := coordination.Close()
1899+
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
1900+
defer ccancel()
1901+
err := coordination.Close(cctx)
19001902
if err != nil {
19011903
t.Logf("error closing in-memory coordination: %s", err.Error())
19021904
}
@@ -2384,7 +2386,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
23842386
clientID, metadata.AgentID,
23852387
coordinator, conn)
23862388
t.Cleanup(func() {
2387-
err := coordination.Close()
2389+
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
2390+
defer ccancel()
2391+
err := coordination.Close(cctx)
23882392
if err != nil {
23892393
t.Logf("error closing in-mem coordination: %s", err.Error())
23902394
}

codersdk/workspacesdk/connector.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
277277
select {
278278
case <-tac.ctx.Done():
279279
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
280-
crdErr := coordination.Close()
280+
crdErr := coordination.Close(tac.gracefulCtx)
281281
if crdErr != nil {
282282
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
283283
}

codersdk/workspacesdk/connector_internal_test.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
5757
derpMapCh := make(chan *tailcfg.DERPMap)
5858
defer close(derpMapCh)
5959
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
60-
Logger: logger,
60+
Logger: logger.Named("svc"),
6161
CoordPtr: &coordPtr,
6262
DERPMapUpdateFrequency: time.Millisecond,
6363
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
@@ -82,7 +82,8 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
8282

8383
fConn := newFakeTailnetConn()
8484

85-
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
85+
uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, svr.URL,
86+
quartz.NewReal(), &websocket.DialOptions{})
8687
uut.runConnector(fConn)
8788

8889
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
@@ -108,6 +109,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
108109
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
109110
require.NotNil(t, reqDisc)
110111
require.NotNil(t, reqDisc.Disconnect)
112+
close(call.Resps)
111113
}
112114

113115
func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {

tailnet/coordinator.go

+34-23
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ type Coordinatee interface {
9191
}
9292

9393
type Coordination interface {
94-
io.Closer
94+
Close(context.Context) error
9595
Error() <-chan error
9696
}
9797

@@ -106,14 +106,29 @@ type remoteCoordination struct {
106106
respLoopDone chan struct{}
107107
}
108108

109-
func (c *remoteCoordination) Close() (retErr error) {
109+
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
110+
// waiting for the server to hang up the coordination. If the provided context expires, we stop
111+
// waiting for the server and close the coordination stream from our end.
112+
func (c *remoteCoordination) Close(ctx context.Context) (retErr error) {
110113
c.Lock()
111114
defer c.Unlock()
112115
if c.closed {
113116
return nil
114117
}
115118
c.closed = true
116119
defer func() {
120+
// We shouldn't just close the protocol right away, because the way dRPC streams work is
121+
// that if you close them, that could take effect immediately, even before the Disconnect
122+
// message is processed. Coordinators are supposed to hang up on us once they get a
123+
// Disconnect message, so we should wait around for that until the context expires.
124+
select {
125+
case <-c.respLoopDone:
126+
c.logger.Debug(ctx, "responses closed after disconnect")
127+
return
128+
case <-ctx.Done():
129+
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
130+
}
131+
// forcefully close the stream
117132
protoErr := c.protocol.Close()
118133
<-c.respLoopDone
119134
if retErr == nil {
@@ -240,7 +255,6 @@ type inMemoryCoordination struct {
240255
ctx context.Context
241256
errChan chan error
242257
closed bool
243-
closedCh chan struct{}
244258
respLoopDone chan struct{}
245259
coordinatee Coordinatee
246260
logger slog.Logger
@@ -280,7 +294,6 @@ func NewInMemoryCoordination(
280294
errChan: make(chan error, 1),
281295
coordinatee: coordinatee,
282296
logger: logger,
283-
closedCh: make(chan struct{}),
284297
respLoopDone: make(chan struct{}),
285298
}
286299

@@ -328,24 +341,15 @@ func (c *inMemoryCoordination) respLoop() {
328341
c.coordinatee.SetAllPeersLost()
329342
close(c.respLoopDone)
330343
}()
331-
for {
332-
select {
333-
case <-c.closedCh:
334-
c.logger.Debug(context.Background(), "in-memory coordination closed")
344+
for resp := range c.resps {
345+
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
346+
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
347+
if err != nil {
348+
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
335349
return
336-
case resp, ok := <-c.resps:
337-
if !ok {
338-
c.logger.Debug(context.Background(), "in-memory response channel closed")
339-
return
340-
}
341-
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
342-
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
343-
if err != nil {
344-
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
345-
return
346-
}
347350
}
348351
}
352+
c.logger.Debug(context.Background(), "in-memory response channel closed")
349353
}
350354

351355
func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
@@ -355,7 +359,10 @@ func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
355359
return ch
356360
}
357361

358-
func (c *inMemoryCoordination) Close() error {
362+
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
363+
// waiting for the server to hang up the coordination. If the provided context expires, we stop
364+
// waiting for the server and close the coordination stream from our end.
365+
func (c *inMemoryCoordination) Close(ctx context.Context) error {
359366
c.Lock()
360367
defer c.Unlock()
361368
c.logger.Debug(context.Background(), "closing in-memory coordination")
@@ -364,13 +371,17 @@ func (c *inMemoryCoordination) Close() error {
364371
}
365372
defer close(c.reqs)
366373
c.closed = true
367-
close(c.closedCh)
368-
<-c.respLoopDone
369374
select {
370-
case <-c.ctx.Done():
375+
case <-ctx.Done():
371376
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
372377
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
373378
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
379+
}
380+
381+
select {
382+
case <-ctx.Done():
383+
return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err())
384+
case <-c.respLoopDone:
374385
return nil
375386
}
376387
}

tailnet/coordinator_test.go

+23-19
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tailnet_test
22

33
import (
44
"context"
5+
"io"
56
"net"
67
"net/netip"
78
"sync"
@@ -284,7 +285,7 @@ func TestInMemoryCoordination(t *testing.T) {
284285
Times(1).Return(reqs, resps)
285286

286287
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
287-
defer uut.Close()
288+
defer uut.Close(ctx)
288289

289290
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
290291

@@ -336,16 +337,13 @@ func TestRemoteCoordination(t *testing.T) {
336337
require.NoError(t, err)
337338

338339
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
339-
defer uut.Close()
340+
defer uut.Close(ctx)
340341

341342
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
342343

343-
select {
344-
case err := <-uut.Error():
345-
require.ErrorContains(t, err, "stream terminated by sending close")
346-
default:
347-
// OK!
348-
}
344+
// Recv loop should be terminated by the server hanging up after Disconnect
345+
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
346+
require.ErrorIs(t, err, io.EOF)
349347
}
350348

351349
func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
@@ -388,7 +386,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
388386
require.NoError(t, err)
389387

390388
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{})
391-
defer uut.Close()
389+
defer uut.Close(ctx)
392390

393391
nk, err := key.NewNode().Public().MarshalBinary()
394392
require.NoError(t, err)
@@ -411,14 +409,15 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
411409
require.Len(t, rfh.ReadyForHandshake, 1)
412410
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
413411

414-
require.NoError(t, uut.Close())
412+
go uut.Close(ctx)
413+
dis := testutil.RequireRecvCtx(ctx, t, reqs)
414+
require.NotNil(t, dis)
415+
require.NotNil(t, dis.Disconnect)
416+
close(resps)
415417

416-
select {
417-
case err := <-uut.Error():
418-
require.ErrorContains(t, err, "stream terminated by sending close")
419-
default:
420-
// OK!
421-
}
418+
// Recv loop should be terminated by the server hanging up after Disconnect
419+
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
420+
require.ErrorIs(t, err, io.EOF)
422421
}
423422

424423
// coordinationTest tests that a coordination behaves correctly
@@ -464,13 +463,18 @@ func coordinationTest(
464463
require.Len(t, fConn.updates[0], 1)
465464
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
466465

467-
err = uut.Close()
468-
require.NoError(t, err)
469-
uut.Error()
466+
errCh := make(chan error, 1)
467+
go func() {
468+
errCh <- uut.Close(ctx)
469+
}()
470470

471471
// When we close, it should gracefully disconnect
472472
req = testutil.RequireRecvCtx(ctx, t, reqs)
473473
require.NotNil(t, req.Disconnect)
474+
close(resps)
475+
476+
err = testutil.RequireRecvCtx(ctx, t, errCh)
477+
require.NoError(t, err)
474478

475479
// It should set all peers lost on the coordinatee
476480
require.Equal(t, 1, fConn.setAllPeersLostCalls)

tailnet/test/integration/integration.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,9 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me
469469

470470
coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID)
471471
t.Cleanup(func() {
472-
_ = coordination.Close()
472+
cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
473+
defer cancel()
474+
_ = coordination.Close(cctx)
473475
})
474476

475477
return conn

0 commit comments

Comments
 (0)