From 9f5135cd803919c44e907750e2249834f026fdf9 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 29 Oct 2024 08:56:03 +0400 Subject: [PATCH] chore: refactor tailnetAPIConnector to use dialer --- codersdk/workspacesdk/connector.go | 333 +++++++------- .../workspacesdk/connector_internal_test.go | 419 +++--------------- codersdk/workspacesdk/connector_test.go | 350 +++++++++++++++ codersdk/workspacesdk/workspacesdk.go | 16 +- tailnet/controllers.go | 2 +- 5 files changed, 594 insertions(+), 526 deletions(-) create mode 100644 codersdk/workspacesdk/connector_test.go diff --git a/codersdk/workspacesdk/connector.go b/codersdk/workspacesdk/connector.go index c50c2b012413a..fd4e028d31866 100644 --- a/codersdk/workspacesdk/connector.go +++ b/codersdk/workspacesdk/connector.go @@ -56,32 +56,28 @@ type tailnetAPIConnector struct { logger slog.Logger - agentID uuid.UUID - coordinateURL string - clock quartz.Clock - dialOptions *websocket.DialOptions - derpCtrl tailnet.DERPController - coordCtrl tailnet.CoordinationController - telCtrl *tailnet.BasicTelemetryController + agentID uuid.UUID + clock quartz.Clock + dialer tailnet.ControlProtocolDialer + derpCtrl tailnet.DERPController + coordCtrl tailnet.CoordinationController + telCtrl *tailnet.BasicTelemetryController + tokenCtrl tailnet.ResumeTokenController - connected chan error - resumeToken *proto.RefreshResumeTokenResponse - isFirst bool - closed chan struct{} + closed chan struct{} } // Create a new tailnetAPIConnector without running it -func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, clock quartz.Clock, dialOptions *websocket.DialOptions) *tailnetAPIConnector { +func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, dialer tailnet.ControlProtocolDialer, clock quartz.Clock) *tailnetAPIConnector { return &tailnetAPIConnector{ - ctx: ctx, - logger: logger, - agentID: agentID, - coordinateURL: coordinateURL, - clock: clock, - dialOptions: dialOptions, - connected: make(chan error, 1), - closed: make(chan struct{}), - telCtrl: tailnet.NewBasicTelemetryController(logger), + ctx: ctx, + logger: logger, + agentID: agentID, + clock: clock, + dialer: dialer, + closed: make(chan struct{}), + telCtrl: tailnet.NewBasicTelemetryController(logger), + tokenCtrl: tailnet.NewBasicResumeTokenController(logger, clock), } } @@ -105,17 +101,25 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) { tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background()) go tac.manageGracefulTimeout() go func() { - tac.isFirst = true defer close(tac.closed) // Sadly retry doesn't support quartz.Clock yet so this is not // influenced by the configured clock. for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); { - tailnetClient, err := tac.dial() + tailnetClients, err := tac.dialer.Dial(tac.ctx, tac.tokenCtrl) if err != nil { + if xerrors.Is(err, context.Canceled) { + continue + } + errF := slog.Error(err) + var sdkErr *codersdk.Error + if xerrors.As(err, &sdkErr) { + errF = slog.Error(sdkErr) + } + tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", errF) continue } tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client") - tac.runConnectorOnce(tailnetClient) + tac.runConnectorOnce(tailnetClients) tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost") } }() @@ -127,144 +131,68 @@ var permanentErrorStatuses = []int{ http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist } -func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) { - tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API") - - u, err := url.Parse(tac.coordinateURL) - if err != nil { - return nil, xerrors.Errorf("parse URL %q: %w", tac.coordinateURL, err) - } - if tac.resumeToken != nil { - q := u.Query() - q.Set("resume_token", tac.resumeToken.Token) - u.RawQuery = q.Encode() - tac.logger.Debug(tac.ctx, "using resume token", slog.F("resume_token", tac.resumeToken)) - } - - coordinateURL := u.String() - tac.logger.Debug(tac.ctx, "using coordinate URL", slog.F("url", coordinateURL)) - - // nolint:bodyclose - ws, res, err := websocket.Dial(tac.ctx, coordinateURL, tac.dialOptions) - if tac.isFirst { - if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) { - err = codersdk.ReadBodyAsError(res) - // A bit more human-readable help in the case the API version was rejected - var sdkErr *codersdk.Error - if xerrors.As(err, &sdkErr) { - if sdkErr.Message == AgentAPIMismatchMessage && - sdkErr.StatusCode() == http.StatusBadRequest { - sdkErr.Helper = fmt.Sprintf( - "Ensure your client release version (%s, different than the API version) matches the server release version", - buildinfo.Version()) - } - } - tac.connected <- err - return nil, err - } - tac.isFirst = false - close(tac.connected) - } - if err != nil { - bodyErr := codersdk.ReadBodyAsError(res) - var sdkErr *codersdk.Error - if xerrors.As(bodyErr, &sdkErr) { - for _, v := range sdkErr.Validations { - if v.Field == "resume_token" { - // Unset the resume token for the next attempt - tac.logger.Warn(tac.ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt") - tac.resumeToken = nil - return nil, err - } - } - } - if !errors.Is(err, context.Canceled) { - tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr)) - } - return nil, err - } - client, err := tailnet.NewDRPCClient( - websocket.NetConn(tac.gracefulCtx, ws, websocket.MessageBinary), - tac.logger, - ) - if err != nil { - tac.logger.Debug(tac.ctx, "failed to create DRPCClient", slog.Error(err)) - _ = ws.Close(websocket.StatusInternalError, "") - return nil, err - } - return client, err -} - // runConnectorOnce uses the provided client to coordinate and stream DERP Maps. It is combined // into one function so that a problem with one tears down the other and triggers a retry (if // appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same // fate. -func (tac *tailnetAPIConnector) runConnectorOnce(client proto.DRPCTailnetClient) { +func (tac *tailnetAPIConnector) runConnectorOnce(clients tailnet.ControlProtocolClients) { defer func() { - conn := client.DRPCConn() - closeErr := conn.Close() + closeErr := clients.Closer.Close() if closeErr != nil && !xerrors.Is(closeErr, io.EOF) && !xerrors.Is(closeErr, context.Canceled) && !xerrors.Is(closeErr, context.DeadlineExceeded) { tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr)) - <-conn.Closed() } }() - tac.telCtrl.New(client) // synchronous, doesn't need a goroutine + tac.telCtrl.New(clients.Telemetry) // synchronous, doesn't need a goroutine refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx) wg := sync.WaitGroup{} wg.Add(3) go func() { defer wg.Done() - tac.coordinate(client) + tac.coordinate(clients.Coordinator) }() go func() { defer wg.Done() defer refreshTokenCancel() - dErr := tac.derpMap(client) + dErr := tac.derpMap(clients.DERP) if dErr != nil && tac.ctx.Err() == nil { // The main context is still active, meaning that we want the tailnet data plane to stay // up, even though we hit some error getting DERP maps on the control plane. That means // we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just // close the underlying connection. This will trigger a retry of the control plane in // run(). - client.DRPCConn().Close() + clients.Closer.Close() // Note that derpMap() logs it own errors, we don't bother here. } }() go func() { defer wg.Done() - tac.refreshToken(refreshTokenCtx, client) + tac.refreshToken(refreshTokenCtx, clients.ResumeToken) }() wg.Wait() } -func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { - // we use the gracefulCtx here so that we'll have time to send the graceful disconnect - coord, err := client.Coordinate(tac.gracefulCtx) - if err != nil { - tac.logger.Error(tac.ctx, "failed to connect to Coordinate RPC", slog.Error(err)) - return - } +func (tac *tailnetAPIConnector) coordinate(client tailnet.CoordinatorClient) { defer func() { - cErr := coord.Close() + cErr := client.Close() if cErr != nil { tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr)) } }() - coordination := tac.coordCtrl.New(coord) + coordination := tac.coordCtrl.New(client) tac.logger.Debug(tac.ctx, "serving coordinator") select { case <-tac.ctx.Done(): tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect") crdErr := coordination.Close(tac.gracefulCtx) if crdErr != nil { - tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err)) + tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(crdErr)) } - case err = <-coordination.Wait(): + case err := <-coordination.Wait(): if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, context.Canceled) && @@ -274,65 +202,162 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { } } -func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error { - s := &tailnet.DERPFromDRPCWrapper{} - var err error - s.Client, err = client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{}) - if err != nil { - return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err) - } +func (tac *tailnetAPIConnector) derpMap(client tailnet.DERPClient) error { defer func() { - cErr := s.Close() + cErr := client.Close() if cErr != nil { tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr)) } }() - cw := tac.derpCtrl.New(s) - err = <-cw.Wait() - if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + cw := tac.derpCtrl.New(client) + select { + case <-tac.ctx.Done(): + cErr := client.Close() + if cErr != nil { + tac.logger.Warn(tac.ctx, "failed to close StreamDERPMaps RPC", slog.Error(cErr)) + } return nil + case err := <-cw.Wait(): + if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + return nil + } + if err != nil && !xerrors.Is(err, io.EOF) { + tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err)) + } + return err } - if err != nil && !xerrors.Is(err, io.EOF) { - tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err)) +} + +func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client tailnet.ResumeTokenClient) { + cw := tac.tokenCtrl.New(client) + go func() { + <-ctx.Done() + cErr := cw.Close(tac.ctx) + if cErr != nil { + tac.logger.Error(tac.ctx, "error closing token refresher", slog.Error(cErr)) + } + }() + + err := <-cw.Wait() + if err != nil && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) { + tac.logger.Error(tac.ctx, "error receiving refresh token", slog.Error(err)) } - return err } -func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.DRPCTailnetClient) { - ticker := tac.clock.NewTicker(15*time.Second, "tailnetAPIConnector", "refreshToken") - defer ticker.Stop() +func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) { + tac.telCtrl.SendTelemetryEvent(event) +} + +type WebsocketDialer struct { + logger slog.Logger + dialOptions *websocket.DialOptions + url *url.URL + resumeTokenFailed bool + connected chan error + isFirst bool +} + +func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController, +) ( + tailnet.ControlProtocolClients, error, +) { + w.logger.Debug(ctx, "dialing Coder tailnet v2+ API") - initialCh := make(chan struct{}, 1) - initialCh <- struct{}{} - defer close(initialCh) - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - case <-initialCh: + u := new(url.URL) + *u = *w.url + if r != nil && !w.resumeTokenFailed { + if token, ok := r.Token(); ok { + q := u.Query() + q.Set("resume_token", token) + u.RawQuery = q.Encode() + w.logger.Debug(ctx, "using resume token on dial") } + } - attemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - res, err := client.RefreshResumeToken(attemptCtx, &proto.RefreshResumeTokenRequest{}) - cancel() - if err != nil { - if ctx.Err() == nil { - tac.logger.Error(tac.ctx, "error refreshing coordinator resume token", slog.Error(err)) + // nolint:bodyclose + ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions) + if w.isFirst { + if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) { + err = codersdk.ReadBodyAsError(res) + // A bit more human-readable help in the case the API version was rejected + var sdkErr *codersdk.Error + if xerrors.As(err, &sdkErr) { + if sdkErr.Message == AgentAPIMismatchMessage && + sdkErr.StatusCode() == http.StatusBadRequest { + sdkErr.Helper = fmt.Sprintf( + "Ensure your client release version (%s, different than the API version) matches the server release version", + buildinfo.Version()) + } } - return + w.connected <- err + return tailnet.ControlProtocolClients{}, err } - tac.logger.Debug(tac.ctx, "refreshed coordinator resume token", slog.F("resume_token", res)) - tac.resumeToken = res - dur := res.RefreshIn.AsDuration() - if dur <= 0 { - // A sensible delay to refresh again. - dur = 30 * time.Minute + w.isFirst = false + close(w.connected) + } + if err != nil { + bodyErr := codersdk.ReadBodyAsError(res) + var sdkErr *codersdk.Error + if xerrors.As(bodyErr, &sdkErr) { + for _, v := range sdkErr.Validations { + if v.Field == "resume_token" { + // Unset the resume token for the next attempt + w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt") + w.resumeTokenFailed = true + return tailnet.ControlProtocolClients{}, err + } + } + } + if !errors.Is(err, context.Canceled) { + w.logger.Error(ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr)) } - ticker.Reset(dur, "tailnetAPIConnector", "refreshToken", "reset") + return tailnet.ControlProtocolClients{}, err + } + w.resumeTokenFailed = false + + client, err := tailnet.NewDRPCClient( + websocket.NetConn(context.Background(), ws, websocket.MessageBinary), + w.logger, + ) + if err != nil { + w.logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + return tailnet.ControlProtocolClients{}, err } + coord, err := client.Coordinate(context.Background()) + if err != nil { + w.logger.Debug(ctx, "failed to create Coordinate RPC", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + return tailnet.ControlProtocolClients{}, err + } + + derps := &tailnet.DERPFromDRPCWrapper{} + derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{}) + if err != nil { + w.logger.Debug(ctx, "failed to create DERPMap stream", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + return tailnet.ControlProtocolClients{}, err + } + + return tailnet.ControlProtocolClients{ + Closer: client.DRPCConn(), + Coordinator: coord, + DERP: derps, + ResumeToken: client, + Telemetry: client, + }, nil } -func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) { - tac.telCtrl.SendTelemetryEvent(event) +func (w *WebsocketDialer) Connected() <-chan error { + return w.connected +} + +func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer { + return &WebsocketDialer{ + logger: logger, + dialOptions: opts, + url: u, + connected: make(chan error, 1), + isFirst: true, + } } diff --git a/codersdk/workspacesdk/connector_internal_test.go b/codersdk/workspacesdk/connector_internal_test.go index 88b857320cdb1..2d66d105e066d 100644 --- a/codersdk/workspacesdk/connector_internal_test.go +++ b/codersdk/workspacesdk/connector_internal_test.go @@ -3,8 +3,7 @@ package workspacesdk import ( "context" "io" - "net/http" - "net/http/httptest" + "net" "sync/atomic" "testing" "time" @@ -13,16 +12,10 @@ import ( "github.com/hashicorp/yamux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "nhooyr.io/websocket" - "storj.io/drpc" "tailscale.com/tailcfg" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/v2/apiversion" - "github.com/coder/coder/v2/coderd/httpapi" - "github.com/coder/coder/v2/coderd/jwtutils" - "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" @@ -63,32 +56,27 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { }) require.NoError(t, err) - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sws, err := websocket.Accept(w, r, nil) - if !assert.NoError(t, err) { - return - } - ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary) - err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ + dialer := &pipeDialer{ + ctx: testCtx, + logger: logger, + t: t, + svc: svc, + streamID: tailnet.StreamID{ Name: "client", ID: clientID, Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID}, - }) - assert.NoError(t, err) - })) + }, + } fConn := newFakeTailnetConn() - uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, svr.URL, - quartz.NewReal(), &websocket.DialOptions{}) + uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, dialer, quartz.NewReal()) uut.runConnector(fConn) call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs) require.NotNil(t, reqTun.AddTunnel) - _ = testutil.RequireRecvCtx(ctx, t, uut.connected) - // simulate a problem with DERPMaps by sending nil testutil.RequireSendCtx(ctx, t, derpMapCh, nil) @@ -109,259 +97,6 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { close(call.Resps) } -func TestTailnetAPIConnector_UplevelVersion(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - agentID := uuid.UUID{0x55} - - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sVer := apiversion.New(proto.CurrentMajor, proto.CurrentMinor-1) - - // the following matches what Coderd does; - // c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate - cVer := r.URL.Query().Get("version") - if err := sVer.Validate(cVer); err != nil { - httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{ - Message: AgentAPIMismatchMessage, - Validations: []codersdk.ValidationError{ - {Field: "version", Detail: err.Error()}, - }, - }) - return - } - })) - - fConn := newFakeTailnetConn() - - uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{}) - uut.runConnector(fConn) - - err := testutil.RequireRecvCtx(ctx, t, uut.connected) - var sdkErr *codersdk.Error - require.ErrorAs(t, err, &sdkErr) - require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) - require.Equal(t, AgentAPIMismatchMessage, sdkErr.Message) - require.NotEmpty(t, sdkErr.Helper) -} - -func TestTailnetAPIConnector_ResumeToken(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, &slogtest.Options{ - IgnoreErrors: true, - }).Leveled(slog.LevelDebug) - agentID := uuid.UUID{0x55} - fCoord := tailnettest.NewFakeCoordinator() - var coord tailnet.Coordinator = fCoord - coordPtr := atomic.Pointer[tailnet.Coordinator]{} - coordPtr.Store(&coord) - derpMapCh := make(chan *tailcfg.DERPMap) - defer close(derpMapCh) - - clock := quartz.NewMock(t) - resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() - require.NoError(t, err) - mgr := jwtutils.StaticKey{ - ID: "123", - Key: resumeTokenSigningKey[:], - } - resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour) - svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ - Logger: logger, - CoordPtr: &coordPtr, - DERPMapUpdateFrequency: time.Millisecond, - DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, - NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {}, - ResumeTokenProvider: resumeTokenProvider, - }) - require.NoError(t, err) - - var ( - websocketConnCh = make(chan *websocket.Conn, 64) - expectResumeToken = "" - ) - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Accept a resume_token query parameter to use the same peer ID. This - // behavior matches the actual client coordinate route. - var ( - peerID = uuid.New() - resumeToken = r.URL.Query().Get("resume_token") - ) - t.Logf("received resume token: %s", resumeToken) - assert.Equal(t, expectResumeToken, resumeToken) - if resumeToken != "" { - peerID, err = resumeTokenProvider.VerifyResumeToken(ctx, resumeToken) - assert.NoError(t, err, "failed to parse resume token") - if err != nil { - httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{ - Message: CoordinateAPIInvalidResumeToken, - Detail: err.Error(), - Validations: []codersdk.ValidationError{ - {Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken}, - }, - }) - return - } - } - - sws, err := websocket.Accept(w, r, nil) - if !assert.NoError(t, err) { - return - } - testutil.RequireSendCtx(ctx, t, websocketConnCh, sws) - ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary) - err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ - Name: "client", - ID: peerID, - Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID}, - }) - assert.NoError(t, err) - })) - - fConn := newFakeTailnetConn() - - newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken") - tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset") - defer newTickerTrap.Close() - uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{}) - uut.runConnector(fConn) - - // Fetch first token. We don't need to advance the clock since we use a - // channel with a single item to immediately fetch. - newTickerTrap.MustWait(ctx).Release() - // We call ticker.Reset after each token fetch to apply the refresh duration - // requested by the server. - trappedReset := tickerResetTrap.MustWait(ctx) - trappedReset.Release() - require.NotNil(t, uut.resumeToken) - originalResumeToken := uut.resumeToken.Token - - // Fetch second token. - waiter := clock.Advance(trappedReset.Duration) - waiter.MustWait(ctx) - trappedReset = tickerResetTrap.MustWait(ctx) - trappedReset.Release() - require.NotNil(t, uut.resumeToken) - require.NotEqual(t, originalResumeToken, uut.resumeToken.Token) - expectResumeToken = uut.resumeToken.Token - t.Logf("expecting resume token: %s", expectResumeToken) - - // Sever the connection and expect it to reconnect with the resume token. - wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh) - _ = wsConn.Close(websocket.StatusGoingAway, "test") - - // Wait for the resume token to be refreshed. - trappedTicker := newTickerTrap.MustWait(ctx) - // Advance the clock slightly to ensure the new JWT is different. - clock.Advance(time.Second).MustWait(ctx) - trappedTicker.Release() - trappedReset = tickerResetTrap.MustWait(ctx) - trappedReset.Release() - - // The resume token should have changed again. - require.NotNil(t, uut.resumeToken) - require.NotEqual(t, expectResumeToken, uut.resumeToken.Token) -} - -func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, &slogtest.Options{ - IgnoreErrors: true, - }).Leveled(slog.LevelDebug) - agentID := uuid.UUID{0x55} - fCoord := tailnettest.NewFakeCoordinator() - var coord tailnet.Coordinator = fCoord - coordPtr := atomic.Pointer[tailnet.Coordinator]{} - coordPtr.Store(&coord) - derpMapCh := make(chan *tailcfg.DERPMap) - defer close(derpMapCh) - - clock := quartz.NewMock(t) - resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() - require.NoError(t, err) - mgr := jwtutils.StaticKey{ - ID: uuid.New().String(), - Key: resumeTokenSigningKey[:], - } - resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour) - svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ - Logger: logger, - CoordPtr: &coordPtr, - DERPMapUpdateFrequency: time.Millisecond, - DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, - NetworkTelemetryHandler: func(_ []*proto.TelemetryEvent) {}, - ResumeTokenProvider: resumeTokenProvider, - }) - require.NoError(t, err) - - var ( - websocketConnCh = make(chan *websocket.Conn, 64) - didFail int64 - ) - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Query().Get("resume_token") != "" { - atomic.AddInt64(&didFail, 1) - httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{ - Message: CoordinateAPIInvalidResumeToken, - Validations: []codersdk.ValidationError{ - {Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken}, - }, - }) - return - } - - sws, err := websocket.Accept(w, r, nil) - if !assert.NoError(t, err) { - return - } - testutil.RequireSendCtx(ctx, t, websocketConnCh, sws) - ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary) - err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ - Name: "client", - ID: uuid.New(), - Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID}, - }) - assert.NoError(t, err) - })) - - fConn := newFakeTailnetConn() - - newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken") - tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset") - defer newTickerTrap.Close() - uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{}) - uut.runConnector(fConn) - - // Wait for the resume token to be fetched for the first time. - newTickerTrap.MustWait(ctx).Release() - trappedReset := tickerResetTrap.MustWait(ctx) - trappedReset.Release() - originalResumeToken := uut.resumeToken.Token - - // Sever the connection and expect it to reconnect with the resume token, - // which should fail and cause the client to be disconnected. The client - // should then reconnect with no resume token. - wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh) - _ = wsConn.Close(websocket.StatusGoingAway, "test") - - // Wait for the resume token to be refreshed, which indicates a successful - // reconnect. - trappedTicker := newTickerTrap.MustWait(ctx) - // Since we failed the initial reconnect and we're definitely reconnected - // now, the stored resume token should now be nil. - require.Nil(t, uut.resumeToken) - trappedTicker.Release() - trappedReset = tickerResetTrap.MustWait(ctx) - trappedReset.Release() - require.NotNil(t, uut.resumeToken) - require.NotEqual(t, originalResumeToken, uut.resumeToken.Token) - - // The resume token should have been rejected by the server. - require.EqualValues(t, 1, atomic.LoadInt64(&didFail)) -} - func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) @@ -392,23 +127,21 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) { }) require.NoError(t, err) - svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sws, err := websocket.Accept(w, r, nil) - if !assert.NoError(t, err) { - return - } - ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary) - err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ + dialer := &pipeDialer{ + ctx: ctx, + logger: logger, + t: t, + svc: svc, + streamID: tailnet.StreamID{ Name: "client", ID: clientID, Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID}, - }) - assert.NoError(t, err) - })) + }, + } fConn := newFakeTailnetConn() - uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{}) + uut := newTailnetAPIConnector(ctx, logger, agentID, dialer, quartz.NewReal()) uut.runConnector(fConn) // Coordinate calls happen _after_ telemetry is connected up, so we use this // to ensure telemetry is connected before sending our event @@ -444,82 +177,42 @@ func newFakeTailnetConn() *fakeTailnetConn { return &fakeTailnetConn{} } -type fakeDRPCConn struct{} - -var _ drpc.Conn = &fakeDRPCConn{} - -// Close implements drpc.Conn. -func (*fakeDRPCConn) Close() error { - return nil -} - -// Closed implements drpc.Conn. -func (*fakeDRPCConn) Closed() <-chan struct{} { - return nil -} - -// Invoke implements drpc.Conn. -func (*fakeDRPCConn) Invoke(_ context.Context, _ string, _ drpc.Encoding, _ drpc.Message, _ drpc.Message) error { - return nil -} - -// NewStream implements drpc.Conn. -func (*fakeDRPCConn) NewStream(_ context.Context, _ string, _ drpc.Encoding) (drpc.Stream, error) { - return nil, nil -} - -type fakeDRPCStream struct { - ch chan struct{} -} - -var _ proto.DRPCTailnet_CoordinateClient = &fakeDRPCStream{} - -// Close implements proto.DRPCTailnet_CoordinateClient. -func (f *fakeDRPCStream) Close() error { - close(f.ch) - return nil -} - -// CloseSend implements proto.DRPCTailnet_CoordinateClient. -func (*fakeDRPCStream) CloseSend() error { - return nil -} - -// Context implements proto.DRPCTailnet_CoordinateClient. -func (*fakeDRPCStream) Context() context.Context { - return nil -} - -// MsgRecv implements proto.DRPCTailnet_CoordinateClient. -func (*fakeDRPCStream) MsgRecv(_ drpc.Message, _ drpc.Encoding) error { - return nil -} - -// MsgSend implements proto.DRPCTailnet_CoordinateClient. -func (*fakeDRPCStream) MsgSend(_ drpc.Message, _ drpc.Encoding) error { - return nil -} - -// Recv implements proto.DRPCTailnet_CoordinateClient. -func (f *fakeDRPCStream) Recv() (*proto.CoordinateResponse, error) { - <-f.ch - return &proto.CoordinateResponse{}, nil -} - -// Send implements proto.DRPCTailnet_CoordinateClient. -func (f *fakeDRPCStream) Send(*proto.CoordinateRequest) error { - <-f.ch - return nil -} - -type fakeDRPPCMapStream struct { - fakeDRPCStream -} - -var _ proto.DRPCTailnet_StreamDERPMapsClient = &fakeDRPPCMapStream{} +type pipeDialer struct { + ctx context.Context + logger slog.Logger + t testing.TB + svc *tailnet.ClientService + streamID tailnet.StreamID +} + +func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) { + s, c := net.Pipe() + go func() { + err := p.svc.ServeConnV2(p.ctx, s, p.streamID) + p.logger.Debug(p.ctx, "piped tailnet service complete", slog.Error(err)) + }() + client, err := tailnet.NewDRPCClient(c, p.logger) + if !assert.NoError(p.t, err) { + _ = c.Close() + return tailnet.ControlProtocolClients{}, err + } + coord, err := client.Coordinate(context.Background()) + if !assert.NoError(p.t, err) { + _ = c.Close() + return tailnet.ControlProtocolClients{}, err + } -// Recv implements proto.DRPCTailnet_StreamDERPMapsClient. -func (f *fakeDRPPCMapStream) Recv() (*proto.DERPMap, error) { - <-f.fakeDRPCStream.ch - return &proto.DERPMap{}, nil + derps := &tailnet.DERPFromDRPCWrapper{} + derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{}) + if !assert.NoError(p.t, err) { + _ = c.Close() + return tailnet.ControlProtocolClients{}, err + } + return tailnet.ControlProtocolClients{ + Closer: client.DRPCConn(), + Coordinator: coord, + DERP: derps, + ResumeToken: client, + Telemetry: client, + }, nil } diff --git a/codersdk/workspacesdk/connector_test.go b/codersdk/workspacesdk/connector_test.go new file mode 100644 index 0000000000000..5247d5c7834da --- /dev/null +++ b/codersdk/workspacesdk/connector_test.go @@ -0,0 +1,350 @@ +package workspacesdk_test + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "nhooyr.io/websocket" + "tailscale.com/tailcfg" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/apiversion" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/tailnet/tailnettest" + "github.com/coder/coder/v2/testutil" +) + +func TestWebsocketDialer_TokenController(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + logger := slogtest.Make(t, &slogtest.Options{ + IgnoreErrors: true, + }).Leveled(slog.LevelDebug) + + fTokenProv := newFakeTokenController(ctx, t) + fCoord := tailnettest.NewFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + + svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ + Logger: logger, + CoordPtr: &coordPtr, + DERPMapUpdateFrequency: time.Hour, + DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} }, + }) + require.NoError(t, err) + + dialTokens := make(chan string, 1) + wsErr := make(chan error, 1) + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-ctx.Done(): + t.Error("timed out sending token") + case dialTokens <- r.URL.Query().Get("resume_token"): + // OK + } + + sws, err := websocket.Accept(w, r, nil) + if !assert.NoError(t, err) { + return + } + wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary) + // streamID can be empty because we don't call RPCs in this test. + wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{}) + })) + defer svr.Close() + svrURL, err := url.Parse(svr.URL) + require.NoError(t, err) + + uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{}) + + clientCh := make(chan tailnet.ControlProtocolClients, 1) + go func() { + clients, err := uut.Dial(ctx, fTokenProv) + assert.NoError(t, err) + clientCh <- clients + }() + + call := testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls) + call <- tokenResponse{"test token", true} + gotToken := <-dialTokens + require.Equal(t, "test token", gotToken) + + clients := testutil.RequireRecvCtx(ctx, t, clientCh) + clients.Closer.Close() + + err = testutil.RequireRecvCtx(ctx, t, wsErr) + require.NoError(t, err) + + clientCh = make(chan tailnet.ControlProtocolClients, 1) + go func() { + clients, err := uut.Dial(ctx, fTokenProv) + assert.NoError(t, err) + clientCh <- clients + }() + + call = testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls) + call <- tokenResponse{"test token", false} + gotToken = <-dialTokens + require.Equal(t, "", gotToken) + + clients = testutil.RequireRecvCtx(ctx, t, clientCh) + clients.Closer.Close() + + err = testutil.RequireRecvCtx(ctx, t, wsErr) + require.NoError(t, err) +} + +func TestWebsocketDialer_NoTokenController(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, &slogtest.Options{ + IgnoreErrors: true, + }).Leveled(slog.LevelDebug) + + fCoord := tailnettest.NewFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + + svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ + Logger: logger, + CoordPtr: &coordPtr, + DERPMapUpdateFrequency: time.Hour, + DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} }, + }) + require.NoError(t, err) + + dialTokens := make(chan string, 1) + wsErr := make(chan error, 1) + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-ctx.Done(): + t.Error("timed out sending token") + case dialTokens <- r.URL.Query().Get("resume_token"): + // OK + } + + sws, err := websocket.Accept(w, r, nil) + if !assert.NoError(t, err) { + return + } + wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary) + // streamID can be empty because we don't call RPCs in this test. + wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{}) + })) + defer svr.Close() + svrURL, err := url.Parse(svr.URL) + require.NoError(t, err) + + uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{}) + + clientCh := make(chan tailnet.ControlProtocolClients, 1) + go func() { + clients, err := uut.Dial(ctx, nil) + assert.NoError(t, err) + clientCh <- clients + }() + + gotToken := <-dialTokens + require.Equal(t, "", gotToken) + + clients := testutil.RequireRecvCtx(ctx, t, clientCh) + clients.Closer.Close() + + err = testutil.RequireRecvCtx(ctx, t, wsErr) + require.NoError(t, err) +} + +func TestWebsocketDialer_ResumeTokenFailure(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, &slogtest.Options{ + IgnoreErrors: true, + }).Leveled(slog.LevelDebug) + + fTokenProv := newFakeTokenController(ctx, t) + fCoord := tailnettest.NewFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + + svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ + Logger: logger, + CoordPtr: &coordPtr, + DERPMapUpdateFrequency: time.Hour, + DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} }, + }) + require.NoError(t, err) + + dialTokens := make(chan string, 1) + wsErr := make(chan error, 1) + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resumeToken := r.URL.Query().Get("resume_token") + select { + case <-ctx.Done(): + t.Error("timed out sending token") + case dialTokens <- resumeToken: + // OK + } + + if resumeToken != "" { + httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{ + Message: workspacesdk.CoordinateAPIInvalidResumeToken, + Validations: []codersdk.ValidationError{ + {Field: "resume_token", Detail: workspacesdk.CoordinateAPIInvalidResumeToken}, + }, + }) + return + } + sws, err := websocket.Accept(w, r, nil) + if !assert.NoError(t, err) { + return + } + wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary) + // streamID can be empty because we don't call RPCs in this test. + wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{}) + })) + defer svr.Close() + svrURL, err := url.Parse(svr.URL) + require.NoError(t, err) + + uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{}) + + errCh := make(chan error, 1) + go func() { + _, err := uut.Dial(ctx, fTokenProv) + errCh <- err + }() + + call := testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls) + call <- tokenResponse{"test token", true} + gotToken := <-dialTokens + require.Equal(t, "test token", gotToken) + + err = testutil.RequireRecvCtx(ctx, t, errCh) + require.Error(t, err) + + // redial should not use the token + clientCh := make(chan tailnet.ControlProtocolClients, 1) + go func() { + clients, err := uut.Dial(ctx, fTokenProv) + assert.NoError(t, err) + clientCh <- clients + }() + gotToken = <-dialTokens + require.Equal(t, "", gotToken) + + clients := testutil.RequireRecvCtx(ctx, t, clientCh) + require.Error(t, err) + clients.Closer.Close() + err = testutil.RequireRecvCtx(ctx, t, wsErr) + require.NoError(t, err) + + // Successful dial should reset to using token again + go func() { + _, err := uut.Dial(ctx, fTokenProv) + errCh <- err + }() + call = testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls) + call <- tokenResponse{"test token", true} + gotToken = <-dialTokens + require.Equal(t, "test token", gotToken) + err = testutil.RequireRecvCtx(ctx, t, errCh) + require.Error(t, err) +} + +func TestWebsocketDialer_UplevelVersion(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sVer := apiversion.New(proto.CurrentMajor, proto.CurrentMinor-1) + + // the following matches what Coderd does; + // c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate + cVer := r.URL.Query().Get("version") + if err := sVer.Validate(cVer); err != nil { + httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{ + Message: workspacesdk.AgentAPIMismatchMessage, + Validations: []codersdk.ValidationError{ + {Field: "version", Detail: err.Error()}, + }, + }) + return + } + })) + svrURL, err := url.Parse(svr.URL) + require.NoError(t, err) + + uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{}) + + errCh := make(chan error, 1) + go func() { + _, err := uut.Dial(ctx, nil) + errCh <- err + }() + + err = testutil.RequireRecvCtx(ctx, t, errCh) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + require.Equal(t, workspacesdk.AgentAPIMismatchMessage, sdkErr.Message) + require.NotEmpty(t, sdkErr.Helper) +} + +type fakeResumeTokenController struct { + ctx context.Context + t testing.TB + tokenCalls chan chan tokenResponse +} + +func (*fakeResumeTokenController) New(tailnet.ResumeTokenClient) tailnet.CloserWaiter { + panic("not implemented") +} + +func (f *fakeResumeTokenController) Token() (string, bool) { + call := make(chan tokenResponse) + select { + case <-f.ctx.Done(): + f.t.Error("timeout on Token() call") + case f.tokenCalls <- call: + // OK + } + select { + case <-f.ctx.Done(): + f.t.Error("timeout on Token() response") + return "", false + case r := <-call: + return r.token, r.ok + } +} + +var _ tailnet.ResumeTokenController = &fakeResumeTokenController{} + +func newFakeTokenController(ctx context.Context, t testing.TB) *fakeResumeTokenController { + return &fakeResumeTokenController{ + ctx: ctx, + t: t, + tokenCalls: make(chan chan tokenResponse), + } +} + +type tokenResponse struct { + token string + ok bool +} diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index d0983d81593d0..365a530d438aa 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -228,13 +228,13 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * q.Add("version", "2.0") coordinateURL.RawQuery = q.Encode() - connector := newTailnetAPIConnector(ctx, options.Logger, agentID, coordinateURL.String(), quartz.NewReal(), - &websocket.DialOptions{ - HTTPClient: c.client.HTTPClient, - HTTPHeader: headers, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) + dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{ + HTTPClient: c.client.HTTPClient, + HTTPHeader: headers, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, + }) + connector := newTailnetAPIConnector(ctx, options.Logger, agentID, dialer, quartz.NewReal()) ip := tailnet.TailscaleServicePrefix.RandomAddr() var header http.Header @@ -271,7 +271,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * select { case <-dialCtx.Done(): return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err()) - case err = <-connector.connected: + case err = <-dialer.Connected(): if err != nil { options.Logger.Error(ctx, "failed to connect to tailnet v2+ API", slog.Error(err)) return nil, xerrors.Errorf("start connector: %w", err) diff --git a/tailnet/controllers.go b/tailnet/controllers.go index 7a3e23e2e216d..3b032b4f323cf 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -615,7 +615,7 @@ func newBasicResumeTokenRefresher( errCh: make(chan error, 1), } r.ctx, r.cancel = context.WithCancel(context.Background()) - r.timer = clock.AfterFunc(never, r.refresh) + r.timer = clock.AfterFunc(never, r.refresh, "basicResumeTokenRefresher") go r.refresh() return r }