Skip to content

feat: set peers lost when disconnected from coordinator #11681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tailnet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,11 @@ func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error
return nil
}

// SetAllPeersLost marks all peers lost; typically used when we disconnect from a coordinator.
func (c *Conn) SetAllPeersLost() {
c.configMaps.setAllPeersLost()
}

// NodeAddresses returns the addresses of a node from the NetworkMap.
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
return c.configMaps.nodeAddresses(publicKey)
Expand Down
67 changes: 44 additions & 23 deletions tailnet/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ type Node struct {
// Conn.
type Coordinatee interface {
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
SetAllPeersLost()
SetNodeCallback(func(*Node))
}

Expand All @@ -107,20 +108,28 @@ type Coordination interface {

type remoteCoordination struct {
sync.Mutex
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
respLoopDone chan struct{}
}

func (c *remoteCoordination) Close() error {
func (c *remoteCoordination) Close() (retErr error) {
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
defer func() {
protoErr := c.protocol.Close()
<-c.respLoopDone
if retErr == nil {
retErr = protoErr
}
}()
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil {
return xerrors.Errorf("send disconnect: %w", err)
Expand All @@ -140,6 +149,10 @@ func (c *remoteCoordination) sendErr(err error) {
}

func (c *remoteCoordination) respLoop() {
defer func() {
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for {
resp, err := c.protocol.Recv()
if err != nil {
Expand All @@ -162,10 +175,11 @@ func NewRemoteCoordination(logger slog.Logger,
tunnelTarget uuid.UUID,
) Coordination {
c := &remoteCoordination{
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
protocol: protocol,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
protocol: protocol,
respLoopDone: make(chan struct{}),
}
if tunnelTarget != uuid.Nil {
c.Lock()
Expand Down Expand Up @@ -200,14 +214,15 @@ func NewRemoteCoordination(logger slog.Logger,

type inMemoryCoordination struct {
sync.Mutex
ctx context.Context
errChan chan error
closed bool
closedCh chan struct{}
coordinatee Coordinatee
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
ctx context.Context
errChan chan error
closed bool
closedCh chan struct{}
respLoopDone chan struct{}
coordinatee Coordinatee
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
}

func (c *inMemoryCoordination) sendErr(err error) {
Expand Down Expand Up @@ -238,11 +253,12 @@ func NewInMemoryCoordination(
thisID = clientID
}
c := &inMemoryCoordination{
ctx: ctx,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
closedCh: make(chan struct{}),
ctx: ctx,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
closedCh: make(chan struct{}),
respLoopDone: make(chan struct{}),
}

// use the background context since we will depend exclusively on closing the req channel to
Expand Down Expand Up @@ -285,6 +301,10 @@ func NewInMemoryCoordination(
}

func (c *inMemoryCoordination) respLoop() {
defer func() {
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for {
select {
case <-c.closedCh:
Expand Down Expand Up @@ -315,6 +335,7 @@ func (c *inMemoryCoordination) Close() error {
defer close(c.reqs)
c.closed = true
close(c.closedCh)
<-c.respLoopDone
select {
case <-c.ctx.Done():
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
Expand Down
172 changes: 167 additions & 5 deletions tailnet/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@ import (
"net"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"

"nhooyr.io/websocket"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
"tailscale.com/types/key"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)
Expand Down Expand Up @@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
require.True(t, ok)
return client, server
}

func TestInMemoryCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}

reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
Times(1).Return(reqs, resps)

uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
defer uut.Close()

coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)

select {
case err := <-uut.Error():
require.NoError(t, err)
default:
// OK!
}
}

func TestRemoteCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}

reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
Times(1).Return(reqs, resps)

var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(
logger.Named("svc"), &coordPtr,
time.Hour,
func() *tailcfg.DERPMap { panic("not implemented") },
)
require.NoError(t, err)
sC, cC := net.Pipe()

serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
serveErr <- err
}()

client, err := tailnet.NewDRPCClient(cC)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)

uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
defer uut.Close()

coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)

select {
case err := <-uut.Error():
require.ErrorContains(t, err, "stream terminated by sending close")
default:
// OK!
}
}

// coordinationTest tests that a coordination behaves correctly
func coordinationTest(
ctx context.Context, t *testing.T,
uut tailnet.Coordination, fConn *fakeCoordinatee,
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
agentID uuid.UUID,
) {
// It should add the tunnel, since we configured as a client
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())

// when we call the callback, it should send a node update
require.NotNil(t, fConn.callback)
fConn.callback(&tailnet.Node{PreferredDERP: 1})

req = testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())

// When we send a peer update, it should update the coordinatee
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
updates := []*proto.CoordinateResponse_PeerUpdate{
{
Id: agentID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 2,
Key: nk,
Disco: string(dk),
},
},
}
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
require.Eventually(t, func() bool {
fConn.Lock()
defer fConn.Unlock()
return len(fConn.updates) > 0
}, testutil.WaitShort, testutil.IntervalFast)
require.Len(t, fConn.updates[0], 1)
require.Equal(t, agentID[:], fConn.updates[0][0].Id)

err = uut.Close()
require.NoError(t, err)
uut.Error()

// When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)

// It should set all peers lost on the coordinatee
require.Equal(t, 1, fConn.setAllPeersLostCalls)
}

type fakeCoordinatee struct {
sync.Mutex
callback func(*tailnet.Node)
updates [][]*proto.CoordinateResponse_PeerUpdate
setAllPeersLostCalls int
}

func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
f.Lock()
defer f.Unlock()
f.updates = append(f.updates, updates)
return nil
}

func (f *fakeCoordinatee) SetAllPeersLost() {
f.Lock()
defer f.Unlock()
f.setAllPeersLostCalls++
}

func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
f.Lock()
defer f.Unlock()
f.callback = callback
}
10 changes: 10 additions & 0 deletions testutil/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,13 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A)
return a
}
}

func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timeout")
case c <- a:
// OK!
}
}