Skip to content

Commit 4b14af0

Browse files
committed
spike comments
1 parent 1c287b2 commit 4b14af0

File tree

11 files changed

+334
-361
lines changed

11 files changed

+334
-361
lines changed

codersdk/workspacesdk/connector.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ type tailnetAPIConnector struct {
5353
coordinateURL string
5454
dialOptions *websocket.DialOptions
5555
conn tailnetConn
56-
agentAckOnce sync.Once
57-
agentAck chan struct{}
5856

5957
connected chan error
6058
isFirst bool
@@ -76,7 +74,6 @@ func runTailnetAPIConnector(
7674
conn: conn,
7775
connected: make(chan error, 1),
7876
closed: make(chan struct{}),
79-
agentAck: make(chan struct{}),
8077
}
8178
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
8279
go tac.manageGracefulTimeout()
@@ -193,17 +190,6 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
193190
}()
194191
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID)
195192
tac.logger.Debug(tac.ctx, "serving coordinator")
196-
go func() {
197-
select {
198-
case <-tac.ctx.Done():
199-
tac.logger.Debug(tac.ctx, "ctx timeout before agent ack")
200-
case <-coordination.AwaitAck():
201-
tac.logger.Debug(tac.ctx, "got agent ack")
202-
tac.agentAckOnce.Do(func() {
203-
close(tac.agentAck)
204-
})
205-
}
206-
}()
207193
select {
208194
case <-tac.ctx.Done():
209195
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")

codersdk/workspacesdk/connector_internal_test.go

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -89,59 +89,6 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
8989
require.NotNil(t, reqDisc.Disconnect)
9090
}
9191

92-
func TestTailnetAPIConnector_Ack(t *testing.T) {
93-
t.Parallel()
94-
testCtx := testutil.Context(t, testutil.WaitShort)
95-
ctx, cancel := context.WithCancel(testCtx)
96-
defer cancel()
97-
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
98-
agentID := uuid.UUID{0x55}
99-
clientID := uuid.UUID{0x66}
100-
fCoord := tailnettest.NewFakeCoordinator()
101-
var coord tailnet.Coordinator = fCoord
102-
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
103-
coordPtr.Store(&coord)
104-
derpMapCh := make(chan *tailcfg.DERPMap)
105-
defer close(derpMapCh)
106-
svc, err := tailnet.NewClientService(
107-
logger, &coordPtr,
108-
time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh },
109-
)
110-
require.NoError(t, err)
111-
112-
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
113-
sws, err := websocket.Accept(w, r, nil)
114-
if !assert.NoError(t, err) {
115-
return
116-
}
117-
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
118-
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
119-
Name: "client",
120-
ID: clientID,
121-
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
122-
})
123-
assert.NoError(t, err)
124-
}))
125-
126-
fConn := newFakeTailnetConn()
127-
128-
uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn)
129-
130-
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
131-
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
132-
require.NotNil(t, reqTun.AddTunnel)
133-
134-
_ = testutil.RequireRecvCtx(ctx, t, uut.connected)
135-
136-
// send an ack to the client
137-
testutil.RequireSendCtx(ctx, t, call.Resps, &proto.CoordinateResponse{
138-
TunnelAck: &proto.CoordinateResponse_Ack{Id: agentID[:]},
139-
})
140-
141-
// the agentAck channel should be successfully closed
142-
_ = testutil.RequireRecvCtx(ctx, t, uut.agentAck)
143-
}
144-
14592
type fakeTailnetConn struct{}
14693

14794
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {

codersdk/workspacesdk/workspacesdk.go

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
203203
DERPForceWebSockets: connInfo.DERPForceWebSockets,
204204
Logger: options.Logger,
205205
BlockEndpoints: c.client.DisableDirectConnections || options.BlockEndpoints,
206+
// TODO: enable in upstack PR
207+
// ShouldWaitForHandshake: true,
206208
})
207209
if err != nil {
208210
return nil, xerrors.Errorf("create tailnet: %w", err)
@@ -260,19 +262,6 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
260262
options.Logger.Debug(ctx, "connected to tailnet v2+ API")
261263
}
262264

