Skip to content

Commit 7762a73

Browse files
committed
querier <- subscriber
1 parent 3af1af1 commit 7762a73

File tree

1 file changed

+71
-47
lines changed

1 file changed

+71
-47
lines changed

enterprise/tailnet/pgcoord.go

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store
110110
logger = logger.Named("pgcoord").With(slog.F("coordinator_id", id))
111111
bCh := make(chan binding)
112112
cCh := make(chan agpl.Queue)
113+
// for communicating subscriptions with the subscriber
113114
sCh := make(chan subscribe)
115+
// for communicating subscriptions with the querier
116+
qsCh := make(chan subscribe)
114117
// signals when first heartbeat has been sent, so it's safe to start binding.
115118
fHB := make(chan struct{})
116119

@@ -123,10 +126,10 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store
123126
binder: newBinder(ctx, logger, id, store, bCh, fHB),
124127
bindings: bCh,
125128
newConnections: cCh,
126-
subscriber: newSubscriber(ctx, logger, id, store, sCh, fHB),
129+
subscriber: newSubscriber(ctx, logger, id, store, sCh, qsCh, fHB),
127130
newSubscriptions: sCh,
128131
id: id,
129-
querier: newQuerier(ctx, logger, id, ps, store, id, cCh, numQuerierWorkers, fHB),
132+
querier: newQuerier(ctx, logger, id, ps, store, id, cCh, qsCh, numQuerierWorkers, fHB),
130133
closed: make(chan struct{}),
131134
}
132135
logger.Info(ctx, "starting coordinator")
@@ -160,11 +163,11 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
160163
}
161164
if err := sendCtx(c.ctx, c.newSubscriptions, subscribe{
162165
sKey: sKey{clientID: id},
166+
q: enq,
163167
active: false,
164168
}); err != nil {
165169
c.logger.Debug(c.ctx, "parent context expired while withdrawing subscriptions", slog.Error(err))
166170
}
167-
c.querier.cleanupConn(enq)
168171
},
169172
}).Init()
170173

@@ -184,13 +187,12 @@ func (c *pgCoord) addSubscription(q agpl.Queue, agentID uuid.UUID) error {
184187
clientID: q.UniqueID(),
185188
agentID: agentID,
186189
},
190+
q: q,
187191
active: true,
188192
})
189193
if err != nil {
190194
return err
191195
}
192-
193-
c.querier.newClientSubscription(q, agentID)
194196
return nil
195197
}
196198

@@ -200,13 +202,12 @@ func (c *pgCoord) removeSubscription(q agpl.Queue, agentID uuid.UUID) error {
200202
clientID: q.UniqueID(),
201203
agentID: agentID,
202204
},
205+
q: q,
203206
active: false,
204207
})
205208
if err != nil {
206209
return err
207210
}
208-
209-
c.querier.removeClientSubscription(q, agentID)
210211
return nil
211212
}
212213

