Skip to content

Commit 4f87ac5

Browse files
committed
feat: implement DERP streaming on tailnet Client API
1 parent fe867d0 commit 4f87ac5

File tree

3 files changed

+77
-18
lines changed

3 files changed

+77
-18
lines changed

coderd/coderd.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,11 @@ func New(options *Options) *API {
479479
}
480480
}
481481
api.TailnetClientService, err = tailnet.NewClientService(
482-
api.Logger.Named("tailnetclient"), &api.TailnetCoordinator)
482+
api.Logger.Named("tailnetclient"),
483+
&api.TailnetCoordinator,
484+
api.Options.DERPMapUpdateFrequency,
485+
api.DERPMap,
486+
)
483487
if err != nil {
484488
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))
485489
}

tailnet/service.go

+51-14
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@ import (
77
"strconv"
88
"strings"
99
"sync/atomic"
10+
"time"
1011

1112
"github.com/google/uuid"
1213
"github.com/hashicorp/yamux"
14+
"storj.io/drpc"
1315
"storj.io/drpc/drpcmux"
1416
"storj.io/drpc/drpcserver"
17+
"tailscale.com/tailcfg"
1518

1619
"cdr.dev/slog"
1720
"github.com/coder/coder/v2/tailnet/proto"
@@ -92,10 +95,22 @@ type ClientService struct {
9295

9396
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
9497
// loaded on each processed connection.
95-
func NewClientService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) (*ClientService, error) {
98+
func NewClientService(
99+
logger slog.Logger,
100+
coordPtr *atomic.Pointer[Coordinator],
101+
derpMapUpdateFrequency time.Duration,
102+
derpMapFn func() *tailcfg.DERPMap,
103+
) (
104+
*ClientService, error,
105+
) {
96106
s := &ClientService{logger: logger, coordPtr: coordPtr}
97107
mux := drpcmux.New()
98-
drpcService := NewDRPCService(logger, coordPtr)
108+
drpcService := &DRPCService{
109+
CoordPtr: coordPtr,
110+
Logger: logger,
111+
DerpMapUpdateFrequency: derpMapUpdateFrequency,
112+
DerpMapFn: derpMapFn,
113+
}
99114
err := proto.DRPCRegisterClient(mux, drpcService)
100115
if err != nil {
101116
return nil, xerrors.Errorf("register DRPC service: %w", err)
@@ -145,20 +160,42 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne
145160

146161
// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
147162
type DRPCService struct {
148-
coordPtr *atomic.Pointer[Coordinator]
149-
logger slog.Logger
163+
CoordPtr *atomic.Pointer[Coordinator]
164+
Logger slog.Logger
165+
DerpMapUpdateFrequency time.Duration
166+
DerpMapFn func() *tailcfg.DERPMap
150167
}
151168

152-
func NewDRPCService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) *DRPCService {
153-
return &DRPCService{
154-
coordPtr: coordPtr,
155-
logger: logger,
156-
}
169+
type StreamDERPMapsStream interface {
170+
drpc.Stream
171+
Send(*proto.DERPMap) error
157172
}
158173

159-
func (*DRPCService) StreamDERPMaps(*proto.StreamDERPMapsRequest, proto.DRPCClient_StreamDERPMapsStream) error {
160-
// TODO integrate with Dean's PR implementation
161-
return xerrors.New("unimplemented")
174+
func (s *DRPCService) StreamDERPMaps(_ *proto.StreamDERPMapsRequest, stream proto.DRPCClient_StreamDERPMapsStream) error {
175+
defer stream.Close()
176+
177+
ticker := time.NewTicker(s.DerpMapUpdateFrequency)
178+
defer ticker.Stop()
179+
180+
var lastDERPMap *tailcfg.DERPMap
181+
for {
182+
derpMap := s.DerpMapFn()
183+
if lastDERPMap == nil || !CompareDERPMaps(lastDERPMap, derpMap) {
184+
protoDERPMap := DERPMapToProto(derpMap)
185+
err := stream.Send(protoDERPMap)
186+
if err != nil {
187+
return xerrors.Errorf("send derp map: %w", err)
188+
}
189+
lastDERPMap = derpMap
190+
}
191+
192+
ticker.Reset(s.DerpMapUpdateFrequency)
193+
select {
194+
case <-stream.Context().Done():
195+
return nil
196+
case <-ticker.C:
197+
}
198+
}
162199
}
163200

164201
func (s *DRPCService) CoordinateTailnet(stream proto.DRPCClient_CoordinateTailnetStream) error {
@@ -168,9 +205,9 @@ func (s *DRPCService) CoordinateTailnet(stream proto.DRPCClient_CoordinateTailne
168205
_ = stream.Close()
169206
return xerrors.New("no Stream ID")
170207
}
171-
logger := s.logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name))
208+
logger := s.Logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name))
172209
logger.Debug(ctx, "starting tailnet Coordinate")
173-
coord := *(s.coordPtr.Load())
210+
coord := *(s.CoordPtr.Load())
174211
reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth)
175212
c := communicator{
176213
logger: logger,

tailnet/service_test.go

+21-3
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import (
88
"net/http"
99
"sync/atomic"
1010
"testing"
11+
"time"
1112

1213
"golang.org/x/xerrors"
14+
"tailscale.com/tailcfg"
1315

1416
"github.com/google/uuid"
1517

@@ -94,7 +96,11 @@ func TestClientService_ServeClient_V2(t *testing.T) {
9496
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
9597
coordPtr.Store(&coord)
9698
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
97-
uut, err := tailnet.NewClientService(logger, &coordPtr)
99+
derpMap := &tailcfg.DERPMap{Regions: map[int]*tailcfg.DERPRegion{999: {RegionCode: "test"}}}
100+
uut, err := tailnet.NewClientService(
101+
logger, &coordPtr,
102+
time.Millisecond, func() *tailcfg.DERPMap { return derpMap },
103+
)
98104
require.NoError(t, err)
99105

100106
ctx := testutil.Context(t, testutil.WaitShort)
@@ -112,6 +118,8 @@ func TestClientService_ServeClient_V2(t *testing.T) {
112118

113119
client, err := tailnet.NewDRPCClient(c)
114120
require.NoError(t, err)
121+
122+
// Coordinate
115123
stream, err := client.CoordinateTailnet(ctx)
116124
require.NoError(t, err)
117125
defer stream.Close()
@@ -145,7 +153,17 @@ func TestClientService_ServeClient_V2(t *testing.T) {
145153
err = stream.Close()
146154
require.NoError(t, err)
147155

148-
// stream ^^ is just one RPC; we need to close the Conn to end the session.
156+
// DERP Map
157+
dms, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{})
158+
require.NoError(t, err)
159+
160+
gotDermMap, err := dms.Recv()
161+
require.NoError(t, err)
162+
require.Equal(t, "test", gotDermMap.GetRegions()[999].GetRegionCode())
163+
err = dms.Close()
164+
require.NoError(t, err)
165+
166+
// RPCs closed; we need to close the Conn to end the session.
149167
err = c.Close()
150168
require.NoError(t, err)
151169
err = testutil.RequireRecvCtx(ctx, t, errCh)
@@ -159,7 +177,7 @@ func TestClientService_ServeClient_V1(t *testing.T) {
159177
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
160178
coordPtr.Store(&coord)
161179
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
162-
uut, err := tailnet.NewClientService(logger, &coordPtr)
180+
uut, err := tailnet.NewClientService(logger, &coordPtr, 0, nil)
163181
require.NoError(t, err)
164182

165183
ctx := testutil.Context(t, testutil.WaitShort)

0 commit comments

Comments
 (0)