263-
// TODO: uncomment after pgcoord ack's are implemented (upstack pr)
264-
// options.Logger.Debug(ctx, "waiting for agent ack")
265-
// // 5 seconds is chosen because this is the timeout for failed Wireguard
266-
// // handshakes. In the worst case, we wait the same amount of time as a
267-
// // failed handshake.
268-
// timer := time.NewTimer(5 * time.Second)
269-
// select {
270-
// case <-connector.agentAck:
271-
// case <-timer.C:
272-
// options.Logger.Debug(ctx, "timed out waiting for agent ack")
273-
// }
274-
// timer.Stop()
275-
276265
agentConn = NewAgentConn(conn, AgentConnOptions{
277266
AgentID: agentID,
278267
CloseFunc: func() error {

enterprise/coderd/workspaceproxy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request)
658658
if err != nil {
659659
return xerrors.Errorf("insert replica: %w", err)
660660
}
661-
} else if err != nil {
661+
} else {
662662
return xerrors.Errorf("get replica: %w", err)
663663
}
664664

tailnet/configmaps.go

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,11 @@ type phased struct {
5656

5757
type configMaps struct {
5858
phased
59-
netmapDirty bool
60-
derpMapDirty bool
61-
filterDirty bool
62-
closing bool
59+
netmapDirty bool
60+
derpMapDirty bool
61+
filterDirty bool
62+
closing bool
63+
waitForHandshake bool
6364

6465
engine engineConfigurable
6566
static netmap.NetworkMap
@@ -216,6 +217,9 @@ func (c *configMaps) netMapLocked() *netmap.NetworkMap {
216217
func (c *configMaps) peerConfigLocked() []*tailcfg.Node {
217218
out := make([]*tailcfg.Node, 0, len(c.peers))
218219
for _, p := range c.peers {
220+
if !p.readyForHandshake {
221+
continue
222+
}
219223
n := p.node.Clone()
220224
if c.blockEndpoints {
221225
n.Endpoints = nil
@@ -225,6 +229,12 @@ func (c *configMaps) peerConfigLocked() []*tailcfg.Node {
225229
return out
226230
}
227231

232+
func (c *configMaps) setWaitForHandshake(wait bool) {
233+
c.L.Lock()
234+
defer c.L.Unlock()
235+
c.waitForHandshake = wait
236+
}
237+
228238
// setAddresses sets the addresses belonging to this node to the given slice. It
229239
// triggers configuration of the engine if the addresses have changed.
230240
// c.L MUST NOT be held.
@@ -377,17 +387,9 @@ func (c *configMaps) updatePeerLocked(update *proto.CoordinateResponse_PeerUpdat
377387
return false
378388
}
379389
logger = logger.With(slog.F("key_id", node.Key.ShortString()), slog.F("node", node))
380-
peerStatus, ok := status.Peer[node.Key]
381-
// Starting KeepAlive messages at the initialization of a connection
382-
// causes a race condition. If we send the handshake before the peer has
383-
// our node, we'll have to wait for 5 seconds before trying again.
384-
// Ideally, the first handshake starts when the user first initiates a
385-
// connection to the peer. After a successful connection we enable
386-
// keep alives to persist the connection and keep it from becoming idle.
387-
// SSH connections don't send packets while idle, so we use keep alives
388-
// to avoid random hangs while we set up the connection again after
389-
// inactivity.
390-
node.KeepAlive = ok && peerStatus.Active
390+
// Since we don't send nodes into Tailscale unless we're sure that the
391+
// peer is ready for handshakes, we always enable keepalives.
392+
node.KeepAlive = true
391393
}
392394
switch {
393395
case !ok && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
@@ -396,23 +398,46 @@ func (c *configMaps) updatePeerLocked(update *proto.CoordinateResponse_PeerUpdat
396398
if ps, ok := status.Peer[node.Key]; ok {
397399
lastHandshake = ps.LastHandshake
398400
}
399-
c.peers[id] = &peerLifecycle{
400-
peerID: id,
401-
node: node,
402-
lastHandshake: lastHandshake,
403-
lost: false,
401+
lc = &peerLifecycle{
402+
peerID: id,
403+
node: node,
404+
lastHandshake: lastHandshake,
405+
lost: false,
406+
readyForHandshake: !c.waitForHandshake,
404407
}
408+
if c.waitForHandshake {
409+
lc.readyForHandshakeTimer = c.clock.AfterFunc(5*time.Second, func() {
410+
logger.Debug(context.Background(), "ready for handshake timeout")
411+
c.peerReadyForHandshakeTimeout(id)
412+
})
413+
}
414+
c.peers[id] = lc
405415
logger.Debug(context.Background(), "adding new peer")
406-
return true
416+
// since we just got this node, we don't know if it's ready for
417+
// handshakes yet.
418+
return lc.readyForHandshake
407419
case ok && update.Kind == proto.CoordinateResponse_PeerUpdate_NODE:
408420
// update
409421
node.Created = lc.node.Created
410-
dirty = !lc.node.Equal(node)
422+
dirty = !lc.node.Equal(node) && lc.readyForHandshake
411423
lc.node = node
412424
lc.lost = false
413425
lc.resetTimer()
414426
logger.Debug(context.Background(), "node update to existing peer", slog.F("dirty", dirty))
415427
return dirty
428+
case ok && update.Kind == proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
429+
wasReady := lc.readyForHandshake
430+
lc.readyForHandshake = true
431+
if lc.readyForHandshakeTimer != nil {
432+
lc.readyForHandshakeTimer.Stop()
433+
}
434+
logger.Debug(context.Background(), "peer ready for handshake")
435+
return !wasReady
436+
case !ok && update.Kind == proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
437+
// TODO: should we keep track of ready for handshake messages we get
438+
// from unknown nodes?
439+
logger.Debug(context.Background(), "got peer ready for handshake for unknown peer")
440+
return false
416441
case !ok:
417442
// disconnected or lost, but we don't have the node. No op
418443
logger.Debug(context.Background(), "skipping update for peer we don't recognize")
@@ -550,12 +575,31 @@ func (c *configMaps) fillPeerDiagnostics(d *PeerDiagnostics, peerID uuid.UUID) {
550575
d.LastWireguardHandshake = ps.LastHandshake
551576
}
552577

578+
func (c *configMaps) peerReadyForHandshakeTimeout(peerID uuid.UUID) {
579+
c.L.Lock()
580+
defer c.L.Unlock()
581+
lc, ok := c.peers[peerID]
582+
if !ok {
583+
return
584+
}
585+
if lc.readyForHandshakeTimer != nil {
586+
wasReady := lc.readyForHandshake
587+
lc.readyForHandshakeTimer = nil
588+
lc.readyForHandshake = true
589+
if !wasReady {
590+
c.Broadcast()
591+
}
592+
}
593+
}
594+
553595
type peerLifecycle struct {
554-
peerID uuid.UUID
555-
node *tailcfg.Node
556-
lost bool
557-
lastHandshake time.Time
558-
timer *clock.Timer
596+
peerID uuid.UUID
597+
node *tailcfg.Node
598+
lost bool
599+
lastHandshake time.Time
600+
timer *clock.Timer
601+
readyForHandshake bool
602+
readyForHandshakeTimer *clock.Timer
559603
}
560604

561605
func (l *peerLifecycle) resetTimer() {

tailnet/configmaps_internal_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,75 @@ func TestConfigMaps_updatePeers_new(t *testing.T) {
185185
_ = testutil.RequireRecvCtx(ctx, t, done)
186186
}
187187

188+
func TestConfigMaps_updatePeers_new_waitForHandshake(t *testing.T) {
189+
t.Parallel()
190+
ctx := testutil.Context(t, testutil.WaitShort)
191+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
192+
fEng := newFakeEngineConfigurable()
193+
nodePrivateKey := key.NewNode()
194+
nodeID := tailcfg.NodeID(5)
195+
discoKey := key.NewDisco()
196+
uut := newConfigMaps(logger, fEng, nodeID, nodePrivateKey, discoKey.Public())
197+
defer uut.close()
198+
uut.setWaitForHandshake(true)
199+
200+
p1ID := uuid.UUID{1}
201+
p1Node := newTestNode(1)
202+
p1n, err := NodeToProto(p1Node)
203+
require.NoError(t, err)
204+
205+
go func() {
206+
<-fEng.status
207+
fEng.statusDone <- struct{}{}
208+
}()
209+
210+
u1 := []*proto.CoordinateResponse_PeerUpdate{
211+
{
212+
Id: p1ID[:],
213+
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
214+
Node: p1n,
215+
},
216+
}
217+
uut.updatePeers(u1)
218+
219+
// it should not send the peer to the netmap yet
220+
221+
go func() {
222+
<-fEng.status
223+
fEng.statusDone <- struct{}{}
224+
}()
225+
226+
u2 := []*proto.CoordinateResponse_PeerUpdate{
227+
{
228+
Id: p1ID[:],
229+
Kind: proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE,
230+
},
231+
}
232+
uut.updatePeers(u2)
233+
234+
// it should now send the peer to the netmap
235+
236+
nm := testutil.RequireRecvCtx(ctx, t, fEng.setNetworkMap)
237+
r := testutil.RequireRecvCtx(ctx, t, fEng.reconfig)
238+
239+
require.Len(t, nm.Peers, 1)
240+
n1 := getNodeWithID(t, nm.Peers, 1)
241+
require.Equal(t, "127.3.3.40:1", n1.DERP)
242+
require.Equal(t, p1Node.Endpoints, n1.Endpoints)
243+
require.True(t, n1.KeepAlive)
244+
245+
// we rely on nmcfg.WGCfg() to convert the netmap to wireguard config, so just
246+
// require the right number of peers.
247+
require.Len(t, r.wg.Peers, 1)
248+
249+
done := make(chan struct{})
250+
go func() {
251+
defer close(done)
252+
uut.close()
253+
}()
254+
_ = testutil.RequireRecvCtx(ctx, t, done)
255+
}
256+
188257
func TestConfigMaps_updatePeers_same(t *testing.T) {
189258
t.Parallel()
190259
ctx := testutil.Context(t, testutil.WaitShort)

tailnet/conn.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ type Options struct {
8787
// connections, rather than trying `Upgrade: derp` first and potentially
8888
// falling back. This is useful for misbehaving proxies that prevent
8989
// fallback due to odd behavior, like Azure App Proxy.
90-
DERPForceWebSockets bool
91-
90+
DERPForceWebSockets bool
91+
ShouldWaitForHandshake bool
9292
// BlockEndpoints specifies whether P2P endpoints are blocked.
9393
// If so, only DERPs can establish connections.
9494
BlockEndpoints bool
@@ -216,6 +216,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
216216
nodePrivateKey,
217217
magicConn.DiscoPublicKey(),
218218
)
219+
cfgMaps.setWaitForHandshake(options.ShouldWaitForHandshake)
219220
cfgMaps.setAddresses(options.Addresses)
220221
if options.DERPMap != nil {
221222
cfgMaps.setDERPMap(options.DERPMap)

0 commit comments

Comments
 (0)