Skip to content

Commit 520b12e

Browse files
authored
fix: close MultiAgentConn when coordinator closes (#11941)
Fixes an issue where a MultiAgentConn isn't closed properly when the coordinator it is connected to is closed. Since servertailnet checks whether the conn is closed before reinitializing, it is important that we check this, otherwise servertailnet can get stuck if the coordinator closes (e.g. when we switch from AGPL to PGCoordinator after decoding a license).
1 parent 2fd1a72 commit 520b12e

File tree

5 files changed

+235
-166
lines changed

5 files changed

+235
-166
lines changed

enterprise/tailnet/connio.go

+13-12
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package tailnet
22

33
import (
44
"context"
5-
"io"
65
"sync"
76
"sync/atomic"
87
"time"
@@ -104,19 +103,21 @@ func (c *connIO) recvLoop() {
104103
}()
105104
defer c.Close()
106105
for {
107-
req, err := agpl.RecvCtx(c.peerCtx, c.requests)
108-
if err != nil {
109-
if xerrors.Is(err, context.Canceled) ||
110-
xerrors.Is(err, context.DeadlineExceeded) ||
111-
xerrors.Is(err, io.EOF) {
112-
c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err))
113-
} else {
114-
c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err))
115-
}
106+
select {
107+
case <-c.coordCtx.Done():
108+
c.logger.Debug(c.coordCtx, "exiting io recvLoop; coordinator exit")
116109
return
117-
}
118-
if err := c.handleRequest(req); err != nil {
110+
case <-c.peerCtx.Done():
111+
c.logger.Debug(c.peerCtx, "exiting io recvLoop; peer context canceled")
119112
return
113+
case req, ok := <-c.requests:
114+
if !ok {
115+
c.logger.Debug(c.peerCtx, "exiting io recvLoop; requests chan closed")
116+
return
117+
}
118+
if err := c.handleRequest(req); err != nil {
119+
return
120+
}
120121
}
121122
}
122123
}

enterprise/tailnet/multiagent_test.go

+73-153
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,14 @@ import (
44
"context"
55
"testing"
66

7-
"github.com/google/uuid"
87
"github.com/stretchr/testify/require"
9-
"golang.org/x/exp/slices"
10-
"tailscale.com/types/key"
118

129
"cdr.dev/slog"
1310
"cdr.dev/slog/sloggers/slogtest"
1411
"github.com/coder/coder/v2/coderd/database/dbtestutil"
1512
"github.com/coder/coder/v2/enterprise/tailnet"
1613
agpl "github.com/coder/coder/v2/tailnet"
17-
"github.com/coder/coder/v2/tailnet/proto"
14+
"github.com/coder/coder/v2/tailnet/tailnettest"
1815
"github.com/coder/coder/v2/testutil"
1916
)
2017

@@ -42,25 +39,48 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
4239
defer agent1.close()
4340
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
4441

45-
ma1 := newTestMultiAgent(t, coord1)
46-
defer ma1.close()
42+
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
43+
defer ma1.Close()
4744

48-
ma1.subscribeAgent(agent1.id)
49-
ma1.assertEventuallyHasDERPs(ctx, 5)
45+
ma1.RequireSubscribeAgent(agent1.id)
46+
ma1.RequireEventuallyHasDERPs(ctx, 5)
5047

5148
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
52-
ma1.assertEventuallyHasDERPs(ctx, 1)
49+
ma1.RequireEventuallyHasDERPs(ctx, 1)
5350

54-
ma1.sendNodeWithDERP(3)
51+
ma1.SendNodeWithDERP(3)
5552
assertEventuallyHasDERPs(ctx, t, agent1, 3)
5653

57-
ma1.close()
54+
ma1.Close()
5855
require.NoError(t, agent1.close())
5956

6057
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
6158
assertEventuallyLost(ctx, t, store, agent1.id)
6259
}
6360

61+
func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) {
62+
t.Parallel()
63+
if !dbtestutil.WillUsePostgres() {
64+
t.Skip("test only with postgres")
65+
}
66+
67+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
68+
store, ps := dbtestutil.NewDB(t)
69+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
70+
defer cancel()
71+
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
72+
require.NoError(t, err)
73+
defer coord1.Close()
74+
75+
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
76+
defer ma1.Close()
77+
78+
err = coord1.Close()
79+
require.NoError(t, err)
80+
81+
ma1.RequireEventuallyClosed(ctx)
82+
}
83+
6484
// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
6585
// a MultiAgent connecting to one agent. It tries to race a call to Unsubscribe
6686
// with the MultiAgent closing.
@@ -86,20 +106,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
86106
defer agent1.close()
87107
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
88108

89-
ma1 := newTestMultiAgent(t, coord1)
90-
defer ma1.close()
109+
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
110+
defer ma1.Close()
91111

