From abd4eddaff1b9a2f083051f4baad119921369d1e Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 4 Nov 2024 11:26:34 +0400 Subject: [PATCH] chore: refactor tailnetAPIConnector to tailnet.Controller --- codersdk/workspacesdk/connector.go | 363 ------------------ .../workspacesdk/connector_internal_test.go | 218 ----------- codersdk/workspacesdk/dialer.go | 139 +++++++ .../{connector_test.go => dialer_test.go} | 0 codersdk/workspacesdk/workspacesdk.go | 14 +- tailnet/controllers.go | 220 +++++++++++ tailnet/controllers_test.go | 201 ++++++++++ 7 files changed, 570 insertions(+), 585 deletions(-) delete mode 100644 codersdk/workspacesdk/connector.go delete mode 100644 codersdk/workspacesdk/connector_internal_test.go create mode 100644 codersdk/workspacesdk/dialer.go rename codersdk/workspacesdk/{connector_test.go => dialer_test.go} (100%) diff --git a/codersdk/workspacesdk/connector.go b/codersdk/workspacesdk/connector.go deleted file mode 100644 index fd4e028d31866..0000000000000 --- a/codersdk/workspacesdk/connector.go +++ /dev/null @@ -1,363 +0,0 @@ -package workspacesdk - -import ( - "context" - "errors" - "fmt" - "io" - "net/http" - "net/url" - "slices" - "sync" - "time" - - "github.com/google/uuid" - "golang.org/x/xerrors" - "nhooyr.io/websocket" - - "cdr.dev/slog" - "github.com/coder/coder/v2/buildinfo" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/tailnet" - "github.com/coder/coder/v2/tailnet/proto" - "github.com/coder/quartz" - "github.com/coder/retry" -) - -var tailnetConnectorGracefulTimeout = time.Second - -// tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is -// included so that we can fake it in testing. -// -// @typescript-ignore tailnetConn -type tailnetConn interface { - tailnet.Coordinatee - tailnet.DERPMapSetter -} - -// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to -// -// 1) run the Coordinate API and pass node information back and forth -// 2) stream DERPMap updates and program the Conn -// 3) Send network telemetry events -// -// These functions share the same websocket, and so are combined here so that if we hit a problem -// we tear the whole thing down and start over with a new websocket. -// -// @typescript-ignore tailnetAPIConnector -type tailnetAPIConnector struct { - // We keep track of two contexts: the main context from the caller, and a "graceful" context - // that we keep open slightly longer than the main context to give a chance to send the - // Disconnect message to the coordinator. That tells the coordinator that we really meant to - // disconnect instead of just losing network connectivity. - ctx context.Context - gracefulCtx context.Context - cancelGracefulCtx context.CancelFunc - - logger slog.Logger - - agentID uuid.UUID - clock quartz.Clock - dialer tailnet.ControlProtocolDialer - derpCtrl tailnet.DERPController - coordCtrl tailnet.CoordinationController - telCtrl *tailnet.BasicTelemetryController - tokenCtrl tailnet.ResumeTokenController - - closed chan struct{} -} - -// Create a new tailnetAPIConnector without running it -func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, dialer tailnet.ControlProtocolDialer, clock quartz.Clock) *tailnetAPIConnector { - return &tailnetAPIConnector{ - ctx: ctx, - logger: logger, - agentID: agentID, - clock: clock, - dialer: dialer, - closed: make(chan struct{}), - telCtrl: tailnet.NewBasicTelemetryController(logger), - tokenCtrl: tailnet.NewBasicResumeTokenController(logger, clock), - } -} - -// manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context -// to allow a graceful disconnect. -func (tac *tailnetAPIConnector) manageGracefulTimeout() { - defer tac.cancelGracefulCtx() - <-tac.ctx.Done() - timer := tac.clock.NewTimer(tailnetConnectorGracefulTimeout, "tailnetAPIClient", "gracefulTimeout") - defer timer.Stop() - select { - case <-tac.closed: - case <-timer.C: - } -} - -// Runs a tailnetAPIConnector using the provided connection -func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) { - tac.derpCtrl = tailnet.NewBasicDERPController(tac.logger, conn) - tac.coordCtrl = tailnet.NewSingleDestController(tac.logger, conn, tac.agentID) - tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background()) - go tac.manageGracefulTimeout() - go func() { - defer close(tac.closed) - // Sadly retry doesn't support quartz.Clock yet so this is not - // influenced by the configured clock. - for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); { - tailnetClients, err := tac.dialer.Dial(tac.ctx, tac.tokenCtrl) - if err != nil { - if xerrors.Is(err, context.Canceled) { - continue - } - errF := slog.Error(err) - var sdkErr *codersdk.Error - if xerrors.As(err, &sdkErr) { - errF = slog.Error(sdkErr) - } - tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", errF) - continue - } - tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client") - tac.runConnectorOnce(tailnetClients) - tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost") - } - }() -} - -var permanentErrorStatuses = []int{ - http.StatusConflict, // returned if client/agent connections disabled (browser only) - http.StatusBadRequest, // returned if API mismatch - http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist -} - -// runConnectorOnce uses the provided client to coordinate and stream DERP Maps. It is combined -// into one function so that a problem with one tears down the other and triggers a retry (if -// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same -// fate. -func (tac *tailnetAPIConnector) runConnectorOnce(clients tailnet.ControlProtocolClients) { - defer func() { - closeErr := clients.Closer.Close() - if closeErr != nil && - !xerrors.Is(closeErr, io.EOF) && - !xerrors.Is(closeErr, context.Canceled) && - !xerrors.Is(closeErr, context.DeadlineExceeded) { - tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr)) - } - }() - - tac.telCtrl.New(clients.Telemetry) // synchronous, doesn't need a goroutine - - refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx) - wg := sync.WaitGroup{} - wg.Add(3) - go func() { - defer wg.Done() - tac.coordinate(clients.Coordinator) - }() - go func() { - defer wg.Done() - defer refreshTokenCancel() - dErr := tac.derpMap(clients.DERP) - if dErr != nil && tac.ctx.Err() == nil { - // The main context is still active, meaning that we want the tailnet data plane to stay - // up, even though we hit some error getting DERP maps on the control plane. That means - // we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just - // close the underlying connection. This will trigger a retry of the control plane in - // run(). - clients.Closer.Close() - // Note that derpMap() logs it own errors, we don't bother here. - } - }() - go func() { - defer wg.Done() - tac.refreshToken(refreshTokenCtx, clients.ResumeToken) - }() - wg.Wait() -} - -func (tac *tailnetAPIConnector) coordinate(client tailnet.CoordinatorClient) { - defer func() { - cErr := client.Close() - if cErr != nil { - tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr)) - } - }() - coordination := tac.coordCtrl.New(client) - tac.logger.Debug(tac.ctx, "serving coordinator") - select { - case <-tac.ctx.Done(): - tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect") - crdErr := coordination.Close(tac.gracefulCtx) - if crdErr != nil { - tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(crdErr)) - } - case err := <-coordination.Wait(): - if err != nil && - !xerrors.Is(err, io.EOF) && - !xerrors.Is(err, context.Canceled) && - !xerrors.Is(err, context.DeadlineExceeded) { - tac.logger.Error(tac.ctx, "remote coordination error", slog.Error(err)) - } - } -} - -func (tac *tailnetAPIConnector) derpMap(client tailnet.DERPClient) error { - defer func() { - cErr := client.Close() - if cErr != nil { - tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr)) - } - }() - cw := tac.derpCtrl.New(client) - select { - case <-tac.ctx.Done(): - cErr := client.Close() - if cErr != nil { - tac.logger.Warn(tac.ctx, "failed to close StreamDERPMaps RPC", slog.Error(cErr)) - } - return nil - case err := <-cw.Wait(): - if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { - return nil - } - if err != nil && !xerrors.Is(err, io.EOF) { - tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err)) - } - return err - } -} - -func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client tailnet.ResumeTokenClient) { - cw := tac.tokenCtrl.New(client) - go func() { - <-ctx.Done() - cErr := cw.Close(tac.ctx) - if cErr != nil { - tac.logger.Error(tac.ctx, "error closing token refresher", slog.Error(cErr)) - } - }() - - err := <-cw.Wait() - if err != nil && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) { - tac.logger.Error(tac.ctx, "error receiving refresh token", slog.Error(err)) - } -} - -func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) { - tac.telCtrl.SendTelemetryEvent(event) -} - -type WebsocketDialer struct { - logger slog.Logger - dialOptions *websocket.DialOptions - url *url.URL - resumeTokenFailed bool - connected chan error - isFirst bool -} - -func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController, -) ( - tailnet.ControlProtocolClients, error, -) { - w.logger.Debug(ctx, "dialing Coder tailnet v2+ API") - - u := new(url.URL) - *u = *w.url - if r != nil && !w.resumeTokenFailed { - if token, ok := r.Token(); ok { - q := u.Query() - q.Set("resume_token", token) - u.RawQuery = q.Encode() - w.logger.Debug(ctx, "using resume token on dial") - } - } - - // nolint:bodyclose - ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions) - if w.isFirst { - if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) { - err = codersdk.ReadBodyAsError(res) - // A bit more human-readable help in the case the API version was rejected - var sdkErr *codersdk.Error - if xerrors.As(err, &sdkErr) { - if sdkErr.Message == AgentAPIMismatchMessage && - sdkErr.StatusCode() == http.StatusBadRequest { - sdkErr.Helper = fmt.Sprintf( - "Ensure your client release version (%s, different than the API version) matches the server release version", - buildinfo.Version()) - } - } - w.connected <- err - return tailnet.ControlProtocolClients{}, err - } - w.isFirst = false - close(w.connected) - } - if err != nil { - bodyErr := codersdk.ReadBodyAsError(res) - var sdkErr *codersdk.Error - if xerrors.As(bodyErr, &sdkErr) { - for _, v := range sdkErr.Validations { - if v.Field == "resume_token" { - // Unset the resume token for the next attempt - w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt") - w.resumeTokenFailed = true - return tailnet.ControlProtocolClients{}, err - } - } - } - if !errors.Is(err, context.Canceled) { - w.logger.Error(ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr)) - } - return tailnet.ControlProtocolClients{}, err - } - w.resumeTokenFailed = false - - client, err := tailnet.NewDRPCClient( - websocket.NetConn(context.Background(), ws, websocket.MessageBinary), - w.logger, - ) - if err != nil { - w.logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err)) - _ = ws.Close(websocket.StatusInternalError, "") - return tailnet.ControlProtocolClients{}, err - } - coord, err := client.Coordinate(context.Background()) - if err != nil { - w.logger.Debug(ctx, "failed to create Coordinate RPC", slog.Error(err)) - _ = ws.Close(websocket.StatusInternalError, "") - return tailnet.ControlProtocolClients{}, err - } - - derps := &tailnet.DERPFromDRPCWrapper{} - derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{}) - if err != nil { - w.logger.Debug(ctx, "failed to create DERPMap stream", slog.Error(err)) - _ = ws.Close(websocket.StatusInternalError, "") - return tailnet.ControlProtocolClients{}, err - } - - return tailnet.ControlProtocolClients{ - Closer: client.DRPCConn(), - Coordinator: coord, - DERP: derps, - ResumeToken: client, - Telemetry: client, - }, nil -} - -func (w *WebsocketDialer) Connected() <-chan error { - return w.connected -} - -func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer { - return &WebsocketDialer{ - logger: logger, - dialOptions: opts, - url: u, - connected: make(chan error, 1), - isFirst: true, - } -} diff --git a/codersdk/workspacesdk/connector_internal_test.go b/codersdk/workspacesdk/connector_internal_test.go deleted file mode 100644 index 2d66d105e066d..0000000000000 --- a/codersdk/workspacesdk/connector_internal_test.go +++ /dev/null @@ -1,218 +0,0 @@ -package workspacesdk - -import ( - "context" - "io" - "net" - "sync/atomic" - "testing" - "time" - - "github.com/google/uuid" - "github.com/hashicorp/yamux" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "tailscale.com/tailcfg" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/v2/tailnet" - "github.com/coder/coder/v2/tailnet/proto" - "github.com/coder/coder/v2/tailnet/tailnettest" - "github.com/coder/coder/v2/testutil" - "github.com/coder/quartz" -) - -func init() { - // Give tests a bit more time to timeout. Darwin is particularly slow. - tailnetConnectorGracefulTimeout = 5 * time.Second -} - -func TestTailnetAPIConnector_Disconnects(t *testing.T) { - t.Parallel() - testCtx := testutil.Context(t, testutil.WaitShort) - ctx, cancel := context.WithCancel(testCtx) - logger := slogtest.Make(t, &slogtest.Options{ - IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, - io.EOF, // we get EOF when we simulate a DERPMap error - yamux.ErrSessionShutdown, // coordination can throw these when DERP error tears down session - ), - }).Leveled(slog.LevelDebug) - agentID := uuid.UUID{0x55} - clientID := uuid.UUID{0x66} - fCoord := tailnettest.NewFakeCoordinator() - var coord tailnet.Coordinator = fCoord - coordPtr := atomic.Pointer[tailnet.Coordinator]{} - coordPtr.Store(&coord) - derpMapCh := make(chan *tailcfg.DERPMap) - defer close(derpMapCh) - svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ - Logger: logger.Named("svc"), - CoordPtr: &coordPtr, - DERPMapUpdateFrequency: time.Millisecond, - DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, - NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {}, - ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), - }) - require.NoError(t, err) - - dialer := &pipeDialer{ - ctx: testCtx, - logger: logger, - t: t, - svc: svc, - streamID: tailnet.StreamID{ - Name: "client", - ID: clientID, - Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID}, - }, - } - - fConn := newFakeTailnetConn() - - uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, dialer, quartz.NewReal()) - uut.runConnector(fConn) - - call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) - reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs) - require.NotNil(t, reqTun.AddTunnel) - - // simulate a problem with DERPMaps by sending nil - testutil.RequireSendCtx(ctx, t, derpMapCh, nil) - - // this should cause the coordinate call to hang up WITHOUT disconnecting - reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs) - require.Nil(t, reqNil) - - // ...and then reconnect - call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) - reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs) - require.NotNil(t, reqTun.AddTunnel) - - // canceling the context should trigger the disconnect message - cancel() - reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs) - require.NotNil(t, reqDisc) - require.NotNil(t, reqDisc.Disconnect) - close(call.Resps) -} - -func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - agentID := uuid.UUID{0x55} - clientID := uuid.UUID{0x66} - fCoord := tailnettest.NewFakeCoordinator() - var coord tailnet.Coordinator = fCoord - coordPtr := atomic.Pointer[tailnet.Coordinator]{} - coordPtr.Store(&coord) - derpMapCh := make(chan *tailcfg.DERPMap) - defer close(derpMapCh) - eventCh := make(chan []*proto.TelemetryEvent, 1) - svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ - Logger: logger, - CoordPtr: &coordPtr, - DERPMapUpdateFrequency: time.Millisecond, - DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { - select { - case <-ctx.Done(): - t.Error("timeout sending telemetry event") - case eventCh <- batch: - t.Log("sent telemetry batch") - } - }, - ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), - }) - require.NoError(t, err) - - dialer := &pipeDialer{ - ctx: ctx, - logger: logger, - t: t, - svc: svc, - streamID: tailnet.StreamID{ - Name: "client", - ID: clientID, - Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID}, - }, - } - - fConn := newFakeTailnetConn() - - uut := newTailnetAPIConnector(ctx, logger, agentID, dialer, quartz.NewReal()) - uut.runConnector(fConn) - // Coordinate calls happen _after_ telemetry is connected up, so we use this - // to ensure telemetry is connected before sending our event - cc := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) - defer close(cc.Resps) - - uut.SendTelemetryEvent(&proto.TelemetryEvent{ - Id: []byte("test event"), - }) - - testEvents := testutil.RequireRecvCtx(ctx, t, eventCh) - - require.Len(t, testEvents, 1) - require.Equal(t, []byte("test event"), testEvents[0].Id) -} - -type fakeTailnetConn struct{} - -func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error { - // TODO implement me - panic("implement me") -} - -func (*fakeTailnetConn) SetAllPeersLost() {} - -func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {} - -func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {} - -func (*fakeTailnetConn) SetTunnelDestination(uuid.UUID) {} - -func newFakeTailnetConn() *fakeTailnetConn { - return &fakeTailnetConn{} -} - -type pipeDialer struct { - ctx context.Context - logger slog.Logger - t testing.TB - svc *tailnet.ClientService - streamID tailnet.StreamID -} - -func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) { - s, c := net.Pipe() - go func() { - err := p.svc.ServeConnV2(p.ctx, s, p.streamID) - p.logger.Debug(p.ctx, "piped tailnet service complete", slog.Error(err)) - }() - client, err := tailnet.NewDRPCClient(c, p.logger) - if !assert.NoError(p.t, err) { - _ = c.Close() - return tailnet.ControlProtocolClients{}, err - } - coord, err := client.Coordinate(context.Background()) - if !assert.NoError(p.t, err) { - _ = c.Close() - return tailnet.ControlProtocolClients{}, err - } - - derps := &tailnet.DERPFromDRPCWrapper{} - derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{}) - if !assert.NoError(p.t, err) { - _ = c.Close() - return tailnet.ControlProtocolClients{}, err - } - return tailnet.ControlProtocolClients{ - Closer: client.DRPCConn(), - Coordinator: coord, - DERP: derps, - ResumeToken: client, - Telemetry: client, - }, nil -} diff --git a/codersdk/workspacesdk/dialer.go b/codersdk/workspacesdk/dialer.go new file mode 100644 index 0000000000000..b15c13aa978f9 --- /dev/null +++ b/codersdk/workspacesdk/dialer.go @@ -0,0 +1,139 @@ +package workspacesdk + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "slices" + + "golang.org/x/xerrors" + "nhooyr.io/websocket" + + "cdr.dev/slog" + "github.com/coder/coder/v2/buildinfo" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" +) + +var permanentErrorStatuses = []int{ + http.StatusConflict, // returned if client/agent connections disabled (browser only) + http.StatusBadRequest, // returned if API mismatch + http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist +} + +type WebsocketDialer struct { + logger slog.Logger + dialOptions *websocket.DialOptions + url *url.URL + resumeTokenFailed bool + connected chan error + isFirst bool +} + +func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController, +) ( + tailnet.ControlProtocolClients, error, +) { + w.logger.Debug(ctx, "dialing Coder tailnet v2+ API") + + u := new(url.URL) + *u = *w.url + if r != nil && !w.resumeTokenFailed { + if token, ok := r.Token(); ok { + q := u.Query() + q.Set("resume_token", token) + u.RawQuery = q.Encode() + w.logger.Debug(ctx, "using resume token on dial") + } + } + + // nolint:bodyclose + ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions) + if w.isFirst { + if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) { + err = codersdk.ReadBodyAsError(res) + // A bit more human-readable help in the case the API version was rejected + var sdkErr *codersdk.Error + if xerrors.As(err, &sdkErr) { + if sdkErr.Message == AgentAPIMismatchMessage && + sdkErr.StatusCode() == http.StatusBadRequest { + sdkErr.Helper = fmt.Sprintf( + "Ensure your client release version (%s, different than the API version) matches the server release version", + buildinfo.Version()) + } + } + w.connected <- err + return tailnet.ControlProtocolClients{}, err + } + w.isFirst = false + close(w.connected) + } + if err != nil { + bodyErr := codersdk.ReadBodyAsError(res) + var sdkErr *codersdk.Error + if xerrors.As(bodyErr, &sdkErr) { + for _, v := range sdkErr.Validations { + if v.Field == "resume_token" { + // Unset the resume token for the next attempt + w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt") + w.resumeTokenFailed = true + return tailnet.ControlProtocolClients{}, err + } + } + } + if !errors.Is(err, context.Canceled) { + w.logger.Error(ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr)) + } + return tailnet.ControlProtocolClients{}, err + } + w.resumeTokenFailed = false + + client, err := tailnet.NewDRPCClient( + websocket.NetConn(context.Background(), ws, websocket.MessageBinary), + w.logger, + ) + if err != nil { + w.logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + return tailnet.ControlProtocolClients{}, err + } + coord, err := client.Coordinate(context.Background()) + if err != nil { + w.logger.Debug(ctx, "failed to create Coordinate RPC", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + return tailnet.ControlProtocolClients{}, err + } + + derps := &tailnet.DERPFromDRPCWrapper{} + derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{}) + if err != nil { + w.logger.Debug(ctx, "failed to create DERPMap stream", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + return tailnet.ControlProtocolClients{}, err + } + + return tailnet.ControlProtocolClients{ + Closer: client.DRPCConn(), + Coordinator: coord, + DERP: derps, + ResumeToken: client, + Telemetry: client, + }, nil +} + +func (w *WebsocketDialer) Connected() <-chan error { + return w.connected +} + +func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer { + return &WebsocketDialer{ + logger: logger, + dialOptions: opts, + url: u, + connected: make(chan error, 1), + isFirst: true, + } +} diff --git a/codersdk/workspacesdk/connector_test.go b/codersdk/workspacesdk/dialer_test.go similarity index 100% rename from codersdk/workspacesdk/connector_test.go rename to codersdk/workspacesdk/dialer_test.go diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index 365a530d438aa..5ce0c06065173 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -234,7 +234,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * // Need to disable compression to avoid a data-race. CompressionMode: websocket.CompressionDisabled, }) - connector := newTailnetAPIConnector(ctx, options.Logger, agentID, dialer, quartz.NewReal()) + clk := quartz.NewReal() + controller := tailnet.NewController(options.Logger, dialer) + controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk) ip := tailnet.TailscaleServicePrefix.RandomAddr() var header http.Header @@ -243,7 +245,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * } var telemetrySink tailnet.TelemetrySink if options.EnableTelemetry { - telemetrySink = connector + basicTel := tailnet.NewBasicTelemetryController(options.Logger) + telemetrySink = basicTel + controller.TelemetryCtrl = basicTel } conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)}, @@ -264,7 +268,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * _ = conn.Close() } }() - connector.runConnector(conn) + controller.CoordCtrl = tailnet.NewSingleDestController(options.Logger, conn, agentID) + controller.DERPCtrl = tailnet.NewBasicDERPController(options.Logger, conn) + controller.Run(ctx) options.Logger.Debug(ctx, "running tailnet API v2+ connector") @@ -283,7 +289,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * AgentID: agentID, CloseFunc: func() error { cancel() - <-connector.closed + <-controller.Closed() return conn.Close() }, }) diff --git a/tailnet/controllers.go b/tailnet/controllers.go index 3b032b4f323cf..4a11e0f537e66 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -16,8 +16,10 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/quartz" + "github.com/coder/retry" ) // A Controller connects to the tailnet control plane, and then uses the control protocols to @@ -30,6 +32,16 @@ type Controller struct { DERPCtrl DERPController ResumeTokenCtrl ResumeTokenController TelemetryCtrl TelemetryController + + ctx context.Context + gracefulCtx context.Context + cancelGracefulCtx context.CancelFunc + logger slog.Logger + closedCh chan struct{} + + // Testing only + clock quartz.Clock + gracefulTimeout time.Duration } type CloserWaiter interface { @@ -664,3 +676,211 @@ func (r *basicResumeTokenRefresher) refresh() { } r.timer.Reset(dur, "basicResumeTokenRefresher", "refresh") } + +// NewController creates a new Controller without running it +func NewController(logger slog.Logger, dialer ControlProtocolDialer, opts ...ControllerOpt) *Controller { + c := &Controller{ + logger: logger, + clock: quartz.NewReal(), + gracefulTimeout: time.Second, + Dialer: dialer, + closedCh: make(chan struct{}), + } + for _, opt := range opts { + opt(c) + } + return c +} + +type ControllerOpt func(*Controller) + +func WithTestClock(clock quartz.Clock) ControllerOpt { + return func(c *Controller) { + c.clock = clock + } +} + +func WithGracefulTimeout(timeout time.Duration) ControllerOpt { + return func(c *Controller) { + c.gracefulTimeout = timeout + } +} + +// manageGracefulTimeout allows the gracefulContext to last longer than the main context +// to allow a graceful disconnect. +func (c *Controller) manageGracefulTimeout() { + defer c.cancelGracefulCtx() + <-c.ctx.Done() + timer := c.clock.NewTimer(c.gracefulTimeout, "tailnetAPIClient", "gracefulTimeout") + defer timer.Stop() + select { + case <-c.closedCh: + case <-timer.C: + } +} + +// Run dials the API and uses it with the provided controllers. +func (c *Controller) Run(ctx context.Context) { + c.ctx = ctx + c.gracefulCtx, c.cancelGracefulCtx = context.WithCancel(context.Background()) + go c.manageGracefulTimeout() + go func() { + defer close(c.closedCh) + // Sadly retry doesn't support quartz.Clock yet so this is not + // influenced by the configured clock. + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(c.ctx); { + tailnetClients, err := c.Dialer.Dial(c.ctx, c.ResumeTokenCtrl) + if err != nil { + if xerrors.Is(err, context.Canceled) { + continue + } + errF := slog.Error(err) + var sdkErr *codersdk.Error + if xerrors.As(err, &sdkErr) { + errF = slog.Error(sdkErr) + } + c.logger.Error(c.ctx, "failed to dial tailnet v2+ API", errF) + continue + } + c.logger.Debug(c.ctx, "obtained tailnet API v2+ client") + c.runControllersOnce(tailnetClients) + c.logger.Debug(c.ctx, "tailnet API v2+ connection lost") + } + }() +} + +// runControllersOnce uses the provided clients to call into the controllers once. It is combined +// into one function so that a problem with one tears down the other and triggers a retry (if +// appropriate). We typically multiplex all RPCs over the same websocket, so we want them to share +// the same fate. +func (c *Controller) runControllersOnce(clients ControlProtocolClients) { + defer func() { + closeErr := clients.Closer.Close() + if closeErr != nil && + !xerrors.Is(closeErr, io.EOF) && + !xerrors.Is(closeErr, context.Canceled) && + !xerrors.Is(closeErr, context.DeadlineExceeded) { + c.logger.Error(c.ctx, "error closing DRPC connection", slog.Error(closeErr)) + } + }() + + if c.TelemetryCtrl != nil { + c.TelemetryCtrl.New(clients.Telemetry) // synchronous, doesn't need a goroutine + } + + wg := sync.WaitGroup{} + + if c.CoordCtrl != nil { + wg.Add(1) + go func() { + defer wg.Done() + c.coordinate(clients.Coordinator) + }() + } + if c.DERPCtrl != nil { + wg.Add(1) + go func() { + defer wg.Done() + dErr := c.derpMap(clients.DERP) + if dErr != nil && c.ctx.Err() == nil { + // The main context is still active, meaning that we want the tailnet data plane to stay + // up, even though we hit some error getting DERP maps on the control plane. That means + // we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just + // close the underlying connection. This will trigger a retry of the control plane in + // run(). + _ = clients.Closer.Close() + // Note that derpMap() logs it own errors, we don't bother here. + } + }() + } + + // Refresh token is a little different, in that we don't want its controller to hold open the + // connection on its own. So we keep it separate from the other wait group, and cancel its + // context as soon as the other routines exit. + refreshTokenCtx, refreshTokenCancel := context.WithCancel(c.ctx) + refreshTokenDone := make(chan struct{}) + defer func() { + <-refreshTokenDone + }() + defer refreshTokenCancel() + go func() { + defer close(refreshTokenDone) + if c.ResumeTokenCtrl != nil { + c.refreshToken(refreshTokenCtx, clients.ResumeToken) + } + }() + + wg.Wait() +} + +func (c *Controller) coordinate(client CoordinatorClient) { + defer func() { + cErr := client.Close() + if cErr != nil { + c.logger.Debug(c.ctx, "error closing Coordinate RPC", slog.Error(cErr)) + } + }() + coordination := c.CoordCtrl.New(client) + c.logger.Debug(c.ctx, "serving coordinator") + select { + case <-c.ctx.Done(): + c.logger.Debug(c.ctx, "main context canceled; do graceful disconnect") + crdErr := coordination.Close(c.gracefulCtx) + if crdErr != nil { + c.logger.Warn(c.ctx, "failed to close remote coordination", slog.Error(crdErr)) + } + case err := <-coordination.Wait(): + if err != nil && + !xerrors.Is(err, io.EOF) && + !xerrors.Is(err, context.Canceled) && + !xerrors.Is(err, context.DeadlineExceeded) { + c.logger.Error(c.ctx, "remote coordination error", slog.Error(err)) + } + } +} + +func (c *Controller) derpMap(client DERPClient) error { + defer func() { + cErr := client.Close() + if cErr != nil { + c.logger.Debug(c.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr)) + } + }() + cw := c.DERPCtrl.New(client) + select { + case <-c.ctx.Done(): + cErr := client.Close() + if cErr != nil { + c.logger.Warn(c.ctx, "failed to close StreamDERPMaps RPC", slog.Error(cErr)) + } + return nil + case err := <-cw.Wait(): + if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + return nil + } + if err != nil && !xerrors.Is(err, io.EOF) { + c.logger.Error(c.ctx, "error receiving DERP Map", slog.Error(err)) + } + return err + } +} + +func (c *Controller) refreshToken(ctx context.Context, client ResumeTokenClient) { + cw := c.ResumeTokenCtrl.New(client) + go func() { + <-ctx.Done() + cErr := cw.Close(c.ctx) + if cErr != nil { + c.logger.Error(c.ctx, "error closing token refresher", slog.Error(cErr)) + } + }() + + err := <-cw.Wait() + if err != nil && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) { + c.logger.Error(c.ctx, "error receiving refresh token", slog.Error(err)) + } +} + +func (c *Controller) Closed() <-chan struct{} { + return c.closedCh +} diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go index d3f88ad23cae3..62f99c0224546 100644 --- a/tailnet/controllers_test.go +++ b/tailnet/controllers_test.go @@ -10,6 +10,8 @@ import ( "time" "github.com/google/uuid" + "github.com/hashicorp/yamux" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "golang.org/x/xerrors" @@ -678,3 +680,202 @@ type fakeResumeTokenCall struct { resp chan *proto.RefreshResumeTokenResponse errCh chan error } + +func TestController_Disconnects(t *testing.T) { + t.Parallel() + testCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(testCtx) + logger := slogtest.Make(t, &slogtest.Options{ + IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, + io.EOF, // we get EOF when we simulate a DERPMap error + yamux.ErrSessionShutdown, // coordination can throw these when DERP error tears down session + ), + }).Leveled(slog.LevelDebug) + agentID := uuid.UUID{0x55} + clientID := uuid.UUID{0x66} + fCoord := tailnettest.NewFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + derpMapCh := make(chan *tailcfg.DERPMap) + defer close(derpMapCh) + svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ + Logger: logger.Named("svc"), + CoordPtr: &coordPtr, + DERPMapUpdateFrequency: time.Millisecond, + DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, + NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {}, + ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), + }) + require.NoError(t, err) + + dialer := &pipeDialer{ + ctx: testCtx, + logger: logger, + t: t, + svc: svc, + streamID: tailnet.StreamID{ + Name: "client", + ID: clientID, + Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID}, + }, + } + + peersLost := make(chan struct{}) + fConn := &fakeTailnetConn{peersLostCh: peersLost} + + uut := tailnet.NewController(logger.Named("tac"), dialer, + // darwin can be slow sometimes. + tailnet.WithGracefulTimeout(5*time.Second)) + uut.CoordCtrl = tailnet.NewAgentCoordinationController(logger.Named("coord_ctrl"), fConn) + uut.DERPCtrl = tailnet.NewBasicDERPController(logger.Named("derp_ctrl"), fConn) + uut.Run(ctx) + + call := testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls) + + // simulate a problem with DERPMaps by sending nil + testutil.RequireSendCtx(testCtx, t, derpMapCh, nil) + + // this should cause the coordinate call to hang up WITHOUT disconnecting + reqNil := testutil.RequireRecvCtx(testCtx, t, call.Reqs) + require.Nil(t, reqNil) + + // and mark all peers lost + _ = testutil.RequireRecvCtx(testCtx, t, peersLost) + + // ...and then reconnect + call = testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls) + + // canceling the context should trigger the disconnect message + cancel() + reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs) + require.NotNil(t, reqDisc) + require.NotNil(t, reqDisc.Disconnect) + close(call.Resps) + + _ = testutil.RequireRecvCtx(testCtx, t, peersLost) +} + +func TestController_TelemetrySuccess(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + agentID := uuid.UUID{0x55} + clientID := uuid.UUID{0x66} + fCoord := tailnettest.NewFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + derpMapCh := make(chan *tailcfg.DERPMap) + defer close(derpMapCh) + eventCh := make(chan []*proto.TelemetryEvent, 1) + svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ + Logger: logger, + CoordPtr: &coordPtr, + DERPMapUpdateFrequency: time.Millisecond, + DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, + NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { + select { + case <-ctx.Done(): + t.Error("timeout sending telemetry event") + case eventCh <- batch: + t.Log("sent telemetry batch") + } + }, + ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), + }) + require.NoError(t, err) + + dialer := &pipeDialer{ + ctx: ctx, + logger: logger, + t: t, + svc: svc, + streamID: tailnet.StreamID{ + Name: "client", + ID: clientID, + Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID}, + }, + } + + uut := tailnet.NewController(logger, dialer) + uut.CoordCtrl = tailnet.NewAgentCoordinationController(logger, &fakeTailnetConn{}) + tel := tailnet.NewBasicTelemetryController(logger) + uut.TelemetryCtrl = tel + uut.Run(ctx) + // Coordinate calls happen _after_ telemetry is connected up, so we use this + // to ensure telemetry is connected before sending our event + cc := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) + defer close(cc.Resps) + + tel.SendTelemetryEvent(&proto.TelemetryEvent{ + Id: []byte("test event"), + }) + + testEvents := testutil.RequireRecvCtx(ctx, t, eventCh) + + require.Len(t, testEvents, 1) + require.Equal(t, []byte("test event"), testEvents[0].Id) +} + +type fakeTailnetConn struct { + peersLostCh chan struct{} +} + +func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error { + // TODO implement me + panic("implement me") +} + +func (f *fakeTailnetConn) SetAllPeersLost() { + if f.peersLostCh == nil { + return + } + f.peersLostCh <- struct{}{} +} + +func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {} + +func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {} + +func (*fakeTailnetConn) SetTunnelDestination(uuid.UUID) {} + +type pipeDialer struct { + ctx context.Context + logger slog.Logger + t testing.TB + svc *tailnet.ClientService + streamID tailnet.StreamID +} + +func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) { + s, c := net.Pipe() + go func() { + err := p.svc.ServeConnV2(p.ctx, s, p.streamID) + p.logger.Debug(p.ctx, "piped tailnet service complete", slog.Error(err)) + }() + client, err := tailnet.NewDRPCClient(c, p.logger) + if !assert.NoError(p.t, err) { + _ = c.Close() + return tailnet.ControlProtocolClients{}, err + } + coord, err := client.Coordinate(context.Background()) + if !assert.NoError(p.t, err) { + _ = c.Close() + return tailnet.ControlProtocolClients{}, err + } + + derps := &tailnet.DERPFromDRPCWrapper{} + derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{}) + if !assert.NoError(p.t, err) { + _ = c.Close() + return tailnet.ControlProtocolClients{}, err + } + return tailnet.ControlProtocolClients{ + Closer: client.DRPCConn(), + Coordinator: coord, + DERP: derps, + ResumeToken: client, + Telemetry: client, + }, nil +}