Skip to content

Commit 7a40bf8

Browse files
committed
Run pubsub loop in a queue
1 parent 9878fc5 commit 7a40bf8

File tree

1 file changed

+113
-89
lines changed

1 file changed

+113
-89
lines changed

enterprise/tailnet/coordinator.go

Lines changed: 113 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,32 @@ import (
2121
// NewCoordinator creates a new high availability coordinator
2222
// that uses PostgreSQL pubsub to exchange handshakes.
2323
func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinator, error) {
24+
ctx, cancelFunc := context.WithCancel(context.Background())
2425
coord := &haCoordinator{
2526
id: uuid.New(),
2627
log: logger,
2728
pubsub: pubsub,
29+
closeFunc: cancelFunc,
2830
close: make(chan struct{}),
2931
nodes: map[uuid.UUID]*agpl.Node{},
3032
agentSockets: map[uuid.UUID]net.Conn{},
3133
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
3234
}
3335

34-
if err := coord.runPubsub(); err != nil {
36+
if err := coord.runPubsub(ctx); err != nil {
3537
return nil, xerrors.Errorf("run coordinator pubsub: %w", err)
3638
}
3739

3840
return coord, nil
3941
}
4042

4143
type haCoordinator struct {
42-
id uuid.UUID
43-
log slog.Logger
44-
mutex sync.RWMutex
45-
pubsub database.Pubsub
46-
close chan struct{}
44+
id uuid.UUID
45+
log slog.Logger
46+
mutex sync.RWMutex
47+
pubsub database.Pubsub
48+
close chan struct{}
49+
closeFunc context.CancelFunc
4750

4851
// nodes maps agent and connection IDs their respective node.
4952
nodes map[uuid.UUID]*agpl.Node
@@ -303,6 +306,7 @@ func (c *haCoordinator) Close() error {
303306
default:
304307
}
305308
close(c.close)
309+
c.closeFunc()
306310

307311
wg := sync.WaitGroup{}
308312

@@ -384,111 +388,131 @@ func (c *haCoordinator) publishAgentToNodes(id uuid.UUID, node *agpl.Node) error
384388
return nil
385389
}
386390

387-
func (c *haCoordinator) runPubsub() error {
391+
func (c *haCoordinator) runPubsub(ctx context.Context) error {
392+
messageQueue := make(chan []byte, 64)
388393
cancelSub, err := c.pubsub.Subscribe("wireguard_peers", func(ctx context.Context, message []byte) {
389-
sp := bytes.Split(message, []byte("|"))
390-
if len(sp) != 4 {
391-
c.log.Error(ctx, "invalid wireguard peer message", slog.F("msg", string(message)))
394+
select {
395+
case messageQueue <- message:
396+
case <-ctx.Done():
392397
return
393398
}
399+
})
400+
if err != nil {
401+
return xerrors.Errorf("subscribe wireguard peers")
402+
}
403+
go func() {
404+
for {
405+
var message []byte
406+
select {
407+
case <-ctx.Done():
408+
return
409+
case message = <-messageQueue:
410+
}
411+
c.handlePubsubMessage(ctx, message)
412+
}
413+
}()
414+
415+
go func() {
416+
defer cancelSub()
417+
<-c.close
418+
}()
419+
420+
return nil
421+
}
422+
423+
func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte) {
424+
sp := bytes.Split(message, []byte("|"))
425+
if len(sp) != 4 {
426+
c.log.Error(ctx, "invalid wireguard peer message", slog.F("msg", string(message)))
427+
return
428+
}
429+
430+
var (
431+
coordinatorID = sp[0]
432+
eventType = sp[1]
433+
agentID = sp[2]
434+
nodeJSON = sp[3]
435+
)
394436

395-
var (
396-
coordinatorID = sp[0]
397-
eventType = sp[1]
398-
agentID = sp[2]
399-
nodeJSON = sp[3]
400-
)
437+
sender, err := uuid.ParseBytes(coordinatorID)
438+
if err != nil {
439+
c.log.Error(ctx, "invalid sender id", slog.F("id", string(coordinatorID)), slog.F("msg", string(message)))
440+
return
441+
}
401442

402-
sender, err := uuid.ParseBytes(coordinatorID)
443+
// We sent this message!
444+
if sender == c.id {
445+
return
446+
}
447+
448+
switch string(eventType) {
449+
case "callmemaybe":
450+
agentUUID, err := uuid.ParseBytes(agentID)
403451
if err != nil {
404-
c.log.Error(ctx, "invalid sender id", slog.F("id", string(coordinatorID)), slog.F("msg", string(message)))
452+
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
405453
return
406454
}
407455

408-
// We sent this message!
409-
if sender == c.id {
456+
c.mutex.Lock()
457+
agentSocket, ok := c.agentSockets[agentUUID]
458+
if !ok {
459+
c.mutex.Unlock()
410460
return
411461
}
462+
c.mutex.Unlock()
412463

413-
switch string(eventType) {
414-
case "callmemaybe":
415-
agentUUID, err := uuid.ParseBytes(agentID)
416-
if err != nil {
417-
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
418-
return
419-
}
420-
421-
c.mutex.Lock()
422-
agentSocket, ok := c.agentSockets[agentUUID]
423-
if !ok {
424-
c.mutex.Unlock()
425-
return
426-
}
427-
c.mutex.Unlock()
428-
429-
// We get a single node over pubsub, so turn into an array.
430-
_, err = agentSocket.Write(nodeJSON)
431-
if err != nil {
432-
if errors.Is(err, io.EOF) {
433-
return
434-
}
435-
c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err))
436-
return
437-
}
438-
case "clienthello":
439-
agentUUID, err := uuid.ParseBytes(agentID)
440-
if err != nil {
441-
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
464+
// We get a single node over pubsub, so turn into an array.
465+
_, err = agentSocket.Write(nodeJSON)
466+
if err != nil {
467+
if errors.Is(err, io.EOF) {
442468
return
443469
}
470+
c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err))
471+
return
472+
}
473+
case "clienthello":
474+
agentUUID, err := uuid.ParseBytes(agentID)
475+
if err != nil {
476+
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
477+
return
478+
}
444479

445-
err = c.handleClientHello(agentUUID)
446-
if err != nil {
447-
c.log.Error(ctx, "handle agent request node", slog.Error(err))
448-
return
449-
}
450-
case "agenthello":
451-
agentUUID, err := uuid.ParseBytes(agentID)
452-
if err != nil {
453-
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
454-
return
455-
}
480+
err = c.handleClientHello(agentUUID)
481+
if err != nil {
482+
c.log.Error(ctx, "handle agent request node", slog.Error(err))
483+
return
484+
}
485+
case "agenthello":
486+
agentUUID, err := uuid.ParseBytes(agentID)
487+
if err != nil {
488+
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
489+
return
490+
}
456491

457-
nodes := c.nodesSubscribedToAgent(agentUUID)
458-
if len(nodes) > 0 {
459-
err := c.publishNodesToAgent(agentUUID, nodes)
460-
if err != nil {
461-
c.log.Error(ctx, "publish nodes to agent", slog.Error(err))
462-
return
463-
}
464-
}
465-
case "agentupdate":
466-
agentUUID, err := uuid.ParseBytes(agentID)
492+
nodes := c.nodesSubscribedToAgent(agentUUID)
493+
if len(nodes) > 0 {
494+
err := c.publishNodesToAgent(agentUUID, nodes)
467495
if err != nil {
468-
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
496+
c.log.Error(ctx, "publish nodes to agent", slog.Error(err))
469497
return
470498
}
499+
}
500+
case "agentupdate":
501+
agentUUID, err := uuid.ParseBytes(agentID)
502+
if err != nil {
503+
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
504+
return
505+
}
471506

472-
decoder := json.NewDecoder(bytes.NewReader(nodeJSON))
473-
_, err = c.handleAgentUpdate(agentUUID, decoder)
474-
if err != nil {
475-
c.log.Error(ctx, "handle agent update", slog.Error(err))
476-
return
477-
}
478-
default:
479-
c.log.Error(ctx, "unknown peer event", slog.F("name", string(eventType)))
507+
decoder := json.NewDecoder(bytes.NewReader(nodeJSON))
508+
_, err = c.handleAgentUpdate(agentUUID, decoder)
509+
if err != nil {
510+
c.log.Error(ctx, "handle agent update", slog.Error(err))
511+
return
480512
}
481-
})
482-
if err != nil {
483-
return xerrors.Errorf("subscribe wireguard peers")
513+
default:
514+
c.log.Error(ctx, "unknown peer event", slog.F("name", string(eventType)))
484515
}
485-
486-
go func() {
487-
defer cancelSub()
488-
<-c.close
489-
}()
490-
491-
return nil
492516
}
493517

494518
// format: <coordinator id>|callmemaybe|<recipient id>|<node json>

0 commit comments

Comments
 (0)