From acdd09dc3558a4ce8291ef205150fae2c2b44ff4 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Tue, 26 Nov 2024 02:26:18 +0000 Subject: [PATCH] chore(vpn): upsert agents with their network status --- .golangci.yaml | 7 +- vpn/client.go | 2 + vpn/tunnel.go | 325 +++++++++++++++++++++++++++--------- vpn/tunnel_internal_test.go | 250 ++++++++++++++++++++++++++- 4 files changed, 496 insertions(+), 88 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index fd8946319ca1d..aee26ad272f16 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -175,8 +175,6 @@ linters-settings: - name: modifies-value-receiver - name: package-comments - name: range - - name: range-val-address - - name: range-val-in-closure - name: receiver-naming - name: redefines-builtin-id - name: string-of-int @@ -199,6 +197,10 @@ linters-settings: govet: disable: - loopclosure + gosec: + excludes: + # Implicit memory aliasing of items from a range statement (irrelevant as of Go v1.22) + - G601 issues: # Rules listed here: https://github.com/securego/gosec#available-rules @@ -238,7 +240,6 @@ linters: - errname - errorlint - exhaustruct - - exportloopref - forcetypeassert - gocritic # gocyclo is may be useful in the future when we start caring diff --git a/vpn/client.go b/vpn/client.go index 954f27a7aa668..06ccbd1322f5b 100644 --- a/vpn/client.go +++ b/vpn/client.go @@ -11,6 +11,7 @@ import ( "tailscale.com/net/dns" "tailscale.com/wgengine/router" + "github.com/google/uuid" "github.com/tailscale/wireguard-go/tun" "cdr.dev/slog" @@ -23,6 +24,7 @@ import ( type Conn interface { CurrentWorkspaceState() (tailnet.WorkspaceUpdate, error) + GetPeerDiagnostics(peerID uuid.UUID) tailnet.PeerDiagnostics Close() error } diff --git a/vpn/tunnel.go b/vpn/tunnel.go index dae555483cf99..6d6983b03946f 100644 --- a/vpn/tunnel.go +++ b/vpn/tunnel.go @@ -7,23 +7,36 @@ import ( "fmt" "io" "net/http" + "net/netip" "net/url" "reflect" "strconv" "sync" + "time" "unicode" "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/net/dns" + "tailscale.com/util/dnsname" "tailscale.com/wgengine/router" + "github.com/google/uuid" + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/tailnet" + "github.com/coder/quartz" ) +// netStatusInterval is the interval at which the tunnel sends network status updates to the manager. +// This is currently only used to keep `last_handshake` up to date. +const netStatusInterval = 10 * time.Second + type Tunnel struct { speaker[*TunnelMessage, *ManagerMessage, ManagerMessage] + updater + ctx context.Context requestLoopDone chan struct{} @@ -33,7 +46,6 @@ type Tunnel struct { logs []*TunnelMessage client Client - conn Conn // clientLogger is a separate logger than `logger` when the `UseAsLogger` // option is used, to avoid the tunnel using itself as a sink for it's own @@ -67,6 +79,13 @@ func NewTunnel( clientLogger: logger, requestLoopDone: make(chan struct{}), client: client, + updater: updater{ + ctx: ctx, + netLoopDone: make(chan struct{}), + uSendCh: s.sendCh, + agents: map[uuid.UUID]tailnet.Agent{}, + clock: quartz.NewReal(), + }, } for _, opt := range opts { @@ -74,6 +93,7 @@ func NewTunnel( } t.speaker.start() go t.requestLoop() + go t.netStatusLoop() return t, nil } @@ -81,16 +101,9 @@ func (t *Tunnel) requestLoop() { defer close(t.requestLoopDone) for req := range t.speaker.requests { if req.msg.Rpc != nil && req.msg.Rpc.MsgId != 0 { - resp := t.handleRPC(req.msg, req.msg.Rpc.MsgId) - if err := req.sendReply(resp); err != nil { - t.logger.Debug(t.ctx, "failed to send RPC reply", slog.Error(err)) - } - if _, ok := resp.GetMsg().(*TunnelMessage_Stop); ok { - // TODO: Wait for the reply to be sent before closing the speaker. - // err := t.speaker.Close() - // if err != nil { - // t.logger.Error(t.ctx, "failed to close speaker", slog.Error(err)) - // } + t.handleRPC(req) + if _, ok := req.msg.GetMsg().(*ManagerMessage_Stop); ok { + close(t.sendCh) return } continue @@ -102,20 +115,18 @@ func (t *Tunnel) requestLoop() { } } -// handleRPC handles unary RPCs from the manager. -func (t *Tunnel) handleRPC(req *ManagerMessage, msgID uint64) *TunnelMessage { +// handleRPC handles unary RPCs from the manager, sending a reply back to the manager. +func (t *Tunnel) handleRPC(req *request[*TunnelMessage, *ManagerMessage]) { resp := &TunnelMessage{} - resp.Rpc = &RPC{ResponseTo: msgID} - switch msg := req.GetMsg().(type) { + resp.Rpc = &RPC{ResponseTo: req.msg.Rpc.MsgId} + switch msg := req.msg.GetMsg().(type) { case *ManagerMessage_GetPeerUpdate: - state, err := t.conn.CurrentWorkspaceState() + err := t.updater.sendUpdateResponse(req) if err != nil { - t.logger.Critical(t.ctx, "failed to get current workspace state", slog.Error(err)) + t.logger.Error(t.ctx, "failed to send peer update", slog.Error(err)) } - resp.Msg = &TunnelMessage_PeerUpdate{ - PeerUpdate: convertWorkspaceUpdate(state), - } - return resp + // Reply has already been sent. + return case *ManagerMessage_Start: startReq := msg.Start t.logger.Info(t.ctx, "starting CoderVPN tunnel", @@ -134,7 +145,6 @@ func (t *Tunnel) handleRPC(req *ManagerMessage, msgID uint64) *TunnelMessage { ErrorMessage: errStr, }, } - return resp case *ManagerMessage_Stop: t.logger.Info(t.ctx, "stopping CoderVPN tunnel") err := t.stop(msg.Stop) @@ -151,10 +161,11 @@ func (t *Tunnel) handleRPC(req *ManagerMessage, msgID uint64) *TunnelMessage { ErrorMessage: errStr, }, } - return resp default: t.logger.Warn(t.ctx, "unhandled manager request", slog.F("request", msg)) - return resp + } + if err := req.sendReply(resp); err != nil { + t.logger.Debug(t.ctx, "failed to send RPC reply", slog.Error(err)) } } @@ -176,6 +187,12 @@ func UseAsDNSConfig() TunnelOption { } } +func WithClock(clock quartz.Clock) TunnelOption { + return func(t *Tunnel) { + t.clock = clock + } +} + // ApplyNetworkSettings sends a request to the manager to apply the given network settings func (t *Tunnel) ApplyNetworkSettings(ctx context.Context, ns *NetworkSettingsRequest) error { msg, err := t.speaker.unaryRPC(ctx, &TunnelMessage{ @@ -193,20 +210,6 @@ func (t *Tunnel) ApplyNetworkSettings(ctx context.Context, ns *NetworkSettingsRe return nil } -func (t *Tunnel) Update(update tailnet.WorkspaceUpdate) error { - msg := &TunnelMessage{ - Msg: &TunnelMessage_PeerUpdate{ - PeerUpdate: convertWorkspaceUpdate(update), - }, - } - select { - case <-t.ctx.Done(): - return t.ctx.Err() - case t.sendCh <- msg: - } - return nil -} - func (t *Tunnel) start(req *StartRequest) error { rawURL := req.GetCoderUrl() if rawURL == "" { @@ -225,31 +228,31 @@ func (t *Tunnel) start(req *StartRequest) error { header.Add(h.GetName(), h.GetValue()) } - if t.conn == nil { - t.conn, err = t.client.NewConn( - t.ctx, - svrURL, - apiToken, - &Options{ - Headers: header, - Logger: t.clientLogger, - DNSConfigurator: t.dnsConfigurator, - Router: t.router, - TUNFileDescriptor: ptr.Ref(int(req.GetTunnelFileDescriptor())), - UpdateHandler: t, - }, - ) - } else { + conn, err := t.client.NewConn( + t.ctx, + svrURL, + apiToken, + &Options{ + Headers: header, + Logger: t.clientLogger, + DNSConfigurator: t.dnsConfigurator, + Router: t.router, + TUNFileDescriptor: ptr.Ref(int(req.GetTunnelFileDescriptor())), + UpdateHandler: t, + }, + ) + if err != nil { + return xerrors.Errorf("failed to start connection: %w", err) + } + + if ok := t.updater.setConn(conn); !ok { t.logger.Warn(t.ctx, "asked to start tunnel, but tunnel is already running") } return err } func (t *Tunnel) stop(*StopRequest) error { - if t.conn == nil { - return nil - } - return t.conn.Close() + return t.updater.stop() } var _ slog.Sink = &Tunnel{} @@ -293,13 +296,75 @@ func sinkEntryToPb(e slog.SinkEntry) *Log { return l } -func convertWorkspaceUpdate(update tailnet.WorkspaceUpdate) *PeerUpdate { +// updater is the component of the tunnel responsible for sending workspace +// updates to the manager. +type updater struct { + ctx context.Context + netLoopDone chan struct{} + + mu sync.Mutex + uSendCh chan<- *TunnelMessage + // agents contains the agents that are currently connected to the tunnel. + agents map[uuid.UUID]tailnet.Agent + conn Conn + + clock quartz.Clock +} + +// Update pushes a workspace update to the manager +func (u *updater) Update(update tailnet.WorkspaceUpdate) error { + u.mu.Lock() + defer u.mu.Unlock() + + peerUpdate := u.createPeerUpdateLocked(update) + msg := &TunnelMessage{ + Msg: &TunnelMessage_PeerUpdate{ + PeerUpdate: peerUpdate, + }, + } + select { + case <-u.ctx.Done(): + return u.ctx.Err() + case u.uSendCh <- msg: + } + return nil +} + +// sendUpdateResponse responds to the provided `ManagerMessage_GetPeerUpdate` request +// with the current state of the workspaces. +func (u *updater) sendUpdateResponse(req *request[*TunnelMessage, *ManagerMessage]) error { + u.mu.Lock() + defer u.mu.Unlock() + + state, err := u.conn.CurrentWorkspaceState() + if err != nil { + return xerrors.Errorf("failed to get current workspace state: %w", err) + } + update := u.createPeerUpdateLocked(state) + resp := &TunnelMessage{ + Msg: &TunnelMessage_PeerUpdate{ + PeerUpdate: update, + }, + } + err = req.sendReply(resp) + if err != nil { + return xerrors.Errorf("failed to send RPC reply: %w", err) + } + return nil +} + +// createPeerUpdateLocked creates a PeerUpdate message from a workspace update, populating +// the network status of the agents. +func (u *updater) createPeerUpdateLocked(update tailnet.WorkspaceUpdate) *PeerUpdate { out := &PeerUpdate{ UpsertedWorkspaces: make([]*Workspace, len(update.UpsertedWorkspaces)), UpsertedAgents: make([]*Agent, len(update.UpsertedAgents)), DeletedWorkspaces: make([]*Workspace, len(update.DeletedWorkspaces)), DeletedAgents: make([]*Agent, len(update.DeletedAgents)), } + + u.saveUpdateLocked(update) + for i, ws := range update.UpsertedWorkspaces { out.UpsertedWorkspaces[i] = &Workspace{ Id: tailnet.UUIDToByteSlice(ws.ID), @@ -307,21 +372,8 @@ func convertWorkspaceUpdate(update tailnet.WorkspaceUpdate) *PeerUpdate { Status: Workspace_Status(ws.Status), } } - for i, agent := range update.UpsertedAgents { - fqdn := make([]string, 0, len(agent.Hosts)) - for name := range agent.Hosts { - fqdn = append(fqdn, name.WithTrailingDot()) - } - out.UpsertedAgents[i] = &Agent{ - Id: tailnet.UUIDToByteSlice(agent.ID), - Name: agent.Name, - WorkspaceId: tailnet.UUIDToByteSlice(agent.WorkspaceID), - Fqdn: fqdn, - IpAddrs: []string{tailnet.CoderServicePrefix.AddrFromUUID(agent.ID).String()}, - // TODO: Populate - LastHandshake: nil, - } - } + upsertedAgents := u.convertAgentsLocked(update.UpsertedAgents) + out.UpsertedAgents = upsertedAgents for i, ws := range update.DeletedWorkspaces { out.DeletedWorkspaces[i] = &Workspace{ Id: tailnet.UUIDToByteSlice(ws.ID), @@ -335,18 +387,139 @@ func convertWorkspaceUpdate(update tailnet.WorkspaceUpdate) *PeerUpdate { fqdn = append(fqdn, name.WithTrailingDot()) } out.DeletedAgents[i] = &Agent{ + Id: tailnet.UUIDToByteSlice(agent.ID), + Name: agent.Name, + WorkspaceId: tailnet.UUIDToByteSlice(agent.WorkspaceID), + Fqdn: fqdn, + IpAddrs: hostsToIPStrings(agent.Hosts), + LastHandshake: nil, + } + } + return out +} + +// convertAgentsLocked takes a list of `tailnet.Agent` and converts them to proto agents. +// If there is an active connection, the last handshake time is populated. +func (u *updater) convertAgentsLocked(agents []*tailnet.Agent) []*Agent { + out := make([]*Agent, 0, len(agents)) + + for _, agent := range agents { + fqdn := make([]string, 0, len(agent.Hosts)) + for name := range agent.Hosts { + fqdn = append(fqdn, name.WithTrailingDot()) + } + protoAgent := &Agent{ Id: tailnet.UUIDToByteSlice(agent.ID), Name: agent.Name, WorkspaceId: tailnet.UUIDToByteSlice(agent.WorkspaceID), Fqdn: fqdn, - IpAddrs: []string{tailnet.CoderServicePrefix.AddrFromUUID(agent.ID).String()}, - // TODO: Populate - LastHandshake: nil, + IpAddrs: hostsToIPStrings(agent.Hosts), } + if u.conn != nil { + diags := u.conn.GetPeerDiagnostics(agent.ID) + protoAgent.LastHandshake = timestamppb.New(diags.LastWireguardHandshake) + } + out = append(out, protoAgent) } + return out } +// saveUpdateLocked saves the workspace update to the tunnel's state, such that it can +// be used to populate automated peer updates. +func (u *updater) saveUpdateLocked(update tailnet.WorkspaceUpdate) { + for _, agent := range update.UpsertedAgents { + u.agents[agent.ID] = agent.Clone() + } + for _, agent := range update.DeletedAgents { + delete(u.agents, agent.ID) + } +} + +// setConn sets the `conn` and returns false if there's already a connection set. +func (u *updater) setConn(conn Conn) bool { + u.mu.Lock() + defer u.mu.Unlock() + + if u.conn != nil { + return false + } + u.conn = conn + return true +} + +func (u *updater) stop() error { + u.mu.Lock() + defer u.mu.Unlock() + + if u.conn == nil { + return nil + } + err := u.conn.Close() + u.conn = nil + return err +} + +// sendAgentUpdate sends a peer update message to the manager with the current +// state of the agents, including the latest network status. +func (u *updater) sendAgentUpdate() { + u.mu.Lock() + defer u.mu.Unlock() + + agents := make([]*tailnet.Agent, 0, len(u.agents)) + for _, agent := range u.agents { + agents = append(agents, &agent) + } + upsertedAgents := u.convertAgentsLocked(agents) + if len(upsertedAgents) == 0 { + return + } + + msg := &TunnelMessage{ + Msg: &TunnelMessage_PeerUpdate{ + PeerUpdate: &PeerUpdate{ + UpsertedAgents: upsertedAgents, + }, + }, + } + + select { + case <-u.ctx.Done(): + return + case u.uSendCh <- msg: + } +} + +func (u *updater) netStatusLoop() { + ticker := u.clock.NewTicker(netStatusInterval) + defer ticker.Stop() + defer close(u.netLoopDone) + for { + select { + case <-u.ctx.Done(): + return + case <-ticker.C: + u.sendAgentUpdate() + } + } +} + +// hostsToIPStrings returns a slice of all unique IP addresses in the values +// of the given map. +func hostsToIPStrings(hosts map[dnsname.FQDN][]netip.Addr) []string { + seen := make(map[netip.Addr]struct{}) + var result []string + for _, inner := range hosts { + for _, elem := range inner { + if _, exists := seen[elem]; !exists { + seen[elem] = struct{}{} + result = append(result, elem.String()) + } + } + } + return result +} + // the following are taken from sloghuman: func formatValue(v interface{}) string { diff --git a/vpn/tunnel_internal_test.go b/vpn/tunnel_internal_test.go index ed5a7e429dccd..8a55205605d7d 100644 --- a/vpn/tunnel_internal_test.go +++ b/vpn/tunnel_internal_test.go @@ -3,15 +3,23 @@ package vpn import ( "context" "net" + "net/netip" "net/url" + "slices" + "strings" "sync" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + "tailscale.com/util/dnsname" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" ) func newFakeClient(ctx context.Context, t *testing.T) *fakeClient { @@ -39,15 +47,17 @@ func (f *fakeClient) NewConn(context.Context, *url.URL, string, *Options) (Conn, } } -func newFakeConn(state tailnet.WorkspaceUpdate) *fakeConn { +func newFakeConn(state tailnet.WorkspaceUpdate, hsTime time.Time) *fakeConn { return &fakeConn{ closed: make(chan struct{}), state: state, + hsTime: hsTime, } } type fakeConn struct { state tailnet.WorkspaceUpdate + hsTime time.Time closed chan struct{} doClose sync.Once } @@ -58,6 +68,12 @@ func (f *fakeConn) CurrentWorkspaceState() (tailnet.WorkspaceUpdate, error) { return f.state, nil } +func (f *fakeConn) GetPeerDiagnostics(uuid.UUID) tailnet.PeerDiagnostics { + return tailnet.PeerDiagnostics{ + LastWireguardHandshake: f.hsTime, + } +} + func (f *fakeConn) Close() error { f.doClose.Do(func() { close(f.closed) @@ -70,9 +86,9 @@ func TestTunnel_StartStop(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) client := newFakeClient(ctx, t) - conn := newFakeConn(tailnet.WorkspaceUpdate{}) + conn := newFakeConn(tailnet.WorkspaceUpdate{}, time.Time{}) - _, mgr := setupTunnel(t, ctx, client) + _, mgr := setupTunnel(t, ctx, client, quartz.NewMock(t)) errCh := make(chan error, 1) var resp *TunnelMessage @@ -136,9 +152,9 @@ func TestTunnel_PeerUpdate(t *testing.T) { ID: wsID2, }, }, - }) + }, time.Time{}) - tun, mgr := setupTunnel(t, ctx, client) + tun, mgr := setupTunnel(t, ctx, client, quartz.NewMock(t)) errCh := make(chan error, 1) var resp *TunnelMessage @@ -161,6 +177,7 @@ func TestTunnel_PeerUpdate(t *testing.T) { _, ok := resp.Msg.(*TunnelMessage_Start) require.True(t, ok) + // When: we inform the tunnel of a WorkspaceUpdate err = tun.Update(tailnet.WorkspaceUpdate{ UpsertedWorkspaces: []*tailnet.Workspace{ { @@ -200,9 +217,9 @@ func TestTunnel_NetworkSettings(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) client := newFakeClient(ctx, t) - conn := newFakeConn(tailnet.WorkspaceUpdate{}) + conn := newFakeConn(tailnet.WorkspaceUpdate{}, time.Time{}) - tun, mgr := setupTunnel(t, ctx, client) + tun, mgr := setupTunnel(t, ctx, client, quartz.NewMock(t)) errCh := make(chan error, 1) var resp *TunnelMessage @@ -251,8 +268,223 @@ func TestTunnel_NetworkSettings(t *testing.T) { require.NoError(t, err) } +func TestUpdater_createPeerUpdate(t *testing.T) { + t.Parallel() + + w1ID := uuid.UUID{1} + w2ID := uuid.UUID{2} + w1a1ID := uuid.UUID{4} + w2a1ID := uuid.UUID{5} + w1a1IP := netip.MustParseAddr("fd60:627a:a42b:0101::") + w2a1IP := netip.MustParseAddr("fd60:627a:a42b:0301::") + + ctx := testutil.Context(t, testutil.WaitShort) + + hsTime := time.Now().Add(-time.Minute).UTC() + updater := updater{ + ctx: ctx, + netLoopDone: make(chan struct{}), + agents: map[uuid.UUID]tailnet.Agent{}, + conn: newFakeConn(tailnet.WorkspaceUpdate{}, hsTime), + } + + update := updater.createPeerUpdateLocked(tailnet.WorkspaceUpdate{ + UpsertedWorkspaces: []*tailnet.Workspace{ + {ID: w1ID, Name: "w1", Status: proto.Workspace_STARTING}, + }, + UpsertedAgents: []*tailnet.Agent{ + { + ID: w1a1ID, Name: "w1a1", WorkspaceID: w1ID, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "w1.coder.": {w1a1IP}, + "w1a1.w1.me.coder.": {w1a1IP}, + "w1a1.w1.testy.coder.": {w1a1IP}, + }, + }, + }, + DeletedWorkspaces: []*tailnet.Workspace{ + {ID: w2ID, Name: "w2", Status: proto.Workspace_STOPPED}, + }, + DeletedAgents: []*tailnet.Agent{ + { + ID: w2a1ID, Name: "w2a1", WorkspaceID: w2ID, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "w2.coder.": {w2a1IP}, + "w2a1.w2.me.coder.": {w2a1IP}, + "w2a1.w2.testy.coder.": {w2a1IP}, + }, + }, + }, + }) + require.Len(t, update.UpsertedAgents, 1) + slices.SortFunc(update.UpsertedAgents[0].Fqdn, func(a, b string) int { + return strings.Compare(a, b) + }) + slices.SortFunc(update.DeletedAgents[0].Fqdn, func(a, b string) int { + return strings.Compare(a, b) + }) + require.Equal(t, update, &PeerUpdate{ + UpsertedWorkspaces: []*Workspace{ + {Id: w1ID[:], Name: "w1", Status: Workspace_Status(proto.Workspace_STARTING)}, + }, + UpsertedAgents: []*Agent{ + { + Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:], + Fqdn: []string{"w1.coder.", "w1a1.w1.me.coder.", "w1a1.w1.testy.coder."}, + IpAddrs: []string{w1a1IP.String()}, + LastHandshake: timestamppb.New(hsTime), + }, + }, + DeletedWorkspaces: []*Workspace{ + {Id: w2ID[:], Name: "w2", Status: Workspace_Status(proto.Workspace_STOPPED)}, + }, + DeletedAgents: []*Agent{ + { + Id: w2a1ID[:], Name: "w2a1", WorkspaceId: w2ID[:], + Fqdn: []string{"w2.coder.", "w2a1.w2.me.coder.", "w2a1.w2.testy.coder."}, + IpAddrs: []string{w2a1IP.String()}, + LastHandshake: nil, + }, + }, + }) +} + +func TestTunnel_sendAgentUpdate(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + + mClock := quartz.NewMock(t) + + wID1 := uuid.UUID{1} + aID1 := uuid.UUID{2} + aID2 := uuid.UUID{3} + hsTime := time.Now().Add(-time.Minute).UTC() + + client := newFakeClient(ctx, t) + conn := newFakeConn(tailnet.WorkspaceUpdate{}, hsTime) + + tun, mgr := setupTunnel(t, ctx, client, mClock) + errCh := make(chan error, 1) + var resp *TunnelMessage + go func() { + r, err := mgr.unaryRPC(ctx, &ManagerMessage{ + Msg: &ManagerMessage_Start{ + Start: &StartRequest{ + TunnelFileDescriptor: 2, + CoderUrl: "https://coder.example.com", + ApiToken: "fakeToken", + }, + }, + }) + resp = r + errCh <- err + }() + testutil.RequireSendCtx(ctx, t, client.ch, conn) + err := testutil.RequireRecvCtx(ctx, t, errCh) + require.NoError(t, err) + _, ok := resp.Msg.(*TunnelMessage_Start) + require.True(t, ok) + + // Inform the tunnel of the initial state + err = tun.Update(tailnet.WorkspaceUpdate{ + UpsertedWorkspaces: []*tailnet.Workspace{ + { + ID: wID1, Name: "w1", Status: proto.Workspace_STARTING, + }, + }, + UpsertedAgents: []*tailnet.Agent{ + { + ID: aID1, + Name: "agent1", + WorkspaceID: wID1, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "agent1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")}, + }, + }, + }, + }) + require.NoError(t, err) + req := testutil.RequireRecvCtx(ctx, t, mgr.requests) + require.Nil(t, req.msg.Rpc) + require.NotNil(t, req.msg.GetPeerUpdate()) + require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1) + require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id) + + // `sendAgentUpdate` produces the same PeerUpdate message until an agent + // update is received + for range 2 { + mClock.AdvanceNext() + // Then: the tunnel sends a PeerUpdate message of agent upserts, + // with the last handshake and latency set + req = testutil.RequireRecvCtx(ctx, t, mgr.requests) + require.Nil(t, req.msg.Rpc) + require.NotNil(t, req.msg.GetPeerUpdate()) + require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1) + require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id) + require.Equal(t, hsTime, req.msg.GetPeerUpdate().UpsertedAgents[0].LastHandshake.AsTime()) + } + + // Upsert a new agent + err = tun.Update(tailnet.WorkspaceUpdate{ + UpsertedWorkspaces: []*tailnet.Workspace{}, + UpsertedAgents: []*tailnet.Agent{ + { + ID: aID2, + Name: "agent2", + WorkspaceID: wID1, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "agent2.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")}, + }, + }, + }, + }) + require.NoError(t, err) + testutil.RequireRecvCtx(ctx, t, mgr.requests) + + // The new update includes the new agent + mClock.AdvanceNext() + req = testutil.RequireRecvCtx(ctx, t, mgr.requests) + require.Nil(t, req.msg.Rpc) + require.NotNil(t, req.msg.GetPeerUpdate()) + require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 2) + slices.SortFunc(req.msg.GetPeerUpdate().UpsertedAgents, func(a, b *Agent) int { + return strings.Compare(a.Name, b.Name) + }) + + require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id) + require.Equal(t, hsTime, req.msg.GetPeerUpdate().UpsertedAgents[0].LastHandshake.AsTime()) + require.Equal(t, aID2[:], req.msg.GetPeerUpdate().UpsertedAgents[1].Id) + require.Equal(t, hsTime, req.msg.GetPeerUpdate().UpsertedAgents[1].LastHandshake.AsTime()) + + // Delete an agent + err = tun.Update(tailnet.WorkspaceUpdate{ + DeletedAgents: []*tailnet.Agent{ + { + ID: aID1, + Name: "agent1", + WorkspaceID: wID1, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "agent1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")}, + }, + }, + }, + }) + require.NoError(t, err) + testutil.RequireRecvCtx(ctx, t, mgr.requests) + + // The new update doesn't include the deleted agent + mClock.AdvanceNext() + req = testutil.RequireRecvCtx(ctx, t, mgr.requests) + require.Nil(t, req.msg.Rpc) + require.NotNil(t, req.msg.GetPeerUpdate()) + require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1) + require.Equal(t, aID2[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id) + require.Equal(t, hsTime, req.msg.GetPeerUpdate().UpsertedAgents[0].LastHandshake.AsTime()) +} + //nolint:revive // t takes precedence -func setupTunnel(t *testing.T, ctx context.Context, client *fakeClient) (*Tunnel, *speaker[*ManagerMessage, *TunnelMessage, TunnelMessage]) { +func setupTunnel(t *testing.T, ctx context.Context, client *fakeClient, mClock quartz.Clock) (*Tunnel, *speaker[*ManagerMessage, *TunnelMessage, TunnelMessage]) { mp, tp := net.Pipe() t.Cleanup(func() { _ = mp.Close() }) t.Cleanup(func() { _ = tp.Close() }) @@ -262,7 +494,7 @@ func setupTunnel(t *testing.T, ctx context.Context, client *fakeClient) (*Tunnel var mgr *speaker[*ManagerMessage, *TunnelMessage, TunnelMessage] errCh := make(chan error, 2) go func() { - tunnel, err := NewTunnel(ctx, logger.Named("tunnel"), tp, client) + tunnel, err := NewTunnel(ctx, logger.Named("tunnel"), tp, client, WithClock(mClock)) tun = tunnel errCh <- err }()