Skip to content

Commit 5690699

Browse files
committed
fix: close MultiAgentConn when coordinator closes
1 parent 619bdd1 commit 5690699

File tree

4 files changed

+222
-154
lines changed

4 files changed

+222
-154
lines changed

enterprise/tailnet/multiagent_test.go

+73-153
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,14 @@ import (
44
"context"
55
"testing"
66

7-
"github.com/google/uuid"
87
"github.com/stretchr/testify/require"
9-
"golang.org/x/exp/slices"
10-
"tailscale.com/types/key"
118

129
"cdr.dev/slog"
1310
"cdr.dev/slog/sloggers/slogtest"
1411
"github.com/coder/coder/v2/coderd/database/dbtestutil"
1512
"github.com/coder/coder/v2/enterprise/tailnet"
1613
agpl "github.com/coder/coder/v2/tailnet"
17-
"github.com/coder/coder/v2/tailnet/proto"
14+
"github.com/coder/coder/v2/tailnet/tailnettest"
1815
"github.com/coder/coder/v2/testutil"
1916
)
2017

@@ -42,25 +39,48 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
4239
defer agent1.close()
4340
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
4441

45-
ma1 := newTestMultiAgent(t, coord1)
46-
defer ma1.close()
42+
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
43+
defer ma1.Close()
4744

48-
ma1.subscribeAgent(agent1.id)
49-
ma1.assertEventuallyHasDERPs(ctx, 5)
45+
ma1.RequireSubscribeAgent(agent1.id)
46+
ma1.RequireEventuallyHasDERPs(ctx, 5)
5047

5148
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
52-
ma1.assertEventuallyHasDERPs(ctx, 1)
49+
ma1.RequireEventuallyHasDERPs(ctx, 1)
5350

54-
ma1.sendNodeWithDERP(3)
51+
ma1.SendNodeWithDERP(3)
5552
assertEventuallyHasDERPs(ctx, t, agent1, 3)
5653

57-
ma1.close()
54+
ma1.Close()
5855
require.NoError(t, agent1.close())
5956

6057
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
6158
assertEventuallyLost(ctx, t, store, agent1.id)
6259
}
6360

61+
func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) {
62+
t.Parallel()
63+
if !dbtestutil.WillUsePostgres() {
64+
t.Skip("test only with postgres")
65+
}
66+
67+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
68+
store, ps := dbtestutil.NewDB(t)
69+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
70+
defer cancel()
71+
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
72+
require.NoError(t, err)
73+
defer coord1.Close()
74+
75+
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
76+
defer ma1.Close()
77+
78+
err = coord1.Close()
79+
require.NoError(t, err)
80+
81+
ma1.RequireEventuallyClosed(ctx)
82+
}
83+
6484
// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
6585
// a MultiAgent connecting to one agent. It tries to race a call to Unsubscribe
6686
// with the MultiAgent closing.
@@ -86,20 +106,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
86106
defer agent1.close()
87107
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
88108

89-
ma1 := newTestMultiAgent(t, coord1)
90-
defer ma1.close()
109+
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
110+
defer ma1.Close()
91111

92-
ma1.subscribeAgent(agent1.id)
93-
ma1.assertEventuallyHasDERPs(ctx, 5)
112+
ma1.RequireSubscribeAgent(agent1.id)
113+
ma1.RequireEventuallyHasDERPs(ctx, 5)
94114

95115
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
96-
ma1.assertEventuallyHasDERPs(ctx, 1)
116+
ma1.RequireEventuallyHasDERPs(ctx, 1)
97117

98-
ma1.sendNodeWithDERP(3)
118+
ma1.SendNodeWithDERP(3)
99119
assertEventuallyHasDERPs(ctx, t, agent1, 3)
100120

101-
ma1.unsubscribeAgent(agent1.id)
102-
ma1.close()
121+
ma1.RequireUnsubscribeAgent(agent1.id)
122+
ma1.Close()
103123
require.NoError(t, agent1.close())
104124

105125
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@@ -131,35 +151,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
131151
defer agent1.close()
132152
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
133153

134-
ma1 := newTestMultiAgent(t, coord1)
135-
defer ma1.close()
154+
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
155+
defer ma1.Close()
136156

137-
ma1.subscribeAgent(agent1.id)
138-
ma1.assertEventuallyHasDERPs(ctx, 5)
157+
ma1.RequireSubscribeAgent(agent1.id)
158+
ma1.RequireEventuallyHasDERPs(ctx, 5)
139159

140160
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
141-
ma1.assertEventuallyHasDERPs(ctx, 1)
161+
ma1.RequireEventuallyHasDERPs(ctx, 1)
142162

