Skip to content

Commit 1f1e8f4

Browse files
committed
feat: add agent acks to in-memory coordinator
When an agent receives a node, it responds with an ACK which is relayed to the client. After the client receives the ACK, it's allowed to begin pinging.
1 parent dc8cf3e commit 1f1e8f4

File tree

7 files changed

+465
-71
lines changed

7 files changed

+465
-71
lines changed

codersdk/workspacesdk/connector.go

+14
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ type tailnetAPIConnector struct {
5353
coordinateURL string
5454
dialOptions *websocket.DialOptions
5555
conn tailnetConn
56+
agentAckOnce sync.Once
57+
agentAck chan struct{}
5658

5759
connected chan error
5860
isFirst bool
@@ -74,6 +76,7 @@ func runTailnetAPIConnector(
7476
conn: conn,
7577
connected: make(chan error, 1),
7678
closed: make(chan struct{}),
79+
agentAck: make(chan struct{}),
7780
}
7881
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
7982
go tac.manageGracefulTimeout()
@@ -190,6 +193,17 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
190193
}()
191194
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID)
192195
tac.logger.Debug(tac.ctx, "serving coordinator")
196+
go func() {
197+
select {
198+
case <-tac.ctx.Done():
199+
tac.logger.Debug(tac.ctx, "ctx timeout before agent ack")
200+
case <-coordination.AwaitAck():
201+
tac.logger.Debug(tac.ctx, "got agent ack")
202+
tac.agentAckOnce.Do(func() {
203+
close(tac.agentAck)
204+
})
205+
}
206+
}()
193207
select {
194208
case <-tac.ctx.Done():
195209
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")

codersdk/workspacesdk/workspacesdk_internal_test.go renamed to codersdk/workspacesdk/connector_internal_test.go

+53
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,59 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
8989
require.NotNil(t, reqDisc.Disconnect)
9090
}
9191

92+
func TestTailnetAPIConnector_Ack(t *testing.T) {
93+
t.Parallel()
94+
testCtx := testutil.Context(t, testutil.WaitShort)
95+
ctx, cancel := context.WithCancel(testCtx)
96+
defer cancel()
97+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
98+
agentID := uuid.UUID{0x55}
99+
clientID := uuid.UUID{0x66}
100+
fCoord := tailnettest.NewFakeCoordinator()
101+
var coord tailnet.Coordinator = fCoord
102+
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
103+
coordPtr.Store(&coord)
104+
derpMapCh := make(chan *tailcfg.DERPMap)
105+
defer close(derpMapCh)
106+
svc, err := tailnet.NewClientService(
107+
logger, &coordPtr,
108+
time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh },
109+
)
110+
require.NoError(t, err)
111+
112+
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
113+
sws, err := websocket.Accept(w, r, nil)
114+
if !assert.NoError(t, err) {
115+
return
116+
}
117+
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
118+
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
119+
Name: "client",
120+
ID: clientID,
121+
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
122+
})
123+
assert.NoError(t, err)
124+
}))
125+
126+
fConn := newFakeTailnetConn()
127+
128+
uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn)
129+
130+
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
131+
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
132+
require.NotNil(t, reqTun.AddTunnel)
133+
134+
_ = testutil.RequireRecvCtx(ctx, t, uut.connected)
135+
136+
// send an ack to the client
137+
testutil.RequireSendCtx(ctx, t, call.Resps, &proto.CoordinateResponse{
138+
TunnelAck: &proto.CoordinateResponse_Ack{Id: agentID[:]},
139+
})
140+
141+
// the agentAck channel should be successfully closed
142+
_ = testutil.RequireRecvCtx(ctx, t, uut.agentAck)
143+
}
144+
92145
type fakeTailnetConn struct{}
93146

94147
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {

codersdk/workspacesdk/workspacesdk.go

+13
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,19 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
260260
options.Logger.Debug(ctx, "connected to tailnet v2+ API")
261261
}
262262

