Skip to content

Commit 63bc699

Browse files
committed
feat: changes codersdk to use tailnet v2 for DERPMap updates
1 parent 924c97f commit 63bc699

File tree

1 file changed

+186
-136
lines changed

1 file changed

+186
-136
lines changed

codersdk/workspaceagents.go

+186-136
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"strings"
1515
"time"
1616

17+
"golang.org/x/sync/errgroup"
18+
1719
"github.com/google/uuid"
1820
"golang.org/x/xerrors"
1921
"nhooyr.io/websocket"
@@ -317,142 +319,28 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
317319
q := coordinateURL.Query()
318320
q.Add("version", proto.CurrentVersion.String())
319321
coordinateURL.RawQuery = q.Encode()
320-
closedCoordinator := make(chan struct{})
321-
// Must only ever be used once, send error OR close to avoid
322-
// reassignment race. Buffered so we don't hang in goroutine.
323-
firstCoordinator := make(chan error, 1)
324-
go func() {
325-
defer close(closedCoordinator)
326-
isFirst := true
327-
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
328-
options.Logger.Debug(ctx, "connecting")
329-
// nolint:bodyclose
330-
ws, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
331-
HTTPClient: c.HTTPClient,
332-
HTTPHeader: headers,
333-
// Need to disable compression to avoid a data-race.
334-
CompressionMode: websocket.CompressionDisabled,
335-
})
336-
if isFirst {
337-
if res != nil && res.StatusCode == http.StatusConflict {
338-
firstCoordinator <- ReadBodyAsError(res)
339-
return
340-
}
341-
isFirst = false
342-
close(firstCoordinator)
343-
}
344-
if err != nil {
345-
if errors.Is(err, context.Canceled) {
346-
return
347-
}
348-
options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
349-
continue
350-
}
351-
client, err := tailnet.NewDRPCClient(websocket.NetConn(ctx, ws, websocket.MessageBinary))
352-
if err != nil {
353-
options.Logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
354-
_ = ws.Close(websocket.StatusInternalError, "")
355-
continue
356-
}
357-
coordinate, err := client.Coordinate(ctx)
358-
if err != nil {
359-
options.Logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err))
360-
_ = ws.Close(websocket.StatusInternalError, "")
361-
continue
362-
}
363-
364-
coordination := tailnet.NewRemoteCoordination(options.Logger, coordinate, conn, agentID)
365-
options.Logger.Debug(ctx, "serving coordinator")
366-
err = <-coordination.Error()
367-
if errors.Is(err, context.Canceled) {
368-
_ = ws.Close(websocket.StatusGoingAway, "")
369-
return
370-
}
371-
if err != nil {
372-
options.Logger.Debug(ctx, "error serving coordinator", slog.Error(err))
373-
_ = ws.Close(websocket.StatusGoingAway, "")
374-
continue
375-
}
376-
_ = ws.Close(websocket.StatusGoingAway, "")
377-
}
378-
}()
379-
380-
derpMapURL, err := c.URL.Parse("/api/v2/derp-map")
381-
if err != nil {
382-
return nil, xerrors.Errorf("parse url: %w", err)
383-
}
384-
closedDerpMap := make(chan struct{})
385-
// Must only ever be used once, send error OR close to avoid
386-
// reassignment race. Buffered so we don't hang in goroutine.
387-
firstDerpMap := make(chan error, 1)
388-
go func() {
389-
defer close(closedDerpMap)
390-
isFirst := true
391-
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
392-
options.Logger.Debug(ctx, "connecting to server for derp map updates")
393-
// nolint:bodyclose
394-
ws, res, err := websocket.Dial(ctx, derpMapURL.String(), &websocket.DialOptions{
395-
HTTPClient: c.HTTPClient,
396-
HTTPHeader: headers,
397-
// Need to disable compression to avoid a data-race.
398-
CompressionMode: websocket.CompressionDisabled,
399-
})
400-
if isFirst {
401-
if res != nil && res.StatusCode == http.StatusConflict {
402-
firstDerpMap <- ReadBodyAsError(res)
403-
return
404-
}
405-
isFirst = false
406-
close(firstDerpMap)
407-
}
408-
if err != nil {
409-
if errors.Is(err, context.Canceled) {
410-
return
411-
}
412-
options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
413-
continue
414-
}
415-
416-
var (
417-
nconn = websocket.NetConn(ctx, ws, websocket.MessageBinary)
418-
dec = json.NewDecoder(nconn)
419-
)
420-
for {
421-
var derpMap tailcfg.DERPMap
422-
err := dec.Decode(&derpMap)
423-
if xerrors.Is(err, context.Canceled) {
424-
_ = ws.Close(websocket.StatusGoingAway, "")
425-
return
426-
}
427-
if err != nil {
428-
options.Logger.Debug(ctx, "failed to decode derp map", slog.Error(err))
429-
_ = ws.Close(websocket.StatusGoingAway, "")
430-
return
431-
}
432-
433-
if !tailnet.CompareDERPMaps(conn.DERPMap(), &derpMap) {
434-
options.Logger.Debug(ctx, "updating derp map due to detected changes")
435-
conn.SetDERPMap(&derpMap)
436-
}
437-
}
438-
}
439-
}()
440322

