Skip to content

feat(enterprise): add ready for handshake support to pgcoord #12935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion coderd/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *sql.TxOptions) err
// Transaction succeeded.
return nil
}
if err != nil && !IsSerializedError(err) {
if !IsSerializedError(err) {
// We should only retry if the error is a serialization error.
return err
}
Expand Down
52 changes: 52 additions & 0 deletions enterprise/tailnet/connio.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package tailnet

import (
"context"
"fmt"
"slices"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -30,10 +32,13 @@ type connIO struct {
responses chan<- *proto.CoordinateResponse
bindings chan<- binding
tunnels chan<- tunnel
rfhs chan<- readyForHandshake
auth agpl.CoordinateeAuth
mu sync.Mutex
closed bool
disconnected bool
// latest is the most recent, unfiltered snapshot of the mappings we know about
latest []mapping

name string
start int64
Expand All @@ -46,6 +51,7 @@ func newConnIO(coordContext context.Context,
logger slog.Logger,
bindings chan<- binding,
tunnels chan<- tunnel,
rfhs chan<- readyForHandshake,
requests <-chan *proto.CoordinateRequest,
responses chan<- *proto.CoordinateResponse,
id uuid.UUID,
Expand All @@ -64,6 +70,7 @@ func newConnIO(coordContext context.Context,
responses: responses,
bindings: bindings,
tunnels: tunnels,
rfhs: rfhs,
auth: auth,
name: name,
start: now,
Expand Down Expand Up @@ -190,9 +197,54 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
c.disconnected = true
return errDisconnect
}
if req.ReadyForHandshake != nil {
c.logger.Debug(c.peerCtx, "got ready for handshake ", slog.F("rfh", req.ReadyForHandshake))
for _, rfh := range req.ReadyForHandshake {
dst, err := uuid.FromBytes(rfh.Id)
if err != nil {
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
// this shouldn't happen unless there is a client error. Close the connection so the client
// doesn't just happily continue thinking everything is fine.
return err
}

mappings := c.getLatestMapping()
if !slices.ContainsFunc(mappings, func(mapping mapping) bool {
return mapping.peer == dst
}) {
c.logger.Debug(c.peerCtx, "cannot process ready for handshake, src isn't peered with dst",
slog.F("dst", dst.String()),
)
_ = c.Enqueue(&proto.CoordinateResponse{
Error: fmt.Sprintf("you do not share a tunnel with %q", dst.String()),
})
return nil
}

if err := agpl.SendCtx(c.coordCtx, c.rfhs, readyForHandshake{
src: c.id,
dst: dst,
}); err != nil {
c.logger.Debug(c.peerCtx, "failed to send ready for handshake", slog.Error(err))
return err
}
}
}
return nil
}

func (c *connIO) setLatestMapping(latest []mapping) {
c.mu.Lock()
defer c.mu.Unlock()
c.latest = latest
}

func (c *connIO) getLatestMapping() []mapping {
c.mu.Lock()
defer c.mu.Unlock()
return c.latest
}

func (c *connIO) UniqueID() uuid.UUID {
return c.id
}
Expand Down
73 changes: 73 additions & 0 deletions enterprise/tailnet/handshaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package tailnet

import (
"context"
"fmt"
"sync"

"github.com/google/uuid"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database/pubsub"
)

type readyForHandshake struct {
src uuid.UUID
dst uuid.UUID
}

type handshaker struct {
ctx context.Context
logger slog.Logger
coordinatorID uuid.UUID
pubsub pubsub.Pubsub
updates <-chan readyForHandshake

workerWG sync.WaitGroup
}

func newHandshaker(ctx context.Context,
logger slog.Logger,
id uuid.UUID,
ps pubsub.Pubsub,
updates <-chan readyForHandshake,
startWorkers <-chan struct{},
) *handshaker {
s := &handshaker{
ctx: ctx,
logger: logger,
coordinatorID: id,
pubsub: ps,
updates: updates,
}
// add to the waitgroup immediately to avoid any races waiting for it before
// the workers start.
s.workerWG.Add(numHandshakerWorkers)
go func() {
<-startWorkers
for i := 0; i < numHandshakerWorkers; i++ {
go s.worker()
}
}()
return s
}

func (t *handshaker) worker() {
defer t.workerWG.Done()

for {
select {
case <-t.ctx.Done():
t.logger.Debug(t.ctx, "handshaker worker exiting", slog.Error(t.ctx.Err()))
return

case rfh := <-t.updates:
err := t.pubsub.Publish(eventReadyForHandshake, []byte(fmt.Sprintf(
"%s,%s", rfh.dst.String(), rfh.src.String(),
)))
if err != nil {
t.logger.Error(t.ctx, "publish ready for handshake", slog.Error(err))
}
}
}
}
47 changes: 47 additions & 0 deletions enterprise/tailnet/handshaker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package tailnet_test

import (
"context"
"testing"

"github.com/stretchr/testify/require"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/enterprise/tailnet"
agpltest "github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)

func TestPGCoordinator_ReadyForHandshake_OK(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()

agpltest.ReadyForHandshakeTest(ctx, t, coord1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice structure & code reuse

}

func TestPGCoordinator_ReadyForHandshake_NoPermission(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()

agpltest.ReadyForHandshakeNoPermissionTest(ctx, t, coord1)
}
Loading
Loading