Skip to content

Commit d0acad3

Browse files
committed
feat: implement DERP streaming on tailnet Client API
1 parent fe867d0 commit d0acad3

File tree

3 files changed

+72
-19
lines changed

3 files changed

+72
-19
lines changed

coderd/coderd.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,11 @@ func New(options *Options) *API {
479479
}
480480
}
481481
api.TailnetClientService, err = tailnet.NewClientService(
482-
api.Logger.Named("tailnetclient"), &api.TailnetCoordinator)
482+
api.Logger.Named("tailnetclient"),
483+
&api.TailnetCoordinator,
484+
api.Options.DERPMapUpdateFrequency,
485+
api.DERPMap,
486+
)
483487
if err != nil {
484488
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))
485489
}

tailnet/service.go

+46-15
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ import (
77
"strconv"
88
"strings"
99
"sync/atomic"
10+
"time"
1011

1112
"github.com/google/uuid"
1213
"github.com/hashicorp/yamux"
1314
"storj.io/drpc/drpcmux"
1415
"storj.io/drpc/drpcserver"
16+
"tailscale.com/tailcfg"
1517

1618
"cdr.dev/slog"
1719
"github.com/coder/coder/v2/tailnet/proto"
@@ -92,10 +94,22 @@ type ClientService struct {
9294

9395
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
9496
// loaded on each processed connection.
95-
func NewClientService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) (*ClientService, error) {
97+
func NewClientService(
98+
logger slog.Logger,
99+
coordPtr *atomic.Pointer[Coordinator],
100+
derpMapUpdateFrequency time.Duration,
101+
derpMapFn func() *tailcfg.DERPMap,
102+
) (
103+
*ClientService, error,
104+
) {
96105
s := &ClientService{logger: logger, coordPtr: coordPtr}
97106
mux := drpcmux.New()
98-
drpcService := NewDRPCService(logger, coordPtr)
107+
drpcService := &DRPCService{
108+
CoordPtr: coordPtr,
109+
Logger: logger,
110+
DerpMapUpdateFrequency: derpMapUpdateFrequency,
111+
DerpMapFn: derpMapFn,
112+
}
99113
err := proto.DRPCRegisterClient(mux, drpcService)
100114
if err != nil {
101115
return nil, xerrors.Errorf("register DRPC service: %w", err)
@@ -145,20 +159,37 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne
145159

146160
// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
147161
type DRPCService struct {
148-
coordPtr *atomic.Pointer[Coordinator]
149-
logger slog.Logger
162+
CoordPtr *atomic.Pointer[Coordinator]
163+
Logger slog.Logger
164+
DerpMapUpdateFrequency time.Duration
165+
DerpMapFn func() *tailcfg.DERPMap
150166
}
151167

152-
func NewDRPCService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) *DRPCService {
153-
return &DRPCService{
154-
coordPtr: coordPtr,
155-
logger: logger,
156-
}
157-
}
168+
func (s *DRPCService) StreamDERPMaps(_ *proto.StreamDERPMapsRequest, stream proto.DRPCClient_StreamDERPMapsStream) error {
169+
defer stream.Close()
170+
171+
ticker := time.NewTicker(s.DerpMapUpdateFrequency)
172+
defer ticker.Stop()
158173

159-
func (*DRPCService) StreamDERPMaps(*proto.StreamDERPMapsRequest, proto.DRPCClient_StreamDERPMapsStream) error {
160-
// TODO integrate with Dean's PR implementation
161-
return xerrors.New("unimplemented")
174+
var lastDERPMap *tailcfg.DERPMap
175+
for {
176+
derpMap := s.DerpMapFn()
177+
if lastDERPMap == nil || !CompareDERPMaps(lastDERPMap, derpMap) {
178+
protoDERPMap := DERPMapToProto(derpMap)
179+
err := stream.Send(protoDERPMap)
180+
if err != nil {
181+
return xerrors.Errorf("send derp map: %w", err)
182+
}
183+
lastDERPMap = derpMap
184+
}
185+
186+
ticker.Reset(s.DerpMapUpdateFrequency)
187+
select {
188+
case <-stream.Context().Done():
189+
return nil
190+
case <-ticker.C:
191+
}
192+
}
162193
}
163194

164195
func (s *DRPCService) CoordinateTailnet(stream proto.DRPCClient_CoordinateTailnetStream) error {
@@ -168,9 +199,9 @@ func (s *DRPCService) CoordinateTailnet(stream proto.DRPCClient_CoordinateTailne
168199
_ = stream.Close()
169200
return xerrors.New("no Stream ID")
170201
}
171-
logger := s.logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name))
202+
logger := s.Logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name))
172203
logger.Debug(ctx, "starting tailnet Coordinate")
173-
coord := *(s.coordPtr.Load())
204+
coord := *(s.CoordPtr.Load())
174205
reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth)
175206
c := communicator{
176207
logger: logger,

tailnet/service_test.go

+21-3
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import (
88
"net/http"
99
"sync/atomic"
1010
"testing"
11+
"time"
1112

1213
"golang.org/x/xerrors"
14+
"tailscale.com/tailcfg"
1315

1416
"github.com/google/uuid"
1517

@@ -94,7 +96,11 @@ func TestClientService_ServeClient_V2(t *testing.T) {
9496
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
9597
coordPtr.Store(&coord)
9698
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
97-
uut, err := tailnet.NewClientService(logger, &coordPtr)
99+
derpMap := &tailcfg.DERPMap{Regions: map[int]*tailcfg.DERPRegion{999: {RegionCode: "test"}}}
100+
uut, err := tailnet.NewClientService(
101+
logger, &coordPtr,
102+
time.Millisecond, func() *tailcfg.DERPMap { return derpMap },
103+
)
98104
require.NoError(t, err)
99105

100106
ctx := testutil.Context(t, testutil.WaitShort)
@@ -112,6 +118,8 @@ func TestClientService_ServeClient_V2(t *testing.T) {
112118

113119
client, err := tailnet.NewDRPCClient(c)
114120
require.NoError(t, err)
121+
122+
// Coordinate
115123
stream, err := client.CoordinateTailnet(ctx)
116124
require.NoError(t, err)
117125
defer stream.Close()
@@ -145,7 +153,17 @@ func TestClientService_ServeClient_V2(t *testing.T) {
145153
err = stream.Close()
146154
require.NoError(t, err)
147155

148-
// stream ^^ is just one RPC; we need to close the Conn to end the session.
156+
// DERP Map
157+
dms, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{})
158+
require.NoError(t, err)
159+
160+
gotDermMap, err := dms.Recv()
161+
require.NoError(t, err)
162+
require.Equal(t, "test", gotDermMap.GetRegions()[999].GetRegionCode())
163+
err = dms.Close()
164+
require.NoError(t, err)
165+
166+
// RPCs closed; we need to close the Conn to end the session.
149167
err = c.Close()
150168
require.NoError(t, err)
151169
err = testutil.RequireRecvCtx(ctx, t, errCh)
@@ -159,7 +177,7 @@ func TestClientService_ServeClient_V1(t *testing.T) {
159177
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
160178
coordPtr.Store(&coord)
161179
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
162-
uut, err := tailnet.NewClientService(logger, &coordPtr)
180+
uut, err := tailnet.NewClientService(logger, &coordPtr, 0, nil)
163181
require.NoError(t, err)
164182

165183
ctx := testutil.Context(t, testutil.WaitShort)

0 commit comments

Comments
 (0)