441-
for firstCoordinator != nil || firstDerpMap != nil {
442-
select {
443-
case <-dialCtx.Done():
444-
return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err())
445-
case err = <-firstCoordinator:
446-
if err != nil {
447-
return nil, xerrors.Errorf("start coordinator: %w", err)
448-
}
449-
firstCoordinator = nil
450-
case err = <-firstDerpMap:
451-
if err != nil {
452-
return nil, xerrors.Errorf("receive derp map: %w", err)
453-
}
454-
firstDerpMap = nil
323+
connector := runTailnetAPIConnector(ctx, options.Logger,
324+
agentID, coordinateURL.String(),
325+
&websocket.DialOptions{
326+
HTTPClient: c.HTTPClient,
327+
HTTPHeader: headers,
328+
// Need to disable compression to avoid a data-race.
329+
CompressionMode: websocket.CompressionDisabled,
330+
},
331+
conn,
332+
)
333+
options.Logger.Debug(ctx, "running tailnet API v2+ connector")
334+
335+
select {
336+
case <-dialCtx.Done():
337+
return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err())
338+
case err = <-connector.connected:
339+
if err != nil {
340+
options.Logger.Error(ctx, "failed to connect to tailnet v2+ API", slog.Error(err))
341+
return nil, xerrors.Errorf("start connector: %w", err)
455342
}
343+
options.Logger.Debug(ctx, "connected to tailnet v2+ API")
456344
}
457345

458346
agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{
@@ -464,8 +352,7 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
464352
AgentIP: WorkspaceAgentIP,
465353
CloseFunc: func() error {
466354
cancel()
467-
<-closedCoordinator
468-
<-closedDerpMap
355+
<-connector.closed
469356
return conn.Close()
470357
},
471358
})
@@ -478,6 +365,169 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
478365
return agentConn, nil
479366
}
480367