143-
ma1.sendNodeWithDERP(3)
163+
ma1.SendNodeWithDERP(3)
144164
assertEventuallyHasDERPs(ctx, t, agent1, 3)
145165

146-
ma1.unsubscribeAgent(agent1.id)
166+
ma1.RequireUnsubscribeAgent(agent1.id)
147167
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
148168

149169
func() {
150170
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
151171
defer cancel()
152-
ma1.sendNodeWithDERP(9)
172+
ma1.SendNodeWithDERP(9)
153173
assertNeverHasDERPs(ctx, t, agent1, 9)
154174
}()
155175
func() {
156176
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
157177
defer cancel()
158178
agent1.sendNode(&agpl.Node{PreferredDERP: 8})
159-
ma1.assertNeverHasDERPs(ctx, 8)
179+
ma1.RequireNeverHasDERPs(ctx, 8)
160180
}()
161181

162-
ma1.close()
182+
ma1.Close()
163183
require.NoError(t, agent1.close())
164184

165185
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@@ -196,19 +216,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
196216
defer agent1.close()
197217
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
198218

199-
ma1 := newTestMultiAgent(t, coord2)
200-
defer ma1.close()
219+
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
220+
defer ma1.Close()
201221

202-
ma1.subscribeAgent(agent1.id)
203-
ma1.assertEventuallyHasDERPs(ctx, 5)
222+
ma1.RequireSubscribeAgent(agent1.id)
223+
ma1.RequireEventuallyHasDERPs(ctx, 5)
204224

205225
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
206-
ma1.assertEventuallyHasDERPs(ctx, 1)
226+
ma1.RequireEventuallyHasDERPs(ctx, 1)
207227

208-
ma1.sendNodeWithDERP(3)
228+
ma1.SendNodeWithDERP(3)
209229
assertEventuallyHasDERPs(ctx, t, agent1, 3)
210230

211-
ma1.close()
231+
ma1.Close()
212232
require.NoError(t, agent1.close())
213233

214234
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@@ -246,19 +266,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
246266
defer agent1.close()
247267
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
248268

249-
ma1 := newTestMultiAgent(t, coord2)
250-
defer ma1.close()
269+
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
270+
defer ma1.Close()
251271

252-
ma1.sendNodeWithDERP(3)
272+
ma1.SendNodeWithDERP(3)
253273

254-
ma1.subscribeAgent(agent1.id)
255-
ma1.assertEventuallyHasDERPs(ctx, 5)
274+
ma1.RequireSubscribeAgent(agent1.id)
275+
ma1.RequireEventuallyHasDERPs(ctx, 5)
256276
assertEventuallyHasDERPs(ctx, t, agent1, 3)
257277

258278
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
259-
ma1.assertEventuallyHasDERPs(ctx, 1)
279+
ma1.RequireEventuallyHasDERPs(ctx, 1)
260280

261-
ma1.close()
281+
ma1.Close()
262282
require.NoError(t, agent1.close())
263283

264284
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@@ -305,129 +325,29 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
305325
defer agent1.close()
306326
agent2.sendNode(&agpl.Node{PreferredDERP: 6})
307327

308-
ma1 := newTestMultiAgent(t, coord3)
309-
defer ma1.close()
328+
ma1 := tailnettest.NewTestMultiAgent(t, coord3)
329+
defer ma1.Close()
310330

311-
ma1.subscribeAgent(agent1.id)
312-
ma1.assertEventuallyHasDERPs(ctx, 5)
331+
ma1.RequireSubscribeAgent(agent1.id)
332+
ma1.RequireEventuallyHasDERPs(ctx, 5)
313333

314334
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
315-
ma1.assertEventuallyHasDERPs(ctx, 1)
335+
ma1.RequireEventuallyHasDERPs(ctx, 1)
316336

317-
ma1.subscribeAgent(agent2.id)
318-
ma1.assertEventuallyHasDERPs(ctx, 6)
337+
ma1.RequireSubscribeAgent(agent2.id)
338+
ma1.RequireEventuallyHasDERPs(ctx, 6)
319339

320340
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
321-
ma1.assertEventuallyHasDERPs(ctx, 2)
341+
ma1.RequireEventuallyHasDERPs(ctx, 2)
322342

323-
ma1.sendNodeWithDERP(3)
343+
ma1.SendNodeWithDERP(3)
324344
assertEventuallyHasDERPs(ctx, t, agent1, 3)
325345
assertEventuallyHasDERPs(ctx, t, agent2, 3)
326346