@@ -307,6 +308,8 @@ type sKey struct {
307308

308309
type subscribe struct {
309310
sKey
311+
312+
q agpl.Queue
310313
// whether the subscription should be active. if true, the subscription is
311314
// added. if false, the subscription is removed.
312315
active bool
@@ -318,6 +321,7 @@ type subscriber struct {
318321
coordinatorID uuid.UUID
319322
store database.Store
320323
subscriptions <-chan subscribe
324+
querierCh chan<- subscribe
321325

322326
mu sync.Mutex
323327
// map[clientID]map[agentID]subscribe
@@ -330,6 +334,7 @@ func newSubscriber(ctx context.Context,
330334
id uuid.UUID,
331335
store database.Store,
332336
subscriptions <-chan subscribe,
337+
querierCh chan<- subscribe,
333338
startWorkers <-chan struct{},
334339
) *subscriber {
335340
s := &subscriber{
@@ -338,6 +343,7 @@ func newSubscriber(ctx context.Context,
338343
coordinatorID: id,
339344
store: store,
340345
subscriptions: subscriptions,
346+
querierCh: querierCh,
341347
latest: make(map[uuid.UUID]map[uuid.UUID]subscribe),
342348
workQ: newWorkQ[sKey](ctx),
343349
}
@@ -360,6 +366,7 @@ func (s *subscriber) handleSubscriptions() {
360366
case sub := <-s.subscriptions:
361367
s.storeSubscription(sub)
362368
s.workQ.enqueue(sub.sKey)
369+
s.querierCh <- sub
363370
}
364371
}
365372
}
@@ -784,6 +791,7 @@ type querier struct {
784791
store database.Store
785792

786793
newConnections chan agpl.Queue
794+
subscriptions chan subscribe
787795

788796
workQ *workQ[mKey]
789797

@@ -812,6 +820,7 @@ func newQuerier(ctx context.Context,
812820
store database.Store,
813821
self uuid.UUID,
814822
newConnections chan agpl.Queue,
823+
subscriptions chan subscribe,
815824
numWorkers int,
816825
firstHeartbeat chan struct{},
817826
) *querier {
@@ -823,6 +832,7 @@ func newQuerier(ctx context.Context,
823832
pubsub: ps,
824833
store: store,
825834
newConnections: newConnections,
835+
subscriptions: subscriptions,
826836
workQ: newWorkQ[mKey](ctx),
827837
heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat),
828838
mappers: make(map[mKey]*countedMapper),
@@ -835,7 +845,7 @@ func newQuerier(ctx context.Context,
835845

836846
go func() {
837847
<-firstHeartbeat
838-
go q.handleNewConnections()
848+
go q.handleIncoming()
839849
for i := 0; i < numWorkers; i++ {
840850
go q.worker()
841851
}
@@ -844,11 +854,12 @@ func newQuerier(ctx context.Context,
844854
return q
845855
}
846856

847-
func (q *querier) handleNewConnections() {
857+
func (q *querier) handleIncoming() {
848858
for {
849859
select {
850860
case <-q.ctx.Done():
851861
return
862+
852863
case c := <-q.newConnections:
853864
switch c.Kind() {
854865
case agpl.QueueKindAgent:
@@ -858,6 +869,13 @@ func (q *querier) handleNewConnections() {
858869
default:
859870
panic(fmt.Sprint("unreachable: invalid queue kind ", c.Kind()))
860871
}
872+
873+
case sub := <-q.subscriptions:
874+
if sub.active {
875+
q.newClientSubscription(sub.q, sub.agentID)
876+
} else {
877+
q.removeClientSubscription(sub.q, sub.agentID)
878+
}
861879
}
862880
}
863881
}
@@ -905,6 +923,11 @@ func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) {
905923
if _, ok := q.clientSubscriptions[c.UniqueID()]; !ok {
906924
q.clientSubscriptions[c.UniqueID()] = map[uuid.UUID]struct{}{}
907925
}
926+
fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID)
927+
fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID)
928+
fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID)
929+
fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID)
930+
fmt.Println("CREATEDC SUBSCRIPTION", c.UniqueID(), agentID)
908931

909932
mk := mKey{
910933
agent: agentID,
@@ -934,6 +957,12 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) {
934957
q.mu.Lock()
935958
defer q.mu.Unlock()
936959

960+
// agentID: uuid.Nil indicates that a client is going away. The querier
961+
// handles that in cleanupConn below instead.
962+
if agentID == uuid.Nil {
963+
return
964+
}
965+
937966
mk := mKey{
938967
agent: agentID,
939968
kind: agpl.QueueKindClient,
@@ -948,6 +977,9 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) {
948977
cm.cancel()
949978
delete(q.mappers, mk)
950979
}
980+
if len(q.clientSubscriptions[c.UniqueID()]) == 0 {
981+
delete(q.clientSubscriptions, c.UniqueID())
982+
}
951983
}
952984

953985
func (q *querier) newClientConn(c agpl.Queue) {
@@ -982,18 +1014,17 @@ func (q *querier) cleanupConn(c agpl.Queue) {
9821014
agent: agentID,
9831015
kind: c.Kind(),
9841016
}
985-
cm, ok := q.mappers[mk]
986-
if ok {
987-
if err := sendCtx(cm.ctx, cm.del, c); err != nil {
988-
continue
989-
}
990-
cm.count--
991-
if cm.count == 0 {
992-
cm.cancel()
993-
delete(q.mappers, mk)
994-
}
1017+
cm := q.mappers[mk]
1018+
if err := sendCtx(cm.ctx, cm.del, c); err != nil {
1019+
continue
1020+
}
1021+
cm.count--
1022+
if cm.count == 0 {
1023+
cm.cancel()
1024+
delete(q.mappers, mk)
9951025
}
9961026
}
1027+
delete(q.clientSubscriptions, c.UniqueID())
9971028

9981029
mk := mKey{
9991030
agent: c.UniqueID(),
@@ -1190,28 +1221,26 @@ func (q *querier) listenClient(_ context.Context, msg []byte, err error) {
11901221
q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err))
11911222
return
11921223
}
1193-
client, agents, err := parseClientUpdate(string(msg))
1224+
client, agent, err := parseClientUpdate(string(msg))
11941225
if err != nil {
11951226
q.logger.Error(q.ctx, "failed to parse client update", slog.F("msg", string(msg)), slog.Error(err))
11961227
return
11971228
}
1198-
logger := q.logger.With(slog.F("client_id", client))
1229+
logger := q.logger.With(slog.F("client_id", client), slog.F("agent_id", agent))
11991230
logger.Debug(q.ctx, "got client update")
1200-
for _, agentID := range agents {
1201-
logger := q.logger.With(slog.F("agent_id", agentID))
1202-
mk := mKey{
1203-
agent: agentID,
1204-
kind: agpl.QueueKindAgent,
1205-
}
1206-
q.mu.Lock()
1207-
_, ok := q.mappers[mk]
1208-
q.mu.Unlock()
1209-
if !ok {
1210-
logger.Debug(q.ctx, "ignoring update because we have no mapper")
1211-
return
1212-
}
1213-
q.workQ.enqueue(mk)
1231+
1232+
mk := mKey{
1233+
agent: agent,
1234+
kind: agpl.QueueKindAgent,
12141235
}
1236+
q.mu.Lock()
1237+
_, ok := q.mappers[mk]
1238+
q.mu.Unlock()
1239+
if !ok {
1240+
logger.Debug(q.ctx, "ignoring update because we have no mapper")
1241+
return
1242+
}
1243+
q.workQ.enqueue(mk)
12151244
}
12161245

12171246
func (q *querier) listenAgent(_ context.Context, msg []byte, err error) {
@@ -1348,27 +1377,22 @@ func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAge
13481377
return agentsMap, clientsMap, nil
13491378
}
13501379

1351-
func parseClientUpdate(msg string) (client uuid.UUID, agents []uuid.UUID, err error) {
1380+
func parseClientUpdate(msg string) (client, agent uuid.UUID, err error) {
13521381
parts := strings.Split(msg, ",")
13531382
if len(parts) != 2 {
1354-
return uuid.Nil, nil, xerrors.Errorf("expected 2 parts separated by comma")
1383+
return uuid.Nil, uuid.Nil, xerrors.Errorf("expected 2 parts separated by comma")
13551384
}
13561385
client, err = uuid.Parse(parts[0])
13571386
if err != nil {
1358-
return uuid.Nil, nil, xerrors.Errorf("failed to parse client UUID: %w", err)
1387+
return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse client UUID: %w", err)
13591388
}
13601389

1361-
agents = []uuid.UUID{}
1362-
for _, agentStr := range parts[1:] {
1363-
agent, err := uuid.Parse(agentStr)
1364-
if err != nil {
1365-
return uuid.Nil, nil, xerrors.Errorf("failed to parse agent UUID: %w", err)
1366-
}
1367-
1368-
agents = append(agents, agent)
1390+
agent, err = uuid.Parse(parts[1])
1391+
if err != nil {
1392+
return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err)
13691393
}
13701394

1371-
return client, agents, nil
1395+
return client, agent, nil
13721396
}
13731397

13741398
func parseUpdateMessage(msg string) (agent uuid.UUID, err error) {

0 commit comments

Comments
 (0)