Skip to content

Commit 540d7eb

Browse files
committed
fix: fix tailnet remoteCoordination to wait for server
1 parent 9dc8e0f commit 540d7eb

File tree

7 files changed

+72
-49
lines changed

7 files changed

+72
-49
lines changed

agent/agent.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1357,7 +1357,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai
13571357
defer close(errCh)
13581358
select {
13591359
case <-ctx.Done():
1360-
err := coordination.Close()
1360+
err := coordination.Close(a.hardCtx)
13611361
if err != nil {
13621362
a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err))
13631363
}

agent/agent_test.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -1896,7 +1896,9 @@ func TestAgent_UpdatedDERP(t *testing.T) {
18961896
coordinator, conn)
18971897
t.Cleanup(func() {
18981898
t.Logf("closing coordination %s", name)
1899-
err := coordination.Close()
1899+
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
1900+
defer ccancel()
1901+
err := coordination.Close(cctx)
19001902
if err != nil {
19011903
t.Logf("error closing in-memory coordination: %s", err.Error())
19021904
}
@@ -2384,7 +2386,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
23842386
clientID, metadata.AgentID,
23852387
coordinator, conn)
23862388
t.Cleanup(func() {
2387-
err := coordination.Close()
2389+
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
2390+
defer ccancel()
2391+
err := coordination.Close(cctx)
23882392
if err != nil {
23892393
t.Logf("error closing in-mem coordination: %s", err.Error())
23902394
}

codersdk/workspacesdk/connector.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
277277
select {
278278
case <-tac.ctx.Done():
279279
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
280-
crdErr := coordination.Close()
280+
crdErr := coordination.Close(tac.gracefulCtx)
281281
if crdErr != nil {
282282
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
283283
}

codersdk/workspacesdk/connector_internal_test.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
5757
derpMapCh := make(chan *tailcfg.DERPMap)
5858
defer close(derpMapCh)
5959
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
60-
Logger: logger,
60+
Logger: logger.Named("svc"),
6161
CoordPtr: &coordPtr,
6262
DERPMapUpdateFrequency: time.Millisecond,
6363
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
@@ -82,7 +82,8 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
8282

8383
fConn := newFakeTailnetConn()
8484

85-
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
85+
uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, svr.URL,
86+
quartz.NewReal(), &websocket.DialOptions{})
8687
uut.runConnector(fConn)
8788

8889
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
@@ -108,6 +109,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
108109
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
109110
require.NotNil(t, reqDisc)
110111
require.NotNil(t, reqDisc.Disconnect)
112+
close(call.Resps)
111113
}
112114

113115
func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {

tailnet/coordinator.go

+34-23
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ type Coordinatee interface {
9191
}
9292

9393
type Coordination interface {
94-
io.Closer
94+
Close(context.Context) error
9595
Error() <-chan error
9696
}
9797

@@ -106,14 +106,29 @@ type remoteCoordination struct {
106106
respLoopDone chan struct{}
107107
}
108108

109-
func (c *remoteCoordination) Close() (retErr error) {
109+
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
110+
// waiting for the server to hang up the coordination. If the provided context expires, we stop
111+
// waiting for the server and close the coordination stream from our end.
112+
func (c *remoteCoordination) Close(ctx context.Context) (retErr error) {
110113
c.Lock()
111114
defer c.Unlock()
112115
if c.closed {
113116
return nil
114117
}
115118
c.closed = true
116119
defer func() {
120+
// We shouldn't just close the protocol right away, because the way dRPC streams work is
121+
// that if you close them, that could take effect immediately, even before the Disconnect
122+
// message is processed. Coordinators are supposed to hang up on us once they get a
123+
// Disconnect message, so we should wait around for that until the context expires.
124+
select {
125+
case <-c.respLoopDone:
126+
c.logger.Debug(ctx, "responses closed after disconnect")
127+
return
128+
case <-ctx.Done():
129+
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
130+
}
131+
// forcefully close the stream
117132
protoErr := c.protocol.Close()
118133
<-c.respLoopDone
119134
if retErr == nil {
@@ -240,7 +255,6 @@ type inMemoryCoordination struct {
240255
ctx context.Context
241256
errChan chan error
242257
closed bool
243-
closedCh chan struct{}
244258
respLoopDone chan struct{}
245259
coordinatee Coordinatee
246260
logger slog.Logger
@@ -280,7 +294,6 @@ func NewInMemoryCoordination(
280294
errChan: make(chan error, 1),
281295
coordinatee: coordinatee,
282296
logger: logger,
283-
closedCh: make(chan struct{}),
284297
respLoopDone: make(chan struct{}),
285298
}
286299

@@ -328,24 +341,15 @@ func (c *inMemoryCoordination) respLoop() {
328341
c.coordinatee.SetAllPeersLost()
329342
close(c.respLoopDone)
330343
}()
331-
for {
332-
select {
333-
case <-c.closedCh:
334-
c.logger.Debug(context.Background(), "in-memory coordination closed")
344+
for resp := range c.resps {
345+
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
346+
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
347+
if err != nil {
348+
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
335349
return
336-
case resp, ok := <-c.resps:
337-
if !ok {
338-
c.logger.Debug(context.Background(), "in-memory response channel closed")
339-
return
340-
}
341-
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
342-
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
343-
if err != nil {
344-
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
345-
return
346-
}
347350
}
348351
}
352+
c.logger.Debug(context.Background(), "in-memory response channel closed")
349353
}
350354

351355
func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
@@ -355,7 +359,10 @@ func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
355359
return ch
356360
}
357361

358-
func (c *inMemoryCoordination) Close() error {
362+
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
363+
// waiting for the server to hang up the coordination. If the provided context expires, we stop
364+
// waiting for the server and close the coordination stream from our end.
365+
func (c *inMemoryCoordination) Close(ctx context.Context) error {
359366
c.Lock()
360367
defer c.Unlock()
361368
c.logger.Debug(context.Background(), "closing in-memory coordination")
@@ -364,13 +371,17 @@ func (c *inMemoryCoordination) Close() error {
364371
}
365372
defer close(c.reqs)
366373
c.closed = true
367-
close(c.closedCh)
368-
<-c.respLoopDone
369374
select {
370-
case <-c.ctx.Done():
375+
case <-ctx.Done():
371376
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
372377
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
373378
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
379+
}
380+
381+
select {
382+
case <-ctx.Done():
383+
return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err())
384+
case <-c.respLoopDone:
374385
return nil
375386
}
376387
}

tailnet/coordinator_test.go

+23-19
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tailnet_test
22

33
import (
44
"context"
5+
"io"
56
"net"
67
"net/netip"
78
"sync"
@@ -284,7 +285,7 @@ func TestInMemoryCoordination(t *testing.T) {
284285
Times(1).Return(reqs, resps)
285286

286287
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
287-
defer uut.Close()
288+
defer uut.Close(ctx)
288289

289290
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
290291

@@ -336,16 +337,13 @@ func TestRemoteCoordination(t *testing.T) {
336337
require.NoError(t, err)
337338

338339
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
339-
defer uut.Close()
340+
defer uut.Close(ctx)
340341

341342
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
342343

343-
select {
344-
case err := <-uut.Error():
345-
require.ErrorContains(t, err, "stream terminated by sending close")
346-
default:
347-
// OK!
348-
}
344+
// Recv loop should be terminated by the server hanging up after Disconnect
345+
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
346+
require.ErrorIs(t, err, io.EOF)
349347
}
350348

351349
func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
@@ -388,7 +386,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
388386
require.NoError(t, err)
389387

390388
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{})
391-
defer uut.Close()
389+
defer uut.Close(ctx)
392390

393391
nk, err := key.NewNode().Public().MarshalBinary()
394392
require.NoError(t, err)
@@ -411,14 +409,15 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
411409
require.Len(t, rfh.ReadyForHandshake, 1)
412410
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
413411

414-
require.NoError(t, uut.Close())
412+
go uut.Close(ctx)
413+
dis := testutil.RequireRecvCtx(ctx, t, reqs)
414+
require.NotNil(t, dis)
415+
require.NotNil(t, dis.Disconnect)
416+
close(resps)
415417

416-
select {
417-
case err := <-uut.Error():
418-
require.ErrorContains(t, err, "stream terminated by sending close")
419-
default:
420-
// OK!
421-
}
418+
// Recv loop should be terminated by the server hanging up after Disconnect
419+
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
420+
require.ErrorIs(t, err, io.EOF)
422421
}
423422

424423
// coordinationTest tests that a coordination behaves correctly
@@ -464,13 +463,18 @@ func coordinationTest(
464463
require.Len(t, fConn.updates[0], 1)
465464
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
466465

467-
err = uut.Close()
468-
require.NoError(t, err)
469-
uut.Error()
466+
errCh := make(chan error, 1)
467+
go func() {
468+
errCh <- uut.Close(ctx)
469+
}()
470470

471471
// When we close, it should gracefully disconnect
472472
req = testutil.RequireRecvCtx(ctx, t, reqs)
473473
require.NotNil(t, req.Disconnect)
474+
close(resps)
475+
476+
err = testutil.RequireRecvCtx(ctx, t, errCh)
477+
require.NoError(t, err)
474478

475479
// It should set all peers lost on the coordinatee
476480
require.Equal(t, 1, fConn.setAllPeersLostCalls)

tailnet/test/integration/integration.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,9 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me
469469

470470
coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID)
471471
t.Cleanup(func() {
472-
_ = coordination.Close()
472+
cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
473+
defer cancel()
474+
_ = coordination.Close(cctx)
473475
})
474476

475477
return conn

0 commit comments

Comments
 (0)