Skip to content

Commit 0cbde53

Browse files
committed
feat: modify PG Coordinator to work with new v2 Tailnet API
1 parent 504d82c commit 0cbde53

File tree

8 files changed

+1032
-803
lines changed

8 files changed

+1032
-803
lines changed

enterprise/tailnet/connio.go

+159-65
Original file line numberDiff line numberDiff line change
@@ -2,136 +2,230 @@ package tailnet
22

33
import (
44
"context"
5-
"encoding/json"
65
"io"
7-
"net"
6+
"sync"
7+
"sync/atomic"
8+
"time"
89

910
"github.com/google/uuid"
1011
"golang.org/x/xerrors"
11-
"nhooyr.io/websocket"
1212

1313
"cdr.dev/slog"
14+
1415
agpl "github.com/coder/coder/v2/tailnet"
16+
"github.com/coder/coder/v2/tailnet/proto"
1517
)
1618

17-
// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to
18-
// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings
19-
// via its updates TrackedConn, which then writes them.
19+
// connIO manages the reading and writing to a connected peer. It reads requests via its requests
20+
// channel, then pushes them onto the bindings or tunnels channel. It receives responses via calls
21+
// to Enqueue and pushes them onto the responses channel.
2022
type connIO struct {
21-
pCtx context.Context
22-
ctx context.Context
23-
cancel context.CancelFunc
24-
logger slog.Logger
25-
decoder *json.Decoder
26-
updates *agpl.TrackedConn
27-
bindings chan<- binding
23+
id uuid.UUID
24+
// coordCtx is the parent context, that is, the context of the Coordinator
25+
coordCtx context.Context
26+
// peerCtx is the context of the connection to our peer
27+
peerCtx context.Context
28+
cancel context.CancelFunc
29+
logger slog.Logger
30+
requests <-chan *proto.CoordinateRequest
31+
responses chan<- *proto.CoordinateResponse
32+
bindings chan<- binding
33+
tunnels chan<- tunnel
34+
auth agpl.TunnelAuth
35+
mu sync.Mutex
36+
closed bool
37+
38+
name string
39+
start int64
40+
lastWrite int64
41+
overwrites int64
2842
}
2943

30-
func newConnIO(pCtx context.Context,
44+
func newConnIO(coordContext context.Context,
45+
peerCtx context.Context,
3146
logger slog.Logger,
3247
bindings chan<- binding,
33-
conn net.Conn,
48+
tunnels chan<- tunnel,
49+
requests <-chan *proto.CoordinateRequest,
50+
responses chan<- *proto.CoordinateResponse,
3451
id uuid.UUID,
3552
name string,
36-
kind agpl.QueueKind,
53+
auth agpl.TunnelAuth,
3754
) *connIO {
38-
ctx, cancel := context.WithCancel(pCtx)
55+
peerCtx, cancel := context.WithCancel(peerCtx)
56+
now := time.Now().Unix()
3957
c := &connIO{
40-
pCtx: pCtx,
41-
ctx: ctx,
42-
cancel: cancel,
43-
logger: logger,
44-
decoder: json.NewDecoder(conn),
45-
updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind),
46-
bindings: bindings,
58+
id: id,
59+
coordCtx: coordContext,
60+
peerCtx: peerCtx,
61+
cancel: cancel,
62+
logger: logger.With(slog.F("name", name)),
63+
requests: requests,
64+
responses: responses,
65+
bindings: bindings,
66+
tunnels: tunnels,
67+
auth: auth,
68+
name: name,
69+
start: now,
70+
lastWrite: now,
4771
}
4872
go c.recvLoop()
49-
go c.updates.SendUpdates()
50-
logger.Info(ctx, "serving connection")
73+
c.logger.Info(coordContext, "serving connection")
5174
return c
5275
}
5376

