Skip to content

Commit ce96564

Browse files
committed
feat(enterprise): add ready for handshake support to pgcoord
1 parent e801e87 commit ce96564

File tree

14 files changed

+442
-33
lines changed

14 files changed

+442
-33
lines changed

coderd/database/db.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *sql.TxOptions) err
103103
// Transaction succeeded.
104104
return nil
105105
}
106-
if err != nil && !IsSerializedError(err) {
106+
if !IsSerializedError(err) {
107107
// We should only retry if the error is a serialization error.
108108
return err
109109
}

coderd/database/dbauthz/dbauthz.go

+7
Original file line numberDiff line numberDiff line change
@@ -2645,6 +2645,13 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
26452645
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
26462646
}
26472647

2648+
func (q *querier) PublishReadyForHandshake(ctx context.Context, arg database.PublishReadyForHandshakeParams) error {
2649+
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
2650+
return err
2651+
}
2652+
return q.db.PublishReadyForHandshake(ctx, arg)
2653+
}
2654+
26482655
func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
26492656
template, err := q.db.GetTemplateByID(ctx, templateID)
26502657
if err != nil {

coderd/database/dbauthz/dbauthz_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -1829,6 +1829,11 @@ func (s *MethodTestSuite) TestTailnetFunctions() {
18291829
Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionCreate).
18301830
Errors(dbmem.ErrUnimplemented)
18311831
}))
1832+
s.Run("PublishReadyForHandshake", s.Subtest(func(db database.Store, check *expects) {
1833+
check.Args(database.PublishReadyForHandshakeParams{}).
1834+
Asserts(rbac.ResourceTailnetCoordinator, rbac.ActionUpdate).
1835+
Errors(dbmem.ErrUnimplemented)
1836+
}))
18321837
}
18331838

18341839
func (s *MethodTestSuite) TestDBCrypt() {

coderd/database/dbmem/dbmem.go

+4
Original file line numberDiff line numberDiff line change
@@ -6742,6 +6742,10 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
67426742
return shares, nil
67436743
}
67446744

