Skip to content

Commit 9878fc5

Browse files
committed
Fix client publish race
1 parent a0e5cab commit 9878fc5

File tree

2 files changed

+61
-10
lines changed

2 files changed

+61
-10
lines changed

enterprise/tailnet/coordinator.go

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ type haCoordinator struct {
5656

5757
// Node returns an in-memory node by ID.
5858
func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node {
59-
c.mutex.RLock()
60-
defer c.mutex.RUnlock()
59+
c.mutex.Lock()
60+
defer c.mutex.Unlock()
6161
node := c.nodes[id]
6262
return node
6363
}
@@ -79,6 +79,11 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID
7979
if err != nil {
8080
return xerrors.Errorf("write nodes: %w", err)
8181
}
82+
} else {
83+
err := c.publishClientHello(agent)
84+
if err != nil {
85+
return xerrors.Errorf("publish client hello: %w", err)
86+
}
8287
}
8388

8489
c.mutex.Lock()
@@ -205,7 +210,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
205210

206211
decoder := json.NewDecoder(conn)
207212
for {
208-
node, err := c.hangleAgentUpdate(id, decoder)
213+
node, err := c.handleAgentUpdate(id, decoder)
209214
if err != nil {
210215
if errors.Is(err, io.EOF) {
211216
return nil
@@ -240,7 +245,17 @@ func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node {
240245
return nodes
241246
}
242247

243-
func (c *haCoordinator) hangleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) {
248+
func (c *haCoordinator) handleClientHello(id uuid.UUID) error {
249+
c.mutex.Lock()
250+
node, ok := c.nodes[id]
251+
c.mutex.Unlock()
252+
if !ok {
253+
return nil
254+
}
255+
return c.publishAgentToNodes(id, node)
256+
}
257+
258+
func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) {
244259
var node agpl.Node
245260
err := decoder.Decode(&node)
246261
if err != nil {
@@ -343,6 +358,18 @@ func (c *haCoordinator) publishAgentHello(id uuid.UUID) error {
343358
return nil
344359
}
345360

361+
func (c *haCoordinator) publishClientHello(id uuid.UUID) error {
362+
msg, err := c.formatClientHello(id)
363+
if err != nil {
364+
return xerrors.Errorf("format client hello: %w", err)
365+
}
366+
err = c.pubsub.Publish("wireguard_peers", msg)
367+
if err != nil {
368+
return xerrors.Errorf("publish client hello: %w", err)
369+
}
370+
return nil
371+
}
372+
346373
func (c *haCoordinator) publishAgentToNodes(id uuid.UUID, node *agpl.Node) error {
347374
msg, err := c.formatAgentUpdate(id, node)
348375
if err != nil {
@@ -408,6 +435,18 @@ func (c *haCoordinator) runPubsub() error {
408435
c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err))
409436
return
410437
}
438+
case "clienthello":
439+
agentUUID, err := uuid.ParseBytes(agentID)
440+
if err != nil {
441+
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
442+
return
443+
}
444+
445+
err = c.handleClientHello(agentUUID)
446+
if err != nil {
447+
c.log.Error(ctx, "handle agent request node", slog.Error(err))
448+
return
449+
}
411450
case "agenthello":
412451
agentUUID, err := uuid.ParseBytes(agentID)
413452
if err != nil {
@@ -431,7 +470,7 @@ func (c *haCoordinator) runPubsub() error {
431470
}
432471

433472
decoder := json.NewDecoder(bytes.NewReader(nodeJSON))
434-
_, err = c.hangleAgentUpdate(agentUUID, decoder)
473+
_, err = c.handleAgentUpdate(agentUUID, decoder)
435474
if err != nil {
436475
c.log.Error(ctx, "handle agent update", slog.Error(err))
437476
return
@@ -478,6 +517,17 @@ func (c *haCoordinator) formatAgentHello(id uuid.UUID) ([]byte, error) {
478517
return buf.Bytes(), nil
479518
}
480519

520+
// format: <coordinator id>|clienthello|<agent id>|
521+
func (c *haCoordinator) formatClientHello(id uuid.UUID) ([]byte, error) {
522+
buf := bytes.Buffer{}
523+
524+
buf.WriteString(c.id.String() + "|")
525+
buf.WriteString("clienthello|")
526+
buf.WriteString(id.String() + "|")
527+
528+
return buf.Bytes(), nil
529+
}
530+
481531
// format: <coordinator id>|agentupdate|<node id>|<node json>
482532
func (c *haCoordinator) formatAgentUpdate(id uuid.UUID, node *agpl.Node) ([]byte, error) {
483533
buf := bytes.Buffer{}

enterprise/tailnet/coordinator_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"cdr.dev/slog/sloggers/slogtest"
1212

1313
"github.com/coder/coder/coderd/database"
14+
"github.com/coder/coder/coderd/database/dbtestutil"
1415
"github.com/coder/coder/enterprise/tailnet"
1516
agpl "github.com/coder/coder/tailnet"
1617
"github.com/coder/coder/testutil"
@@ -167,16 +168,12 @@ func TestCoordinatorHA(t *testing.T) {
167168
t.Run("AgentWithClient", func(t *testing.T) {
168169
t.Parallel()
169170

170-
pubsub := database.NewPubsubInMemory()
171+
_, pubsub := dbtestutil.NewDB(t)
171172

172173
coordinator1, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub)
173174
require.NoError(t, err)
174175
defer coordinator1.Close()
175176

176-
coordinator2, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub)
177-
require.NoError(t, err)
178-
defer coordinator2.Close()
179-
180177
agentWS, agentServerWS := net.Pipe()
181178
defer agentWS.Close()
182179
agentNodeChan := make(chan []*agpl.Node)
@@ -196,6 +193,10 @@ func TestCoordinatorHA(t *testing.T) {
196193
return coordinator1.Node(agentID) != nil
197194
}, testutil.WaitShort, testutil.IntervalFast)
198195

196+
coordinator2, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub)
197+
require.NoError(t, err)
198+
defer coordinator2.Close()
199+
199200
clientWS, clientServerWS := net.Pipe()
200201
defer clientWS.Close()
201202
defer clientServerWS.Close()

0 commit comments

Comments
 (0)