92-
ma1.subscribeAgent(agent1.id)
93-
ma1.assertEventuallyHasDERPs(ctx, 5)
112+
ma1.RequireSubscribeAgent(agent1.id)
113+
ma1.RequireEventuallyHasDERPs(ctx, 5)
94114

95115
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
96-
ma1.assertEventuallyHasDERPs(ctx, 1)
116+
ma1.RequireEventuallyHasDERPs(ctx, 1)
97117

98-
ma1.sendNodeWithDERP(3)
118+
ma1.SendNodeWithDERP(3)
99119
assertEventuallyHasDERPs(ctx, t, agent1, 3)
100120

101-
ma1.unsubscribeAgent(agent1.id)
102-
ma1.close()
121+
ma1.RequireUnsubscribeAgent(agent1.id)
122+
ma1.Close()
103123
require.NoError(t, agent1.close())
104124

105125
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@@ -131,35 +151,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
131151
defer agent1.close()
132152
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
133153

134-
ma1 := newTestMultiAgent(t, coord1)
135-
defer ma1.close()
154+
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
155+
defer ma1.Close()
136156

137-
ma1.subscribeAgent(agent1.id)
138-
ma1.assertEventuallyHasDERPs(ctx, 5)
157+
ma1.RequireSubscribeAgent(agent1.id)
158+
ma1.RequireEventuallyHasDERPs(ctx, 5)
139159

140160
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
141-
ma1.assertEventuallyHasDERPs(ctx, 1)
161+
ma1.RequireEventuallyHasDERPs(ctx, 1)
142162

143-
ma1.sendNodeWithDERP(3)
163+
ma1.SendNodeWithDERP(3)
144164
assertEventuallyHasDERPs(ctx, t, agent1, 3)
145165

146-
ma1.unsubscribeAgent(agent1.id)
166+
ma1.RequireUnsubscribeAgent(agent1.id)
147167
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
148168

149169
func() {
150170
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
151171
defer cancel()
152-
ma1.sendNodeWithDERP(9)
172+
ma1.SendNodeWithDERP(9)
153173
assertNeverHasDERPs(ctx, t, agent1, 9)
154174
}()
155175
func() {
156176
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
157177
defer cancel()
158178
agent1.sendNode(&agpl.Node{PreferredDERP: 8})
159-
ma1.assertNeverHasDERPs(ctx, 8)
179+
ma1.RequireNeverHasDERPs(ctx, 8)
160180
}()
161181

162-
ma1.close()
182+
ma1.Close()
163183
require.NoError(t, agent1.close())
164184

165185
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@@ -196,19 +216,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
196216
defer agent1.close()
197217
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
198218

199-
ma1 := newTestMultiAgent(t, coord2)
200-
defer ma1.close()
219+
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
220+
defer ma1.Close()
201221

202-
ma1.subscribeAgent(agent1.id)
203-
ma1.assertEventuallyHasDERPs(ctx, 5)
222+
ma1.RequireSubscribeAgent(agent1.id)
223+
ma1.RequireEventuallyHasDERPs(ctx, 5)
204224

205225
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
206-
ma1.assertEventuallyHasDERPs(ctx, 1)
226+
ma1.RequireEventuallyHasDERPs(ctx, 1)
207227

208-
ma1.sendNodeWithDERP(3)
228+
ma1.SendNodeWithDERP(3)
209229
assertEventuallyHasDERPs(ctx, t, agent1, 3)
210230

211-
ma1.close()
231+
ma1.Close()
212232
require.NoError(t, agent1.close())
213233

214234
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@@ -246,19 +266,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
246266
defer agent1.close()
247267
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
248268

249-
ma1 := newTestMultiAgent(t, coord2)
250-
defer ma1.close()
269+
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
270+
defer ma1.Close()
251271

252-
ma1.sendNodeWithDERP(3)
272+
ma1.SendNodeWithDERP(3)
253273

254-
ma1.subscribeAgent(agent1.id)
255-
ma1.assertEventuallyHasDERPs(ctx, 5)
274+
ma1.RequireSubscribeAgent(agent1.id)
275+
ma1.RequireEventuallyHasDERPs(ctx, 5)
256276
assertEventuallyHasDERPs(ctx, t, agent1, 3)
257277

258278
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
259-
ma1.assertEventuallyHasDERPs(ctx, 1)
279+
ma1.RequireEventuallyHasDERPs(ctx, 1)
260280

261-
ma1.close()
281+
ma1.Close()
262282
require.NoError(t, agent1.close())
263283

264284
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@@ -305,129 +325,29 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
305325
defer agent1.close()
306326
agent2.sendNode(&agpl.Node{PreferredDERP: 6})
307327

308-
ma1 := newTestMultiAgent(t, coord3)
309-
defer ma1.close()
328+
ma1 := tailnettest.NewTestMultiAgent(t, coord3)
329+
defer ma1.Close()
310330