327-
ma1.close()
347+
ma1.Close()
328348
require.NoError(t, agent1.close())
329349
require.NoError(t, agent2.close())
330350

331351
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
332352
assertEventuallyLost(ctx, t, store, agent1.id)
333353
}
334-
335-
type testMultiAgent struct {
336-
t testing.TB
337-
id uuid.UUID
338-
a agpl.MultiAgentConn
339-
nodeKey []byte
340-
discoKey string
341-
}
342-
343-
func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent {
344-
nk, err := key.NewNode().Public().MarshalBinary()
345-
require.NoError(t, err)
346-
dk, err := key.NewDisco().Public().MarshalText()
347-
require.NoError(t, err)
348-
m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)}
349-
m.a = coord.ServeMultiAgent(m.id)
350-
return m
351-
}
352-
353-
func (m *testMultiAgent) sendNodeWithDERP(derp int32) {
354-
m.t.Helper()
355-
err := m.a.UpdateSelf(&proto.Node{
356-
Key: m.nodeKey,
357-
Disco: m.discoKey,
358-
PreferredDerp: derp,
359-
})
360-
require.NoError(m.t, err)
361-
}
362-
363-
func (m *testMultiAgent) close() {
364-
m.t.Helper()
365-
err := m.a.Close()
366-
require.NoError(m.t, err)
367-
}
368-
369-
func (m *testMultiAgent) subscribeAgent(id uuid.UUID) {
370-
m.t.Helper()
371-
err := m.a.SubscribeAgent(id)
372-
require.NoError(m.t, err)
373-
}
374-
375-
func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) {
376-
m.t.Helper()
377-
err := m.a.UnsubscribeAgent(id)
378-
require.NoError(m.t, err)
379-
}
380-
381-
func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) {
382-
m.t.Helper()
383-
for {
384-
resp, ok := m.a.NextUpdate(ctx)
385-
require.True(m.t, ok)
386-
nodes, err := agpl.OnlyNodeUpdates(resp)
387-
require.NoError(m.t, err)
388-
if len(nodes) != len(expected) {
389-
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
390-
continue
391-
}
392-
393-
derps := make([]int, 0, len(nodes))
394-
for _, n := range nodes {
395-
derps = append(derps, n.PreferredDERP)
396-
}
397-
for _, e := range expected {
398-
if !slices.Contains(derps, e) {
399-
m.t.Logf("expected DERP %d to be in %v", e, derps)
400-
continue
401-
}
402-
return
403-
}
404-
}
405-
}
406-
407-
func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) {
408-
m.t.Helper()
409-
for {
410-
resp, ok := m.a.NextUpdate(ctx)
411-
if !ok {
412-
return
413-
}
414-
nodes, err := agpl.OnlyNodeUpdates(resp)
415-
require.NoError(m.t, err)
416-
if len(nodes) != len(expected) {
417-
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
418-
continue
419-
}
420-
421-
derps := make([]int, 0, len(nodes))
422-
for _, n := range nodes {
423-
derps = append(derps, n.PreferredDERP)
424-
}
425-
for _, e := range expected {
426-
if !slices.Contains(derps, e) {
427-
m.t.Logf("expected DERP %d to be in %v", e, derps)
428-
continue
429-
}
430-
return
431-
}
432-
}
433-
}

tailnet/coordinator.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,13 @@ func v1ReqLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logge
10171017
}
10181018

10191019
func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
1020-
defer cancel()
1020+
defer func() {
1021+
cErr := q.Close()
1022+
if cErr != nil {
1023+
logger.Info(ctx, "error closing response Queue", slog.Error(cErr))
1024+
}
1025+
cancel()
1026+
}()
10211027
for {
10221028
resp, err := RecvCtx(ctx, resps)
10231029
if err != nil {

tailnet/coordinator_test.go

+18
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,24 @@ func TestCoordinator_Lost(t *testing.T) {
383383
test.LostTest(ctx, t, coordinator)
384384
}
385385

386+
func TestCoordinator_MultiAgent_CoordClose(t *testing.T) {
387+
t.Parallel()
388+
389+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
390+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
391+
defer cancel()
392+
coord1 := tailnet.NewCoordinator(logger.Named("coord1"))
393+
defer coord1.Close()
394+
395+
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
396+
defer ma1.Close()
397+
398+
err := coord1.Close()
399+
require.NoError(t, err)
400+
401+
ma1.RequireEventuallyClosed(ctx)
402+
}
403+
386404
func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) {
387405
t.Helper()
388406
sc := make(chan net.Conn, 1)

0 commit comments

Comments
 (0)