Skip to content

feat: modify PG Coordinator to work with new v2 Tailnet API #10573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 159 additions & 65 deletions enterprise/tailnet/connio.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,136 +2,230 @@ package tailnet

import (
"context"
"encoding/json"
"io"
"net"
"sync"
"sync/atomic"
"time"

"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"

"cdr.dev/slog"

agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
)

// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to
// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings
// via its updates TrackedConn, which then writes them.
// connIO manages the reading and writing to a connected peer. It reads requests via its requests
// channel, then pushes them onto the bindings or tunnels channel. It receives responses via calls
// to Enqueue and pushes them onto the responses channel.
type connIO struct {
pCtx context.Context
ctx context.Context
cancel context.CancelFunc
logger slog.Logger
decoder *json.Decoder
updates *agpl.TrackedConn
bindings chan<- binding
id uuid.UUID
// coordCtx is the parent context, that is, the context of the Coordinator
coordCtx context.Context
// peerCtx is the context of the connection to our peer
peerCtx context.Context
cancel context.CancelFunc
logger slog.Logger
requests <-chan *proto.CoordinateRequest
responses chan<- *proto.CoordinateResponse
bindings chan<- binding
tunnels chan<- tunnel
auth agpl.TunnelAuth
mu sync.Mutex
closed bool

name string
start int64
lastWrite int64
overwrites int64
}

func newConnIO(pCtx context.Context,
func newConnIO(coordContext context.Context,
peerCtx context.Context,
logger slog.Logger,
bindings chan<- binding,
conn net.Conn,
tunnels chan<- tunnel,
requests <-chan *proto.CoordinateRequest,
responses chan<- *proto.CoordinateResponse,
id uuid.UUID,
name string,
kind agpl.QueueKind,
auth agpl.TunnelAuth,
) *connIO {
ctx, cancel := context.WithCancel(pCtx)
peerCtx, cancel := context.WithCancel(peerCtx)
now := time.Now().Unix()
c := &connIO{
pCtx: pCtx,
ctx: ctx,
cancel: cancel,
logger: logger,
decoder: json.NewDecoder(conn),
updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind),
bindings: bindings,
id: id,
coordCtx: coordContext,
peerCtx: peerCtx,
cancel: cancel,
logger: logger.With(slog.F("name", name)),
requests: requests,
responses: responses,
bindings: bindings,
tunnels: tunnels,
auth: auth,
name: name,
start: now,
lastWrite: now,
}
go c.recvLoop()
go c.updates.SendUpdates()
logger.Info(ctx, "serving connection")
c.logger.Info(coordContext, "serving connection")
return c
}

func (c *connIO) recvLoop() {
defer func() {
// withdraw bindings when we exit. We need to use the parent context here, since our own context might be
// canceled, but we still need to withdraw bindings.
// withdraw bindings & tunnels when we exit. We need to use the parent context here, since
// our own context might be canceled, but we still need to withdraw.
b := binding{
bKey: bKey{
id: c.UniqueID(),
kind: c.Kind(),
},
bKey: bKey(c.UniqueID()),
}
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
}
if err := sendCtx(c.pCtx, c.bindings, b); err != nil {
c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err))
t := tunnel{
tKey: tKey{src: c.UniqueID()},
active: false,
}
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
}
}()
defer c.cancel()
defer c.Close()
for {
var node agpl.Node
err := c.decoder.Decode(&node)
req, err := recvCtx(c.peerCtx, c.requests)
if err != nil {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, io.ErrClosedPipe) ||
xerrors.Is(err, context.Canceled) ||
if xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) ||
websocket.CloseStatus(err) > 0 {
c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err))
xerrors.Is(err, io.EOF) {
c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err))
} else {
c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err))
c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err))
}
return
}
c.logger.Debug(c.ctx, "got node update", slog.F("node", node))
if err := c.handleRequest(req); err != nil {
return
}
}
}

