Skip to content

Commit 5c48cb4

Browse files
authored
feat: modify PG Coordinator to work with new v2 Tailnet API (#10573)
re: #10528 Refactors PG Coordinator to work with the Tailnet v2 API, including wrappers for the existing v1 API. The debug endpoint functions, but doesn't return sensible data, that will be in another stacked PR.
1 parent a8c2518 commit 5c48cb4

File tree

8 files changed

+1024
-805
lines changed

8 files changed

+1024
-805
lines changed

enterprise/tailnet/connio.go

Lines changed: 159 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,136 +2,230 @@ package tailnet
22

33
import (
44
"context"
5-
"encoding/json"
65
"io"
7-
"net"
6+
"sync"
7+
"sync/atomic"
8+
"time"
89

910
"github.com/google/uuid"
1011
"golang.org/x/xerrors"
11-
"nhooyr.io/websocket"
1212

1313
"cdr.dev/slog"
14+
1415
agpl "github.com/coder/coder/v2/tailnet"
16+
"github.com/coder/coder/v2/tailnet/proto"
1517
)
1618

17-
// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to
18-
// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings
19-
// via its updates TrackedConn, which then writes them.
19+
// connIO manages the reading and writing to a connected peer. It reads requests via its requests
20+
// channel, then pushes them onto the bindings or tunnels channel. It receives responses via calls
21+
// to Enqueue and pushes them onto the responses channel.
2022
type connIO struct {
21-
pCtx context.Context
22-
ctx context.Context
23-
cancel context.CancelFunc
24-
logger slog.Logger
25-
decoder *json.Decoder
26-
updates *agpl.TrackedConn
27-
bindings chan<- binding
23+
id uuid.UUID
24+
// coordCtx is the parent context, that is, the context of the Coordinator
25+
coordCtx context.Context
26+
// peerCtx is the context of the connection to our peer
27+
peerCtx context.Context
28+
cancel context.CancelFunc
29+
logger slog.Logger
30+
requests <-chan *proto.CoordinateRequest
31+
responses chan<- *proto.CoordinateResponse
32+
bindings chan<- binding
33+
tunnels chan<- tunnel
34+
auth agpl.TunnelAuth
35+
mu sync.Mutex
36+
closed bool
37+
38+
name string
39+
start int64
40+
lastWrite int64
41+
overwrites int64
2842
}
2943

30-
func newConnIO(pCtx context.Context,
44+
func newConnIO(coordContext context.Context,
45+
peerCtx context.Context,
3146
logger slog.Logger,
3247
bindings chan<- binding,
33-
conn net.Conn,
48+
tunnels chan<- tunnel,
49+
requests <-chan *proto.CoordinateRequest,
50+
responses chan<- *proto.CoordinateResponse,
3451
id uuid.UUID,
3552
name string,
36-
kind agpl.QueueKind,
53+
auth agpl.TunnelAuth,
3754
) *connIO {
38-
ctx, cancel := context.WithCancel(pCtx)
55+
peerCtx, cancel := context.WithCancel(peerCtx)
56+
now := time.Now().Unix()
3957
c := &connIO{
40-
pCtx: pCtx,
41-
ctx: ctx,
42-
cancel: cancel,
43-
logger: logger,
44-
decoder: json.NewDecoder(conn),
45-
updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind),
46-
bindings: bindings,
58+
id: id,
59+
coordCtx: coordContext,
60+
peerCtx: peerCtx,
61+
cancel: cancel,
62+
logger: logger.With(slog.F("name", name)),
63+
requests: requests,
64+
responses: responses,
65+
bindings: bindings,
66+
tunnels: tunnels,
67+
auth: auth,
68+
name: name,
69+
start: now,
70+
lastWrite: now,
4771
}
4872
go c.recvLoop()
49-
go c.updates.SendUpdates()
50-
logger.Info(ctx, "serving connection")
73+
c.logger.Info(coordContext, "serving connection")
5174
return c
5275
}
5376

5477
func (c *connIO) recvLoop() {
5578
defer func() {
56-
// withdraw bindings when we exit. We need to use the parent context here, since our own context might be
57-
// canceled, but we still need to withdraw bindings.
79+
// withdraw bindings & tunnels when we exit. We need to use the parent context here, since
80+
// our own context might be canceled, but we still need to withdraw.
5881
b := binding{
59-
bKey: bKey{
60-
id: c.UniqueID(),
61-
kind: c.Kind(),
62-
},
82+
bKey: bKey(c.UniqueID()),
83+
}
84+
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
85+
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
6386
}
64-
if err := sendCtx(c.pCtx, c.bindings, b); err != nil {
65-
c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err))
87+
t := tunnel{
88+
tKey: tKey{src: c.UniqueID()},
89+
active: false,
90+
}
91+
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
92+
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
6693
}
6794
}()
68-
defer c.cancel()
95+
defer c.Close()
6996
for {
70-
var node agpl.Node
71-
err := c.decoder.Decode(&node)
97+
req, err := recvCtx(c.peerCtx, c.requests)
7298
if err != nil {
73-
if xerrors.Is(err, io.EOF) ||
74-
xerrors.Is(err, io.ErrClosedPipe) ||
75-
xerrors.Is(err, context.Canceled) ||
99+
if xerrors.Is(err, context.Canceled) ||
76100
xerrors.Is(err, context.DeadlineExceeded) ||
77-
websocket.CloseStatus(err) > 0 {
78-
c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err))
101+
xerrors.Is(err, io.EOF) {
102+
c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err))
79103
} else {
80-
c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err))
104+
c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err))
81105
}
82106
return
83107
}
84-
c.logger.Debug(c.ctx, "got node update", slog.F("node", node))
108+
if err := c.handleRequest(req); err != nil {
109+
return
110+
}
111+
}
112+
}
113+
114+
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
115+
c.logger.Debug(c.peerCtx, "got request")
116+
if req.UpdateSelf != nil {
117+
c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf))
85118
b := binding{
86-
bKey: bKey{
87-
id: c.UniqueID(),
88-
kind: c.Kind(),
119+
bKey: bKey(c.UniqueID()),
120+
node: req.UpdateSelf.Node,
121+
}
122+
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
123+
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
124+
return err
125+
}
126+
}
127+
if req.AddTunnel != nil {
128+
c.logger.Debug(c.peerCtx, "got add tunnel", slog.F("tunnel", req.AddTunnel))
129+
dst, err := uuid.FromBytes(req.AddTunnel.Uuid)
130+
if err != nil {
131+
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
132+
// this shouldn't happen unless there is a client error. Close the connection so the client
133+
// doesn't just happily continue thinking everything is fine.
134+
return err
135+
}
136+
if !c.auth.Authorize(dst) {
137+
return xerrors.New("unauthorized tunnel")
138+
}
139+
t := tunnel{
140+
tKey: tKey{
141+
src: c.UniqueID(),
142+
dst: dst,
89143
},
90-
node: &node,
144+
active: true,
91145
}
92-
if err := sendCtx(c.ctx, c.bindings, b); err != nil {
93-
c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err))
94-
return
146+
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
147+
c.logger.Debug(c.peerCtx, "failed to send add tunnel", slog.Error(err))
148+
return err
149+
}
150+
}
151+
if req.RemoveTunnel != nil {
152+
c.logger.Debug(c.peerCtx, "got remove tunnel", slog.F("tunnel", req.RemoveTunnel))
153+
dst, err := uuid.FromBytes(req.RemoveTunnel.Uuid)
154+
if err != nil {
155+
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
156+
// this shouldn't happen unless there is a client error. Close the connection so the client
157+
// doesn't just happily continue thinking everything is fine.
158+
return err
159+
}
160+
t := tunnel{
161+
tKey: tKey{
162+
src: c.UniqueID(),
163+
dst: dst,
164+
},
165+
active: false,
166+
}
167+
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
168+
c.logger.Debug(c.peerCtx, "failed to send remove tunnel", slog.Error(err))
169+
return err
95170
}
96171
}
172+
// TODO: (spikecurtis) support Disconnect
173+
return nil
97174
}
98175

