Skip to content

Commit a6dfa38

Browse files
committed
feat: support graceful disconnect in PGCoordinator
1 parent 23194ad commit a6dfa38

File tree

4 files changed

+318
-118
lines changed

4 files changed

+318
-118
lines changed

enterprise/tailnet/connio.go

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,17 @@ type connIO struct {
2424
// coordCtx is the parent context, that is, the context of the Coordinator
2525
coordCtx context.Context
2626
// peerCtx is the context of the connection to our peer
27-
peerCtx context.Context
28-
cancel context.CancelFunc
29-
logger slog.Logger
30-
requests <-chan *proto.CoordinateRequest
31-
responses chan<- *proto.CoordinateResponse
32-
bindings chan<- binding
33-
tunnels chan<- tunnel
34-
auth agpl.TunnelAuth
35-
mu sync.Mutex
36-
closed bool
27+
peerCtx context.Context
28+
cancel context.CancelFunc
29+
logger slog.Logger
30+
requests <-chan *proto.CoordinateRequest
31+
responses chan<- *proto.CoordinateResponse
32+
bindings chan<- binding
33+
tunnels chan<- tunnel
34+
auth agpl.TunnelAuth
35+
mu sync.Mutex
36+
closed bool
37+
disconnected bool
3738

3839
name string
3940
start int64
@@ -76,20 +77,29 @@ func newConnIO(coordContext context.Context,
7677

7778
func (c *connIO) recvLoop() {
7879
defer func() {
79-
// withdraw bindings & tunnels when we exit. We need to use the parent context here, since
80+
// withdraw bindings & tunnels when we exit. We need to use the coordinator context here, since
8081
// our own context might be canceled, but we still need to withdraw.
8182
b := binding{
8283
bKey: bKey(c.UniqueID()),
84+
kind: proto.CoordinateResponse_PeerUpdate_LOST,
85+
}
86+
if c.disconnected {
87+
b.kind = proto.CoordinateResponse_PeerUpdate_DISCONNECTED
8388
}
8489
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
8590
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
8691
}
87-
t := tunnel{
88-
tKey: tKey{src: c.UniqueID()},
89-
active: false,
90-
}
91-
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
92-
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
92+
// only remove tunnels on graceful disconnect. If we remove tunnels for lost peers, then
93+
// this will look like a disconnect from the peer perspective, since we query for active peers
94+
// by using the tunnel as a join in the database
95+
if c.disconnected {
96+
t := tunnel{
97+
tKey: tKey{src: c.UniqueID()},
98+
active: false,
99+
}
100+
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
101+
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
102+
}
93103
}
94104
}()
95105
defer c.Close()
@@ -111,13 +121,16 @@ func (c *connIO) recvLoop() {
111121
}
112122
}
113123

124+
var errDisconnect = xerrors.New("graceful disconnect")
125+
114126
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
115127
c.logger.Debug(c.peerCtx, "got request")
116128
if req.UpdateSelf != nil {
117129
c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf))
118130
b := binding{
119131
bKey: bKey(c.UniqueID()),
120132
node: req.UpdateSelf.Node,
133+
kind: proto.CoordinateResponse_PeerUpdate_NODE,
121134
}
122135
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
123136
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
@@ -169,7 +182,11 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
169182
return err
170183
}
171184
}
172-
// TODO: (spikecurtis) support Disconnect
185+
if req.Disconnect != nil {
186+
c.logger.Debug(c.peerCtx, "graceful disconnect")
187+
c.disconnected = true
188+
return errDisconnect
189+
}
173190
return nil
174191
}
175192

enterprise/tailnet/multiagent_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
5858
require.NoError(t, agent1.close())
5959

6060
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
61-
assertEventuallyNoAgents(ctx, t, store, agent1.id)
61+
assertEventuallyLost(ctx, t, store, agent1.id)
6262
}
6363

6464
// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
@@ -106,7 +106,7 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
106106
require.NoError(t, agent1.close())
107107

108108
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
109-
assertEventuallyNoAgents(ctx, t, store, agent1.id)
109+
assertEventuallyLost(ctx, t, store, agent1.id)
110110
}
111111

112112
// TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a
@@ -168,7 +168,7 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
168168
require.NoError(t, agent1.close())
169169

170170
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
171-
assertEventuallyNoAgents(ctx, t, store, agent1.id)
171+
assertEventuallyLost(ctx, t, store, agent1.id)
172172
}
173173

174174
// TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a
@@ -220,7 +220,7 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
220220
require.NoError(t, agent1.close())
221221

222222
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
223-
assertEventuallyNoAgents(ctx, t, store, agent1.id)
223+
assertEventuallyLost(ctx, t, store, agent1.id)
224224
}
225225

226226
// TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two
@@ -273,7 +273,7 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
273273
require.NoError(t, agent1.close())
274274

275275
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
276-
assertEventuallyNoAgents(ctx, t, store, agent1.id)
276+
assertEventuallyLost(ctx, t, store, agent1.id)
277277
}
278278

279279
// TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a
@@ -344,5 +344,5 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
344344
require.NoError(t, agent2.close())
345345

346346
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
347-
assertEventuallyNoAgents(ctx, t, store, agent1.id)
347+
assertEventuallyLost(ctx, t, store, agent1.id)
348348
}

0 commit comments

Comments
 (0)