Skip to content

Commit 53b97df

Browse files
committed
feat: set peers lost when disconnected from coordinator
1 parent d476a87 commit 53b97df

File tree

4 files changed

+226
-28
lines changed

4 files changed

+226
-28
lines changed

tailnet/conn.go

+5
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,11 @@ func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error
356356
return nil
357357
}
358358

359+
// SetAllPeersLost marks all peers lost; typically used when we disconnect from a coordinator.
360+
func (c *Conn) SetAllPeersLost() {
361+
c.configMaps.setAllPeersLost()
362+
}
363+
359364
// NodeAddresses returns the addresses of a node from the NetworkMap.
360365
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
361366
return c.configMaps.nodeAddresses(publicKey)

tailnet/coordinator.go

+44-23
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ type Node struct {
9797
// Conn.
9898
type Coordinatee interface {
9999
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
100+
SetAllPeersLost()
100101
SetNodeCallback(func(*Node))
101102
}
102103

@@ -107,20 +108,28 @@ type Coordination interface {
107108

108109
type remoteCoordination struct {
109110
sync.Mutex
110-
closed bool
111-
errChan chan error
112-
coordinatee Coordinatee
113-
logger slog.Logger
114-
protocol proto.DRPCTailnet_CoordinateClient
111+
closed bool
112+
errChan chan error
113+
coordinatee Coordinatee
114+
logger slog.Logger
115+
protocol proto.DRPCTailnet_CoordinateClient
116+
respLoopDone chan struct{}
115117
}
116118

117-
func (c *remoteCoordination) Close() error {
119+
func (c *remoteCoordination) Close() (retErr error) {
118120
c.Lock()
119121
defer c.Unlock()
120122
if c.closed {
121123
return nil
122124
}
123125
c.closed = true
126+
defer func() {
127+
protoErr := c.protocol.Close()
128+
<-c.respLoopDone
129+
if retErr == nil {
130+
retErr = protoErr
131+
}
132+
}()
124133
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
125134
if err != nil {
126135
return xerrors.Errorf("send disconnect: %w", err)
@@ -140,6 +149,10 @@ func (c *remoteCoordination) sendErr(err error) {
140149
}
141150

142151
func (c *remoteCoordination) respLoop() {
152+
defer func() {
153+
c.coordinatee.SetAllPeersLost()
154+
close(c.respLoopDone)
155+
}()
143156
for {
144157
resp, err := c.protocol.Recv()
145158
if err != nil {
@@ -162,10 +175,11 @@ func NewRemoteCoordination(logger slog.Logger,
162175
tunnelTarget uuid.UUID,
163176
) Coordination {
164177
c := &remoteCoordination{
165-
errChan: make(chan error, 1),
166-
coordinatee: coordinatee,
167-
logger: logger,
168-
protocol: protocol,
178+
errChan: make(chan error, 1),
179+
coordinatee: coordinatee,
180+
logger: logger,
181+
protocol: protocol,
182+
respLoopDone: make(chan struct{}),
169183
}
170184
if tunnelTarget != uuid.Nil {
171185
c.Lock()
@@ -200,14 +214,15 @@ func NewRemoteCoordination(logger slog.Logger,
200214

201215
type inMemoryCoordination struct {
202216
sync.Mutex
203-
ctx context.Context
204-
errChan chan error
205-
closed bool
206-
closedCh chan struct{}
207-
coordinatee Coordinatee
208-
logger slog.Logger
209-
resps <-chan *proto.CoordinateResponse
210-
reqs chan<- *proto.CoordinateRequest
217+
ctx context.Context
218+
errChan chan error
219+
closed bool
220+
closedCh chan struct{}
221+
respLoopDone chan struct{}
222+
coordinatee Coordinatee
223+
logger slog.Logger
224+
resps <-chan *proto.CoordinateResponse
225+
reqs chan<- *proto.CoordinateRequest
211226
}
212227

213228
func (c *inMemoryCoordination) sendErr(err error) {
@@ -238,11 +253,12 @@ func NewInMemoryCoordination(
238253
thisID = clientID
239254
}
240255
c := &inMemoryCoordination{
241-
ctx: ctx,
242-
errChan: make(chan error, 1),
243-
coordinatee: coordinatee,
244-
logger: logger,
245-
closedCh: make(chan struct{}),
256+
ctx: ctx,
257+
errChan: make(chan error, 1),
258+
coordinatee: coordinatee,
259+
logger: logger,
260+
closedCh: make(chan struct{}),
261+
respLoopDone: make(chan struct{}),
246262
}
247263

248264
// use the background context since we will depend exclusively on closing the req channel to
@@ -285,6 +301,10 @@ func NewInMemoryCoordination(
285301
}
286302

287303
func (c *inMemoryCoordination) respLoop() {
304+
defer func() {
305+
c.coordinatee.SetAllPeersLost()
306+
close(c.respLoopDone)
307+
}()
288308
for {
289309
select {
290310
case <-c.closedCh:
@@ -315,6 +335,7 @@ func (c *inMemoryCoordination) Close() error {
315335
defer close(c.reqs)
316336
c.closed = true
317337
close(c.closedCh)
338+
<-c.respLoopDone
318339
select {
319340
case <-c.ctx.Done():
320341
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())

tailnet/coordinator_test.go

+167-5
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@ import (
66
"net"
77
"net/http"
88
"net/http/httptest"
9+
"sync"
10+
"sync/atomic"
911
"testing"
1012
"time"
1113

12-
"nhooyr.io/websocket"
13-
14-
"cdr.dev/slog"
15-
"cdr.dev/slog/sloggers/slogtest"
16-
1714
"github.com/google/uuid"
1815
"github.com/stretchr/testify/assert"
1916
"github.com/stretchr/testify/require"
17+
"go.uber.org/mock/gomock"
18+
"nhooyr.io/websocket"
19+
"tailscale.com/tailcfg"
20+
"tailscale.com/types/key"
2021

22+
"cdr.dev/slog"
23+
"cdr.dev/slog/sloggers/slogtest"
2124
"github.com/coder/coder/v2/tailnet"
25+
"github.com/coder/coder/v2/tailnet/proto"
26+
"github.com/coder/coder/v2/tailnet/tailnettest"
2227
"github.com/coder/coder/v2/tailnet/test"
2328
"github.com/coder/coder/v2/testutil"
2429
)
@@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
400405
require.True(t, ok)
401406
return client, server
402407
}
408+
409+
func TestInMemoryCoordination(t *testing.T) {
410+
t.Parallel()
411+
ctx := testutil.Context(t, testutil.WaitShort)
412+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
413+
clientID := uuid.UUID{1}
414+
agentID := uuid.UUID{2}
415+
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
416+
fConn := &fakeCoordinatee{}
417+
418+
reqs := make(chan *proto.CoordinateRequest, 100)
419+
resps := make(chan *proto.CoordinateResponse, 100)
420+
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
421+
Times(1).Return(reqs, resps)
422+
423+
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
424+
defer uut.Close()
425+
426+
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
427+
428+
select {
429+
case err := <-uut.Error():
430+
require.NoError(t, err)
431+
default:
432+
// OK!
433+
}
434+
}
435+
436+
func TestRemoteCoordination(t *testing.T) {
437+
t.Parallel()
438+
ctx := testutil.Context(t, testutil.WaitShort)
439+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
440+
clientID := uuid.UUID{1}
441+
agentID := uuid.UUID{2}
442+
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
443+
fConn := &fakeCoordinatee{}
444+
445+
reqs := make(chan *proto.CoordinateRequest, 100)
446+
resps := make(chan *proto.CoordinateResponse, 100)
447+
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
448+
Times(1).Return(reqs, resps)
449+
450+
var coord tailnet.Coordinator = mCoord
451+
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
452+
coordPtr.Store(&coord)
453+
svc, err := tailnet.NewClientService(
454+
logger.Named("svc"), &coordPtr,
455+
time.Hour,
456+
func() *tailcfg.DERPMap { panic("not implemented") },
457+
)
458+
require.NoError(t, err)
459+
sC, cC := net.Pipe()
460+
461+
serveErr := make(chan error, 1)
462+
go func() {
463+
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
464+
serveErr <- err
465+
}()
466+
467+
client, err := tailnet.NewDRPCClient(cC)
468+
require.NoError(t, err)
469+
protocol, err := client.Coordinate(ctx)
470+
require.NoError(t, err)
471+
472+
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
473+
defer uut.Close()
474+
475+
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
476+
477+
select {
478+
case err := <-uut.Error():
479+
require.ErrorContains(t, err, "stream terminated by sending close")
480+
default:
481+
// OK!
482+
}
483+
}
484+
485+
// coordinationTest tests that a coordination behaves correctly
486+
func coordinationTest(
487+
ctx context.Context, t *testing.T,
488+
uut tailnet.Coordination, fConn *fakeCoordinatee,
489+
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
490+
agentID uuid.UUID,
491+
) {
492+
// It should add the tunnel, since we configured as a client
493+
req := testutil.RequireRecvCtx(ctx, t, reqs)
494+
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
495+
496+
// when we call the callback, it should send a node update
497+
require.NotNil(t, fConn.callback)
498+
fConn.callback(&tailnet.Node{PreferredDERP: 1})
499+
500+
req = testutil.RequireRecvCtx(ctx, t, reqs)
501+
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
502+
503+
// When we send a peer update, it should update the coordinatee
504+
nk, err := key.NewNode().Public().MarshalBinary()
505+
require.NoError(t, err)
506+
dk, err := key.NewDisco().Public().MarshalText()
507+
require.NoError(t, err)
508+
updates := []*proto.CoordinateResponse_PeerUpdate{
509+
{
510+
Id: agentID[:],
511+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
512+
Node: &proto.Node{
513+
Id: 2,
514+
Key: nk,
515+
Disco: string(dk),
516+
},
517+
},
518+
}
519+
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
520+
require.Eventually(t, func() bool {
521+
fConn.Lock()
522+
defer fConn.Unlock()
523+
return len(fConn.updates) > 0
524+
}, testutil.WaitShort, testutil.IntervalFast)
525+
require.Len(t, fConn.updates[0], 1)
526+
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
527+
528+
err = uut.Close()
529+
require.NoError(t, err)
530+
uut.Error()
531+
532+
// When we close, it should gracefully disconnect
533+
req = testutil.RequireRecvCtx(ctx, t, reqs)
534+
require.NotNil(t, req.Disconnect)
535+
536+
// It should set all peers lost on the coordinatee
537+
require.Equal(t, 1, fConn.setAllPeersLostCalls)
538+
}
539+
540+
type fakeCoordinatee struct {
541+
sync.Mutex
542+
callback func(*tailnet.Node)
543+
updates [][]*proto.CoordinateResponse_PeerUpdate
544+
setAllPeersLostCalls int
545+
}
546+
547+
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
548+
f.Lock()
549+
defer f.Unlock()
550+
f.updates = append(f.updates, updates)
551+
return nil
552+
}
553+
554+
func (f *fakeCoordinatee) SetAllPeersLost() {
555+
f.Lock()
556+
defer f.Unlock()
557+
f.setAllPeersLostCalls++
558+
}
559+
560+
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
561+
f.Lock()
562+
defer f.Unlock()
563+
f.callback = callback
564+
}

testutil/ctx.go

+10
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,13 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A)
2222
return a
2323
}
2424
}
25+
26+
func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
27+
t.Helper()
28+
select {
29+
case <-ctx.Done():
30+
t.Fatal("timeout")
31+
case c <- a:
32+
// OK!
33+
}
34+
}

0 commit comments

Comments
 (0)