From 72221359a8a815bcc7059d72ca5bbf37a662b025 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Mon, 26 Jun 2023 19:36:12 +0000 Subject: [PATCH 01/19] chore: replace wsconncache with a single tailnet --- agent/agent.go | 15 +- agent/agent_test.go | 312 +++++------------ agent/agenttest/client.go | 169 +++++++++ coderd/coderd.go | 18 +- .../prometheusmetrics_test.go | 2 +- coderd/tailnet.go | 330 ++++++++++++++++++ coderd/tailnet_test.go | 209 +++++++++++ coderd/workspaceagents.go | 21 +- coderd/workspaceapps/apptest/setup.go | 3 +- coderd/workspaceapps/proxy.go | 34 +- coderd/wsconncache/wsconncache.go | 54 ++- coderd/wsconncache/wsconncache_test.go | 16 +- codersdk/workspaceagentconn.go | 151 ++++++-- codersdk/workspaceagents.go | 24 +- enterprise/tailnet/coordinator.go | 57 +++ enterprise/tailnet/pgcoord.go | 10 + enterprise/wsproxy/wsproxy.go | 8 +- scaletest/agentconn/run.go | 6 +- tailnet/conn.go | 59 ++-- tailnet/conn_test.go | 4 +- tailnet/coordinator.go | 75 +++- tailnet/tailnettest/tailnettest.go | 4 +- tailnet/tailnettest/tailnettest_test.go | 2 +- 23 files changed, 1250 insertions(+), 333 deletions(-) create mode 100644 agent/agenttest/client.go create mode 100644 coderd/tailnet.go create mode 100644 coderd/tailnet_test.go diff --git a/agent/agent.go b/agent/agent.go index b1218190bbcb4..ce477d11415b5 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -64,6 +64,7 @@ type Options struct { SSHMaxTimeout time.Duration TailnetListenPort uint16 Subsystem codersdk.AgentSubsystem + Addresses []netip.Prefix PrometheusRegistry *prometheus.Registry } @@ -111,6 +112,16 @@ func New(options Options) Agent { prometheusRegistry = prometheus.NewRegistry() } + if len(options.Addresses) == 0 { + options.Addresses = []netip.Prefix{ + // This is the IP that should be used primarily. + netip.PrefixFrom(tailnet.IP(), 128), + // We also listen on the legacy codersdk.WorkspaceAgentIP. This + // allows for a transition away from wsconncache. + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + } + } + ctx, cancelFunc := context.WithCancel(context.Background()) a := &agent{ tailnetListenPort: options.TailnetListenPort, @@ -131,6 +142,7 @@ func New(options Options) Agent { connStatsChan: make(chan *agentsdk.Stats, 1), sshMaxTimeout: options.SSHMaxTimeout, subsystem: options.Subsystem, + addresses: options.Addresses, prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), @@ -174,6 +186,7 @@ type agent struct { lifecycleStates []agentsdk.PostLifecycleRequest network *tailnet.Conn + addresses []netip.Prefix connStatsChan chan *agentsdk.Stats latestStat atomic.Pointer[agentsdk.Stats] @@ -639,7 +652,7 @@ func (a *agent) trackConnGoroutine(fn func()) error { func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) { network, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, + Addresses: a.addresses, DERPMap: derpMap, Logger: a.logger.Named("tailnet"), ListenPort: a.tailnetListenPort, diff --git a/agent/agent_test.go b/agent/agent_test.go index 1a16fa7cab612..32c63ad5d8155 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -34,7 +34,6 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" "golang.org/x/crypto/ssh" - "golang.org/x/exp/maps" "golang.org/x/xerrors" "tailscale.com/net/speedtest" "tailscale.com/tailcfg" @@ -43,6 +42,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" "github.com/coder/coder/agent/agentssh" + "github.com/coder/coder/agent/agenttest" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" @@ -881,16 +881,15 @@ func TestAgent_StartupScript(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ StartupScript: command, DERPMap: &tailcfg.DERPMap{}, }, - statsChan: make(chan *agentsdk.Stats), - coordinator: tailnet.NewCoordinator(logger), - } + make(chan *agentsdk.Stats), + tailnet.NewCoordinator(logger), + ) closer := agent.New(agent.Options{ Client: client, Filesystem: afero.NewMemMapFs(), @@ -901,36 +900,35 @@ func TestAgent_StartupScript(t *testing.T) { _ = closer.Close() }) assert.Eventually(t, func() bool { - got := client.getLifecycleStates() + got := client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady }, testutil.WaitShort, testutil.IntervalMedium) - require.Len(t, client.getStartupLogs(), 1) - require.Equal(t, output, client.getStartupLogs()[0].Output) + require.Len(t, client.GetStartupLogs(), 1) + require.Equal(t, output, client.GetStartupLogs()[0].Output) }) // This ensures that even when coderd sends back that the startup // script has written too many lines it will still succeed! t.Run("OverflowsAndSkips", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ StartupScript: command, DERPMap: &tailcfg.DERPMap{}, }, - patchWorkspaceLogs: func() error { - resp := httptest.NewRecorder() - httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{ - Message: "Too many lines!", - }) - res := resp.Result() - defer res.Body.Close() - return codersdk.ReadBodyAsError(res) - }, - statsChan: make(chan *agentsdk.Stats), - coordinator: tailnet.NewCoordinator(logger), + make(chan *agentsdk.Stats, 50), + tailnet.NewCoordinator(logger), + ) + client.PatchWorkspaceLogs = func() error { + resp := httptest.NewRecorder() + httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{ + Message: "Too many lines!", + }) + res := resp.Result() + defer res.Body.Close() + return codersdk.ReadBodyAsError(res) } closer := agent.New(agent.Options{ Client: client, @@ -942,10 +940,10 @@ func TestAgent_StartupScript(t *testing.T) { _ = closer.Close() }) assert.Eventually(t, func() bool { - got := client.getLifecycleStates() + got := client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady }, testutil.WaitShort, testutil.IntervalMedium) - require.Len(t, client.getStartupLogs(), 0) + require.Len(t, client.GetStartupLogs(), 0) }) } @@ -969,14 +967,14 @@ func TestAgent_Metadata(t *testing.T) { var gotMd map[string]agentsdk.PostMetadataRequest require.Eventually(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() return len(gotMd) == 1 }, testutil.WaitShort, testutil.IntervalMedium) collectedAt := gotMd["greeting"].CollectedAt require.Never(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() if len(gotMd) != 1 { panic("unexpected number of metadata") } @@ -1000,7 +998,7 @@ func TestAgent_Metadata(t *testing.T) { var gotMd map[string]agentsdk.PostMetadataRequest require.Eventually(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() return len(gotMd) == 1 }, testutil.WaitShort, testutil.IntervalMedium) @@ -1010,7 +1008,7 @@ func TestAgent_Metadata(t *testing.T) { } if !assert.Eventually(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() return gotMd["greeting"].CollectedAt.After(collectedAt1) }, testutil.WaitShort, testutil.IntervalMedium) { t.Fatalf("expected metadata to be collected again") @@ -1052,11 +1050,11 @@ func TestAgentMetadata_Timing(t *testing.T) { }, 0) require.Eventually(t, func() bool { - return len(client.getMetadata()) == 2 + return len(client.GetMetadata()) == 2 }, testutil.WaitShort, testutil.IntervalMedium) for start := time.Now(); time.Since(start) < testutil.WaitMedium; time.Sleep(testutil.IntervalMedium) { - md := client.getMetadata() + md := client.GetMetadata() require.Len(t, md, 2, "got: %+v", md) require.Equal(t, "hello\n", md["greeting"].Value) @@ -1110,7 +1108,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == want[len(want)-1] }, testutil.WaitShort, testutil.IntervalMedium) @@ -1132,7 +1130,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == want[len(want)-1] }, testutil.WaitShort, testutil.IntervalMedium) @@ -1154,7 +1152,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == want[len(want)-1] }, testutil.WaitShort, testutil.IntervalMedium) @@ -1171,7 +1169,7 @@ func TestAgent_Lifecycle(t *testing.T) { var ready []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - ready = client.getLifecycleStates() + ready = client.GetLifecycleStates() return len(ready) > 0 && ready[len(ready)-1] == codersdk.WorkspaceAgentLifecycleReady }, testutil.WaitShort, testutil.IntervalMedium) @@ -1192,7 +1190,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates()[len(ready):] + got = client.GetLifecycleStates()[len(ready):] return len(got) > 0 && got[len(got)-1] == want[len(want)-1] }, testutil.WaitShort, testutil.IntervalMedium) @@ -1209,7 +1207,7 @@ func TestAgent_Lifecycle(t *testing.T) { var ready []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - ready = client.getLifecycleStates() + ready = client.GetLifecycleStates() return len(ready) > 0 && ready[len(ready)-1] == codersdk.WorkspaceAgentLifecycleReady }, testutil.WaitShort, testutil.IntervalMedium) @@ -1231,7 +1229,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates()[len(ready):] + got = client.GetLifecycleStates()[len(ready):] return len(got) > 0 && got[len(got)-1] == want[len(want)-1] }, testutil.WaitShort, testutil.IntervalMedium) @@ -1248,7 +1246,7 @@ func TestAgent_Lifecycle(t *testing.T) { var ready []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - ready = client.getLifecycleStates() + ready = client.GetLifecycleStates() return len(ready) > 0 && ready[len(ready)-1] == codersdk.WorkspaceAgentLifecycleReady }, testutil.WaitShort, testutil.IntervalMedium) @@ -1270,7 +1268,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates()[len(ready):] + got = client.GetLifecycleStates()[len(ready):] return len(got) > 0 && got[len(got)-1] == want[len(want)-1] }, testutil.WaitShort, testutil.IntervalMedium) @@ -1281,17 +1279,18 @@ func TestAgent_Lifecycle(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) expected := "this-is-shutdown" - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ - DERPMap: tailnettest.RunDERPAndSTUN(t), + derpMap, _ := tailnettest.RunDERPAndSTUN(t) + + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ + DERPMap: derpMap, StartupScript: "echo 1", ShutdownScript: "echo " + expected, }, - statsChan: make(chan *agentsdk.Stats), - coordinator: tailnet.NewCoordinator(logger), - } + make(chan *agentsdk.Stats, 50), + tailnet.NewCoordinator(logger), + ) fs := afero.NewMemMapFs() agent := agent.New(agent.Options{ @@ -1343,9 +1342,9 @@ func TestAgent_Startup(t *testing.T) { Directory: "", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) - require.Equal(t, "", client.getStartup().ExpandedDirectory) + require.Equal(t, "", client.GetStartup().ExpandedDirectory) }) t.Run("HomeDirectory", func(t *testing.T) { @@ -1357,11 +1356,11 @@ func TestAgent_Startup(t *testing.T) { Directory: "~", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, homeDir, client.getStartup().ExpandedDirectory) + require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory) }) t.Run("NotAbsoluteDirectory", func(t *testing.T) { @@ -1373,11 +1372,11 @@ func TestAgent_Startup(t *testing.T) { Directory: "coder/coder", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.getStartup().ExpandedDirectory) + require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.GetStartup().ExpandedDirectory) }) t.Run("HomeEnvironmentVariable", func(t *testing.T) { @@ -1389,11 +1388,11 @@ func TestAgent_Startup(t *testing.T) { Directory: "$HOME", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, homeDir, client.getStartup().ExpandedDirectory) + require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory) }) } @@ -1532,7 +1531,7 @@ func TestAgent_Speedtest(t *testing.T) { t.Skip("This test is relatively flakey because of Tailscale's speedtest code...") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) //nolint:dogsled conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{ DERPMap: derpMap, @@ -1552,17 +1551,16 @@ func TestAgent_Reconnect(t *testing.T) { defer coordinator.Close() agentID := uuid.New() - statsCh := make(chan *agentsdk.Stats) - derpMap := tailnettest.RunDERPAndSTUN(t) - client := &client{ - t: t, - agentID: agentID, - manifest: agentsdk.Manifest{ + statsCh := make(chan *agentsdk.Stats, 50) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) + client := agenttest.NewClient(t, + agentID, + agentsdk.Manifest{ DERPMap: derpMap, }, - statsChan: statsCh, - coordinator: coordinator, - } + statsCh, + coordinator, + ) initialized := atomic.Int32{} closer := agent.New(agent.Options{ ExchangeToken: func(ctx context.Context) (string, error) { @@ -1577,7 +1575,7 @@ func TestAgent_Reconnect(t *testing.T) { require.Eventually(t, func() bool { return coordinator.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) - client.lastWorkspaceAgent() + client.LastWorkspaceAgent() require.Eventually(t, func() bool { return initialized.Load() == 2 }, testutil.WaitShort, testutil.IntervalFast) @@ -1589,16 +1587,15 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { coordinator := tailnet.NewCoordinator(logger) defer coordinator.Close() - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ GitAuthConfigs: 1, DERPMap: &tailcfg.DERPMap{}, }, - statsChan: make(chan *agentsdk.Stats), - coordinator: coordinator, - } + make(chan *agentsdk.Stats, 50), + coordinator, + ) filesystem := afero.NewMemMapFs() closer := agent.New(agent.Options{ ExchangeToken: func(ctx context.Context) (string, error) { @@ -1683,22 +1680,16 @@ func setupSSHSession(t *testing.T, options agentsdk.Manifest) *ssh.Session { return session } -type closeFunc func() error - -func (c closeFunc) Close() error { - return c() -} - func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(agent.Options) agent.Options) ( *codersdk.WorkspaceAgentConn, - *client, + *agenttest.Client, <-chan *agentsdk.Stats, afero.Fs, io.Closer, ) { logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) if metadata.DERPMap == nil { - metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) + metadata.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) } coordinator := tailnet.NewCoordinator(logger) t.Cleanup(func() { @@ -1707,13 +1698,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati agentID := uuid.New() statsCh := make(chan *agentsdk.Stats, 50) fs := afero.NewMemMapFs() - c := &client{ - t: t, - agentID: agentID, - manifest: metadata, - statsChan: statsCh, - coordinator: coordinator, - } + c := agenttest.NewClient(t, agentID, metadata, statsCh, coordinator) options := agent.Options{ Client: c, @@ -1752,9 +1737,16 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati return conn.UpdateNodes(node, false) }) conn.SetNodeCallback(sendNode) - agentConn := &codersdk.WorkspaceAgentConn{ - Conn: conn, - } + agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: agentID, + GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { + node := coordinator.Node(agentID) + if node == nil { + return nil, xerrors.Errorf("node not found %q", err) + } + return node, nil + }, + }) t.Cleanup(func() { _ = agentConn.Close() }) @@ -1799,136 +1791,6 @@ func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { assert.Equal(t, len(payload), n, "payload length does not match") } -type client struct { - t *testing.T - agentID uuid.UUID - manifest agentsdk.Manifest - metadata map[string]agentsdk.PostMetadataRequest - statsChan chan *agentsdk.Stats - coordinator tailnet.Coordinator - lastWorkspaceAgent func() - patchWorkspaceLogs func() error - - mu sync.Mutex // Protects following. - lifecycleStates []codersdk.WorkspaceAgentLifecycle - startup agentsdk.PostStartupRequest - logs []agentsdk.StartupLog -} - -func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) { - return c.manifest, nil -} - -func (c *client) Listen(_ context.Context) (net.Conn, error) { - clientConn, serverConn := net.Pipe() - closed := make(chan struct{}) - c.lastWorkspaceAgent = func() { - _ = serverConn.Close() - _ = clientConn.Close() - <-closed - } - c.t.Cleanup(c.lastWorkspaceAgent) - go func() { - _ = c.coordinator.ServeAgent(serverConn, c.agentID, "") - close(closed) - }() - return clientConn, nil -} - -func (c *client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) { - doneCh := make(chan struct{}) - ctx, cancel := context.WithCancel(ctx) - - go func() { - defer close(doneCh) - - setInterval(500 * time.Millisecond) - for { - select { - case <-ctx.Done(): - return - case stat := <-statsChan: - select { - case c.statsChan <- stat: - case <-ctx.Done(): - return - default: - // We don't want to send old stats. - continue - } - } - } - }() - return closeFunc(func() error { - cancel() - <-doneCh - close(c.statsChan) - return nil - }), nil -} - -func (c *client) getLifecycleStates() []codersdk.WorkspaceAgentLifecycle { - c.mu.Lock() - defer c.mu.Unlock() - return c.lifecycleStates -} - -func (c *client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - c.lifecycleStates = append(c.lifecycleStates, req.State) - return nil -} - -func (*client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error { - return nil -} - -func (c *client) getStartup() agentsdk.PostStartupRequest { - c.mu.Lock() - defer c.mu.Unlock() - return c.startup -} - -func (c *client) getMetadata() map[string]agentsdk.PostMetadataRequest { - c.mu.Lock() - defer c.mu.Unlock() - return maps.Clone(c.metadata) -} - -func (c *client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - if c.metadata == nil { - c.metadata = make(map[string]agentsdk.PostMetadataRequest) - } - c.metadata[key] = req - return nil -} - -func (c *client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - c.startup = startup - return nil -} - -func (c *client) getStartupLogs() []agentsdk.StartupLog { - c.mu.Lock() - defer c.mu.Unlock() - return c.logs -} - -func (c *client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error { - c.mu.Lock() - defer c.mu.Unlock() - if c.patchWorkspaceLogs != nil { - return c.patchWorkspaceLogs() - } - c.logs = append(c.logs, logs.Logs...) - return nil -} - // tempDirUnixSocket returns a temporary directory that can safely hold unix // sockets (probably). // diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go new file mode 100644 index 0000000000000..c1daa0bfacc9a --- /dev/null +++ b/agent/agenttest/client.go @@ -0,0 +1,169 @@ +package agenttest + +import ( + "context" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "golang.org/x/exp/maps" + + "cdr.dev/slog" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/tailnet" +) + +func NewClient(t testing.TB, + agentID uuid.UUID, + manifest agentsdk.Manifest, + statsChan chan *agentsdk.Stats, + coordinator tailnet.Coordinator, +) *Client { + return &Client{ + t: t, + agentID: agentID, + manifest: manifest, + statsChan: statsChan, + coordinator: coordinator, + } +} + +type Client struct { + t testing.TB + agentID uuid.UUID + manifest agentsdk.Manifest + metadata map[string]agentsdk.PostMetadataRequest + statsChan chan *agentsdk.Stats + coordinator tailnet.Coordinator + LastWorkspaceAgent func() + PatchWorkspaceLogs func() error + + mu sync.Mutex // Protects following. + lifecycleStates []codersdk.WorkspaceAgentLifecycle + startup agentsdk.PostStartupRequest + logs []agentsdk.StartupLog +} + +func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) { + return c.manifest, nil +} + +func (c *Client) Listen(_ context.Context) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + closed := make(chan struct{}) + c.LastWorkspaceAgent = func() { + _ = serverConn.Close() + _ = clientConn.Close() + <-closed + } + c.t.Cleanup(c.LastWorkspaceAgent) + go func() { + _ = c.coordinator.ServeAgent(serverConn, c.agentID, "") + close(closed) + }() + return clientConn, nil +} + +func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) { + doneCh := make(chan struct{}) + ctx, cancel := context.WithCancel(ctx) + + go func() { + defer close(doneCh) + + setInterval(500 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return + case stat := <-statsChan: + select { + case c.statsChan <- stat: + case <-ctx.Done(): + return + default: + // We don't want to send old stats. + continue + } + } + } + }() + return closeFunc(func() error { + cancel() + <-doneCh + close(c.statsChan) + return nil + }), nil +} + +func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle { + c.mu.Lock() + defer c.mu.Unlock() + return c.lifecycleStates +} + +func (c *Client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error { + c.mu.Lock() + defer c.mu.Unlock() + c.lifecycleStates = append(c.lifecycleStates, req.State) + return nil +} + +func (*Client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error { + return nil +} + +func (c *Client) GetStartup() agentsdk.PostStartupRequest { + c.mu.Lock() + defer c.mu.Unlock() + return c.startup +} + +func (c *Client) GetMetadata() map[string]agentsdk.PostMetadataRequest { + c.mu.Lock() + defer c.mu.Unlock() + return maps.Clone(c.metadata) +} + +func (c *Client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.metadata == nil { + c.metadata = make(map[string]agentsdk.PostMetadataRequest) + } + c.metadata[key] = req + return nil +} + +func (c *Client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error { + c.mu.Lock() + defer c.mu.Unlock() + c.startup = startup + return nil +} + +func (c *Client) GetStartupLogs() []agentsdk.StartupLog { + c.mu.Lock() + defer c.mu.Unlock() + return c.logs +} + +func (c *Client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.PatchWorkspaceLogs != nil { + return c.PatchWorkspaceLogs() + } + c.logs = append(c.logs, logs.Logs...) + return nil +} + +type closeFunc func() error + +func (c closeFunc) Close() error { + return c() +} diff --git a/coderd/coderd.go b/coderd/coderd.go index 41ddcf4bbda58..43a77676b2688 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -351,8 +351,17 @@ func New(options *Options) *API { } api.Auditor.Store(&options.Auditor) - api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) api.TailnetCoordinator.Store(&options.TailnetCoordinator) + api.tailnet, err = NewServerTailnet(api.ctx, + options.Logger, + options.DERPServer, + options.DERPMap, + &api.TailnetCoordinator, + wsconncache.New(api._dialWorkspaceAgentTailnet, 0), + ) + if err != nil { + panic("failed to setup server tailnet: " + err.Error()) + } api.workspaceAppServer = &workspaceapps.Server{ Logger: options.Logger.Named("workspaceapps"), @@ -364,7 +373,7 @@ func New(options *Options) *API { RealIPConfig: options.RealIPConfig, SignedTokenProvider: api.WorkspaceAppsProvider, - WorkspaceConnCache: api.workspaceAgentCache, + AgentProvider: api.tailnet, AppSecurityKey: options.AppSecurityKey, DisablePathApps: options.DeploymentValues.DisablePathApps.Value(), @@ -874,10 +883,10 @@ type API struct { derpCloseFunc func() metricsCache *metricscache.Cache - workspaceAgentCache *wsconncache.Cache updateChecker *updatecheck.Checker WorkspaceAppsProvider workspaceapps.SignedTokenProvider workspaceAppServer *workspaceapps.Server + tailnet *ServerTailnet // Experiments contains the list of experiments currently enabled. // This is used to gate features that are not yet ready for production. @@ -904,7 +913,8 @@ func (api *API) Close() error { if coordinator != nil { _ = (*coordinator).Close() } - return api.workspaceAgentCache.Close() + _ = api.tailnet.Close() + return nil } func compressHandler(h http.Handler) http.Handler { diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index 9101288cca570..d05d84cca71e0 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -302,7 +302,7 @@ func TestAgents(t *testing.T) { coordinator := tailnet.NewCoordinator(slogtest.Make(t, nil).Leveled(slog.LevelDebug)) coordinatorPtr := atomic.Pointer[tailnet.Coordinator]{} coordinatorPtr.Store(&coordinator) - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) agentInactiveDisconnectTimeout := 1 * time.Hour // don't need to focus on this value in tests registry := prometheus.NewRegistry() diff --git a/coderd/tailnet.go b/coderd/tailnet.go new file mode 100644 index 0000000000000..acfbd4a7fadee --- /dev/null +++ b/coderd/tailnet.go @@ -0,0 +1,330 @@ +package coderd + +import ( + "bufio" + "context" + "net" + "net/http" + "net/http/httputil" + "net/netip" + "net/url" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "tailscale.com/derp" + "tailscale.com/tailcfg" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/wsconncache" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/site" + "github.com/coder/coder/tailnet" +) + +var defaultTransport *http.Transport + +func init() { + var valid bool + defaultTransport, valid = http.DefaultTransport.(*http.Transport) + if !valid { + panic("dev error: default transport is the wrong type") + } +} + +// TODO: ServerTailnet does not currently remove stale peers. + +// NewServerTailnet creates a new tailnet intended for use by coderd. It +// automatically falls back to wsconncache if a legacy agent is encountered. +func NewServerTailnet( + ctx context.Context, + logger slog.Logger, + derpServer *derp.Server, + derpMap *tailcfg.DERPMap, + coord *atomic.Pointer[tailnet.Coordinator], + cache *wsconncache.Cache, +) (*ServerTailnet, error) { + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: derpMap, + Logger: logger.Named("tailnet"), + }) + if err != nil { + return nil, xerrors.Errorf("create tailnet conn: %w", err) + } + + tn := &ServerTailnet{ + logger: logger, + conn: conn, + coordinator: coord, + cache: cache, + agentNodes: map[uuid.UUID]*tailnetNode{}, + transport: defaultTransport.Clone(), + } + tn.transport.DialContext = tn.dialContext + tn.transport.MaxIdleConnsPerHost = 10 + tn.transport.MaxIdleConns = 0 + + conn.SetNodeCallback(func(node *tailnet.Node) { + tn.nodesMu.Lock() + ids := make([]uuid.UUID, 0, len(tn.agentNodes)) + for id := range tn.agentNodes { + ids = append(ids, id) + } + tn.nodesMu.Unlock() + + err := (*tn.coordinator.Load()).BroadcastToAgents(ids, node) + if err != nil { + tn.logger.Error(context.Background(), "broadcast server node to agents", slog.Error(err)) + } + }) + + // This is set to allow local DERP traffic to be proxied through memory + // instead of needing to hit the external access URL. Don't use the ctx + // given in this callback, it's only valid while connecting. + conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn { + if !region.EmbeddedRelay { + return nil + } + left, right := net.Pipe() + go func() { + defer left.Close() + defer right.Close() + brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right)) + derpServer.Accept(ctx, right, brw, "internal") + }() + return left + }) + + return tn, nil +} + +type tailnetNode struct { + node *tailnet.Node + lastConnection time.Time + stop func() +} + +type ServerTailnet struct { + logger slog.Logger + conn *tailnet.Conn + coordinator *atomic.Pointer[tailnet.Coordinator] + cache *wsconncache.Cache + nodesMu sync.Mutex + // agentNodes is a map of agent tailnetNodes the server wants to keep a + // connection to. + agentNodes map[uuid.UUID]*tailnetNode + + transport *http.Transport +} + +func (s *ServerTailnet) updateNode(id uuid.UUID, node *tailnet.Node) { + s.nodesMu.Lock() + cached, ok := s.agentNodes[id] + if ok { + cached.node = node + } + s.nodesMu.Unlock() + + if ok { + err := s.conn.UpdateNodes([]*tailnet.Node{node}, false) + if err != nil { + s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err)) + return + } + } +} + +func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (*httputil.ReverseProxy, func(), error) { + proxy := httputil.NewSingleHostReverseProxy(targetURL) + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + site.RenderStaticErrorPage(w, r, site.ErrorPageData{ + Status: http.StatusBadGateway, + Title: "Bad Gateway", + Description: "Failed to proxy request to application: " + err.Error(), + RetryEnabled: true, + DashboardURL: dashboardURL.String(), + }) + } + proxy.Director = s.director(agentID, proxy.Director) + proxy.Transport = s.transport + + return proxy, func() {}, nil +} + +type agentIDKey struct{} + +// director makes sure agentIDKey is set on the context in the reverse proxy. +// This allows the transport to correctly identify which agent to dial to. +func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) { + return func(req *http.Request) { + ctx := context.WithValue(req.Context(), agentIDKey{}, agentID) + *req = *req.WithContext(ctx) + prev(req) + } +} + +func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { + agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID) + if !ok { + return nil, xerrors.Errorf("no agent id attached") + } + + return s.DialAgentNetConn(ctx, agentID, network, addr) +} + +func (s *ServerTailnet) getNode(agentID uuid.UUID) (*tailnet.Node, error) { + s.nodesMu.Lock() + tnode, ok := s.agentNodes[agentID] + // If we don't have the node, fetch it from the coordinator. + if !ok { + coord := *s.coordinator.Load() + node := coord.Node(agentID) + // The coordinator doesn't have the node either. Nothing we can do here. + if node == nil { + s.nodesMu.Unlock() + return nil, xerrors.Errorf("node %q not found", agentID.String()) + } + stop := coord.SubscribeAgent(agentID, s.updateNode) + tnode = &tailnetNode{ + node: node, + lastConnection: time.Now(), + stop: stop, + } + s.agentNodes[agentID] = tnode + + err := coord.BroadcastToAgents([]uuid.UUID{agentID}, s.conn.Node()) + if err != nil { + s.logger.Debug(context.Background(), "broadcast server node to agents", slog.Error(err)) + } + } + s.nodesMu.Unlock() + + if len(tnode.node.Addresses) == 0 { + return nil, xerrors.New("agent has no reachable addresses") + } + + // if we didn't already have the node locally, add it to our tailnet. + if !ok { + err := s.conn.UpdateNodes([]*tailnet.Node{tnode.node}, false) + if err != nil { + return nil, xerrors.Errorf("update nodes: %w", err) + } + } + + return tnode.node, nil +} + +func (s *ServerTailnet) awaitNodeExists(ctx context.Context, id uuid.UUID, timeout time.Duration) (*tailnet.Node, error) { + // Short circuit, if the node already exists, don't spend time setting up + // the ticker and loop. + if node, err := s.getNode(id); err == nil { + return node, nil + } + + var ( + ticker = time.NewTicker(10 * time.Millisecond) + + tries int + node *tailnet.Node + err error + ) + defer ticker.Stop() + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + for { + select { + case <-ctx.Done(): + // return the last error we got from getNode. + return nil, xerrors.Errorf("tries %d, last error: %w", tries, err) + case <-ticker.C: + } + + tries++ + node, err = s.getNode(id) + if err == nil { + return node, nil + } + } +} + +func (*ServerTailnet) nodeIsLegacy(node *tailnet.Node) bool { + return node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP +} + +func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) { + node, err := s.awaitNodeExists(ctx, agentID, 5*time.Second) + if err != nil { + return nil, nil, xerrors.Errorf("get agent node: %w", err) + } + + var ( + conn *codersdk.WorkspaceAgentConn + ret = func() {} + ) + + if s.nodeIsLegacy(node) { + cconn, release, err := s.cache.Acquire(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err) + } + + conn = cconn.WorkspaceAgentConn + ret = release + } else { + conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: agentID, + GetNode: s.getNode, + CloseFunc: func() error { return codersdk.ErrSkipClose }, + }) + } + + reachable := conn.AwaitReachable(ctx) + if !reachable { + return nil, nil, xerrors.New("agent is unreachable") + } + + return conn, ret, nil +} + +func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) { + conn, release, err := s.AgentConn(ctx, agentID) + if err != nil { + return nil, xerrors.Errorf("acquire agent conn: %w", err) + } + defer release() + + reachable := conn.AwaitReachable(ctx) + if !reachable { + return nil, xerrors.New("agent is unreachable") + } + + node, err := s.getNode(agentID) + if err != nil { + return nil, xerrors.New("get agent node") + } + + _, rawPort, _ := net.SplitHostPort(addr) + port, _ := strconv.ParseUint(rawPort, 10, 16) + ipp := netip.AddrPortFrom(node.Addresses[0].Addr(), uint16(port)) + + if network == "tcp" { + return conn.DialContextTCP(ctx, ipp) + } else if network == "udp" { + return conn.DialContextUDP(ctx, ipp) + } else { + return nil, xerrors.Errorf("unknown network %q", network) + } +} + +func (s *ServerTailnet) Close() error { + _ = s.cache.Close() + _ = s.conn.Close() + s.transport.CloseIdleConnections() + return nil +} diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go new file mode 100644 index 0000000000000..09033c8c597eb --- /dev/null +++ b/coderd/tailnet_test.go @@ -0,0 +1,209 @@ +package coderd_test + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "sync/atomic" + "testing" + + "github.com/google/uuid" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agenttest" + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/wsconncache" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/tailnet" + "github.com/coder/coder/tailnet/tailnettest" + "github.com/coder/coder/testutil" +) + +func TestServerTailnet_AgentConn_OK(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + // Connect through the ServerTailnet + agentID, _, serverTailnet := setupAgent(t, nil) + + conn, release, err := serverTailnet.AgentConn(ctx, agentID) + require.NoError(t, err) + defer release() + + assert.True(t, conn.AwaitReachable(ctx)) +} + +func TestServerTailnet_AgentConn_Legacy(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + // Force a connection through wsconncache using the legacy hardcoded ip. + agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{ + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + }) + + conn, release, err := serverTailnet.AgentConn(ctx, agentID) + require.NoError(t, err) + defer release() + + assert.True(t, conn.AwaitReachable(ctx)) +} + +func TestServerTailnet_ReverseProxy_OK(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + // Force a connection through wsconncache using the legacy hardcoded ip. + agentID, _, serverTailnet := setupAgent(t, nil) + + u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort)) + require.NoError(t, err) + + rp, release, err := serverTailnet.ReverseProxy(u, u, agentID) + require.NoError(t, err) + defer release() + + rw := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + u.String(), + nil, + ).WithContext(ctx) + + rp.ServeHTTP(rw, req) + res := rw.Result() + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) +} + +func TestServerTailnet_ReverseProxy_Legacy(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + // Force a connection through wsconncache using the legacy hardcoded ip. + agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{ + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + }) + + u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort)) + require.NoError(t, err) + + rp, release, err := serverTailnet.ReverseProxy(u, u, agentID) + require.NoError(t, err) + defer release() + + rw := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + u.String(), + nil, + ).WithContext(ctx) + + rp.ServeHTTP(rw, req) + res := rw.Result() + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) +} + +func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) { + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + derpMap, derpServer := tailnettest.RunDERPAndSTUN(t) + manifest := agentsdk.Manifest{ + DERPMap: derpMap, + } + + var coordPtr atomic.Pointer[tailnet.Coordinator] + coordinator := tailnet.NewCoordinator(logger) + coordPtr.Store(&coordinator) + t.Cleanup(func() { + _ = coordinator.Close() + }) + + agentID := uuid.New() + c := agenttest.NewClient(t, agentID, manifest, make(chan *agentsdk.Stats, 50), coordinator) + + options := agent.Options{ + Client: c, + Filesystem: afero.NewMemMapFs(), + Logger: logger.Named("agent"), + Addresses: agentAddresses, + } + + ag := agent.New(options) + t.Cleanup(func() { + _ = ag.Close() + }) + + cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: manifest.DERPMap, + Logger: logger.Named("client"), + }) + require.NoError(t, err) + clientConn, serverConn := net.Pipe() + serveClientDone := make(chan struct{}) + t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() + _ = conn.Close() + <-serveClientDone + }) + go func() { + defer close(serveClientDone) + coordinator.ServeClient(serverConn, uuid.New(), agentID) + }() + sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { + return conn.UpdateNodes(node, false) + }) + conn.SetNodeCallback(sendNode) + return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: agentID, + GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { + node := coordinator.Node(agentID) + if node == nil { + return nil, xerrors.Errorf("node not found %q", err) + } + return node, nil + }, + CloseFunc: func() error { return codersdk.ErrSkipClose }, + }), nil + }, 0) + + serverTailnet, err := coderd.NewServerTailnet( + context.Background(), + logger.Named("server"), + derpServer, + manifest.DERPMap, + &coordPtr, + cache, + ) + require.NoError(t, err) + + t.Cleanup(func() { + _ = serverTailnet.Close() + }) + + return agentID, ag, serverTailnet +} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 25c7861ce4ae4..aa10fffb312d3 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -654,7 +654,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req return } - agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID) + agentConn, release, err := api.tailnet.AgentConn(ctx, workspaceAgent.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error dialing workspace agent.", @@ -729,7 +729,8 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req httpapi.Write(ctx, rw, http.StatusOK, portsResponse) } -func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { +// Deprecated: use api.tailnet.AgentConn instead. +func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { clientConn, serverConn := net.Pipe() conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, @@ -765,14 +766,22 @@ func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspac return nil }) conn.SetNodeCallback(sendNodes) - agentConn := &codersdk.WorkspaceAgentConn{ - Conn: conn, - CloseFunc: func() { + agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: agentID, + GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { + return &tailnet.Node{ + // Since this is a legacy function only used by wsconncache as a + // fallback, we hardcode the node to use the wsconncache IP. + Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, + }, nil + }, + CloseFunc: func() error { cancel() _ = clientConn.Close() _ = serverConn.Close() + return nil }, - } + }) go func() { err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID) if err != nil { diff --git a/coderd/workspaceapps/apptest/setup.go b/coderd/workspaceapps/apptest/setup.go index 9432e09c9703d..0f0167f37bf79 100644 --- a/coderd/workspaceapps/apptest/setup.go +++ b/coderd/workspaceapps/apptest/setup.go @@ -399,7 +399,8 @@ func doWithRetries(t require.TestingT, client *codersdk.Client, req *http.Reques return resp, err } -func requestWithRetries(ctx context.Context, t require.TestingT, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) { +func requestWithRetries(ctx context.Context, t testing.TB, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) { + t.Helper() var resp *http.Response var err error require.Eventually(t, func() bool { diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 1d3e8592d7a1c..05dbed2e4f20a 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -23,7 +23,6 @@ import ( "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/coderd/util/slice" - "github.com/coder/coder/coderd/wsconncache" "github.com/coder/coder/codersdk" "github.com/coder/coder/site" ) @@ -61,6 +60,15 @@ var nonCanonicalHeaders = map[string]string{ "Sec-Websocket-Version": "Sec-WebSocket-Version", } +type AgentProvider interface { + // TODO: after wsconncache is deleted this doesn't need to return a release + // func. + AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error) + // TODO: after wsconncache is deleted this doesn't need to return an error. + ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) + Close() error +} + // Server serves workspace apps endpoints, including: // - Path-based apps // - Subdomain app middleware @@ -83,7 +91,6 @@ type Server struct { RealIPConfig *httpmw.RealIPConfig SignedTokenProvider SignedTokenProvider - WorkspaceConnCache *wsconncache.Cache AppSecurityKey SecurityKey // DisablePathApps disables path-based apps. This is a security feature as path @@ -95,6 +102,8 @@ type Server struct { DisablePathApps bool SecureAuthCookie bool + AgentProvider AgentProvider + websocketWaitMutex sync.Mutex websocketWaitGroup sync.WaitGroup } @@ -106,8 +115,8 @@ func (s *Server) Close() error { s.websocketWaitGroup.Wait() s.websocketWaitMutex.Unlock() - // The caller must close the SignedTokenProvider (if necessary) and the - // wsconncache. + // The caller must close the SignedTokenProvider and the AgentProvider (if + // necessary). return nil } @@ -517,18 +526,7 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT r.URL.Path = path appURL.RawQuery = "" - proxy := httputil.NewSingleHostReverseProxy(appURL) - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ - Status: http.StatusBadGateway, - Title: "Bad Gateway", - Description: "Failed to proxy request to application: " + err.Error(), - RetryEnabled: true, - DashboardURL: s.DashboardURL.String(), - }) - } - - conn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID) + proxy, release, err := s.AgentProvider.ReverseProxy(appURL, s.DashboardURL, appToken.AgentID) if err != nil { site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ Status: http.StatusBadGateway, @@ -540,7 +538,6 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT return } defer release() - proxy.Transport = conn.HTTPTransport() proxy.ModifyResponse = func(r *http.Response) error { r.Header.Del(httpmw.AccessControlAllowOriginHeader) @@ -658,13 +655,14 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { go httpapi.Heartbeat(ctx, conn) - agentConn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID) + agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID) if err != nil { log.Debug(ctx, "dial workspace agent", slog.Error(err)) _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err)) return } defer release() + defer agentConn.Close() log.Debug(ctx, "dialed workspace agent") ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command")) if err != nil { diff --git a/coderd/wsconncache/wsconncache.go b/coderd/wsconncache/wsconncache.go index 19c7f65f9fb74..917ab34a5bf13 100644 --- a/coderd/wsconncache/wsconncache.go +++ b/coderd/wsconncache/wsconncache.go @@ -1,9 +1,12 @@ // Package wsconncache caches workspace agent connections by UUID. +// DEPRECATED package wsconncache import ( "context" "net/http" + "net/http/httputil" + "net/url" "sync" "time" @@ -13,13 +16,56 @@ import ( "golang.org/x/xerrors" "github.com/coder/coder/codersdk" + "github.com/coder/coder/site" ) -// New creates a new workspace connection cache that closes -// connections after the inactive timeout provided. +type AgentProvider struct { + Cache *Cache +} + +func (a *AgentProvider) AgentConn(_ context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) { + conn, rel, err := a.Cache.Acquire(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("acquire agent connection: %w", err) + } + + return conn.WorkspaceAgentConn, rel, nil +} + +func (a *AgentProvider) ReverseProxy(targetURL *url.URL, dashboardURL *url.URL, agentID uuid.UUID) (*httputil.ReverseProxy, func(), error) { + proxy := httputil.NewSingleHostReverseProxy(targetURL) + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + site.RenderStaticErrorPage(w, r, site.ErrorPageData{ + Status: http.StatusBadGateway, + Title: "Bad Gateway", + Description: "Failed to proxy request to application: " + err.Error(), + RetryEnabled: true, + DashboardURL: dashboardURL.String(), + }) + } + + conn, release, err := a.Cache.Acquire(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("acquire agent connection: %w", err) + } + + proxy.Transport = conn.HTTPTransport() + + return proxy, release, nil +} + +func (a *AgentProvider) Close() error { + return a.Cache.Close() +} + +// New creates a new workspace connection cache that closes connections after +// the inactive timeout provided. +// +// Agent connections are cached due to Wireguard negotiation taking a few +// hundred milliseconds, depending on latency. // -// Agent connections are cached due to WebRTC negotiation -// taking a few hundred milliseconds. +// Deprecated: Use coderd.NewServerTailnet instead. wsconncache is being phased +// out because it creates a unique Tailnet for each agent. func New(dialer Dialer, inactiveTimeout time.Duration) *Cache { if inactiveTimeout == 0 { inactiveTimeout = 5 * time.Minute diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 6fdecbcf7bf3f..30a6892446d82 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/atomic" "go.uber.org/goleak" + "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -157,7 +158,7 @@ func TestCache(t *testing.T) { func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn { t.Helper() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - manifest.DERPMap = tailnettest.RunDERPAndSTUN(t) + manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) coordinator := tailnet.NewCoordinator(logger) t.Cleanup(func() { @@ -194,9 +195,16 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati return conn.UpdateNodes(node, false) }) conn.SetNodeCallback(sendNode) - agentConn := &codersdk.WorkspaceAgentConn{ - Conn: conn, - } + agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: agentID, + GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { + node := coordinator.Node(agentID) + if node == nil { + return nil, xerrors.Errorf("node not found %q", err) + } + return node, nil + }, + }) t.Cleanup(func() { _ = agentConn.Close() }) diff --git a/codersdk/workspaceagentconn.go b/codersdk/workspaceagentconn.go index 64bd4fe2f8bfa..8bb1b2e41414f 100644 --- a/codersdk/workspaceagentconn.go +++ b/codersdk/workspaceagentconn.go @@ -15,6 +15,7 @@ import ( "time" "github.com/google/uuid" + "github.com/hashicorp/go-multierror" "golang.org/x/crypto/ssh" "golang.org/x/xerrors" "tailscale.com/ipn/ipnstate" @@ -27,8 +28,13 @@ import ( // WorkspaceAgentIP is a static IPv6 address with the Tailscale prefix that is used to route // connections from clients to this node. A dynamic address is not required because a Tailnet // client only dials a single agent at a time. +// +// Deprecated: use tailnet.IP() instead. This is kept for backwards +// compatibility with wsconncache. var WorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") +var ErrSkipClose = xerrors.New("skip tailnet close") + const ( WorkspaceAgentSSHPort = tailnet.WorkspaceAgentSSHPort WorkspaceAgentReconnectingPTYPort = tailnet.WorkspaceAgentReconnectingPTYPort @@ -120,11 +126,38 @@ func init() { } } +// NewWorkspaceAgentConn creates a new WorkspaceAgentConn. `conn` may be unique +// to the WorkspaceAgentConn, or it may be shared in the case of coderd. If the +// conn is shared and closing it is undesirable, you may return ErrNoClose from +// opts.CloseFunc. This will ensure the underlying conn is not closed. +func NewWorkspaceAgentConn(conn *tailnet.Conn, opts WorkspaceAgentConnOptions) *WorkspaceAgentConn { + return &WorkspaceAgentConn{ + Conn: conn, + opts: opts, + } +} + // WorkspaceAgentConn represents a connection to a workspace agent. // @typescript-ignore WorkspaceAgentConn type WorkspaceAgentConn struct { *tailnet.Conn - CloseFunc func() + opts WorkspaceAgentConnOptions +} + +// @typescript-ignore WorkspaceAgentConnOptions +type WorkspaceAgentConnOptions struct { + AgentID uuid.UUID + GetNode func(agentID uuid.UUID) (*tailnet.Node, error) + CloseFunc func() error +} + +func (c *WorkspaceAgentConn) getAgentAddress() (netip.Addr, error) { + node, err := c.opts.GetNode(c.opts.AgentID) + if err != nil { + return netip.Addr{}, err + } + + return node.Addresses[0].Addr(), nil } // AwaitReachable waits for the agent to be reachable. @@ -132,7 +165,25 @@ func (c *WorkspaceAgentConn) AwaitReachable(ctx context.Context) bool { ctx, span := tracing.StartSpan(ctx) defer span.End() - return c.Conn.AwaitReachable(ctx, WorkspaceAgentIP) + var ( + addr netip.Addr + err error + ) + for { + addr, err = c.getAgentAddress() + if err == nil { + break + } + + select { + case <-ctx.Done(): + return false + case <-time.After(10 * time.Millisecond): + continue + } + } + + return c.Conn.AwaitReachable(ctx, addr) } // Ping pings the agent and returns the round-trip time. @@ -141,13 +192,25 @@ func (c *WorkspaceAgentConn) Ping(ctx context.Context) (time.Duration, bool, *ip ctx, span := tracing.StartSpan(ctx) defer span.End() - return c.Conn.Ping(ctx, WorkspaceAgentIP) + addr, err := c.getAgentAddress() + if err != nil { + return 0, false, nil, err + } + + return c.Conn.Ping(ctx, addr) } // Close ends the connection to the workspace agent. func (c *WorkspaceAgentConn) Close() error { - if c.CloseFunc != nil { - c.CloseFunc() + var cerr error + if c.opts.CloseFunc != nil { + cerr = c.opts.CloseFunc() + if xerrors.Is(cerr, ErrSkipClose) { + return nil + } + } + if cerr != nil { + return multierror.Append(cerr, c.Conn.Close()) } return c.Conn.Close() } @@ -176,10 +239,17 @@ type ReconnectingPTYRequest struct { func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort)) + + addr, err := c.getAgentAddress() + if err != nil { + return nil, err + } + + conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(addr, WorkspaceAgentReconnectingPTYPort)) if err != nil { return nil, err } @@ -209,10 +279,17 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort)) + + addr, err := c.getAgentAddress() + if err != nil { + return nil, err + } + + return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(addr, WorkspaceAgentSSHPort)) } // SSHClient calls SSH to create a client that uses a weak cipher @@ -220,10 +297,12 @@ func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) { func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + netConn, err := c.SSH(ctx) if err != nil { return nil, xerrors.Errorf("ssh: %w", err) } + sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{ // SSH host validation isn't helpful, because obtaining a peer // connection already signifies user-intent to dial a workspace. @@ -233,6 +312,7 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) if err != nil { return nil, xerrors.Errorf("ssh conn: %w", err) } + return ssh.NewClient(sshConn, channels, requests), nil } @@ -240,17 +320,26 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort)) + + addr, err := c.getAgentAddress() + if err != nil { + return nil, err + } + + speedConn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(addr, WorkspaceAgentSpeedtestPort)) if err != nil { return nil, xerrors.Errorf("dial speedtest: %w", err) } + results, err := speedtest.RunClientWithConn(direction, duration, speedConn) if err != nil { return nil, xerrors.Errorf("run speedtest: %w", err) } + return results, err } @@ -259,19 +348,27 @@ func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest. func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - if network == "unix" { - return nil, xerrors.New("network must be tcp or udp") - } - _, rawPort, _ := net.SplitHostPort(addr) - port, _ := strconv.ParseUint(rawPort, 10, 16) - ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port)) + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - if network == "udp" { + + agentAddr, err := c.getAgentAddress() + if err != nil { + return nil, err + } + + _, rawPort, _ := net.SplitHostPort(addr) + port, _ := strconv.ParseUint(rawPort, 10, 16) + ipp := netip.AddrPortFrom(agentAddr, uint16(port)) + + if network == "tcp" { + return c.Conn.DialContextTCP(ctx, ipp) + } else if network == "udp" { return c.Conn.DialContextUDP(ctx, ipp) + } else { + return nil, xerrors.Errorf("unknown network %q", network) } - return c.Conn.DialContextTCP(ctx, ipp) } type WorkspaceAgentListeningPortsResponse struct { @@ -309,7 +406,13 @@ func (c *WorkspaceAgentConn) ListeningPorts(ctx context.Context) (WorkspaceAgent func (c *WorkspaceAgentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - host := net.JoinHostPort(WorkspaceAgentIP.String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort)) + + agentAddr, err := c.getAgentAddress() + if err != nil { + return nil, xerrors.Errorf("get agent address: %w", err) + } + + host := net.JoinHostPort(agentAddr.String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort)) url := fmt.Sprintf("http://%s%s", host, path) req, err := http.NewRequestWithContext(ctx, method, url, body) @@ -332,13 +435,14 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client { if network != "tcp" { return nil, xerrors.Errorf("network must be tcp") } + host, port, err := net.SplitHostPort(addr) if err != nil { return nil, xerrors.Errorf("split host port %q: %w", addr, err) } - // Verify that host is TailnetIP and port is - // TailnetStatisticsPort. - if host != WorkspaceAgentIP.String() || port != strconv.Itoa(WorkspaceAgentHTTPAPIServerPort) { + + // Verify that the port is TailnetStatisticsPort. + if port != strconv.Itoa(WorkspaceAgentHTTPAPIServerPort) { return nil, xerrors.Errorf("request %q does not appear to be for http api", addr) } @@ -346,7 +450,12 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort)) + ipAddr, err := netip.ParseAddr(host) + if err != nil { + return nil, xerrors.Errorf("parse host addr: %w", err) + } + + conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(ipAddr, WorkspaceAgentHTTPAPIServerPort)) if err != nil { return nil, xerrors.Errorf("dial http api: %w", err) } diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index c6fda00cbd95c..9e5ab4b18906a 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -11,6 +11,7 @@ import ( "net/http/cookiejar" "net/netip" "strconv" + "sync/atomic" "time" "github.com/google/uuid" @@ -262,6 +263,7 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti }() closed := make(chan struct{}) first := make(chan error) + var latestNode atomic.Pointer[tailnet.Node] go func() { defer close(closed) isFirst := true @@ -290,6 +292,11 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti continue } sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error { + if len(node) != 1 { + options.Logger.Warn(ctx, "no nodes returned from ServeCoordinator") + return nil + } + latestNode.Store(node[0]) return conn.UpdateNodes(node, false) }) conn.SetNodeCallback(sendNode) @@ -312,13 +319,22 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti return nil, err } - agentConn = &WorkspaceAgentConn{ - Conn: conn, - CloseFunc: func() { + agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{ + AgentID: agentID, + GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { + node := latestNode.Load() + if node == nil { + return nil, xerrors.New("node not found") + } + return node, nil + }, + CloseFunc: func() error { cancel() <-closed + return conn.Close() }, - } + }) + if !agentConn.AwaitReachable(ctx) { _ = agentConn.Close() return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err()) diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index b0d9cfa64032f..16c2911bbb4fb 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -40,6 +40,7 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err agentSockets: map[uuid.UUID]*agpl.TrackedConn{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn{}, agentNameCache: nameCache, + agentCallbacks: map[uuid.UUID]map[uuid.UUID]func(uuid.UUID, *agpl.Node){}, } if err := coord.runPubsub(ctx); err != nil { @@ -49,6 +50,47 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err return coord, nil } +func (c *haCoordinator) SubscribeAgent(agentID uuid.UUID, cb func(agentID uuid.UUID, node *agpl.Node)) func() { + c.mutex.Lock() + defer c.mutex.Unlock() + + id := uuid.New() + cbMap, ok := c.agentCallbacks[agentID] + if !ok { + cbMap = map[uuid.UUID]func(uuid.UUID, *agpl.Node){} + c.agentCallbacks[agentID] = cbMap + } + + cbMap[id] = cb + + return func() { + c.mutex.Lock() + defer c.mutex.Unlock() + delete(cbMap, id) + } +} + +func (c *haCoordinator) BroadcastToAgents(agents []uuid.UUID, node *agpl.Node) error { + ctx := context.Background() + + for _, id := range agents { + c.mutex.Lock() + agentSocket, ok := c.agentSockets[id] + c.mutex.Unlock() + if !ok { + continue + } + + // Write the new node from this client to the actively connected agent. + err := agentSocket.Enqueue([]*agpl.Node{node}) + if err != nil { + c.log.Debug(ctx, "failed to write to agent", slog.Error(err)) + } + } + + return nil +} + type haCoordinator struct { id uuid.UUID log slog.Logger @@ -68,6 +110,8 @@ type haCoordinator struct { // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] + + agentCallbacks map[uuid.UUID]map[uuid.UUID]func(uuid.UUID, *agpl.Node) } // Node returns an in-memory node by ID. @@ -311,6 +355,19 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( for _, connectionSocket := range connectionSockets { _ = connectionSocket.Enqueue([]*agpl.Node{&node}) } + + wg := sync.WaitGroup{} + cbs := c.agentCallbacks[id] + wg.Add(len(cbs)) + for _, cb := range cbs { + cb := cb + go func() { + cb(id, &node) + wg.Done() + }() + } + wg.Wait() + c.mutex.Unlock() return &node, nil } diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 5e714eaca9513..bd52a9358ba9b 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -106,6 +106,16 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store return c, nil } +func (*pgCoord) SubscribeAgent(agentID uuid.UUID, cb func(agentID uuid.UUID, node *agpl.Node)) func() { + _, _ = agentID, cb + panic("not implemented") // TODO: Implement +} + +func (*pgCoord) BroadcastToAgents(agents []uuid.UUID, node *agpl.Node) error { + _, _ = agents, node + panic("not implemented") // TODO: Implement +} + func (*pgCoord) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) { // TODO(spikecurtis) I'd like to hold off implementing this until after the rest of this is code reviewed. w.WriteHeader(http.StatusOK) diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index fce5e0cc7a3b1..ae5da832054e2 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -183,9 +183,12 @@ func New(ctx context.Context, opts *Options) (*Server, error) { SecurityKey: secKey, Logger: s.Logger.Named("proxy_token_provider"), }, - WorkspaceConnCache: wsconncache.New(s.DialWorkspaceAgent, 0), - AppSecurityKey: secKey, + AppSecurityKey: secKey, + // TODO: Convert wsproxy to use coderd.ServerTailnet. + AgentProvider: &wsconncache.AgentProvider{ + Cache: wsconncache.New(s.DialWorkspaceAgent, 0), + }, DisablePathApps: opts.DisablePathApps, SecureAuthCookie: opts.SecureAuthCookie, } @@ -273,6 +276,7 @@ func (s *Server) Close() error { tmp, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() _ = s.SDKClient.WorkspaceProxyGoingAway(tmp) + _ = s.AppServer.AgentProvider.Close() return s.AppServer.Close() } diff --git a/scaletest/agentconn/run.go b/scaletest/agentconn/run.go index 8e65a232ce1b5..8128e83e8ca15 100644 --- a/scaletest/agentconn/run.go +++ b/scaletest/agentconn/run.go @@ -6,7 +6,6 @@ import ( "io" "net" "net/http" - "net/netip" "net/url" "strconv" "time" @@ -377,7 +376,10 @@ func agentHTTPClient(conn *codersdk.WorkspaceAgentConn) *http.Client { if err != nil { return nil, xerrors.Errorf("parse port %q: %w", port, err) } - return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.WorkspaceAgentIP, uint16(portUint))) + + // Addr doesn't matter here, besides the port. DialContext will + // automatically choose the right IP to dial. + return conn.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", portUint)) }, }, } diff --git a/tailnet/conn.go b/tailnet/conn.go index 363ccb80ff48c..c7c852f618945 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -182,10 +182,17 @@ func NewConn(options *Options) (conn *Conn, err error) { netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey() netStack, err := netstack.Create( - Logger(options.Logger.Named("netstack")), tunDevice, wireguardEngine, magicConn, dialer, dnsManager) + Logger(options.Logger.Named("netstack")), + tunDevice, + wireguardEngine, + magicConn, + dialer, + dnsManager, + ) if err != nil { return nil, xerrors.Errorf("create netstack: %w", err) } + dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { return netStack.DialContextTCP(ctx, dst) } @@ -203,7 +210,14 @@ func NewConn(options *Options) (conn *Conn, err error) { localIPs, _ := localIPSet.IPSet() logIPSet := netipx.IPSetBuilder{} logIPs, _ := logIPSet.IPSet() - wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter")))) + wireguardEngine.SetFilter(filter.New( + netMap.PacketFilter, + localIPs, + logIPs, + nil, + Logger(options.Logger.Named("packet-filter")), + )) + dialContext, dialCancel := context.WithCancel(context.Background()) server := &Conn{ blockEndpoints: options.BlockEndpoints, @@ -230,6 +244,7 @@ func NewConn(options *Options) (conn *Conn, err error) { _ = server.Close() } }() + wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err)) if err != nil { @@ -251,6 +266,7 @@ func NewConn(options *Options) (conn *Conn, err error) { server.lastMutex.Unlock() server.sendNode() }) + wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) { server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni)) server.lastMutex.Lock() @@ -262,6 +278,7 @@ func NewConn(options *Options) (conn *Conn, err error) { server.lastMutex.Unlock() server.sendNode() }) + magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) { server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason)) server.lastMutex.Lock() @@ -273,6 +290,7 @@ func NewConn(options *Options) (conn *Conn, err error) { server.lastMutex.Unlock() server.sendNode() }) + netStack.ForwardTCPIn = server.forwardTCP netStack.ForwardTCPSockOpts = server.forwardTCPSockOpts @@ -334,6 +352,12 @@ type Conn struct { trafficStats *connstats.Statistics } +func (c *Conn) Addresses() []netip.Prefix { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.netMap.Addresses +} + func (c *Conn) SetNodeCallback(callback func(node *Node)) { c.lastMutex.Lock() c.nodeCallback = callback @@ -366,32 +390,6 @@ func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tail c.magicConn.SetDERPRegionDialer(dialer) } -func (c *Conn) RemoveAllPeers() error { - c.mutex.Lock() - defer c.mutex.Unlock() - - c.netMap.Peers = []*tailcfg.Node{} - c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{} - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) - cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "") - if err != nil { - return xerrors.Errorf("update wireguard config: %w", err) - } - err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) - if err != nil { - if c.isClosed() { - return nil - } - if errors.Is(err, wgengine.ErrNoChanges) { - return nil - } - return xerrors.Errorf("reconfig: %w", err) - } - return nil -} - // UpdateNodes connects with a set of peers. This can be constantly updated, // and peers will continually be reconnected as necessary. If replacePeers is // true, all peers will be removed before adding the new ones. @@ -423,6 +421,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { } delete(c.peerMap, peer.ID) } + for _, node := range nodes { // If no preferred DERP is provided, we can't reach the node. if node.PreferredDERP == 0 { @@ -452,10 +451,12 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { } c.peerMap[node.ID] = peerNode } + c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap)) for _, peer := range c.peerMap { c.netMap.Peers = append(c.netMap.Peers, peer.Clone()) } + netMapCopy := *c.netMap c.logger.Debug(context.Background(), "updating network map") c.wireguardEngine.SetNetworkMap(&netMapCopy) @@ -463,6 +464,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { if err != nil { return xerrors.Errorf("update wireguard config: %w", err) } + err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) if err != nil { if c.isClosed() { @@ -473,6 +475,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { } return xerrors.Errorf("reconfig: %w", err) } + return nil } diff --git a/tailnet/conn_test.go b/tailnet/conn_test.go index 2e19379e6df03..0dd0812b94777 100644 --- a/tailnet/conn_test.go +++ b/tailnet/conn_test.go @@ -23,7 +23,7 @@ func TestMain(m *testing.M) { func TestTailnet(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) t.Run("InstantClose", func(t *testing.T) { t.Parallel() conn, err := tailnet.NewConn(&tailnet.Options{ @@ -172,7 +172,7 @@ func TestConn_PreferredDERP(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, Logger: logger.Named("w1"), diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index ee675ef6665e7..cc4127f02bbfc 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -44,6 +44,9 @@ type Coordinator interface { ServeAgent(conn net.Conn, id uuid.UUID, name string) error // Close closes the coordinator. Close() error + + SubscribeAgent(agentID uuid.UUID, cb func(agentID uuid.UUID, node *Node)) func() + BroadcastToAgents(agents []uuid.UUID, node *Node) error } // Node represents a node in the network. @@ -54,10 +57,11 @@ type Node struct { AsOf time.Time `json:"as_of"` // Key is the Wireguard public key of the node. Key key.NodePublic `json:"key"` - // DiscoKey is used for discovery messages over DERP to establish peer-to-peer connections. + // DiscoKey is used for discovery messages over DERP to establish + // peer-to-peer connections. DiscoKey key.DiscoPublic `json:"disco"` - // PreferredDERP is the DERP server that peered connections - // should meet at to establish. + // PreferredDERP is the DERP server that peered connections should meet at + // to establish. PreferredDERP int `json:"preferred_derp"` // DERPLatency is the latency in seconds to each DERP server. DERPLatency map[string]float64 `json:"derp_latency"` @@ -68,8 +72,8 @@ type Node struct { DERPForcedWebsocket map[int]string `json:"derp_forced_websockets"` // Addresses are the IP address ranges this connection exposes. Addresses []netip.Prefix `json:"addresses"` - // AllowedIPs specify what addresses can dial the connection. - // We allow all by default. + // AllowedIPs specify what addresses can dial the connection. We allow all + // by default. AllowedIPs []netip.Prefix `json:"allowed_ips"` // Endpoints are ip:port combinations that can be used to establish // peer-to-peer connections. @@ -130,8 +134,8 @@ func NewCoordinator(logger slog.Logger) Coordinator { // ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ // │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ // └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ -// This coordinator is incompatible with multiple Coder -// replicas as all node data is in-memory. +// This coordinator is incompatible with multiple Coder replicas as all node +// data is in-memory. type coordinator struct { core *core } @@ -154,6 +158,8 @@ type core struct { // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] + + agentCallbacks map[uuid.UUID]map[uuid.UUID]func(uuid.UUID, *Node) } func newCore(logger slog.Logger) *core { @@ -169,6 +175,7 @@ func newCore(logger slog.Logger) *core { agentSockets: map[uuid.UUID]*TrackedConn{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*TrackedConn{}, agentNameCache: nameCache, + agentCallbacks: map[uuid.UUID]map[uuid.UUID]func(uuid.UUID, *Node){}, } } @@ -587,6 +594,19 @@ func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error { slog.F("client_id", clientID), slog.Error(err)) } } + + wg := sync.WaitGroup{} + cbs := c.agentCallbacks[id] + wg.Add(len(cbs)) + for _, cb := range cbs { + cb := cb + go func() { + cb(id, node) + wg.Done() + }() + } + + wg.Wait() return nil } @@ -767,3 +787,44 @@ func CoordinatorHTTPDebug( } } } + +func (c *coordinator) SubscribeAgent(agentID uuid.UUID, cb func(agentID uuid.UUID, node *Node)) func() { + c.core.mutex.Lock() + defer c.core.mutex.Unlock() + + id := uuid.New() + cbMap, ok := c.core.agentCallbacks[agentID] + if !ok { + cbMap = map[uuid.UUID]func(uuid.UUID, *Node){} + c.core.agentCallbacks[agentID] = cbMap + } + + cbMap[id] = cb + + return func() { + c.core.mutex.Lock() + defer c.core.mutex.Unlock() + delete(cbMap, id) + } +} + +func (c *coordinator) BroadcastToAgents(agents []uuid.UUID, node *Node) error { + ctx := context.Background() + + for _, id := range agents { + c.core.mutex.Lock() + agentSocket, ok := c.core.agentSockets[id] + c.core.mutex.Unlock() + if !ok { + continue + } + + // Write the new node from this client to the actively connected agent. + err := agentSocket.Enqueue([]*Node{node}) + if err != nil { + c.core.logger.Debug(ctx, "failed to write to agent", slog.Error(err)) + } + } + + return nil +} diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index 482c1232e258a..0cb7dbd330ed3 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -22,7 +22,7 @@ import ( ) // RunDERPAndSTUN creates a DERP mapping for tests. -func RunDERPAndSTUN(t *testing.T) *tailcfg.DERPMap { +func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) { logf := tailnet.Logger(slogtest.Make(t, nil)) d := derp.NewServer(key.NewNode(), logf) server := httptest.NewUnstartedServer(derphttp.Handler(d)) @@ -61,7 +61,7 @@ func RunDERPAndSTUN(t *testing.T) *tailcfg.DERPMap { }, }, }, - } + }, d } // RunDERPOnlyWebSockets creates a DERP mapping for tests that diff --git a/tailnet/tailnettest/tailnettest_test.go b/tailnet/tailnettest/tailnettest_test.go index 6424bd94db0c0..fda818a1cebca 100644 --- a/tailnet/tailnettest/tailnettest_test.go +++ b/tailnet/tailnettest/tailnettest_test.go @@ -14,7 +14,7 @@ func TestMain(m *testing.M) { func TestRunDERPAndSTUN(t *testing.T) { t.Parallel() - _ = tailnettest.RunDERPAndSTUN(t) + _, _ = tailnettest.RunDERPAndSTUN(t) } func TestRunDERPOnlyWebSockets(t *testing.T) { From d4181b0497a013f473c62df7e79f32c06172cf68 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Mon, 26 Jun 2023 19:40:58 +0000 Subject: [PATCH 02/19] remove duplicate AwaitReachable --- coderd/tailnet.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/coderd/tailnet.go b/coderd/tailnet.go index acfbd4a7fadee..6cd195cf9a89d 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -298,11 +298,7 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, return nil, xerrors.Errorf("acquire agent conn: %w", err) } defer release() - - reachable := conn.AwaitReachable(ctx) - if !reachable { - return nil, xerrors.New("agent is unreachable") - } + defer conn.Close() node, err := s.getNode(agentID) if err != nil { From 2a133d1b7cb927d855b58ec5fb7ca9841f41f39b Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Mon, 26 Jun 2023 19:52:05 +0000 Subject: [PATCH 03/19] properly release net.Conn --- coderd/tailnet.go | 26 ++++++++++++++++++++------ coderd/workspaceapps/proxy.go | 11 +++++++++-- go.mod | 1 - go.sum | 1 - 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 6cd195cf9a89d..2de4573141364 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -138,7 +138,7 @@ func (s *ServerTailnet) updateNode(id uuid.UUID, node *tailnet.Node) { } } -func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (*httputil.ReverseProxy, func(), error) { +func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) { proxy := httputil.NewSingleHostReverseProxy(targetURL) proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { site.RenderStaticErrorPage(w, r, site.ErrorPageData{ @@ -257,7 +257,7 @@ func (*ServerTailnet) nodeIsLegacy(node *tailnet.Node) bool { return node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP } -func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) { +func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error) { node, err := s.awaitNodeExists(ctx, agentID, 5*time.Second) if err != nil { return nil, nil, xerrors.Errorf("get agent node: %w", err) @@ -297,8 +297,6 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, if err != nil { return nil, xerrors.Errorf("acquire agent conn: %w", err) } - defer release() - defer conn.Close() node, err := s.getNode(agentID) if err != nil { @@ -309,13 +307,29 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, port, _ := strconv.ParseUint(rawPort, 10, 16) ipp := netip.AddrPortFrom(node.Addresses[0].Addr(), uint16(port)) + var nc net.Conn if network == "tcp" { - return conn.DialContextTCP(ctx, ipp) + nc, err = conn.DialContextTCP(ctx, ipp) } else if network == "udp" { - return conn.DialContextUDP(ctx, ipp) + nc, err = conn.DialContextUDP(ctx, ipp) } else { return nil, xerrors.Errorf("unknown network %q", network) } + + return &netConnCloser{Conn: nc, close: func() { + release() + conn.Close() + }}, err +} + +type netConnCloser struct { + net.Conn + close func() +} + +func (c *netConnCloser) Close() error { + c.close() + return c.Conn.Close() } func (s *ServerTailnet) Close() error { diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 05dbed2e4f20a..9b2d9c4bfa297 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -61,11 +61,18 @@ var nonCanonicalHeaders = map[string]string{ } type AgentProvider interface { + // ReverseProxy returns an httputil.ReverseProxy for proxying HTTP requests + // to the specified agent. + // + // TODO: after wsconncache is deleted this doesn't need to return an error. + ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) + + // AgentConn returns a new connection to the specified agent. + // // TODO: after wsconncache is deleted this doesn't need to return a release // func. AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error) - // TODO: after wsconncache is deleted this doesn't need to return an error. - ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) + Close() error } diff --git a/go.mod b/go.mod index 2fcfe01954423..8560f3a56645b 100644 --- a/go.mod +++ b/go.mod @@ -218,7 +218,6 @@ require ( github.com/docker/docker v23.0.3+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect - github.com/dustin/go-humanize v1.0.1 github.com/elastic/go-windows v1.0.0 // indirect github.com/fxamacker/cbor/v2 v2.4.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect diff --git a/go.sum b/go.sum index bc06322ca8024..49e3f12f43042 100644 --- a/go.sum +++ b/go.sum @@ -238,7 +238,6 @@ github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDD github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= -github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/elastic/go-sysinfo v1.11.0 h1:QW+6BF1oxBoAprH3w2yephF7xLkrrSXj7gl2xC2BM4w= github.com/elastic/go-sysinfo v1.11.0/go.mod h1:6KQb31j0QeWBDF88jIdWSxE8cwoOB9tO4Y4osN7Q70E= github.com/elastic/go-windows v1.0.0 h1:qLURgZFkkrYyTTkvYpsZIgf83AUsdIHfvlJaqaZ7aSY= From aa6bcb6b3e28cd428e9dde9426ccb66687a56b9c Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Mon, 26 Jun 2023 21:28:36 +0000 Subject: [PATCH 04/19] tailnetTransport --- coderd/tailnet.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 2de4573141364..d5846d1488c39 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -25,11 +25,11 @@ import ( "github.com/coder/coder/tailnet" ) -var defaultTransport *http.Transport +var tailnetTransport *http.Transport func init() { var valid bool - defaultTransport, valid = http.DefaultTransport.(*http.Transport) + tailnetTransport, valid = http.DefaultTransport.(*http.Transport) if !valid { panic("dev error: default transport is the wrong type") } @@ -62,7 +62,7 @@ func NewServerTailnet( coordinator: coord, cache: cache, agentNodes: map[uuid.UUID]*tailnetNode{}, - transport: defaultTransport.Clone(), + transport: tailnetTransport.Clone(), } tn.transport.DialContext = tn.dialContext tn.transport.MaxIdleConnsPerHost = 10 From f9040fc20c8f740170d6988f271a5d626c191601 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Mon, 26 Jun 2023 22:26:31 +0000 Subject: [PATCH 05/19] link to issue in deprecation notice --- coderd/workspaceagents.go | 1 + coderd/wsconncache/wsconncache.go | 3 ++- codersdk/workspaceagentconn.go | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index aa10fffb312d3..0f7dc43b49d74 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -730,6 +730,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req } // Deprecated: use api.tailnet.AgentConn instead. +// See: https://github.com/coder/coder/issues/8218 func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { clientConn, serverConn := net.Pipe() conn, err := tailnet.NewConn(&tailnet.Options{ diff --git a/coderd/wsconncache/wsconncache.go b/coderd/wsconncache/wsconncache.go index 917ab34a5bf13..fcff1deadec1d 100644 --- a/coderd/wsconncache/wsconncache.go +++ b/coderd/wsconncache/wsconncache.go @@ -1,5 +1,5 @@ // Package wsconncache caches workspace agent connections by UUID. -// DEPRECATED +// Deprecated package wsconncache import ( @@ -66,6 +66,7 @@ func (a *AgentProvider) Close() error { // // Deprecated: Use coderd.NewServerTailnet instead. wsconncache is being phased // out because it creates a unique Tailnet for each agent. +// See: https://github.com/coder/coder/issues/8218 func New(dialer Dialer, inactiveTimeout time.Duration) *Cache { if inactiveTimeout == 0 { inactiveTimeout = 5 * time.Minute diff --git a/codersdk/workspaceagentconn.go b/codersdk/workspaceagentconn.go index 8bb1b2e41414f..5065566bae607 100644 --- a/codersdk/workspaceagentconn.go +++ b/codersdk/workspaceagentconn.go @@ -31,6 +31,7 @@ import ( // // Deprecated: use tailnet.IP() instead. This is kept for backwards // compatibility with wsconncache. +// See: https://github.com/coder/coder/issues/8218 var WorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") var ErrSkipClose = xerrors.New("skip tailnet close") From a988feff8aa71a0c25dbe00df2b53224663ff02e Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Wed, 28 Jun 2023 02:00:40 +0000 Subject: [PATCH 06/19] address review comments --- agent/agent.go | 25 +++- coderd/apidoc/docs.go | 3 + coderd/apidoc/swagger.json | 3 + coderd/tailnet.go | 106 ++++++++++------- coderd/tailnet_test.go | 2 +- coderd/workspaceagents.go | 1 + coderd/wsconncache/wsconncache.go | 2 +- coderd/wsconncache/wsconncache_test.go | 2 +- codersdk/agentsdk/agentsdk.go | 1 + codersdk/workspaceagentconn.go | 9 +- codersdk/workspaceagents.go | 10 +- docs/api/agents.md | 1 + docs/api/schemas.md | 2 + enterprise/tailnet/coordinator.go | 122 ++++++++++++-------- enterprise/tailnet/pgcoord.go | 9 +- scaletest/workspacetraffic/run_test.go | 1 + tailnet/conn.go | 47 +++++++- tailnet/coordinator.go | 154 +++++++++++++++---------- tailnet/multiagent.go | 119 +++++++++++++++++++ 19 files changed, 441 insertions(+), 178 deletions(-) create mode 100644 tailnet/multiagent.go diff --git a/agent/agent.go b/agent/agent.go index ce477d11415b5..58c8e470ebb93 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -605,7 +605,7 @@ func (a *agent) run(ctx context.Context) error { network := a.network a.closeMutex.Unlock() if network == nil { - network, err = a.createTailnet(ctx, manifest.DERPMap, manifest.DisableDirectConnections) + network, err = a.createTailnet(ctx, manifest.AgentID, manifest.DERPMap, manifest.DisableDirectConnections) if err != nil { return xerrors.Errorf("create tailnet: %w", err) } @@ -623,6 +623,11 @@ func (a *agent) run(ctx context.Context) error { a.startReportingConnectionStats(ctx) } else { + // Update the wireguard IPs if the agent ID changed. + err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID)) + if err != nil { + a.logger.Error(ctx, "update tailnet addresses", slog.Error(err)) + } // Update the DERP map and allow/disallow direct connections. network.SetDERPMap(manifest.DERPMap) network.SetBlockEndpoints(manifest.DisableDirectConnections) @@ -636,6 +641,20 @@ func (a *agent) run(ctx context.Context) error { return nil } +func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix { + if len(a.addresses) == 0 { + return []netip.Prefix{ + // This is the IP that should be used primarily. + netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128), + // We also listen on the legacy codersdk.WorkspaceAgentIP. This + // allows for a transition away from wsconncache. + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + } + } + + return a.addresses +} + func (a *agent) trackConnGoroutine(fn func()) error { a.closeMutex.Lock() defer a.closeMutex.Unlock() @@ -650,9 +669,9 @@ func (a *agent) trackConnGoroutine(fn func()) error { return nil } -func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) { +func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) { network, err := tailnet.NewConn(&tailnet.Options{ - Addresses: a.addresses, + Addresses: a.wireguardAddresses(agentID), DERPMap: derpMap, Logger: a.logger.Named("tailnet"), ListenPort: a.tailnetListenPort, diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 872fd022878cf..e29dabac80c52 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -5772,6 +5772,9 @@ const docTemplate = `{ "agentsdk.Manifest": { "type": "object", "properties": { + "agent_id": { + "type": "string" + }, "apps": { "type": "array", "items": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 56db90e9f26e8..c884cd1c6515f 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -5083,6 +5083,9 @@ "agentsdk.Manifest": { "type": "object", "properties": { + "agent_id": { + "type": "string" + }, "apps": { "type": "array", "items": { diff --git a/coderd/tailnet.go b/coderd/tailnet.go index d5846d1488c39..f400da5b3ffe9 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -35,7 +35,7 @@ func init() { } } -// TODO: ServerTailnet does not currently remove stale peers. +// TODO(coadler): ServerTailnet does not currently remove stale peers. // NewServerTailnet creates a new tailnet intended for use by coderd. It // automatically falls back to wsconncache if a legacy agent is encountered. @@ -56,10 +56,17 @@ func NewServerTailnet( return nil, xerrors.Errorf("create tailnet conn: %w", err) } + id := uuid.New() + ma := (*coord.Load()).ServeMultiAgent(id) + + serverCtx, cancel := context.WithCancel(ctx) tn := &ServerTailnet{ + ctx: serverCtx, + cancel: cancel, logger: logger, conn: conn, coordinator: coord, + agentConn: ma, cache: cache, agentNodes: map[uuid.UUID]*tailnetNode{}, transport: tailnetTransport.Clone(), @@ -69,16 +76,9 @@ func NewServerTailnet( tn.transport.MaxIdleConns = 0 conn.SetNodeCallback(func(node *tailnet.Node) { - tn.nodesMu.Lock() - ids := make([]uuid.UUID, 0, len(tn.agentNodes)) - for id := range tn.agentNodes { - ids = append(ids, id) - } - tn.nodesMu.Unlock() - - err := (*tn.coordinator.Load()).BroadcastToAgents(ids, node) + err := tn.agentConn.UpdateSelf(node) if err != nil { - tn.logger.Error(context.Background(), "broadcast server node to agents", slog.Error(err)) + tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err)) } }) @@ -99,19 +99,52 @@ func NewServerTailnet( return left }) + go tn.watchAgentUpdates() return tn, nil } +func (s *ServerTailnet) watchAgentUpdates() { + for { + nodes := s.agentConn.NextUpdate(s.ctx) + if nodes == nil { + return + } + + toUpdate := make([]*tailnet.Node, 0) + + s.nodesMu.Lock() + for _, node := range nodes { + cached, ok := s.agentNodes[node.AgentID] + if ok { + cached.node = node.Node + toUpdate = append(toUpdate, node.Node) + } + } + s.nodesMu.Unlock() + + if len(toUpdate) > 0 { + err := s.conn.UpdateNodes(toUpdate, false) + if err != nil { + s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err)) + return + } + } + } +} + type tailnetNode struct { node *tailnet.Node lastConnection time.Time - stop func() } type ServerTailnet struct { + ctx context.Context + cancel func() + logger slog.Logger conn *tailnet.Conn coordinator *atomic.Pointer[tailnet.Coordinator] + agentConn tailnet.MultiAgentConn cache *wsconncache.Cache nodesMu sync.Mutex // agentNodes is a map of agent tailnetNodes the server wants to keep a @@ -121,23 +154,6 @@ type ServerTailnet struct { transport *http.Transport } -func (s *ServerTailnet) updateNode(id uuid.UUID, node *tailnet.Node) { - s.nodesMu.Lock() - cached, ok := s.agentNodes[id] - if ok { - cached.node = node - } - s.nodesMu.Unlock() - - if ok { - err := s.conn.UpdateNodes([]*tailnet.Node{node}, false) - if err != nil { - s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err)) - return - } - } -} - func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) { proxy := httputil.NewSingleHostReverseProxy(targetURL) proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { @@ -188,18 +204,16 @@ func (s *ServerTailnet) getNode(agentID uuid.UUID) (*tailnet.Node, error) { s.nodesMu.Unlock() return nil, xerrors.Errorf("node %q not found", agentID.String()) } - stop := coord.SubscribeAgent(agentID, s.updateNode) + + err := s.agentConn.SubscribeAgent(agentID, s.conn.Node()) + if err != nil { + return nil, xerrors.Errorf("subscribe agent: %w", err) + } tnode = &tailnetNode{ node: node, lastConnection: time.Now(), - stop: stop, } s.agentNodes[agentID] = tnode - - err := coord.BroadcastToAgents([]uuid.UUID{agentID}, s.conn.Node()) - if err != nil { - s.logger.Debug(context.Background(), "broadcast server node to agents", slog.Error(err)) - } } s.nodesMu.Unlock() @@ -257,7 +271,7 @@ func (*ServerTailnet) nodeIsLegacy(node *tailnet.Node) bool { return node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP } -func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error) { +func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) { node, err := s.awaitNodeExists(ctx, agentID, 5*time.Second) if err != nil { return nil, nil, xerrors.Errorf("get agent node: %w", err) @@ -284,8 +298,13 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (_ *co }) } + // Since we now have an open conn, be careful to close it if we error + // without returning it to the user. + reachable := conn.AwaitReachable(ctx) if !reachable { + ret() + conn.Close() return nil, nil, xerrors.New("agent is unreachable") } @@ -298,8 +317,13 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, return nil, xerrors.Errorf("acquire agent conn: %w", err) } + // Since we now have an open conn, be careful to close it if we error + // without returning it to the user. + node, err := s.getNode(agentID) if err != nil { + release() + conn.Close() return nil, xerrors.New("get agent node") } @@ -308,11 +332,14 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, ipp := netip.AddrPortFrom(node.Addresses[0].Addr(), uint16(port)) var nc net.Conn - if network == "tcp" { + switch network { + case "tcp": nc, err = conn.DialContextTCP(ctx, ipp) - } else if network == "udp" { + case "udp": nc, err = conn.DialContextUDP(ctx, ipp) - } else { + default: + release() + conn.Close() return nil, xerrors.Errorf("unknown network %q", network) } @@ -333,6 +360,7 @@ func (c *netConnCloser) Close() error { } func (s *ServerTailnet) Close() error { + s.cancel() _ = s.cache.Close() _ = s.conn.Close() s.transport.CloseIdleConnections() diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index 09033c8c597eb..e89023a3d2734 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -183,7 +183,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { node := coordinator.Node(agentID) if node == nil { - return nil, xerrors.Errorf("node not found %q", err) + return nil, xerrors.Errorf("node not found %q", agentID) } return node, nil }, diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 0f7dc43b49d74..c45cbb2bdf5f7 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -161,6 +161,7 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request) } httpapi.Write(ctx, rw, http.StatusOK, agentsdk.Manifest{ + AgentID: apiAgent.ID, Apps: convertApps(dbApps), DERPMap: api.DERPMap, GitAuthConfigs: len(api.GitAuthConfigs), diff --git a/coderd/wsconncache/wsconncache.go b/coderd/wsconncache/wsconncache.go index fcff1deadec1d..13d1588384954 100644 --- a/coderd/wsconncache/wsconncache.go +++ b/coderd/wsconncache/wsconncache.go @@ -1,5 +1,5 @@ // Package wsconncache caches workspace agent connections by UUID. -// Deprecated +// Deprecated: Use ServerTailnet instead. package wsconncache import ( diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 30a6892446d82..b755da025a8f4 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -200,7 +200,7 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { node := coordinator.Node(agentID) if node == nil { - return nil, xerrors.Errorf("node not found %q", err) + return nil, xerrors.Errorf("node not found %q", agentID) } return node, nil }, diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index ac0211cf2d37e..ab867ed504877 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -84,6 +84,7 @@ func (c *Client) PostMetadata(ctx context.Context, key string, req PostMetadataR } type Manifest struct { + AgentID uuid.UUID `json:"agent_id"` // GitAuthConfigs stores the number of Git configurations // the Coder deployment has. If this number is >0, we // set up special configuration in the workspace. diff --git a/codersdk/workspaceagentconn.go b/codersdk/workspaceagentconn.go index 5065566bae607..3bdf98c0dee6f 100644 --- a/codersdk/workspaceagentconn.go +++ b/codersdk/workspaceagentconn.go @@ -167,9 +167,12 @@ func (c *WorkspaceAgentConn) AwaitReachable(ctx context.Context) bool { defer span.End() var ( - addr netip.Addr - err error + addr netip.Addr + err error + ticker = time.NewTicker(10 * time.Millisecond) ) + defer ticker.Stop() + for { addr, err = c.getAgentAddress() if err == nil { @@ -179,7 +182,7 @@ func (c *WorkspaceAgentConn) AwaitReachable(ctx context.Context) bool { select { case <-ctx.Done(): return false - case <-time.After(10 * time.Millisecond): + case <-ticker.C: continue } } diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 9e5ab4b18906a..763ea773494a1 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -291,13 +291,13 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) continue } - sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error { - if len(node) != 1 { - options.Logger.Warn(ctx, "no nodes returned from ServeCoordinator") + sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error { + if len(nodes) != 1 { + options.Logger.Warn(ctx, "incorrect number of nodes returned from ServeCoordinator", slog.F("len", len(nodes))) return nil } - latestNode.Store(node[0]) - return conn.UpdateNodes(node, false) + latestNode.Store(nodes[0]) + return conn.UpdateNodes(nodes, false) }) conn.SetNodeCallback(sendNode) options.Logger.Debug(ctx, "serving coordinator") diff --git a/docs/api/agents.md b/docs/api/agents.md index b8c73c8ceae95..fadaa72f91ccb 100644 --- a/docs/api/agents.md +++ b/docs/api/agents.md @@ -292,6 +292,7 @@ curl -X GET http://coder-server:8080/api/v2/workspaceagents/me/manifest \ ```json { + "agent_id": "string", "apps": [ { "command": "string", diff --git a/docs/api/schemas.md b/docs/api/schemas.md index f332d03968fb1..17af834fc910a 100644 --- a/docs/api/schemas.md +++ b/docs/api/schemas.md @@ -161,6 +161,7 @@ ```json { + "agent_id": "string", "apps": [ { "command": "string", @@ -260,6 +261,7 @@ | Name | Type | Required | Restrictions | Description | | ---------------------------- | ------------------------------------------------------------------------------------------------- | -------- | ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `agent_id` | string | false | | | | `apps` | array of [codersdk.WorkspaceApp](#codersdkworkspaceapp) | false | | | | `derpmap` | [tailcfg.DERPMap](#tailcfgderpmap) | false | | | | `directory` | string | false | | | diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 16c2911bbb4fb..8071d024b0bad 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -12,11 +12,13 @@ import ( "sync" "github.com/google/uuid" + "github.com/hashicorp/go-multierror" lru "github.com/hashicorp/golang-lru/v2" "golang.org/x/xerrors" "cdr.dev/slog" "github.com/coder/coder/coderd/database/pubsub" + "github.com/coder/coder/codersdk" agpl "github.com/coder/coder/tailnet" ) @@ -40,7 +42,8 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err agentSockets: map[uuid.UUID]*agpl.TrackedConn{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn{}, agentNameCache: nameCache, - agentCallbacks: map[uuid.UUID]map[uuid.UUID]func(uuid.UUID, *agpl.Node){}, + legacyAgents: map[uuid.UUID]struct{}{}, + multiAgents: map[uuid.UUID]*agpl.MultiAgent{}, } if err := coord.runPubsub(ctx); err != nil { @@ -50,43 +53,43 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err return coord, nil } -func (c *haCoordinator) SubscribeAgent(agentID uuid.UUID, cb func(agentID uuid.UUID, node *agpl.Node)) func() { - c.mutex.Lock() - defer c.mutex.Unlock() - - id := uuid.New() - cbMap, ok := c.agentCallbacks[agentID] - if !ok { - cbMap = map[uuid.UUID]func(uuid.UUID, *agpl.Node){} - c.agentCallbacks[agentID] = cbMap - } - - cbMap[id] = cb - - return func() { - c.mutex.Lock() - defer c.mutex.Unlock() - delete(cbMap, id) - } +func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { + m := (&agpl.MultiAgent{ + ID: id, + Logger: c.log, + AgentIsLegacyFunc: c.agentIsLegacy, + OnNodeUpdate: c.multiAgentUpdate, + OnClose: c.removeMultiAgent, + }).Init() + c.addMultiAgent(m) + return m } -func (c *haCoordinator) BroadcastToAgents(agents []uuid.UUID, node *agpl.Node) error { - ctx := context.Background() +func (c *haCoordinator) addMultiAgent(m *agpl.MultiAgent) { + c.mutex.Lock() + c.multiAgents[m.ID] = m + c.mutex.Unlock() +} - for _, id := range agents { - c.mutex.Lock() - agentSocket, ok := c.agentSockets[id] - c.mutex.Unlock() - if !ok { - continue - } +func (c *haCoordinator) removeMultiAgent(id uuid.UUID) { + c.mutex.Lock() + delete(c.multiAgents, id) + c.mutex.Unlock() +} - // Write the new node from this client to the actively connected agent. - err := agentSocket.Enqueue([]*agpl.Node{node}) +func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error { + var errs *multierror.Error + // This isn't the most efficient, but this coordinator is being deprecated + // soon anyways. + for _, agent := range agents { + err := c.handleClientUpdate(id, agent, node) if err != nil { - c.log.Debug(ctx, "failed to write to agent", slog.Error(err)) + errs = multierror.Append(errs, err) } } + if errs != nil { + return errs + } return nil } @@ -111,7 +114,8 @@ type haCoordinator struct { // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] - agentCallbacks map[uuid.UUID]map[uuid.UUID]func(uuid.UUID, *agpl.Node) + legacyAgents map[uuid.UUID]struct{} + multiAgents map[uuid.UUID]*agpl.MultiAgent } // Node returns an in-memory node by ID. @@ -203,28 +207,32 @@ func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *js return xerrors.Errorf("read json: %w", err) } + return c.handleClientUpdate(id, agent, &node) +} + +func (c *haCoordinator) handleClientUpdate(id, agent uuid.UUID, node *agpl.Node) error { c.mutex.Lock() // Update the node of this client in our in-memory map. If an agent entirely // shuts down and reconnects, it needs to be aware of all clients attempting // to establish connections. - c.nodes[id] = &node + c.nodes[id] = node + // Write the new node from this client to the actively connected agent. agentSocket, ok := c.agentSockets[agent] - if !ok { c.mutex.Unlock() // If we don't own the agent locally, send it over pubsub to a node that // owns the agent. - err := c.publishNodesToAgent(agent, []*agpl.Node{&node}) + err := c.publishNodesToAgent(agent, []*agpl.Node{node}) if err != nil { return xerrors.Errorf("publish node to agent") } return nil } - err = agentSocket.Enqueue([]*agpl.Node{&node}) + err := agentSocket.Enqueue([]*agpl.Node{node}) c.mutex.Unlock() if err != nil { - return xerrors.Errorf("enqueu nodes: %w", err) + return xerrors.Errorf("enqueue node: %w", err) } return nil } @@ -329,6 +337,13 @@ func (c *haCoordinator) handleClientHello(id uuid.UUID) error { return c.publishAgentToNodes(id, node) } +func (c *haCoordinator) agentIsLegacy(agentID uuid.UUID) bool { + c.mutex.RLock() + _, ok := c.legacyAgents[agentID] + c.mutex.RUnlock() + return ok +} + func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) { var node agpl.Node err := decoder.Decode(&node) @@ -337,6 +352,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( } c.mutex.Lock() + // Keep a cache of all legacy agents. + if len(node.Addresses) > 0 && node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP { + c.legacyAgents[id] = struct{}{} + } + oldNode := c.nodes[id] if oldNode != nil { if oldNode.AsOf.After(node.AsOf) { @@ -356,19 +376,13 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( _ = connectionSocket.Enqueue([]*agpl.Node{&node}) } - wg := sync.WaitGroup{} - cbs := c.agentCallbacks[id] - wg.Add(len(cbs)) - for _, cb := range cbs { - cb := cb - go func() { - cb(id, &node) - wg.Done() - }() + // Publish the new node to every active multiAgent. + for _, multiAgent := range c.multiAgents { + multiAgent.OnAgentUpdate(id, &node) } - wg.Wait() c.mutex.Unlock() + return &node, nil } @@ -407,6 +421,15 @@ func (c *haCoordinator) Close() error { } } + wg.Add(len(c.multiAgents)) + for _, multiAgent := range c.multiAgents { + multiAgent := multiAgent + go func() { + _ = multiAgent.Close() + wg.Done() + }() + } + wg.Wait() return nil } @@ -479,13 +502,12 @@ func (c *haCoordinator) runPubsub(ctx context.Context) error { } go func() { for { - var message []byte select { case <-ctx.Done(): return - case message = <-messageQueue: + case message := <-messageQueue: + c.handlePubsubMessage(ctx, message) } - c.handlePubsubMessage(ctx, message) } }() diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index bd52a9358ba9b..a17cdbf374078 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -106,13 +106,8 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store return c, nil } -func (*pgCoord) SubscribeAgent(agentID uuid.UUID, cb func(agentID uuid.UUID, node *agpl.Node)) func() { - _, _ = agentID, cb - panic("not implemented") // TODO: Implement -} - -func (*pgCoord) BroadcastToAgents(agents []uuid.UUID, node *agpl.Node) error { - _, _ = agents, node +func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { + _, _ = c, id panic("not implemented") // TODO: Implement } diff --git a/scaletest/workspacetraffic/run_test.go b/scaletest/workspacetraffic/run_test.go index e53d408bcd428..c070a906be228 100644 --- a/scaletest/workspacetraffic/run_test.go +++ b/scaletest/workspacetraffic/run_test.go @@ -68,6 +68,7 @@ func TestRun(t *testing.T) { agentCloser := agent.New(agent.Options{ Client: agentClient, }) + ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) t.Cleanup(func() { diff --git a/tailnet/conn.go b/tailnet/conn.go index c7c852f618945..0534cf961771a 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -139,6 +139,7 @@ func NewConn(options *Options) (conn *Conn, err error) { } }() + IP() dialer := &tsdial.Dialer{ Logf: Logger(options.Logger.Named("tsdial")), } @@ -302,22 +303,30 @@ func NewConn(options *Options) (conn *Conn, err error) { return server, nil } -// IP generates a new IP with a static service prefix. -func IP() netip.Addr { - // This is Tailscale's ephemeral service prefix. - // This can be changed easily later-on, because - // all of our nodes are ephemeral. +func maskUUID(uid uuid.UUID) uuid.UUID { + // This is Tailscale's ephemeral service prefix. This can be changed easily + // later-on, because all of our nodes are ephemeral. // fd7a:115c:a1e0 - uid := uuid.New() uid[0] = 0xfd uid[1] = 0x7a uid[2] = 0x11 uid[3] = 0x5c uid[4] = 0xa1 uid[5] = 0xe0 + return uid +} + +// IP generates a random IP with a static service prefix. +func IP() netip.Addr { + uid := maskUUID(uuid.New()) return netip.AddrFrom16(uid) } +// IP generates a new IP from a UUID. +func IPFromUUID(uid uuid.UUID) netip.Addr { + return netip.AddrFrom16(maskUUID(uid)) +} + // Conn is an actively listening Wireguard connection. type Conn struct { dialContext context.Context @@ -352,6 +361,23 @@ type Conn struct { trafficStats *connstats.Statistics } +func (c *Conn) SetAddresses(ips []netip.Prefix) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.netMap.Addresses = ips + + netMapCopy := *c.netMap + c.logger.Debug(context.Background(), "updating network map") + c.wireguardEngine.SetNetworkMap(&netMapCopy) + err := c.reconfig() + if err != nil { + return xerrors.Errorf("reconfig: %w", err) + } + + return nil +} + func (c *Conn) Addresses() []netip.Prefix { c.mutex.Lock() defer c.mutex.Unlock() @@ -460,6 +486,15 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { netMapCopy := *c.netMap c.logger.Debug(context.Background(), "updating network map") c.wireguardEngine.SetNetworkMap(&netMapCopy) + err := c.reconfig() + if err != nil { + return xerrors.Errorf("reconfig: %w", err) + } + + return nil +} + +func (c *Conn) reconfig() error { cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "") if err != nil { return xerrors.Errorf("update wireguard config: %w", err) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index cc4127f02bbfc..a1f17316308ff 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -14,14 +14,15 @@ import ( "sync/atomic" "time" - "cdr.dev/slog" - "github.com/google/uuid" + "github.com/hashicorp/go-multierror" lru "github.com/hashicorp/golang-lru/v2" "golang.org/x/exp/slices" "golang.org/x/xerrors" "tailscale.com/tailcfg" "tailscale.com/types/key" + + "cdr.dev/slog" ) // Coordinator exchanges nodes with agents to establish connections. @@ -45,8 +46,7 @@ type Coordinator interface { // Close closes the coordinator. Close() error - SubscribeAgent(agentID uuid.UUID, cb func(agentID uuid.UUID, node *Node)) func() - BroadcastToAgents(agents []uuid.UUID, node *Node) error + ServeMultiAgent(id uuid.UUID) MultiAgentConn } // Node represents a node in the network. @@ -140,6 +140,30 @@ type coordinator struct { core *core } +func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { + m := (&MultiAgent{ + ID: id, + Logger: c.core.logger, + AgentIsLegacyFunc: c.core.agentIsLegacy, + OnNodeUpdate: c.core.multiAgentUpdate, + OnClose: c.core.removeMultiAgent, + }).Init() + c.core.addMultiAgent(m) + return m +} + +func (c *core) addMultiAgent(m *MultiAgent) { + c.mutex.Lock() + c.multiAgents[m.ID] = m + c.mutex.Unlock() +} + +func (c *core) removeMultiAgent(id uuid.UUID) { + c.mutex.Lock() + delete(c.multiAgents, id) + c.mutex.Unlock() +} + // core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines; // it is protected by a mutex to ensure data stay consistent. type core struct { @@ -159,7 +183,8 @@ type core struct { // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] - agentCallbacks map[uuid.UUID]map[uuid.UUID]func(uuid.UUID, *Node) + legacyAgents map[uuid.UUID]struct{} + multiAgents map[uuid.UUID]*MultiAgent } func newCore(logger slog.Logger) *core { @@ -171,11 +196,12 @@ func newCore(logger slog.Logger) *core { return &core{ logger: logger, closed: false, - nodes: make(map[uuid.UUID]*Node), + nodes: map[uuid.UUID]*Node{}, agentSockets: map[uuid.UUID]*TrackedConn{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*TrackedConn{}, agentNameCache: nameCache, - agentCallbacks: map[uuid.UUID]map[uuid.UUID]func(uuid.UUID, *Node){}, + legacyAgents: map[uuid.UUID]struct{}{}, + multiAgents: map[uuid.UUID]*MultiAgent{}, } } @@ -433,9 +459,14 @@ func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json } func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error { - logger := c.clientLogger(id, agent) c.mutex.Lock() defer c.mutex.Unlock() + + return c.clientNodeUpdateUnlocked(id, agent, node) +} + +func (c *core) clientNodeUpdateUnlocked(id, agent uuid.UUID, node *Node) error { + logger := c.clientLogger(id, agent) // Update the node of this client in our in-memory map. If an agent entirely // shuts down and reconnects, it needs to be aware of all clients attempting // to establish connections. @@ -449,12 +480,30 @@ func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error { err := agentSocket.Enqueue([]*Node{node}) if err != nil { - return xerrors.Errorf("Enqueue node: %w", err) + return xerrors.Errorf("enqueue node: %w", err) } logger.Debug(context.Background(), "enqueued node to agent") return nil } +func (c *core) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *Node) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + var errs *multierror.Error + for _, aid := range agents { + err := c.clientNodeUpdateUnlocked(id, aid, node) + if err != nil { + errs = multierror.Append(errs, err) + } + } + if errs != nil { + return errs + } + + return nil +} + func (c *core) agentLogger(id uuid.UUID) slog.Logger { return c.logger.With(slog.F("agent_id", id)) } @@ -570,11 +619,36 @@ func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder return c.core.agentNodeUpdate(id, &node) } +// This is copied from codersdk because importing it here would cause an import +// cycle. This is just temporary until wsconncache is phased out. +var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") + +// This is temporary until we no longer need to detect for agent backwards +// compatibility. +// See: https://github.com/coder/coder/issues/8218 +func (c *core) agentIsLegacy(agentID uuid.UUID) bool { + c.mutex.RLock() + _, ok := c.legacyAgents[agentID] + c.mutex.RUnlock() + return ok +} + func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error { logger := c.agentLogger(id) c.mutex.Lock() defer c.mutex.Unlock() c.nodes[id] = node + + // Keep a cache of all legacy agents. + if len(node.Addresses) > 0 && node.Addresses[0].Addr() == legacyAgentIP { + c.legacyAgents[id] = struct{}{} + } + + // Publish the new node to every active multiAgent. + for _, multiAgent := range c.multiAgents { + multiAgent.OnAgentUpdate(id, node) + } + connectionSockets, ok := c.agentToConnectionSockets[id] if !ok { logger.Debug(context.Background(), "no client sockets; unable to send node") @@ -595,18 +669,6 @@ func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error { } } - wg := sync.WaitGroup{} - cbs := c.agentCallbacks[id] - wg.Add(len(cbs)) - for _, cb := range cbs { - cb := cb - go func() { - cb(id, node) - wg.Done() - }() - } - - wg.Wait() return nil } @@ -646,6 +708,15 @@ func (c *core) close() error { } } + wg.Add(len(c.multiAgents)) + for _, multiAgent := range c.multiAgents { + multiAgent := multiAgent + go func() { + _ = multiAgent.Close() + wg.Done() + }() + } + c.mutex.Unlock() wg.Wait() @@ -787,44 +858,3 @@ func CoordinatorHTTPDebug( } } } - -func (c *coordinator) SubscribeAgent(agentID uuid.UUID, cb func(agentID uuid.UUID, node *Node)) func() { - c.core.mutex.Lock() - defer c.core.mutex.Unlock() - - id := uuid.New() - cbMap, ok := c.core.agentCallbacks[agentID] - if !ok { - cbMap = map[uuid.UUID]func(uuid.UUID, *Node){} - c.core.agentCallbacks[agentID] = cbMap - } - - cbMap[id] = cb - - return func() { - c.core.mutex.Lock() - defer c.core.mutex.Unlock() - delete(cbMap, id) - } -} - -func (c *coordinator) BroadcastToAgents(agents []uuid.UUID, node *Node) error { - ctx := context.Background() - - for _, id := range agents { - c.core.mutex.Lock() - agentSocket, ok := c.core.agentSockets[id] - c.core.mutex.Unlock() - if !ok { - continue - } - - // Write the new node from this client to the actively connected agent. - err := agentSocket.Enqueue([]*Node{node}) - if err != nil { - c.core.logger.Debug(ctx, "failed to write to agent", slog.Error(err)) - } - } - - return nil -} diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go new file mode 100644 index 0000000000000..3e9f39153ebc2 --- /dev/null +++ b/tailnet/multiagent.go @@ -0,0 +1,119 @@ +package tailnet + +import ( + "context" + "sync" + + "github.com/google/uuid" + + "cdr.dev/slog" +) + +type MultiAgentConn interface { + UpdateSelf(node *Node) error + SubscribeAgent(agentID uuid.UUID, node *Node) error + UnsubscribeAgent(agentID uuid.UUID) + NextUpdate(ctx context.Context) []AgentNode + AgentIsLegacy(agentID uuid.UUID) bool + Close() error +} + +type MultiAgent struct { + mu sync.RWMutex + + ID uuid.UUID + Logger slog.Logger + + AgentIsLegacyFunc func(agentID uuid.UUID) bool + OnNodeUpdate func(id uuid.UUID, agents []uuid.UUID, node *Node) error + OnClose func(id uuid.UUID) + + updates chan AgentNode + subscribedAgents map[uuid.UUID]struct{} +} + +type AgentNode struct { + AgentID uuid.UUID + *Node +} + +func (m *MultiAgent) Init() *MultiAgent { + m.updates = make(chan AgentNode, 128) + m.subscribedAgents = map[uuid.UUID]struct{}{} + return m +} + +func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool { + return m.AgentIsLegacyFunc(agentID) +} + +func (m *MultiAgent) OnAgentUpdate(id uuid.UUID, node *Node) { + m.mu.RLock() + defer m.mu.RUnlock() + + if _, ok := m.subscribedAgents[id]; !ok { + return + } + + select { + case m.updates <- AgentNode{AgentID: id, Node: node}: + default: + m.Logger.Debug(context.Background(), "unable to send node %q to multiagent %q; buffer full", id, m.ID) + } +} + +func (m *MultiAgent) UpdateSelf(node *Node) error { + m.mu.Lock() + agents := make([]uuid.UUID, 0, len(m.subscribedAgents)) + for agent := range m.subscribedAgents { + agents = append(agents, agent) + } + m.mu.Unlock() + + return m.OnNodeUpdate(m.ID, agents, node) +} + +func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID, node *Node) error { + m.mu.Lock() + m.subscribedAgents[agentID] = struct{}{} + m.mu.Unlock() + + return m.OnNodeUpdate(m.ID, []uuid.UUID{agentID}, node) +} + +func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.subscribedAgents, agentID) +} + +func (m *MultiAgent) NextUpdate(ctx context.Context) []AgentNode { + var nodes []AgentNode + +loop: + // Read all buffered nodes. + for { + select { + case <-ctx.Done(): + return nil + + case node := <-m.updates: + nodes = append(nodes, node) + + default: + break loop + } + } + + return nodes +} + +func (m *MultiAgent) Close() error { + m.mu.Lock() + close(m.updates) + m.mu.Unlock() + + m.OnClose(m.ID) + + return nil +} From 2e201dc44e24c81dffe56b72e057f1f735e561b4 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Wed, 28 Jun 2023 06:46:40 +0000 Subject: [PATCH 07/19] fixup! address review comments --- agent/agent.go | 10 --- agent/agent_test.go | 17 ++--- coderd/tailnet.go | 94 ++++---------------------- coderd/tailnet_test.go | 10 +-- coderd/workspaceagents.go | 8 +-- coderd/wsconncache/wsconncache_test.go | 9 +-- codersdk/workspaceagentconn.go | 83 +++++------------------ codersdk/workspaceagents.go | 7 -- enterprise/tailnet/coordinator.go | 33 ++++++++- tailnet/coordinator.go | 7 +- tailnet/multiagent.go | 3 +- 11 files changed, 76 insertions(+), 205 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 58c8e470ebb93..ccb3e4b59249d 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -112,16 +112,6 @@ func New(options Options) Agent { prometheusRegistry = prometheus.NewRegistry() } - if len(options.Addresses) == 0 { - options.Addresses = []netip.Prefix{ - // This is the IP that should be used primarily. - netip.PrefixFrom(tailnet.IP(), 128), - // We also listen on the legacy codersdk.WorkspaceAgentIP. This - // allows for a transition away from wsconncache. - netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), - } - } - ctx, cancelFunc := context.WithCancel(context.Background()) a := &agent{ tailnetListenPort: options.TailnetListenPort, diff --git a/agent/agent_test.go b/agent/agent_test.go index 32c63ad5d8155..4e794ac87f6eb 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1691,14 +1691,16 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati if metadata.DERPMap == nil { metadata.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) } + if metadata.AgentID == uuid.Nil { + metadata.AgentID = uuid.New() + } coordinator := tailnet.NewCoordinator(logger) t.Cleanup(func() { _ = coordinator.Close() }) - agentID := uuid.New() statsCh := make(chan *agentsdk.Stats, 50) fs := afero.NewMemMapFs() - c := agenttest.NewClient(t, agentID, metadata, statsCh, coordinator) + c := agenttest.NewClient(t, metadata.AgentID, metadata, statsCh, coordinator) options := agent.Options{ Client: c, @@ -1731,21 +1733,14 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati }) go func() { defer close(serveClientDone) - coordinator.ServeClient(serverConn, uuid.New(), agentID) + coordinator.ServeClient(serverConn, uuid.New(), metadata.AgentID) }() sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { return conn.UpdateNodes(node, false) }) conn.SetNodeCallback(sendNode) agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ - AgentID: agentID, - GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { - node := coordinator.Node(agentID) - if node == nil { - return nil, xerrors.Errorf("node not found %q", err) - } - return node, nil - }, + AgentID: metadata.AgentID, }) t.Cleanup(func() { _ = agentConn.Close() diff --git a/coderd/tailnet.go b/coderd/tailnet.go index f400da5b3ffe9..e599ba1e6f301 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -114,9 +114,8 @@ func (s *ServerTailnet) watchAgentUpdates() { s.nodesMu.Lock() for _, node := range nodes { - cached, ok := s.agentNodes[node.AgentID] + _, ok := s.agentNodes[node.AgentID] if ok { - cached.node = node.Node toUpdate = append(toUpdate, node.Node) } } @@ -133,7 +132,6 @@ func (s *ServerTailnet) watchAgentUpdates() { } type tailnetNode struct { - node *tailnet.Node lastConnection time.Time } @@ -192,97 +190,32 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) ( return s.DialAgentNetConn(ctx, agentID, network, addr) } -func (s *ServerTailnet) getNode(agentID uuid.UUID) (*tailnet.Node, error) { +func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error { s.nodesMu.Lock() tnode, ok := s.agentNodes[agentID] - // If we don't have the node, fetch it from the coordinator. + // If we don't have the node, subscribe. if !ok { - coord := *s.coordinator.Load() - node := coord.Node(agentID) - // The coordinator doesn't have the node either. Nothing we can do here. - if node == nil { - s.nodesMu.Unlock() - return nil, xerrors.Errorf("node %q not found", agentID.String()) - } - err := s.agentConn.SubscribeAgent(agentID, s.conn.Node()) if err != nil { - return nil, xerrors.Errorf("subscribe agent: %w", err) + return xerrors.Errorf("subscribe agent: %w", err) } tnode = &tailnetNode{ - node: node, lastConnection: time.Now(), } s.agentNodes[agentID] = tnode } s.nodesMu.Unlock() - if len(tnode.node.Addresses) == 0 { - return nil, xerrors.New("agent has no reachable addresses") - } - - // if we didn't already have the node locally, add it to our tailnet. - if !ok { - err := s.conn.UpdateNodes([]*tailnet.Node{tnode.node}, false) - if err != nil { - return nil, xerrors.Errorf("update nodes: %w", err) - } - } - - return tnode.node, nil -} - -func (s *ServerTailnet) awaitNodeExists(ctx context.Context, id uuid.UUID, timeout time.Duration) (*tailnet.Node, error) { - // Short circuit, if the node already exists, don't spend time setting up - // the ticker and loop. - if node, err := s.getNode(id); err == nil { - return node, nil - } - - var ( - ticker = time.NewTicker(10 * time.Millisecond) - - tries int - node *tailnet.Node - err error - ) - defer ticker.Stop() - - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - for { - select { - case <-ctx.Done(): - // return the last error we got from getNode. - return nil, xerrors.Errorf("tries %d, last error: %w", tries, err) - case <-ticker.C: - } - - tries++ - node, err = s.getNode(id) - if err == nil { - return node, nil - } - } -} - -func (*ServerTailnet) nodeIsLegacy(node *tailnet.Node) bool { - return node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP + return nil } func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) { - node, err := s.awaitNodeExists(ctx, agentID, 5*time.Second) - if err != nil { - return nil, nil, xerrors.Errorf("get agent node: %w", err) - } - var ( conn *codersdk.WorkspaceAgentConn ret = func() {} ) - if s.nodeIsLegacy(node) { + if s.agentConn.AgentIsLegacy(agentID) { cconn, release, err := s.cache.Acquire(agentID) if err != nil { return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err) @@ -291,9 +224,13 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code conn = cconn.WorkspaceAgentConn ret = release } else { + err := s.ensureAgent(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("ensure agent: %w", err) + } + conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{ AgentID: agentID, - GetNode: s.getNode, CloseFunc: func() error { return codersdk.ErrSkipClose }, }) } @@ -320,16 +257,9 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, // Since we now have an open conn, be careful to close it if we error // without returning it to the user. - node, err := s.getNode(agentID) - if err != nil { - release() - conn.Close() - return nil, xerrors.New("get agent node") - } - _, rawPort, _ := net.SplitHostPort(addr) port, _ := strconv.ParseUint(rawPort, 10, 16) - ipp := netip.AddrPortFrom(node.Addresses[0].Addr(), uint16(port)) + ipp := netip.AddrPortFrom(tailnet.IPFromUUID(agentID), uint16(port)) var nc net.Conn switch network { diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index e89023a3d2734..a4dd9a0cf78b6 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -15,7 +15,6 @@ import ( "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -179,14 +178,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A }) conn.SetNodeCallback(sendNode) return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ - AgentID: agentID, - GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { - node := coordinator.Node(agentID) - if node == nil { - return nil, xerrors.Errorf("node not found %q", agentID) - } - return node, nil - }, + AgentID: agentID, CloseFunc: func() error { return codersdk.ErrSkipClose }, }), nil }, 0) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index c45cbb2bdf5f7..074080e8a9801 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -770,13 +770,7 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa conn.SetNodeCallback(sendNodes) agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ AgentID: agentID, - GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { - return &tailnet.Node{ - // Since this is a legacy function only used by wsconncache as a - // fallback, we hardcode the node to use the wsconncache IP. - Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, - }, nil - }, + IP: codersdk.WorkspaceAgentIP, CloseFunc: func() error { cancel() _ = clientConn.Close() diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index b755da025a8f4..520077fa6ccb8 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -20,7 +20,6 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/atomic" "go.uber.org/goleak" - "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -197,13 +196,7 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati conn.SetNodeCallback(sendNode) agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ AgentID: agentID, - GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { - node := coordinator.Node(agentID) - if node == nil { - return nil, xerrors.Errorf("node not found %q", agentID) - } - return node, nil - }, + IP: codersdk.WorkspaceAgentIP, }) t.Cleanup(func() { _ = agentConn.Close() diff --git a/codersdk/workspaceagentconn.go b/codersdk/workspaceagentconn.go index 3bdf98c0dee6f..5e097632d44d6 100644 --- a/codersdk/workspaceagentconn.go +++ b/codersdk/workspaceagentconn.go @@ -148,17 +148,16 @@ type WorkspaceAgentConn struct { // @typescript-ignore WorkspaceAgentConnOptions type WorkspaceAgentConnOptions struct { AgentID uuid.UUID - GetNode func(agentID uuid.UUID) (*tailnet.Node, error) + IP netip.Addr CloseFunc func() error } -func (c *WorkspaceAgentConn) getAgentAddress() (netip.Addr, error) { - node, err := c.opts.GetNode(c.opts.AgentID) - if err != nil { - return netip.Addr{}, err +func (c *WorkspaceAgentConn) agentAddress() netip.Addr { + if c.opts.IP.Compare(netip.IPv6Unspecified()) == 0 { + return c.opts.IP } - return node.Addresses[0].Addr(), nil + return tailnet.IPFromUUID(c.opts.AgentID) } // AwaitReachable waits for the agent to be reachable. @@ -166,28 +165,7 @@ func (c *WorkspaceAgentConn) AwaitReachable(ctx context.Context) bool { ctx, span := tracing.StartSpan(ctx) defer span.End() - var ( - addr netip.Addr - err error - ticker = time.NewTicker(10 * time.Millisecond) - ) - defer ticker.Stop() - - for { - addr, err = c.getAgentAddress() - if err == nil { - break - } - - select { - case <-ctx.Done(): - return false - case <-ticker.C: - continue - } - } - - return c.Conn.AwaitReachable(ctx, addr) + return c.Conn.AwaitReachable(ctx, c.agentAddress()) } // Ping pings the agent and returns the round-trip time. @@ -196,12 +174,7 @@ func (c *WorkspaceAgentConn) Ping(ctx context.Context) (time.Duration, bool, *ip ctx, span := tracing.StartSpan(ctx) defer span.End() - addr, err := c.getAgentAddress() - if err != nil { - return 0, false, nil, err - } - - return c.Conn.Ping(ctx, addr) + return c.Conn.Ping(ctx, c.agentAddress()) } // Close ends the connection to the workspace agent. @@ -248,12 +221,7 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - addr, err := c.getAgentAddress() - if err != nil { - return nil, err - } - - conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(addr, WorkspaceAgentReconnectingPTYPort)) + conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentReconnectingPTYPort)) if err != nil { return nil, err } @@ -288,12 +256,7 @@ func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - addr, err := c.getAgentAddress() - if err != nil { - return nil, err - } - - return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(addr, WorkspaceAgentSSHPort)) + return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentSSHPort)) } // SSHClient calls SSH to create a client that uses a weak cipher @@ -329,12 +292,7 @@ func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest. return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - addr, err := c.getAgentAddress() - if err != nil { - return nil, err - } - - speedConn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(addr, WorkspaceAgentSpeedtestPort)) + speedConn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentSpeedtestPort)) if err != nil { return nil, xerrors.Errorf("dial speedtest: %w", err) } @@ -357,20 +315,16 @@ func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, ad return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - agentAddr, err := c.getAgentAddress() - if err != nil { - return nil, err - } - _, rawPort, _ := net.SplitHostPort(addr) port, _ := strconv.ParseUint(rawPort, 10, 16) - ipp := netip.AddrPortFrom(agentAddr, uint16(port)) + ipp := netip.AddrPortFrom(c.agentAddress(), uint16(port)) - if network == "tcp" { + switch network { + case "tcp": return c.Conn.DialContextTCP(ctx, ipp) - } else if network == "udp" { + case "udp": return c.Conn.DialContextUDP(ctx, ipp) - } else { + default: return nil, xerrors.Errorf("unknown network %q", network) } } @@ -411,12 +365,7 @@ func (c *WorkspaceAgentConn) apiRequest(ctx context.Context, method, path string ctx, span := tracing.StartSpan(ctx) defer span.End() - agentAddr, err := c.getAgentAddress() - if err != nil { - return nil, xerrors.Errorf("get agent address: %w", err) - } - - host := net.JoinHostPort(agentAddr.String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort)) + host := net.JoinHostPort(c.agentAddress().String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort)) url := fmt.Sprintf("http://%s%s", host, path) req, err := http.NewRequestWithContext(ctx, method, url, body) diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 763ea773494a1..22d4e7699217e 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -321,13 +321,6 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{ AgentID: agentID, - GetNode: func(agentID uuid.UUID) (*tailnet.Node, error) { - node := latestNode.Load() - if node == nil { - return nil, xerrors.New("node not found") - } - return node, nil - }, CloseFunc: func() error { cancel() <-closed diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 8071d024b0bad..9bbf5bf9aac01 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -58,6 +58,7 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { ID: id, Logger: c.log, AgentIsLegacyFunc: c.agentIsLegacy, + OnSubscribe: c.multiAgentSubscribe, OnNodeUpdate: c.multiAgentUpdate, OnClose: c.removeMultiAgent, }).Init() @@ -77,6 +78,36 @@ func (c *haCoordinator) removeMultiAgent(id uuid.UUID) { c.mutex.Unlock() } +func (c *haCoordinator) multiAgentSubscribe(id, agent uuid.UUID, node *agpl.Node) error { + c.mutex.Lock() + + agentNode, ok := c.nodes[agent] + // If we have the node locally, publish it immediately to the multiagent. + if ok { + multiAgent, ok := c.multiAgents[id] + if !ok { + return xerrors.Errorf("unknown multi agent %q", id) + } + + c.mutex.Unlock() + multiAgent.OnAgentUpdate(agent, agentNode) + } else { + // If we don't have the node locally, notify other coordinators. + c.mutex.Unlock() + err := c.publishClientHello(agent) + if err != nil { + return xerrors.Errorf("publish client hello: %w", err) + } + } + + err := c.handleClientUpdate(id, agent, node) + if err != nil { + return xerrors.Errorf("handle client update: %w", err) + } + + return nil +} + func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error { var errs *multierror.Error // This isn't the most efficient, but this coordinator is being deprecated @@ -136,7 +167,7 @@ func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger { // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. -func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { +func (c *haCoordinator) ServeClient(conn net.Conn, id, agent uuid.UUID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() logger := c.clientLogger(id, agent) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index a1f17316308ff..69eea4f19e9df 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -145,8 +145,11 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { ID: id, Logger: c.core.logger, AgentIsLegacyFunc: c.core.agentIsLegacy, - OnNodeUpdate: c.core.multiAgentUpdate, - OnClose: c.core.removeMultiAgent, + OnSubscribe: func(id, agent uuid.UUID, node *Node) error { + return c.core.multiAgentUpdate(id, []uuid.UUID{agent}, node) + }, + OnNodeUpdate: c.core.multiAgentUpdate, + OnClose: c.core.removeMultiAgent, }).Init() c.core.addMultiAgent(m) return m diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index 3e9f39153ebc2..511eb785ae46e 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -25,6 +25,7 @@ type MultiAgent struct { Logger slog.Logger AgentIsLegacyFunc func(agentID uuid.UUID) bool + OnSubscribe func(id uuid.UUID, agent uuid.UUID, node *Node) error OnNodeUpdate func(id uuid.UUID, agents []uuid.UUID, node *Node) error OnClose func(id uuid.UUID) @@ -78,7 +79,7 @@ func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID, node *Node) error { m.subscribedAgents[agentID] = struct{}{} m.mu.Unlock() - return m.OnNodeUpdate(m.ID, []uuid.UUID{agentID}, node) + return m.OnSubscribe(m.ID, agentID, node) } func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) { From 311ea2b2e0217ebd50a18321ac318f6e58709315 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Wed, 28 Jun 2023 21:28:10 +0000 Subject: [PATCH 08/19] second round --- agent/agent.go | 4 + agent/agenttest/client.go | 3 + coderd/coderd.go | 2 +- coderd/coderdtest/coderdtest.go | 1 + coderd/tailnet.go | 86 +++++------ coderd/tailnet_test.go | 23 +-- coderd/workspaceagents.go | 2 +- coderd/wsconncache/wsconncache_test.go | 11 +- codersdk/workspaceagentconn.go | 7 +- codersdk/workspaceagents.go | 7 - enterprise/tailnet/coordinator.go | 141 +++++++++--------- scaletest/reconnectingpty/run_test.go | 9 +- tailnet/coordinator.go | 190 +++++++++++++++---------- tailnet/multiagent.go | 126 ++++++++++------ 14 files changed, 339 insertions(+), 273 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index ccb3e4b59249d..20f6bd15dc83f 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -510,6 +510,10 @@ func (a *agent) run(ctx context.Context) error { } a.logger.Info(ctx, "fetched manifest", slog.F("manifest", manifest)) + if manifest.AgentID == uuid.Nil { + return xerrors.New("nil agentID returned by manifest") + } + // Expand the directory and send it back to coderd so external // applications that rely on the directory can use it. // diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index c1daa0bfacc9a..0ca1640e79d2f 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -23,6 +23,9 @@ func NewClient(t testing.TB, statsChan chan *agentsdk.Stats, coordinator tailnet.Coordinator, ) *Client { + if manifest.AgentID == uuid.Nil { + manifest.AgentID = agentID + } return &Client{ t: t, agentID: agentID, diff --git a/coderd/coderd.go b/coderd/coderd.go index 43a77676b2688..32d29145343c4 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -356,7 +356,7 @@ func New(options *Options) *API { options.Logger, options.DERPServer, options.DERPMap, - &api.TailnetCoordinator, + *api.TailnetCoordinator.Load(), wsconncache.New(api._dialWorkspaceAgentTailnet, 0), ) if err != nil { diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index f4bf035311e6a..d073b48824bd4 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -109,6 +109,7 @@ type Options struct { GitAuthConfigs []*gitauth.Config TrialGenerator func(context.Context, string) error TemplateScheduleStore schedule.TemplateScheduleStore + Coordinator tailnet.Coordinator HealthcheckFunc func(ctx context.Context, apiKey string) *healthcheck.Report HealthcheckTimeout time.Duration diff --git a/coderd/tailnet.go b/coderd/tailnet.go index e599ba1e6f301..a2ffe7f4dc0e1 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -8,9 +8,7 @@ import ( "net/http/httputil" "net/netip" "net/url" - "strconv" "sync" - "sync/atomic" "time" "github.com/google/uuid" @@ -44,37 +42,41 @@ func NewServerTailnet( logger slog.Logger, derpServer *derp.Server, derpMap *tailcfg.DERPMap, - coord *atomic.Pointer[tailnet.Coordinator], + coord tailnet.Coordinator, cache *wsconncache.Cache, ) (*ServerTailnet, error) { + logger = logger.Named("servertailnet") conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, DERPMap: derpMap, - Logger: logger.Named("tailnet"), + Logger: logger, }) if err != nil { return nil, xerrors.Errorf("create tailnet conn: %w", err) } id := uuid.New() - ma := (*coord.Load()).ServeMultiAgent(id) + ma := coord.ServeMultiAgent(id) serverCtx, cancel := context.WithCancel(ctx) tn := &ServerTailnet{ - ctx: serverCtx, - cancel: cancel, - logger: logger, - conn: conn, - coordinator: coord, - agentConn: ma, - cache: cache, - agentNodes: map[uuid.UUID]*tailnetNode{}, - transport: tailnetTransport.Clone(), + ctx: serverCtx, + cancel: cancel, + logger: logger, + conn: conn, + agentConn: ma, + cache: cache, + agentNodes: map[uuid.UUID]*tailnetNode{}, + transport: tailnetTransport.Clone(), } tn.transport.DialContext = tn.dialContext tn.transport.MaxIdleConnsPerHost = 10 tn.transport.MaxIdleConns = 0 + err = ma.UpdateSelf(conn.Node()) + if err != nil { + tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err)) + } conn.SetNodeCallback(func(node *tailnet.Node) { err := tn.agentConn.UpdateSelf(node) if err != nil { @@ -110,41 +112,28 @@ func (s *ServerTailnet) watchAgentUpdates() { return } - toUpdate := make([]*tailnet.Node, 0) - - s.nodesMu.Lock() - for _, node := range nodes { - _, ok := s.agentNodes[node.AgentID] - if ok { - toUpdate = append(toUpdate, node.Node) - } - } - s.nodesMu.Unlock() - - if len(toUpdate) > 0 { - err := s.conn.UpdateNodes(toUpdate, false) - if err != nil { - s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err)) - return - } + err := s.conn.UpdateNodes(nodes, false) + if err != nil { + s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err)) + return } } } type tailnetNode struct { lastConnection time.Time + close func() } type ServerTailnet struct { ctx context.Context cancel func() - logger slog.Logger - conn *tailnet.Conn - coordinator *atomic.Pointer[tailnet.Coordinator] - agentConn tailnet.MultiAgentConn - cache *wsconncache.Cache - nodesMu sync.Mutex + logger slog.Logger + conn *tailnet.Conn + agentConn tailnet.MultiAgentConn + cache *wsconncache.Cache + nodesMu sync.Mutex // agentNodes is a map of agent tailnetNodes the server wants to keep a // connection to. agentNodes map[uuid.UUID]*tailnetNode @@ -195,14 +184,18 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error { tnode, ok := s.agentNodes[agentID] // If we don't have the node, subscribe. if !ok { - err := s.agentConn.SubscribeAgent(agentID, s.conn.Node()) + s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID)) + closer, err := s.agentConn.SubscribeAgent(agentID) if err != nil { return xerrors.Errorf("subscribe agent: %w", err) } tnode = &tailnetNode{ lastConnection: time.Now(), + close: closer, } s.agentNodes[agentID] = tnode + } else { + tnode.lastConnection = time.Now() } s.nodesMu.Unlock() @@ -216,6 +209,7 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code ) if s.agentConn.AgentIsLegacy(agentID) { + s.logger.Debug(s.ctx, "acquiring legacy agent", slog.F("agent_id", agentID)) cconn, release, err := s.cache.Acquire(agentID) if err != nil { return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err) @@ -229,6 +223,7 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code return nil, nil, xerrors.Errorf("ensure agent: %w", err) } + s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID)) conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{ AgentID: agentID, CloseFunc: func() error { return codersdk.ErrSkipClose }, @@ -257,20 +252,11 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, // Since we now have an open conn, be careful to close it if we error // without returning it to the user. - _, rawPort, _ := net.SplitHostPort(addr) - port, _ := strconv.ParseUint(rawPort, 10, 16) - ipp := netip.AddrPortFrom(tailnet.IPFromUUID(agentID), uint16(port)) - - var nc net.Conn - switch network { - case "tcp": - nc, err = conn.DialContextTCP(ctx, ipp) - case "udp": - nc, err = conn.DialContextUDP(ctx, ipp) - default: + nc, err := conn.DialContext(ctx, network, addr) + if err != nil { release() conn.Close() - return nil, xerrors.Errorf("unknown network %q", network) + return nil, xerrors.Errorf("dial context: %w", err) } return &netConnCloser{Conn: nc, close: func() { diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index a4dd9a0cf78b6..159dc6b2079e7 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -8,7 +8,6 @@ import ( "net/http/httptest" "net/netip" "net/url" - "sync/atomic" "testing" "github.com/google/uuid" @@ -129,18 +128,16 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) derpMap, derpServer := tailnettest.RunDERPAndSTUN(t) manifest := agentsdk.Manifest{ + AgentID: uuid.New(), DERPMap: derpMap, } - var coordPtr atomic.Pointer[tailnet.Coordinator] coordinator := tailnet.NewCoordinator(logger) - coordPtr.Store(&coordinator) t.Cleanup(func() { _ = coordinator.Close() }) - agentID := uuid.New() - c := agenttest.NewClient(t, agentID, manifest, make(chan *agentsdk.Stats, 50), coordinator) + c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coordinator) options := agent.Options{ Client: c, @@ -154,6 +151,11 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A _ = ag.Close() }) + // Wait for the agent to connect. + require.Eventually(t, func() bool { + return coordinator.Node(manifest.AgentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, @@ -171,24 +173,25 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A }) go func() { defer close(serveClientDone) - coordinator.ServeClient(serverConn, uuid.New(), agentID) + coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID) }() sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { return conn.UpdateNodes(node, false) }) conn.SetNodeCallback(sendNode) return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ - AgentID: agentID, + AgentID: manifest.AgentID, + AgentIP: codersdk.WorkspaceAgentIP, CloseFunc: func() error { return codersdk.ErrSkipClose }, }), nil }, 0) serverTailnet, err := coderd.NewServerTailnet( context.Background(), - logger.Named("server"), + logger, derpServer, manifest.DERPMap, - &coordPtr, + coordinator, cache, ) require.NoError(t, err) @@ -197,5 +200,5 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A _ = serverTailnet.Close() }) - return agentID, ag, serverTailnet + return manifest.AgentID, ag, serverTailnet } diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 074080e8a9801..5775212102b80 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -770,7 +770,7 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa conn.SetNodeCallback(sendNodes) agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ AgentID: agentID, - IP: codersdk.WorkspaceAgentIP, + AgentIP: codersdk.WorkspaceAgentIP, CloseFunc: func() error { cancel() _ = clientConn.Close() diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 520077fa6ccb8..f677c52c7ea4f 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -163,16 +163,17 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati t.Cleanup(func() { _ = coordinator.Close() }) - agentID := uuid.New() + manifest.AgentID = uuid.New() closer := agent.New(agent.Options{ Client: &client{ t: t, - agentID: agentID, + agentID: manifest.AgentID, manifest: manifest, coordinator: coordinator, }, Logger: logger.Named("agent"), ReconnectingPTYTimeout: ptyTimeout, + Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, }) t.Cleanup(func() { _ = closer.Close() @@ -189,14 +190,14 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati _ = serverConn.Close() _ = conn.Close() }) - go coordinator.ServeClient(serverConn, uuid.New(), agentID) + go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID) sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { return conn.UpdateNodes(node, false) }) conn.SetNodeCallback(sendNode) agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ - AgentID: agentID, - IP: codersdk.WorkspaceAgentIP, + AgentID: manifest.AgentID, + AgentIP: codersdk.WorkspaceAgentIP, }) t.Cleanup(func() { _ = agentConn.Close() diff --git a/codersdk/workspaceagentconn.go b/codersdk/workspaceagentconn.go index 5e097632d44d6..6b9b6f0d33f44 100644 --- a/codersdk/workspaceagentconn.go +++ b/codersdk/workspaceagentconn.go @@ -148,13 +148,14 @@ type WorkspaceAgentConn struct { // @typescript-ignore WorkspaceAgentConnOptions type WorkspaceAgentConnOptions struct { AgentID uuid.UUID - IP netip.Addr + AgentIP netip.Addr CloseFunc func() error } func (c *WorkspaceAgentConn) agentAddress() netip.Addr { - if c.opts.IP.Compare(netip.IPv6Unspecified()) == 0 { - return c.opts.IP + var emptyIP netip.Addr + if cmp := c.opts.AgentIP.Compare(emptyIP); cmp != 0 { + return c.opts.AgentIP } return tailnet.IPFromUUID(c.opts.AgentID) diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 22d4e7699217e..508b8a9bb8e77 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -11,7 +11,6 @@ import ( "net/http/cookiejar" "net/netip" "strconv" - "sync/atomic" "time" "github.com/google/uuid" @@ -263,7 +262,6 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti }() closed := make(chan struct{}) first := make(chan error) - var latestNode atomic.Pointer[tailnet.Node] go func() { defer close(closed) isFirst := true @@ -292,11 +290,6 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti continue } sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error { - if len(nodes) != 1 { - options.Logger.Warn(ctx, "incorrect number of nodes returned from ServeCoordinator", slog.F("len", len(nodes))) - return nil - } - latestNode.Store(nodes[0]) return conn.UpdateNodes(nodes, false) }) conn.SetNodeCallback(sendNode) diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 9bbf5bf9aac01..8aa2f5579104e 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -39,11 +39,10 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err closeFunc: cancelFunc, close: make(chan struct{}), nodes: map[uuid.UUID]*agpl.Node{}, - agentSockets: map[uuid.UUID]*agpl.TrackedConn{}, - agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn{}, + agentSockets: map[uuid.UUID]agpl.Enqueueable{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable{}, agentNameCache: nameCache, legacyAgents: map[uuid.UUID]struct{}{}, - multiAgents: map[uuid.UUID]*agpl.MultiAgent{}, } if err := coord.runPubsub(ctx); err != nil { @@ -60,52 +59,39 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { AgentIsLegacyFunc: c.agentIsLegacy, OnSubscribe: c.multiAgentSubscribe, OnNodeUpdate: c.multiAgentUpdate, - OnClose: c.removeMultiAgent, }).Init() - c.addMultiAgent(m) return m } -func (c *haCoordinator) addMultiAgent(m *agpl.MultiAgent) { +func (c *haCoordinator) multiAgentSubscribe(enq agpl.Enqueueable, agentID uuid.UUID) (func(), error) { c.mutex.Lock() - c.multiAgents[m.ID] = m - c.mutex.Unlock() -} - -func (c *haCoordinator) removeMultiAgent(id uuid.UUID) { - c.mutex.Lock() - delete(c.multiAgents, id) - c.mutex.Unlock() -} -func (c *haCoordinator) multiAgentSubscribe(id, agent uuid.UUID, node *agpl.Node) error { - c.mutex.Lock() + node := c.nodes[enq.UniqueID()] - agentNode, ok := c.nodes[agent] + agentNode, ok := c.nodes[agentID] // If we have the node locally, publish it immediately to the multiagent. if ok { - multiAgent, ok := c.multiAgents[id] - if !ok { - return xerrors.Errorf("unknown multi agent %q", id) + err := enq.Enqueue([]*agpl.Node{agentNode}) + if err != nil { + return nil, xerrors.Errorf("enqueue agent on subscribe: %w", err) } - - c.mutex.Unlock() - multiAgent.OnAgentUpdate(agent, agentNode) } else { // If we don't have the node locally, notify other coordinators. c.mutex.Unlock() - err := c.publishClientHello(agent) + err := c.publishClientHello(agentID) if err != nil { - return xerrors.Errorf("publish client hello: %w", err) + return nil, xerrors.Errorf("publish client hello: %w", err) } } - err := c.handleClientUpdate(id, agent, node) - if err != nil { - return xerrors.Errorf("handle client update: %w", err) + if node != nil { + err := c.handleClientUpdate(enq.UniqueID(), agentID, node) + if err != nil { + return nil, xerrors.Errorf("handle client update: %w", err) + } } - return nil + return c.cleanupClientConn(enq.UniqueID(), agentID), nil } func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error { @@ -136,17 +122,16 @@ type haCoordinator struct { // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*agpl.Node // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]*agpl.TrackedConn + agentSockets map[uuid.UUID]agpl.Enqueueable // agentToConnectionSockets maps agent IDs to connection IDs of conns that // are subscribed to updates for that agent. - agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] legacyAgents map[uuid.UUID]struct{} - multiAgents map[uuid.UUID]*agpl.MultiAgent } // Node returns an in-memory node by ID. @@ -173,16 +158,9 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agent uuid.UUID) error { logger := c.clientLogger(id, agent) c.mutex.Lock() - connectionSockets, ok := c.agentToConnectionSockets[agent] - if !ok { - connectionSockets = map[uuid.UUID]*agpl.TrackedConn{} - c.agentToConnectionSockets[agent] = connectionSockets - } tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0) - // Insert this connection into a map so the agent - // can publish node updates. - connectionSockets[id] = tc + c.initOrSetAgentConnectionSocketLocked(agent, tc) // When a new connection is requested, we update it with the latest // node of the agent. This allows the connection to establish. @@ -202,21 +180,7 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agent uuid.UUID) error { } go tc.SendUpdates() - defer func() { - c.mutex.Lock() - defer c.mutex.Unlock() - // Clean all traces of this connection from the map. - delete(c.nodes, id) - connectionSockets, ok := c.agentToConnectionSockets[agent] - if !ok { - return - } - delete(connectionSockets, id) - if len(connectionSockets) != 0 { - return - } - delete(c.agentToConnectionSockets, agent) - }() + defer c.cleanupClientConn(id, agent) decoder := json.NewDecoder(conn) // Indefinitely handle messages from the client websocket. @@ -231,6 +195,33 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agent uuid.UUID) error { } } +func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq agpl.Enqueueable) { + connectionSockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + connectionSockets = map[uuid.UUID]agpl.Enqueueable{} + c.agentToConnectionSockets[agentID] = connectionSockets + } + connectionSockets[enq.UniqueID()] = enq +} + +func (c *haCoordinator) cleanupClientConn(id, agentID uuid.UUID) func() { + return func() { + c.mutex.Lock() + defer c.mutex.Unlock() + // Clean all traces of this connection from the map. + delete(c.nodes, id) + connectionSockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + return + } + delete(connectionSockets, id) + if len(connectionSockets) != 0 { + return + } + delete(c.agentToConnectionSockets, agentID) + } +} + func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { var node agpl.Node err := decoder.Decode(&node) @@ -268,6 +259,26 @@ func (c *haCoordinator) handleClientUpdate(id, agent uuid.UUID, node *agpl.Node) return nil } +func (c *haCoordinator) handleClientUpdateLocked(id, agent uuid.UUID, node *agpl.Node) error { + agentSocket, ok := c.agentSockets[agent] + if !ok { + c.mutex.Unlock() + // If we don't own the agent locally, send it over pubsub to a node that + // owns the agent. + err := c.publishNodesToAgent(agent, []*agpl.Node{node}) + if err != nil { + return xerrors.Errorf("publish node to agent") + } + return nil + } + err := agentSocket.Enqueue([]*agpl.Node{node}) + c.mutex.Unlock() + if err != nil { + return xerrors.Errorf("enqueue node: %w", err) + } + return nil +} + // ServeAgent accepts a WebSocket connection to an agent that listens to // incoming connections and publishes node updates. func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { @@ -285,7 +296,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err // dead. oldAgentSocket, ok := c.agentSockets[id] if ok { - overwrites = oldAgentSocket.Overwrites + 1 + overwrites = oldAgentSocket.Overwrites() + 1 _ = oldAgentSocket.Close() } // This uniquely identifies a connection that belongs to this goroutine. @@ -317,7 +328,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err // Only delete the connection if it's ours. It could have been // overwritten. - if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique { + if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique { delete(c.agentSockets, id) delete(c.nodes, id) } @@ -407,11 +418,6 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( _ = connectionSocket.Enqueue([]*agpl.Node{&node}) } - // Publish the new node to every active multiAgent. - for _, multiAgent := range c.multiAgents { - multiAgent.OnAgentUpdate(id, &node) - } - c.mutex.Unlock() return &node, nil @@ -452,15 +458,6 @@ func (c *haCoordinator) Close() error { } } - wg.Add(len(c.multiAgents)) - for _, multiAgent := range c.multiAgents { - multiAgent := multiAgent - go func() { - _ = multiAgent.Close() - wg.Done() - }() - } - wg.Wait() return nil } diff --git a/scaletest/reconnectingpty/run_test.go b/scaletest/reconnectingpty/run_test.go index f6f70bbf574bf..382a3718436f9 100644 --- a/scaletest/reconnectingpty/run_test.go +++ b/scaletest/reconnectingpty/run_test.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" "github.com/coder/coder/coderd/coderdtest" @@ -243,7 +244,7 @@ func Test_Runner(t *testing.T) { func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID) { t.Helper() - client = coderdtest.New(t, &coderdtest.Options{ + client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, }) user := coderdtest.CreateFirstUser(t, client) @@ -282,12 +283,16 @@ func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID) agentClient.SetSessionToken(authToken) agentCloser := agent.New(agent.Options{ Client: agentClient, - Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("agent"), + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("agent").Leveled(slog.LevelDebug), }) t.Cleanup(func() { _ = agentCloser.Close() }) resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + require.Eventually(t, func() bool { + t.Log("agent id", resources[0].Agents[0].ID) + return (*api.TailnetCoordinator.Load()).Node(resources[0].Agents[0].ID) != nil + }, testutil.WaitLong, testutil.IntervalMedium, "agent never connected") return client, resources[0].Agents[0].ID } diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 69eea4f19e9df..b25cbe63bc336 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -145,28 +145,13 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { ID: id, Logger: c.core.logger, AgentIsLegacyFunc: c.core.agentIsLegacy, - OnSubscribe: func(id, agent uuid.UUID, node *Node) error { - return c.core.multiAgentUpdate(id, []uuid.UUID{agent}, node) - }, - OnNodeUpdate: c.core.multiAgentUpdate, - OnClose: c.core.removeMultiAgent, + OnSubscribe: c.core.multiAgentSubscribe, + OnNodeUpdate: c.core.multiAgentUpdate, + // OnClose: c.core.removeMultiAgent, }).Init() - c.core.addMultiAgent(m) return m } -func (c *core) addMultiAgent(m *MultiAgent) { - c.mutex.Lock() - c.multiAgents[m.ID] = m - c.mutex.Unlock() -} - -func (c *core) removeMultiAgent(id uuid.UUID) { - c.mutex.Lock() - delete(c.multiAgents, id) - c.mutex.Unlock() -} - // core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines; // it is protected by a mutex to ensure data stay consistent. type core struct { @@ -177,17 +162,25 @@ type core struct { // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*Node // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]*TrackedConn + agentSockets map[uuid.UUID]Enqueueable // agentToConnectionSockets maps agent IDs to connection IDs of conns that // are subscribed to updates for that agent. - agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*TrackedConn + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Enqueueable // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] legacyAgents map[uuid.UUID]struct{} - multiAgents map[uuid.UUID]*MultiAgent +} + +type Enqueueable interface { + UniqueID() uuid.UUID + Enqueue(n []*Node) error + Name() string + Stats() (start, lastWrite int64) + Overwrites() int64 + Close() error } func newCore(logger slog.Logger) *core { @@ -200,11 +193,10 @@ func newCore(logger slog.Logger) *core { logger: logger, closed: false, nodes: map[uuid.UUID]*Node{}, - agentSockets: map[uuid.UUID]*TrackedConn{}, - agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*TrackedConn{}, + agentSockets: map[uuid.UUID]Enqueueable{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Enqueueable{}, agentNameCache: nameCache, legacyAgents: map[uuid.UUID]struct{}{}, - multiAgents: map[uuid.UUID]*MultiAgent{}, } } @@ -220,16 +212,16 @@ type TrackedConn struct { // ID is an ephemeral UUID used to uniquely identify the owner of the // connection. - ID uuid.UUID + id uuid.UUID - Name string - Start int64 - LastWrite int64 - Overwrites int64 + name string + start int64 + lastWrite int64 + overwrites int64 } func (t *TrackedConn) Enqueue(n []*Node) (err error) { - atomic.StoreInt64(&t.LastWrite, time.Now().Unix()) + atomic.StoreInt64(&t.lastWrite, time.Now().Unix()) select { case t.updates <- n: return nil @@ -238,6 +230,22 @@ func (t *TrackedConn) Enqueue(n []*Node) (err error) { } } +func (t *TrackedConn) UniqueID() uuid.UUID { + return t.id +} + +func (t *TrackedConn) Name() string { + return t.name +} + +func (t *TrackedConn) Stats() (start, lastWrite int64) { + return t.start, atomic.LoadInt64(&t.lastWrite) +} + +func (t *TrackedConn) Overwrites() int64 { + return t.overwrites +} + // Close the connection and cancel the context for reading node updates from the queue func (t *TrackedConn) Close() error { t.cancel() @@ -315,10 +323,10 @@ func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.U cancel: cancel, updates: updates, logger: logger, - ID: id, - Start: now, - LastWrite: now, - Overwrites: overwrites, + id: id, + start: now, + lastWrite: now, + overwrites: overwrites, } } @@ -420,16 +428,20 @@ func (c *core) initAndTrackClient( // Insert this connection into a map so the agent // can publish node updates. - connectionSockets, ok := c.agentToConnectionSockets[agent] - if !ok { - connectionSockets = map[uuid.UUID]*TrackedConn{} - c.agentToConnectionSockets[agent] = connectionSockets - } - connectionSockets[id] = tc + c.initOrSetAgentConnectionSocketLocked(agent, tc) logger.Debug(ctx, "added tracked connection") return tc, nil } +func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Enqueueable) { + connectionSockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + connectionSockets = map[uuid.UUID]Enqueueable{} + c.agentToConnectionSockets[agentID] = connectionSockets + } + connectionSockets[enq.UniqueID()] = enq +} + func (c *core) clientDisconnected(id, agent uuid.UUID) { logger := c.clientLogger(id, agent) c.mutex.Lock() @@ -465,16 +477,17 @@ func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error { c.mutex.Lock() defer c.mutex.Unlock() - return c.clientNodeUpdateUnlocked(id, agent, node) -} - -func (c *core) clientNodeUpdateUnlocked(id, agent uuid.UUID, node *Node) error { - logger := c.clientLogger(id, agent) // Update the node of this client in our in-memory map. If an agent entirely // shuts down and reconnects, it needs to be aware of all clients attempting // to establish connections. c.nodes[id] = node + return c.clientNodeUpdateLocked(id, agent, node) +} + +func (c *core) clientNodeUpdateLocked(id, agent uuid.UUID, node *Node) error { + logger := c.clientLogger(id, agent) + agentSocket, ok := c.agentSockets[agent] if !ok { logger.Debug(context.Background(), "no agent socket, unable to send node") @@ -489,13 +502,49 @@ func (c *core) clientNodeUpdateUnlocked(id, agent uuid.UUID, node *Node) error { return nil } +func (c *core) multiAgentSubscribe(enq Enqueueable, agentID uuid.UUID) (func(), error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.initOrSetAgentConnectionSocketLocked(agentID, enq) + + node, ok := c.nodes[enq.UniqueID()] + if ok { + // If we have the node, send it to the agent. If not, it will be sent + // async. + err := c.clientNodeUpdateLocked(enq.UniqueID(), agentID, node) + if err != nil { + return nil, xerrors.Errorf("send update to agent: %w", err) + } + } else { + c.logger.Debug(context.Background(), "multiagent node doesn't exist", slog.F("multiagent_id", enq.UniqueID())) + } + + closer := func() {} + + agentNode, ok := c.nodes[agentID] + if !ok { + // This is ok, once the agent connects the node will be sent over. + c.logger.Debug(context.Background(), "agent node doesn't exist", slog.F("agent_id", agentID)) + return closer, nil + } + + // Send the subscribed agent back to the multi agent. + err := enq.Enqueue([]*Node{agentNode}) + return closer, err +} + func (c *core) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *Node) error { c.mutex.Lock() defer c.mutex.Unlock() + // Update the node of this client in our in-memory map. If an agent entirely + // shuts down and reconnects, it needs to be aware of all clients attempting + // to establish connections. + c.nodes[id] = node var errs *multierror.Error for _, aid := range agents { - err := c.clientNodeUpdateUnlocked(id, aid, node) + err := c.clientNodeUpdateLocked(id, aid, node) if err != nil { errs = multierror.Append(errs, err) } @@ -551,7 +600,7 @@ func (c *core) agentDisconnected(id, unique uuid.UUID) { // Only delete the connection if it's ours. It could have been // overwritten. - if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique { + if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique { delete(c.agentSockets, id) delete(c.nodes, id) logger.Debug(context.Background(), "deleted agent socket and node") @@ -577,7 +626,7 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co // dead. oldAgentSocket, ok := c.agentSockets[id] if ok { - overwrites = oldAgentSocket.Overwrites + 1 + overwrites = oldAgentSocket.Overwrites() + 1 _ = oldAgentSocket.Close() } tc := NewTrackedConn(ctx, cancel, conn, unique, logger, overwrites) @@ -647,11 +696,6 @@ func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error { c.legacyAgents[id] = struct{}{} } - // Publish the new node to every active multiAgent. - for _, multiAgent := range c.multiAgents { - multiAgent.OnAgentUpdate(id, node) - } - connectionSockets, ok := c.agentToConnectionSockets[id] if !ok { logger.Debug(context.Background(), "no client sockets; unable to send node") @@ -711,15 +755,6 @@ func (c *core) close() error { } } - wg.Add(len(c.multiAgents)) - for _, multiAgent := range c.multiAgents { - multiAgent := multiAgent - go func() { - _ = multiAgent.Close() - wg.Done() - }() - } - c.mutex.Unlock() wg.Wait() @@ -742,8 +777,8 @@ func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) { } func CoordinatorHTTPDebug( - agentSocketsMap map[uuid.UUID]*TrackedConn, - agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]*TrackedConn, + agentSocketsMap map[uuid.UUID]Enqueueable, + agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]Enqueueable, agentNameCache *lru.Cache[uuid.UUID, string], ) func(w http.ResponseWriter, _ *http.Request) { return func(w http.ResponseWriter, _ *http.Request) { @@ -751,7 +786,7 @@ func CoordinatorHTTPDebug( type idConn struct { id uuid.UUID - conn *TrackedConn + conn Enqueueable } { @@ -764,16 +799,17 @@ func CoordinatorHTTPDebug( } slices.SortFunc(agentSockets, func(a, b idConn) bool { - return a.conn.Name < b.conn.Name + return a.conn.Name() < b.conn.Name() }) for _, agent := range agentSockets { + start, lastWrite := agent.conn.Stats() _, _ = fmt.Fprintf(w, "
  • %s (%s): created %v ago, write %v ago, overwrites %d
  • \n", - agent.conn.Name, + agent.conn.Name(), agent.id.String(), - now.Sub(time.Unix(agent.conn.Start, 0)).Round(time.Second), - now.Sub(time.Unix(agent.conn.LastWrite, 0)).Round(time.Second), - agent.conn.Overwrites, + now.Sub(time.Unix(start, 0)).Round(time.Second), + now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), + agent.conn.Overwrites(), ) if conns := agentToConnectionSocketsMap[agent.id]; len(conns) > 0 { @@ -789,11 +825,12 @@ func CoordinatorHTTPDebug( _, _ = fmt.Fprintln(w, "
      ") for _, connSocket := range connSockets { + start, lastWrite := connSocket.conn.Stats() _, _ = fmt.Fprintf(w, "
    • %s (%s): created %v ago, write %v ago
    • \n", - connSocket.conn.Name, + connSocket.conn.Name(), connSocket.id.String(), - now.Sub(time.Unix(connSocket.conn.Start, 0)).Round(time.Second), - now.Sub(time.Unix(connSocket.conn.LastWrite, 0)).Round(time.Second), + now.Sub(time.Unix(start, 0)).Round(time.Second), + now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), ) } _, _ = fmt.Fprintln(w, "
    ") @@ -848,11 +885,12 @@ func CoordinatorHTTPDebug( _, _ = fmt.Fprintf(w, "

    connections: total %d

    \n", len(agentConns.conns)) _, _ = fmt.Fprintln(w, "
      ") for _, agentConn := range agentConns.conns { + start, lastWrite := agentConn.conn.Stats() _, _ = fmt.Fprintf(w, "
    • %s (%s): created %v ago, write %v ago
    • \n", - agentConn.conn.Name, + agentConn.conn.Name(), agentConn.id.String(), - now.Sub(time.Unix(agentConn.conn.Start, 0)).Round(time.Second), - now.Sub(time.Unix(agentConn.conn.LastWrite, 0)).Round(time.Second), + now.Sub(time.Unix(start, 0)).Round(time.Second), + now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), ) } _, _ = fmt.Fprintln(w, "
    ") diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index 511eb785ae46e..f160b094736ae 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -3,6 +3,7 @@ package tailnet import ( "context" "sync" + "sync/atomic" "github.com/google/uuid" @@ -11,9 +12,9 @@ import ( type MultiAgentConn interface { UpdateSelf(node *Node) error - SubscribeAgent(agentID uuid.UUID, node *Node) error + SubscribeAgent(agentID uuid.UUID) (func(), error) UnsubscribeAgent(agentID uuid.UUID) - NextUpdate(ctx context.Context) []AgentNode + NextUpdate(ctx context.Context) []*Node AgentIsLegacy(agentID uuid.UUID) bool Close() error } @@ -25,22 +26,16 @@ type MultiAgent struct { Logger slog.Logger AgentIsLegacyFunc func(agentID uuid.UUID) bool - OnSubscribe func(id uuid.UUID, agent uuid.UUID, node *Node) error + OnSubscribe func(enq Enqueueable, agent uuid.UUID) (close func(), err error) OnNodeUpdate func(id uuid.UUID, agents []uuid.UUID, node *Node) error - OnClose func(id uuid.UUID) - updates chan AgentNode - subscribedAgents map[uuid.UUID]struct{} -} - -type AgentNode struct { - AgentID uuid.UUID - *Node + updates chan []*Node + subscribedAgents map[uuid.UUID]func() } func (m *MultiAgent) Init() *MultiAgent { - m.updates = make(chan AgentNode, 128) - m.subscribedAgents = map[uuid.UUID]struct{}{} + m.updates = make(chan []*Node, 128) + m.subscribedAgents = map[uuid.UUID]func(){} return m } @@ -48,21 +43,6 @@ func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool { return m.AgentIsLegacyFunc(agentID) } -func (m *MultiAgent) OnAgentUpdate(id uuid.UUID, node *Node) { - m.mu.RLock() - defer m.mu.RUnlock() - - if _, ok := m.subscribedAgents[id]; !ok { - return - } - - select { - case m.updates <- AgentNode{AgentID: id, Node: node}: - default: - m.Logger.Debug(context.Background(), "unable to send node %q to multiagent %q; buffer full", id, m.ID) - } -} - func (m *MultiAgent) UpdateSelf(node *Node) error { m.mu.Lock() agents := make([]uuid.UUID, 0, len(m.subscribedAgents)) @@ -74,47 +54,101 @@ func (m *MultiAgent) UpdateSelf(node *Node) error { return m.OnNodeUpdate(m.ID, agents, node) } -func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID, node *Node) error { +func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) (func(), error) { m.mu.Lock() - m.subscribedAgents[agentID] = struct{}{} - m.mu.Unlock() + defer m.mu.Unlock() + + if closer, ok := m.subscribedAgents[agentID]; ok { + return closer, nil + } - return m.OnSubscribe(m.ID, agentID, node) + closer, err := m.OnSubscribe(m.enqueuer(agentID), agentID) + if err != nil { + return nil, err + } + m.subscribedAgents[agentID] = closer + return closer, nil } func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) { m.mu.Lock() defer m.mu.Unlock() + + if closer, ok := m.subscribedAgents[agentID]; ok { + closer() + } delete(m.subscribedAgents, agentID) } -func (m *MultiAgent) NextUpdate(ctx context.Context) []AgentNode { - var nodes []AgentNode - -loop: - // Read all buffered nodes. +func (m *MultiAgent) NextUpdate(ctx context.Context) []*Node { for { select { case <-ctx.Done(): return nil - case node := <-m.updates: - nodes = append(nodes, node) - - default: - break loop + case nodes := <-m.updates: + return nodes } } +} + +func (m *MultiAgent) enqueuer(agentID uuid.UUID) Enqueueable { + return &multiAgentEnqueuer{ + agentID: agentID, + m: m, + } +} + +type multiAgentEnqueuer struct { + m *MultiAgent + + agentID uuid.UUID + start int64 + lastWrite int64 + overwrites int64 +} + +func (m *multiAgentEnqueuer) UniqueID() uuid.UUID { + return m.m.ID +} + +func (m *multiAgentEnqueuer) Enqueue(nodes []*Node) error { + select { + case m.m.updates <- nodes: + return nil + default: + return ErrWouldBlock + } +} - return nodes +func (m *multiAgentEnqueuer) Name() string { + return "multiagent-" + m.m.ID.String() +} + +func (m *multiAgentEnqueuer) Stats() (start int64, lastWrite int64) { + return m.start, atomic.LoadInt64(&m.lastWrite) +} + +func (m *multiAgentEnqueuer) Overwrites() int64 { + return m.overwrites +} + +func (m *multiAgentEnqueuer) Close() error { + m.m.mu.Lock() + defer m.m.mu.Unlock() + + // Delete without running the closer. If the enqueuer itself gets closed, we + // can assume that the caller is removing it from the coordinator. + delete(m.m.subscribedAgents, m.agentID) + return nil } func (m *MultiAgent) Close() error { m.mu.Lock() close(m.updates) + for _, closer := range m.subscribedAgents { + closer() + } m.mu.Unlock() - - m.OnClose(m.ID) - return nil } From 30aefcb1a6e84fa55ec7a3cd933f9c0b089a72e9 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Wed, 28 Jun 2023 21:39:21 +0000 Subject: [PATCH 09/19] fixup! second round --- enterprise/tailnet/coordinator.go | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 8aa2f5579104e..f6cf755f9ba2f 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -259,26 +259,6 @@ func (c *haCoordinator) handleClientUpdate(id, agent uuid.UUID, node *agpl.Node) return nil } -func (c *haCoordinator) handleClientUpdateLocked(id, agent uuid.UUID, node *agpl.Node) error { - agentSocket, ok := c.agentSockets[agent] - if !ok { - c.mutex.Unlock() - // If we don't own the agent locally, send it over pubsub to a node that - // owns the agent. - err := c.publishNodesToAgent(agent, []*agpl.Node{node}) - if err != nil { - return xerrors.Errorf("publish node to agent") - } - return nil - } - err := agentSocket.Enqueue([]*agpl.Node{node}) - c.mutex.Unlock() - if err != nil { - return xerrors.Errorf("enqueue node: %w", err) - } - return nil -} - // ServeAgent accepts a WebSocket connection to an agent that listens to // incoming connections and publishes node updates. func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { From e55e1468321115e29df777edb55b8b2f7ffe1f5a Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 29 Jun 2023 01:27:56 +0000 Subject: [PATCH 10/19] support swapping coordinators --- coderd/coderd.go | 2 +- coderd/tailnet.go | 47 ++++++++++++++++++++++++++++++++---------- coderd/tailnet_test.go | 15 ++++++++------ tailnet/coordinator.go | 18 +++++++++++++++- tailnet/multiagent.go | 29 ++++++++++++++++++++++++-- 5 files changed, 90 insertions(+), 21 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 32d29145343c4..43a77676b2688 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -356,7 +356,7 @@ func New(options *Options) *API { options.Logger, options.DERPServer, options.DERPMap, - *api.TailnetCoordinator.Load(), + &api.TailnetCoordinator, wsconncache.New(api._dialWorkspaceAgentTailnet, 0), ) if err != nil { diff --git a/coderd/tailnet.go b/coderd/tailnet.go index a2ffe7f4dc0e1..13bc436e1b3df 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -9,6 +9,7 @@ import ( "net/netip" "net/url" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -42,7 +43,7 @@ func NewServerTailnet( logger slog.Logger, derpServer *derp.Server, derpMap *tailcfg.DERPMap, - coord tailnet.Coordinator, + coord *atomic.Pointer[tailnet.Coordinator], cache *wsconncache.Cache, ) (*ServerTailnet, error) { logger = logger.Named("servertailnet") @@ -55,16 +56,13 @@ func NewServerTailnet( return nil, xerrors.Errorf("create tailnet conn: %w", err) } - id := uuid.New() - ma := coord.ServeMultiAgent(id) - serverCtx, cancel := context.WithCancel(ctx) tn := &ServerTailnet{ ctx: serverCtx, cancel: cancel, logger: logger, conn: conn, - agentConn: ma, + coord: coord, cache: cache, agentNodes: map[uuid.UUID]*tailnetNode{}, transport: tailnetTransport.Clone(), @@ -72,13 +70,15 @@ func NewServerTailnet( tn.transport.DialContext = tn.dialContext tn.transport.MaxIdleConnsPerHost = 10 tn.transport.MaxIdleConns = 0 + agentConn := (*coord.Load()).ServeMultiAgent(uuid.New()) + tn.agentConn.Store(&agentConn) - err = ma.UpdateSelf(conn.Node()) + err = tn.getAgentConn().UpdateSelf(conn.Node()) if err != nil { tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err)) } conn.SetNodeCallback(func(node *tailnet.Node) { - err := tn.agentConn.UpdateSelf(node) + err := tn.getAgentConn().UpdateSelf(node) if err != nil { tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err)) } @@ -107,8 +107,12 @@ func NewServerTailnet( func (s *ServerTailnet) watchAgentUpdates() { for { - nodes := s.agentConn.NextUpdate(s.ctx) + nodes := s.getAgentConn().NextUpdate(s.ctx) if nodes == nil { + if s.getAgentConn().IsClosed() && s.ctx.Err() == nil { + s.reinitCoordinator() + continue + } return } @@ -120,6 +124,26 @@ func (s *ServerTailnet) watchAgentUpdates() { } } +func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn { + return *s.agentConn.Load() +} + +func (s *ServerTailnet) reinitCoordinator() { + agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New()) + s.agentConn.Store(&agentConn) + + s.nodesMu.Lock() + // Resubscribe to all of the agents we're tracking. + for agentID, agentNode := range s.agentNodes { + closer, err := agentConn.SubscribeAgent(agentID) + if err != nil { + s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID)) + } + agentNode.close = closer + } + s.nodesMu.Unlock() +} + type tailnetNode struct { lastConnection time.Time close func() @@ -131,7 +155,8 @@ type ServerTailnet struct { logger slog.Logger conn *tailnet.Conn - agentConn tailnet.MultiAgentConn + coord *atomic.Pointer[tailnet.Coordinator] + agentConn atomic.Pointer[tailnet.MultiAgentConn] cache *wsconncache.Cache nodesMu sync.Mutex // agentNodes is a map of agent tailnetNodes the server wants to keep a @@ -185,7 +210,7 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error { // If we don't have the node, subscribe. if !ok { s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID)) - closer, err := s.agentConn.SubscribeAgent(agentID) + closer, err := s.getAgentConn().SubscribeAgent(agentID) if err != nil { return xerrors.Errorf("subscribe agent: %w", err) } @@ -208,7 +233,7 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code ret = func() {} ) - if s.agentConn.AgentIsLegacy(agentID) { + if s.getAgentConn().AgentIsLegacy(agentID) { s.logger.Debug(s.ctx, "acquiring legacy agent", slog.F("agent_id", agentID)) cconn, release, err := s.cache.Acquire(agentID) if err != nil { diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index 159dc6b2079e7..16d597607312c 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "net/netip" "net/url" + "sync/atomic" "testing" "github.com/google/uuid" @@ -132,12 +133,14 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A DERPMap: derpMap, } - coordinator := tailnet.NewCoordinator(logger) + var coordPtr atomic.Pointer[tailnet.Coordinator] + coord := tailnet.NewCoordinator(logger) + coordPtr.Store(&coord) t.Cleanup(func() { - _ = coordinator.Close() + _ = coord.Close() }) - c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coordinator) + c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord) options := agent.Options{ Client: c, @@ -153,7 +156,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A // Wait for the agent to connect. require.Eventually(t, func() bool { - return coordinator.Node(manifest.AgentID) != nil + return coord.Node(manifest.AgentID) != nil }, testutil.WaitShort, testutil.IntervalFast) cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { @@ -173,7 +176,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A }) go func() { defer close(serveClientDone) - coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID) + coord.ServeClient(serverConn, uuid.New(), manifest.AgentID) }() sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { return conn.UpdateNodes(node, false) @@ -191,7 +194,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A logger, derpServer, manifest.DERPMap, - coordinator, + &coordPtr, cache, ) require.NoError(t, err) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index b25cbe63bc336..98bb8a946cfb2 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -147,11 +147,17 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { AgentIsLegacyFunc: c.core.agentIsLegacy, OnSubscribe: c.core.multiAgentSubscribe, OnNodeUpdate: c.core.multiAgentUpdate, - // OnClose: c.core.removeMultiAgent, }).Init() + c.core.addMultiAgent(id, m) return m } +func (c *core) addMultiAgent(id uuid.UUID, ma *MultiAgent) { + c.mutex.Lock() + c.multiAgents[id] = ma + c.mutex.Unlock() +} + // core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines; // it is protected by a mutex to ensure data stay consistent. type core struct { @@ -172,6 +178,11 @@ type core struct { agentNameCache *lru.Cache[uuid.UUID, string] legacyAgents map[uuid.UUID]struct{} + // multiAgents holds all of the unique multiAgents listening on this + // coordinator. We need to keep track of these separately because we need to + // make sure they're closed on coordinator shutdown. If not, they won't be + // able to reopen another multiAgent on the new coordinator. + multiAgents map[uuid.UUID]*MultiAgent } type Enqueueable interface { @@ -197,6 +208,7 @@ func newCore(logger slog.Logger) *core { agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Enqueueable{}, agentNameCache: nameCache, legacyAgents: map[uuid.UUID]struct{}{}, + multiAgents: map[uuid.UUID]*MultiAgent{}, } } @@ -755,6 +767,10 @@ func (c *core) close() error { } } + for _, multiAgent := range c.multiAgents { + multiAgent.CoordinatorClose() + } + c.mutex.Unlock() wg.Wait() diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index f160b094736ae..393b51d49de35 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -17,10 +17,12 @@ type MultiAgentConn interface { NextUpdate(ctx context.Context) []*Node AgentIsLegacy(agentID uuid.UUID) bool Close() error + IsClosed() bool } type MultiAgent struct { - mu sync.RWMutex + mu sync.RWMutex + closed chan struct{} ID uuid.UUID Logger slog.Logger @@ -34,6 +36,7 @@ type MultiAgent struct { } func (m *MultiAgent) Init() *MultiAgent { + m.closed = make(chan struct{}) m.updates = make(chan []*Node, 128) m.subscribedAgents = map[uuid.UUID]func(){} return m @@ -143,12 +146,34 @@ func (m *multiAgentEnqueuer) Close() error { return nil } +func (m *MultiAgent) IsClosed() bool { + select { + case <-m.closed: + return true + default: + return false + } +} + +func (m *MultiAgent) CoordinatorClose() { + m.mu.Lock() + if !m.IsClosed() { + close(m.closed) + close(m.updates) + } + m.mu.Unlock() +} + func (m *MultiAgent) Close() error { m.mu.Lock() + defer m.mu.Unlock() + if m.IsClosed() { + return nil + } + close(m.closed) close(m.updates) for _, closer := range m.subscribedAgents { closer() } - m.mu.Unlock() return nil } From 457470da10d883c7e10ce5ff7937dd1e8e09a450 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Thu, 29 Jun 2023 22:54:52 +0000 Subject: [PATCH 11/19] use multiagents for all clients --- coderd/tailnet.go | 54 ++++- enterprise/tailnet/coordinator.go | 144 +++++++----- tailnet/coordinator.go | 373 +++++++++++------------------- tailnet/multiagent.go | 151 ++++-------- tailnet/trackedconn.go | 146 ++++++++++++ 5 files changed, 457 insertions(+), 411 deletions(-) create mode 100644 tailnet/trackedconn.go diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 13bc436e1b3df..c7b2db7f81e63 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -34,8 +34,6 @@ func init() { } } -// TODO(coadler): ServerTailnet does not currently remove stale peers. - // NewServerTailnet creates a new tailnet intended for use by coderd. It // automatically falls back to wsconncache if a legacy agent is encountered. func NewServerTailnet( @@ -102,14 +100,49 @@ func NewServerTailnet( }) go tn.watchAgentUpdates() + go tn.expireOldAgents() return tn, nil } +func (s *ServerTailnet) expireOldAgents() { + const ( + tick = 5 * time.Minute + cutoff = 30 * time.Minute + ) + + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + } + + s.nodesMu.Lock() + agentConn := s.getAgentConn() + for agentID, node := range s.agentNodes { + if time.Since(node.lastConnection) > cutoff { + err := agentConn.UnsubscribeAgent(agentID) + if err != nil { + s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID)) + } + delete(s.agentNodes, agentID) + + // TODO(coadler): actually remove from the netmap + } + } + s.nodesMu.Unlock() + } +} + func (s *ServerTailnet) watchAgentUpdates() { for { - nodes := s.getAgentConn().NextUpdate(s.ctx) - if nodes == nil { - if s.getAgentConn().IsClosed() && s.ctx.Err() == nil { + conn := s.getAgentConn() + nodes, ok := conn.NextUpdate(s.ctx) + if !ok { + if conn.IsClosed() && s.ctx.Err() == nil { s.reinitCoordinator() continue } @@ -129,24 +162,22 @@ func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn { } func (s *ServerTailnet) reinitCoordinator() { + s.nodesMu.Lock() agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New()) s.agentConn.Store(&agentConn) - s.nodesMu.Lock() // Resubscribe to all of the agents we're tracking. - for agentID, agentNode := range s.agentNodes { - closer, err := agentConn.SubscribeAgent(agentID) + for agentID := range s.agentNodes { + err := agentConn.SubscribeAgent(agentID) if err != nil { s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID)) } - agentNode.close = closer } s.nodesMu.Unlock() } type tailnetNode struct { lastConnection time.Time - close func() } type ServerTailnet struct { @@ -210,13 +241,12 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error { // If we don't have the node, subscribe. if !ok { s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID)) - closer, err := s.getAgentConn().SubscribeAgent(agentID) + err := s.getAgentConn().SubscribeAgent(agentID) if err != nil { return xerrors.Errorf("subscribe agent: %w", err) } tnode = &tailnetNode{ lastConnection: time.Now(), - close: closer, } s.agentNodes[agentID] = tnode } else { diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index f6cf755f9ba2f..3deedeb9d3a7c 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -12,7 +12,6 @@ import ( "sync" "github.com/google/uuid" - "github.com/hashicorp/go-multierror" lru "github.com/hashicorp/golang-lru/v2" "golang.org/x/xerrors" @@ -42,6 +41,8 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err agentSockets: map[uuid.UUID]agpl.Enqueueable{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable{}, agentNameCache: nameCache, + clients: map[uuid.UUID]agpl.Enqueueable{}, + clientsToAgents: map[uuid.UUID]map[uuid.UUID]struct{}{}, legacyAgents: map[uuid.UUID]struct{}{}, } @@ -57,14 +58,22 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { ID: id, Logger: c.log, AgentIsLegacyFunc: c.agentIsLegacy, - OnSubscribe: c.multiAgentSubscribe, - OnNodeUpdate: c.multiAgentUpdate, + OnSubscribe: c.clientSubscribeToAgent, + OnNodeUpdate: c.clientNodeUpdate, + OnRemove: c.clientDisconnected, }).Init() + c.mutex.Lock() + c.clients[id] = m + c.clientsToAgents[id] = map[uuid.UUID]struct{}{} + c.mutex.Unlock() return m } -func (c *haCoordinator) multiAgentSubscribe(enq agpl.Enqueueable, agentID uuid.UUID) (func(), error) { +func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Enqueueable, agentID uuid.UUID) error { c.mutex.Lock() + defer c.mutex.Unlock() + + c.initOrSetAgentConnectionSocketLocked(agentID, enq) node := c.nodes[enq.UniqueID()] @@ -73,44 +82,43 @@ func (c *haCoordinator) multiAgentSubscribe(enq agpl.Enqueueable, agentID uuid.U if ok { err := enq.Enqueue([]*agpl.Node{agentNode}) if err != nil { - return nil, xerrors.Errorf("enqueue agent on subscribe: %w", err) + return xerrors.Errorf("enqueue agent on subscribe: %w", err) } } else { // If we don't have the node locally, notify other coordinators. - c.mutex.Unlock() err := c.publishClientHello(agentID) if err != nil { - return nil, xerrors.Errorf("publish client hello: %w", err) + return xerrors.Errorf("publish client hello: %w", err) } } if node != nil { - err := c.handleClientUpdate(enq.UniqueID(), agentID, node) + err := c.sendNodeToAgentLocked(agentID, node) if err != nil { - return nil, xerrors.Errorf("handle client update: %w", err) + return xerrors.Errorf("handle client update: %w", err) } } - return c.cleanupClientConn(enq.UniqueID(), agentID), nil -} - -func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error { - var errs *multierror.Error - // This isn't the most efficient, but this coordinator is being deprecated - // soon anyways. - for _, agent := range agents { - err := c.handleClientUpdate(id, agent, node) - if err != nil { - errs = multierror.Append(errs, err) - } - } - if errs != nil { - return errs - } - return nil } +// func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error { +// var errs *multierror.Error +// // This isn't the most efficient, but this coordinator is being deprecated +// // soon anyways. +// for _, agent := range agents { +// err := c.handleClientUpdate(id, agent, node) +// if err != nil { +// errs = multierror.Append(errs, err) +// } +// } +// if errs != nil { +// return errs +// } + +// return nil +// } + type haCoordinator struct { id uuid.UUID log slog.Logger @@ -127,6 +135,9 @@ type haCoordinator struct { // are subscribed to updates for that agent. agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable + clients map[uuid.UUID]agpl.Enqueueable + clientsToAgents map[uuid.UUID]map[uuid.UUID]struct{} + // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] @@ -152,40 +163,25 @@ func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger { // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. -func (c *haCoordinator) ServeClient(conn net.Conn, id, agent uuid.UUID) error { +func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := c.clientLogger(id, agent) + logger := c.clientLogger(id, agentID) - c.mutex.Lock() - - tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0) - c.initOrSetAgentConnectionSocketLocked(agent, tc) + ma := c.ServeMultiAgent(id) + defer ma.Close() - // When a new connection is requested, we update it with the latest - // node of the agent. This allows the connection to establish. - node, ok := c.nodes[agent] - if ok { - err := tc.Enqueue([]*agpl.Node{node}) - c.mutex.Unlock() - if err != nil { - return xerrors.Errorf("enqueue node: %w", err) - } - } else { - c.mutex.Unlock() - err := c.publishClientHello(agent) - if err != nil { - return xerrors.Errorf("publish client hello: %w", err) - } + err := ma.SubscribeAgent(agentID) + if err != nil { + return xerrors.Errorf("subscribe agent: %w", err) } - go tc.SendUpdates() - defer c.cleanupClientConn(id, agent) + go agpl.SendUpdatesToConn(ctx, logger, ma, conn) decoder := json.NewDecoder(conn) // Indefinitely handle messages from the client websocket. for { - err := c.handleNextClientMessage(id, agent, decoder) + err := c.handleNextClientMessage(id, decoder) if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { return nil @@ -202,12 +198,14 @@ func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, c.agentToConnectionSockets[agentID] = connectionSockets } connectionSockets[enq.UniqueID()] = enq + c.clientsToAgents[enq.UniqueID()][agentID] = struct{}{} } -func (c *haCoordinator) cleanupClientConn(id, agentID uuid.UUID) func() { - return func() { - c.mutex.Lock() - defer c.mutex.Unlock() +func (c *haCoordinator) clientDisconnected(id uuid.UUID) { + c.mutex.Lock() + defer c.mutex.Unlock() + + for agentID := range c.clientsToAgents[id] { // Clean all traces of this connection from the map. delete(c.nodes, id) connectionSockets, ok := c.agentToConnectionSockets[agentID] @@ -220,39 +218,52 @@ func (c *haCoordinator) cleanupClientConn(id, agentID uuid.UUID) func() { } delete(c.agentToConnectionSockets, agentID) } + + delete(c.clients, id) + delete(c.clientsToAgents, id) } -func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { +func (c *haCoordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error { var node agpl.Node err := decoder.Decode(&node) if err != nil { return xerrors.Errorf("read json: %w", err) } - return c.handleClientUpdate(id, agent, &node) + return c.clientNodeUpdate(id, &node) } -func (c *haCoordinator) handleClientUpdate(id, agent uuid.UUID, node *agpl.Node) error { +func (c *haCoordinator) clientNodeUpdate(id uuid.UUID, node *agpl.Node) error { c.mutex.Lock() + defer c.mutex.Unlock() // Update the node of this client in our in-memory map. If an agent entirely // shuts down and reconnects, it needs to be aware of all clients attempting // to establish connections. c.nodes[id] = node - // Write the new node from this client to the actively connected agent. - agentSocket, ok := c.agentSockets[agent] + for agentID := range c.clientsToAgents[id] { + // Write the new node from this client to the actively connected agent. + err := c.sendNodeToAgentLocked(agentID, node) + if err != nil { + c.log.Error(context.Background(), "send node to agent", slog.Error(err), slog.F("agent_id", agentID)) + } + } + + return nil +} + +func (c *haCoordinator) sendNodeToAgentLocked(agentID uuid.UUID, node *agpl.Node) error { + agentSocket, ok := c.agentSockets[agentID] if !ok { - c.mutex.Unlock() // If we don't own the agent locally, send it over pubsub to a node that // owns the agent. - err := c.publishNodesToAgent(agent, []*agpl.Node{node}) + err := c.publishNodesToAgent(agentID, []*agpl.Node{node}) if err != nil { return xerrors.Errorf("publish node to agent") } return nil } err := agentSocket.Enqueue([]*agpl.Node{node}) - c.mutex.Unlock() if err != nil { return xerrors.Errorf("enqueue node: %w", err) } @@ -422,7 +433,7 @@ func (c *haCoordinator) Close() error { for _, socket := range c.agentSockets { socket := socket go func() { - _ = socket.Close() + _ = socket.CoordinatorClose() wg.Done() }() } @@ -432,12 +443,17 @@ func (c *haCoordinator) Close() error { for _, socket := range connMap { socket := socket go func() { - _ = socket.Close() + _ = socket.CoordinatorClose() wg.Done() }() } } + // Ensure clients that have no subscriptions are properly closed. + for _, client := range c.clients { + _ = client.CoordinatorClose() + } + wg.Wait() return nil } diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 98bb8a946cfb2..b7b8fe9c28537 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -11,11 +11,9 @@ import ( "net/http" "net/netip" "sync" - "sync/atomic" "time" "github.com/google/uuid" - "github.com/hashicorp/go-multierror" lru "github.com/hashicorp/golang-lru/v2" "golang.org/x/exp/slices" "golang.org/x/xerrors" @@ -145,8 +143,9 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { ID: id, Logger: c.core.logger, AgentIsLegacyFunc: c.core.agentIsLegacy, - OnSubscribe: c.core.multiAgentSubscribe, - OnNodeUpdate: c.core.multiAgentUpdate, + OnSubscribe: c.core.clientSubscribeToAgent, + OnNodeUpdate: c.core.clientNodeUpdate, + OnRemove: c.core.clientDisconnected, }).Init() c.core.addMultiAgent(id, m) return m @@ -154,7 +153,8 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { func (c *core) addMultiAgent(id uuid.UUID, ma *MultiAgent) { c.mutex.Lock() - c.multiAgents[id] = ma + c.clients[id] = ma + c.clientsToAgents[id] = map[uuid.UUID]struct{}{} c.mutex.Unlock() } @@ -173,6 +173,9 @@ type core struct { // are subscribed to updates for that agent. agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Enqueueable + clients map[uuid.UUID]Enqueueable + clientsToAgents map[uuid.UUID]map[uuid.UUID]struct{} + // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] @@ -182,7 +185,7 @@ type core struct { // coordinator. We need to keep track of these separately because we need to // make sure they're closed on coordinator shutdown. If not, they won't be // able to reopen another multiAgent on the new coordinator. - multiAgents map[uuid.UUID]*MultiAgent + // multiAgents map[uuid.UUID]*MultiAgent } type Enqueueable interface { @@ -191,6 +194,7 @@ type Enqueueable interface { Name() string Stats() (start, lastWrite int64) Overwrites() int64 + CoordinatorClose() error Close() error } @@ -208,140 +212,13 @@ func newCore(logger slog.Logger) *core { agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Enqueueable{}, agentNameCache: nameCache, legacyAgents: map[uuid.UUID]struct{}{}, - multiAgents: map[uuid.UUID]*MultiAgent{}, + clients: map[uuid.UUID]Enqueueable{}, + clientsToAgents: map[uuid.UUID]map[uuid.UUID]struct{}{}, } } var ErrWouldBlock = xerrors.New("would block") -type TrackedConn struct { - ctx context.Context - cancel func() - conn net.Conn - updates chan []*Node - logger slog.Logger - lastData []byte - - // ID is an ephemeral UUID used to uniquely identify the owner of the - // connection. - id uuid.UUID - - name string - start int64 - lastWrite int64 - overwrites int64 -} - -func (t *TrackedConn) Enqueue(n []*Node) (err error) { - atomic.StoreInt64(&t.lastWrite, time.Now().Unix()) - select { - case t.updates <- n: - return nil - default: - return ErrWouldBlock - } -} - -func (t *TrackedConn) UniqueID() uuid.UUID { - return t.id -} - -func (t *TrackedConn) Name() string { - return t.name -} - -func (t *TrackedConn) Stats() (start, lastWrite int64) { - return t.start, atomic.LoadInt64(&t.lastWrite) -} - -func (t *TrackedConn) Overwrites() int64 { - return t.overwrites -} - -// Close the connection and cancel the context for reading node updates from the queue -func (t *TrackedConn) Close() error { - t.cancel() - return t.conn.Close() -} - -// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung. -// It is exported so that tests can use it. -const WriteTimeout = time.Second * 5 - -// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is -// canceled. -func (t *TrackedConn) SendUpdates() { - for { - select { - case <-t.ctx.Done(): - t.logger.Debug(t.ctx, "done sending updates") - return - case nodes := <-t.updates: - data, err := json.Marshal(nodes) - if err != nil { - t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) - return - } - if bytes.Equal(t.lastData, data) { - t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", nodes)) - continue - } - - // Set a deadline so that hung connections don't put back pressure on the system. - // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. - err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err)) - _ = t.Close() - return - } - _, err = t.conn.Write(data) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "could not write nodes to connection", slog.Error(err), slog.F("nodes", nodes)) - _ = t.Close() - return - } - t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", nodes)) - - // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are - // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() - // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. - // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after - // our successful write, it is important that we reset the deadline before it fires. - err = t.conn.SetWriteDeadline(time.Time{}) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err)) - _ = t.Close() - return - } - t.lastData = data - } - } -} - -func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, overwrites int64) *TrackedConn { - // buffer updates so they don't block, since we hold the - // coordinator mutex while queuing. Node updates don't - // come quickly, so 512 should be plenty for all but - // the most pathological cases. - updates := make(chan []*Node, 512) - now := time.Now().Unix() - return &TrackedConn{ - ctx: ctx, - conn: conn, - cancel: cancel, - updates: updates, - logger: logger, - id: id, - start: now, - lastWrite: now, - overwrites: overwrites, - } -} - // Node returns an in-memory node by ID. // If the node does not exist, nil is returned. func (c *coordinator) Node(id uuid.UUID) *Node { @@ -376,24 +253,27 @@ func (c *core) agentCount() int { // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. -func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { +func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := c.core.clientLogger(id, agent) + logger := c.core.clientLogger(id, agentID) logger.Debug(ctx, "coordinating client") - tc, err := c.core.initAndTrackClient(ctx, cancel, conn, id, agent) + + ma := c.ServeMultiAgent(id) + defer ma.Close() + + err := ma.SubscribeAgent(agentID) if err != nil { - return err + return xerrors.Errorf("subscribe agent: %w", err) } - defer c.core.clientDisconnected(id, agent) // On this goroutine, we read updates from the client and publish them. We start a second goroutine // to write updates back to the client. - go tc.SendUpdates() + go SendUpdatesToConn(ctx, logger, ma, conn) decoder := json.NewDecoder(conn) for { - err := c.handleNextClientMessage(id, agent, decoder) + err := c.handleNextClientMessage(id, decoder) if err != nil { logger.Debug(ctx, "unable to read client update, connection may be closed", slog.Error(err)) if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { @@ -404,45 +284,64 @@ func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) } } -func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger { - return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) -} +func SendUpdatesToConn(ctx context.Context, logger slog.Logger, ma MultiAgentConn, conn net.Conn) { + defer logger.Debug(ctx, "done sending updates") + defer func() { + _ = ma.Close() + _ = conn.Close() + }() -// initAndTrackClient creates a TrackedConn for the client, and sends any initial Node updates if we have any. It is -// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some -// updates. -func (c *core) initAndTrackClient( - ctx context.Context, cancel func(), conn net.Conn, id, agent uuid.UUID, -) ( - *TrackedConn, error, -) { - logger := c.clientLogger(id, agent) - c.mutex.Lock() - defer c.mutex.Unlock() - if c.closed { - return nil, xerrors.New("coordinator is closed") - } - tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0) + lastData := []byte{} - // When a new connection is requested, we update it with the latest - // node of the agent. This allows the connection to establish. - node, ok := c.nodes[agent] - if ok { - err := tc.Enqueue([]*Node{node}) - // this should never error since we're still the only goroutine that - // knows about the TrackedConn. If we hit an error something really - // wrong is happening + for { + nodes, ok := ma.NextUpdate(ctx) + if !ok { + return + } + + data, err := json.Marshal(nodes) if err != nil { - logger.Critical(ctx, "unable to queue initial node", slog.Error(err)) - return nil, err + logger.Error(ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) + return + } + if bytes.Equal(lastData, data) { + logger.Debug(ctx, "skipping duplicate update", slog.F("nodes", nodes)) + continue + } + + // Set a deadline so that hung connections don't put back pressure on the system. + // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. + err = conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + logger.Debug(ctx, "unable to set write deadline", slog.Error(err)) + return + } + _, err = conn.Write(data) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + logger.Debug(ctx, "could not write nodes to connection", slog.Error(err), slog.F("nodes", nodes)) + return } + logger.Debug(ctx, "wrote nodes", slog.F("nodes", nodes)) + + // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are + // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() + // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. + // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after + // our successful write, it is important that we reset the deadline before it fires. + err = conn.SetWriteDeadline(time.Time{}) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + logger.Debug(ctx, "unable to extend write deadline", slog.Error(err)) + return + } + lastData = data } +} - // Insert this connection into a map so the agent - // can publish node updates. - c.initOrSetAgentConnectionSocketLocked(agent, tc) - logger.Debug(ctx, "added tracked connection") - return tc, nil +func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger { + return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) } func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Enqueueable) { @@ -452,40 +351,52 @@ func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Enque c.agentToConnectionSockets[agentID] = connectionSockets } connectionSockets[enq.UniqueID()] = enq + + c.clientsToAgents[enq.UniqueID()][agentID] = struct{}{} } -func (c *core) clientDisconnected(id, agent uuid.UUID) { - logger := c.clientLogger(id, agent) +func (c *core) clientDisconnected(id uuid.UUID) { + logger := c.clientLogger(id, uuid.Nil) c.mutex.Lock() defer c.mutex.Unlock() // Clean all traces of this connection from the map. delete(c.nodes, id) logger.Debug(context.Background(), "deleted client node") - connectionSockets, ok := c.agentToConnectionSockets[agent] - if !ok { - return - } - delete(connectionSockets, id) - logger.Debug(context.Background(), "deleted client connectionSocket from map") - if len(connectionSockets) != 0 { - return + + for agentID := range c.clientsToAgents[id] { + connectionSockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + return + } + delete(connectionSockets, id) + logger.Debug(context.Background(), "deleted client connectionSocket from map", slog.F("agent_id", agentID)) + + if len(connectionSockets) != 0 { + return + } + delete(c.agentToConnectionSockets, agentID) + logger.Debug(context.Background(), "deleted last client connectionSocket from map", slog.F("agent_id", agentID)) } - delete(c.agentToConnectionSockets, agent) - logger.Debug(context.Background(), "deleted last client connectionSocket from map") + + delete(c.clients, id) + delete(c.clientsToAgents, id) + logger.Debug(context.Background(), "deleted client agents") } -func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { - logger := c.core.clientLogger(id, agent) +func (c *coordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error { + logger := c.core.clientLogger(id, uuid.Nil) + var node Node err := decoder.Decode(&node) if err != nil { return xerrors.Errorf("read json: %w", err) } + logger.Debug(context.Background(), "got client node update", slog.F("node", node)) - return c.core.clientNodeUpdate(id, agent, &node) + return c.core.clientNodeUpdate(id, &node) } -func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error { +func (c *core) clientNodeUpdate(id uuid.UUID, node *Node) error { c.mutex.Lock() defer c.mutex.Unlock() @@ -494,78 +405,69 @@ func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error { // to establish connections. c.nodes[id] = node - return c.clientNodeUpdateLocked(id, agent, node) + return c.clientNodeUpdateLocked(id, node) } -func (c *core) clientNodeUpdateLocked(id, agent uuid.UUID, node *Node) error { - logger := c.clientLogger(id, agent) +func (c *core) clientNodeUpdateLocked(id uuid.UUID, node *Node) error { + logger := c.clientLogger(id, uuid.Nil) + + agents := []uuid.UUID{} + for agentID := range c.clientsToAgents[id] { + err := c.sendNodeToAgentLocked(agentID, node) + if err != nil { + logger.Debug(context.Background(), "unable to send node to agent", slog.Error(err), slog.F("agent_id", agentID)) + continue + } + agents = append(agents, agentID) + } + + logger.Debug(context.Background(), "enqueued node to agents", slog.F("agent_ids", agents)) + return nil +} - agentSocket, ok := c.agentSockets[agent] +func (c *core) sendNodeToAgentLocked(agentID uuid.UUID, node *Node) error { + agentSocket, ok := c.agentSockets[agentID] if !ok { - logger.Debug(context.Background(), "no agent socket, unable to send node") - return nil + return xerrors.New("no agent socket") } err := agentSocket.Enqueue([]*Node{node}) if err != nil { - return xerrors.Errorf("enqueue node: %w", err) + return xerrors.Errorf("enqueue client to agent: %w", err) } - logger.Debug(context.Background(), "enqueued node to agent") + return nil } -func (c *core) multiAgentSubscribe(enq Enqueueable, agentID uuid.UUID) (func(), error) { +func (c *core) clientSubscribeToAgent(enq Enqueueable, agentID uuid.UUID) error { c.mutex.Lock() defer c.mutex.Unlock() + logger := c.clientLogger(enq.UniqueID(), uuid.Nil) + c.initOrSetAgentConnectionSocketLocked(agentID, enq) node, ok := c.nodes[enq.UniqueID()] if ok { // If we have the node, send it to the agent. If not, it will be sent // async. - err := c.clientNodeUpdateLocked(enq.UniqueID(), agentID, node) + err := c.sendNodeToAgentLocked(agentID, node) if err != nil { - return nil, xerrors.Errorf("send update to agent: %w", err) + logger.Debug(context.Background(), "unable to send node to agent", slog.Error(err), slog.F("agent_id", agentID)) } } else { - c.logger.Debug(context.Background(), "multiagent node doesn't exist", slog.F("multiagent_id", enq.UniqueID())) + logger.Debug(context.Background(), "multiagent node doesn't exist", slog.F("multiagent_id", enq.UniqueID())) } - closer := func() {} - agentNode, ok := c.nodes[agentID] if !ok { // This is ok, once the agent connects the node will be sent over. - c.logger.Debug(context.Background(), "agent node doesn't exist", slog.F("agent_id", agentID)) - return closer, nil + logger.Debug(context.Background(), "agent node doesn't exist", slog.F("agent_id", agentID)) + return nil } // Send the subscribed agent back to the multi agent. - err := enq.Enqueue([]*Node{agentNode}) - return closer, err -} - -func (c *core) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *Node) error { - c.mutex.Lock() - defer c.mutex.Unlock() - // Update the node of this client in our in-memory map. If an agent entirely - // shuts down and reconnects, it needs to be aware of all clients attempting - // to establish connections. - c.nodes[id] = node - - var errs *multierror.Error - for _, aid := range agents { - err := c.clientNodeUpdateLocked(id, aid, node) - if err != nil { - errs = multierror.Append(errs, err) - } - } - if errs != nil { - return errs - } - - return nil + return enq.Enqueue([]*Node{agentNode}) } func (c *core) agentLogger(id uuid.UUID) slog.Logger { @@ -751,7 +653,7 @@ func (c *core) close() error { for _, socket := range c.agentSockets { socket := socket go func() { - _ = socket.Close() + _ = socket.CoordinatorClose() wg.Done() }() } @@ -761,14 +663,15 @@ func (c *core) close() error { for _, socket := range connMap { socket := socket go func() { - _ = socket.Close() + _ = socket.CoordinatorClose() wg.Done() }() } } - for _, multiAgent := range c.multiAgents { - multiAgent.CoordinatorClose() + // Ensure clients that have no subscriptions are properly closed. + for _, client := range c.clients { + _ = client.CoordinatorClose() } c.mutex.Unlock() diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index 393b51d49de35..3b24489d34d5f 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -4,6 +4,7 @@ import ( "context" "sync" "sync/atomic" + "time" "github.com/google/uuid" @@ -12,9 +13,9 @@ import ( type MultiAgentConn interface { UpdateSelf(node *Node) error - SubscribeAgent(agentID uuid.UUID) (func(), error) - UnsubscribeAgent(agentID uuid.UUID) - NextUpdate(ctx context.Context) []*Node + SubscribeAgent(agentID uuid.UUID) error + UnsubscribeAgent(agentID uuid.UUID) error + NextUpdate(ctx context.Context) ([]*Node, bool) AgentIsLegacy(agentID uuid.UUID) bool Close() error IsClosed() bool @@ -22,158 +23,108 @@ type MultiAgentConn interface { type MultiAgent struct { mu sync.RWMutex - closed chan struct{} + closed bool ID uuid.UUID Logger slog.Logger AgentIsLegacyFunc func(agentID uuid.UUID) bool - OnSubscribe func(enq Enqueueable, agent uuid.UUID) (close func(), err error) - OnNodeUpdate func(id uuid.UUID, agents []uuid.UUID, node *Node) error + OnSubscribe func(enq Enqueueable, agent uuid.UUID) error + OnUnsubscribe func(enq Enqueueable, agent uuid.UUID) error + OnNodeUpdate func(id uuid.UUID, node *Node) error + OnRemove func(id uuid.UUID) - updates chan []*Node - subscribedAgents map[uuid.UUID]func() + updates chan []*Node + closeOnce sync.Once + start int64 + lastWrite int64 + overwrites int64 } func (m *MultiAgent) Init() *MultiAgent { - m.closed = make(chan struct{}) m.updates = make(chan []*Node, 128) - m.subscribedAgents = map[uuid.UUID]func(){} + m.start = time.Now().Unix() return m } +func (m *MultiAgent) UniqueID() uuid.UUID { + return m.ID +} + func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool { return m.AgentIsLegacyFunc(agentID) } func (m *MultiAgent) UpdateSelf(node *Node) error { - m.mu.Lock() - agents := make([]uuid.UUID, 0, len(m.subscribedAgents)) - for agent := range m.subscribedAgents { - agents = append(agents, agent) - } - m.mu.Unlock() - - return m.OnNodeUpdate(m.ID, agents, node) + return m.OnNodeUpdate(m.ID, node) } -func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) (func(), error) { - m.mu.Lock() - defer m.mu.Unlock() - - if closer, ok := m.subscribedAgents[agentID]; ok { - return closer, nil - } - - closer, err := m.OnSubscribe(m.enqueuer(agentID), agentID) - if err != nil { - return nil, err - } - m.subscribedAgents[agentID] = closer - return closer, nil +func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error { + return m.OnSubscribe(m, agentID) } -func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) { - m.mu.Lock() - defer m.mu.Unlock() - - if closer, ok := m.subscribedAgents[agentID]; ok { - closer() - } - delete(m.subscribedAgents, agentID) +func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error { + return m.OnUnsubscribe(m, agentID) } -func (m *MultiAgent) NextUpdate(ctx context.Context) []*Node { - for { - select { - case <-ctx.Done(): - return nil - - case nodes := <-m.updates: - return nodes - } - } -} +func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) { + select { + case <-ctx.Done(): + return nil, false -func (m *MultiAgent) enqueuer(agentID uuid.UUID) Enqueueable { - return &multiAgentEnqueuer{ - agentID: agentID, - m: m, + case nodes := <-m.updates: + return nodes, true } } -type multiAgentEnqueuer struct { - m *MultiAgent +func (m *MultiAgent) Enqueue(nodes []*Node) error { + atomic.StoreInt64(&m.lastWrite, time.Now().Unix()) - agentID uuid.UUID - start int64 - lastWrite int64 - overwrites int64 -} + m.mu.RLock() + defer m.mu.RUnlock() -func (m *multiAgentEnqueuer) UniqueID() uuid.UUID { - return m.m.ID -} + if m.closed { + return nil + } -func (m *multiAgentEnqueuer) Enqueue(nodes []*Node) error { select { - case m.m.updates <- nodes: + case m.updates <- nodes: return nil default: return ErrWouldBlock } } -func (m *multiAgentEnqueuer) Name() string { - return "multiagent-" + m.m.ID.String() +func (m *MultiAgent) Name() string { + return m.ID.String() } -func (m *multiAgentEnqueuer) Stats() (start int64, lastWrite int64) { +func (m *MultiAgent) Stats() (start int64, lastWrite int64) { return m.start, atomic.LoadInt64(&m.lastWrite) } -func (m *multiAgentEnqueuer) Overwrites() int64 { +func (m *MultiAgent) Overwrites() int64 { return m.overwrites } -func (m *multiAgentEnqueuer) Close() error { - m.m.mu.Lock() - defer m.m.mu.Unlock() - - // Delete without running the closer. If the enqueuer itself gets closed, we - // can assume that the caller is removing it from the coordinator. - delete(m.m.subscribedAgents, m.agentID) - return nil -} - func (m *MultiAgent) IsClosed() bool { - select { - case <-m.closed: - return true - default: - return false - } + m.mu.RLock() + defer m.mu.RUnlock() + return m.closed } -func (m *MultiAgent) CoordinatorClose() { +func (m *MultiAgent) CoordinatorClose() error { m.mu.Lock() - if !m.IsClosed() { - close(m.closed) + if !m.closed { + m.closed = true close(m.updates) } m.mu.Unlock() + return nil } func (m *MultiAgent) Close() error { - m.mu.Lock() - defer m.mu.Unlock() - if m.IsClosed() { - return nil - } - close(m.closed) - close(m.updates) - for _, closer := range m.subscribedAgents { - closer() - } + _ = m.CoordinatorClose() + m.closeOnce.Do(func() { m.OnRemove(m.ID) }) return nil } diff --git a/tailnet/trackedconn.go b/tailnet/trackedconn.go new file mode 100644 index 0000000000000..b6459960c7954 --- /dev/null +++ b/tailnet/trackedconn.go @@ -0,0 +1,146 @@ +package tailnet + +import ( + "bytes" + "context" + "encoding/json" + "net" + "sync/atomic" + "time" + + "github.com/google/uuid" + + "cdr.dev/slog" +) + +// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung. +// It is exported so that tests can use it. +const WriteTimeout = time.Second * 5 + +type TrackedConn struct { + ctx context.Context + cancel func() + conn net.Conn + updates chan []*Node + logger slog.Logger + lastData []byte + + // ID is an ephemeral UUID used to uniquely identify the owner of the + // connection. + id uuid.UUID + + name string + start int64 + lastWrite int64 + overwrites int64 +} + +func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, overwrites int64) *TrackedConn { + // buffer updates so they don't block, since we hold the + // coordinator mutex while queuing. Node updates don't + // come quickly, so 512 should be plenty for all but + // the most pathological cases. + updates := make(chan []*Node, 512) + now := time.Now().Unix() + return &TrackedConn{ + ctx: ctx, + conn: conn, + cancel: cancel, + updates: updates, + logger: logger, + id: id, + start: now, + lastWrite: now, + overwrites: overwrites, + } +} + +func (t *TrackedConn) Enqueue(n []*Node) (err error) { + atomic.StoreInt64(&t.lastWrite, time.Now().Unix()) + select { + case t.updates <- n: + return nil + default: + return ErrWouldBlock + } +} + +func (t *TrackedConn) UniqueID() uuid.UUID { + return t.id +} + +func (t *TrackedConn) Name() string { + return t.name +} + +func (t *TrackedConn) Stats() (start, lastWrite int64) { + return t.start, atomic.LoadInt64(&t.lastWrite) +} + +func (t *TrackedConn) Overwrites() int64 { + return t.overwrites +} + +func (t *TrackedConn) CoordinatorClose() error { + return t.Close() +} + +// Close the connection and cancel the context for reading node updates from the queue +func (t *TrackedConn) Close() error { + t.cancel() + return t.conn.Close() +} + +// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is +// canceled. +func (t *TrackedConn) SendUpdates() { + for { + select { + case <-t.ctx.Done(): + t.logger.Debug(t.ctx, "done sending updates") + return + case nodes := <-t.updates: + data, err := json.Marshal(nodes) + if err != nil { + t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) + return + } + if bytes.Equal(t.lastData, data) { + t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", nodes)) + continue + } + + // Set a deadline so that hung connections don't put back pressure on the system. + // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. + err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err)) + _ = t.Close() + return + } + _, err = t.conn.Write(data) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + t.logger.Debug(t.ctx, "could not write nodes to connection", slog.Error(err), slog.F("nodes", nodes)) + _ = t.Close() + return + } + t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", nodes)) + + // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are + // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() + // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. + // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after + // our successful write, it is important that we reset the deadline before it fires. + err = t.conn.SetWriteDeadline(time.Time{}) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err)) + _ = t.Close() + return + } + t.lastData = data + } + } +} From 11f2805ca76b77624799e691af87131dab079e32 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 30 Jun 2023 02:56:59 +0000 Subject: [PATCH 12/19] fixups --- tailnet/coordinator.go | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index b7b8fe9c28537..046ef9489b987 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -173,19 +173,21 @@ type core struct { // are subscribed to updates for that agent. agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Enqueueable - clients map[uuid.UUID]Enqueueable + // clients holds a map of all clients connected to the coordinator. This is + // necessary because a client may not be subscribed into any agents. + clients map[uuid.UUID]Enqueueable + // clientsToAgents is an index of clients to all of their subscribed agents. clientsToAgents map[uuid.UUID]map[uuid.UUID]struct{} // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] + // legacyAgents holda a mapping of all agents detected as legacy, meaning + // they only listen on codersdk.WorkspaceAgentIP. They aren't compatible + // with the new ServerTailnet, so they must be connected through + // wsconncache. legacyAgents map[uuid.UUID]struct{} - // multiAgents holds all of the unique multiAgents listening on this - // coordinator. We need to keep track of these separately because we need to - // make sure they're closed on coordinator shutdown. If not, they won't be - // able to reopen another multiAgent on the new coordinator. - // multiAgents map[uuid.UUID]*MultiAgent } type Enqueueable interface { @@ -658,20 +660,13 @@ func (c *core) close() error { }() } - for _, connMap := range c.agentToConnectionSockets { - wg.Add(len(connMap)) - for _, socket := range connMap { - socket := socket - go func() { - _ = socket.CoordinatorClose() - wg.Done() - }() - } - } - // Ensure clients that have no subscriptions are properly closed. for _, client := range c.clients { - _ = client.CoordinatorClose() + client := client + go func() { + _ = client.CoordinatorClose() + wg.Done() + }() } c.mutex.Unlock() From a88be6676a98ab90b69285e8e385f8d7b7b4576a Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 30 Jun 2023 03:11:08 +0000 Subject: [PATCH 13/19] fixup! fixups --- enterprise/tailnet/coordinator.go | 45 ++++++++++--------------------- tailnet/coordinator.go | 2 +- tailnet/multiagent.go | 10 ++++--- 3 files changed, 21 insertions(+), 36 deletions(-) diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 3deedeb9d3a7c..cfef8d521e619 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -102,23 +102,6 @@ func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Enqueueable, agentID uui return nil } -// func (c *haCoordinator) multiAgentUpdate(id uuid.UUID, agents []uuid.UUID, node *agpl.Node) error { -// var errs *multierror.Error -// // This isn't the most efficient, but this coordinator is being deprecated -// // soon anyways. -// for _, agent := range agents { -// err := c.handleClientUpdate(id, agent, node) -// if err != nil { -// errs = multierror.Append(errs, err) -// } -// } -// if errs != nil { -// return errs -// } - -// return nil -// } - type haCoordinator struct { id uuid.UUID log slog.Logger @@ -135,13 +118,20 @@ type haCoordinator struct { // are subscribed to updates for that agent. agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable - clients map[uuid.UUID]agpl.Enqueueable + // clients holds a map of all clients connected to the coordinator. This is + // necessary because a client may not be subscribed into any agents. + clients map[uuid.UUID]agpl.Enqueueable + // clientsToAgents is an index of clients to all of their subscribed agents. clientsToAgents map[uuid.UUID]map[uuid.UUID]struct{} // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] + // legacyAgents holda a mapping of all agents detected as legacy, meaning + // they only listen on codersdk.WorkspaceAgentIP. They aren't compatible + // with the new ServerTailnet, so they must be connected through + // wsconncache. legacyAgents map[uuid.UUID]struct{} } @@ -438,20 +428,13 @@ func (c *haCoordinator) Close() error { }() } - for _, connMap := range c.agentToConnectionSockets { - wg.Add(len(connMap)) - for _, socket := range connMap { - socket := socket - go func() { - _ = socket.CoordinatorClose() - wg.Done() - }() - } - } - - // Ensure clients that have no subscriptions are properly closed. + wg.Add(len(c.clients)) for _, client := range c.clients { - _ = client.CoordinatorClose() + client := client + go func() { + _ = client.CoordinatorClose() + wg.Done() + }() } wg.Wait() diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 046ef9489b987..61affb4eb02a8 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -660,7 +660,7 @@ func (c *core) close() error { }() } - // Ensure clients that have no subscriptions are properly closed. + wg.Add(len(c.clients)) for _, client := range c.clients { client := client go func() { diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index 3b24489d34d5f..950b951c227fa 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -34,10 +34,12 @@ type MultiAgent struct { OnNodeUpdate func(id uuid.UUID, node *Node) error OnRemove func(id uuid.UUID) - updates chan []*Node - closeOnce sync.Once - start int64 - lastWrite int64 + updates chan []*Node + closeOnce sync.Once + start int64 + lastWrite int64 + // Client nodes normally generate a unique id for each connection so + // overwrites are really not an issue, but is provided for compatibility. overwrites int64 } From 5e4b631764cf08b34864562e33f360ae78bab63b Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 30 Jun 2023 19:33:28 +0000 Subject: [PATCH 14/19] fixes --- enterprise/tailnet/coordinator.go | 70 +++++++++----- go.mod | 2 +- tailnet/coordinator.go | 153 ++++++++++-------------------- tailnet/multiagent.go | 32 ++++++- 4 files changed, 125 insertions(+), 132 deletions(-) diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index cfef8d521e619..c3d6b2650ce18 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -38,11 +38,11 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err closeFunc: cancelFunc, close: make(chan struct{}), nodes: map[uuid.UUID]*agpl.Node{}, - agentSockets: map[uuid.UUID]agpl.Enqueueable{}, - agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable{}, + agentSockets: map[uuid.UUID]agpl.Queue{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]agpl.Queue{}, agentNameCache: nameCache, - clients: map[uuid.UUID]agpl.Enqueueable{}, - clientsToAgents: map[uuid.UUID]map[uuid.UUID]struct{}{}, + clients: map[uuid.UUID]agpl.Queue{}, + clientsToAgents: map[uuid.UUID]map[uuid.UUID]agpl.Queue{}, legacyAgents: map[uuid.UUID]struct{}{}, } @@ -62,14 +62,18 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { OnNodeUpdate: c.clientNodeUpdate, OnRemove: c.clientDisconnected, }).Init() + c.addClient(id, m) + return m +} + +func (c *haCoordinator) addClient(id uuid.UUID, q agpl.Queue) { c.mutex.Lock() - c.clients[id] = m - c.clientsToAgents[id] = map[uuid.UUID]struct{}{} + c.clients[id] = q + c.clientsToAgents[id] = map[uuid.UUID]agpl.Queue{} c.mutex.Unlock() - return m } -func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Enqueueable, agentID uuid.UUID) error { +func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Queue, agentID uuid.UUID) error { c.mutex.Lock() defer c.mutex.Unlock() @@ -113,16 +117,16 @@ type haCoordinator struct { // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*agpl.Node // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]agpl.Enqueueable + agentSockets map[uuid.UUID]agpl.Queue // agentToConnectionSockets maps agent IDs to connection IDs of conns that // are subscribed to updates for that agent. - agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Enqueueable + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Queue // clients holds a map of all clients connected to the coordinator. This is // necessary because a client may not be subscribed into any agents. - clients map[uuid.UUID]agpl.Enqueueable + clients map[uuid.UUID]agpl.Queue // clientsToAgents is an index of clients to all of their subscribed agents. - clientsToAgents map[uuid.UUID]map[uuid.UUID]struct{} + clientsToAgents map[uuid.UUID]map[uuid.UUID]agpl.Queue // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. @@ -158,15 +162,18 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error defer cancel() logger := c.clientLogger(id, agentID) - ma := c.ServeMultiAgent(id) - defer ma.Close() + tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0) + defer tc.Close() - err := ma.SubscribeAgent(agentID) + c.addClient(id, tc) + defer c.clientDisconnected(id) + + err := c.clientSubscribeToAgent(tc, agentID) if err != nil { return xerrors.Errorf("subscribe agent: %w", err) } - go agpl.SendUpdatesToConn(ctx, logger, ma, conn) + go tc.SendUpdates() decoder := json.NewDecoder(conn) // Indefinitely handle messages from the client websocket. @@ -181,14 +188,14 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error } } -func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq agpl.Enqueueable) { +func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq agpl.Queue) { connectionSockets, ok := c.agentToConnectionSockets[agentID] if !ok { - connectionSockets = map[uuid.UUID]agpl.Enqueueable{} + connectionSockets = map[uuid.UUID]agpl.Queue{} c.agentToConnectionSockets[agentID] = connectionSockets } connectionSockets[enq.UniqueID()] = enq - c.clientsToAgents[enq.UniqueID()][agentID] = struct{}{} + c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID] } func (c *haCoordinator) clientDisconnected(id uuid.UUID) { @@ -231,11 +238,20 @@ func (c *haCoordinator) clientNodeUpdate(id uuid.UUID, node *agpl.Node) error { // to establish connections. c.nodes[id] = node - for agentID := range c.clientsToAgents[id] { - // Write the new node from this client to the actively connected agent. - err := c.sendNodeToAgentLocked(agentID, node) - if err != nil { - c.log.Error(context.Background(), "send node to agent", slog.Error(err), slog.F("agent_id", agentID)) + for agentID, agentSocket := range c.clientsToAgents[id] { + if agentSocket == nil { + // If we don't own the agent locally, send it over pubsub to a node that + // owns the agent. + err := c.publishNodesToAgent(agentID, []*agpl.Node{node}) + if err != nil { + c.log.Error(context.Background(), "publish node to agent", slog.Error(err), slog.F("agent_id", agentID)) + } + } else { + // Write the new node from this client to the actively connected agent. + err := agentSocket.Enqueue([]*agpl.Node{node}) + if err != nil { + c.log.Error(context.Background(), "enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID)) + } } } @@ -294,6 +310,9 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err } } c.agentSockets[id] = tc + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = tc + } c.mutex.Unlock() go tc.SendUpdates() @@ -313,6 +332,9 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err delete(c.agentSockets, id) delete(c.nodes, id) } + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = nil + } }() decoder := json.NewDecoder(conn) diff --git a/go.mod b/go.mod index 8560f3a56645b..8b6e9ce7ae986 100644 --- a/go.mod +++ b/go.mod @@ -126,7 +126,7 @@ require ( github.com/mitchellh/go-wordwrap v1.0.1 github.com/mitchellh/mapstructure v1.5.0 github.com/moby/moby v24.0.1+incompatible - github.com/muesli/reflow v0.3.0 + github.com/muesli/reflow v0.3.0 // indirect github.com/open-policy-agent/opa v0.51.0 github.com/ory/dockertest/v3 v3.10.0 github.com/pion/udp v0.1.2 diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 61affb4eb02a8..70c9e23c152a2 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -1,7 +1,6 @@ package tailnet import ( - "bytes" "context" "encoding/json" "errors" @@ -147,14 +146,14 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { OnNodeUpdate: c.core.clientNodeUpdate, OnRemove: c.core.clientDisconnected, }).Init() - c.core.addMultiAgent(id, m) + c.core.addClient(id, m) return m } -func (c *core) addMultiAgent(id uuid.UUID, ma *MultiAgent) { +func (c *core) addClient(id uuid.UUID, ma Queue) { c.mutex.Lock() c.clients[id] = ma - c.clientsToAgents[id] = map[uuid.UUID]struct{}{} + c.clientsToAgents[id] = map[uuid.UUID]Queue{} c.mutex.Unlock() } @@ -168,16 +167,16 @@ type core struct { // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*Node // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]Enqueueable + agentSockets map[uuid.UUID]Queue // agentToConnectionSockets maps agent IDs to connection IDs of conns that // are subscribed to updates for that agent. - agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Enqueueable + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Queue // clients holds a map of all clients connected to the coordinator. This is // necessary because a client may not be subscribed into any agents. - clients map[uuid.UUID]Enqueueable + clients map[uuid.UUID]Queue // clientsToAgents is an index of clients to all of their subscribed agents. - clientsToAgents map[uuid.UUID]map[uuid.UUID]struct{} + clientsToAgents map[uuid.UUID]map[uuid.UUID]Queue // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. @@ -190,7 +189,7 @@ type core struct { legacyAgents map[uuid.UUID]struct{} } -type Enqueueable interface { +type Queue interface { UniqueID() uuid.UUID Enqueue(n []*Node) error Name() string @@ -210,12 +209,12 @@ func newCore(logger slog.Logger) *core { logger: logger, closed: false, nodes: map[uuid.UUID]*Node{}, - agentSockets: map[uuid.UUID]Enqueueable{}, - agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Enqueueable{}, + agentSockets: map[uuid.UUID]Queue{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Queue{}, agentNameCache: nameCache, legacyAgents: map[uuid.UUID]struct{}{}, - clients: map[uuid.UUID]Enqueueable{}, - clientsToAgents: map[uuid.UUID]map[uuid.UUID]struct{}{}, + clients: map[uuid.UUID]Queue{}, + clientsToAgents: map[uuid.UUID]map[uuid.UUID]Queue{}, } } @@ -261,17 +260,20 @@ func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { logger := c.core.clientLogger(id, agentID) logger.Debug(ctx, "coordinating client") - ma := c.ServeMultiAgent(id) - defer ma.Close() + tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0) + defer tc.Close() - err := ma.SubscribeAgent(agentID) + c.core.addClient(id, tc) + defer c.core.clientDisconnected(id) + + err := c.core.clientSubscribeToAgent(tc, agentID) if err != nil { return xerrors.Errorf("subscribe agent: %w", err) } // On this goroutine, we read updates from the client and publish them. We start a second goroutine // to write updates back to the client. - go SendUpdatesToConn(ctx, logger, ma, conn) + go tc.SendUpdates() decoder := json.NewDecoder(conn) for { @@ -286,75 +288,19 @@ func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { } } -func SendUpdatesToConn(ctx context.Context, logger slog.Logger, ma MultiAgentConn, conn net.Conn) { - defer logger.Debug(ctx, "done sending updates") - defer func() { - _ = ma.Close() - _ = conn.Close() - }() - - lastData := []byte{} - - for { - nodes, ok := ma.NextUpdate(ctx) - if !ok { - return - } - - data, err := json.Marshal(nodes) - if err != nil { - logger.Error(ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) - return - } - if bytes.Equal(lastData, data) { - logger.Debug(ctx, "skipping duplicate update", slog.F("nodes", nodes)) - continue - } - - // Set a deadline so that hung connections don't put back pressure on the system. - // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. - err = conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - logger.Debug(ctx, "unable to set write deadline", slog.Error(err)) - return - } - _, err = conn.Write(data) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - logger.Debug(ctx, "could not write nodes to connection", slog.Error(err), slog.F("nodes", nodes)) - return - } - logger.Debug(ctx, "wrote nodes", slog.F("nodes", nodes)) - - // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are - // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() - // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. - // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after - // our successful write, it is important that we reset the deadline before it fires. - err = conn.SetWriteDeadline(time.Time{}) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - logger.Debug(ctx, "unable to extend write deadline", slog.Error(err)) - return - } - lastData = data - } -} - func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger { return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) } -func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Enqueueable) { +func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Queue) { connectionSockets, ok := c.agentToConnectionSockets[agentID] if !ok { - connectionSockets = map[uuid.UUID]Enqueueable{} + connectionSockets = map[uuid.UUID]Queue{} c.agentToConnectionSockets[agentID] = connectionSockets } connectionSockets[enq.UniqueID()] = enq - c.clientsToAgents[enq.UniqueID()][agentID] = struct{}{} + c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID] } func (c *core) clientDisconnected(id uuid.UUID) { @@ -414,10 +360,15 @@ func (c *core) clientNodeUpdateLocked(id uuid.UUID, node *Node) error { logger := c.clientLogger(id, uuid.Nil) agents := []uuid.UUID{} - for agentID := range c.clientsToAgents[id] { - err := c.sendNodeToAgentLocked(agentID, node) + for agentID, agentSocket := range c.clientsToAgents[id] { + if agentSocket == nil { + logger.Debug(context.Background(), "enqueue node to agent; socket is nil", slog.F("agent_id", agentID)) + continue + } + + err := agentSocket.Enqueue([]*Node{node}) if err != nil { - logger.Debug(context.Background(), "unable to send node to agent", slog.Error(err), slog.F("agent_id", agentID)) + logger.Debug(context.Background(), "unable to Enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID)) continue } agents = append(agents, agentID) @@ -427,25 +378,11 @@ func (c *core) clientNodeUpdateLocked(id uuid.UUID, node *Node) error { return nil } -func (c *core) sendNodeToAgentLocked(agentID uuid.UUID, node *Node) error { - agentSocket, ok := c.agentSockets[agentID] - if !ok { - return xerrors.New("no agent socket") - } - - err := agentSocket.Enqueue([]*Node{node}) - if err != nil { - return xerrors.Errorf("enqueue client to agent: %w", err) - } - - return nil -} - -func (c *core) clientSubscribeToAgent(enq Enqueueable, agentID uuid.UUID) error { +func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) error { c.mutex.Lock() defer c.mutex.Unlock() - logger := c.clientLogger(enq.UniqueID(), uuid.Nil) + logger := c.clientLogger(enq.UniqueID(), agentID) c.initOrSetAgentConnectionSocketLocked(agentID, enq) @@ -453,12 +390,17 @@ func (c *core) clientSubscribeToAgent(enq Enqueueable, agentID uuid.UUID) error if ok { // If we have the node, send it to the agent. If not, it will be sent // async. - err := c.sendNodeToAgentLocked(agentID, node) - if err != nil { - logger.Debug(context.Background(), "unable to send node to agent", slog.Error(err), slog.F("agent_id", agentID)) + agentSocket, ok := c.agentSockets[agentID] + if !ok { + logger.Debug(context.Background(), "subscribe to agent; socket is nil") + } else { + err := agentSocket.Enqueue([]*Node{node}) + if err != nil { + return xerrors.Errorf("enqueue client to agent: %w", err) + } } } else { - logger.Debug(context.Background(), "multiagent node doesn't exist", slog.F("multiagent_id", enq.UniqueID())) + logger.Debug(context.Background(), "multiagent node doesn't exist") } agentNode, ok := c.nodes[agentID] @@ -521,6 +463,9 @@ func (c *core) agentDisconnected(id, unique uuid.UUID) { delete(c.nodes, id) logger.Debug(context.Background(), "deleted agent socket and node") } + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = nil + } } // initAndTrackAgent creates a TrackedConn for the agent, and sends any initial nodes updates if we have any. It is @@ -572,6 +517,10 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co } c.agentSockets[id] = tc + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = tc + } + logger.Debug(ctx, "added agent socket") return tc, nil } @@ -691,8 +640,8 @@ func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) { } func CoordinatorHTTPDebug( - agentSocketsMap map[uuid.UUID]Enqueueable, - agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]Enqueueable, + agentSocketsMap map[uuid.UUID]Queue, + agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]Queue, agentNameCache *lru.Cache[uuid.UUID, string], ) func(w http.ResponseWriter, _ *http.Request) { return func(w http.ResponseWriter, _ *http.Request) { @@ -700,7 +649,7 @@ func CoordinatorHTTPDebug( type idConn struct { id uuid.UUID - conn Enqueueable + conn Queue } { diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index 950b951c227fa..ffe320c782ddd 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/uuid" + "golang.org/x/xerrors" "cdr.dev/slog" ) @@ -22,15 +23,16 @@ type MultiAgentConn interface { } type MultiAgent struct { - mu sync.RWMutex + mu sync.RWMutex + closed bool ID uuid.UUID Logger slog.Logger AgentIsLegacyFunc func(agentID uuid.UUID) bool - OnSubscribe func(enq Enqueueable, agent uuid.UUID) error - OnUnsubscribe func(enq Enqueueable, agent uuid.UUID) error + OnSubscribe func(enq Queue, agent uuid.UUID) error + OnUnsubscribe func(enq Queue, agent uuid.UUID) error OnNodeUpdate func(id uuid.UUID, node *Node) error OnRemove func(id uuid.UUID) @@ -57,15 +59,35 @@ func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool { return m.AgentIsLegacyFunc(agentID) } +var ErrMultiAgentClosed = xerrors.New("multiagent is closed") + func (m *MultiAgent) UpdateSelf(node *Node) error { + m.mu.RLock() + defer m.mu.RUnlock() + if m.closed { + return ErrMultiAgentClosed + } + return m.OnNodeUpdate(m.ID, node) } func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error { + m.mu.RLock() + defer m.mu.RUnlock() + if m.closed { + return ErrMultiAgentClosed + } + return m.OnSubscribe(m, agentID) } func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error { + m.mu.RLock() + defer m.mu.RUnlock() + if m.closed { + return ErrMultiAgentClosed + } + return m.OnUnsubscribe(m, agentID) } @@ -74,8 +96,8 @@ func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) { case <-ctx.Done(): return nil, false - case nodes := <-m.updates: - return nodes, true + case nodes, ok := <-m.updates: + return nodes, ok } } From dd3cc15e79957982179a883ba1c7ff02dc3535b7 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 7 Jul 2023 17:04:18 +0000 Subject: [PATCH 15/19] comment --- tailnet/coordinator.go | 1 + 1 file changed, 1 insertion(+) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 70c9e23c152a2..b807f2e0254fe 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -195,6 +195,7 @@ type Queue interface { Name() string Stats() (start, lastWrite int64) Overwrites() int64 + // CoordinatorClose CoordinatorClose() error Close() error } From 8896ae46103f8a8c5490993bfce7502b6b48043b Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 7 Jul 2023 17:04:34 +0000 Subject: [PATCH 16/19] fixup! comment --- tailnet/coordinator.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index b807f2e0254fe..4334b49317308 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -195,7 +195,8 @@ type Queue interface { Name() string Stats() (start, lastWrite int64) Overwrites() int64 - // CoordinatorClose + // CoordinatorClose is used by the coordinator when closing a Queue. It + // should skip removing itself from the coordinator. CoordinatorClose() error Close() error } From be4db71c4c8be50b65292aca4d21739553f11a28 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 7 Jul 2023 17:11:56 +0000 Subject: [PATCH 17/19] fixup! comment --- enterprise/coderd/appearance_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/enterprise/coderd/appearance_test.go b/enterprise/coderd/appearance_test.go index 9fa7ffd863f68..f2ad6be6ff292 100644 --- a/enterprise/coderd/appearance_test.go +++ b/enterprise/coderd/appearance_test.go @@ -9,8 +9,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" - "github.com/google/uuid" - "github.com/coder/coder/cli/clibase" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" From 7bfac9be2bafccadadb58439816c5174e10de103 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Tue, 11 Jul 2023 19:38:29 +0000 Subject: [PATCH 18/19] add feature flag --- coderd/apidoc/docs.go | 6 ++++-- coderd/apidoc/swagger.json | 6 ++++-- coderd/coderd.go | 30 ++++++++++++++++++------------ coderd/workspaceagents.go | 2 +- codersdk/deployment.go | 6 ++++++ docs/api/schemas.md | 1 + site/src/api/typesGenerated.ts | 2 ++ 7 files changed, 36 insertions(+), 17 deletions(-) diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index a282e45d4623a..2e373d75a77f7 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -7647,13 +7647,15 @@ const docTemplate = `{ "moons", "workspace_actions", "tailnet_pg_coordinator", - "convert-to-oidc" + "convert-to-oidc", + "single_tailnet" ], "x-enum-varnames": [ "ExperimentMoons", "ExperimentWorkspaceActions", "ExperimentTailnetPGCoordinator", - "ExperimentConvertToOIDC" + "ExperimentConvertToOIDC", + "ExperimentSingleTailnet" ] }, "codersdk.Feature": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 2bc936281cdbe..721d9df3cf6a9 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -6843,13 +6843,15 @@ "moons", "workspace_actions", "tailnet_pg_coordinator", - "convert-to-oidc" + "convert-to-oidc", + "single_tailnet" ], "x-enum-varnames": [ "ExperimentMoons", "ExperimentWorkspaceActions", "ExperimentTailnetPGCoordinator", - "ExperimentConvertToOIDC" + "ExperimentConvertToOIDC", + "ExperimentSingleTailnet" ] }, "codersdk.Feature": { diff --git a/coderd/coderd.go b/coderd/coderd.go index 59dd56f96c7de..398477568d379 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -364,15 +364,21 @@ func New(options *Options) *API { api.Auditor.Store(&options.Auditor) api.TailnetCoordinator.Store(&options.TailnetCoordinator) - api.tailnet, err = NewServerTailnet(api.ctx, - options.Logger, - options.DERPServer, - options.DERPMap, - &api.TailnetCoordinator, - wsconncache.New(api._dialWorkspaceAgentTailnet, 0), - ) - if err != nil { - panic("failed to setup server tailnet: " + err.Error()) + if api.Experiments.Enabled(codersdk.ExperimentSingleTailnet) { + api.agentProvider, err = NewServerTailnet(api.ctx, + options.Logger, + options.DERPServer, + options.DERPMap, + &api.TailnetCoordinator, + wsconncache.New(api._dialWorkspaceAgentTailnet, 0), + ) + if err != nil { + panic("failed to setup server tailnet: " + err.Error()) + } + } else { + api.agentProvider = &wsconncache.AgentProvider{ + Cache: wsconncache.New(api._dialWorkspaceAgentTailnet, 0), + } } api.workspaceAppServer = &workspaceapps.Server{ @@ -385,7 +391,7 @@ func New(options *Options) *API { RealIPConfig: options.RealIPConfig, SignedTokenProvider: api.WorkspaceAppsProvider, - AgentProvider: api.tailnet, + AgentProvider: api.agentProvider, AppSecurityKey: options.AppSecurityKey, DisablePathApps: options.DeploymentValues.DisablePathApps.Value(), @@ -923,7 +929,7 @@ type API struct { updateChecker *updatecheck.Checker WorkspaceAppsProvider workspaceapps.SignedTokenProvider workspaceAppServer *workspaceapps.Server - tailnet *ServerTailnet + agentProvider workspaceapps.AgentProvider // Experiments contains the list of experiments currently enabled. // This is used to gate features that are not yet ready for production. @@ -950,7 +956,7 @@ func (api *API) Close() error { if coordinator != nil { _ = (*coordinator).Close() } - _ = api.tailnet.Close() + _ = api.agentProvider.Close() return nil } diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 2de3e660996d9..fd7cb0ebe5874 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -655,7 +655,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req return } - agentConn, release, err := api.tailnet.AgentConn(ctx, workspaceAgent.ID) + agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error dialing workspace agent.", diff --git a/codersdk/deployment.go b/codersdk/deployment.go index 88bc65ea6a5f5..01c45e0c0ae2b 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -1763,6 +1763,12 @@ const ( // oidc. ExperimentConvertToOIDC Experiment = "convert-to-oidc" + // ExperimentSingleTailnet replaces workspace connections inside coderd to + // all use a single tailnet, instead of the previous behavior of creating a + // single tailnet for each agent. + // WARNING: This cannot be enabled when using HA. + ExperimentSingleTailnet Experiment = "single_tailnet" + // Add new experiments here! // ExperimentExample Experiment = "example" ) diff --git a/docs/api/schemas.md b/docs/api/schemas.md index 076bf5a631d35..d6a5601b6ded1 100644 --- a/docs/api/schemas.md +++ b/docs/api/schemas.md @@ -2539,6 +2539,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in | `workspace_actions` | | `tailnet_pg_coordinator` | | `convert-to-oidc` | +| `single_tailnet` | ## codersdk.Feature diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 6e0a012a2af4e..e461b9e09f908 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1414,11 +1414,13 @@ export const Entitlements: Entitlement[] = [ export type Experiment = | "convert-to-oidc" | "moons" + | "single_tailnet" | "tailnet_pg_coordinator" | "workspace_actions" export const Experiments: Experiment[] = [ "convert-to-oidc", "moons", + "single_tailnet", "tailnet_pg_coordinator", "workspace_actions", ] From ef440925de7e897716321d329471bec819b2b065 Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Wed, 12 Jul 2023 21:58:26 +0000 Subject: [PATCH 19/19] fixes --- coderd/tailnet.go | 62 +++++++++++++++---------------- enterprise/tailnet/coordinator.go | 42 +++++++++++---------- tailnet/coordinator.go | 40 ++++++++++++++------ tailnet/multiagent.go | 21 +++++++++-- 4 files changed, 99 insertions(+), 66 deletions(-) diff --git a/coderd/tailnet.go b/coderd/tailnet.go index c7b2db7f81e63..a1559e4efcd52 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -56,14 +56,15 @@ func NewServerTailnet( serverCtx, cancel := context.WithCancel(ctx) tn := &ServerTailnet{ - ctx: serverCtx, - cancel: cancel, - logger: logger, - conn: conn, - coord: coord, - cache: cache, - agentNodes: map[uuid.UUID]*tailnetNode{}, - transport: tailnetTransport.Clone(), + ctx: serverCtx, + cancel: cancel, + logger: logger, + conn: conn, + coord: coord, + cache: cache, + agentNodes: map[uuid.UUID]time.Time{}, + agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{}, + transport: tailnetTransport.Clone(), } tn.transport.DialContext = tn.dialContext tn.transport.MaxIdleConnsPerHost = 10 @@ -122,15 +123,19 @@ func (s *ServerTailnet) expireOldAgents() { s.nodesMu.Lock() agentConn := s.getAgentConn() - for agentID, node := range s.agentNodes { - if time.Since(node.lastConnection) > cutoff { - err := agentConn.UnsubscribeAgent(agentID) - if err != nil { - s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID)) - } - delete(s.agentNodes, agentID) - - // TODO(coadler): actually remove from the netmap + for agentID, lastConnection := range s.agentNodes { + // If no one has connected since the cutoff and there are no active + // connections, remove the agent. + if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 { + _ = agentConn + // err := agentConn.UnsubscribeAgent(agentID) + // if err != nil { + // s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID)) + // } + // delete(s.agentNodes, agentID) + + // TODO(coadler): actually remove from the netmap, then reenable + // the above } } s.nodesMu.Unlock() @@ -176,10 +181,6 @@ func (s *ServerTailnet) reinitCoordinator() { s.nodesMu.Unlock() } -type tailnetNode struct { - lastConnection time.Time -} - type ServerTailnet struct { ctx context.Context cancel func() @@ -191,8 +192,10 @@ type ServerTailnet struct { cache *wsconncache.Cache nodesMu sync.Mutex // agentNodes is a map of agent tailnetNodes the server wants to keep a - // connection to. - agentNodes map[uuid.UUID]*tailnetNode + // connection to. It contains the last time the agent was connected to. + agentNodes map[uuid.UUID]time.Time + // agentTockets holds a map of all open connections to an agent. + agentTickets map[uuid.UUID]map[uuid.UUID]struct{} transport *http.Transport } @@ -237,7 +240,9 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) ( func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error { s.nodesMu.Lock() - tnode, ok := s.agentNodes[agentID] + defer s.nodesMu.Unlock() + + _, ok := s.agentNodes[agentID] // If we don't have the node, subscribe. if !ok { s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID)) @@ -245,15 +250,10 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error { if err != nil { return xerrors.Errorf("subscribe agent: %w", err) } - tnode = &tailnetNode{ - lastConnection: time.Now(), - } - s.agentNodes[agentID] = tnode - } else { - tnode.lastConnection = time.Now() + s.agentTickets[agentID] = map[uuid.UUID]struct{}{} } - s.nodesMu.Unlock() + s.agentNodes[agentID] = time.Now() return nil } diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index c3d6b2650ce18..889df136710c5 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -73,37 +73,34 @@ func (c *haCoordinator) addClient(id uuid.UUID, q agpl.Queue) { c.mutex.Unlock() } -func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Queue, agentID uuid.UUID) error { +func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) { c.mutex.Lock() defer c.mutex.Unlock() c.initOrSetAgentConnectionSocketLocked(agentID, enq) node := c.nodes[enq.UniqueID()] + if node != nil { + err := c.sendNodeToAgentLocked(agentID, node) + if err != nil { + return nil, xerrors.Errorf("handle client update: %w", err) + } + } agentNode, ok := c.nodes[agentID] - // If we have the node locally, publish it immediately to the multiagent. + // If we have the node locally, give it back to the multiagent. if ok { - err := enq.Enqueue([]*agpl.Node{agentNode}) - if err != nil { - return xerrors.Errorf("enqueue agent on subscribe: %w", err) - } - } else { - // If we don't have the node locally, notify other coordinators. - err := c.publishClientHello(agentID) - if err != nil { - return xerrors.Errorf("publish client hello: %w", err) - } + return agentNode, nil } - if node != nil { - err := c.sendNodeToAgentLocked(agentID, node) - if err != nil { - return xerrors.Errorf("handle client update: %w", err) - } + // If we don't have the node locally, notify other coordinators. + err := c.publishClientHello(agentID) + if err != nil { + return nil, xerrors.Errorf("publish client hello: %w", err) } - return nil + // nolint:nilnil + return nil, nil } type haCoordinator struct { @@ -168,11 +165,18 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error c.addClient(id, tc) defer c.clientDisconnected(id) - err := c.clientSubscribeToAgent(tc, agentID) + agentNode, err := c.clientSubscribeToAgent(tc, agentID) if err != nil { return xerrors.Errorf("subscribe agent: %w", err) } + if agentNode != nil { + err := tc.Enqueue([]*agpl.Node{agentNode}) + if err != nil { + logger.Debug(ctx, "enqueue initial node", slog.Error(err)) + } + } + go tc.SendUpdates() decoder := json.NewDecoder(conn) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 4334b49317308..93cf8c67af56b 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -143,6 +143,7 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { Logger: c.core.logger, AgentIsLegacyFunc: c.core.agentIsLegacy, OnSubscribe: c.core.clientSubscribeToAgent, + OnUnsubscribe: c.core.clientUnsubscribeFromAgent, OnNodeUpdate: c.core.clientNodeUpdate, OnRemove: c.core.clientDisconnected, }).Init() @@ -268,11 +269,18 @@ func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { c.core.addClient(id, tc) defer c.core.clientDisconnected(id) - err := c.core.clientSubscribeToAgent(tc, agentID) + agentNode, err := c.core.clientSubscribeToAgent(tc, agentID) if err != nil { return xerrors.Errorf("subscribe agent: %w", err) } + if agentNode != nil { + err := tc.Enqueue([]*Node{agentNode}) + if err != nil { + logger.Debug(ctx, "enqueue initial node", slog.Error(err)) + } + } + // On this goroutine, we read updates from the client and publish them. We start a second goroutine // to write updates back to the client. go tc.SendUpdates() @@ -316,16 +324,15 @@ func (c *core) clientDisconnected(id uuid.UUID) { for agentID := range c.clientsToAgents[id] { connectionSockets, ok := c.agentToConnectionSockets[agentID] if !ok { - return + continue } delete(connectionSockets, id) logger.Debug(context.Background(), "deleted client connectionSocket from map", slog.F("agent_id", agentID)) - if len(connectionSockets) != 0 { - return + if len(connectionSockets) == 0 { + delete(c.agentToConnectionSockets, agentID) + logger.Debug(context.Background(), "deleted last client connectionSocket from map", slog.F("agent_id", agentID)) } - delete(c.agentToConnectionSockets, agentID) - logger.Debug(context.Background(), "deleted last client connectionSocket from map", slog.F("agent_id", agentID)) } delete(c.clients, id) @@ -380,7 +387,7 @@ func (c *core) clientNodeUpdateLocked(id uuid.UUID, node *Node) error { return nil } -func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) error { +func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) (*Node, error) { c.mutex.Lock() defer c.mutex.Unlock() @@ -390,15 +397,15 @@ func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) error { node, ok := c.nodes[enq.UniqueID()] if ok { - // If we have the node, send it to the agent. If not, it will be sent - // async. + // If we have the client node, send it to the agent. If not, it will be + // sent async. agentSocket, ok := c.agentSockets[agentID] if !ok { logger.Debug(context.Background(), "subscribe to agent; socket is nil") } else { err := agentSocket.Enqueue([]*Node{node}) if err != nil { - return xerrors.Errorf("enqueue client to agent: %w", err) + return nil, xerrors.Errorf("enqueue client to agent: %w", err) } } } else { @@ -409,11 +416,20 @@ func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) error { if !ok { // This is ok, once the agent connects the node will be sent over. logger.Debug(context.Background(), "agent node doesn't exist", slog.F("agent_id", agentID)) - return nil } // Send the subscribed agent back to the multi agent. - return enq.Enqueue([]*Node{agentNode}) + return agentNode, nil +} + +func (c *core) clientUnsubscribeFromAgent(enq Queue, agentID uuid.UUID) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + delete(c.clientsToAgents[enq.UniqueID()], agentID) + delete(c.agentToConnectionSockets[agentID], enq.UniqueID()) + + return nil } func (c *core) agentLogger(id uuid.UUID) slog.Logger { diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index ffe320c782ddd..13300fdce677a 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -31,7 +31,7 @@ type MultiAgent struct { Logger slog.Logger AgentIsLegacyFunc func(agentID uuid.UUID) bool - OnSubscribe func(enq Queue, agent uuid.UUID) error + OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error) OnUnsubscribe func(enq Queue, agent uuid.UUID) error OnNodeUpdate func(id uuid.UUID, node *Node) error OnRemove func(id uuid.UUID) @@ -78,7 +78,16 @@ func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error { return ErrMultiAgentClosed } - return m.OnSubscribe(m, agentID) + node, err := m.OnSubscribe(m, agentID) + if err != nil { + return err + } + + if node != nil { + return m.enqueueLocked([]*Node{node}) + } + + return nil } func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error { @@ -102,8 +111,6 @@ func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) { } func (m *MultiAgent) Enqueue(nodes []*Node) error { - atomic.StoreInt64(&m.lastWrite, time.Now().Unix()) - m.mu.RLock() defer m.mu.RUnlock() @@ -111,6 +118,12 @@ func (m *MultiAgent) Enqueue(nodes []*Node) error { return nil } + return m.enqueueLocked(nodes) +} + +func (m *MultiAgent) enqueueLocked(nodes []*Node) error { + atomic.StoreInt64(&m.lastWrite, time.Now().Unix()) + select { case m.updates <- nodes: return nil