func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
c.logger.Debug(c.peerCtx, "got request")
if req.UpdateSelf != nil {
c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf))
b := binding{
bKey: bKey{
id: c.UniqueID(),
kind: c.Kind(),
bKey: bKey(c.UniqueID()),
node: req.UpdateSelf.Node,
}
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
return err
}
}
if req.AddTunnel != nil {
c.logger.Debug(c.peerCtx, "got add tunnel", slog.F("tunnel", req.AddTunnel))
dst, err := uuid.FromBytes(req.AddTunnel.Uuid)
if err != nil {
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
// this shouldn't happen unless there is a client error. Close the connection so the client
// doesn't just happily continue thinking everything is fine.
return err
}
if !c.auth.Authorize(dst) {
return xerrors.New("unauthorized tunnel")
}
t := tunnel{
tKey: tKey{
src: c.UniqueID(),
dst: dst,
},
node: &node,
active: true,
}
if err := sendCtx(c.ctx, c.bindings, b); err != nil {
c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err))
return
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.peerCtx, "failed to send add tunnel", slog.Error(err))
return err
}
}
if req.RemoveTunnel != nil {
c.logger.Debug(c.peerCtx, "got remove tunnel", slog.F("tunnel", req.RemoveTunnel))
dst, err := uuid.FromBytes(req.RemoveTunnel.Uuid)
if err != nil {
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
// this shouldn't happen unless there is a client error. Close the connection so the client
// doesn't just happily continue thinking everything is fine.
return err
}
t := tunnel{
tKey: tKey{
src: c.UniqueID(),
dst: dst,
},
active: false,
}
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.peerCtx, "failed to send remove tunnel", slog.Error(err))
return err
}
}
// TODO: (spikecurtis) support Disconnect
return nil
}

func (c *connIO) UniqueID() uuid.UUID {
return c.updates.UniqueID()
}

func (c *connIO) Kind() agpl.QueueKind {
return c.updates.Kind()
return c.id
}

func (c *connIO) Enqueue(n []*agpl.Node) error {
return c.updates.Enqueue(n)
func (c *connIO) Enqueue(resp *proto.CoordinateResponse) error {
atomic.StoreInt64(&c.lastWrite, time.Now().Unix())
c.mu.Lock()
closed := c.closed
c.mu.Unlock()
if closed {
return xerrors.New("connIO closed")
}
select {
case <-c.peerCtx.Done():
return c.peerCtx.Err()
case c.responses <- resp:
c.logger.Debug(c.peerCtx, "wrote response")
return nil
default:
return agpl.ErrWouldBlock
}
}

func (c *connIO) Name() string {
return c.updates.Name()
return c.name
}

func (c *connIO) Stats() (start int64, lastWrite int64) {
return c.updates.Stats()
return c.start, atomic.LoadInt64(&c.lastWrite)
}

func (c *connIO) Overwrites() int64 {
return c.updates.Overwrites()
return atomic.LoadInt64(&c.overwrites)
}

// CoordinatorClose is used by the coordinator when closing a Queue. It
// should skip removing itself from the coordinator.
func (c *connIO) CoordinatorClose() error {
c.cancel()
return c.updates.CoordinatorClose()
return c.Close()
}

func (c *connIO) Done() <-chan struct{} {
return c.ctx.Done()
return c.peerCtx.Done()
}

func (c *connIO) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.cancel()
return c.updates.Close()
c.closed = true
close(c.responses)
return nil
}
30 changes: 12 additions & 18 deletions enterprise/tailnet/multiagent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
t.Skip("test only with postgres")
}

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
Expand Down Expand Up @@ -75,11 +74,10 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
t.Skip("test only with postgres")
}

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
Expand Down Expand Up @@ -124,11 +122,10 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
t.Skip("test only with postgres")
}

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
Expand Down Expand Up @@ -189,11 +186,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
t.Skip("test only with postgres")
}

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
Expand Down Expand Up @@ -243,11 +239,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
t.Skip("test only with postgres")
}

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
Expand Down Expand Up @@ -299,11 +294,10 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
t.Skip("test only with postgres")
}

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
Expand Down
Loading