5477
func (c *connIO) recvLoop() {
5578
defer func() {
56-
// withdraw bindings when we exit. We need to use the parent context here, since our own context might be
57-
// canceled, but we still need to withdraw bindings.
79+
// withdraw bindings & tunnels when we exit. We need to use the parent context here, since
80+
// our own context might be canceled, but we still need to withdraw.
5881
b := binding{
59-
bKey: bKey{
60-
id: c.UniqueID(),
61-
kind: c.Kind(),
62-
},
82+
bKey: bKey(c.UniqueID()),
83+
}
84+
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
85+
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
6386
}
64-
if err := sendCtx(c.pCtx, c.bindings, b); err != nil {
65-
c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err))
87+
t := tunnel{
88+
tKey: tKey{src: c.UniqueID()},
89+
active: false,
90+
}
91+
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
92+
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
6693
}
6794
}()
68-
defer c.cancel()
95+
defer c.Close()
6996
for {
70-
var node agpl.Node
71-
err := c.decoder.Decode(&node)
97+
req, err := recvCtx(c.peerCtx, c.requests)
7298
if err != nil {
73-
if xerrors.Is(err, io.EOF) ||
74-
xerrors.Is(err, io.ErrClosedPipe) ||
75-
xerrors.Is(err, context.Canceled) ||
99+
if xerrors.Is(err, context.Canceled) ||
76100
xerrors.Is(err, context.DeadlineExceeded) ||
77-
websocket.CloseStatus(err) > 0 {
78-
c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err))
101+
xerrors.Is(err, io.EOF) {
102+
c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err))
79103
} else {
80-
c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err))
104+
c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err))
81105
}
82106
return
83107
}
84-
c.logger.Debug(c.ctx, "got node update", slog.F("node", node))
108+
if err := c.handleRequest(req); err != nil {
109+
return
110+
}
111+
}
112+
}
113+
114+
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
115+
c.logger.Debug(c.peerCtx, "got request")
116+
if req.UpdateSelf != nil {
117+
c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf))
85118
b := binding{
86-
bKey: bKey{
87-
id: c.UniqueID(),
88-
kind: c.Kind(),
119+
bKey: bKey(c.UniqueID()),
120+
node: req.UpdateSelf.Node,
121+
}
122+
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
123+
c.logger.Debug(c.peerCtx, "failed to send binding, context expired?", slog.Error(err))
124+
return err
125+
}
126+
}
127+
if req.AddTunnel != nil {
128+
c.logger.Debug(c.peerCtx, "got add tunnel", slog.F("tunnel", req.AddTunnel))
129+
dst, err := uuid.FromBytes(req.AddTunnel.Uuid)
130+
if err != nil {
131+
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
132+
// this shouldn't happen unless there is a client error. Close the connection so the client
133+
// doesn't just happily continue thinking everything is fine.
134+
return err
135+
}
136+
if !c.auth.Authorize(dst) {
137+
return xerrors.New("unauthorized tunnel")
138+
}
139+
t := tunnel{
140+
tKey: tKey{
141+
src: c.UniqueID(),
142+
dst: dst,
89143
},
90-
node: &node,
144+
active: true,
91145
}
92-
if err := sendCtx(c.ctx, c.bindings, b); err != nil {
93-
c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err))
94-
return
146+
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
147+
c.logger.Debug(c.peerCtx, "failed to send add tunnel, context expired?", slog.Error(err))
148+
return err
149+
}
150+
}
151+
if req.RemoveTunnel != nil {
152+
c.logger.Debug(c.peerCtx, "got remove tunnel", slog.F("tunnel", req.RemoveTunnel))
153+
dst, err := uuid.FromBytes(req.RemoveTunnel.Uuid)
154+
if err != nil {
155+
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
156+
// this shouldn't happen unless there is a client error. Close the connection so the client
157+
// doesn't just happily continue thinking everything is fine.
158+
return err
159+
}
160+
t := tunnel{
161+
tKey: tKey{
162+
src: c.UniqueID(),
163+
dst: dst,
164+
},
165+
active: false,
166+
}
167+
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
168+
c.logger.Debug(c.peerCtx, "failed to send remove tunnel, context expired?", slog.Error(err))
169+
return err
95170
}
96171
}
172+
// TODO: (spikecurtis) support Disconnect
173+
return nil
97174
}
98175

99176
func (c *connIO) UniqueID() uuid.UUID {
100-
return c.updates.UniqueID()
101-
}
102-
103-
func (c *connIO) Kind() agpl.QueueKind {
104-
return c.updates.Kind()
177+
return c.id
105178
}
106179