311-
ma1.subscribeAgent(agent1.id)
312-
ma1.assertEventuallyHasDERPs(ctx, 5)
331+
ma1.RequireSubscribeAgent(agent1.id)
332+
ma1.RequireEventuallyHasDERPs(ctx, 5)
313333

314334
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
315-
ma1.assertEventuallyHasDERPs(ctx, 1)
335+
ma1.RequireEventuallyHasDERPs(ctx, 1)
316336

317-
ma1.subscribeAgent(agent2.id)
318-
ma1.assertEventuallyHasDERPs(ctx, 6)
337+
ma1.RequireSubscribeAgent(agent2.id)
338+
ma1.RequireEventuallyHasDERPs(ctx, 6)
319339

320340
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
321-
ma1.assertEventuallyHasDERPs(ctx, 2)
341+
ma1.RequireEventuallyHasDERPs(ctx, 2)
322342

323-
ma1.sendNodeWithDERP(3)
343+
ma1.SendNodeWithDERP(3)
324344
assertEventuallyHasDERPs(ctx, t, agent1, 3)
325345
assertEventuallyHasDERPs(ctx, t, agent2, 3)
326346

327-
ma1.close()
347+
ma1.Close()
328348
require.NoError(t, agent1.close())
329349
require.NoError(t, agent2.close())
330350

331351
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
332352
assertEventuallyLost(ctx, t, store, agent1.id)
333353
}
334-
335-
type testMultiAgent struct {
336-
t testing.TB
337-
id uuid.UUID
338-
a agpl.MultiAgentConn
339-
nodeKey []byte
340-
discoKey string
341-
}
342-
343-
func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent {
344-
nk, err := key.NewNode().Public().MarshalBinary()
345-
require.NoError(t, err)
346-
dk, err := key.NewDisco().Public().MarshalText()
347-
require.NoError(t, err)
348-
m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)}
349-
m.a = coord.ServeMultiAgent(m.id)
350-
return m
351-
}
352-
353-
func (m *testMultiAgent) sendNodeWithDERP(derp int32) {
354-
m.t.Helper()
355-
err := m.a.UpdateSelf(&proto.Node{
356-
Key: m.nodeKey,
357-
Disco: m.discoKey,
358-
PreferredDerp: derp,
359-
})
360-
require.NoError(m.t, err)
361-
}
362-
363-
func (m *testMultiAgent) close() {
364-
m.t.Helper()
365-
err := m.a.Close()
366-
require.NoError(m.t, err)
367-
}
368-
369-
func (m *testMultiAgent) subscribeAgent(id uuid.UUID) {
370-
m.t.Helper()
371-
err := m.a.SubscribeAgent(id)
372-
require.NoError(m.t, err)
373-
}
374-
375-
func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) {
376-
m.t.Helper()
377-
err := m.a.UnsubscribeAgent(id)
378-
require.NoError(m.t, err)
379-
}
380-
381-
func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) {
382-
m.t.Helper()
383-
for {
384-
resp, ok := m.a.NextUpdate(ctx)
385-
require.True(m.t, ok)
386-
nodes, err := agpl.OnlyNodeUpdates(resp)
387-
require.NoError(m.t, err)
388-
if len(nodes) != len(expected) {
389-
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
390-
continue
391-
}
392-
393-
derps := make([]int, 0, len(nodes))
394-
for _, n := range nodes {
395-
derps = append(derps, n.PreferredDERP)
396-
}
397-
for _, e := range expected {
398-
if !slices.Contains(derps, e) {
399-
m.t.Logf("expected DERP %d to be in %v", e, derps)
400-
continue
401-
}
402-
return
403-
}
404-
}
405-
}
406-
407-
func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) {
408-
m.t.Helper()
409-
for {
410-
resp, ok := m.a.NextUpdate(ctx)
411-
if !ok {
412-
return
413-
}
414-
nodes, err := agpl.OnlyNodeUpdates(resp)
415-
require.NoError(m.t, err)
416-
if len(nodes) != len(expected) {
417-
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
418-
continue
419-
}
420-
421-
derps := make([]int, 0, len(nodes))
422-
for _, n := range nodes {
423-
derps = append(derps, n.PreferredDERP)
424-
}
425-
for _, e := range expected {
426-
if !slices.Contains(derps, e) {
427-
m.t.Logf("expected DERP %d to be in %v", e, derps)
428-
continue
429-
}
430-
return
431-
}
432-
}
433-
}

tailnet/coordinator.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,13 @@ func v1ReqLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logge
10171017
}
10181018

10191019
func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
1020-
defer cancel()
1020+
defer func() {
1021+
cErr := q.Close()
1022+
if cErr != nil {
1023+
logger.Info(ctx, "error closing response Queue", slog.Error(cErr))
1024+
}
1025+
cancel()
1026+
}()
10211027
for {
10221028
resp, err := RecvCtx(ctx, resps)
10231029
if err != nil {

0 commit comments

Comments
 (0)