diff --git a/agent/agent.go b/agent/agent.go index 25e24215d90bb..92ab84a2d0877 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -89,7 +89,6 @@ type Options struct { type Client interface { Manifest(ctx context.Context) (agentsdk.Manifest, error) Listen(ctx context.Context) (drpc.Conn, error) - DERPMapUpdates(ctx context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error) ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error @@ -822,10 +821,22 @@ func (a *agent) run(ctx context.Context) error { network.SetBlockEndpoints(manifest.DisableDirectConnections) } + // Listen returns the dRPC connection we use for both Coordinator and DERPMap updates + conn, err := a.client.Listen(ctx) + if err != nil { + return err + } + defer func() { + cErr := conn.Close() + if cErr != nil { + a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err)) + } + }() + eg, egCtx := errgroup.WithContext(ctx) eg.Go(func() error { a.logger.Debug(egCtx, "running tailnet connection coordinator") - err := a.runCoordinator(egCtx, network) + err := a.runCoordinator(egCtx, conn, network) if err != nil { return xerrors.Errorf("run coordinator: %w", err) } @@ -834,7 +845,7 @@ func (a *agent) run(ctx context.Context) error { eg.Go(func() error { a.logger.Debug(egCtx, "running derp map subscriber") - err := a.runDERPMapSubscriber(egCtx, network) + err := a.runDERPMapSubscriber(egCtx, conn, network) if err != nil { return xerrors.Errorf("run derp map subscriber: %w", err) } @@ -1056,21 +1067,8 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t // runCoordinator runs a coordinator and returns whether a reconnect // should occur. -func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - conn, err := a.client.Listen(ctx) - if err != nil { - return err - } - defer func() { - cErr := conn.Close() - if cErr != nil { - a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err)) - } - }() - +func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tailnet.Conn) error { + defer a.logger.Debug(ctx, "disconnected from coordination RPC") tClient := tailnetproto.NewDRPCTailnetClient(conn) coordinate, err := tClient.Coordinate(ctx) if err != nil { @@ -1082,7 +1080,7 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error a.logger.Debug(ctx, "error closing Coordinate client", slog.Error(err)) } }() - a.logger.Info(ctx, "connected to coordination endpoint") + a.logger.Info(ctx, "connected to coordination RPC") coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) select { case <-ctx.Done(): @@ -1093,30 +1091,29 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error } // runDERPMapSubscriber runs a coordinator and returns if a reconnect should occur. -func (a *agent) runDERPMapSubscriber(ctx context.Context, network *tailnet.Conn) error { +func (a *agent) runDERPMapSubscriber(ctx context.Context, conn drpc.Conn, network *tailnet.Conn) error { + defer a.logger.Debug(ctx, "disconnected from derp map RPC") ctx, cancel := context.WithCancel(ctx) defer cancel() - - updates, closer, err := a.client.DERPMapUpdates(ctx) + tClient := tailnetproto.NewDRPCTailnetClient(conn) + stream, err := tClient.StreamDERPMaps(ctx, &tailnetproto.StreamDERPMapsRequest{}) if err != nil { - return err + return xerrors.Errorf("stream DERP Maps: %w", err) } - defer closer.Close() - - a.logger.Info(ctx, "connected to derp map endpoint") + defer func() { + cErr := stream.Close() + if cErr != nil { + a.logger.Debug(ctx, "error closing DERPMap stream", slog.Error(err)) + } + }() + a.logger.Info(ctx, "connected to derp map RPC") for { - select { - case <-ctx.Done(): - return ctx.Err() - case update := <-updates: - if update.Err != nil { - return update.Err - } - if update.DERPMap != nil && !tailnet.CompareDERPMaps(network.DERPMap(), update.DERPMap) { - a.logger.Info(ctx, "updating derp map due to detected changes") - network.SetDERPMap(update.DERPMap) - } + dmp, err := stream.Recv() + if err != nil { + return xerrors.Errorf("recv DERPMap error: %w", err) } + dm := tailnet.DERPMapFromProto(dmp) + network.SetDERPMap(dm) } } diff --git a/agent/agent_test.go b/agent/agent_test.go index 163c64b78841d..6f3ab55ffd8be 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1349,6 +1349,7 @@ func TestAgent_Lifecycle(t *testing.T) { make(chan *agentsdk.Stats, 50), tailnet.NewCoordinator(logger), ) + defer client.Close() fs := afero.NewMemMapFs() agent := agent.New(agent.Options{ @@ -1683,6 +1684,10 @@ func TestAgent_UpdatedDERP(t *testing.T) { statsCh, coordinator, ) + t.Cleanup(func() { + t.Log("closing client") + client.Close() + }) uut := agent.New(agent.Options{ Client: client, Filesystem: fs, @@ -1690,6 +1695,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { ReconnectingPTYTimeout: time.Minute, }) t.Cleanup(func() { + t.Log("closing agent") _ = uut.Close() }) @@ -1718,6 +1724,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { if err != nil { t.Logf("error closing in-memory coordination: %s", err.Error()) } + t.Logf("closed coordination %s", name) }) // Force DERP. conn.SetBlockEndpoints(true) @@ -1753,11 +1760,9 @@ func TestAgent_UpdatedDERP(t *testing.T) { } // Push a new DERP map to the agent. - err := client.PushDERPMapUpdate(agentsdk.DERPMapUpdate{ - DERPMap: newDerpMap, - }) + err := client.PushDERPMapUpdate(newDerpMap) require.NoError(t, err) - t.Logf("client Pushed DERPMap update") + t.Logf("pushed DERPMap update to agent") require.Eventually(t, func() bool { conn := uut.TailnetConn() @@ -1826,6 +1831,7 @@ func TestAgent_Reconnect(t *testing.T) { statsCh, coordinator, ) + defer client.Close() initialized := atomic.Int32{} closer := agent.New(agent.Options{ ExchangeToken: func(ctx context.Context) (string, error) { @@ -1862,6 +1868,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { make(chan *agentsdk.Stats, 50), coordinator, ) + defer client.Close() filesystem := afero.NewMemMapFs() closer := agent.New(agent.Options{ ExchangeToken: func(ctx context.Context) (string, error) { @@ -2039,6 +2046,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati statsCh := make(chan *agentsdk.Stats, 50) fs := afero.NewMemMapFs() c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator) + t.Cleanup(c.Close) options := agent.Options{ Client: c, diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index ddea2d749e39c..d7c632e7d452f 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -39,12 +39,12 @@ func NewClient(t testing.TB, coordPtr := atomic.Pointer[tailnet.Coordinator]{} coordPtr.Store(&coordinator) mux := drpcmux.New() + derpMapUpdates := make(chan *tailcfg.DERPMap) drpcService := &tailnet.DRPCService{ - CoordPtr: &coordPtr, - Logger: logger, - // TODO: handle DERPMap too! - DerpMapUpdateFrequency: time.Hour, - DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, + CoordPtr: &coordPtr, + Logger: logger, + DerpMapUpdateFrequency: time.Microsecond, + DerpMapFn: func() *tailcfg.DERPMap { return <-derpMapUpdates }, } err := proto.DRPCRegisterTailnet(mux, drpcService) require.NoError(t, err) @@ -64,7 +64,7 @@ func NewClient(t testing.TB, statsChan: statsChan, coordinator: coordinator, server: server, - derpMapUpdates: make(chan agentsdk.DERPMapUpdate), + derpMapUpdates: derpMapUpdates, } } @@ -85,23 +85,26 @@ type Client struct { lifecycleStates []codersdk.WorkspaceAgentLifecycle startup agentsdk.PostStartupRequest logs []agentsdk.Log - derpMapUpdates chan agentsdk.DERPMapUpdate + derpMapUpdates chan *tailcfg.DERPMap + derpMapOnce sync.Once +} + +func (c *Client) Close() { + c.derpMapOnce.Do(func() { close(c.derpMapUpdates) }) } func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) { return c.manifest, nil } -func (c *Client) Listen(_ context.Context) (drpc.Conn, error) { +func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) { conn, lis := drpcsdk.MemTransportPipe() - closed := make(chan struct{}) c.LastWorkspaceAgent = func() { _ = conn.Close() _ = lis.Close() - <-closed } c.t.Cleanup(c.LastWorkspaceAgent) - serveCtx, cancel := context.WithCancel(context.Background()) + serveCtx, cancel := context.WithCancel(ctx) c.t.Cleanup(cancel) auth := tailnet.AgentTunnelAuth{} streamID := tailnet.StreamID{ @@ -112,7 +115,6 @@ func (c *Client) Listen(_ context.Context) (drpc.Conn, error) { serveCtx = tailnet.WithStreamID(serveCtx, streamID) go func() { _ = c.server.Serve(serveCtx, lis) - close(closed) }() return conn, nil } @@ -235,7 +237,7 @@ func (c *Client) GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerCo return codersdk.ServiceBannerConfig{}, nil } -func (c *Client) PushDERPMapUpdate(update agentsdk.DERPMapUpdate) error { +func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error { timer := time.NewTimer(testutil.WaitShort) defer timer.Stop() select { @@ -247,14 +249,6 @@ func (c *Client) PushDERPMapUpdate(update agentsdk.DERPMapUpdate) error { return nil } -func (c *Client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error) { - closed := make(chan struct{}) - return c.derpMapUpdates, closeFunc(func() error { - close(closed) - return nil - }), nil -} - type closeFunc func() error func (c closeFunc) Close() error { diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index 392bc8d306f49..f3c1876b3c97f 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -178,6 +178,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A }) c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord) + t.Cleanup(c.Close) options := agent.Options{ Client: c, diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 8a66e3ba0364f..ae01bf5785aee 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -171,13 +171,16 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati _ = coordinator.Close() }) manifest.AgentID = uuid.New() + aC := &client{ + t: t, + agentID: manifest.AgentID, + manifest: manifest, + coordinator: coordinator, + derpMapUpdates: make(chan *tailcfg.DERPMap), + } + t.Cleanup(aC.close) closer := agent.New(agent.Options{ - Client: &client{ - t: t, - agentID: manifest.AgentID, - manifest: manifest, - coordinator: coordinator, - }, + Client: aC, Logger: logger.Named("agent"), ReconnectingPTYTimeout: ptyTimeout, Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, @@ -230,52 +233,37 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati } type client struct { - t *testing.T - agentID uuid.UUID - manifest agentsdk.Manifest - coordinator tailnet.Coordinator -} - -func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) { - return c.manifest, nil + t *testing.T + agentID uuid.UUID + manifest agentsdk.Manifest + coordinator tailnet.Coordinator + closeOnce sync.Once + derpMapUpdates chan *tailcfg.DERPMap } -type closer struct { - closeFunc func() error +func (c *client) close() { + c.closeOnce.Do(func() { close(c.derpMapUpdates) }) } -func (c *closer) Close() error { - return c.closeFunc() -} - -func (*client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error) { - closed := make(chan struct{}) - return make(<-chan agentsdk.DERPMapUpdate), &closer{ - closeFunc: func() error { - close(closed) - return nil - }, - }, nil +func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) { + return c.manifest, nil } func (c *client) Listen(_ context.Context) (drpc.Conn, error) { logger := slogtest.Make(c.t, nil).Leveled(slog.LevelDebug).Named("drpc") conn, lis := drpcsdk.MemTransportPipe() - closed := make(chan struct{}) c.t.Cleanup(func() { _ = conn.Close() _ = lis.Close() - <-closed }) coordPtr := atomic.Pointer[tailnet.Coordinator]{} coordPtr.Store(&c.coordinator) mux := drpcmux.New() drpcService := &tailnet.DRPCService{ - CoordPtr: &coordPtr, - Logger: logger, - // TODO: handle DERPMap too! - DerpMapUpdateFrequency: time.Hour, - DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, + CoordPtr: &coordPtr, + Logger: logger, + DerpMapUpdateFrequency: time.Microsecond, + DerpMapFn: func() *tailcfg.DERPMap { return <-c.derpMapUpdates }, } err := proto.DRPCRegisterTailnet(mux, drpcService) if err != nil { @@ -302,7 +290,6 @@ func (c *client) Listen(_ context.Context) (drpc.Conn, error) { serveCtx = tailnet.WithStreamID(serveCtx, streamID) go func() { server.Serve(serveCtx, lis) - close(closed) }() return conn, nil } diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index f1c29c4517441..2e94b6cd42cc8 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -192,96 +192,6 @@ func (c *Client) rewriteDerpMap(derpMap *tailcfg.DERPMap) error { return nil } -type DERPMapUpdate struct { - Err error - DERPMap *tailcfg.DERPMap -} - -// DERPMapUpdates connects to the DERP map updates WebSocket. -func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.Closer, error) { - derpMapURL, err := c.SDK.URL.Parse("/api/v2/derp-map") - if err != nil { - return nil, nil, xerrors.Errorf("parse url: %w", err) - } - jar, err := cookiejar.New(nil) - if err != nil { - return nil, nil, xerrors.Errorf("create cookie jar: %w", err) - } - jar.SetCookies(derpMapURL, []*http.Cookie{{ - Name: codersdk.SessionTokenCookie, - Value: c.SDK.SessionToken(), - }}) - httpClient := &http.Client{ - Jar: jar, - Transport: c.SDK.HTTPClient.Transport, - } - // nolint:bodyclose - conn, res, err := websocket.Dial(ctx, derpMapURL.String(), &websocket.DialOptions{ - HTTPClient: httpClient, - }) - if err != nil { - if res == nil { - return nil, nil, err - } - return nil, nil, codersdk.ReadBodyAsError(res) - } - - ctx, cancelFunc := context.WithCancel(ctx) - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) - pingClosed := pingWebSocket(ctx, c.SDK.Logger(), conn, "derp map") - - var ( - updates = make(chan DERPMapUpdate) - updatesClosed = make(chan struct{}) - dec = json.NewDecoder(wsNetConn) - ) - go func() { - defer close(updates) - defer close(updatesClosed) - defer cancelFunc() - defer conn.Close(websocket.StatusGoingAway, "DERPMapUpdates closed") - for { - var update DERPMapUpdate - err := dec.Decode(&update.DERPMap) - if err != nil { - update.Err = err - update.DERPMap = nil - } - if update.DERPMap != nil { - err = c.rewriteDerpMap(update.DERPMap) - if err != nil { - update.Err = err - update.DERPMap = nil - } - } - - select { - case updates <- update: - case <-ctx.Done(): - // Unblock the caller if they're waiting for an update. - select { - case updates <- DERPMapUpdate{Err: ctx.Err()}: - default: - } - return - } - if update.Err != nil { - return - } - } - }() - - return updates, &closer{ - closeFunc: func() error { - cancelFunc() - <-pingClosed - _ = conn.Close(websocket.StatusGoingAway, "DERPMapUpdates closed") - <-updatesClosed - return nil - }, - }, nil -} - // Listen connects to the workspace agent API WebSocket // that handles connection negotiation. func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) { @@ -902,11 +812,3 @@ func pingWebSocket(ctx context.Context, logger slog.Logger, conn *websocket.Conn return closed } - -type closer struct { - closeFunc func() error -} - -func (c *closer) Close() error { - return c.closeFunc() -} diff --git a/tailnet/configmaps.go b/tailnet/configmaps.go index 7579140c9f604..4dd307536e0f6 100644 --- a/tailnet/configmaps.go +++ b/tailnet/configmaps.go @@ -2,6 +2,7 @@ package tailnet import ( "context" + "encoding/json" "errors" "fmt" "net/netip" @@ -146,14 +147,14 @@ func (c *configMaps) configLoop() { if c.derpMapDirty { derpMap := c.derpMapLocked() actions = append(actions, func() { - c.logger.Debug(context.Background(), "updating engine DERP map", slog.F("derp_map", derpMap)) + c.logger.Info(context.Background(), "updating engine DERP map", slog.F("derp_map", (*derpMapStringer)(derpMap))) c.engine.SetDERPMap(derpMap) }) } if c.netmapDirty { nm := c.netMapLocked() actions = append(actions, func() { - c.logger.Debug(context.Background(), "updating engine network map", slog.F("network_map", nm)) + c.logger.Info(context.Background(), "updating engine network map", slog.F("network_map", nm)) c.engine.SetNetworkMap(nm) c.reconfig(nm) }) @@ -161,7 +162,7 @@ func (c *configMaps) configLoop() { if c.filterDirty { f := c.filterLocked() actions = append(actions, func() { - c.logger.Debug(context.Background(), "updating engine filter", slog.F("filter", f)) + c.logger.Info(context.Background(), "updating engine filter", slog.F("filter", f)) c.engine.SetFilter(f) }) } @@ -570,3 +571,16 @@ func prefixesDifferent(a, b []netip.Prefix) bool { } return false } + +// derpMapStringer converts a DERPMap into a readable string for logging, since +// it includes pointers that we want to know the contents of, not actual pointer +// address. +type derpMapStringer tailcfg.DERPMap + +func (d *derpMapStringer) String() string { + out, err := json.Marshal((*tailcfg.DERPMap)(d)) + if err != nil { + return fmt.Sprintf("!!!error marshaling DERPMap: %s", err.Error()) + } + return string(out) +} diff --git a/tailnet/service.go b/tailnet/service.go index 02bc50a57113a..3be0abcab6ded 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -132,6 +132,10 @@ func (s *DRPCService) StreamDERPMaps(_ *proto.StreamDERPMapsRequest, stream prot var lastDERPMap *tailcfg.DERPMap for { derpMap := s.DerpMapFn() + if derpMap == nil { + // in testing, we send nil to close the stream. + return io.EOF + } if lastDERPMap == nil || !CompareDERPMaps(lastDERPMap, derpMap) { protoDERPMap := DERPMapToProto(derpMap) err := stream.Send(protoDERPMap)