368+
// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
369+
//
370+
// 1) run the Coordinate API and pass node information back and forth
371+
// 2) stream DERPMap updates and program the Conn
372+
//
373+
// These functions share the same websocket, and so are combined here so that if we hit a problem
374+
// we tear the whole thing down and start over with a new websocket.
375+
type tailnetAPIConnector struct {
376+
ctx context.Context
377+
logger slog.Logger
378+
379+
agentID uuid.UUID
380+
coordinateURL string
381+
dialOptions *websocket.DialOptions
382+
conn *tailnet.Conn
383+
384+
connected chan error
385+
isFirst bool
386+
closed chan struct{}
387+
}
388+
389+
// runTailnetAPIConnector creates and runs a tailnetAPIConnector
390+
func runTailnetAPIConnector(
391+
ctx context.Context, logger slog.Logger,
392+
agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions,
393+
conn *tailnet.Conn,
394+
) *tailnetAPIConnector {
395+
tac := &tailnetAPIConnector{
396+
ctx: ctx,
397+
logger: logger,
398+
agentID: agentID,
399+
coordinateURL: coordinateURL,
400+
dialOptions: dialOptions,
401+
conn: conn,
402+
connected: make(chan error, 1),
403+
closed: make(chan struct{}),
404+
}
405+
go tac.run()
406+
return tac
407+
}
408+
409+
func (tac *tailnetAPIConnector) run() {
410+
tac.isFirst = true
411+
defer close(tac.closed)
412+
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
413+
tailnetClient, err := tac.dial()
414+
if err != nil {
415+
continue
416+
}
417+
tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client")
418+
tac.coordinateAndDERPMap(tailnetClient)
419+
tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost")
420+
}
421+
}
422+
423+
func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
424+
tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API")
425+
// nolint:bodyclose
426+
ws, res, err := websocket.Dial(tac.ctx, tac.coordinateURL, tac.dialOptions)
427+
if tac.isFirst {
428+
if res != nil && res.StatusCode == http.StatusConflict {
429+
err = ReadBodyAsError(res)
430+
tac.connected <- err
431+
return nil, err
432+
}
433+
tac.isFirst = false
434+
close(tac.connected)
435+
}
436+
if err != nil {
437+
if !errors.Is(err, context.Canceled) {
438+
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err))
439+
}
440+
return nil, err
441+
}
442+
client, err := tailnet.NewDRPCClient(websocket.NetConn(tac.ctx, ws, websocket.MessageBinary))
443+
if err != nil {
444+
tac.logger.Debug(tac.ctx, "failed to create DRPCClient", slog.Error(err))
445+
_ = ws.Close(websocket.StatusInternalError, "")
446+
return nil, err
447+
}
448+
return client, err
449+
}
450+
451+
// coordinateAndDERPMap uses the provided client to coordinate and stream DERP Maps. It is combined
452+
// into one function so that a problem with one tears down the other and triggers a retry (if
453+
// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
454+
// fate.
455+
func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetClient) {
456+
defer func() {
457+
conn := client.DRPCConn()
458+
closeErr := conn.Close()
459+
if closeErr != nil &&
460+
!xerrors.Is(closeErr, io.EOF) &&
461+
!xerrors.Is(closeErr, context.Canceled) &&
462+
!xerrors.Is(closeErr, context.DeadlineExceeded) {
463+
tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr))
464+
<-conn.Closed()
465+
}
466+
}()
467+
eg, egCtx := errgroup.WithContext(tac.ctx)
468+
eg.Go(func() error {
469+
return tac.coordinate(egCtx, client)
470+
})
471+
eg.Go(func() error {
472+
return tac.derpMap(egCtx, client)
473+
})
474+
err := eg.Wait()
475+
if err != nil &&
476+
!xerrors.Is(err, io.EOF) &&
477+
!xerrors.Is(err, context.Canceled) &&
478+
!xerrors.Is(err, context.DeadlineExceeded) {
479+
tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API")
480+
}
481+
}
482+
483+
func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error {
484+
coord, err := client.Coordinate(ctx)
485+
if err != nil {
486+
return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err)
487+
}
488+
defer func() {
489+
cErr := coord.Close()
490+
if cErr != nil {
491+
tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr))
492+
}
493+
}()
494+
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID)
495+
tac.logger.Debug(ctx, "serving coordinator")
496+
err = <-coordination.Error()
497+
if err != nil &&
498+
!xerrors.Is(err, io.EOF) &&
499+
!xerrors.Is(err, context.Canceled) &&
500+
!xerrors.Is(err, context.DeadlineExceeded) {
501+
return xerrors.Errorf("remote coordination error: %w", err)
502+
}
503+
return nil
504+
}
505+
506+
func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error {
507+
s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{})
508+
if err != nil {
509+
return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err)
510+
}
511+
defer func() {
512+
cErr := s.Close()
513+
if cErr != nil {
514+
tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
515+
}
516+
}()
517+
for {
518+
dmp, err := s.Recv()
519+
if err != nil {
520+
if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
521+
return nil
522+
}
523+
return xerrors.Errorf("error receiving DERP Map: %w", err)
524+
}
525+
tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp))
526+
dm := tailnet.DERPMapFromProto(dmp)
527+
tac.conn.SetDERPMap(dm)
528+
}
529+
}
530+
481531
// WatchWorkspaceAgentMetadata watches the metadata of a workspace agent.
482532
// The returned channel will be closed when the context is canceled. Exactly
483533
// one error will be sent on the error channel. The metadata channel is never closed.

0 commit comments

Comments
 (0)