107-
func (c *connIO) Enqueue(n []*agpl.Node) error {
108-
return c.updates.Enqueue(n)
180+
func (c *connIO) Enqueue(resp *proto.CoordinateResponse) error {
181+
atomic.StoreInt64(&c.lastWrite, time.Now().Unix())
182+
c.mu.Lock()
183+
closed := c.closed
184+
c.mu.Unlock()
185+
if closed {
186+
return xerrors.New("connIO closed")
187+
}
188+
select {
189+
case <-c.peerCtx.Done():
190+
return c.peerCtx.Err()
191+
case c.responses <- resp:
192+
c.logger.Debug(c.peerCtx, "wrote response")
193+
return nil
194+
default:
195+
return agpl.ErrWouldBlock
196+
}
109197
}
110198

111199
func (c *connIO) Name() string {
112-
return c.updates.Name()
200+
return c.name
113201
}
114202

115203
func (c *connIO) Stats() (start int64, lastWrite int64) {
116-
return c.updates.Stats()
204+
return c.start, atomic.LoadInt64(&c.lastWrite)
117205
}
118206

119207
func (c *connIO) Overwrites() int64 {
120-
return c.updates.Overwrites()
208+
return atomic.LoadInt64(&c.overwrites)
121209
}
122210

123211
// CoordinatorClose is used by the coordinator when closing a Queue. It
124212
// should skip removing itself from the coordinator.
125213
func (c *connIO) CoordinatorClose() error {
126-
c.cancel()
127-
return c.updates.CoordinatorClose()
214+
return c.Close()
128215
}
129216

130217
func (c *connIO) Done() <-chan struct{} {
131-
return c.ctx.Done()
218+
return c.peerCtx.Done()
132219
}
133220

134221
func (c *connIO) Close() error {
222+
c.mu.Lock()
223+
defer c.mu.Unlock()
224+
if c.closed {
225+
return nil
226+
}
135227
c.cancel()
136-
return c.updates.Close()
228+
c.closed = true
229+
close(c.responses)
230+
return nil
137231
}

enterprise/tailnet/multiagent_test.go

+12-18
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,10 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
2727
t.Skip("test only with postgres")
2828
}
2929

30-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
31-
defer cancel()
32-
3330
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
3431
store, ps := dbtestutil.NewDB(t)
32+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
33+
defer cancel()
3534
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
3635
require.NoError(t, err)
3736
defer coord1.Close()
@@ -75,11 +74,10 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
7574
t.Skip("test only with postgres")
7675
}
7776

78-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
79-
defer cancel()
80-
8177
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
8278
store, ps := dbtestutil.NewDB(t)
79+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
80+
defer cancel()
8381
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
8482
require.NoError(t, err)
8583
defer coord1.Close()
@@ -124,11 +122,10 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
124122
t.Skip("test only with postgres")
125123
}
126124

127-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
128-
defer cancel()
129-
130125
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
131126
store, ps := dbtestutil.NewDB(t)
127+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
128+
defer cancel()
132129
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
133130
require.NoError(t, err)
134131
defer coord1.Close()
@@ -189,11 +186,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
189186
t.Skip("test only with postgres")
190187
}
191188

192-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
193-
defer cancel()
194-
195189
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
196190
store, ps := dbtestutil.NewDB(t)
191+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
192+
defer cancel()
197193
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
198194
require.NoError(t, err)
199195
defer coord1.Close()
@@ -243,11 +239,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
243239
t.Skip("test only with postgres")
244240
}
245241

246-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
247-
defer cancel()
248-
249242
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
250243
store, ps := dbtestutil.NewDB(t)
244+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
245+
defer cancel()
251246
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
252247
require.NoError(t, err)
253248
defer coord1.Close()
@@ -299,11 +294,10 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
299294
t.Skip("test only with postgres")
300295
}
301296

302-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
303-
defer cancel()
304-
305297
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
306298
store, ps := dbtestutil.NewDB(t)
299+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
300+
defer cancel()
307301
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
308302
require.NoError(t, err)
309303
defer coord1.Close()

0 commit comments

Comments
 (0)