99176
func (c *connIO) UniqueID() uuid.UUID {
100-
return c.updates.UniqueID()
101-
}
102-
103-
func (c *connIO) Kind() agpl.QueueKind {
104-
return c.updates.Kind()
177+
return c.id
105178
}
106179

107-
func (c *connIO) Enqueue(n []*agpl.Node) error {
108-
return c.updates.Enqueue(n)
180+
func (c *connIO) Enqueue(resp *proto.CoordinateResponse) error {
181+
atomic.StoreInt64(&c.lastWrite, time.Now().Unix())
182+
c.mu.Lock()
183+
closed := c.closed
184+
c.mu.Unlock()
185+
if closed {
186+
return xerrors.New("connIO closed")
187+
}
188+
select {
189+
case <-c.peerCtx.Done():
190+
return c.peerCtx.Err()
191+
case c.responses <- resp:
192+
c.logger.Debug(c.peerCtx, "wrote response")
193+
return nil
194+
default:
195+
return agpl.ErrWouldBlock
196+
}
109197
}
110198

111199
func (c *connIO) Name() string {
112-
return c.updates.Name()
200+
return c.name
113201
}
114202

115203
func (c *connIO) Stats() (start int64, lastWrite int64) {
116-
return c.updates.Stats()
204+
return c.start, atomic.LoadInt64(&c.lastWrite)
117205
}
118206

119207
func (c *connIO) Overwrites() int64 {
120-
return c.updates.Overwrites()
208+
return atomic.LoadInt64(&c.overwrites)
121209
}
122210

123211
// CoordinatorClose is used by the coordinator when closing a Queue. It
124212
// should skip removing itself from the coordinator.
125213
func (c *connIO) CoordinatorClose() error {
126-
c.cancel()
127-
return c.updates.CoordinatorClose()
214+
return c.Close()
128215
}
129216

130217
func (c *connIO) Done() <-chan struct{} {
131-
return c.ctx.Done()
218+
return c.peerCtx.Done()
132219
}
133220

134221
func (c *connIO) Close() error {
222+
c.mu.Lock()
223+
defer c.mu.Unlock()
224+
if c.closed {
225+
return nil
226+
}
135227
c.cancel()
136-
return c.updates.Close()
228+
c.closed = true
229+
close(c.responses)
230+
return nil
137231
}

enterprise/tailnet/multiagent_test.go

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,10 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
2727
t.Skip("test only with postgres")
2828
}
2929

