Skip to content

fix: close MultiAgentConn when coordinator closes #11941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions enterprise/tailnet/connio.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package tailnet

import (
"context"
"io"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -104,19 +103,21 @@ func (c *connIO) recvLoop() {
}()
defer c.Close()
for {
req, err := agpl.RecvCtx(c.peerCtx, c.requests)
if err != nil {
if xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) ||
xerrors.Is(err, io.EOF) {
c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err))
} else {
c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err))
}
select {
case <-c.coordCtx.Done():
c.logger.Debug(c.coordCtx, "exiting io recvLoop; coordinator exit")
return
}
if err := c.handleRequest(req); err != nil {
case <-c.peerCtx.Done():
c.logger.Debug(c.peerCtx, "exiting io recvLoop; peer context canceled")
return
case req, ok := <-c.requests:
if !ok {
c.logger.Debug(c.peerCtx, "exiting io recvLoop; requests chan closed")
return
}
if err := c.handleRequest(req); err != nil {
return
}
}
}
}
Expand Down
226 changes: 73 additions & 153 deletions enterprise/tailnet/multiagent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@ import (
"context"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"tailscale.com/types/key"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/enterprise/tailnet"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)

Expand Down Expand Up @@ -42,25 +39,48 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})

ma1 := newTestMultiAgent(t, coord1)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()

ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)

agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)

ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)

ma1.close()
ma1.Close()
require.NoError(t, agent1.close())

assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}

func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()

ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()

err = coord1.Close()
require.NoError(t, err)

ma1.RequireEventuallyClosed(ctx)
}

// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
// a MultiAgent connecting to one agent. It tries to race a call to Unsubscribe
// with the MultiAgent closing.
Expand All @@ -86,20 +106,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})

ma1 := newTestMultiAgent(t, coord1)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()

ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)

agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)

ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)

ma1.unsubscribeAgent(agent1.id)
ma1.close()
ma1.RequireUnsubscribeAgent(agent1.id)
ma1.Close()
require.NoError(t, agent1.close())

assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
Expand Down Expand Up @@ -131,35 +151,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})

ma1 := newTestMultiAgent(t, coord1)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()

ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)

agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)

ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)

ma1.unsubscribeAgent(agent1.id)
ma1.RequireUnsubscribeAgent(agent1.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)

func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel()
ma1.sendNodeWithDERP(9)
ma1.SendNodeWithDERP(9)
assertNeverHasDERPs(ctx, t, agent1, 9)
}()
func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel()
agent1.sendNode(&agpl.Node{PreferredDERP: 8})
ma1.assertNeverHasDERPs(ctx, 8)
ma1.RequireNeverHasDERPs(ctx, 8)
}()

ma1.close()
ma1.Close()
require.NoError(t, agent1.close())

assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
Expand Down Expand Up @@ -196,19 +216,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})

ma1 := newTestMultiAgent(t, coord2)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
defer ma1.Close()

ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)

agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)

ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)

ma1.close()
ma1.Close()
require.NoError(t, agent1.close())

assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
Expand Down Expand Up @@ -246,19 +266,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})

ma1 := newTestMultiAgent(t, coord2)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
defer ma1.Close()

ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)

ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)
assertEventuallyHasDERPs(ctx, t, agent1, 3)

agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)

ma1.close()
ma1.Close()
require.NoError(t, agent1.close())

assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
Expand Down Expand Up @@ -305,129 +325,29 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
defer agent1.close()
agent2.sendNode(&agpl.Node{PreferredDERP: 6})

ma1 := newTestMultiAgent(t, coord3)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord3)
defer ma1.Close()

ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)

agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)

ma1.subscribeAgent(agent2.id)
ma1.assertEventuallyHasDERPs(ctx, 6)
ma1.RequireSubscribeAgent(agent2.id)
ma1.RequireEventuallyHasDERPs(ctx, 6)

agent2.sendNode(&agpl.Node{PreferredDERP: 2})
ma1.assertEventuallyHasDERPs(ctx, 2)
ma1.RequireEventuallyHasDERPs(ctx, 2)

ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
assertEventuallyHasDERPs(ctx, t, agent2, 3)

ma1.close()
ma1.Close()
require.NoError(t, agent1.close())
require.NoError(t, agent2.close())

assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}

type testMultiAgent struct {
t testing.TB
id uuid.UUID
a agpl.MultiAgentConn
nodeKey []byte
discoKey string
}

func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent {
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)}
m.a = coord.ServeMultiAgent(m.id)
return m
}

func (m *testMultiAgent) sendNodeWithDERP(derp int32) {
m.t.Helper()
err := m.a.UpdateSelf(&proto.Node{
Key: m.nodeKey,
Disco: m.discoKey,
PreferredDerp: derp,
})
require.NoError(m.t, err)
}

func (m *testMultiAgent) close() {
m.t.Helper()
err := m.a.Close()
require.NoError(m.t, err)
}

func (m *testMultiAgent) subscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.SubscribeAgent(id)
require.NoError(m.t, err)
}

func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.UnsubscribeAgent(id)
require.NoError(m.t, err)
}

func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
require.True(m.t, ok)
nodes, err := agpl.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}

derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}

func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
if !ok {
return
}
nodes, err := agpl.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}

derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
8 changes: 7 additions & 1 deletion tailnet/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,13 @@ func v1ReqLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logge
}

func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
defer cancel()
defer func() {
cErr := q.Close()
if cErr != nil {
logger.Info(ctx, "error closing response Queue", slog.Error(cErr))
}
cancel()
}()
for {
resp, err := RecvCtx(ctx, resps)
if err != nil {
Expand Down
Loading