6745+
func (*FakeQuerier) PublishReadyForHandshake(context.Context, database.PublishReadyForHandshakeParams) error {
6746+
return ErrUnimplemented
6747+
}
6748+
67456749
func (q *FakeQuerier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(_ context.Context, templateID uuid.UUID) error {
67466750
err := validateDatabaseType(templateID)
67476751
if err != nil {

coderd/database/dbmetrics/dbmetrics.go

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

+14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

+17
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/tailnet.sql

+6
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,12 @@ FROM tailnet_tunnels
207207
INNER JOIN tailnet_peers ON tailnet_tunnels.src_id = tailnet_peers.id
208208
WHERE tailnet_tunnels.dst_id = $1;
209209

210+
-- name: PublishReadyForHandshake :exec
211+
SELECT pg_notify(
212+
'tailnet_ready_for_handshake',
213+
format('%s,%s', sqlc.arg('to')::text, sqlc.arg('from')::text)
214+
);
215+
210216
-- For PG Coordinator HTMLDebug
211217

212218
-- name: GetAllTailnetCoordinators :many

enterprise/tailnet/connio.go

+23
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type connIO struct {
3030
responses chan<- *proto.CoordinateResponse
3131
bindings chan<- binding
3232
tunnels chan<- tunnel
33+
rfhs chan<- readyForHandshake
3334
auth agpl.CoordinateeAuth
3435
mu sync.Mutex
3536
closed bool
@@ -46,6 +47,7 @@ func newConnIO(coordContext context.Context,
4647
logger slog.Logger,
4748
bindings chan<- binding,
4849
tunnels chan<- tunnel,
50+
rfhs chan<- readyForHandshake,
4951
requests <-chan *proto.CoordinateRequest,
5052
responses chan<- *proto.CoordinateResponse,
5153
id uuid.UUID,
@@ -64,6 +66,7 @@ func newConnIO(coordContext context.Context,
6466
responses: responses,
6567
bindings: bindings,
6668
tunnels: tunnels,
69+
rfhs: rfhs,
6770
auth: auth,
6871
name: name,
6972
start: now,
@@ -190,6 +193,26 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
190193
c.disconnected = true
191194
return errDisconnect
192195
}
196+
if req.ReadyForHandshake != nil {
197+
c.logger.Debug(c.peerCtx, "got ready for handshake ", slog.F("rfh", req.ReadyForHandshake))
198+
for _, rfh := range req.ReadyForHandshake {
199+
dst, err := uuid.FromBytes(rfh.Id)
200+
if err != nil {
201+
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
202+
// this shouldn't happen unless there is a client error. Close the connection so the client
203+
// doesn't just happily continue thinking everything is fine.
204+
return err
205+
}
206+
207+
if err := agpl.SendCtx(c.coordCtx, c.rfhs, readyForHandshake{hKey: hKey{
208+
src: c.id,
209+
dst: dst,
210+
}}); err != nil {
211+
c.logger.Debug(c.peerCtx, "failed to send ready for handshake", slog.Error(err))
212+
return err
213+
}
214+
}
215+
}
193216
return nil
194217
}
195218

enterprise/tailnet/handshaker.go

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package tailnet
2+
3+
import (
4+
"context"
5+
"slices"
6+
"sync"
7+
8+
"github.com/cenkalti/backoff/v4"
9+
"github.com/google/uuid"
10+
11+
"cdr.dev/slog"
12+
"github.com/coder/coder/v2/coderd/database"
13+
)
14+
15+
type readyForHandshake struct {
16+
hKey
17+
}
18+
19+
type hKey struct {
20+
src uuid.UUID
21+
dst uuid.UUID
22+
}
23+
24+
type handshaker struct {
25+
ctx context.Context
26+
logger slog.Logger
27+
coordinatorID uuid.UUID
28+
store database.Store
29+
updates <-chan readyForHandshake
30+
31+
workQ *workQ[hKey]
32+
33+
workerWG sync.WaitGroup
34+
}
35+
36+
func newHandshaker(ctx context.Context,
37+
logger slog.Logger,
38+
id uuid.UUID,
39+
store database.Store,
40+
updates <-chan readyForHandshake,
41+
startWorkers <-chan struct{},
42+
) *handshaker {
43+
s := &handshaker{
44+
ctx: ctx,
45+
logger: logger,
46+
coordinatorID: id,
47+
store: store,
48+
updates: updates,
49+
workQ: newWorkQ[hKey](ctx),
50+
}
51+
go s.handle()
52+
// add to the waitgroup immediately to avoid any races waiting for it before
53+
// the workers start.
54+
s.workerWG.Add(numHandshakerWorkers)
55+
go func() {
56+
<-startWorkers
57+
for i := 0; i < numHandshakerWorkers; i++ {
58+
go s.worker()
59+
}
60+
}()
61+
return s
62+
}
63+
64+
func (t *handshaker) handle() {
65+
for {
66+
select {
67+
case <-t.ctx.Done():
68+
t.logger.Debug(t.ctx, "handshaker exiting", slog.Error(t.ctx.Err()))
69+
return
70+
case rfh := <-t.updates:
71+
t.workQ.enqueue(rfh.hKey)
72+
}
73+
}
74+
}
75+
76+
func (t *handshaker) worker() {
77+
defer t.workerWG.Done()
78+
eb := backoff.NewExponentialBackOff()
79+
eb.MaxElapsedTime = 0 // retry indefinitely
80+
eb.MaxInterval = dbMaxBackoff
81+
bkoff := backoff.WithContext(eb, t.ctx)
82+
for {
83+
hk, err := t.workQ.acquire()
84+
if err != nil {
85+
// context expired
86+
return
87+
}
88+
err = backoff.Retry(func() error {
89+
return t.writeOne(hk)
90+
}, bkoff)
91+
if err != nil {
92+
bkoff.Reset()
93+
}
94+
t.workQ.done(hk)
95+
}
96+
}
97+
98+
func (t *handshaker) writeOne(hk hKey) error {
99+
logger := t.logger.With(
100+
slog.F("src_id", hk.src),
101+
slog.F("dst_id", hk.dst),
102+
)
103+
104+
peers, err := t.store.GetTailnetTunnelPeerIDs(t.ctx, hk.src)
105+
if err != nil {
106+
if !database.IsQueryCanceledError(err) {
107+
logger.Error(t.ctx, "get tunnel peers ids", slog.Error(err))
108+
}
109+
return err
110+
}
111+
112+
if !slices.ContainsFunc(peers, func(peer database.GetTailnetTunnelPeerIDsRow) bool {
113+
return peer.PeerID == hk.dst
114+
}) {
115+
// In the in-memory coordinator we return an error to the client, but
116+
// this isn't really possible here.
117+
logger.Warn(t.ctx, "cannot process ready for handshake, src isn't peered with dst")
118+
return nil
119+
}
120+
121+
err = t.store.PublishReadyForHandshake(t.ctx, database.PublishReadyForHandshakeParams{
122+
To: hk.dst.String(),
123+
From: hk.src.String(),
124+
})
125+
if err != nil {
126+
if !database.IsQueryCanceledError(err) {
127+
logger.Error(t.ctx, "publish ready for handshake", slog.Error(err))
128+
}
129+
return err
130+
}
131+
132+
return nil
133+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package tailnet
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/google/uuid"
8+
"go.uber.org/mock/gomock"
9+
10+
"cdr.dev/slog"
11+
"cdr.dev/slog/sloggers/slogtest"
12+
"github.com/coder/coder/v2/coderd/database"
13+
"github.com/coder/coder/v2/coderd/database/dbmock"
14+
"github.com/coder/coder/v2/testutil"
15+
)
16+
17+
func Test_handshaker_NoPermission(t *testing.T) {
18+
t.Parallel()
19+
20+
ctrl := gomock.NewController(t)
21+
defer ctrl.Finish()
22+
mDB := dbmock.NewMockStore(ctrl)
23+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
24+
defer cancel()
25+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
26+
27+
rfhCh := make(chan readyForHandshake)
28+
ready := make(chan struct{})
29+
close(ready)
30+
31+
srcID, dstID := uuid.New(), uuid.New()
32+
33+
newHandshaker(ctx, logger, uuid.New(), mDB, rfhCh, ready)
34+
35+
called := make(chan struct{})
36+
mDB.EXPECT().GetTailnetTunnelPeerIDs(gomock.Any(), srcID).
37+
DoAndReturn(func(context.Context, uuid.UUID) ([]database.GetTailnetTunnelPeerIDsRow, error) {
38+
close(called)
39+
return []database.GetTailnetTunnelPeerIDsRow{}, nil
40+
})
41+
rfhCh <- readyForHandshake{hKey{src: srcID, dst: dstID}}
42+
<-called
43+
// the handshaker should not attempt to broadcast the rfh. if it does, the
44+
// mock will catch an unmocked call.
45+
}

0 commit comments

Comments
 (0)