30-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
31-
defer cancel()
32-
3330
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
3431
store, ps := dbtestutil.NewDB(t)
32+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
33+
defer cancel()
3534
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
3635
require.NoError(t, err)
3736
defer coord1.Close()
@@ -75,11 +74,10 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
7574
t.Skip("test only with postgres")
7675
}
7776

78-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
79-
defer cancel()
80-
8177
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
8278
store, ps := dbtestutil.NewDB(t)
79+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
80+
defer cancel()
8381
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
8482
require.NoError(t, err)
8583
defer coord1.Close()
@@ -124,11 +122,10 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
124122
t.Skip("test only with postgres")
125123
}
126124

127-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
128-
defer cancel()
129-
130125
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
131126
store, ps := dbtestutil.NewDB(t)
127+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
128+
defer cancel()
132129
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
133130
require.NoError(t, err)
134131
defer coord1.Close()
@@ -189,11 +186,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
189186
t.Skip("test only with postgres")
190187
}
191188

192-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
193-
defer cancel()
194-
195189
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
196190
store, ps := dbtestutil.NewDB(t)
191+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
192+
defer cancel()
197193
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
198194
require.NoError(t, err)
199195
defer coord1.Close()
@@ -243,11 +239,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
243239
t.Skip("test only with postgres")
244240
}
245241

246-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
247-
defer cancel()
248-
249242
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
250243
store, ps := dbtestutil.NewDB(t)
244+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
245+
defer cancel()
251246
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
252247
require.NoError(t, err)
253248
defer coord1.Close()
@@ -299,11 +294,10 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
299294
t.Skip("test only with postgres")
300295
}
301296

302-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
303-
defer cancel()
304-
305297
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
306298
store, ps := dbtestutil.NewDB(t)
299+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
300+
defer cancel()
307301
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
308302
require.NoError(t, err)
309303
defer coord1.Close()

0 commit comments

Comments
 (0)