Skip to content

feat: add agent acks to in-memory coordinator #12786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 10, 2024
Next Next commit
feat: add agent acks to in-memory coordinator
When an agent receives a node, it responds with an ACK which is relayed
to the client. After the client receives the ACK, it's allowed to begin
pinging.
  • Loading branch information
coadler committed Apr 10, 2024
commit 1c287b2f54e97c31d66a99f79416f2a11313620b
14 changes: 14 additions & 0 deletions codersdk/workspacesdk/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ type tailnetAPIConnector struct {
coordinateURL string
dialOptions *websocket.DialOptions
conn tailnetConn
agentAckOnce sync.Once
agentAck chan struct{}

connected chan error
isFirst bool
Expand All @@ -74,6 +76,7 @@ func runTailnetAPIConnector(
conn: conn,
connected: make(chan error, 1),
closed: make(chan struct{}),
agentAck: make(chan struct{}),
}
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
go tac.manageGracefulTimeout()
Expand Down Expand Up @@ -190,6 +193,17 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
}()
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID)
tac.logger.Debug(tac.ctx, "serving coordinator")
go func() {
select {
case <-tac.ctx.Done():
tac.logger.Debug(tac.ctx, "ctx timeout before agent ack")
case <-coordination.AwaitAck():
tac.logger.Debug(tac.ctx, "got agent ack")
tac.agentAckOnce.Do(func() {
close(tac.agentAck)
})
}
}()
select {
case <-tac.ctx.Done():
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,59 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
require.NotNil(t, reqDisc.Disconnect)
}

func TestTailnetAPIConnector_Ack(t *testing.T) {
t.Parallel()
testCtx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(testCtx)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
agentID := uuid.UUID{0x55}
clientID := uuid.UUID{0x66}
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh)
svc, err := tailnet.NewClientService(
logger, &coordPtr,
time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh },
)
require.NoError(t, err)

svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client",
ID: clientID,
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
})
assert.NoError(t, err)
}))

fConn := newFakeTailnetConn()

uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn)

call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
require.NotNil(t, reqTun.AddTunnel)

_ = testutil.RequireRecvCtx(ctx, t, uut.connected)

// send an ack to the client
testutil.RequireSendCtx(ctx, t, call.Resps, &proto.CoordinateResponse{
TunnelAck: &proto.CoordinateResponse_Ack{Id: agentID[:]},
})

// the agentAck channel should be successfully closed
_ = testutil.RequireRecvCtx(ctx, t, uut.agentAck)
}

type fakeTailnetConn struct{}

func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
Expand Down
13 changes: 13 additions & 0 deletions codersdk/workspacesdk/workspacesdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,19 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
options.Logger.Debug(ctx, "connected to tailnet v2+ API")
}

// TODO: uncomment after pgcoord ack's are implemented (upstack pr)
// options.Logger.Debug(ctx, "waiting for agent ack")
// // 5 seconds is chosen because this is the timeout for failed Wireguard
// // handshakes. In the worst case, we wait the same amount of time as a
// // failed handshake.
// timer := time.NewTimer(5 * time.Second)
// select {
// case <-connector.agentAck:
// case <-timer.C:
// options.Logger.Debug(ctx, "timed out waiting for agent ack")
// }
// timer.Stop()

agentConn = NewAgentConn(conn, AgentConnOptions{
AgentID: agentID,
CloseFunc: func() error {
Expand Down
88 changes: 84 additions & 4 deletions tailnet/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ type Coordinatee interface {
}

type Coordination interface {
AwaitAck() <-chan struct{}
io.Closer
Error() <-chan error
}
Expand All @@ -111,9 +112,14 @@ type remoteCoordination struct {
closed bool
errChan chan error
coordinatee Coordinatee
tgt uuid.UUID
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
respLoopDone chan struct{}

ackOnce sync.Once
// tgtAck is closed when an ack from tgt is received.
tgtAck chan struct{}
}

func (c *remoteCoordination) Close() (retErr error) {
Expand Down Expand Up @@ -161,14 +167,54 @@ func (c *remoteCoordination) respLoop() {
c.sendErr(xerrors.Errorf("read: %w", err))
return
}
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.sendErr(xerrors.Errorf("update peers: %w", err))
return

if len(resp.GetPeerUpdates()) > 0 {
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.sendErr(xerrors.Errorf("update peers: %w", err))
return
}

// Only send acks from agents.
if c.tgt == uuid.Nil {
// Send an ack back for all received peers. This could
// potentially be smarter to only send an ACK once per client,
// but there's nothing currently stopping clients from reusing
// IDs.
for _, peer := range resp.GetPeerUpdates() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry this is too superficial --- here we are only acknowledging the fact that we received a peer update, not that it was programmed into wireguard, which is what is actually needed for the handshake to complete.

I guess this is an OK start in that it cuts out any propagation delay from the Coordinator out of the race condition, but still leaves the race there. I realize that you have yet to add support to the PGCoordinator, which is where we suspect the real problems are, so we will need to test this out and confirm that missed handshakes are substantially reduced. We can embed the ack deeper into tailnet in some later PR if we are still missing handshakes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, completely eliminating the race would require digging down into the configmaps which I wasn't keen to do unless necessary. In my testing with the in-memory coordinator I wasn't able to hit the 5s backoff anymore. I suspect pgcoord to actually fare better considering the extra round trip latency as compared to the in-memory coordinator.

err := c.protocol.Send(&proto.CoordinateRequest{
TunnelAck: &proto.CoordinateRequest_Ack{Id: peer.Id},
})
if err != nil {
c.sendErr(xerrors.Errorf("send: %w", err))
return
}
}
}
}

// If we receive an ack, close the tgtAck channel to notify the waiting
// client.
if ack := resp.GetTunnelAck(); ack != nil {
dstID, err := uuid.FromBytes(ack.Id)
if err != nil {
c.sendErr(xerrors.Errorf("parse ack id: %w", err))
return
}

if c.tgt == dstID {
c.ackOnce.Do(func() {
close(c.tgtAck)
})
}
}
}
}

