@@ -110,7 +110,10 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store
110
110
logger = logger .Named ("pgcoord" ).With (slog .F ("coordinator_id" , id ))
111
111
bCh := make (chan binding )
112
112
cCh := make (chan agpl.Queue )
113
+ // for communicating subscriptions with the subscriber
113
114
sCh := make (chan subscribe )
115
+ // for communicating subscriptions with the querier
116
+ qsCh := make (chan subscribe )
114
117
// signals when first heartbeat has been sent, so it's safe to start binding.
115
118
fHB := make (chan struct {})
116
119
@@ -123,10 +126,10 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store
123
126
binder : newBinder (ctx , logger , id , store , bCh , fHB ),
124
127
bindings : bCh ,
125
128
newConnections : cCh ,
126
- subscriber : newSubscriber (ctx , logger , id , store , sCh , fHB ),
129
+ subscriber : newSubscriber (ctx , logger , id , store , sCh , qsCh , fHB ),
127
130
newSubscriptions : sCh ,
128
131
id : id ,
129
- querier : newQuerier (ctx , logger , id , ps , store , id , cCh , numQuerierWorkers , fHB ),
132
+ querier : newQuerier (ctx , logger , id , ps , store , id , cCh , qsCh , numQuerierWorkers , fHB ),
130
133
closed : make (chan struct {}),
131
134
}
132
135
logger .Info (ctx , "starting coordinator" )
@@ -160,11 +163,11 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
160
163
}
161
164
if err := sendCtx (c .ctx , c .newSubscriptions , subscribe {
162
165
sKey : sKey {clientID : id },
166
+ q : enq ,
163
167
active : false ,
164
168
}); err != nil {
165
169
c .logger .Debug (c .ctx , "parent context expired while withdrawing subscriptions" , slog .Error (err ))
166
170
}
167
- c .querier .cleanupConn (enq )
168
171
},
169
172
}).Init ()
170
173
@@ -184,13 +187,12 @@ func (c *pgCoord) addSubscription(q agpl.Queue, agentID uuid.UUID) error {
184
187
clientID : q .UniqueID (),
185
188
agentID : agentID ,
186
189
},
190
+ q : q ,
187
191
active : true ,
188
192
})
189
193
if err != nil {
190
194
return err
191
195
}
192
-
193
- c .querier .newClientSubscription (q , agentID )
194
196
return nil
195
197
}
196
198
@@ -200,13 +202,12 @@ func (c *pgCoord) removeSubscription(q agpl.Queue, agentID uuid.UUID) error {
200
202
clientID : q .UniqueID (),
201
203
agentID : agentID ,
202
204
},
205
+ q : q ,
203
206
active : false ,
204
207
})
205
208
if err != nil {
206
209
return err
207
210
}
208
-
209
- c .querier .removeClientSubscription (q , agentID )
210
211
return nil
211
212
}
212
213
@@ -307,6 +308,8 @@ type sKey struct {
307
308
308
309
type subscribe struct {
309
310
sKey
311
+
312
+ q agpl.Queue
310
313
// whether the subscription should be active. if true, the subscription is
311
314
// added. if false, the subscription is removed.
312
315
active bool
@@ -318,6 +321,7 @@ type subscriber struct {
318
321
coordinatorID uuid.UUID
319
322
store database.Store
320
323
subscriptions <- chan subscribe
324
+ querierCh chan <- subscribe
321
325
322
326
mu sync.Mutex
323
327
// map[clientID]map[agentID]subscribe
@@ -330,6 +334,7 @@ func newSubscriber(ctx context.Context,
330
334
id uuid.UUID ,
331
335
store database.Store ,
332
336
subscriptions <- chan subscribe ,
337
+ querierCh chan <- subscribe ,
333
338
startWorkers <- chan struct {},
334
339
) * subscriber {
335
340
s := & subscriber {
@@ -338,6 +343,7 @@ func newSubscriber(ctx context.Context,
338
343
coordinatorID : id ,
339
344
store : store ,
340
345
subscriptions : subscriptions ,
346
+ querierCh : querierCh ,
341
347
latest : make (map [uuid.UUID ]map [uuid.UUID ]subscribe ),
342
348
workQ : newWorkQ [sKey ](ctx ),
343
349
}
@@ -360,6 +366,7 @@ func (s *subscriber) handleSubscriptions() {
360
366
case sub := <- s .subscriptions :
361
367
s .storeSubscription (sub )
362
368
s .workQ .enqueue (sub .sKey )
369
+ s .querierCh <- sub
363
370
}
364
371
}
365
372
}
@@ -784,6 +791,7 @@ type querier struct {
784
791
store database.Store
785
792
786
793
newConnections chan agpl.Queue
794
+ subscriptions chan subscribe
787
795
788
796
workQ * workQ [mKey ]
789
797
@@ -812,6 +820,7 @@ func newQuerier(ctx context.Context,
812
820
store database.Store ,
813
821
self uuid.UUID ,
814
822
newConnections chan agpl.Queue ,
823
+ subscriptions chan subscribe ,
815
824
numWorkers int ,
816
825
firstHeartbeat chan struct {},
817
826
) * querier {
@@ -823,6 +832,7 @@ func newQuerier(ctx context.Context,
823
832
pubsub : ps ,
824
833
store : store ,
825
834
newConnections : newConnections ,
835
+ subscriptions : subscriptions ,
826
836
workQ : newWorkQ [mKey ](ctx ),
827
837
heartbeats : newHeartbeats (ctx , logger , ps , store , self , updates , firstHeartbeat ),
828
838
mappers : make (map [mKey ]* countedMapper ),
@@ -835,7 +845,7 @@ func newQuerier(ctx context.Context,
835
845
836
846
go func () {
837
847
<- firstHeartbeat
838
- go q .handleNewConnections ()
848
+ go q .handleIncoming ()
839
849
for i := 0 ; i < numWorkers ; i ++ {
840
850
go q .worker ()
841
851
}
@@ -844,11 +854,12 @@ func newQuerier(ctx context.Context,
844
854
return q
845
855
}
846
856
847
- func (q * querier ) handleNewConnections () {
857
+ func (q * querier ) handleIncoming () {
848
858
for {
849
859
select {
850
860
case <- q .ctx .Done ():
851
861
return
862
+
852
863
case c := <- q .newConnections :
853
864
switch c .Kind () {
854
865
case agpl .QueueKindAgent :
@@ -858,6 +869,13 @@ func (q *querier) handleNewConnections() {
858
869
default :
859
870
panic (fmt .Sprint ("unreachable: invalid queue kind " , c .Kind ()))
860
871
}
872
+
873
+ case sub := <- q .subscriptions :
874
+ if sub .active {
875
+ q .newClientSubscription (sub .q , sub .agentID )
876
+ } else {
877
+ q .removeClientSubscription (sub .q , sub .agentID )
878
+ }
861
879
}
862
880
}
863
881
}
@@ -905,6 +923,11 @@ func (q *querier) newClientSubscription(c agpl.Queue, agentID uuid.UUID) {
905
923
if _ , ok := q .clientSubscriptions [c .UniqueID ()]; ! ok {
906
924
q .clientSubscriptions [c .UniqueID ()] = map [uuid.UUID ]struct {}{}
907
925
}
926
+ fmt .Println ("CREATEDC SUBSCRIPTION" , c .UniqueID (), agentID )
927
+ fmt .Println ("CREATEDC SUBSCRIPTION" , c .UniqueID (), agentID )
928
+ fmt .Println ("CREATEDC SUBSCRIPTION" , c .UniqueID (), agentID )
929
+ fmt .Println ("CREATEDC SUBSCRIPTION" , c .UniqueID (), agentID )
930
+ fmt .Println ("CREATEDC SUBSCRIPTION" , c .UniqueID (), agentID )
908
931
909
932
mk := mKey {
910
933
agent : agentID ,
@@ -934,6 +957,12 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) {
934
957
q .mu .Lock ()
935
958
defer q .mu .Unlock ()
936
959
960
+ // agentID: uuid.Nil indicates that a client is going away. The querier
961
+ // handles that in cleanupConn below instead.
962
+ if agentID == uuid .Nil {
963
+ return
964
+ }
965
+
937
966
mk := mKey {
938
967
agent : agentID ,
939
968
kind : agpl .QueueKindClient ,
@@ -948,6 +977,9 @@ func (q *querier) removeClientSubscription(c agpl.Queue, agentID uuid.UUID) {
948
977
cm .cancel ()
949
978
delete (q .mappers , mk )
950
979
}
980
+ if len (q .clientSubscriptions [c .UniqueID ()]) == 0 {
981
+ delete (q .clientSubscriptions , c .UniqueID ())
982
+ }
951
983
}
952
984
953
985
func (q * querier ) newClientConn (c agpl.Queue ) {
@@ -982,18 +1014,17 @@ func (q *querier) cleanupConn(c agpl.Queue) {
982
1014
agent : agentID ,
983
1015
kind : c .Kind (),
984
1016
}
985
- cm , ok := q .mappers [mk ]
986
- if ok {
987
- if err := sendCtx (cm .ctx , cm .del , c ); err != nil {
988
- continue
989
- }
990
- cm .count --
991
- if cm .count == 0 {
992
- cm .cancel ()
993
- delete (q .mappers , mk )
994
- }
1017
+ cm := q .mappers [mk ]
1018
+ if err := sendCtx (cm .ctx , cm .del , c ); err != nil {
1019
+ continue
1020
+ }
1021
+ cm .count --
1022
+ if cm .count == 0 {
1023
+ cm .cancel ()
1024
+ delete (q .mappers , mk )
995
1025
}
996
1026
}
1027
+ delete (q .clientSubscriptions , c .UniqueID ())
997
1028
998
1029
mk := mKey {
999
1030
agent : c .UniqueID (),
@@ -1190,28 +1221,26 @@ func (q *querier) listenClient(_ context.Context, msg []byte, err error) {
1190
1221
q .logger .Warn (q .ctx , "unhandled pubsub error" , slog .Error (err ))
1191
1222
return
1192
1223
}
1193
- client , agents , err := parseClientUpdate (string (msg ))
1224
+ client , agent , err := parseClientUpdate (string (msg ))
1194
1225
if err != nil {
1195
1226
q .logger .Error (q .ctx , "failed to parse client update" , slog .F ("msg" , string (msg )), slog .Error (err ))
1196
1227
return
1197
1228
}
1198
- logger := q .logger .With (slog .F ("client_id" , client ))
1229
+ logger := q .logger .With (slog .F ("client_id" , client ), slog . F ( "agent_id" , agent ) )
1199
1230
logger .Debug (q .ctx , "got client update" )
1200
- for _ , agentID := range agents {
1201
- logger := q .logger .With (slog .F ("agent_id" , agentID ))
1202
- mk := mKey {
1203
- agent : agentID ,
1204
- kind : agpl .QueueKindAgent ,
1205
- }
1206
- q .mu .Lock ()
1207
- _ , ok := q .mappers [mk ]
1208
- q .mu .Unlock ()
1209
- if ! ok {
1210
- logger .Debug (q .ctx , "ignoring update because we have no mapper" )
1211
- return
1212
- }
1213
- q .workQ .enqueue (mk )
1231
+
1232
+ mk := mKey {
1233
+ agent : agent ,
1234
+ kind : agpl .QueueKindAgent ,
1214
1235
}
1236
+ q .mu .Lock ()
1237
+ _ , ok := q .mappers [mk ]
1238
+ q .mu .Unlock ()
1239
+ if ! ok {
1240
+ logger .Debug (q .ctx , "ignoring update because we have no mapper" )
1241
+ return
1242
+ }
1243
+ q .workQ .enqueue (mk )
1215
1244
}
1216
1245
1217
1246
func (q * querier ) listenAgent (_ context.Context , msg []byte , err error ) {
@@ -1348,27 +1377,22 @@ func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAge
1348
1377
return agentsMap , clientsMap , nil
1349
1378
}
1350
1379
1351
- func parseClientUpdate (msg string ) (client uuid. UUID , agents [] uuid.UUID , err error ) {
1380
+ func parseClientUpdate (msg string ) (client , agent uuid.UUID , err error ) {
1352
1381
parts := strings .Split (msg , "," )
1353
1382
if len (parts ) != 2 {
1354
- return uuid .Nil , nil , xerrors .Errorf ("expected 2 parts separated by comma" )
1383
+ return uuid .Nil , uuid . Nil , xerrors .Errorf ("expected 2 parts separated by comma" )
1355
1384
}
1356
1385
client , err = uuid .Parse (parts [0 ])
1357
1386
if err != nil {
1358
- return uuid .Nil , nil , xerrors .Errorf ("failed to parse client UUID: %w" , err )
1387
+ return uuid .Nil , uuid . Nil , xerrors .Errorf ("failed to parse client UUID: %w" , err )
1359
1388
}
1360
1389
1361
- agents = []uuid.UUID {}
1362
- for _ , agentStr := range parts [1 :] {
1363
- agent , err := uuid .Parse (agentStr )
1364
- if err != nil {
1365
- return uuid .Nil , nil , xerrors .Errorf ("failed to parse agent UUID: %w" , err )
1366
- }
1367
-
1368
- agents = append (agents , agent )
1390
+ agent , err = uuid .Parse (parts [1 ])
1391
+ if err != nil {
1392
+ return uuid .Nil , uuid .Nil , xerrors .Errorf ("failed to parse agent UUID: %w" , err )
1369
1393
}
1370
1394
1371
- return client , agents , nil
1395
+ return client , agent , nil
1372
1396
}
1373
1397
1374
1398
func parseUpdateMessage (msg string ) (agent uuid.UUID , err error ) {
0 commit comments