diff --git a/agent/agent.go b/agent/agent.go index c2e2670a41257..f755587e793ec 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -64,6 +64,7 @@ type Options struct { SSHMaxTimeout time.Duration TailnetListenPort uint16 Subsystem codersdk.AgentSubsystem + Addresses []netip.Prefix PrometheusRegistry *prometheus.Registry } @@ -132,6 +133,7 @@ func New(options Options) Agent { connStatsChan: make(chan *agentsdk.Stats, 1), sshMaxTimeout: options.SSHMaxTimeout, subsystem: options.Subsystem, + addresses: options.Addresses, prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), @@ -177,6 +179,7 @@ type agent struct { lifecycleStates []agentsdk.PostLifecycleRequest network *tailnet.Conn + addresses []netip.Prefix connStatsChan chan *agentsdk.Stats latestStat atomic.Pointer[agentsdk.Stats] @@ -545,6 +548,10 @@ func (a *agent) run(ctx context.Context) error { } a.logger.Info(ctx, "fetched manifest", slog.F("manifest", manifest)) + if manifest.AgentID == uuid.Nil { + return xerrors.New("nil agentID returned by manifest") + } + // Expand the directory and send it back to coderd so external // applications that rely on the directory can use it. // @@ -630,7 +637,7 @@ func (a *agent) run(ctx context.Context) error { network := a.network a.closeMutex.Unlock() if network == nil { - network, err = a.createTailnet(ctx, manifest.DERPMap, manifest.DisableDirectConnections) + network, err = a.createTailnet(ctx, manifest.AgentID, manifest.DERPMap, manifest.DisableDirectConnections) if err != nil { return xerrors.Errorf("create tailnet: %w", err) } @@ -648,6 +655,11 @@ func (a *agent) run(ctx context.Context) error { a.startReportingConnectionStats(ctx) } else { + // Update the wireguard IPs if the agent ID changed. + err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID)) + if err != nil { + a.logger.Error(ctx, "update tailnet addresses", slog.Error(err)) + } // Update the DERP map and allow/disallow direct connections. network.SetDERPMap(manifest.DERPMap) network.SetBlockEndpoints(manifest.DisableDirectConnections) @@ -661,6 +673,20 @@ func (a *agent) run(ctx context.Context) error { return nil } +func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix { + if len(a.addresses) == 0 { + return []netip.Prefix{ + // This is the IP that should be used primarily. + netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128), + // We also listen on the legacy codersdk.WorkspaceAgentIP. This + // allows for a transition away from wsconncache. + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + } + } + + return a.addresses +} + func (a *agent) trackConnGoroutine(fn func()) error { a.closeMutex.Lock() defer a.closeMutex.Unlock() @@ -675,9 +701,9 @@ func (a *agent) trackConnGoroutine(fn func()) error { return nil } -func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) { +func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) { network, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, + Addresses: a.wireguardAddresses(agentID), DERPMap: derpMap, Logger: a.logger.Named("tailnet"), ListenPort: a.tailnetListenPort, diff --git a/agent/agent_test.go b/agent/agent_test.go index 8ac7eca050af9..e9b1f485f718a 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -35,7 +35,6 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" "golang.org/x/crypto/ssh" - "golang.org/x/exp/maps" "golang.org/x/exp/slices" "golang.org/x/xerrors" "tailscale.com/net/speedtest" @@ -45,6 +44,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" "github.com/coder/coder/agent/agentssh" + "github.com/coder/coder/agent/agenttest" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" @@ -67,7 +67,7 @@ func TestAgent_Stats_SSH(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, &client{}, 0) + conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -100,7 +100,7 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, &client{}, 0) + conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "/bin/bash") require.NoError(t, err) @@ -130,7 +130,7 @@ func TestAgent_Stats_Magic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -157,7 +157,7 @@ func TestAgent_Stats_Magic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, &client{}, 0) + conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -425,20 +425,19 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled // Allow the blank identifiers. - conn, client, _, _, _ := setupAgent(t, &client{}, 0) + conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) for _, test := range tests { test := test - + // Set new banner func and wait for the agent to call it to update the + // banner. ready := make(chan struct{}, 2) - client.mu.Lock() - client.getServiceBanner = func() (codersdk.ServiceBannerConfig, error) { + client.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) { select { case ready <- struct{}{}: default: } return test.banner, nil - } - client.mu.Unlock() + }) <-ready <-ready // Wait for two updates to ensure the value has propagated. @@ -542,7 +541,7 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -592,7 +591,7 @@ func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -922,7 +921,7 @@ func TestAgent_SFTP(t *testing.T) { home = "/" + strings.ReplaceAll(home, "\\", "/") } //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -954,7 +953,7 @@ func TestAgent_SCP(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -1062,16 +1061,15 @@ func TestAgent_StartupScript(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ StartupScript: command, DERPMap: &tailcfg.DERPMap{}, }, - statsChan: make(chan *agentsdk.Stats), - coordinator: tailnet.NewCoordinator(logger), - } + make(chan *agentsdk.Stats), + tailnet.NewCoordinator(logger), + ) closer := agent.New(agent.Options{ Client: client, Filesystem: afero.NewMemMapFs(), @@ -1082,36 +1080,35 @@ func TestAgent_StartupScript(t *testing.T) { _ = closer.Close() }) assert.Eventually(t, func() bool { - got := client.getLifecycleStates() + got := client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady }, testutil.WaitShort, testutil.IntervalMedium) - require.Len(t, client.getStartupLogs(), 1) - require.Equal(t, output, client.getStartupLogs()[0].Output) + require.Len(t, client.GetStartupLogs(), 1) + require.Equal(t, output, client.GetStartupLogs()[0].Output) }) // This ensures that even when coderd sends back that the startup // script has written too many lines it will still succeed! t.Run("OverflowsAndSkips", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ StartupScript: command, DERPMap: &tailcfg.DERPMap{}, }, - patchWorkspaceLogs: func() error { - resp := httptest.NewRecorder() - httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{ - Message: "Too many lines!", - }) - res := resp.Result() - defer res.Body.Close() - return codersdk.ReadBodyAsError(res) - }, - statsChan: make(chan *agentsdk.Stats), - coordinator: tailnet.NewCoordinator(logger), + make(chan *agentsdk.Stats, 50), + tailnet.NewCoordinator(logger), + ) + client.PatchWorkspaceLogs = func() error { + resp := httptest.NewRecorder() + httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{ + Message: "Too many lines!", + }) + res := resp.Result() + defer res.Body.Close() + return codersdk.ReadBodyAsError(res) } closer := agent.New(agent.Options{ Client: client, @@ -1123,10 +1120,10 @@ func TestAgent_StartupScript(t *testing.T) { _ = closer.Close() }) assert.Eventually(t, func() bool { - got := client.getLifecycleStates() + got := client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady }, testutil.WaitShort, testutil.IntervalMedium) - require.Len(t, client.getStartupLogs(), 0) + require.Len(t, client.GetStartupLogs(), 0) }) } @@ -1138,28 +1135,26 @@ func TestAgent_Metadata(t *testing.T) { t.Run("Once", func(t *testing.T) { t.Parallel() //nolint:dogsled - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - Metadata: []codersdk.WorkspaceAgentMetadataDescription{ - { - Key: "greeting", - Interval: 0, - Script: echoHello, - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + Metadata: []codersdk.WorkspaceAgentMetadataDescription{ + { + Key: "greeting", + Interval: 0, + Script: echoHello, }, }, }, 0) var gotMd map[string]agentsdk.PostMetadataRequest require.Eventually(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() return len(gotMd) == 1 }, testutil.WaitShort, testutil.IntervalMedium) collectedAt := gotMd["greeting"].CollectedAt require.Never(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() if len(gotMd) != 1 { panic("unexpected number of metadata") } @@ -1170,22 +1165,20 @@ func TestAgent_Metadata(t *testing.T) { t.Run("Many", func(t *testing.T) { t.Parallel() //nolint:dogsled - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - Metadata: []codersdk.WorkspaceAgentMetadataDescription{ - { - Key: "greeting", - Interval: 1, - Timeout: 100, - Script: echoHello, - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + Metadata: []codersdk.WorkspaceAgentMetadataDescription{ + { + Key: "greeting", + Interval: 1, + Timeout: 100, + Script: echoHello, }, }, }, 0) var gotMd map[string]agentsdk.PostMetadataRequest require.Eventually(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() return len(gotMd) == 1 }, testutil.WaitShort, testutil.IntervalMedium) @@ -1195,7 +1188,7 @@ func TestAgent_Metadata(t *testing.T) { } if !assert.Eventually(t, func() bool { - gotMd = client.getMetadata() + gotMd = client.GetMetadata() return gotMd["greeting"].CollectedAt.After(collectedAt1) }, testutil.WaitShort, testutil.IntervalMedium) { t.Fatalf("expected metadata to be collected again") @@ -1221,29 +1214,27 @@ func TestAgentMetadata_Timing(t *testing.T) { script = "echo hello | tee -a " + greetingPath ) //nolint:dogsled - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - Metadata: []codersdk.WorkspaceAgentMetadataDescription{ - { - Key: "greeting", - Interval: reportInterval, - Script: script, - }, - { - Key: "bad", - Interval: reportInterval, - Script: "exit 1", - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + Metadata: []codersdk.WorkspaceAgentMetadataDescription{ + { + Key: "greeting", + Interval: reportInterval, + Script: script, + }, + { + Key: "bad", + Interval: reportInterval, + Script: "exit 1", }, }, }, 0) require.Eventually(t, func() bool { - return len(client.getMetadata()) == 2 + return len(client.GetMetadata()) == 2 }, testutil.WaitShort, testutil.IntervalMedium) for start := time.Now(); time.Since(start) < testutil.WaitMedium; time.Sleep(testutil.IntervalMedium) { - md := client.getMetadata() + md := client.GetMetadata() require.Len(t, md, 2, "got: %+v", md) require.Equal(t, "hello\n", md["greeting"].Value) @@ -1285,11 +1276,9 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("StartTimeout", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "sleep 3", - StartupScriptTimeout: time.Nanosecond, - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "sleep 3", + StartupScriptTimeout: time.Nanosecond, }, 0) want := []codersdk.WorkspaceAgentLifecycle{ @@ -1299,7 +1288,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return slices.Contains(got, want[len(want)-1]) }, testutil.WaitShort, testutil.IntervalMedium) @@ -1309,11 +1298,9 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("StartError", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "false", - StartupScriptTimeout: 30 * time.Second, - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "false", + StartupScriptTimeout: 30 * time.Second, }, 0) want := []codersdk.WorkspaceAgentLifecycle{ @@ -1323,7 +1310,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return slices.Contains(got, want[len(want)-1]) }, testutil.WaitShort, testutil.IntervalMedium) @@ -1333,11 +1320,9 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("Ready", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, }, 0) want := []codersdk.WorkspaceAgentLifecycle{ @@ -1347,7 +1332,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return len(got) > 0 && got[len(got)-1] == want[len(want)-1] }, testutil.WaitShort, testutil.IntervalMedium) @@ -1357,15 +1342,13 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("ShuttingDown", func(t *testing.T) { t.Parallel() - _, client, _, _, closer := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - ShutdownScript: "sleep 3", - StartupScriptTimeout: 30 * time.Second, - }, + _, client, _, _, closer := setupAgent(t, agentsdk.Manifest{ + ShutdownScript: "sleep 3", + StartupScriptTimeout: 30 * time.Second, }, 0) assert.Eventually(t, func() bool { - return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) + return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) }, testutil.WaitShort, testutil.IntervalMedium) // Start close asynchronously so that we an inspect the state. @@ -1387,7 +1370,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return slices.Contains(got, want[len(want)-1]) }, testutil.WaitShort, testutil.IntervalMedium) @@ -1397,15 +1380,13 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("ShutdownTimeout", func(t *testing.T) { t.Parallel() - _, client, _, _, closer := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - ShutdownScript: "sleep 3", - ShutdownScriptTimeout: time.Nanosecond, - }, + _, client, _, _, closer := setupAgent(t, agentsdk.Manifest{ + ShutdownScript: "sleep 3", + ShutdownScriptTimeout: time.Nanosecond, }, 0) assert.Eventually(t, func() bool { - return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) + return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) }, testutil.WaitShort, testutil.IntervalMedium) // Start close asynchronously so that we an inspect the state. @@ -1428,7 +1409,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return slices.Contains(got, want[len(want)-1]) }, testutil.WaitShort, testutil.IntervalMedium) @@ -1438,15 +1419,13 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("ShutdownError", func(t *testing.T) { t.Parallel() - _, client, _, _, closer := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - ShutdownScript: "false", - ShutdownScriptTimeout: 30 * time.Second, - }, + _, client, _, _, closer := setupAgent(t, agentsdk.Manifest{ + ShutdownScript: "false", + ShutdownScriptTimeout: 30 * time.Second, }, 0) assert.Eventually(t, func() bool { - return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) + return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady) }, testutil.WaitShort, testutil.IntervalMedium) // Start close asynchronously so that we an inspect the state. @@ -1469,7 +1448,7 @@ func TestAgent_Lifecycle(t *testing.T) { var got []codersdk.WorkspaceAgentLifecycle assert.Eventually(t, func() bool { - got = client.getLifecycleStates() + got = client.GetLifecycleStates() return slices.Contains(got, want[len(want)-1]) }, testutil.WaitShort, testutil.IntervalMedium) @@ -1480,17 +1459,18 @@ func TestAgent_Lifecycle(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) expected := "this-is-shutdown" - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ - DERPMap: tailnettest.RunDERPAndSTUN(t), + derpMap, _ := tailnettest.RunDERPAndSTUN(t) + + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ + DERPMap: derpMap, StartupScript: "echo 1", ShutdownScript: "echo " + expected, }, - statsChan: make(chan *agentsdk.Stats), - coordinator: tailnet.NewCoordinator(logger), - } + make(chan *agentsdk.Stats, 50), + tailnet.NewCoordinator(logger), + ) fs := afero.NewMemMapFs() agent := agent.New(agent.Options{ @@ -1536,71 +1516,63 @@ func TestAgent_Startup(t *testing.T) { t.Run("EmptyDirectory", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "", - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) - require.Equal(t, "", client.getStartup().ExpandedDirectory) + require.Equal(t, "", client.GetStartup().ExpandedDirectory) }) t.Run("HomeDirectory", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "~", - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "~", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, homeDir, client.getStartup().ExpandedDirectory) + require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory) }) t.Run("NotAbsoluteDirectory", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "coder/coder", - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "coder/coder", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.getStartup().ExpandedDirectory) + require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.GetStartup().ExpandedDirectory) }) t.Run("HomeEnvironmentVariable", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "$HOME", - }, + _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "$HOME", }, 0) assert.Eventually(t, func() bool { - return client.getStartup().Version != "" + return client.GetStartup().Version != "" }, testutil.WaitShort, testutil.IntervalFast) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, homeDir, client.getStartup().ExpandedDirectory) + require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory) }) } @@ -1617,7 +1589,7 @@ func TestAgent_ReconnectingPTY(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) id := uuid.New() netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash") require.NoError(t, err) @@ -1719,7 +1691,7 @@ func TestAgent_Dial(t *testing.T) { }() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0) + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) require.True(t, conn.AwaitReachable(context.Background())) conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) require.NoError(t, err) @@ -1739,12 +1711,10 @@ func TestAgent_Speedtest(t *testing.T) { t.Skip("This test is relatively flakey because of Tailscale's speedtest code...") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{ - manifest: agentsdk.Manifest{ - DERPMap: derpMap, - }, + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{ + DERPMap: derpMap, }, 0) defer conn.Close() res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond) @@ -1761,17 +1731,16 @@ func TestAgent_Reconnect(t *testing.T) { defer coordinator.Close() agentID := uuid.New() - statsCh := make(chan *agentsdk.Stats) - derpMap := tailnettest.RunDERPAndSTUN(t) - client := &client{ - t: t, - agentID: agentID, - manifest: agentsdk.Manifest{ + statsCh := make(chan *agentsdk.Stats, 50) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) + client := agenttest.NewClient(t, + agentID, + agentsdk.Manifest{ DERPMap: derpMap, }, - statsChan: statsCh, - coordinator: coordinator, - } + statsCh, + coordinator, + ) initialized := atomic.Int32{} closer := agent.New(agent.Options{ ExchangeToken: func(ctx context.Context) (string, error) { @@ -1786,7 +1755,7 @@ func TestAgent_Reconnect(t *testing.T) { require.Eventually(t, func() bool { return coordinator.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) - client.lastWorkspaceAgent() + client.LastWorkspaceAgent() require.Eventually(t, func() bool { return initialized.Load() == 2 }, testutil.WaitShort, testutil.IntervalFast) @@ -1798,16 +1767,15 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { coordinator := tailnet.NewCoordinator(logger) defer coordinator.Close() - client := &client{ - t: t, - agentID: uuid.New(), - manifest: agentsdk.Manifest{ + client := agenttest.NewClient(t, + uuid.New(), + agentsdk.Manifest{ GitAuthConfigs: 1, DERPMap: &tailcfg.DERPMap{}, }, - statsChan: make(chan *agentsdk.Stats), - coordinator: coordinator, - } + make(chan *agentsdk.Stats, 50), + coordinator, + ) filesystem := afero.NewMemMapFs() closer := agent.New(agent.Options{ ExchangeToken: func(ctx context.Context) (string, error) { @@ -1830,7 +1798,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) { //nolint:dogsled - agentConn, _, _, _, _ := setupAgent(t, &client{}, 0) + agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) waitGroup := sync.WaitGroup{} @@ -1883,12 +1851,11 @@ func setupSSHSession( ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, fs, _ := setupAgent(t, &client{ - manifest: options, - getServiceBanner: func() (codersdk.ServiceBannerConfig, error) { + conn, _, _, fs, _ := setupAgent(t, options, 0, func(c *agenttest.Client, _ *agent.Options) { + c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) { return serviceBanner, nil - }, - }, 0) + }) + }) if prepareFS != nil { prepareFS(fs) } @@ -1905,31 +1872,28 @@ func setupSSHSession( return session } -type closeFunc func() error - -func (c closeFunc) Close() error { - return c() -} - -func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(agent.Options) agent.Options) ( +func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( *codersdk.WorkspaceAgentConn, - *client, + *agenttest.Client, <-chan *agentsdk.Stats, afero.Fs, io.Closer, ) { - c.t = t logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - if c.manifest.DERPMap == nil { - c.manifest.DERPMap = tailnettest.RunDERPAndSTUN(t) + if metadata.DERPMap == nil { + metadata.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) } - c.coordinator = tailnet.NewCoordinator(logger) + if metadata.AgentID == uuid.Nil { + metadata.AgentID = uuid.New() + } + coordinator := tailnet.NewCoordinator(logger) t.Cleanup(func() { - _ = c.coordinator.Close() + _ = coordinator.Close() }) - c.agentID = uuid.New() - c.statsChan = make(chan *agentsdk.Stats, 50) + statsCh := make(chan *agentsdk.Stats, 50) fs := afero.NewMemMapFs() + c := agenttest.NewClient(t, metadata.AgentID, metadata, statsCh, coordinator) + options := agent.Options{ Client: c, Filesystem: fs, @@ -1938,7 +1902,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func( } for _, opt := range opts { - options = opt(options) + opt(c, &options) } closer := agent.New(options) @@ -1947,7 +1911,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func( }) conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, - DERPMap: c.manifest.DERPMap, + DERPMap: metadata.DERPMap, Logger: logger.Named("client"), }) require.NoError(t, err) @@ -1961,15 +1925,15 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func( }) go func() { defer close(serveClientDone) - c.coordinator.ServeClient(serverConn, uuid.New(), c.agentID) + coordinator.ServeClient(serverConn, uuid.New(), metadata.AgentID) }() sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { return conn.UpdateNodes(node, false) }) conn.SetNodeCallback(sendNode) - agentConn := &codersdk.WorkspaceAgentConn{ - Conn: conn, - } + agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: metadata.AgentID, + }) t.Cleanup(func() { _ = agentConn.Close() }) @@ -1980,7 +1944,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func( if !agentConn.AwaitReachable(ctx) { t.Fatal("agent not reachable") } - return agentConn, c, c.statsChan, fs, closer + return agentConn, c, statsCh, fs, closer } var dialTestPayload = []byte("dean-was-here123") @@ -2043,146 +2007,6 @@ func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected } } -type client struct { - t *testing.T - agentID uuid.UUID - manifest agentsdk.Manifest - metadata map[string]agentsdk.PostMetadataRequest - statsChan chan *agentsdk.Stats - coordinator tailnet.Coordinator - lastWorkspaceAgent func() - patchWorkspaceLogs func() error - getServiceBanner func() (codersdk.ServiceBannerConfig, error) - - mu sync.Mutex // Protects following. - lifecycleStates []codersdk.WorkspaceAgentLifecycle - startup agentsdk.PostStartupRequest - logs []agentsdk.StartupLog -} - -func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) { - return c.manifest, nil -} - -func (c *client) Listen(_ context.Context) (net.Conn, error) { - clientConn, serverConn := net.Pipe() - closed := make(chan struct{}) - c.lastWorkspaceAgent = func() { - _ = serverConn.Close() - _ = clientConn.Close() - <-closed - } - c.t.Cleanup(c.lastWorkspaceAgent) - go func() { - _ = c.coordinator.ServeAgent(serverConn, c.agentID, "") - close(closed) - }() - return clientConn, nil -} - -func (c *client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) { - doneCh := make(chan struct{}) - ctx, cancel := context.WithCancel(ctx) - - go func() { - defer close(doneCh) - - setInterval(500 * time.Millisecond) - for { - select { - case <-ctx.Done(): - return - case stat := <-statsChan: - select { - case c.statsChan <- stat: - case <-ctx.Done(): - return - default: - // We don't want to send old stats. - continue - } - } - } - }() - return closeFunc(func() error { - cancel() - <-doneCh - close(c.statsChan) - return nil - }), nil -} - -func (c *client) getLifecycleStates() []codersdk.WorkspaceAgentLifecycle { - c.mu.Lock() - defer c.mu.Unlock() - return c.lifecycleStates -} - -func (c *client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - c.lifecycleStates = append(c.lifecycleStates, req.State) - return nil -} - -func (*client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error { - return nil -} - -func (c *client) getStartup() agentsdk.PostStartupRequest { - c.mu.Lock() - defer c.mu.Unlock() - return c.startup -} - -func (c *client) getMetadata() map[string]agentsdk.PostMetadataRequest { - c.mu.Lock() - defer c.mu.Unlock() - return maps.Clone(c.metadata) -} - -func (c *client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - if c.metadata == nil { - c.metadata = make(map[string]agentsdk.PostMetadataRequest) - } - c.metadata[key] = req - return nil -} - -func (c *client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - c.startup = startup - return nil -} - -func (c *client) getStartupLogs() []agentsdk.StartupLog { - c.mu.Lock() - defer c.mu.Unlock() - return c.logs -} - -func (c *client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error { - c.mu.Lock() - defer c.mu.Unlock() - if c.patchWorkspaceLogs != nil { - return c.patchWorkspaceLogs() - } - c.logs = append(c.logs, logs.Logs...) - return nil -} - -func (c *client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) { - c.mu.Lock() - defer c.mu.Unlock() - if c.getServiceBanner != nil { - return c.getServiceBanner() - } - return codersdk.ServiceBannerConfig{}, nil -} - // tempDirUnixSocket returns a temporary directory that can safely hold unix // sockets (probably). // @@ -2214,9 +2038,8 @@ func TestAgent_Metrics_SSH(t *testing.T) { registry := prometheus.NewRegistry() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, &client{}, 0, func(o agent.Options) agent.Options { + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { o.PrometheusRegistry = registry - return o }) sshClient, err := conn.SSHClient(ctx) diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go new file mode 100644 index 0000000000000..c69ff59eb730b --- /dev/null +++ b/agent/agenttest/client.go @@ -0,0 +1,189 @@ +package agenttest + +import ( + "context" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "golang.org/x/exp/maps" + + "cdr.dev/slog" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/tailnet" +) + +func NewClient(t testing.TB, + agentID uuid.UUID, + manifest agentsdk.Manifest, + statsChan chan *agentsdk.Stats, + coordinator tailnet.Coordinator, +) *Client { + if manifest.AgentID == uuid.Nil { + manifest.AgentID = agentID + } + return &Client{ + t: t, + agentID: agentID, + manifest: manifest, + statsChan: statsChan, + coordinator: coordinator, + } +} + +type Client struct { + t testing.TB + agentID uuid.UUID + manifest agentsdk.Manifest + metadata map[string]agentsdk.PostMetadataRequest + statsChan chan *agentsdk.Stats + coordinator tailnet.Coordinator + LastWorkspaceAgent func() + PatchWorkspaceLogs func() error + GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) + + mu sync.Mutex // Protects following. + lifecycleStates []codersdk.WorkspaceAgentLifecycle + startup agentsdk.PostStartupRequest + logs []agentsdk.StartupLog +} + +func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) { + return c.manifest, nil +} + +func (c *Client) Listen(_ context.Context) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + closed := make(chan struct{}) + c.LastWorkspaceAgent = func() { + _ = serverConn.Close() + _ = clientConn.Close() + <-closed + } + c.t.Cleanup(c.LastWorkspaceAgent) + go func() { + _ = c.coordinator.ServeAgent(serverConn, c.agentID, "") + close(closed) + }() + return clientConn, nil +} + +func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) { + doneCh := make(chan struct{}) + ctx, cancel := context.WithCancel(ctx) + + go func() { + defer close(doneCh) + + setInterval(500 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return + case stat := <-statsChan: + select { + case c.statsChan <- stat: + case <-ctx.Done(): + return + default: + // We don't want to send old stats. + continue + } + } + } + }() + return closeFunc(func() error { + cancel() + <-doneCh + close(c.statsChan) + return nil + }), nil +} + +func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle { + c.mu.Lock() + defer c.mu.Unlock() + return c.lifecycleStates +} + +func (c *Client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error { + c.mu.Lock() + defer c.mu.Unlock() + c.lifecycleStates = append(c.lifecycleStates, req.State) + return nil +} + +func (*Client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error { + return nil +} + +func (c *Client) GetStartup() agentsdk.PostStartupRequest { + c.mu.Lock() + defer c.mu.Unlock() + return c.startup +} + +func (c *Client) GetMetadata() map[string]agentsdk.PostMetadataRequest { + c.mu.Lock() + defer c.mu.Unlock() + return maps.Clone(c.metadata) +} + +func (c *Client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.metadata == nil { + c.metadata = make(map[string]agentsdk.PostMetadataRequest) + } + c.metadata[key] = req + return nil +} + +func (c *Client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error { + c.mu.Lock() + defer c.mu.Unlock() + c.startup = startup + return nil +} + +func (c *Client) GetStartupLogs() []agentsdk.StartupLog { + c.mu.Lock() + defer c.mu.Unlock() + return c.logs +} + +func (c *Client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.PatchWorkspaceLogs != nil { + return c.PatchWorkspaceLogs() + } + c.logs = append(c.logs, logs.Logs...) + return nil +} + +func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) { + c.mu.Lock() + defer c.mu.Unlock() + + c.GetServiceBannerFunc = f +} + +func (c *Client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.GetServiceBannerFunc != nil { + return c.GetServiceBannerFunc() + } + return codersdk.ServiceBannerConfig{}, nil +} + +type closeFunc func() error + +func (c closeFunc) Close() error { + return c() +} diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 328107b39d54b..31970e84477cc 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -5961,6 +5961,9 @@ const docTemplate = `{ "agentsdk.Manifest": { "type": "object", "properties": { + "agent_id": { + "type": "string" + }, "apps": { "type": "array", "items": { @@ -7617,6 +7620,7 @@ const docTemplate = `{ "workspace_actions", "tailnet_ha_coordinator", "convert-to-oidc", + "single_tailnet", "workspace_build_logs_ui" ], "x-enum-varnames": [ @@ -7624,6 +7628,7 @@ const docTemplate = `{ "ExperimentWorkspaceActions", "ExperimentTailnetHACoordinator", "ExperimentConvertToOIDC", + "ExperimentSingleTailnet", "ExperimentWorkspaceBuildLogsUI" ] }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 7ea1a1de0633c..841f9c50bbe5f 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -5251,6 +5251,9 @@ "agentsdk.Manifest": { "type": "object", "properties": { + "agent_id": { + "type": "string" + }, "apps": { "type": "array", "items": { @@ -6818,6 +6821,7 @@ "workspace_actions", "tailnet_ha_coordinator", "convert-to-oidc", + "single_tailnet", "workspace_build_logs_ui" ], "x-enum-varnames": [ @@ -6825,6 +6829,7 @@ "ExperimentWorkspaceActions", "ExperimentTailnetHACoordinator", "ExperimentConvertToOIDC", + "ExperimentSingleTailnet", "ExperimentWorkspaceBuildLogsUI" ] }, diff --git a/coderd/coderd.go b/coderd/coderd.go index dc14727879f08..7104049187235 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -364,8 +364,23 @@ func New(options *Options) *API { } api.Auditor.Store(&options.Auditor) - api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) api.TailnetCoordinator.Store(&options.TailnetCoordinator) + if api.Experiments.Enabled(codersdk.ExperimentSingleTailnet) { + api.agentProvider, err = NewServerTailnet(api.ctx, + options.Logger, + options.DERPServer, + options.DERPMap, + &api.TailnetCoordinator, + wsconncache.New(api._dialWorkspaceAgentTailnet, 0), + ) + if err != nil { + panic("failed to setup server tailnet: " + err.Error()) + } + } else { + api.agentProvider = &wsconncache.AgentProvider{ + Cache: wsconncache.New(api._dialWorkspaceAgentTailnet, 0), + } + } api.workspaceAppServer = &workspaceapps.Server{ Logger: options.Logger.Named("workspaceapps"), @@ -377,7 +392,7 @@ func New(options *Options) *API { RealIPConfig: options.RealIPConfig, SignedTokenProvider: api.WorkspaceAppsProvider, - WorkspaceConnCache: api.workspaceAgentCache, + AgentProvider: api.agentProvider, AppSecurityKey: options.AppSecurityKey, DisablePathApps: options.DeploymentValues.DisablePathApps.Value(), @@ -921,10 +936,10 @@ type API struct { derpCloseFunc func() metricsCache *metricscache.Cache - workspaceAgentCache *wsconncache.Cache updateChecker *updatecheck.Checker WorkspaceAppsProvider workspaceapps.SignedTokenProvider workspaceAppServer *workspaceapps.Server + agentProvider workspaceapps.AgentProvider // Experiments contains the list of experiments currently enabled. // This is used to gate features that are not yet ready for production. @@ -951,7 +966,8 @@ func (api *API) Close() error { if coordinator != nil { _ = (*coordinator).Close() } - return api.workspaceAgentCache.Close() + _ = api.agentProvider.Close() + return nil } func compressHandler(h http.Handler) http.Handler { diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index f4bf035311e6a..d073b48824bd4 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -109,6 +109,7 @@ type Options struct { GitAuthConfigs []*gitauth.Config TrialGenerator func(context.Context, string) error TemplateScheduleStore schedule.TemplateScheduleStore + Coordinator tailnet.Coordinator HealthcheckFunc func(ctx context.Context, apiKey string) *healthcheck.Report HealthcheckTimeout time.Duration diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index 5b53fcaa047e4..2ece768671280 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -302,7 +302,7 @@ func TestAgents(t *testing.T) { coordinator := tailnet.NewCoordinator(slogtest.Make(t, nil).Leveled(slog.LevelDebug)) coordinatorPtr := atomic.Pointer[tailnet.Coordinator]{} coordinatorPtr.Store(&coordinator) - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) agentInactiveDisconnectTimeout := 1 * time.Hour // don't need to focus on this value in tests registry := prometheus.NewRegistry() diff --git a/coderd/tailnet.go b/coderd/tailnet.go new file mode 100644 index 0000000000000..a1559e4efcd52 --- /dev/null +++ b/coderd/tailnet.go @@ -0,0 +1,339 @@ +package coderd + +import ( + "bufio" + "context" + "net" + "net/http" + "net/http/httputil" + "net/netip" + "net/url" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "tailscale.com/derp" + "tailscale.com/tailcfg" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/wsconncache" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/site" + "github.com/coder/coder/tailnet" +) + +var tailnetTransport *http.Transport + +func init() { + var valid bool + tailnetTransport, valid = http.DefaultTransport.(*http.Transport) + if !valid { + panic("dev error: default transport is the wrong type") + } +} + +// NewServerTailnet creates a new tailnet intended for use by coderd. It +// automatically falls back to wsconncache if a legacy agent is encountered. +func NewServerTailnet( + ctx context.Context, + logger slog.Logger, + derpServer *derp.Server, + derpMap *tailcfg.DERPMap, + coord *atomic.Pointer[tailnet.Coordinator], + cache *wsconncache.Cache, +) (*ServerTailnet, error) { + logger = logger.Named("servertailnet") + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: derpMap, + Logger: logger, + }) + if err != nil { + return nil, xerrors.Errorf("create tailnet conn: %w", err) + } + + serverCtx, cancel := context.WithCancel(ctx) + tn := &ServerTailnet{ + ctx: serverCtx, + cancel: cancel, + logger: logger, + conn: conn, + coord: coord, + cache: cache, + agentNodes: map[uuid.UUID]time.Time{}, + agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{}, + transport: tailnetTransport.Clone(), + } + tn.transport.DialContext = tn.dialContext + tn.transport.MaxIdleConnsPerHost = 10 + tn.transport.MaxIdleConns = 0 + agentConn := (*coord.Load()).ServeMultiAgent(uuid.New()) + tn.agentConn.Store(&agentConn) + + err = tn.getAgentConn().UpdateSelf(conn.Node()) + if err != nil { + tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err)) + } + conn.SetNodeCallback(func(node *tailnet.Node) { + err := tn.getAgentConn().UpdateSelf(node) + if err != nil { + tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err)) + } + }) + + // This is set to allow local DERP traffic to be proxied through memory + // instead of needing to hit the external access URL. Don't use the ctx + // given in this callback, it's only valid while connecting. + conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn { + if !region.EmbeddedRelay { + return nil + } + left, right := net.Pipe() + go func() { + defer left.Close() + defer right.Close() + brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right)) + derpServer.Accept(ctx, right, brw, "internal") + }() + return left + }) + + go tn.watchAgentUpdates() + go tn.expireOldAgents() + return tn, nil +} + +func (s *ServerTailnet) expireOldAgents() { + const ( + tick = 5 * time.Minute + cutoff = 30 * time.Minute + ) + + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + } + + s.nodesMu.Lock() + agentConn := s.getAgentConn() + for agentID, lastConnection := range s.agentNodes { + // If no one has connected since the cutoff and there are no active + // connections, remove the agent. + if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 { + _ = agentConn + // err := agentConn.UnsubscribeAgent(agentID) + // if err != nil { + // s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID)) + // } + // delete(s.agentNodes, agentID) + + // TODO(coadler): actually remove from the netmap, then reenable + // the above + } + } + s.nodesMu.Unlock() + } +} + +func (s *ServerTailnet) watchAgentUpdates() { + for { + conn := s.getAgentConn() + nodes, ok := conn.NextUpdate(s.ctx) + if !ok { + if conn.IsClosed() && s.ctx.Err() == nil { + s.reinitCoordinator() + continue + } + return + } + + err := s.conn.UpdateNodes(nodes, false) + if err != nil { + s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err)) + return + } + } +} + +func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn { + return *s.agentConn.Load() +} + +func (s *ServerTailnet) reinitCoordinator() { + s.nodesMu.Lock() + agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New()) + s.agentConn.Store(&agentConn) + + // Resubscribe to all of the agents we're tracking. + for agentID := range s.agentNodes { + err := agentConn.SubscribeAgent(agentID) + if err != nil { + s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID)) + } + } + s.nodesMu.Unlock() +} + +type ServerTailnet struct { + ctx context.Context + cancel func() + + logger slog.Logger + conn *tailnet.Conn + coord *atomic.Pointer[tailnet.Coordinator] + agentConn atomic.Pointer[tailnet.MultiAgentConn] + cache *wsconncache.Cache + nodesMu sync.Mutex + // agentNodes is a map of agent tailnetNodes the server wants to keep a + // connection to. It contains the last time the agent was connected to. + agentNodes map[uuid.UUID]time.Time + // agentTockets holds a map of all open connections to an agent. + agentTickets map[uuid.UUID]map[uuid.UUID]struct{} + + transport *http.Transport +} + +func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) { + proxy := httputil.NewSingleHostReverseProxy(targetURL) + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + site.RenderStaticErrorPage(w, r, site.ErrorPageData{ + Status: http.StatusBadGateway, + Title: "Bad Gateway", + Description: "Failed to proxy request to application: " + err.Error(), + RetryEnabled: true, + DashboardURL: dashboardURL.String(), + }) + } + proxy.Director = s.director(agentID, proxy.Director) + proxy.Transport = s.transport + + return proxy, func() {}, nil +} + +type agentIDKey struct{} + +// director makes sure agentIDKey is set on the context in the reverse proxy. +// This allows the transport to correctly identify which agent to dial to. +func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) { + return func(req *http.Request) { + ctx := context.WithValue(req.Context(), agentIDKey{}, agentID) + *req = *req.WithContext(ctx) + prev(req) + } +} + +func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { + agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID) + if !ok { + return nil, xerrors.Errorf("no agent id attached") + } + + return s.DialAgentNetConn(ctx, agentID, network, addr) +} + +func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error { + s.nodesMu.Lock() + defer s.nodesMu.Unlock() + + _, ok := s.agentNodes[agentID] + // If we don't have the node, subscribe. + if !ok { + s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID)) + err := s.getAgentConn().SubscribeAgent(agentID) + if err != nil { + return xerrors.Errorf("subscribe agent: %w", err) + } + s.agentTickets[agentID] = map[uuid.UUID]struct{}{} + } + + s.agentNodes[agentID] = time.Now() + return nil +} + +func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) { + var ( + conn *codersdk.WorkspaceAgentConn + ret = func() {} + ) + + if s.getAgentConn().AgentIsLegacy(agentID) { + s.logger.Debug(s.ctx, "acquiring legacy agent", slog.F("agent_id", agentID)) + cconn, release, err := s.cache.Acquire(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err) + } + + conn = cconn.WorkspaceAgentConn + ret = release + } else { + err := s.ensureAgent(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("ensure agent: %w", err) + } + + s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID)) + conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: agentID, + CloseFunc: func() error { return codersdk.ErrSkipClose }, + }) + } + + // Since we now have an open conn, be careful to close it if we error + // without returning it to the user. + + reachable := conn.AwaitReachable(ctx) + if !reachable { + ret() + conn.Close() + return nil, nil, xerrors.New("agent is unreachable") + } + + return conn, ret, nil +} + +func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) { + conn, release, err := s.AgentConn(ctx, agentID) + if err != nil { + return nil, xerrors.Errorf("acquire agent conn: %w", err) + } + + // Since we now have an open conn, be careful to close it if we error + // without returning it to the user. + + nc, err := conn.DialContext(ctx, network, addr) + if err != nil { + release() + conn.Close() + return nil, xerrors.Errorf("dial context: %w", err) + } + + return &netConnCloser{Conn: nc, close: func() { + release() + conn.Close() + }}, err +} + +type netConnCloser struct { + net.Conn + close func() +} + +func (c *netConnCloser) Close() error { + c.close() + return c.Conn.Close() +} + +func (s *ServerTailnet) Close() error { + s.cancel() + _ = s.cache.Close() + _ = s.conn.Close() + s.transport.CloseIdleConnections() + return nil +} diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go new file mode 100644 index 0000000000000..16d597607312c --- /dev/null +++ b/coderd/tailnet_test.go @@ -0,0 +1,207 @@ +package coderd_test + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "sync/atomic" + "testing" + + "github.com/google/uuid" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agenttest" + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/wsconncache" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/tailnet" + "github.com/coder/coder/tailnet/tailnettest" + "github.com/coder/coder/testutil" +) + +func TestServerTailnet_AgentConn_OK(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + // Connect through the ServerTailnet + agentID, _, serverTailnet := setupAgent(t, nil) + + conn, release, err := serverTailnet.AgentConn(ctx, agentID) + require.NoError(t, err) + defer release() + + assert.True(t, conn.AwaitReachable(ctx)) +} + +func TestServerTailnet_AgentConn_Legacy(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + // Force a connection through wsconncache using the legacy hardcoded ip. + agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{ + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + }) + + conn, release, err := serverTailnet.AgentConn(ctx, agentID) + require.NoError(t, err) + defer release() + + assert.True(t, conn.AwaitReachable(ctx)) +} + +func TestServerTailnet_ReverseProxy_OK(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + // Force a connection through wsconncache using the legacy hardcoded ip. + agentID, _, serverTailnet := setupAgent(t, nil) + + u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort)) + require.NoError(t, err) + + rp, release, err := serverTailnet.ReverseProxy(u, u, agentID) + require.NoError(t, err) + defer release() + + rw := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + u.String(), + nil, + ).WithContext(ctx) + + rp.ServeHTTP(rw, req) + res := rw.Result() + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) +} + +func TestServerTailnet_ReverseProxy_Legacy(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + // Force a connection through wsconncache using the legacy hardcoded ip. + agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{ + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + }) + + u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort)) + require.NoError(t, err) + + rp, release, err := serverTailnet.ReverseProxy(u, u, agentID) + require.NoError(t, err) + defer release() + + rw := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + u.String(), + nil, + ).WithContext(ctx) + + rp.ServeHTTP(rw, req) + res := rw.Result() + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) +} + +func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) { + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + derpMap, derpServer := tailnettest.RunDERPAndSTUN(t) + manifest := agentsdk.Manifest{ + AgentID: uuid.New(), + DERPMap: derpMap, + } + + var coordPtr atomic.Pointer[tailnet.Coordinator] + coord := tailnet.NewCoordinator(logger) + coordPtr.Store(&coord) + t.Cleanup(func() { + _ = coord.Close() + }) + + c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord) + + options := agent.Options{ + Client: c, + Filesystem: afero.NewMemMapFs(), + Logger: logger.Named("agent"), + Addresses: agentAddresses, + } + + ag := agent.New(options) + t.Cleanup(func() { + _ = ag.Close() + }) + + // Wait for the agent to connect. + require.Eventually(t, func() bool { + return coord.Node(manifest.AgentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { + conn, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: manifest.DERPMap, + Logger: logger.Named("client"), + }) + require.NoError(t, err) + clientConn, serverConn := net.Pipe() + serveClientDone := make(chan struct{}) + t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() + _ = conn.Close() + <-serveClientDone + }) + go func() { + defer close(serveClientDone) + coord.ServeClient(serverConn, uuid.New(), manifest.AgentID) + }() + sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { + return conn.UpdateNodes(node, false) + }) + conn.SetNodeCallback(sendNode) + return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: manifest.AgentID, + AgentIP: codersdk.WorkspaceAgentIP, + CloseFunc: func() error { return codersdk.ErrSkipClose }, + }), nil + }, 0) + + serverTailnet, err := coderd.NewServerTailnet( + context.Background(), + logger, + derpServer, + manifest.DERPMap, + &coordPtr, + cache, + ) + require.NoError(t, err) + + t.Cleanup(func() { + _ = serverTailnet.Close() + }) + + return manifest.AgentID, ag, serverTailnet +} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index bfe61b4a180df..c1f2e90c02de9 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -161,6 +161,7 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request) } httpapi.Write(ctx, rw, http.StatusOK, agentsdk.Manifest{ + AgentID: apiAgent.ID, Apps: convertApps(dbApps), DERPMap: api.DERPMap, GitAuthConfigs: len(api.GitAuthConfigs), @@ -654,7 +655,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req return } - agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID) + agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error dialing workspace agent.", @@ -729,7 +730,9 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req httpapi.Write(ctx, rw, http.StatusOK, portsResponse) } -func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { +// Deprecated: use api.tailnet.AgentConn instead. +// See: https://github.com/coder/coder/issues/8218 +func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { clientConn, serverConn := net.Pipe() conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, @@ -765,14 +768,16 @@ func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspac return nil }) conn.SetNodeCallback(sendNodes) - agentConn := &codersdk.WorkspaceAgentConn{ - Conn: conn, - CloseFunc: func() { + agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: agentID, + AgentIP: codersdk.WorkspaceAgentIP, + CloseFunc: func() error { cancel() _ = clientConn.Close() _ = serverConn.Close() + return nil }, - } + }) go func() { err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID) if err != nil { diff --git a/coderd/workspaceapps/apptest/setup.go b/coderd/workspaceapps/apptest/setup.go index 9432e09c9703d..0f0167f37bf79 100644 --- a/coderd/workspaceapps/apptest/setup.go +++ b/coderd/workspaceapps/apptest/setup.go @@ -399,7 +399,8 @@ func doWithRetries(t require.TestingT, client *codersdk.Client, req *http.Reques return resp, err } -func requestWithRetries(ctx context.Context, t require.TestingT, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) { +func requestWithRetries(ctx context.Context, t testing.TB, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) { + t.Helper() var resp *http.Response var err error require.Eventually(t, func() bool { diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 1d3e8592d7a1c..9b2d9c4bfa297 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -23,7 +23,6 @@ import ( "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/coderd/util/slice" - "github.com/coder/coder/coderd/wsconncache" "github.com/coder/coder/codersdk" "github.com/coder/coder/site" ) @@ -61,6 +60,22 @@ var nonCanonicalHeaders = map[string]string{ "Sec-Websocket-Version": "Sec-WebSocket-Version", } +type AgentProvider interface { + // ReverseProxy returns an httputil.ReverseProxy for proxying HTTP requests + // to the specified agent. + // + // TODO: after wsconncache is deleted this doesn't need to return an error. + ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) + + // AgentConn returns a new connection to the specified agent. + // + // TODO: after wsconncache is deleted this doesn't need to return a release + // func. + AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error) + + Close() error +} + // Server serves workspace apps endpoints, including: // - Path-based apps // - Subdomain app middleware @@ -83,7 +98,6 @@ type Server struct { RealIPConfig *httpmw.RealIPConfig SignedTokenProvider SignedTokenProvider - WorkspaceConnCache *wsconncache.Cache AppSecurityKey SecurityKey // DisablePathApps disables path-based apps. This is a security feature as path @@ -95,6 +109,8 @@ type Server struct { DisablePathApps bool SecureAuthCookie bool + AgentProvider AgentProvider + websocketWaitMutex sync.Mutex websocketWaitGroup sync.WaitGroup } @@ -106,8 +122,8 @@ func (s *Server) Close() error { s.websocketWaitGroup.Wait() s.websocketWaitMutex.Unlock() - // The caller must close the SignedTokenProvider (if necessary) and the - // wsconncache. + // The caller must close the SignedTokenProvider and the AgentProvider (if + // necessary). return nil } @@ -517,18 +533,7 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT r.URL.Path = path appURL.RawQuery = "" - proxy := httputil.NewSingleHostReverseProxy(appURL) - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ - Status: http.StatusBadGateway, - Title: "Bad Gateway", - Description: "Failed to proxy request to application: " + err.Error(), - RetryEnabled: true, - DashboardURL: s.DashboardURL.String(), - }) - } - - conn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID) + proxy, release, err := s.AgentProvider.ReverseProxy(appURL, s.DashboardURL, appToken.AgentID) if err != nil { site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ Status: http.StatusBadGateway, @@ -540,7 +545,6 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT return } defer release() - proxy.Transport = conn.HTTPTransport() proxy.ModifyResponse = func(r *http.Response) error { r.Header.Del(httpmw.AccessControlAllowOriginHeader) @@ -658,13 +662,14 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { go httpapi.Heartbeat(ctx, conn) - agentConn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID) + agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID) if err != nil { log.Debug(ctx, "dial workspace agent", slog.Error(err)) _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err)) return } defer release() + defer agentConn.Close() log.Debug(ctx, "dialed workspace agent") ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command")) if err != nil { diff --git a/coderd/wsconncache/wsconncache.go b/coderd/wsconncache/wsconncache.go index 19c7f65f9fb74..13d1588384954 100644 --- a/coderd/wsconncache/wsconncache.go +++ b/coderd/wsconncache/wsconncache.go @@ -1,9 +1,12 @@ // Package wsconncache caches workspace agent connections by UUID. +// Deprecated: Use ServerTailnet instead. package wsconncache import ( "context" "net/http" + "net/http/httputil" + "net/url" "sync" "time" @@ -13,13 +16,57 @@ import ( "golang.org/x/xerrors" "github.com/coder/coder/codersdk" + "github.com/coder/coder/site" ) -// New creates a new workspace connection cache that closes -// connections after the inactive timeout provided. +type AgentProvider struct { + Cache *Cache +} + +func (a *AgentProvider) AgentConn(_ context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) { + conn, rel, err := a.Cache.Acquire(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("acquire agent connection: %w", err) + } + + return conn.WorkspaceAgentConn, rel, nil +} + +func (a *AgentProvider) ReverseProxy(targetURL *url.URL, dashboardURL *url.URL, agentID uuid.UUID) (*httputil.ReverseProxy, func(), error) { + proxy := httputil.NewSingleHostReverseProxy(targetURL) + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + site.RenderStaticErrorPage(w, r, site.ErrorPageData{ + Status: http.StatusBadGateway, + Title: "Bad Gateway", + Description: "Failed to proxy request to application: " + err.Error(), + RetryEnabled: true, + DashboardURL: dashboardURL.String(), + }) + } + + conn, release, err := a.Cache.Acquire(agentID) + if err != nil { + return nil, nil, xerrors.Errorf("acquire agent connection: %w", err) + } + + proxy.Transport = conn.HTTPTransport() + + return proxy, release, nil +} + +func (a *AgentProvider) Close() error { + return a.Cache.Close() +} + +// New creates a new workspace connection cache that closes connections after +// the inactive timeout provided. +// +// Agent connections are cached due to Wireguard negotiation taking a few +// hundred milliseconds, depending on latency. // -// Agent connections are cached due to WebRTC negotiation -// taking a few hundred milliseconds. +// Deprecated: Use coderd.NewServerTailnet instead. wsconncache is being phased +// out because it creates a unique Tailnet for each agent. +// See: https://github.com/coder/coder/issues/8218 func New(dialer Dialer, inactiveTimeout time.Duration) *Cache { if inactiveTimeout == 0 { inactiveTimeout = 5 * time.Minute diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 34b92267080e5..276e528313751 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -157,22 +157,23 @@ func TestCache(t *testing.T) { func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn { t.Helper() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - manifest.DERPMap = tailnettest.RunDERPAndSTUN(t) + manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) coordinator := tailnet.NewCoordinator(logger) t.Cleanup(func() { _ = coordinator.Close() }) - agentID := uuid.New() + manifest.AgentID = uuid.New() closer := agent.New(agent.Options{ Client: &client{ t: t, - agentID: agentID, + agentID: manifest.AgentID, manifest: manifest, coordinator: coordinator, }, Logger: logger.Named("agent"), ReconnectingPTYTimeout: ptyTimeout, + Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, }) t.Cleanup(func() { _ = closer.Close() @@ -189,14 +190,15 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati _ = serverConn.Close() _ = conn.Close() }) - go coordinator.ServeClient(serverConn, uuid.New(), agentID) + go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID) sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { return conn.UpdateNodes(node, false) }) conn.SetNodeCallback(sendNode) - agentConn := &codersdk.WorkspaceAgentConn{ - Conn: conn, - } + agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ + AgentID: manifest.AgentID, + AgentIP: codersdk.WorkspaceAgentIP, + }) t.Cleanup(func() { _ = agentConn.Close() }) diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index 1e281ef494099..bf150cd84940f 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -84,6 +84,7 @@ func (c *Client) PostMetadata(ctx context.Context, key string, req PostMetadataR } type Manifest struct { + AgentID uuid.UUID `json:"agent_id"` // GitAuthConfigs stores the number of Git configurations // the Coder deployment has. If this number is >0, we // set up special configuration in the workspace. diff --git a/codersdk/deployment.go b/codersdk/deployment.go index 3921963e86f4b..79266441b6dc6 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -1764,6 +1764,12 @@ const ( // oidc. ExperimentConvertToOIDC Experiment = "convert-to-oidc" + // ExperimentSingleTailnet replaces workspace connections inside coderd to + // all use a single tailnet, instead of the previous behavior of creating a + // single tailnet for each agent. + // WARNING: This cannot be enabled when using HA. + ExperimentSingleTailnet Experiment = "single_tailnet" + ExperimentWorkspaceBuildLogsUI Experiment = "workspace_build_logs_ui" // Add new experiments here! // ExperimentExample Experiment = "example" diff --git a/codersdk/workspaceagentconn.go b/codersdk/workspaceagentconn.go index 64bd4fe2f8bfa..6b9b6f0d33f44 100644 --- a/codersdk/workspaceagentconn.go +++ b/codersdk/workspaceagentconn.go @@ -15,6 +15,7 @@ import ( "time" "github.com/google/uuid" + "github.com/hashicorp/go-multierror" "golang.org/x/crypto/ssh" "golang.org/x/xerrors" "tailscale.com/ipn/ipnstate" @@ -27,8 +28,14 @@ import ( // WorkspaceAgentIP is a static IPv6 address with the Tailscale prefix that is used to route // connections from clients to this node. A dynamic address is not required because a Tailnet // client only dials a single agent at a time. +// +// Deprecated: use tailnet.IP() instead. This is kept for backwards +// compatibility with wsconncache. +// See: https://github.com/coder/coder/issues/8218 var WorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") +var ErrSkipClose = xerrors.New("skip tailnet close") + const ( WorkspaceAgentSSHPort = tailnet.WorkspaceAgentSSHPort WorkspaceAgentReconnectingPTYPort = tailnet.WorkspaceAgentReconnectingPTYPort @@ -120,11 +127,38 @@ func init() { } } +// NewWorkspaceAgentConn creates a new WorkspaceAgentConn. `conn` may be unique +// to the WorkspaceAgentConn, or it may be shared in the case of coderd. If the +// conn is shared and closing it is undesirable, you may return ErrNoClose from +// opts.CloseFunc. This will ensure the underlying conn is not closed. +func NewWorkspaceAgentConn(conn *tailnet.Conn, opts WorkspaceAgentConnOptions) *WorkspaceAgentConn { + return &WorkspaceAgentConn{ + Conn: conn, + opts: opts, + } +} + // WorkspaceAgentConn represents a connection to a workspace agent. // @typescript-ignore WorkspaceAgentConn type WorkspaceAgentConn struct { *tailnet.Conn - CloseFunc func() + opts WorkspaceAgentConnOptions +} + +// @typescript-ignore WorkspaceAgentConnOptions +type WorkspaceAgentConnOptions struct { + AgentID uuid.UUID + AgentIP netip.Addr + CloseFunc func() error +} + +func (c *WorkspaceAgentConn) agentAddress() netip.Addr { + var emptyIP netip.Addr + if cmp := c.opts.AgentIP.Compare(emptyIP); cmp != 0 { + return c.opts.AgentIP + } + + return tailnet.IPFromUUID(c.opts.AgentID) } // AwaitReachable waits for the agent to be reachable. @@ -132,7 +166,7 @@ func (c *WorkspaceAgentConn) AwaitReachable(ctx context.Context) bool { ctx, span := tracing.StartSpan(ctx) defer span.End() - return c.Conn.AwaitReachable(ctx, WorkspaceAgentIP) + return c.Conn.AwaitReachable(ctx, c.agentAddress()) } // Ping pings the agent and returns the round-trip time. @@ -141,13 +175,20 @@ func (c *WorkspaceAgentConn) Ping(ctx context.Context) (time.Duration, bool, *ip ctx, span := tracing.StartSpan(ctx) defer span.End() - return c.Conn.Ping(ctx, WorkspaceAgentIP) + return c.Conn.Ping(ctx, c.agentAddress()) } // Close ends the connection to the workspace agent. func (c *WorkspaceAgentConn) Close() error { - if c.CloseFunc != nil { - c.CloseFunc() + var cerr error + if c.opts.CloseFunc != nil { + cerr = c.opts.CloseFunc() + if xerrors.Is(cerr, ErrSkipClose) { + return nil + } + } + if cerr != nil { + return multierror.Append(cerr, c.Conn.Close()) } return c.Conn.Close() } @@ -176,10 +217,12 @@ type ReconnectingPTYRequest struct { func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort)) + + conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentReconnectingPTYPort)) if err != nil { return nil, err } @@ -209,10 +252,12 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort)) + + return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentSSHPort)) } // SSHClient calls SSH to create a client that uses a weak cipher @@ -220,10 +265,12 @@ func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) { func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + netConn, err := c.SSH(ctx) if err != nil { return nil, xerrors.Errorf("ssh: %w", err) } + sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{ // SSH host validation isn't helpful, because obtaining a peer // connection already signifies user-intent to dial a workspace. @@ -233,6 +280,7 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) if err != nil { return nil, xerrors.Errorf("ssh conn: %w", err) } + return ssh.NewClient(sshConn, channels, requests), nil } @@ -240,17 +288,21 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort)) + + speedConn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentSpeedtestPort)) if err != nil { return nil, xerrors.Errorf("dial speedtest: %w", err) } + results, err := speedtest.RunClientWithConn(direction, duration, speedConn) if err != nil { return nil, xerrors.Errorf("run speedtest: %w", err) } + return results, err } @@ -259,19 +311,23 @@ func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest. func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - if network == "unix" { - return nil, xerrors.New("network must be tcp or udp") - } - _, rawPort, _ := net.SplitHostPort(addr) - port, _ := strconv.ParseUint(rawPort, 10, 16) - ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port)) + if !c.AwaitReachable(ctx) { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - if network == "udp" { + + _, rawPort, _ := net.SplitHostPort(addr) + port, _ := strconv.ParseUint(rawPort, 10, 16) + ipp := netip.AddrPortFrom(c.agentAddress(), uint16(port)) + + switch network { + case "tcp": + return c.Conn.DialContextTCP(ctx, ipp) + case "udp": return c.Conn.DialContextUDP(ctx, ipp) + default: + return nil, xerrors.Errorf("unknown network %q", network) } - return c.Conn.DialContextTCP(ctx, ipp) } type WorkspaceAgentListeningPortsResponse struct { @@ -309,7 +365,8 @@ func (c *WorkspaceAgentConn) ListeningPorts(ctx context.Context) (WorkspaceAgent func (c *WorkspaceAgentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - host := net.JoinHostPort(WorkspaceAgentIP.String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort)) + + host := net.JoinHostPort(c.agentAddress().String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort)) url := fmt.Sprintf("http://%s%s", host, path) req, err := http.NewRequestWithContext(ctx, method, url, body) @@ -332,13 +389,14 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client { if network != "tcp" { return nil, xerrors.Errorf("network must be tcp") } + host, port, err := net.SplitHostPort(addr) if err != nil { return nil, xerrors.Errorf("split host port %q: %w", addr, err) } - // Verify that host is TailnetIP and port is - // TailnetStatisticsPort. - if host != WorkspaceAgentIP.String() || port != strconv.Itoa(WorkspaceAgentHTTPAPIServerPort) { + + // Verify that the port is TailnetStatisticsPort. + if port != strconv.Itoa(WorkspaceAgentHTTPAPIServerPort) { return nil, xerrors.Errorf("request %q does not appear to be for http api", addr) } @@ -346,7 +404,12 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client { return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err()) } - conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort)) + ipAddr, err := netip.ParseAddr(host) + if err != nil { + return nil, xerrors.Errorf("parse host addr: %w", err) + } + + conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(ipAddr, WorkspaceAgentHTTPAPIServerPort)) if err != nil { return nil, xerrors.Errorf("dial http api: %w", err) } diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 208c4511d261a..b76ebba9344f5 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -307,8 +307,8 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) continue } - sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error { - return conn.UpdateNodes(node, false) + sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error { + return conn.UpdateNodes(nodes, false) }) conn.SetNodeCallback(sendNode) options.Logger.Debug(ctx, "serving coordinator") @@ -330,13 +330,15 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti return nil, err } - agentConn = &WorkspaceAgentConn{ - Conn: conn, - CloseFunc: func() { + agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{ + AgentID: agentID, + CloseFunc: func() error { cancel() <-closed + return conn.Close() }, - } + }) + if !agentConn.AwaitReachable(ctx) { _ = agentConn.Close() return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err()) diff --git a/docs/api/agents.md b/docs/api/agents.md index 69ff2fbe72318..7dcce5d52e847 100644 --- a/docs/api/agents.md +++ b/docs/api/agents.md @@ -292,6 +292,7 @@ curl -X GET http://coder-server:8080/api/v2/workspaceagents/me/manifest \ ```json { + "agent_id": "string", "apps": [ { "command": "string", diff --git a/docs/api/schemas.md b/docs/api/schemas.md index 2a0861d413573..a042b6d1f6e04 100644 --- a/docs/api/schemas.md +++ b/docs/api/schemas.md @@ -161,6 +161,7 @@ ```json { + "agent_id": "string", "apps": [ { "command": "string", @@ -260,6 +261,7 @@ | Name | Type | Required | Restrictions | Description | | ---------------------------- | ------------------------------------------------------------------------------------------------- | -------- | ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `agent_id` | string | false | | | | `apps` | array of [codersdk.WorkspaceApp](#codersdkworkspaceapp) | false | | | | `derpmap` | [tailcfg.DERPMap](#tailcfgderpmap) | false | | | | `directory` | string | false | | | @@ -2543,6 +2545,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in | `workspace_actions` | | `tailnet_ha_coordinator` | | `convert-to-oidc` | +| `single_tailnet` | | `workspace_build_logs_ui` | ## codersdk.Feature diff --git a/enterprise/coderd/appearance_test.go b/enterprise/coderd/appearance_test.go index dc6ce99052b60..6f564eaa3a680 100644 --- a/enterprise/coderd/appearance_test.go +++ b/enterprise/coderd/appearance_test.go @@ -6,9 +6,8 @@ import ( "net/http" "testing" - "github.com/stretchr/testify/require" - "github.com/google/uuid" + "github.com/stretchr/testify/require" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/coderd/coderdtest" diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index b0d9cfa64032f..889df136710c5 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -17,6 +17,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/coderd/database/pubsub" + "github.com/coder/coder/codersdk" agpl "github.com/coder/coder/tailnet" ) @@ -37,9 +38,12 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err closeFunc: cancelFunc, close: make(chan struct{}), nodes: map[uuid.UUID]*agpl.Node{}, - agentSockets: map[uuid.UUID]*agpl.TrackedConn{}, - agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn{}, + agentSockets: map[uuid.UUID]agpl.Queue{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]agpl.Queue{}, agentNameCache: nameCache, + clients: map[uuid.UUID]agpl.Queue{}, + clientsToAgents: map[uuid.UUID]map[uuid.UUID]agpl.Queue{}, + legacyAgents: map[uuid.UUID]struct{}{}, } if err := coord.runPubsub(ctx); err != nil { @@ -49,6 +53,56 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err return coord, nil } +func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { + m := (&agpl.MultiAgent{ + ID: id, + Logger: c.log, + AgentIsLegacyFunc: c.agentIsLegacy, + OnSubscribe: c.clientSubscribeToAgent, + OnNodeUpdate: c.clientNodeUpdate, + OnRemove: c.clientDisconnected, + }).Init() + c.addClient(id, m) + return m +} + +func (c *haCoordinator) addClient(id uuid.UUID, q agpl.Queue) { + c.mutex.Lock() + c.clients[id] = q + c.clientsToAgents[id] = map[uuid.UUID]agpl.Queue{} + c.mutex.Unlock() +} + +func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.initOrSetAgentConnectionSocketLocked(agentID, enq) + + node := c.nodes[enq.UniqueID()] + if node != nil { + err := c.sendNodeToAgentLocked(agentID, node) + if err != nil { + return nil, xerrors.Errorf("handle client update: %w", err) + } + } + + agentNode, ok := c.nodes[agentID] + // If we have the node locally, give it back to the multiagent. + if ok { + return agentNode, nil + } + + // If we don't have the node locally, notify other coordinators. + err := c.publishClientHello(agentID) + if err != nil { + return nil, xerrors.Errorf("publish client hello: %w", err) + } + + // nolint:nilnil + return nil, nil +} + type haCoordinator struct { id uuid.UUID log slog.Logger @@ -60,14 +114,26 @@ type haCoordinator struct { // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*agpl.Node // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]*agpl.TrackedConn + agentSockets map[uuid.UUID]agpl.Queue // agentToConnectionSockets maps agent IDs to connection IDs of conns that // are subscribed to updates for that agent. - agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Queue + + // clients holds a map of all clients connected to the coordinator. This is + // necessary because a client may not be subscribed into any agents. + clients map[uuid.UUID]agpl.Queue + // clientsToAgents is an index of clients to all of their subscribed agents. + clientsToAgents map[uuid.UUID]map[uuid.UUID]agpl.Queue // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] + + // legacyAgents holda a mapping of all agents detected as legacy, meaning + // they only listen on codersdk.WorkspaceAgentIP. They aren't compatible + // with the new ServerTailnet, so they must be connected through + // wsconncache. + legacyAgents map[uuid.UUID]struct{} } // Node returns an in-memory node by ID. @@ -88,47 +154,62 @@ func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger { // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. -func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { +func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := c.clientLogger(id, agent) - - c.mutex.Lock() - connectionSockets, ok := c.agentToConnectionSockets[agent] - if !ok { - connectionSockets = map[uuid.UUID]*agpl.TrackedConn{} - c.agentToConnectionSockets[agent] = connectionSockets - } + logger := c.clientLogger(id, agentID) tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0) - // Insert this connection into a map so the agent - // can publish node updates. - connectionSockets[id] = tc + defer tc.Close() - // When a new connection is requested, we update it with the latest - // node of the agent. This allows the connection to establish. - node, ok := c.nodes[agent] - if ok { - err := tc.Enqueue([]*agpl.Node{node}) - c.mutex.Unlock() + c.addClient(id, tc) + defer c.clientDisconnected(id) + + agentNode, err := c.clientSubscribeToAgent(tc, agentID) + if err != nil { + return xerrors.Errorf("subscribe agent: %w", err) + } + + if agentNode != nil { + err := tc.Enqueue([]*agpl.Node{agentNode}) if err != nil { - return xerrors.Errorf("enqueue node: %w", err) + logger.Debug(ctx, "enqueue initial node", slog.Error(err)) } - } else { - c.mutex.Unlock() - err := c.publishClientHello(agent) + } + + go tc.SendUpdates() + + decoder := json.NewDecoder(conn) + // Indefinitely handle messages from the client websocket. + for { + err := c.handleNextClientMessage(id, decoder) if err != nil { - return xerrors.Errorf("publish client hello: %w", err) + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { + return nil + } + return xerrors.Errorf("handle next client message: %w", err) } } - go tc.SendUpdates() +} - defer func() { - c.mutex.Lock() - defer c.mutex.Unlock() +func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq agpl.Queue) { + connectionSockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + connectionSockets = map[uuid.UUID]agpl.Queue{} + c.agentToConnectionSockets[agentID] = connectionSockets + } + connectionSockets[enq.UniqueID()] = enq + c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID] +} + +func (c *haCoordinator) clientDisconnected(id uuid.UUID) { + c.mutex.Lock() + defer c.mutex.Unlock() + + for agentID := range c.clientsToAgents[id] { // Clean all traces of this connection from the map. delete(c.nodes, id) - connectionSockets, ok := c.agentToConnectionSockets[agent] + connectionSockets, ok := c.agentToConnectionSockets[agentID] if !ok { return } @@ -136,51 +217,65 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID if len(connectionSockets) != 0 { return } - delete(c.agentToConnectionSockets, agent) - }() - - decoder := json.NewDecoder(conn) - // Indefinitely handle messages from the client websocket. - for { - err := c.handleNextClientMessage(id, agent, decoder) - if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { - return nil - } - return xerrors.Errorf("handle next client message: %w", err) - } + delete(c.agentToConnectionSockets, agentID) } + + delete(c.clients, id) + delete(c.clientsToAgents, id) } -func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { +func (c *haCoordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error { var node agpl.Node err := decoder.Decode(&node) if err != nil { return xerrors.Errorf("read json: %w", err) } + return c.clientNodeUpdate(id, &node) +} + +func (c *haCoordinator) clientNodeUpdate(id uuid.UUID, node *agpl.Node) error { c.mutex.Lock() + defer c.mutex.Unlock() // Update the node of this client in our in-memory map. If an agent entirely // shuts down and reconnects, it needs to be aware of all clients attempting // to establish connections. - c.nodes[id] = &node - // Write the new node from this client to the actively connected agent. - agentSocket, ok := c.agentSockets[agent] + c.nodes[id] = node + for agentID, agentSocket := range c.clientsToAgents[id] { + if agentSocket == nil { + // If we don't own the agent locally, send it over pubsub to a node that + // owns the agent. + err := c.publishNodesToAgent(agentID, []*agpl.Node{node}) + if err != nil { + c.log.Error(context.Background(), "publish node to agent", slog.Error(err), slog.F("agent_id", agentID)) + } + } else { + // Write the new node from this client to the actively connected agent. + err := agentSocket.Enqueue([]*agpl.Node{node}) + if err != nil { + c.log.Error(context.Background(), "enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID)) + } + } + } + + return nil +} + +func (c *haCoordinator) sendNodeToAgentLocked(agentID uuid.UUID, node *agpl.Node) error { + agentSocket, ok := c.agentSockets[agentID] if !ok { - c.mutex.Unlock() // If we don't own the agent locally, send it over pubsub to a node that // owns the agent. - err := c.publishNodesToAgent(agent, []*agpl.Node{&node}) + err := c.publishNodesToAgent(agentID, []*agpl.Node{node}) if err != nil { return xerrors.Errorf("publish node to agent") } return nil } - err = agentSocket.Enqueue([]*agpl.Node{&node}) - c.mutex.Unlock() + err := agentSocket.Enqueue([]*agpl.Node{node}) if err != nil { - return xerrors.Errorf("enqueu nodes: %w", err) + return xerrors.Errorf("enqueue node: %w", err) } return nil } @@ -202,7 +297,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err // dead. oldAgentSocket, ok := c.agentSockets[id] if ok { - overwrites = oldAgentSocket.Overwrites + 1 + overwrites = oldAgentSocket.Overwrites() + 1 _ = oldAgentSocket.Close() } // This uniquely identifies a connection that belongs to this goroutine. @@ -219,6 +314,9 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err } } c.agentSockets[id] = tc + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = tc + } c.mutex.Unlock() go tc.SendUpdates() @@ -234,10 +332,13 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err // Only delete the connection if it's ours. It could have been // overwritten. - if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique { + if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique { delete(c.agentSockets, id) delete(c.nodes, id) } + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = nil + } }() decoder := json.NewDecoder(conn) @@ -285,6 +386,13 @@ func (c *haCoordinator) handleClientHello(id uuid.UUID) error { return c.publishAgentToNodes(id, node) } +func (c *haCoordinator) agentIsLegacy(agentID uuid.UUID) bool { + c.mutex.RLock() + _, ok := c.legacyAgents[agentID] + c.mutex.RUnlock() + return ok +} + func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) { var node agpl.Node err := decoder.Decode(&node) @@ -293,6 +401,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( } c.mutex.Lock() + // Keep a cache of all legacy agents. + if len(node.Addresses) > 0 && node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP { + c.legacyAgents[id] = struct{}{} + } + oldNode := c.nodes[id] if oldNode != nil { if oldNode.AsOf.After(node.AsOf) { @@ -311,7 +424,9 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) ( for _, connectionSocket := range connectionSockets { _ = connectionSocket.Enqueue([]*agpl.Node{&node}) } + c.mutex.Unlock() + return &node, nil } @@ -334,20 +449,18 @@ func (c *haCoordinator) Close() error { for _, socket := range c.agentSockets { socket := socket go func() { - _ = socket.Close() + _ = socket.CoordinatorClose() wg.Done() }() } - for _, connMap := range c.agentToConnectionSockets { - wg.Add(len(connMap)) - for _, socket := range connMap { - socket := socket - go func() { - _ = socket.Close() - wg.Done() - }() - } + wg.Add(len(c.clients)) + for _, client := range c.clients { + client := client + go func() { + _ = client.CoordinatorClose() + wg.Done() + }() } wg.Wait() @@ -422,13 +535,12 @@ func (c *haCoordinator) runPubsub(ctx context.Context) error { } go func() { for { - var message []byte select { case <-ctx.Done(): return - case message = <-messageQueue: + case message := <-messageQueue: + c.handlePubsubMessage(ctx, message) } - c.handlePubsubMessage(ctx, message) } }() diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 6c3cc73ac9e9d..37c516f5aa65e 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -125,6 +125,11 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store return c, nil } +func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { + _, _ = c, id + panic("not implemented") // TODO: Implement +} + func (*pgCoord) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) { // TODO(spikecurtis) I'd like to hold off implementing this until after the rest of this is code reviewed. w.WriteHeader(http.StatusOK) diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index fce5e0cc7a3b1..ae5da832054e2 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -183,9 +183,12 @@ func New(ctx context.Context, opts *Options) (*Server, error) { SecurityKey: secKey, Logger: s.Logger.Named("proxy_token_provider"), }, - WorkspaceConnCache: wsconncache.New(s.DialWorkspaceAgent, 0), - AppSecurityKey: secKey, + AppSecurityKey: secKey, + // TODO: Convert wsproxy to use coderd.ServerTailnet. + AgentProvider: &wsconncache.AgentProvider{ + Cache: wsconncache.New(s.DialWorkspaceAgent, 0), + }, DisablePathApps: opts.DisablePathApps, SecureAuthCookie: opts.SecureAuthCookie, } @@ -273,6 +276,7 @@ func (s *Server) Close() error { tmp, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() _ = s.SDKClient.WorkspaceProxyGoingAway(tmp) + _ = s.AppServer.AgentProvider.Close() return s.AppServer.Close() } diff --git a/scaletest/agentconn/run.go b/scaletest/agentconn/run.go index cefbb660485e5..6593e7c79b52c 100644 --- a/scaletest/agentconn/run.go +++ b/scaletest/agentconn/run.go @@ -6,7 +6,6 @@ import ( "io" "net" "net/http" - "net/netip" "net/url" "strconv" "time" @@ -377,7 +376,10 @@ func agentHTTPClient(conn *codersdk.WorkspaceAgentConn) *http.Client { if err != nil { return nil, xerrors.Errorf("parse port %q: %w", port, err) } - return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.WorkspaceAgentIP, uint16(portUint))) + + // Addr doesn't matter here, besides the port. DialContext will + // automatically choose the right IP to dial. + return conn.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", portUint)) }, }, } diff --git a/scaletest/reconnectingpty/run_test.go b/scaletest/reconnectingpty/run_test.go index f6f70bbf574bf..382a3718436f9 100644 --- a/scaletest/reconnectingpty/run_test.go +++ b/scaletest/reconnectingpty/run_test.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" "github.com/coder/coder/coderd/coderdtest" @@ -243,7 +244,7 @@ func Test_Runner(t *testing.T) { func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID) { t.Helper() - client = coderdtest.New(t, &coderdtest.Options{ + client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ IncludeProvisionerDaemon: true, }) user := coderdtest.CreateFirstUser(t, client) @@ -282,12 +283,16 @@ func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID) agentClient.SetSessionToken(authToken) agentCloser := agent.New(agent.Options{ Client: agentClient, - Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("agent"), + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("agent").Leveled(slog.LevelDebug), }) t.Cleanup(func() { _ = agentCloser.Close() }) resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + require.Eventually(t, func() bool { + t.Log("agent id", resources[0].Agents[0].ID) + return (*api.TailnetCoordinator.Load()).Node(resources[0].Agents[0].ID) != nil + }, testutil.WaitLong, testutil.IntervalMedium, "agent never connected") return client, resources[0].Agents[0].ID } diff --git a/scaletest/workspacetraffic/run_test.go b/scaletest/workspacetraffic/run_test.go index e53d408bcd428..c070a906be228 100644 --- a/scaletest/workspacetraffic/run_test.go +++ b/scaletest/workspacetraffic/run_test.go @@ -68,6 +68,7 @@ func TestRun(t *testing.T) { agentCloser := agent.New(agent.Options{ Client: agentClient, }) + ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) t.Cleanup(func() { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index ee1bcda3736e9..5ba2158c93b36 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1431,12 +1431,14 @@ export const Entitlements: Entitlement[] = [ export type Experiment = | "convert-to-oidc" | "moons" + | "single_tailnet" | "tailnet_ha_coordinator" | "workspace_actions" | "workspace_build_logs_ui" export const Experiments: Experiment[] = [ "convert-to-oidc", "moons", + "single_tailnet", "tailnet_ha_coordinator", "workspace_actions", "workspace_build_logs_ui", diff --git a/tailnet/conn.go b/tailnet/conn.go index 363ccb80ff48c..0534cf961771a 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -139,6 +139,7 @@ func NewConn(options *Options) (conn *Conn, err error) { } }() + IP() dialer := &tsdial.Dialer{ Logf: Logger(options.Logger.Named("tsdial")), } @@ -182,10 +183,17 @@ func NewConn(options *Options) (conn *Conn, err error) { netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey() netStack, err := netstack.Create( - Logger(options.Logger.Named("netstack")), tunDevice, wireguardEngine, magicConn, dialer, dnsManager) + Logger(options.Logger.Named("netstack")), + tunDevice, + wireguardEngine, + magicConn, + dialer, + dnsManager, + ) if err != nil { return nil, xerrors.Errorf("create netstack: %w", err) } + dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { return netStack.DialContextTCP(ctx, dst) } @@ -203,7 +211,14 @@ func NewConn(options *Options) (conn *Conn, err error) { localIPs, _ := localIPSet.IPSet() logIPSet := netipx.IPSetBuilder{} logIPs, _ := logIPSet.IPSet() - wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter")))) + wireguardEngine.SetFilter(filter.New( + netMap.PacketFilter, + localIPs, + logIPs, + nil, + Logger(options.Logger.Named("packet-filter")), + )) + dialContext, dialCancel := context.WithCancel(context.Background()) server := &Conn{ blockEndpoints: options.BlockEndpoints, @@ -230,6 +245,7 @@ func NewConn(options *Options) (conn *Conn, err error) { _ = server.Close() } }() + wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err)) if err != nil { @@ -251,6 +267,7 @@ func NewConn(options *Options) (conn *Conn, err error) { server.lastMutex.Unlock() server.sendNode() }) + wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) { server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni)) server.lastMutex.Lock() @@ -262,6 +279,7 @@ func NewConn(options *Options) (conn *Conn, err error) { server.lastMutex.Unlock() server.sendNode() }) + magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) { server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason)) server.lastMutex.Lock() @@ -273,6 +291,7 @@ func NewConn(options *Options) (conn *Conn, err error) { server.lastMutex.Unlock() server.sendNode() }) + netStack.ForwardTCPIn = server.forwardTCP netStack.ForwardTCPSockOpts = server.forwardTCPSockOpts @@ -284,22 +303,30 @@ func NewConn(options *Options) (conn *Conn, err error) { return server, nil } -// IP generates a new IP with a static service prefix. -func IP() netip.Addr { - // This is Tailscale's ephemeral service prefix. - // This can be changed easily later-on, because - // all of our nodes are ephemeral. +func maskUUID(uid uuid.UUID) uuid.UUID { + // This is Tailscale's ephemeral service prefix. This can be changed easily + // later-on, because all of our nodes are ephemeral. // fd7a:115c:a1e0 - uid := uuid.New() uid[0] = 0xfd uid[1] = 0x7a uid[2] = 0x11 uid[3] = 0x5c uid[4] = 0xa1 uid[5] = 0xe0 + return uid +} + +// IP generates a random IP with a static service prefix. +func IP() netip.Addr { + uid := maskUUID(uuid.New()) return netip.AddrFrom16(uid) } +// IP generates a new IP from a UUID. +func IPFromUUID(uid uuid.UUID) netip.Addr { + return netip.AddrFrom16(maskUUID(uid)) +} + // Conn is an actively listening Wireguard connection. type Conn struct { dialContext context.Context @@ -334,6 +361,29 @@ type Conn struct { trafficStats *connstats.Statistics } +func (c *Conn) SetAddresses(ips []netip.Prefix) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.netMap.Addresses = ips + + netMapCopy := *c.netMap + c.logger.Debug(context.Background(), "updating network map") + c.wireguardEngine.SetNetworkMap(&netMapCopy) + err := c.reconfig() + if err != nil { + return xerrors.Errorf("reconfig: %w", err) + } + + return nil +} + +func (c *Conn) Addresses() []netip.Prefix { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.netMap.Addresses +} + func (c *Conn) SetNodeCallback(callback func(node *Node)) { c.lastMutex.Lock() c.nodeCallback = callback @@ -366,32 +416,6 @@ func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tail c.magicConn.SetDERPRegionDialer(dialer) } -func (c *Conn) RemoveAllPeers() error { - c.mutex.Lock() - defer c.mutex.Unlock() - - c.netMap.Peers = []*tailcfg.Node{} - c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{} - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) - cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "") - if err != nil { - return xerrors.Errorf("update wireguard config: %w", err) - } - err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) - if err != nil { - if c.isClosed() { - return nil - } - if errors.Is(err, wgengine.ErrNoChanges) { - return nil - } - return xerrors.Errorf("reconfig: %w", err) - } - return nil -} - // UpdateNodes connects with a set of peers. This can be constantly updated, // and peers will continually be reconnected as necessary. If replacePeers is // true, all peers will be removed before adding the new ones. @@ -423,6 +447,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { } delete(c.peerMap, peer.ID) } + for _, node := range nodes { // If no preferred DERP is provided, we can't reach the node. if node.PreferredDERP == 0 { @@ -452,17 +477,29 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { } c.peerMap[node.ID] = peerNode } + c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap)) for _, peer := range c.peerMap { c.netMap.Peers = append(c.netMap.Peers, peer.Clone()) } + netMapCopy := *c.netMap c.logger.Debug(context.Background(), "updating network map") c.wireguardEngine.SetNetworkMap(&netMapCopy) + err := c.reconfig() + if err != nil { + return xerrors.Errorf("reconfig: %w", err) + } + + return nil +} + +func (c *Conn) reconfig() error { cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "") if err != nil { return xerrors.Errorf("update wireguard config: %w", err) } + err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) if err != nil { if c.isClosed() { @@ -473,6 +510,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { } return xerrors.Errorf("reconfig: %w", err) } + return nil } diff --git a/tailnet/conn_test.go b/tailnet/conn_test.go index 2e19379e6df03..0dd0812b94777 100644 --- a/tailnet/conn_test.go +++ b/tailnet/conn_test.go @@ -23,7 +23,7 @@ func TestMain(m *testing.M) { func TestTailnet(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) t.Run("InstantClose", func(t *testing.T) { t.Parallel() conn, err := tailnet.NewConn(&tailnet.Options{ @@ -172,7 +172,7 @@ func TestConn_PreferredDERP(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - derpMap := tailnettest.RunDERPAndSTUN(t) + derpMap, _ := tailnettest.RunDERPAndSTUN(t) conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, Logger: logger.Named("w1"), diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 7ff279a508b46..93cf8c67af56b 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -1,7 +1,6 @@ package tailnet import ( - "bytes" "context" "encoding/json" "errors" @@ -11,17 +10,16 @@ import ( "net/http" "net/netip" "sync" - "sync/atomic" "time" - "cdr.dev/slog" - "github.com/google/uuid" lru "github.com/hashicorp/golang-lru/v2" "golang.org/x/exp/slices" "golang.org/x/xerrors" "tailscale.com/tailcfg" "tailscale.com/types/key" + + "cdr.dev/slog" ) // Coordinator exchanges nodes with agents to establish connections. @@ -44,6 +42,8 @@ type Coordinator interface { ServeAgent(conn net.Conn, id uuid.UUID, name string) error // Close closes the coordinator. Close() error + + ServeMultiAgent(id uuid.UUID) MultiAgentConn } // Node represents a node in the network. @@ -54,10 +54,11 @@ type Node struct { AsOf time.Time `json:"as_of"` // Key is the Wireguard public key of the node. Key key.NodePublic `json:"key"` - // DiscoKey is used for discovery messages over DERP to establish peer-to-peer connections. + // DiscoKey is used for discovery messages over DERP to establish + // peer-to-peer connections. DiscoKey key.DiscoPublic `json:"disco"` - // PreferredDERP is the DERP server that peered connections - // should meet at to establish. + // PreferredDERP is the DERP server that peered connections should meet at + // to establish. PreferredDERP int `json:"preferred_derp"` // DERPLatency is the latency in seconds to each DERP server. DERPLatency map[string]float64 `json:"derp_latency"` @@ -68,8 +69,8 @@ type Node struct { DERPForcedWebsocket map[int]string `json:"derp_forced_websockets"` // Addresses are the IP address ranges this connection exposes. Addresses []netip.Prefix `json:"addresses"` - // AllowedIPs specify what addresses can dial the connection. - // We allow all by default. + // AllowedIPs specify what addresses can dial the connection. We allow all + // by default. AllowedIPs []netip.Prefix `json:"allowed_ips"` // Endpoints are ip:port combinations that can be used to establish // peer-to-peer connections. @@ -130,12 +131,33 @@ func NewCoordinator(logger slog.Logger) Coordinator { // ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ // │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ // └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ -// This coordinator is incompatible with multiple Coder -// replicas as all node data is in-memory. +// This coordinator is incompatible with multiple Coder replicas as all node +// data is in-memory. type coordinator struct { core *core } +func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { + m := (&MultiAgent{ + ID: id, + Logger: c.core.logger, + AgentIsLegacyFunc: c.core.agentIsLegacy, + OnSubscribe: c.core.clientSubscribeToAgent, + OnUnsubscribe: c.core.clientUnsubscribeFromAgent, + OnNodeUpdate: c.core.clientNodeUpdate, + OnRemove: c.core.clientDisconnected, + }).Init() + c.core.addClient(id, m) + return m +} + +func (c *core) addClient(id uuid.UUID, ma Queue) { + c.mutex.Lock() + c.clients[id] = ma + c.clientsToAgents[id] = map[uuid.UUID]Queue{} + c.mutex.Unlock() +} + // core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines; // it is protected by a mutex to ensure data stay consistent. type core struct { @@ -146,14 +168,38 @@ type core struct { // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*Node // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]*TrackedConn + agentSockets map[uuid.UUID]Queue // agentToConnectionSockets maps agent IDs to connection IDs of conns that // are subscribed to updates for that agent. - agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*TrackedConn + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Queue + + // clients holds a map of all clients connected to the coordinator. This is + // necessary because a client may not be subscribed into any agents. + clients map[uuid.UUID]Queue + // clientsToAgents is an index of clients to all of their subscribed agents. + clientsToAgents map[uuid.UUID]map[uuid.UUID]Queue // agentNameCache holds a cache of agent names. If one of them disappears, // it's helpful to have a name cached for debugging. agentNameCache *lru.Cache[uuid.UUID, string] + + // legacyAgents holda a mapping of all agents detected as legacy, meaning + // they only listen on codersdk.WorkspaceAgentIP. They aren't compatible + // with the new ServerTailnet, so they must be connected through + // wsconncache. + legacyAgents map[uuid.UUID]struct{} +} + +type Queue interface { + UniqueID() uuid.UUID + Enqueue(n []*Node) error + Name() string + Stats() (start, lastWrite int64) + Overwrites() int64 + // CoordinatorClose is used by the coordinator when closing a Queue. It + // should skip removing itself from the coordinator. + CoordinatorClose() error + Close() error } func newCore(logger slog.Logger) *core { @@ -165,128 +211,18 @@ func newCore(logger slog.Logger) *core { return &core{ logger: logger, closed: false, - nodes: make(map[uuid.UUID]*Node), - agentSockets: map[uuid.UUID]*TrackedConn{}, - agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*TrackedConn{}, + nodes: map[uuid.UUID]*Node{}, + agentSockets: map[uuid.UUID]Queue{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Queue{}, agentNameCache: nameCache, + legacyAgents: map[uuid.UUID]struct{}{}, + clients: map[uuid.UUID]Queue{}, + clientsToAgents: map[uuid.UUID]map[uuid.UUID]Queue{}, } } var ErrWouldBlock = xerrors.New("would block") -type TrackedConn struct { - ctx context.Context - cancel func() - conn net.Conn - updates chan []*Node - logger slog.Logger - lastData []byte - - // ID is an ephemeral UUID used to uniquely identify the owner of the - // connection. - ID uuid.UUID - - Name string - Start int64 - LastWrite int64 - Overwrites int64 -} - -func (t *TrackedConn) Enqueue(n []*Node) (err error) { - atomic.StoreInt64(&t.LastWrite, time.Now().Unix()) - select { - case t.updates <- n: - return nil - default: - return ErrWouldBlock - } -} - -// Close the connection and cancel the context for reading node updates from the queue -func (t *TrackedConn) Close() error { - t.cancel() - return t.conn.Close() -} - -// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung. -// It is exported so that tests can use it. -const WriteTimeout = time.Second * 5 - -// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is -// canceled. -func (t *TrackedConn) SendUpdates() { - for { - select { - case <-t.ctx.Done(): - t.logger.Debug(t.ctx, "done sending updates") - return - case nodes := <-t.updates: - data, err := json.Marshal(nodes) - if err != nil { - t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) - return - } - if bytes.Equal(t.lastData, data) { - t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", string(data))) - continue - } - - // Set a deadline so that hung connections don't put back pressure on the system. - // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. - err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err)) - _ = t.Close() - return - } - _, err = t.conn.Write(data) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "could not write nodes to connection", - slog.Error(err), slog.F("nodes", string(data))) - _ = t.Close() - return - } - t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", string(data))) - - // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are - // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() - // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. - // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after - // our successful write, it is important that we reset the deadline before it fires. - err = t.conn.SetWriteDeadline(time.Time{}) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err)) - _ = t.Close() - return - } - t.lastData = data - } - } -} - -func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, overwrites int64) *TrackedConn { - // buffer updates so they don't block, since we hold the - // coordinator mutex while queuing. Node updates don't - // come quickly, so 512 should be plenty for all but - // the most pathological cases. - updates := make(chan []*Node, 512) - now := time.Now().Unix() - return &TrackedConn{ - ctx: ctx, - conn: conn, - cancel: cancel, - updates: updates, - logger: logger, - ID: id, - Start: now, - LastWrite: now, - Overwrites: overwrites, - } -} - // Node returns an in-memory node by ID. // If the node does not exist, nil is returned. func (c *coordinator) Node(id uuid.UUID) *Node { @@ -321,16 +257,29 @@ func (c *core) agentCount() int { // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. -func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { +func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := c.core.clientLogger(id, agent) + logger := c.core.clientLogger(id, agentID) logger.Debug(ctx, "coordinating client") - tc, err := c.core.initAndTrackClient(ctx, cancel, conn, id, agent) + + tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0) + defer tc.Close() + + c.core.addClient(id, tc) + defer c.core.clientDisconnected(id) + + agentNode, err := c.core.clientSubscribeToAgent(tc, agentID) if err != nil { - return err + return xerrors.Errorf("subscribe agent: %w", err) + } + + if agentNode != nil { + err := tc.Enqueue([]*Node{agentNode}) + if err != nil { + logger.Debug(ctx, "enqueue initial node", slog.Error(err)) + } } - defer c.core.clientDisconnected(id, agent) // On this goroutine, we read updates from the client and publish them. We start a second goroutine // to write updates back to the client. @@ -338,7 +287,7 @@ func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) decoder := json.NewDecoder(conn) for { - err := c.handleNextClientMessage(id, agent, decoder) + err := c.handleNextClientMessage(id, decoder) if err != nil { logger.Debug(ctx, "unable to read client update, connection may be closed", slog.Error(err)) if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { @@ -353,99 +302,133 @@ func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger { return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) } -// initAndTrackClient creates a TrackedConn for the client, and sends any initial Node updates if we have any. It is -// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some -// updates. -func (c *core) initAndTrackClient( - ctx context.Context, cancel func(), conn net.Conn, id, agent uuid.UUID, -) ( - *TrackedConn, error, -) { - logger := c.clientLogger(id, agent) - c.mutex.Lock() - defer c.mutex.Unlock() - if c.closed { - return nil, xerrors.New("coordinator is closed") - } - tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0) - - // When a new connection is requested, we update it with the latest - // node of the agent. This allows the connection to establish. - node, ok := c.nodes[agent] - if ok { - err := tc.Enqueue([]*Node{node}) - // this should never error since we're still the only goroutine that - // knows about the TrackedConn. If we hit an error something really - // wrong is happening - if err != nil { - logger.Critical(ctx, "unable to queue initial node", slog.Error(err)) - return nil, err - } - } - - // Insert this connection into a map so the agent - // can publish node updates. - connectionSockets, ok := c.agentToConnectionSockets[agent] +func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Queue) { + connectionSockets, ok := c.agentToConnectionSockets[agentID] if !ok { - connectionSockets = map[uuid.UUID]*TrackedConn{} - c.agentToConnectionSockets[agent] = connectionSockets + connectionSockets = map[uuid.UUID]Queue{} + c.agentToConnectionSockets[agentID] = connectionSockets } - connectionSockets[id] = tc - logger.Debug(ctx, "added tracked connection") - return tc, nil + connectionSockets[enq.UniqueID()] = enq + + c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID] } -func (c *core) clientDisconnected(id, agent uuid.UUID) { - logger := c.clientLogger(id, agent) +func (c *core) clientDisconnected(id uuid.UUID) { + logger := c.clientLogger(id, uuid.Nil) c.mutex.Lock() defer c.mutex.Unlock() // Clean all traces of this connection from the map. delete(c.nodes, id) logger.Debug(context.Background(), "deleted client node") - connectionSockets, ok := c.agentToConnectionSockets[agent] - if !ok { - return - } - delete(connectionSockets, id) - logger.Debug(context.Background(), "deleted client connectionSocket from map") - if len(connectionSockets) != 0 { - return + + for agentID := range c.clientsToAgents[id] { + connectionSockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + continue + } + delete(connectionSockets, id) + logger.Debug(context.Background(), "deleted client connectionSocket from map", slog.F("agent_id", agentID)) + + if len(connectionSockets) == 0 { + delete(c.agentToConnectionSockets, agentID) + logger.Debug(context.Background(), "deleted last client connectionSocket from map", slog.F("agent_id", agentID)) + } } - delete(c.agentToConnectionSockets, agent) - logger.Debug(context.Background(), "deleted last client connectionSocket from map") + + delete(c.clients, id) + delete(c.clientsToAgents, id) + logger.Debug(context.Background(), "deleted client agents") } -func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { - logger := c.core.clientLogger(id, agent) +func (c *coordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error { + logger := c.core.clientLogger(id, uuid.Nil) + var node Node err := decoder.Decode(&node) if err != nil { return xerrors.Errorf("read json: %w", err) } + logger.Debug(context.Background(), "got client node update", slog.F("node", node)) - return c.core.clientNodeUpdate(id, agent, &node) + return c.core.clientNodeUpdate(id, &node) } -func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error { - logger := c.clientLogger(id, agent) +func (c *core) clientNodeUpdate(id uuid.UUID, node *Node) error { c.mutex.Lock() defer c.mutex.Unlock() + // Update the node of this client in our in-memory map. If an agent entirely // shuts down and reconnects, it needs to be aware of all clients attempting // to establish connections. c.nodes[id] = node - agentSocket, ok := c.agentSockets[agent] - if !ok { - logger.Debug(context.Background(), "no agent socket, unable to send node") - return nil + return c.clientNodeUpdateLocked(id, node) +} + +func (c *core) clientNodeUpdateLocked(id uuid.UUID, node *Node) error { + logger := c.clientLogger(id, uuid.Nil) + + agents := []uuid.UUID{} + for agentID, agentSocket := range c.clientsToAgents[id] { + if agentSocket == nil { + logger.Debug(context.Background(), "enqueue node to agent; socket is nil", slog.F("agent_id", agentID)) + continue + } + + err := agentSocket.Enqueue([]*Node{node}) + if err != nil { + logger.Debug(context.Background(), "unable to Enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID)) + continue + } + agents = append(agents, agentID) } - err := agentSocket.Enqueue([]*Node{node}) - if err != nil { - return xerrors.Errorf("Enqueue node: %w", err) + logger.Debug(context.Background(), "enqueued node to agents", slog.F("agent_ids", agents)) + return nil +} + +func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) (*Node, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + logger := c.clientLogger(enq.UniqueID(), agentID) + + c.initOrSetAgentConnectionSocketLocked(agentID, enq) + + node, ok := c.nodes[enq.UniqueID()] + if ok { + // If we have the client node, send it to the agent. If not, it will be + // sent async. + agentSocket, ok := c.agentSockets[agentID] + if !ok { + logger.Debug(context.Background(), "subscribe to agent; socket is nil") + } else { + err := agentSocket.Enqueue([]*Node{node}) + if err != nil { + return nil, xerrors.Errorf("enqueue client to agent: %w", err) + } + } + } else { + logger.Debug(context.Background(), "multiagent node doesn't exist") } - logger.Debug(context.Background(), "enqueued node to agent") + + agentNode, ok := c.nodes[agentID] + if !ok { + // This is ok, once the agent connects the node will be sent over. + logger.Debug(context.Background(), "agent node doesn't exist", slog.F("agent_id", agentID)) + } + + // Send the subscribed agent back to the multi agent. + return agentNode, nil +} + +func (c *core) clientUnsubscribeFromAgent(enq Queue, agentID uuid.UUID) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + delete(c.clientsToAgents[enq.UniqueID()], agentID) + delete(c.agentToConnectionSockets[agentID], enq.UniqueID()) + return nil } @@ -493,11 +476,14 @@ func (c *core) agentDisconnected(id, unique uuid.UUID) { // Only delete the connection if it's ours. It could have been // overwritten. - if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique { + if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique { delete(c.agentSockets, id) delete(c.nodes, id) logger.Debug(context.Background(), "deleted agent socket and node") } + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = nil + } } // initAndTrackAgent creates a TrackedConn for the agent, and sends any initial nodes updates if we have any. It is @@ -519,7 +505,7 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co // dead. oldAgentSocket, ok := c.agentSockets[id] if ok { - overwrites = oldAgentSocket.Overwrites + 1 + overwrites = oldAgentSocket.Overwrites() + 1 _ = oldAgentSocket.Close() } tc := NewTrackedConn(ctx, cancel, conn, unique, logger, overwrites) @@ -549,6 +535,10 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co } c.agentSockets[id] = tc + for clientID := range c.agentToConnectionSockets[id] { + c.clientsToAgents[clientID][id] = tc + } + logger.Debug(ctx, "added agent socket") return tc, nil } @@ -564,11 +554,31 @@ func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder return c.core.agentNodeUpdate(id, &node) } +// This is copied from codersdk because importing it here would cause an import +// cycle. This is just temporary until wsconncache is phased out. +var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") + +// This is temporary until we no longer need to detect for agent backwards +// compatibility. +// See: https://github.com/coder/coder/issues/8218 +func (c *core) agentIsLegacy(agentID uuid.UUID) bool { + c.mutex.RLock() + _, ok := c.legacyAgents[agentID] + c.mutex.RUnlock() + return ok +} + func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error { logger := c.agentLogger(id) c.mutex.Lock() defer c.mutex.Unlock() c.nodes[id] = node + + // Keep a cache of all legacy agents. + if len(node.Addresses) > 0 && node.Addresses[0].Addr() == legacyAgentIP { + c.legacyAgents[id] = struct{}{} + } + connectionSockets, ok := c.agentToConnectionSockets[id] if !ok { logger.Debug(context.Background(), "no client sockets; unable to send node") @@ -588,6 +598,7 @@ func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error { slog.F("client_id", clientID), slog.Error(err)) } } + return nil } @@ -611,20 +622,18 @@ func (c *core) close() error { for _, socket := range c.agentSockets { socket := socket go func() { - _ = socket.Close() + _ = socket.CoordinatorClose() wg.Done() }() } - for _, connMap := range c.agentToConnectionSockets { - wg.Add(len(connMap)) - for _, socket := range connMap { - socket := socket - go func() { - _ = socket.Close() - wg.Done() - }() - } + wg.Add(len(c.clients)) + for _, client := range c.clients { + client := client + go func() { + _ = client.CoordinatorClose() + wg.Done() + }() } c.mutex.Unlock() @@ -649,8 +658,8 @@ func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) { } func CoordinatorHTTPDebug( - agentSocketsMap map[uuid.UUID]*TrackedConn, - agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]*TrackedConn, + agentSocketsMap map[uuid.UUID]Queue, + agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]Queue, agentNameCache *lru.Cache[uuid.UUID, string], ) func(w http.ResponseWriter, _ *http.Request) { return func(w http.ResponseWriter, _ *http.Request) { @@ -658,7 +667,7 @@ func CoordinatorHTTPDebug( type idConn struct { id uuid.UUID - conn *TrackedConn + conn Queue } { @@ -671,16 +680,17 @@ func CoordinatorHTTPDebug( } slices.SortFunc(agentSockets, func(a, b idConn) bool { - return a.conn.Name < b.conn.Name + return a.conn.Name() < b.conn.Name() }) for _, agent := range agentSockets { + start, lastWrite := agent.conn.Stats() _, _ = fmt.Fprintf(w, "
  • %s (%s): created %v ago, write %v ago, overwrites %d
  • \n", - agent.conn.Name, + agent.conn.Name(), agent.id.String(), - now.Sub(time.Unix(agent.conn.Start, 0)).Round(time.Second), - now.Sub(time.Unix(agent.conn.LastWrite, 0)).Round(time.Second), - agent.conn.Overwrites, + now.Sub(time.Unix(start, 0)).Round(time.Second), + now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), + agent.conn.Overwrites(), ) if conns := agentToConnectionSocketsMap[agent.id]; len(conns) > 0 { @@ -696,11 +706,12 @@ func CoordinatorHTTPDebug( _, _ = fmt.Fprintln(w, "") @@ -755,11 +766,12 @@ func CoordinatorHTTPDebug( _, _ = fmt.Fprintf(w, "

    connections: total %d

    \n", len(agentConns.conns)) _, _ = fmt.Fprintln(w, "") diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go new file mode 100644 index 0000000000000..13300fdce677a --- /dev/null +++ b/tailnet/multiagent.go @@ -0,0 +1,167 @@ +package tailnet + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" +) + +type MultiAgentConn interface { + UpdateSelf(node *Node) error + SubscribeAgent(agentID uuid.UUID) error + UnsubscribeAgent(agentID uuid.UUID) error + NextUpdate(ctx context.Context) ([]*Node, bool) + AgentIsLegacy(agentID uuid.UUID) bool + Close() error + IsClosed() bool +} + +type MultiAgent struct { + mu sync.RWMutex + + closed bool + + ID uuid.UUID + Logger slog.Logger + + AgentIsLegacyFunc func(agentID uuid.UUID) bool + OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error) + OnUnsubscribe func(enq Queue, agent uuid.UUID) error + OnNodeUpdate func(id uuid.UUID, node *Node) error + OnRemove func(id uuid.UUID) + + updates chan []*Node + closeOnce sync.Once + start int64 + lastWrite int64 + // Client nodes normally generate a unique id for each connection so + // overwrites are really not an issue, but is provided for compatibility. + overwrites int64 +} + +func (m *MultiAgent) Init() *MultiAgent { + m.updates = make(chan []*Node, 128) + m.start = time.Now().Unix() + return m +} + +func (m *MultiAgent) UniqueID() uuid.UUID { + return m.ID +} + +func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool { + return m.AgentIsLegacyFunc(agentID) +} + +var ErrMultiAgentClosed = xerrors.New("multiagent is closed") + +func (m *MultiAgent) UpdateSelf(node *Node) error { + m.mu.RLock() + defer m.mu.RUnlock() + if m.closed { + return ErrMultiAgentClosed + } + + return m.OnNodeUpdate(m.ID, node) +} + +func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error { + m.mu.RLock() + defer m.mu.RUnlock() + if m.closed { + return ErrMultiAgentClosed + } + + node, err := m.OnSubscribe(m, agentID) + if err != nil { + return err + } + + if node != nil { + return m.enqueueLocked([]*Node{node}) + } + + return nil +} + +func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error { + m.mu.RLock() + defer m.mu.RUnlock() + if m.closed { + return ErrMultiAgentClosed + } + + return m.OnUnsubscribe(m, agentID) +} + +func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) { + select { + case <-ctx.Done(): + return nil, false + + case nodes, ok := <-m.updates: + return nodes, ok + } +} + +func (m *MultiAgent) Enqueue(nodes []*Node) error { + m.mu.RLock() + defer m.mu.RUnlock() + + if m.closed { + return nil + } + + return m.enqueueLocked(nodes) +} + +func (m *MultiAgent) enqueueLocked(nodes []*Node) error { + atomic.StoreInt64(&m.lastWrite, time.Now().Unix()) + + select { + case m.updates <- nodes: + return nil + default: + return ErrWouldBlock + } +} + +func (m *MultiAgent) Name() string { + return m.ID.String() +} + +func (m *MultiAgent) Stats() (start int64, lastWrite int64) { + return m.start, atomic.LoadInt64(&m.lastWrite) +} + +func (m *MultiAgent) Overwrites() int64 { + return m.overwrites +} + +func (m *MultiAgent) IsClosed() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.closed +} + +func (m *MultiAgent) CoordinatorClose() error { + m.mu.Lock() + if !m.closed { + m.closed = true + close(m.updates) + } + m.mu.Unlock() + return nil +} + +func (m *MultiAgent) Close() error { + _ = m.CoordinatorClose() + m.closeOnce.Do(func() { m.OnRemove(m.ID) }) + return nil +} diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index 482c1232e258a..0cb7dbd330ed3 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -22,7 +22,7 @@ import ( ) // RunDERPAndSTUN creates a DERP mapping for tests. -func RunDERPAndSTUN(t *testing.T) *tailcfg.DERPMap { +func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) { logf := tailnet.Logger(slogtest.Make(t, nil)) d := derp.NewServer(key.NewNode(), logf) server := httptest.NewUnstartedServer(derphttp.Handler(d)) @@ -61,7 +61,7 @@ func RunDERPAndSTUN(t *testing.T) *tailcfg.DERPMap { }, }, }, - } + }, d } // RunDERPOnlyWebSockets creates a DERP mapping for tests that diff --git a/tailnet/tailnettest/tailnettest_test.go b/tailnet/tailnettest/tailnettest_test.go index 6424bd94db0c0..fda818a1cebca 100644 --- a/tailnet/tailnettest/tailnettest_test.go +++ b/tailnet/tailnettest/tailnettest_test.go @@ -14,7 +14,7 @@ func TestMain(m *testing.M) { func TestRunDERPAndSTUN(t *testing.T) { t.Parallel() - _ = tailnettest.RunDERPAndSTUN(t) + _, _ = tailnettest.RunDERPAndSTUN(t) } func TestRunDERPOnlyWebSockets(t *testing.T) { diff --git a/tailnet/trackedconn.go b/tailnet/trackedconn.go new file mode 100644 index 0000000000000..cedd6e37dbc8d --- /dev/null +++ b/tailnet/trackedconn.go @@ -0,0 +1,147 @@ +package tailnet + +import ( + "bytes" + "context" + "encoding/json" + "net" + "sync/atomic" + "time" + + "github.com/google/uuid" + + "cdr.dev/slog" +) + +// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung. +// It is exported so that tests can use it. +const WriteTimeout = time.Second * 5 + +type TrackedConn struct { + ctx context.Context + cancel func() + conn net.Conn + updates chan []*Node + logger slog.Logger + lastData []byte + + // ID is an ephemeral UUID used to uniquely identify the owner of the + // connection. + id uuid.UUID + + name string + start int64 + lastWrite int64 + overwrites int64 +} + +func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, overwrites int64) *TrackedConn { + // buffer updates so they don't block, since we hold the + // coordinator mutex while queuing. Node updates don't + // come quickly, so 512 should be plenty for all but + // the most pathological cases. + updates := make(chan []*Node, 512) + now := time.Now().Unix() + return &TrackedConn{ + ctx: ctx, + conn: conn, + cancel: cancel, + updates: updates, + logger: logger, + id: id, + start: now, + lastWrite: now, + overwrites: overwrites, + } +} + +func (t *TrackedConn) Enqueue(n []*Node) (err error) { + atomic.StoreInt64(&t.lastWrite, time.Now().Unix()) + select { + case t.updates <- n: + return nil + default: + return ErrWouldBlock + } +} + +func (t *TrackedConn) UniqueID() uuid.UUID { + return t.id +} + +func (t *TrackedConn) Name() string { + return t.name +} + +func (t *TrackedConn) Stats() (start, lastWrite int64) { + return t.start, atomic.LoadInt64(&t.lastWrite) +} + +func (t *TrackedConn) Overwrites() int64 { + return t.overwrites +} + +func (t *TrackedConn) CoordinatorClose() error { + return t.Close() +} + +// Close the connection and cancel the context for reading node updates from the queue +func (t *TrackedConn) Close() error { + t.cancel() + return t.conn.Close() +} + +// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is +// canceled. +func (t *TrackedConn) SendUpdates() { + for { + select { + case <-t.ctx.Done(): + t.logger.Debug(t.ctx, "done sending updates") + return + case nodes := <-t.updates: + data, err := json.Marshal(nodes) + if err != nil { + t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) + return + } + if bytes.Equal(t.lastData, data) { + t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", string(data))) + continue + } + + // Set a deadline so that hung connections don't put back pressure on the system. + // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. + err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err)) + _ = t.Close() + return + } + _, err = t.conn.Write(data) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + t.logger.Debug(t.ctx, "could not write nodes to connection", + slog.Error(err), slog.F("nodes", string(data))) + _ = t.Close() + return + } + t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", string(data))) + + // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are + // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() + // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. + // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after + // our successful write, it is important that we reset the deadline before it fires. + err = t.conn.SetWriteDeadline(time.Time{}) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err)) + _ = t.Close() + return + } + t.lastData = data + } + } +}