func (c *remoteCoordination) AwaitAck() <-chan struct{} {
return c.tgtAck
}

// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
// a client---agents should NOT set this!).
Expand All @@ -179,9 +225,11 @@ func NewRemoteCoordination(logger slog.Logger,
c := &remoteCoordination{
errChan: make(chan error, 1),
coordinatee: coordinatee,
tgt: tunnelTarget,
logger: logger,
protocol: protocol,
respLoopDone: make(chan struct{}),
tgtAck: make(chan struct{}),
}
if tunnelTarget != uuid.Nil {
c.Lock()
Expand Down Expand Up @@ -327,6 +375,13 @@ func (c *inMemoryCoordination) respLoop() {
}
}

func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
// This is only used for tests, so just return a closed channel.
ch := make(chan struct{})
close(ch)
return ch
}

func (c *inMemoryCoordination) Close() error {
c.Lock()
defer c.Unlock()
Expand Down Expand Up @@ -658,6 +713,31 @@ func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error {
if req.Disconnect != nil {
c.removePeerLocked(p.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "graceful disconnect")
}
if ack := req.TunnelAck; ack != nil {
err := c.handleAckLocked(pr, ack)
if err != nil {
return xerrors.Errorf("handle ack: %w", err)
}
}
return nil
}

func (c *core) handleAckLocked(src *peer, ack *proto.CoordinateRequest_Ack) error {
dstID, err := uuid.FromBytes(ack.Id)
if err != nil {
// this shouldn't happen unless there is a client error. Close the connection so the client
// doesn't just happily continue thinking everything is fine.
return xerrors.Errorf("unable to convert bytes to UUID: %w", err)
}

dst, ok := c.peers[dstID]
if ok {
dst.resps <- &proto.CoordinateResponse{
TunnelAck: &proto.CoordinateResponse_Ack{
Id: src.id[:],
},
}
}
return nil
}

Expand Down
73 changes: 73 additions & 0 deletions tailnet/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,24 @@ func TestCoordinator(t *testing.T) {
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
})

t.Run("AgentAck", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)

clientID := uuid.New()
agentID := uuid.New()

aReq, _ := coordinator.Coordinate(ctx, agentID, agentID.String(), tailnet.AgentCoordinateeAuth{ID: agentID})
_, cRes := coordinator.Coordinate(ctx, clientID, clientID.String(), tailnet.ClientCoordinateeAuth{AgentID: agentID})

aReq <- &proto.CoordinateRequest{TunnelAck: &proto.CoordinateRequest_Ack{Id: clientID[:]}}
ack := testutil.RequireRecvCtx(ctx, t, cRes)
require.NotNil(t, ack.TunnelAck)
require.Equal(t, agentID[:], ack.TunnelAck.Id)
})
}

// TestCoordinator_AgentUpdateWhileClientConnects tests for regression on
Expand Down Expand Up @@ -638,6 +656,61 @@ func TestRemoteCoordination(t *testing.T) {
}
}

func TestRemoteCoordination_Ack(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}

reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)

var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(
logger.Named("svc"), &coordPtr,
time.Hour,
func() *tailcfg.DERPMap { panic("not implemented") },
)
require.NoError(t, err)
sC, cC := net.Pipe()

serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, clientID, agentID)
serveErr <- err
}()

client, err := tailnet.NewDRPCClient(cC, logger)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)

uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
defer uut.Close()

testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{
TunnelAck: &proto.CoordinateResponse_Ack{Id: agentID[:]},
})

testutil.RequireRecvCtx(ctx, t, uut.AwaitAck())

require.NoError(t, uut.Close())

select {
case err := <-uut.Error():
require.ErrorContains(t, err, "stream terminated by sending close")
default:
// OK!
}
}

// coordinationTest tests that a coordination behaves correctly
func coordinationTest(
ctx context.Context, t *testing.T,
Expand Down
Loading