Skip to content

Commit 777dfbe

Browse files
authored
feat(enterprise): add ready for handshake support to pgcoord (#12935)
1 parent 942e902 commit 777dfbe

File tree

10 files changed

+364
-82
lines changed

10 files changed

+364
-82
lines changed

coderd/database/db.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *sql.TxOptions) err
103103
// Transaction succeeded.
104104
return nil
105105
}
106-
if err != nil && !IsSerializedError(err) {
106+
if !IsSerializedError(err) {
107107
// We should only retry if the error is a serialization error.
108108
return err
109109
}

enterprise/tailnet/connio.go

+52
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package tailnet
22

33
import (
44
"context"
5+
"fmt"
6+
"slices"
57
"sync"
68
"sync/atomic"
79
"time"
@@ -30,10 +32,13 @@ type connIO struct {
3032
responses chan<- *proto.CoordinateResponse
3133
bindings chan<- binding
3234
tunnels chan<- tunnel
35+
rfhs chan<- readyForHandshake
3336
auth agpl.CoordinateeAuth
3437
mu sync.Mutex
3538
closed bool
3639
disconnected bool
40+
// latest is the most recent, unfiltered snapshot of the mappings we know about
41+
latest []mapping
3742

3843
name string
3944
start int64
@@ -46,6 +51,7 @@ func newConnIO(coordContext context.Context,
4651
logger slog.Logger,
4752
bindings chan<- binding,
4853
tunnels chan<- tunnel,
54+
rfhs chan<- readyForHandshake,
4955
requests <-chan *proto.CoordinateRequest,
5056
responses chan<- *proto.CoordinateResponse,
5157
id uuid.UUID,
@@ -64,6 +70,7 @@ func newConnIO(coordContext context.Context,
6470
responses: responses,
6571
bindings: bindings,
6672
tunnels: tunnels,
73+
rfhs: rfhs,
6774
auth: auth,
6875
name: name,
6976
start: now,
@@ -190,9 +197,54 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
190197
c.disconnected = true
191198
return errDisconnect
192199
}
200+
if req.ReadyForHandshake != nil {
201+
c.logger.Debug(c.peerCtx, "got ready for handshake ", slog.F("rfh", req.ReadyForHandshake))
202+
for _, rfh := range req.ReadyForHandshake {
203+
dst, err := uuid.FromBytes(rfh.Id)
204+
if err != nil {
205+
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
206+
// this shouldn't happen unless there is a client error. Close the connection so the client
207+
// doesn't just happily continue thinking everything is fine.
208+
return err
209+
}
210+
211+
mappings := c.getLatestMapping()
212+
if !slices.ContainsFunc(mappings, func(mapping mapping) bool {
213+
return mapping.peer == dst
214+
}) {
215+
c.logger.Debug(c.peerCtx, "cannot process ready for handshake, src isn't peered with dst",
216+
slog.F("dst", dst.String()),
217+
)
218+
_ = c.Enqueue(&proto.CoordinateResponse{
219+
Error: fmt.Sprintf("you do not share a tunnel with %q", dst.String()),
220+
})
221+
return nil
222+
}
223+
224+
if err := agpl.SendCtx(c.coordCtx, c.rfhs, readyForHandshake{
225+
src: c.id,
226+
dst: dst,
227+
}); err != nil {
228+
c.logger.Debug(c.peerCtx, "failed to send ready for handshake", slog.Error(err))
229+
return err
230+
}
231+
}
232+
}
193233
return nil
194234
}
195235

236+
func (c *connIO) setLatestMapping(latest []mapping) {
237+
c.mu.Lock()
238+
defer c.mu.Unlock()
239+
c.latest = latest
240+
}
241+
242+
func (c *connIO) getLatestMapping() []mapping {
243+
c.mu.Lock()
244+
defer c.mu.Unlock()
245+
return c.latest
246+
}
247+
196248
func (c *connIO) UniqueID() uuid.UUID {
197249
return c.id
198250
}

enterprise/tailnet/handshaker.go

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package tailnet
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
8+
"github.com/google/uuid"
9+
10+
"cdr.dev/slog"
11+
"github.com/coder/coder/v2/coderd/database/pubsub"
12+
)
13+
14+
type readyForHandshake struct {
15+
src uuid.UUID
16+
dst uuid.UUID
17+
}
18+
19+
type handshaker struct {
20+
ctx context.Context
21+
logger slog.Logger
22+
coordinatorID uuid.UUID
23+
pubsub pubsub.Pubsub
24+
updates <-chan readyForHandshake
25+
26+
workerWG sync.WaitGroup
27+
}
28+
29+
func newHandshaker(ctx context.Context,
30+
logger slog.Logger,
31+
id uuid.UUID,
32+
ps pubsub.Pubsub,
33+
updates <-chan readyForHandshake,
34+
startWorkers <-chan struct{},
35+
) *handshaker {
36+
s := &handshaker{
37+
ctx: ctx,
38+
logger: logger,
39+
coordinatorID: id,
40+
pubsub: ps,
41+
updates: updates,
42+
}
43+
// add to the waitgroup immediately to avoid any races waiting for it before
44+
// the workers start.
45+
s.workerWG.Add(numHandshakerWorkers)
46+
go func() {
47+
<-startWorkers
48+
for i := 0; i < numHandshakerWorkers; i++ {
49+
go s.worker()
50+
}
51+
}()
52+
return s
53+
}
54+
55+
func (t *handshaker) worker() {
56+
defer t.workerWG.Done()
57+
58+
for {
59+
select {
60+
case <-t.ctx.Done():
61+
t.logger.Debug(t.ctx, "handshaker worker exiting", slog.Error(t.ctx.Err()))
62+
return
63+
64+
case rfh := <-t.updates:
65+
err := t.pubsub.Publish(eventReadyForHandshake, []byte(fmt.Sprintf(
66+
"%s,%s", rfh.dst.String(), rfh.src.String(),
67+
)))
68+
if err != nil {
69+
t.logger.Error(t.ctx, "publish ready for handshake", slog.Error(err))
70+
}
71+
}
72+
}
73+
}

enterprise/tailnet/handshaker_test.go

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package tailnet_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"cdr.dev/slog"
10+
"cdr.dev/slog/sloggers/slogtest"
11+
"github.com/coder/coder/v2/coderd/database/dbtestutil"
12+
"github.com/coder/coder/v2/enterprise/tailnet"
13+
agpltest "github.com/coder/coder/v2/tailnet/test"
14+
"github.com/coder/coder/v2/testutil"
15+
)
16+
17+
func TestPGCoordinator_ReadyForHandshake_OK(t *testing.T) {
18+
t.Parallel()
19+
if !dbtestutil.WillUsePostgres() {
20+
t.Skip("test only with postgres")
21+
}
22+
store, ps := dbtestutil.NewDB(t)
23+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
24+
defer cancel()
25+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
26+
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
27+
require.NoError(t, err)
28+
defer coord1.Close()
29+
30+
agpltest.ReadyForHandshakeTest(ctx, t, coord1)
31+
}
32+
33+
func TestPGCoordinator_ReadyForHandshake_NoPermission(t *testing.T) {
34+
t.Parallel()
35+
if !dbtestutil.WillUsePostgres() {
36+
t.Skip("test only with postgres")
37+
}
38+
store, ps := dbtestutil.NewDB(t)
39+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
40+
defer cancel()
41+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
42+
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
43+
require.NoError(t, err)
44+
defer coord1.Close()
45+
46+
agpltest.ReadyForHandshakeNoPermissionTest(ctx, t, coord1)
47+
}

0 commit comments

Comments
 (0)