263+
// TODO: uncomment after pgcoord ack's are implemented (upstack pr)
264+
// options.Logger.Debug(ctx, "waiting for agent ack")
265+
// // 5 seconds is chosen because this is the timeout for failed Wireguard
266+
// // handshakes. In the worst case, we wait the same amount of time as a
267+
// // failed handshake.
268+
// timer := time.NewTimer(5 * time.Second)
269+
// select {
270+
// case <-connector.agentAck:
271+
// case <-timer.C:
272+
// options.Logger.Debug(ctx, "timed out waiting for agent ack")
273+
// }
274+
// timer.Stop()
275+
263276
agentConn = NewAgentConn(conn, AgentConnOptions{
264277
AgentID: agentID,
265278
CloseFunc: func() error {

tailnet/coordinator.go

+79-4
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ type Coordinatee interface {
102102
}
103103

104104
type Coordination interface {
105+
AwaitAck() <-chan struct{}
105106
io.Closer
106107
Error() <-chan error
107108
}
@@ -111,9 +112,14 @@ type remoteCoordination struct {
111112
closed bool
112113
errChan chan error
113114
coordinatee Coordinatee
115+
tgt uuid.UUID
114116
logger slog.Logger
115117
protocol proto.DRPCTailnet_CoordinateClient
116118
respLoopDone chan struct{}
119+
120+
ackOnce sync.Once
121+
// tgtAck is closed when an ack from tgt is received.
122+
tgtAck chan struct{}
117123
}
118124

119125
func (c *remoteCoordination) Close() (retErr error) {
@@ -161,14 +167,49 @@ func (c *remoteCoordination) respLoop() {
161167
c.sendErr(xerrors.Errorf("read: %w", err))
162168
return
163169
}
164-
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
165-
if err != nil {
166-
c.sendErr(xerrors.Errorf("update peers: %w", err))
167-
return
170+
171+
if len(resp.GetPeerUpdates()) > 0 {
172+
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
173+
if err != nil {
174+
c.sendErr(xerrors.Errorf("update peers: %w", err))
175+
return
176+
}
177+
178+
// only send acks from agents.
179+
if c.tgt == uuid.Nil {
180+
for _, peer := range resp.GetPeerUpdates() {
181+
err := c.protocol.Send(&proto.CoordinateRequest{
182+
TunnelAck: &proto.CoordinateRequest_Ack{Id: peer.Id},
183+
})
184+
if err != nil {
185+
c.sendErr(xerrors.Errorf("send: %w", err))
186+
return
187+
}
188+
}
189+
}
190+
}
191+
192+
// If we receive an ack, check the active waiters and notify them.
193+
if ack := resp.GetTunnelAck(); ack != nil {
194+
dstID, err := uuid.FromBytes(ack.Id)
195+
if err != nil {
196+
c.sendErr(xerrors.Errorf("parse ack id: %w", err))
197+
return
198+
}
199+
200+
if c.tgt == dstID {
201+
c.ackOnce.Do(func() {
202+
close(c.tgtAck)
203+
})
204+
}
168205
}
169206
}
170207
}
171208

209+
func (c *remoteCoordination) AwaitAck() <-chan struct{} {
210+
return c.tgtAck
211+
}
212+
172213
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a
173214
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
174215
// a client---agents should NOT set this!).
@@ -179,9 +220,11 @@ func NewRemoteCoordination(logger slog.Logger,
179220
c := &remoteCoordination{
180221
errChan: make(chan error, 1),
181222
coordinatee: coordinatee,
223+
tgt: tunnelTarget,
182224
logger: logger,
183225
protocol: protocol,
184226
respLoopDone: make(chan struct{}),
227+
tgtAck: make(chan struct{}),
185228
}
186229
if tunnelTarget != uuid.Nil {
187230
c.Lock()
@@ -327,6 +370,13 @@ func (c *inMemoryCoordination) respLoop() {
327370
}
328371
}
329372

373+
func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
374+
// This is only used for tests, so just return a closed channel.
375+
ch := make(chan struct{})
376+
close(ch)
377+
return ch
378+
}
379+
330380
func (c *inMemoryCoordination) Close() error {
331381
c.Lock()
332382
defer c.Unlock()
@@ -658,6 +708,31 @@ func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error {
658708
if req.Disconnect != nil {
659709
c.removePeerLocked(p.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "graceful disconnect")
660710
}
711+
if ack := req.TunnelAck; ack != nil {
712+
err := c.handleAckLocked(pr, ack)
713+
if err != nil {
714+
return xerrors.Errorf("handle ack: %w", err)
715+
}
716+
}
717+
return nil
718+
}
719+
720+
func (c *core) handleAckLocked(src *peer, ack *proto.CoordinateRequest_Ack) error {
721+
dstID, err := uuid.FromBytes(ack.Id)
722+
if err != nil {
723+
// this shouldn't happen unless there is a client error. Close the connection so the client
724+
// doesn't just happily continue thinking everything is fine.
725+
return xerrors.Errorf("unable to convert bytes to UUID: %w", err)
726+
}
727+
728+
dst, ok := c.peers[dstID]
729+
if ok {
730+
dst.resps <- &proto.CoordinateResponse{
731+
TunnelAck: &proto.CoordinateResponse_Ack{
732+
Id: src.id[:],
733+
},
734+
}
735+
}
661736
return nil
662737
}
663738

tailnet/coordinator_test.go

+73
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,24 @@ func TestCoordinator(t *testing.T) {
412412
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
413413
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
414414
})
415+
416+
t.Run("AgentAck", func(t *testing.T) {
417+
t.Parallel()
418+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
419+
coordinator := tailnet.NewCoordinator(logger)
420+
ctx := testutil.Context(t, testutil.WaitShort)
421+
422+
clientID := uuid.New()
423+
agentID := uuid.New()
424+
425+
aReq, _ := coordinator.Coordinate(ctx, agentID, agentID.String(), tailnet.AgentCoordinateeAuth{ID: agentID})
426+
_, cRes := coordinator.Coordinate(ctx, clientID, clientID.String(), tailnet.ClientCoordinateeAuth{AgentID: agentID})
427+
428+
aReq <- &proto.CoordinateRequest{TunnelAck: &proto.CoordinateRequest_Ack{Id: clientID[:]}}
429+
ack := testutil.RequireRecvCtx(ctx, t, cRes)
430+
require.NotNil(t, ack.TunnelAck)
431+
require.Equal(t, agentID[:], ack.TunnelAck.Id)
432+
})
415433
}
416434

417435
// TestCoordinator_AgentUpdateWhileClientConnects tests for regression on
@@ -638,6 +656,61 @@ func TestRemoteCoordination(t *testing.T) {
638656
}
639657
}
640658

659+
func TestRemoteCoordination_Ack(t *testing.T) {
660+
t.Parallel()
661+
ctx := testutil.Context(t, testutil.WaitShort)
662+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
663+
clientID := uuid.UUID{1}
664+
agentID := uuid.UUID{2}
665+
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
666+
fConn := &fakeCoordinatee{}
667+
668+
reqs := make(chan *proto.CoordinateRequest, 100)
669+
resps := make(chan *proto.CoordinateResponse, 100)
670+
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
671+
Times(1).Return(reqs, resps)
672+
673+
var coord tailnet.Coordinator = mCoord
674+
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
675+
coordPtr.Store(&coord)
676+
svc, err := tailnet.NewClientService(
677+
logger.Named("svc"), &coordPtr,
678+
time.Hour,
679+
func() *tailcfg.DERPMap { panic("not implemented") },
680+
)
681+
require.NoError(t, err)
682+
sC, cC := net.Pipe()
683+
684+
serveErr := make(chan error, 1)
685+
go func() {
686+
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, clientID, agentID)
687+
serveErr <- err
688+
}()
689+
690+
client, err := tailnet.NewDRPCClient(cC, logger)
691+
require.NoError(t, err)
692+
protocol, err := client.Coordinate(ctx)
693+
require.NoError(t, err)
694+
695+
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
696+
defer uut.Close()
697+
698+
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{
699+
TunnelAck: &proto.CoordinateResponse_Ack{Id: agentID[:]},
700+
})
701+
702+
testutil.RequireRecvCtx(ctx, t, uut.AwaitAck())
703+
704+
require.NoError(t, uut.Close())
705+
706+
select {
707+
case err := <-uut.Error():
708+
require.ErrorContains(t, err, "stream terminated by sending close")
709+
default:
710+
// OK!
711+
}
712+
}
713+
641714
// coordinationTest tests that a coordination behaves correctly
642715
func coordinationTest(
643716
ctx context.Context, t *testing.T,

0 commit comments

Comments
 (0)