From ba2dfd250586dac9069a5d846a88f309f258fe75 Mon Sep 17 00:00:00 2001 From: Ethan <39577870+ethanndickson@users.noreply.github.com> Date: Wed, 30 Apr 2025 15:17:10 +1000 Subject: [PATCH 01/11] feat(cli): use coder connect in `coder ssh --stdio`, if available (#17572) Closes https://github.com/coder/vscode-coder/issues/447 Closes https://github.com/coder/jetbrains-coder/issues/543 Closes https://github.com/coder/coder-jetbrains-toolbox/issues/21 This PR adds Coder Connect support to `coder ssh --stdio`. When connecting to a workspace, if `--force-new-tunnel` is not passed, the CLI will first do a DNS lookup for `...`. If an IP address is returned, and it's within the Coder service prefix, the CLI will not create a new tailnet connection to the workspace, and instead dial the SSH server running on port 22 on the workspace directly over TCP. This allows IDE extensions to use the Coder Connect tunnel, without requiring any modifications to the extensions themselves. Additionally, `using_coder_connect` is added to the `sshNetworkStats` file, which the VS Code extension (and maybe Jetbrains?) will be able to read, and indicate to the user that they are using Coder Connect. One advantage of this approach is that running `coder ssh --stdio` on an offline workspace with Coder Connect enabled will have the CLI wait for the workspace to build, the agent to connect (and optionally, for the startup scripts to finish), before finally connecting using the Coder Connect tunnel. As a result, `coder ssh --stdio` has the overhead of looking up the workspace and agent, and checking if they are running. On my device, this meant `coder ssh --stdio ` was approximately a second slower than just connecting to the workspace directly using `ssh .coder` (I would assume anyone serious about their Coder Connect usage would know to just do the latter anyway). To ensure this doesn't come at a significant performance cost, I've also benchmarked this PR.
Benchmark ## Methodology All tests were completed on `dev.coder.com`, where a Linux workspace running in AWS `us-west1` was created. The machine running Coder Desktop (the 'client') was a Windows VM running in the same AWS region and VPC as the workspace. To test the performance of specifically the SSH connection, a port was forwarded between the client and workspace using: ``` ssh -p 22 -L7001:localhost:7001 ``` where `host` was either an alias for an SSH ProxyCommand that called `coder ssh`, or a Coder Connect hostname. For latency, [`tcping`](https://www.elifulkerson.com/projects/tcping.php) was used against the forwarded port: ``` tcping -n 100 localhost 7001 ``` For throughput, [`iperf3`](https://iperf.fr/iperf-download.php) was used: ``` iperf3 -c localhost -p 7001 ``` where an `iperf3` server was running on the workspace on port 7001. ## Test Cases ### Testcase 1: `coder ssh` `ProxyCommand` that bicopies from Coder Connect This case tests the implementation in this PR, such that we can write a config like: ``` Host codercliconnect ProxyCommand /path/to/coder ssh --stdio workspace ``` With Coder Connect enabled, `ssh -p 22 -L7001:localhost:7001 codercliconnect` will use the Coder Connect tunnel. The results were as follows: **Throughput, 10 tests, back to back:** - Average throughput across all tests: 788.20 Mbits/sec - Minimum average throughput: 731 Mbits/sec - Maximum average throughput: 871 Mbits/sec - Standard Deviation: 38.88 Mbits/sec **Latency, 100 RTTs:** - Average: 0.369ms - Minimum: 0.290ms - Maximum: 0.473ms ### Testcase 2: `ssh` dialing Coder Connect directly without a `ProxyCommand` This is what we assume to be the 'best' way to use Coder Connect **Throughput, 10 tests, back to back:** - Average throughput across all tests: 789.50 Mbits/sec - Minimum average throughput: 708 Mbits/sec - Maximum average throughput: 839 Mbits/sec - Standard Deviation: 39.98 Mbits/sec **Latency, 100 RTTs:** - Average: 0.369ms - Minimum: 0.267ms - Maximum: 0.440ms ### Testcase 3: `coder ssh` `ProxyCommand` that creates its own Tailnet connection in-process This is what normally happens when you run `coder ssh`: **Throughput, 10 tests, back to back:** - Average throughput across all tests: 610.20 Mbits/sec - Minimum average throughput: 569 Mbits/sec - Maximum average throughput: 664 Mbits/sec - Standard Deviation: 27.29 Mbits/sec **Latency, 100 RTTs:** - Average: 0.335ms - Minimum: 0.262ms - Maximum: 0.452ms ## Analysis Performing a two-tailed, unpaired t-test against the throughput of testcases 1 and 2, we find a P value of `0.9450`. This suggests the difference between the data sets is not statistically significant. In other words, there is a 94.5% chance that the difference between the data sets is due to chance. ## Conclusion From the t-test, and by comparison to the status quo (regular `coder ssh`, which uses gvisor, and is noticeably slower), I think it's safe to say any impact on throughput or latency by the `ProxyCommand` performing a bicopy against Coder Connect is negligible. Users are very much unlikely to run into performance issues as a result of using Coder Connect via `coder ssh`, as implemented in this PR. Less scientifically, I ran these same tests on my home network with my Sydney workspace, and both throughput and latency were consistent across testcases 1 and 2.
(cherry picked from commit 53ba3613b3500c1be8acac999683b5b3cfffde07) --- cli/ssh.go | 132 +++++++++++++++++++++++-- cli/ssh_internal_test.go | 85 ++++++++++++++++ cli/ssh_test.go | 151 ++++++++++++++++++++++------- codersdk/workspacesdk/agentconn.go | 4 +- testutil/rwconn.go | 36 +++++++ 5 files changed, 359 insertions(+), 49 deletions(-) create mode 100644 testutil/rwconn.go diff --git a/cli/ssh.go b/cli/ssh.go index 2025c1691b7d7..f9cc1be14c3b8 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log" + "net" "net/http" "net/url" "os" @@ -66,6 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command { stdio bool hostPrefix string hostnameSuffix string + forceNewTunnel bool forwardAgent bool forwardGPG bool identityAgent string @@ -85,6 +87,7 @@ func (r *RootCmd) ssh() *serpent.Command { containerUser string ) client := new(codersdk.Client) + wsClient := workspacesdk.New(client) cmd := &serpent.Command{ Annotations: workspaceCommand, Use: "ssh ", @@ -203,14 +206,14 @@ func (r *RootCmd) ssh() *serpent.Command { parsedEnv = append(parsedEnv, [2]string{k, v}) } - deploymentSSHConfig := codersdk.SSHConfigResponse{ + cliConfig := codersdk.SSHConfigResponse{ HostnamePrefix: hostPrefix, HostnameSuffix: hostnameSuffix, } workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname( ctx, inv, client, - inv.Args[0], deploymentSSHConfig, disableAutostart) + inv.Args[0], cliConfig, disableAutostart) if err != nil { return err } @@ -275,10 +278,44 @@ func (r *RootCmd) ssh() *serpent.Command { return err } + // If we're in stdio mode, check to see if we can use Coder Connect. + // We don't support Coder Connect over non-stdio coder ssh yet. + if stdio && !forceNewTunnel { + connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx) + if err != nil { + return xerrors.Errorf("get agent connection info: %w", err) + } + coderConnectHost := fmt.Sprintf("%s.%s.%s.%s", + workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix) + exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost) + if exists { + defer cancel() + + if networkInfoDir != "" { + if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil { + logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err)) + } + } + + stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) + defer stopPolling() + + usageAppName := getUsageAppName(usageApp) + if usageAppName != "" { + closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{ + AgentID: workspaceAgent.ID, + AppName: usageAppName, + }) + defer closeUsage() + } + return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack) + } + } + if r.disableDirect { _, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.") } - conn, err := workspacesdk.New(client). + conn, err := wsClient. DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{ Logger: logger, BlockEndpoints: r.disableDirect, @@ -660,6 +697,12 @@ func (r *RootCmd) ssh() *serpent.Command { Value: serpent.StringOf(&containerUser), Hidden: true, // Hidden until this features is at least in beta. }, + { + Flag: "force-new-tunnel", + Description: "Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available.", + Value: serpent.BoolOf(&forceNewTunnel), + Hidden: true, + }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } return cmd @@ -1372,12 +1415,13 @@ func setStatsCallback( } type sshNetworkStats struct { - P2P bool `json:"p2p"` - Latency float64 `json:"latency"` - PreferredDERP string `json:"preferred_derp"` - DERPLatency map[string]float64 `json:"derp_latency"` - UploadBytesSec int64 `json:"upload_bytes_sec"` - DownloadBytesSec int64 `json:"download_bytes_sec"` + P2P bool `json:"p2p"` + Latency float64 `json:"latency"` + PreferredDERP string `json:"preferred_derp"` + DERPLatency map[string]float64 `json:"derp_latency"` + UploadBytesSec int64 `json:"upload_bytes_sec"` + DownloadBytesSec int64 `json:"download_bytes_sec"` + UsingCoderConnect bool `json:"using_coder_connect"` } func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { @@ -1448,6 +1492,76 @@ func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, }, nil } +type coderConnectDialerContextKey struct{} + +type coderConnectDialer interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) +} + +func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDialer) context.Context { + return context.WithValue(ctx, coderConnectDialerContextKey{}, dialer) +} + +func testOrDefaultDialer(ctx context.Context) coderConnectDialer { + dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer) + if !ok || dialer == nil { + return &net.Dialer{} + } + return dialer +} + +func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error { + dialer := testOrDefaultDialer(ctx) + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return xerrors.Errorf("dial coder connect host: %w", err) + } + if err := stack.push("tcp conn", conn); err != nil { + return err + } + + agentssh.Bicopy(ctx, conn, &StdioRwc{ + Reader: stdin, + Writer: stdout, + }) + + return nil +} + +type StdioRwc struct { + io.Reader + io.Writer +} + +func (*StdioRwc) Close() error { + return nil +} + +func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error { + fs, ok := ctx.Value("fs").(afero.Fs) + if !ok { + fs = afero.NewOsFs() + } + // The VS Code extension obtains the PID of the SSH process to + // find the log file associated with a SSH session. + // + // We get the parent PID because it's assumed `ssh` is calling this + // command via the ProxyCommand SSH option. + networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid())) + stats := &sshNetworkStats{ + UsingCoderConnect: true, + } + rawStats, err := json.Marshal(stats) + if err != nil { + return xerrors.Errorf("marshal network stats: %w", err) + } + err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600) + if err != nil { + return xerrors.Errorf("write network stats: %w", err) + } + return nil +} + // Converts workspace name input to owner/workspace.agent format // Possible valid input formats: // workspace diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index d5e4c049347b2..caee1ec25b710 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -3,13 +3,17 @@ package cli import ( "context" "fmt" + "io" + "net" "net/url" "sync" "testing" "time" + gliderssh "github.com/gliderlabs/ssh" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" "golang.org/x/xerrors" "cdr.dev/slog" @@ -220,6 +224,87 @@ func TestCloserStack_Timeout(t *testing.T) { testutil.TryReceive(ctx, t, closed) } +func TestCoderConnectStdio(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + stack := newCloserStack(ctx, logger, quartz.NewMock(t)) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + server := newSSHServer("127.0.0.1:0") + ln, err := net.Listen("tcp", server.server.Addr) + require.NoError(t, err) + + go func() { + _ = server.Serve(ln) + }() + t.Cleanup(func() { + _ = server.Close() + }) + + stdioDone := make(chan struct{}) + go func() { + err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack) + assert.NoError(t, err) + close(stdioDone) + }() + + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + // We're not connected to a real shell + err = session.Run("") + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + + <-stdioDone +} + +type sshServer struct { + server *gliderssh.Server +} + +func newSSHServer(addr string) *sshServer { + return &sshServer{ + server: &gliderssh.Server{ + Addr: addr, + Handler: func(s gliderssh.Session) { + _, _ = io.WriteString(s.Stderr(), "Connected!") + }, + }, + } +} + +func (s *sshServer) Serve(ln net.Listener) error { + return s.server.Serve(ln) +} + +func (s *sshServer) Close() error { + return s.server.Close() +} + type fakeCloser struct { closes *[]*fakeCloser err error diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 2603c81e88cec..5fcb6205d5e45 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -41,6 +41,7 @@ import ( "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agenttest" agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/cli" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/coderdtest" @@ -473,7 +474,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -542,7 +543,7 @@ func TestSSH(t *testing.T) { signer, err := agentssh.CoderSigner(keySeed) assert.NoError(t, err) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -605,7 +606,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -773,7 +774,7 @@ func TestSSH(t *testing.T) { // have access to the shell. _ = agenttest.New(t, client.URL, authToken) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: proxyCommandStdoutR, Writer: clientStdinW, }, "", &ssh.ClientConfig{ @@ -835,7 +836,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -894,7 +895,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1082,7 +1083,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -1741,7 +1742,7 @@ func TestSSH(t *testing.T) { assert.NoError(t, err) }) - conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ Reader: serverOutput, Writer: clientInput, }, "", &ssh.ClientConfig{ @@ -2102,6 +2103,111 @@ func TestSSH_Container(t *testing.T) { }) } +func TestSSH_CoderConnect(t *testing.T) { + t.Parallel() + + t.Run("Enabled", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + fs := afero.NewMemMapFs() + //nolint:revive,staticcheck + ctx = context.WithValue(ctx, "fs", fs) + + client, workspace, agentToken := setupWorkspaceForAgent(t) + inv, root := clitest.New(t, "ssh", workspace.Name, "--network-info-dir", "/net", "--stdio") + clitest.SetupConfig(t, client, root) + _ = ptytest.New(t).Attach(inv) + + ctx = cli.WithTestOnlyCoderConnectDialer(ctx, &fakeCoderConnectDialer{}) + ctx = withCoderConnectRunning(ctx) + + errCh := make(chan error, 1) + tGo(t, func() { + err := inv.WithContext(ctx).Run() + errCh <- err + }) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + err := testutil.TryReceive(ctx, t, errCh) + // Our mock dialer will always fail with this error, if it was called + require.ErrorContains(t, err, "dial coder connect host \"dev.myworkspace.myuser.coder:22\" over tcp") + + // The network info file should be created since we passed `--stdio` + entries, err := afero.ReadDir(fs, "/net") + require.NoError(t, err) + require.True(t, len(entries) > 0) + }) + + t.Run("Disabled", func(t *testing.T) { + t.Parallel() + client, workspace, agentToken := setupWorkspaceForAgent(t) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + inv, root := clitest.New(t, "ssh", "--force-new-tunnel", "--stdio", workspace.Name) + clitest.SetupConfig(t, client, root) + inv.Stdin = clientOutput + inv.Stdout = serverInput + inv.Stderr = io.Discard + + ctx = cli.WithTestOnlyCoderConnectDialer(ctx, &fakeCoderConnectDialer{}) + ctx = withCoderConnectRunning(ctx) + + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + // Shouldn't fail to dial the Coder Connect host + // since `--force-new-tunnel` was passed + assert.NoError(t, err) + }) + + conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + // #nosec + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. + err = session.Run("exit") + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + + <-cmdDone + }) +} + +type fakeCoderConnectDialer struct{} + +func (*fakeCoderConnectDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, xerrors.Errorf("dial coder connect host %q over %s", addr, network) +} + // tGoContext runs fn in a goroutine passing a context that will be // canceled on test completion and wait until fn has finished executing. // Done and cancel are returned for optionally waiting until completion @@ -2145,35 +2251,6 @@ func tGo(t *testing.T, fn func()) (done <-chan struct{}) { return doneC } -type stdioConn struct { - io.Reader - io.Writer -} - -func (*stdioConn) Close() (err error) { - return nil -} - -func (*stdioConn) LocalAddr() net.Addr { - return nil -} - -func (*stdioConn) RemoteAddr() net.Addr { - return nil -} - -func (*stdioConn) SetDeadline(_ time.Time) error { - return nil -} - -func (*stdioConn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (*stdioConn) SetWriteDeadline(_ time.Time) error { - return nil -} - // tempDirUnixSocket returns a temporary directory that can safely hold unix // sockets (probably). // diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index fa569080f7dd2..97b4268c68780 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -185,14 +185,12 @@ func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, return c.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), port)) } -// SSHClient calls SSH to create a client that uses a weak cipher -// to improve throughput. +// SSHClient calls SSH to create a client func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { return c.SSHClientOnPort(ctx, AgentSSHPort) } // SSHClientOnPort calls SSH to create a client on a specific port -// that uses a weak cipher to improve throughput. func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() diff --git a/testutil/rwconn.go b/testutil/rwconn.go new file mode 100644 index 0000000000000..a731e9c3c0ab0 --- /dev/null +++ b/testutil/rwconn.go @@ -0,0 +1,36 @@ +package testutil + +import ( + "io" + "net" + "time" +) + +type ReaderWriterConn struct { + io.Reader + io.Writer +} + +func (*ReaderWriterConn) Close() (err error) { + return nil +} + +func (*ReaderWriterConn) LocalAddr() net.Addr { + return nil +} + +func (*ReaderWriterConn) RemoteAddr() net.Addr { + return nil +} + +func (*ReaderWriterConn) SetDeadline(_ time.Time) error { + return nil +} + +func (*ReaderWriterConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (*ReaderWriterConn) SetWriteDeadline(_ time.Time) error { + return nil +} From 813e631c1f84897b674ffbb981c4f89570a126c6 Mon Sep 17 00:00:00 2001 From: Ethan <39577870+ethanndickson@users.noreply.github.com> Date: Thu, 1 May 2025 16:53:13 +1000 Subject: [PATCH 02/11] fix: create directory before writing coder connect network info file (#17628) The regular network info file creation code also calls `Mkdirall`. Wasn't picked up in manual testing as I already had the `/net` folder in my VSCode. Wasn't picked up in automated testing because we use an in-memory FS, which for some reason does this implicitly. (cherry picked from commit c7fc7b91ec5cd7f108022ad3c511fa77c94f427e) --- cli/ssh.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cli/ssh.go b/cli/ssh.go index f9cc1be14c3b8..7c5bda073f973 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -1542,6 +1542,10 @@ func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error if !ok { fs = afero.NewOsFs() } + if err := fs.MkdirAll(networkInfoDir, 0o700); err != nil { + return xerrors.Errorf("mkdir: %w", err) + } + // The VS Code extension obtains the PID of the SSH process to // find the log file associated with a SSH session. // From 0d4b15ea081a157b23699e86fe8e5ee636438a8e Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 6 May 2025 16:00:16 +0200 Subject: [PATCH 03/11] feat: improve coder connect tunnel handling on reconnect (#17598) Closes https://github.com/coder/internal/issues/563 The [Coder Connect tunnel](https://github.com/coder/coder/blob/main/vpn/tunnel.go) receives workspace state from the Coder server over a [dRPC stream.](https://github.com/coder/coder/blob/114ba4593b2a82dfd41cdcb7fd6eb70d866e7b86/tailnet/controllers.go#L1029) When first connecting to this stream, the current state of the user's workspaces is received, with subsequent messages being diffs on top of that state. However, if the client disconnects from this stream, such as when the user's device is suspended, and then reconnects later, no mechanism exists for the tunnel to differentiate that message containing the entire initial state from another diff, and so that state is incorrectly applied as a diff. In practice: - Tunnel connects, receives a workspace update containing all the existing workspaces & agents. - Tunnel loses connection, but isn't completely stopped. - All the user's workspaces are restarted, producing a new set of agents. - Tunnel regains connection, and receives a workspace update containing all the existing workspaces & agents. - This initial update is incorrectly applied as a diff, with the Tunnel's state containing both the old & new agents. This PR introduces a solution in which tunnelUpdater, when created, sends a FreshState flag with the WorkspaceUpdate type. This flag is handled in the vpn tunnel in the following fashion: - Preserve existing Agents - Remove current Agents in the tunnel that are not present in the WorkspaceUpdate - Remove unreferenced Workspaces (cherry picked from commit 5f516ed135118ee721021c5722b62f3ea5c8cd2e) --- tailnet/controllers.go | 32 ++- tailnet/controllers_test.go | 9 +- vpn/tunnel.go | 77 +++++-- vpn/tunnel_internal_test.go | 396 ++++++++++++++++++++++++++++++++++++ 4 files changed, 498 insertions(+), 16 deletions(-) diff --git a/tailnet/controllers.go b/tailnet/controllers.go index 2328e19640a4d..b7d4e246a4bee 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -897,6 +897,21 @@ type Workspace struct { agents map[uuid.UUID]*Agent } +func (w *Workspace) Clone() Workspace { + agents := make(map[uuid.UUID]*Agent, len(w.agents)) + for k, v := range w.agents { + clone := v.Clone() + agents[k] = &clone + } + return Workspace{ + ID: w.ID, + Name: w.Name, + Status: w.Status, + ownerUsername: w.ownerUsername, + agents: agents, + } +} + type DNSNameOptions struct { Suffix string } @@ -1049,6 +1064,7 @@ func (t *tunnelUpdater) recvLoop() { t.logger.Debug(context.Background(), "tunnel updater recvLoop started") defer t.logger.Debug(context.Background(), "tunnel updater recvLoop done") defer close(t.recvLoopDone) + updateKind := Snapshot for { update, err := t.client.Recv() if err != nil { @@ -1061,8 +1077,10 @@ func (t *tunnelUpdater) recvLoop() { } t.logger.Debug(context.Background(), "got workspace update", slog.F("workspace_update", update), + slog.F("update_kind", updateKind), ) - err = t.handleUpdate(update) + err = t.handleUpdate(update, updateKind) + updateKind = Diff if err != nil { t.logger.Critical(context.Background(), "failed to handle workspace Update", slog.Error(err)) cErr := t.client.Close() @@ -1083,14 +1101,23 @@ type WorkspaceUpdate struct { UpsertedAgents []*Agent DeletedWorkspaces []*Workspace DeletedAgents []*Agent + Kind UpdateKind } +type UpdateKind int + +const ( + Diff UpdateKind = iota + Snapshot +) + func (w *WorkspaceUpdate) Clone() WorkspaceUpdate { clone := WorkspaceUpdate{ UpsertedWorkspaces: make([]*Workspace, len(w.UpsertedWorkspaces)), UpsertedAgents: make([]*Agent, len(w.UpsertedAgents)), DeletedWorkspaces: make([]*Workspace, len(w.DeletedWorkspaces)), DeletedAgents: make([]*Agent, len(w.DeletedAgents)), + Kind: w.Kind, } for i, ws := range w.UpsertedWorkspaces { clone.UpsertedWorkspaces[i] = &Workspace{ @@ -1115,7 +1142,7 @@ func (w *WorkspaceUpdate) Clone() WorkspaceUpdate { return clone } -func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate) error { +func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate, updateKind UpdateKind) error { t.Lock() defer t.Unlock() @@ -1124,6 +1151,7 @@ func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate) error { UpsertedAgents: []*Agent{}, DeletedWorkspaces: []*Workspace{}, DeletedAgents: []*Agent{}, + Kind: updateKind, } for _, uw := range update.UpsertedWorkspaces { diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go index 67834de462655..bb5b543378ebe 100644 --- a/tailnet/controllers_test.go +++ b/tailnet/controllers_test.go @@ -1611,6 +1611,7 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) { }, DeletedWorkspaces: []*tailnet.Workspace{}, DeletedAgents: []*tailnet.Agent{}, + Kind: tailnet.Snapshot, } // And the callback @@ -1626,6 +1627,9 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) { slices.SortFunc(recvState.UpsertedAgents, func(a, b *tailnet.Agent) int { return strings.Compare(a.Name, b.Name) }) + // tunnel is still open, so it's a diff + currentState.Kind = tailnet.Diff + require.Equal(t, currentState, recvState) } @@ -1692,14 +1696,17 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) { }, DeletedWorkspaces: []*tailnet.Workspace{}, DeletedAgents: []*tailnet.Agent{}, + Kind: tailnet.Snapshot, } cbUpdate := testutil.TryReceive(ctx, t, fUH.ch) require.Equal(t, initRecvUp, cbUpdate) - // Current state should match initial state, err := updateCtrl.CurrentState() require.NoError(t, err) + // tunnel is still open, so it's a diff + initRecvUp.Kind = tailnet.Diff + require.Equal(t, initRecvUp, state) // Send update that removes w1a1 and adds w1a2 diff --git a/vpn/tunnel.go b/vpn/tunnel.go index 63de203980d14..6c71aecaa0965 100644 --- a/vpn/tunnel.go +++ b/vpn/tunnel.go @@ -88,6 +88,7 @@ func NewTunnel( netLoopDone: make(chan struct{}), uSendCh: s.sendCh, agents: map[uuid.UUID]tailnet.Agent{}, + workspaces: map[uuid.UUID]tailnet.Workspace{}, clock: quartz.NewReal(), }, } @@ -347,7 +348,9 @@ type updater struct { uSendCh chan<- *TunnelMessage // agents contains the agents that are currently connected to the tunnel. agents map[uuid.UUID]tailnet.Agent - conn Conn + // workspaces contains the workspaces to which agents are currently connected via the tunnel. + workspaces map[uuid.UUID]tailnet.Workspace + conn Conn clock quartz.Clock } @@ -397,6 +400,11 @@ func (u *updater) sendUpdateResponse(req *request[*TunnelMessage, *ManagerMessag // createPeerUpdateLocked creates a PeerUpdate message from a workspace update, populating // the network status of the agents. func (u *updater) createPeerUpdateLocked(update tailnet.WorkspaceUpdate) *PeerUpdate { + // if the update is a snapshot, we need to process the full state + if update.Kind == tailnet.Snapshot { + processSnapshotUpdate(&update, u.agents, u.workspaces) + } + out := &PeerUpdate{ UpsertedWorkspaces: make([]*Workspace, len(update.UpsertedWorkspaces)), UpsertedAgents: make([]*Agent, len(update.UpsertedAgents)), @@ -404,7 +412,20 @@ func (u *updater) createPeerUpdateLocked(update tailnet.WorkspaceUpdate) *PeerUp DeletedAgents: make([]*Agent, len(update.DeletedAgents)), } - u.saveUpdateLocked(update) + // save the workspace update to the tunnel's state, such that it can + // be used to populate automated peer updates. + for _, agent := range update.UpsertedAgents { + u.agents[agent.ID] = agent.Clone() + } + for _, agent := range update.DeletedAgents { + delete(u.agents, agent.ID) + } + for _, workspace := range update.UpsertedWorkspaces { + u.workspaces[workspace.ID] = workspace.Clone() + } + for _, workspace := range update.DeletedWorkspaces { + delete(u.workspaces, workspace.ID) + } for i, ws := range update.UpsertedWorkspaces { out.UpsertedWorkspaces[i] = &Workspace{ @@ -413,6 +434,7 @@ func (u *updater) createPeerUpdateLocked(update tailnet.WorkspaceUpdate) *PeerUp Status: Workspace_Status(ws.Status), } } + upsertedAgents := u.convertAgentsLocked(update.UpsertedAgents) out.UpsertedAgents = upsertedAgents for i, ws := range update.DeletedWorkspaces { @@ -472,17 +494,6 @@ func (u *updater) convertAgentsLocked(agents []*tailnet.Agent) []*Agent { return out } -// saveUpdateLocked saves the workspace update to the tunnel's state, such that it can -// be used to populate automated peer updates. -func (u *updater) saveUpdateLocked(update tailnet.WorkspaceUpdate) { - for _, agent := range update.UpsertedAgents { - u.agents[agent.ID] = agent.Clone() - } - for _, agent := range update.DeletedAgents { - delete(u.agents, agent.ID) - } -} - // setConn sets the `conn` and returns false if there's already a connection set. func (u *updater) setConn(conn Conn) bool { u.mu.Lock() @@ -552,6 +563,46 @@ func (u *updater) netStatusLoop() { } } +// processSnapshotUpdate handles the logic when a full state update is received. +// When the tunnel is live, we only receive diffs, but the first packet on any given +// reconnect to the tailnet API is a full state. +// Without this logic we weren't processing deletes for any workspaces or agents deleted +// while the client was disconnected while the computer was asleep. +func processSnapshotUpdate(update *tailnet.WorkspaceUpdate, agents map[uuid.UUID]tailnet.Agent, workspaces map[uuid.UUID]tailnet.Workspace) { + // ignoredWorkspaces is initially populated with the workspaces that are + // in the current update. Later on we populate it with the deleted workspaces too + // so that we don't send duplicate updates. Same applies to ignoredAgents. + ignoredWorkspaces := make(map[uuid.UUID]struct{}, len(update.UpsertedWorkspaces)) + ignoredAgents := make(map[uuid.UUID]struct{}, len(update.UpsertedAgents)) + + for _, workspace := range update.UpsertedWorkspaces { + ignoredWorkspaces[workspace.ID] = struct{}{} + } + for _, agent := range update.UpsertedAgents { + ignoredAgents[agent.ID] = struct{}{} + } + for _, agent := range agents { + if _, present := ignoredAgents[agent.ID]; !present { + // delete any current agents that are not in the new update + update.DeletedAgents = append(update.DeletedAgents, &tailnet.Agent{ + ID: agent.ID, + Name: agent.Name, + WorkspaceID: agent.WorkspaceID, + }) + } + } + for _, workspace := range workspaces { + if _, present := ignoredWorkspaces[workspace.ID]; !present { + update.DeletedWorkspaces = append(update.DeletedWorkspaces, &tailnet.Workspace{ + ID: workspace.ID, + Name: workspace.Name, + Status: workspace.Status, + }) + ignoredWorkspaces[workspace.ID] = struct{}{} + } + } +} + // hostsToIPStrings returns a slice of all unique IP addresses in the values // of the given map. func hostsToIPStrings(hosts map[dnsname.FQDN][]netip.Addr) []string { diff --git a/vpn/tunnel_internal_test.go b/vpn/tunnel_internal_test.go index d1d7377361f79..2beba66d7a7d2 100644 --- a/vpn/tunnel_internal_test.go +++ b/vpn/tunnel_internal_test.go @@ -2,6 +2,7 @@ package vpn import ( "context" + "maps" "net" "net/netip" "net/url" @@ -292,6 +293,7 @@ func TestUpdater_createPeerUpdate(t *testing.T) { ctx: ctx, netLoopDone: make(chan struct{}), agents: map[uuid.UUID]tailnet.Agent{}, + workspaces: map[uuid.UUID]tailnet.Workspace{}, conn: newFakeConn(tailnet.WorkspaceUpdate{}, hsTime), } @@ -486,6 +488,212 @@ func TestTunnel_sendAgentUpdate(t *testing.T) { require.Equal(t, hsTime, req.msg.GetPeerUpdate().UpsertedAgents[0].LastHandshake.AsTime()) } +func TestTunnel_sendAgentUpdateReconnect(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + + mClock := quartz.NewMock(t) + + wID1 := uuid.UUID{1} + aID1 := uuid.UUID{2} + aID2 := uuid.UUID{3} + + hsTime := time.Now().Add(-time.Minute).UTC() + + client := newFakeClient(ctx, t) + conn := newFakeConn(tailnet.WorkspaceUpdate{}, hsTime) + + tun, mgr := setupTunnel(t, ctx, client, mClock) + errCh := make(chan error, 1) + var resp *TunnelMessage + go func() { + r, err := mgr.unaryRPC(ctx, &ManagerMessage{ + Msg: &ManagerMessage_Start{ + Start: &StartRequest{ + TunnelFileDescriptor: 2, + CoderUrl: "https://coder.example.com", + ApiToken: "fakeToken", + }, + }, + }) + resp = r + errCh <- err + }() + testutil.RequireSend(ctx, t, client.ch, conn) + err := testutil.TryReceive(ctx, t, errCh) + require.NoError(t, err) + _, ok := resp.Msg.(*TunnelMessage_Start) + require.True(t, ok) + + // Inform the tunnel of the initial state + err = tun.Update(tailnet.WorkspaceUpdate{ + UpsertedWorkspaces: []*tailnet.Workspace{ + { + ID: wID1, Name: "w1", Status: proto.Workspace_STARTING, + }, + }, + UpsertedAgents: []*tailnet.Agent{ + { + ID: aID1, + Name: "agent1", + WorkspaceID: wID1, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "agent1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")}, + }, + }, + }, + }) + require.NoError(t, err) + req := testutil.TryReceive(ctx, t, mgr.requests) + require.Nil(t, req.msg.Rpc) + require.NotNil(t, req.msg.GetPeerUpdate()) + require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1) + require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id) + + // Upsert a new agent simulating a reconnect + err = tun.Update(tailnet.WorkspaceUpdate{ + UpsertedWorkspaces: []*tailnet.Workspace{ + { + ID: wID1, Name: "w1", Status: proto.Workspace_STARTING, + }, + }, + UpsertedAgents: []*tailnet.Agent{ + { + ID: aID2, + Name: "agent2", + WorkspaceID: wID1, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "agent2.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")}, + }, + }, + }, + Kind: tailnet.Snapshot, + }) + require.NoError(t, err) + + // The new update only contains the new agent + mClock.AdvanceNext() + req = testutil.TryReceive(ctx, t, mgr.requests) + require.Nil(t, req.msg.Rpc) + peerUpdate := req.msg.GetPeerUpdate() + require.NotNil(t, peerUpdate) + require.Len(t, peerUpdate.UpsertedAgents, 1) + require.Len(t, peerUpdate.DeletedAgents, 1) + require.Len(t, peerUpdate.DeletedWorkspaces, 0) + + require.Equal(t, aID2[:], peerUpdate.UpsertedAgents[0].Id) + require.Equal(t, hsTime, peerUpdate.UpsertedAgents[0].LastHandshake.AsTime()) + + require.Equal(t, aID1[:], peerUpdate.DeletedAgents[0].Id) +} + +func TestTunnel_sendAgentUpdateWorkspaceReconnect(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + + mClock := quartz.NewMock(t) + + wID1 := uuid.UUID{1} + wID2 := uuid.UUID{2} + aID1 := uuid.UUID{3} + aID3 := uuid.UUID{4} + + hsTime := time.Now().Add(-time.Minute).UTC() + + client := newFakeClient(ctx, t) + conn := newFakeConn(tailnet.WorkspaceUpdate{}, hsTime) + + tun, mgr := setupTunnel(t, ctx, client, mClock) + errCh := make(chan error, 1) + var resp *TunnelMessage + go func() { + r, err := mgr.unaryRPC(ctx, &ManagerMessage{ + Msg: &ManagerMessage_Start{ + Start: &StartRequest{ + TunnelFileDescriptor: 2, + CoderUrl: "https://coder.example.com", + ApiToken: "fakeToken", + }, + }, + }) + resp = r + errCh <- err + }() + testutil.RequireSend(ctx, t, client.ch, conn) + err := testutil.TryReceive(ctx, t, errCh) + require.NoError(t, err) + _, ok := resp.Msg.(*TunnelMessage_Start) + require.True(t, ok) + + // Inform the tunnel of the initial state + err = tun.Update(tailnet.WorkspaceUpdate{ + UpsertedWorkspaces: []*tailnet.Workspace{ + { + ID: wID1, Name: "w1", Status: proto.Workspace_STARTING, + }, + }, + UpsertedAgents: []*tailnet.Agent{ + { + ID: aID1, + Name: "agent1", + WorkspaceID: wID1, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "agent1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")}, + }, + }, + }, + }) + require.NoError(t, err) + req := testutil.TryReceive(ctx, t, mgr.requests) + require.Nil(t, req.msg.Rpc) + require.NotNil(t, req.msg.GetPeerUpdate()) + require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1) + require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id) + + // Upsert a new agent with a new workspace while simulating a reconnect + err = tun.Update(tailnet.WorkspaceUpdate{ + UpsertedWorkspaces: []*tailnet.Workspace{ + { + ID: wID2, Name: "w2", Status: proto.Workspace_STARTING, + }, + }, + UpsertedAgents: []*tailnet.Agent{ + { + ID: aID3, + Name: "agent3", + WorkspaceID: wID2, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "agent3.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")}, + }, + }, + }, + Kind: tailnet.Snapshot, + }) + require.NoError(t, err) + + // The new update only contains the new agent + mClock.AdvanceNext() + req = testutil.TryReceive(ctx, t, mgr.requests) + mClock.AdvanceNext() + + require.Nil(t, req.msg.Rpc) + peerUpdate := req.msg.GetPeerUpdate() + require.NotNil(t, peerUpdate) + require.Len(t, peerUpdate.UpsertedWorkspaces, 1) + require.Len(t, peerUpdate.UpsertedAgents, 1) + require.Len(t, peerUpdate.DeletedAgents, 1) + require.Len(t, peerUpdate.DeletedWorkspaces, 1) + + require.Equal(t, wID2[:], peerUpdate.UpsertedWorkspaces[0].Id) + require.Equal(t, aID3[:], peerUpdate.UpsertedAgents[0].Id) + require.Equal(t, hsTime, peerUpdate.UpsertedAgents[0].LastHandshake.AsTime()) + + require.Equal(t, aID1[:], peerUpdate.DeletedAgents[0].Id) + require.Equal(t, wID1[:], peerUpdate.DeletedWorkspaces[0].Id) +} + //nolint:revive // t takes precedence func setupTunnel(t *testing.T, ctx context.Context, client *fakeClient, mClock quartz.Clock) (*Tunnel, *speaker[*ManagerMessage, *TunnelMessage, TunnelMessage]) { mp, tp := net.Pipe() @@ -513,3 +721,191 @@ func setupTunnel(t *testing.T, ctx context.Context, client *fakeClient, mClock q mgr.start() return tun, mgr } + +func TestProcessFreshState(t *testing.T) { + t.Parallel() + + wsID1 := uuid.New() + wsID2 := uuid.New() + wsID3 := uuid.New() + wsID4 := uuid.New() + + agentID1 := uuid.New() + agentID2 := uuid.New() + agentID3 := uuid.New() + agentID4 := uuid.New() + + agent1 := tailnet.Agent{ID: agentID1, Name: "agent1", WorkspaceID: wsID1} + agent2 := tailnet.Agent{ID: agentID2, Name: "agent2", WorkspaceID: wsID2} + agent3 := tailnet.Agent{ID: agentID3, Name: "agent3", WorkspaceID: wsID3} + agent4 := tailnet.Agent{ID: agentID4, Name: "agent4", WorkspaceID: wsID1} + + ws1 := tailnet.Workspace{ID: wsID1, Name: "ws1", Status: proto.Workspace_RUNNING} + ws2 := tailnet.Workspace{ID: wsID2, Name: "ws2", Status: proto.Workspace_RUNNING} + ws3 := tailnet.Workspace{ID: wsID3, Name: "ws3", Status: proto.Workspace_RUNNING} + ws4 := tailnet.Workspace{ID: wsID4, Name: "ws4", Status: proto.Workspace_RUNNING} + + initialAgents := map[uuid.UUID]tailnet.Agent{ + agentID1: agent1, + agentID2: agent2, + agentID4: agent4, + } + initialWorkspaces := map[uuid.UUID]tailnet.Workspace{ + wsID1: ws1, + wsID2: ws2, + } + + tests := []struct { + name string + initialAgents map[uuid.UUID]tailnet.Agent + initialWorkspaces map[uuid.UUID]tailnet.Workspace + update *tailnet.WorkspaceUpdate + expectedDelete *tailnet.WorkspaceUpdate // We only care about deletions added by the function + }{ + { + name: "NoChange", + initialAgents: initialAgents, + initialWorkspaces: initialWorkspaces, + update: &tailnet.WorkspaceUpdate{ + Kind: tailnet.Snapshot, + UpsertedWorkspaces: []*tailnet.Workspace{&ws1, &ws2}, + UpsertedAgents: []*tailnet.Agent{&agent1, &agent2, &agent4}, + DeletedWorkspaces: []*tailnet.Workspace{}, + DeletedAgents: []*tailnet.Agent{}, + }, + expectedDelete: &tailnet.WorkspaceUpdate{ // Expect no *additional* deletions + DeletedWorkspaces: []*tailnet.Workspace{}, + DeletedAgents: []*tailnet.Agent{}, + }, + }, + { + name: "AgentAdded", // Agent 3 added in update + initialAgents: initialAgents, + initialWorkspaces: initialWorkspaces, + update: &tailnet.WorkspaceUpdate{ + Kind: tailnet.Snapshot, + UpsertedWorkspaces: []*tailnet.Workspace{&ws1, &ws2, &ws3}, + UpsertedAgents: []*tailnet.Agent{&agent1, &agent2, &agent3, &agent4}, + DeletedWorkspaces: []*tailnet.Workspace{}, + DeletedAgents: []*tailnet.Agent{}, + }, + expectedDelete: &tailnet.WorkspaceUpdate{ + DeletedWorkspaces: []*tailnet.Workspace{}, + DeletedAgents: []*tailnet.Agent{}, + }, + }, + { + name: "AgentRemovedWorkspaceAlsoRemoved", // Agent 2 removed, ws2 also removed + initialAgents: initialAgents, + initialWorkspaces: initialWorkspaces, + update: &tailnet.WorkspaceUpdate{ + Kind: tailnet.Snapshot, + UpsertedWorkspaces: []*tailnet.Workspace{&ws1}, // ws2 not present + UpsertedAgents: []*tailnet.Agent{&agent1, &agent4}, // agent2 not present + DeletedWorkspaces: []*tailnet.Workspace{}, + DeletedAgents: []*tailnet.Agent{}, + }, + expectedDelete: &tailnet.WorkspaceUpdate{ + DeletedWorkspaces: []*tailnet.Workspace{ + {ID: wsID2, Name: "ws2", Status: proto.Workspace_RUNNING}, + }, // Expect ws2 to be deleted + DeletedAgents: []*tailnet.Agent{ // Expect agent2 to be deleted + {ID: agentID2, Name: "agent2", WorkspaceID: wsID2}, + }, + }, + }, + { + name: "AgentRemovedWorkspaceStays", // Agent 4 removed, but ws1 stays (due to agent1) + initialAgents: initialAgents, + initialWorkspaces: initialWorkspaces, + update: &tailnet.WorkspaceUpdate{ + Kind: tailnet.Snapshot, + UpsertedWorkspaces: []*tailnet.Workspace{&ws1, &ws2}, // ws1 still present + UpsertedAgents: []*tailnet.Agent{&agent1, &agent2}, // agent4 not present + DeletedWorkspaces: []*tailnet.Workspace{}, + DeletedAgents: []*tailnet.Agent{}, + }, + expectedDelete: &tailnet.WorkspaceUpdate{ + DeletedWorkspaces: []*tailnet.Workspace{}, // ws1 should NOT be deleted + DeletedAgents: []*tailnet.Agent{ // Expect agent4 to be deleted + {ID: agentID4, Name: "agent4", WorkspaceID: wsID1}, + }, + }, + }, + { + name: "InitialAgentsEmpty", + initialAgents: map[uuid.UUID]tailnet.Agent{}, // Start with no agents known + initialWorkspaces: map[uuid.UUID]tailnet.Workspace{}, + update: &tailnet.WorkspaceUpdate{ + Kind: tailnet.Snapshot, + UpsertedWorkspaces: []*tailnet.Workspace{&ws1, &ws2}, + UpsertedAgents: []*tailnet.Agent{&agent1, &agent2}, + DeletedWorkspaces: []*tailnet.Workspace{}, + DeletedAgents: []*tailnet.Agent{}, + }, + expectedDelete: &tailnet.WorkspaceUpdate{ // Expect no deletions added + DeletedWorkspaces: []*tailnet.Workspace{}, + DeletedAgents: []*tailnet.Agent{}, + }, + }, + { + name: "UpdateEmpty", // Snapshot says nothing exists + initialAgents: initialAgents, + initialWorkspaces: initialWorkspaces, + update: &tailnet.WorkspaceUpdate{ + Kind: tailnet.Snapshot, + UpsertedWorkspaces: []*tailnet.Workspace{}, + UpsertedAgents: []*tailnet.Agent{}, + DeletedWorkspaces: []*tailnet.Workspace{}, + DeletedAgents: []*tailnet.Agent{}, + }, + expectedDelete: &tailnet.WorkspaceUpdate{ // Expect all initial agents/workspaces to be deleted + DeletedWorkspaces: []*tailnet.Workspace{ + {ID: wsID1, Name: "ws1", Status: proto.Workspace_RUNNING}, + {ID: wsID2, Name: "ws2", Status: proto.Workspace_RUNNING}, + }, // ws1 and ws2 deleted + DeletedAgents: []*tailnet.Agent{ // agent1, agent2, agent4 deleted + {ID: agentID1, Name: "agent1", WorkspaceID: wsID1}, + {ID: agentID2, Name: "agent2", WorkspaceID: wsID2}, + {ID: agentID4, Name: "agent4", WorkspaceID: wsID1}, + }, + }, + }, + { + name: "WorkspaceWithNoAgents", // Snapshot says nothing exists + initialAgents: initialAgents, + initialWorkspaces: map[uuid.UUID]tailnet.Workspace{wsID1: ws1, wsID2: ws2, wsID4: ws4}, // ws4 has no agents + update: &tailnet.WorkspaceUpdate{ + Kind: tailnet.Snapshot, + UpsertedWorkspaces: []*tailnet.Workspace{&ws1, &ws2}, + UpsertedAgents: []*tailnet.Agent{&agent1, &agent2, &agent4}, + DeletedWorkspaces: []*tailnet.Workspace{}, + DeletedAgents: []*tailnet.Agent{}, + }, + expectedDelete: &tailnet.WorkspaceUpdate{ // Expect all initial agents/workspaces to be deleted + DeletedWorkspaces: []*tailnet.Workspace{ + {ID: wsID4, Name: "ws4", Status: proto.Workspace_RUNNING}, + }, // ws4 should be deleted + DeletedAgents: []*tailnet.Agent{}, + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + agentsCopy := make(map[uuid.UUID]tailnet.Agent) + maps.Copy(agentsCopy, tt.initialAgents) + + workspaceCopy := make(map[uuid.UUID]tailnet.Workspace) + maps.Copy(workspaceCopy, tt.initialWorkspaces) + + processSnapshotUpdate(tt.update, agentsCopy, workspaceCopy) + + require.ElementsMatch(t, tt.expectedDelete.DeletedAgents, tt.update.DeletedAgents, "DeletedAgents mismatch") + require.ElementsMatch(t, tt.expectedDelete.DeletedWorkspaces, tt.update.DeletedWorkspaces, "DeletedWorkspaces mismatch") + }) + } +} From f51e0277017bc296be0432d7c762e4e0e5f4440a Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Wed, 30 Apr 2025 16:26:30 +0200 Subject: [PATCH 04/11] fix: fix prebuild omissions (#17579) Fixes accidental omission from https://github.com/coder/coder/pull/17527 --------- Signed-off-by: Danny Kopping (cherry picked from commit 6936a7b5a25659bc776cec0dd9a8b4b82600d292) --- enterprise/coderd/coderd.go | 2 +- enterprise/coderd/coderd_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index ca3531b60db78..8b473e8168ffa 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -1166,5 +1166,5 @@ func (api *API) setupPrebuilds(featureEnabled bool) (agplprebuilds.Reconciliatio reconciler := prebuilds.NewStoreReconciler(api.Database, api.Pubsub, api.DeploymentValues.Prebuilds, api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry) - return reconciler, prebuilds.EnterpriseClaimer{} + return reconciler, prebuilds.NewEnterpriseClaimer(api.Database) } diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 4a3c47e56a671..446fce042d70f 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -331,7 +331,7 @@ func TestEntitlements_Prebuilds(t *testing.T) { if tc.expectedEnabled { require.IsType(t, &prebuilds.StoreReconciler{}, *reconciler) - require.IsType(t, prebuilds.EnterpriseClaimer{}, *claimer) + require.IsType(t, &prebuilds.EnterpriseClaimer{}, *claimer) } else { require.Equal(t, &agplprebuilds.DefaultReconciler, reconciler) require.Equal(t, &agplprebuilds.DefaultClaimer, claimer) From 795fba976c86a84197c2017088981e001a96f0b0 Mon Sep 17 00:00:00 2001 From: Yevhenii Shcherbina Date: Thu, 1 May 2025 17:26:30 -0400 Subject: [PATCH 05/11] fix: fix bug with deletion of prebuilt workspaces (#17652) Don't specify the template version for a delete transition, because the prebuilt workspace may have been created using an older template version. If the template version isn't explicitly set, the builder will automatically use the version from the last workspace build - which is the desired behavior. (cherry picked from commit ef11d4f769abf48c9b55624ade7cb07152374df9) --- enterprise/coderd/prebuilds/reconcile.go | 15 ++-- enterprise/coderd/prebuilds/reconcile_test.go | 69 +++++++++++++++++++ 2 files changed, 79 insertions(+), 5 deletions(-) diff --git a/enterprise/coderd/prebuilds/reconcile.go b/enterprise/coderd/prebuilds/reconcile.go index 1b99e46a56680..5639678c1b9db 100644 --- a/enterprise/coderd/prebuilds/reconcile.go +++ b/enterprise/coderd/prebuilds/reconcile.go @@ -549,13 +549,18 @@ func (c *StoreReconciler) provision( builder := wsbuilder.New(workspace, transition). Reason(database.BuildReasonInitiator). Initiator(prebuilds.SystemUserID). - VersionID(template.ActiveVersionID). - MarkPrebuild(). - TemplateVersionPresetID(presetID) + MarkPrebuild() - // We only inject the required params when the prebuild is being created. - // This mirrors the behavior of regular workspace deletion (see cli/delete.go). if transition != database.WorkspaceTransitionDelete { + // We don't specify the version for a delete transition, + // because the prebuilt workspace may have been created using an older template version. + // If the version isn't explicitly set, the builder will automatically use the version + // from the last workspace build — which is the desired behavior. + builder = builder.VersionID(template.ActiveVersionID) + + // We only inject the required params when the prebuild is being created. + // This mirrors the behavior of regular workspace deletion (see cli/delete.go). + builder = builder.TemplateVersionPresetID(presetID) builder = builder.RichParameterValues(params) } diff --git a/enterprise/coderd/prebuilds/reconcile_test.go b/enterprise/coderd/prebuilds/reconcile_test.go index 9783b215f185b..a44a589264c1d 100644 --- a/enterprise/coderd/prebuilds/reconcile_test.go +++ b/enterprise/coderd/prebuilds/reconcile_test.go @@ -554,6 +554,75 @@ func TestInvalidPreset(t *testing.T) { } } +func TestDeletionOfPrebuiltWorkspaceWithInvalidPreset(t *testing.T) { + t.Parallel() + + if !dbtestutil.WillUsePostgres() { + t.Skip("This test requires postgres") + } + + templateDeleted := false + + clock := quartz.NewMock(t) + ctx := testutil.Context(t, testutil.WaitShort) + cfg := codersdk.PrebuildsConfig{} + logger := slogtest.Make( + t, &slogtest.Options{IgnoreErrors: true}, + ).Leveled(slog.LevelDebug) + db, pubSub := dbtestutil.NewDB(t) + controller := prebuilds.NewStoreReconciler(db, pubSub, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry()) + + ownerID := uuid.New() + dbgen.User(t, db, database.User{ + ID: ownerID, + }) + org, template := setupTestDBTemplate(t, db, ownerID, templateDeleted) + templateVersionID := setupTestDBTemplateVersion(ctx, t, clock, db, pubSub, org.ID, ownerID, template.ID) + preset := setupTestDBPreset(t, db, templateVersionID, 1, uuid.New().String()) + prebuiltWorkspace := setupTestDBPrebuild( + t, + clock, + db, + pubSub, + database.WorkspaceTransitionStart, + database.ProvisionerJobStatusSucceeded, + org.ID, + preset, + template.ID, + templateVersionID, + ) + + workspaces, err := db.GetWorkspacesByTemplateID(ctx, template.ID) + require.NoError(t, err) + // make sure we have only one workspace + require.Equal(t, 1, len(workspaces)) + + // Create a new template version and mark it as active. + // This marks the previous template version as inactive. + templateVersionID = setupTestDBTemplateVersion(ctx, t, clock, db, pubSub, org.ID, ownerID, template.ID) + // Add required param, which is not set in preset. + // It means that creating of new prebuilt workspace will fail, but we should be able to clean up old prebuilt workspaces. + dbgen.TemplateVersionParameter(t, db, database.TemplateVersionParameter{ + TemplateVersionID: templateVersionID, + Name: "required-param", + Description: "required param which isn't set in preset", + Type: "bool", + DefaultValue: "", + Required: true, + }) + + // Old prebuilt workspace should be deleted. + require.NoError(t, controller.ReconcileAll(ctx)) + + builds, err := db.GetWorkspaceBuildsByWorkspaceID(ctx, database.GetWorkspaceBuildsByWorkspaceIDParams{ + WorkspaceID: prebuiltWorkspace.ID, + }) + require.NoError(t, err) + // Make sure old prebuild workspace was deleted, despite it contains required parameter which isn't set in preset. + require.Equal(t, 2, len(builds)) + require.Equal(t, database.WorkspaceTransitionDelete, builds[0].Transition) +} + func TestRunLoop(t *testing.T) { t.Parallel() From b4edac1c5c90f1e53938c003190db12c722eb89c Mon Sep 17 00:00:00 2001 From: Yevhenii Shcherbina Date: Thu, 1 May 2025 04:52:23 -0400 Subject: [PATCH 06/11] fix: fix for prebuilds claiming and deletion (#17624) PR contains: - fix for claiming & deleting prebuilds with immutable params - unit test for claiming scenario - unit test for deletion scenario The parameter resolver was failing when deleting/claiming prebuilds because a value for a previously-used parameter was provided to the resolver, but since the value was unchanged (it's coming from the preset) it failed in the resolver. The resolver was missing a check to see if the old value != new value; if the values match then there's no mutation of an immutable parameter. --------- Signed-off-by: Danny Kopping (cherry picked from commit 98e5611e164ca889e99fab0aa4c5870c07d181f8) --- codersdk/richparameters.go | 2 +- codersdk/richparameters_test.go | 57 ++++++++++++++++--- enterprise/coderd/prebuilds/claim_test.go | 16 +++++- enterprise/coderd/prebuilds/reconcile_test.go | 19 ++++++- 4 files changed, 81 insertions(+), 13 deletions(-) diff --git a/codersdk/richparameters.go b/codersdk/richparameters.go index 2ddd5d00f6c41..6fc6b8e0c343f 100644 --- a/codersdk/richparameters.go +++ b/codersdk/richparameters.go @@ -164,7 +164,7 @@ type ParameterResolver struct { // resolves the correct value. It returns the value of the parameter, if valid, and an error if invalid. func (r *ParameterResolver) ValidateResolve(p TemplateVersionParameter, v *WorkspaceBuildParameter) (value string, err error) { prevV := r.findLastValue(p) - if !p.Mutable && v != nil && prevV != nil { + if !p.Mutable && v != nil && prevV != nil && v.Value != prevV.Value { return "", xerrors.Errorf("Parameter %q is not mutable, so it can't be updated after creating a workspace.", p.Name) } if p.Required && v == nil && prevV == nil { diff --git a/codersdk/richparameters_test.go b/codersdk/richparameters_test.go index 16365f7c2f416..5635a82beb6c6 100644 --- a/codersdk/richparameters_test.go +++ b/codersdk/richparameters_test.go @@ -1,6 +1,7 @@ package codersdk_test import ( + "fmt" "testing" "github.com/stretchr/testify/require" @@ -121,20 +122,60 @@ func TestParameterResolver_ValidateResolve_NewOverridesOld(t *testing.T) { func TestParameterResolver_ValidateResolve_Immutable(t *testing.T) { t.Parallel() uut := codersdk.ParameterResolver{ - Rich: []codersdk.WorkspaceBuildParameter{{Name: "n", Value: "5"}}, + Rich: []codersdk.WorkspaceBuildParameter{{Name: "n", Value: "old"}}, } p := codersdk.TemplateVersionParameter{ Name: "n", - Type: "number", + Type: "string", Required: true, Mutable: false, } - v, err := uut.ValidateResolve(p, &codersdk.WorkspaceBuildParameter{ - Name: "n", - Value: "6", - }) - require.Error(t, err) - require.Equal(t, "", v) + + cases := []struct { + name string + newValue string + expectedErr string + }{ + { + name: "mutation", + newValue: "new", // "new" != "old" + expectedErr: fmt.Sprintf("Parameter %q is not mutable", p.Name), + }, + { + // Values are case-sensitive. + name: "case change", + newValue: "Old", // "Old" != "old" + expectedErr: fmt.Sprintf("Parameter %q is not mutable", p.Name), + }, + { + name: "default", + newValue: "", // "" != "old" + expectedErr: fmt.Sprintf("Parameter %q is not mutable", p.Name), + }, + { + name: "no change", + newValue: "old", // "old" == "old" + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + v, err := uut.ValidateResolve(p, &codersdk.WorkspaceBuildParameter{ + Name: "n", + Value: tc.newValue, + }) + + if tc.expectedErr == "" { + require.NoError(t, err) + require.Equal(t, tc.newValue, v) + } else { + require.ErrorContains(t, err, tc.expectedErr) + require.Equal(t, "", v) + } + }) + } } func TestRichParameterValidation(t *testing.T) { diff --git a/enterprise/coderd/prebuilds/claim_test.go b/enterprise/coderd/prebuilds/claim_test.go index 145095e6533e7..ad31d2b4eff1b 100644 --- a/enterprise/coderd/prebuilds/claim_test.go +++ b/enterprise/coderd/prebuilds/claim_test.go @@ -329,9 +329,8 @@ func TestClaimPrebuild(t *testing.T) { require.NoError(t, err) stopBuild, err := userClient.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{ - TemplateVersionID: version.ID, - Transition: codersdk.WorkspaceTransitionStop, - RichParameterValues: wp, + TemplateVersionID: version.ID, + Transition: codersdk.WorkspaceTransitionStop, }) require.NoError(t, err) coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, stopBuild.ID) @@ -369,6 +368,17 @@ func templateWithAgentAndPresetsWithPrebuilds(desiredInstances int32) *echo.Resp }, }, }, + // Make sure immutable params don't break claiming logic + Parameters: []*proto.RichParameter{ + { + Name: "k1", + Description: "immutable param", + Type: "string", + DefaultValue: "", + Required: false, + Mutable: false, + }, + }, Presets: []*proto.Preset{ { Name: "preset-a", diff --git a/enterprise/coderd/prebuilds/reconcile_test.go b/enterprise/coderd/prebuilds/reconcile_test.go index a44a589264c1d..a1732c8391d11 100644 --- a/enterprise/coderd/prebuilds/reconcile_test.go +++ b/enterprise/coderd/prebuilds/reconcile_test.go @@ -970,6 +970,16 @@ func setupTestDBTemplateVersion( ID: templateID, ActiveVersionID: templateVersion.ID, })) + // Make sure immutable params don't break prebuilt workspace deletion logic + dbgen.TemplateVersionParameter(t, db, database.TemplateVersionParameter{ + TemplateVersionID: templateVersion.ID, + Name: "test", + Description: "required & immutable param", + Type: "string", + DefaultValue: "", + Required: true, + Mutable: false, + }) return templateVersion.ID } @@ -1068,7 +1078,7 @@ func setupTestDBWorkspace( OrganizationID: orgID, Error: buildError, }) - dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + workspaceBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspace.ID, InitiatorID: initiatorID, TemplateVersionID: templateVersionID, @@ -1077,6 +1087,13 @@ func setupTestDBWorkspace( Transition: transition, CreatedAt: clock.Now(), }) + dbgen.WorkspaceBuildParameters(t, db, []database.WorkspaceBuildParameter{ + { + WorkspaceBuildID: workspaceBuild.ID, + Name: "test", + Value: "test", + }, + }) return workspace } From 3fd9c1af20ed05d5f3913195f290f7806178fc3d Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Fri, 2 May 2025 12:17:01 +0200 Subject: [PATCH 07/11] feat: collect database metrics (#17635) Currently we don't have a way to get insight into Postgres connections being exhausted. By using the prometheus' [`DBStats` collector](https://github.com/prometheus/client_golang/blob/main/prometheus/collectors/dbstats_collector.go), we get some insight out-of-the-box. ``` # HELP go_sql_idle_connections The number of idle connections. # TYPE go_sql_idle_connections gauge go_sql_idle_connections{db_name="coder"} 1 # HELP go_sql_in_use_connections The number of connections currently in use. # TYPE go_sql_in_use_connections gauge go_sql_in_use_connections{db_name="coder"} 2 # HELP go_sql_max_idle_closed_total The total number of connections closed due to SetMaxIdleConns. # TYPE go_sql_max_idle_closed_total counter go_sql_max_idle_closed_total{db_name="coder"} 112 # HELP go_sql_max_idle_time_closed_total The total number of connections closed due to SetConnMaxIdleTime. # TYPE go_sql_max_idle_time_closed_total counter go_sql_max_idle_time_closed_total{db_name="coder"} 0 # HELP go_sql_max_lifetime_closed_total The total number of connections closed due to SetConnMaxLifetime. # TYPE go_sql_max_lifetime_closed_total counter go_sql_max_lifetime_closed_total{db_name="coder"} 0 # HELP go_sql_max_open_connections Maximum number of open connections to the database. # TYPE go_sql_max_open_connections gauge go_sql_max_open_connections{db_name="coder"} 10 # HELP go_sql_open_connections The number of established connections both in use and idle. # TYPE go_sql_open_connections gauge go_sql_open_connections{db_name="coder"} 3 # HELP go_sql_wait_count_total The total number of connections waited for. # TYPE go_sql_wait_count_total counter go_sql_wait_count_total{db_name="coder"} 28 # HELP go_sql_wait_duration_seconds_total The total time blocked waiting for a new connection. # TYPE go_sql_wait_duration_seconds_total counter go_sql_wait_duration_seconds_total{db_name="coder"} 0.086936235 ``` `go_sql_wait_count_total` is the metric I'm most interested in gaining, but the others are also very useful. Changing the prefix is easy (`prometheus.WrapRegistererWithPrefix`), but getting rid of the `go_` segment is not quite so easy. I've kept the changeset small for now. **NOTE:** I imported a library to determine the database name from the given conn string. It's [not as simple](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) as one might hope. The database name is used for the `db_name` label. --------- Signed-off-by: Danny Kopping (cherry picked from commit c27866221822e4112bddb43f8ad102a5589c98ab) --- cli/server.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cli/server.go b/cli/server.go index 39cfa52571595..580dae369446c 100644 --- a/cli/server.go +++ b/cli/server.go @@ -739,6 +739,15 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. _ = sqlDB.Close() }() + if options.DeploymentValues.Prometheus.Enable { + // At this stage we don't think the database name serves much purpose in these metrics. + // It requires parsing the DSN to determine it, which requires pulling in another dependency + // (i.e. https://github.com/jackc/pgx), but it's rather heavy. + // The conn string (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) can + // take different forms, which make parsing non-trivial. + options.PrometheusRegistry.MustRegister(collectors.NewDBStatsCollector(sqlDB, "")) + } + options.Database = database.New(sqlDB) ps, err := pubsub.New(ctx, logger.Named("pubsub"), sqlDB, dbURL) if err != nil { From ed37b0295d7d63ef8575cc9af109e9b987045866 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Mon, 5 May 2025 11:54:18 +0200 Subject: [PATCH 08/11] fix: move pubsub publishing out of database transactions to avoid conn exhaustion (#17648) Database transactions hold onto connections, and `pubsub.Publish` tries to acquire a connection of its own. If the latter is called within a transaction, this can lead to connection exhaustion. I plan two follow-ups to this PR: 1. Make connection counts tuneable https://github.com/coder/coder/blob/main/cli/server.go#L2360-L2376 We will then be able to write tests showing how connection exhaustion occurs. 2. Write a linter/ruleguard to prevent `pubsub.Publish` from being called within a transaction. --------- Signed-off-by: Danny Kopping (cherry picked from commit a646478aed63996d5257e425ffd56318eda91457) --- enterprise/coderd/prebuilds/reconcile.go | 61 ++++-- enterprise/coderd/prebuilds/reconcile_test.go | 198 ++++++++++-------- 2 files changed, 156 insertions(+), 103 deletions(-) diff --git a/enterprise/coderd/prebuilds/reconcile.go b/enterprise/coderd/prebuilds/reconcile.go index 5639678c1b9db..c31da695637ba 100644 --- a/enterprise/coderd/prebuilds/reconcile.go +++ b/enterprise/coderd/prebuilds/reconcile.go @@ -40,10 +40,11 @@ type StoreReconciler struct { registerer prometheus.Registerer metrics *MetricsCollector - cancelFn context.CancelCauseFunc - running atomic.Bool - stopped atomic.Bool - done chan struct{} + cancelFn context.CancelCauseFunc + running atomic.Bool + stopped atomic.Bool + done chan struct{} + provisionNotifyCh chan database.ProvisionerJob } var _ prebuilds.ReconciliationOrchestrator = &StoreReconciler{} @@ -56,13 +57,14 @@ func NewStoreReconciler(store database.Store, registerer prometheus.Registerer, ) *StoreReconciler { reconciler := &StoreReconciler{ - store: store, - pubsub: ps, - logger: logger, - cfg: cfg, - clock: clock, - registerer: registerer, - done: make(chan struct{}, 1), + store: store, + pubsub: ps, + logger: logger, + cfg: cfg, + clock: clock, + registerer: registerer, + done: make(chan struct{}, 1), + provisionNotifyCh: make(chan database.ProvisionerJob, 10), } reconciler.metrics = NewMetricsCollector(store, logger, reconciler) @@ -100,6 +102,29 @@ func (c *StoreReconciler) Run(ctx context.Context) { // NOTE: without this atomic bool, Stop might race with Run for the c.cancelFn above. c.running.Store(true) + // Publish provisioning jobs outside of database transactions. + // A connection is held while a database transaction is active; PGPubsub also tries to acquire a new connection on + // Publish, so we can exhaust available connections. + // + // A single worker dequeues from the channel, which should be sufficient. + // If any messages are missed due to congestion or errors, provisionerdserver has a backup polling mechanism which + // will periodically pick up any queued jobs (see poll(time.Duration) in coderd/provisionerdserver/acquirer.go). + go func() { + for { + select { + case <-c.done: + return + case <-ctx.Done(): + return + case job := <-c.provisionNotifyCh: + err := provisionerjobs.PostJob(c.pubsub, job) + if err != nil { + c.logger.Error(ctx, "failed to post provisioner job to pubsub", slog.Error(err)) + } + } + } + }() + for { select { // TODO: implement pubsub listener to allow reconciling a specific template imperatively once it has been changed, @@ -576,10 +601,16 @@ func (c *StoreReconciler) provision( return xerrors.Errorf("provision workspace: %w", err) } - err = provisionerjobs.PostJob(c.pubsub, *provisionerJob) - if err != nil { - // Client probably doesn't care about this error, so just log it. - c.logger.Error(ctx, "failed to post provisioner job to pubsub", slog.Error(err)) + if provisionerJob == nil { + return nil + } + + // Publish provisioner job event outside of transaction. + select { + case c.provisionNotifyCh <- *provisionerJob: + default: // channel full, drop the message; provisioner will pick this job up later with its periodic check, though. + c.logger.Warn(ctx, "provisioner job notification queue full, dropping", + slog.F("job_id", provisionerJob.ID), slog.F("prebuild_id", prebuildID.String())) } c.logger.Info(ctx, "prebuild job scheduled", slog.F("transition", transition), diff --git a/enterprise/coderd/prebuilds/reconcile_test.go b/enterprise/coderd/prebuilds/reconcile_test.go index a1732c8391d11..a1666134a7965 100644 --- a/enterprise/coderd/prebuilds/reconcile_test.go +++ b/enterprise/coderd/prebuilds/reconcile_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/prometheus/client_golang/prometheus" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/util/slice" @@ -303,100 +304,106 @@ func TestPrebuildReconciliation(t *testing.T) { for _, prebuildLatestTransition := range tc.prebuildLatestTransitions { for _, prebuildJobStatus := range tc.prebuildJobStatuses { for _, templateDeleted := range tc.templateDeleted { - t.Run(fmt.Sprintf("%s - %s - %s", tc.name, prebuildLatestTransition, prebuildJobStatus), func(t *testing.T) { - t.Parallel() - t.Cleanup(func() { - if t.Failed() { - t.Logf("failed to run test: %s", tc.name) - t.Logf("templateVersionActive: %t", templateVersionActive) - t.Logf("prebuildLatestTransition: %s", prebuildLatestTransition) - t.Logf("prebuildJobStatus: %s", prebuildJobStatus) + for _, useBrokenPubsub := range []bool{true, false} { + t.Run(fmt.Sprintf("%s - %s - %s - pubsub_broken=%v", tc.name, prebuildLatestTransition, prebuildJobStatus, useBrokenPubsub), func(t *testing.T) { + t.Parallel() + t.Cleanup(func() { + if t.Failed() { + t.Logf("failed to run test: %s", tc.name) + t.Logf("templateVersionActive: %t", templateVersionActive) + t.Logf("prebuildLatestTransition: %s", prebuildLatestTransition) + t.Logf("prebuildJobStatus: %s", prebuildJobStatus) + } + }) + clock := quartz.NewMock(t) + ctx := testutil.Context(t, testutil.WaitShort) + cfg := codersdk.PrebuildsConfig{} + logger := slogtest.Make( + t, &slogtest.Options{IgnoreErrors: true}, + ).Leveled(slog.LevelDebug) + db, pubSub := dbtestutil.NewDB(t) + + ownerID := uuid.New() + dbgen.User(t, db, database.User{ + ID: ownerID, + }) + org, template := setupTestDBTemplate(t, db, ownerID, templateDeleted) + templateVersionID := setupTestDBTemplateVersion( + ctx, + t, + clock, + db, + pubSub, + org.ID, + ownerID, + template.ID, + ) + preset := setupTestDBPreset( + t, + db, + templateVersionID, + 1, + uuid.New().String(), + ) + prebuild := setupTestDBPrebuild( + t, + clock, + db, + pubSub, + prebuildLatestTransition, + prebuildJobStatus, + org.ID, + preset, + template.ID, + templateVersionID, + ) + + if !templateVersionActive { + // Create a new template version and mark it as active + // This marks the template version that we care about as inactive + setupTestDBTemplateVersion(ctx, t, clock, db, pubSub, org.ID, ownerID, template.ID) } - }) - clock := quartz.NewMock(t) - ctx := testutil.Context(t, testutil.WaitShort) - cfg := codersdk.PrebuildsConfig{} - logger := slogtest.Make( - t, &slogtest.Options{IgnoreErrors: true}, - ).Leveled(slog.LevelDebug) - db, pubSub := dbtestutil.NewDB(t) - controller := prebuilds.NewStoreReconciler(db, pubSub, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry()) - - ownerID := uuid.New() - dbgen.User(t, db, database.User{ - ID: ownerID, - }) - org, template := setupTestDBTemplate(t, db, ownerID, templateDeleted) - templateVersionID := setupTestDBTemplateVersion( - ctx, - t, - clock, - db, - pubSub, - org.ID, - ownerID, - template.ID, - ) - preset := setupTestDBPreset( - t, - db, - templateVersionID, - 1, - uuid.New().String(), - ) - prebuild := setupTestDBPrebuild( - t, - clock, - db, - pubSub, - prebuildLatestTransition, - prebuildJobStatus, - org.ID, - preset, - template.ID, - templateVersionID, - ) - - if !templateVersionActive { - // Create a new template version and mark it as active - // This marks the template version that we care about as inactive - setupTestDBTemplateVersion(ctx, t, clock, db, pubSub, org.ID, ownerID, template.ID) - } - - // Run the reconciliation multiple times to ensure idempotency - // 8 was arbitrary, but large enough to reasonably trust the result - for i := 1; i <= 8; i++ { - require.NoErrorf(t, controller.ReconcileAll(ctx), "failed on iteration %d", i) - - if tc.shouldCreateNewPrebuild != nil { - newPrebuildCount := 0 - workspaces, err := db.GetWorkspacesByTemplateID(ctx, template.ID) - require.NoError(t, err) - for _, workspace := range workspaces { - if workspace.ID != prebuild.ID { - newPrebuildCount++ + + if useBrokenPubsub { + pubSub = &brokenPublisher{Pubsub: pubSub} + } + controller := prebuilds.NewStoreReconciler(db, pubSub, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry()) + + // Run the reconciliation multiple times to ensure idempotency + // 8 was arbitrary, but large enough to reasonably trust the result + for i := 1; i <= 8; i++ { + require.NoErrorf(t, controller.ReconcileAll(ctx), "failed on iteration %d", i) + + if tc.shouldCreateNewPrebuild != nil { + newPrebuildCount := 0 + workspaces, err := db.GetWorkspacesByTemplateID(ctx, template.ID) + require.NoError(t, err) + for _, workspace := range workspaces { + if workspace.ID != prebuild.ID { + newPrebuildCount++ + } } + // This test configures a preset that desires one prebuild. + // In cases where new prebuilds should be created, there should be exactly one. + require.Equal(t, *tc.shouldCreateNewPrebuild, newPrebuildCount == 1) } - // This test configures a preset that desires one prebuild. - // In cases where new prebuilds should be created, there should be exactly one. - require.Equal(t, *tc.shouldCreateNewPrebuild, newPrebuildCount == 1) - } - if tc.shouldDeleteOldPrebuild != nil { - builds, err := db.GetWorkspaceBuildsByWorkspaceID(ctx, database.GetWorkspaceBuildsByWorkspaceIDParams{ - WorkspaceID: prebuild.ID, - }) - require.NoError(t, err) - if *tc.shouldDeleteOldPrebuild { - require.Equal(t, 2, len(builds)) - require.Equal(t, database.WorkspaceTransitionDelete, builds[0].Transition) - } else { - require.Equal(t, 1, len(builds)) - require.Equal(t, prebuildLatestTransition, builds[0].Transition) + if tc.shouldDeleteOldPrebuild != nil { + builds, err := db.GetWorkspaceBuildsByWorkspaceID(ctx, database.GetWorkspaceBuildsByWorkspaceIDParams{ + WorkspaceID: prebuild.ID, + }) + require.NoError(t, err) + if *tc.shouldDeleteOldPrebuild { + require.Equal(t, 2, len(builds)) + require.Equal(t, database.WorkspaceTransitionDelete, builds[0].Transition) + } else { + require.Equal(t, 1, len(builds)) + require.Equal(t, prebuildLatestTransition, builds[0].Transition) + } } } - } - }) + }) + } } } } @@ -404,6 +411,21 @@ func TestPrebuildReconciliation(t *testing.T) { } } +// brokenPublisher is used to validate that Publish() calls which always fail do not affect the reconciler's behavior, +// since the messages published are not essential but merely advisory. +type brokenPublisher struct { + pubsub.Pubsub +} + +// Publish deliberately fails. +// I'm explicitly _not_ checking for EventJobPosted (coderd/database/provisionerjobs/provisionerjobs.go) since that +// requires too much knowledge of the underlying implementation. +func (*brokenPublisher) Publish(event string, _ []byte) error { + // Mimick some work being done. + <-time.After(testutil.IntervalFast) + return xerrors.Errorf("failed to publish %q", event) +} + func TestMultiplePresetsPerTemplateVersion(t *testing.T) { t.Parallel() From f3bbbefc8ea442148b619e98151c3edc99a1c6d2 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 13 May 2025 20:27:41 +0200 Subject: [PATCH 09/11] feat: fetch prebuilds metrics state in background (#17792) `Collect()` is called whenever the `/metrics` endpoint is hit to retrieve metrics. The queries used in prebuilds metrics collection are quite heavy, and we want to avoid having them running concurrently / too often to keep db load down. Here I'm moving towards a background retrieval of the state required to set the metrics, which gets invalidated every interval. Also introduces `coderd_prebuilt_workspaces_metrics_last_updated` which operators can use to determine when these metrics go stale. See https://github.com/coder/coder/pull/17789 as well. --------- Signed-off-by: Danny Kopping (cherry picked from commit b2a1de9e2a035e21f3d6d7a93b423bb3cc5671d0) --- .../coderd/prebuilds/metricscollector.go | 106 ++++++++++++++---- .../coderd/prebuilds/metricscollector_test.go | 5 + enterprise/coderd/prebuilds/reconcile.go | 22 +++- 3 files changed, 109 insertions(+), 24 deletions(-) diff --git a/enterprise/coderd/prebuilds/metricscollector.go b/enterprise/coderd/prebuilds/metricscollector.go index 7b55227effffa..76089c025243d 100644 --- a/enterprise/coderd/prebuilds/metricscollector.go +++ b/enterprise/coderd/prebuilds/metricscollector.go @@ -2,14 +2,17 @@ package prebuilds import ( "context" + "fmt" + "sync/atomic" "time" - "cdr.dev/slog" - "github.com/prometheus/client_golang/prometheus" + "golang.org/x/xerrors" + + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/prebuilds" ) @@ -55,20 +58,34 @@ var ( labels, nil, ) + lastUpdateDesc = prometheus.NewDesc( + "coderd_prebuilt_workspaces_metrics_last_updated", + "The unix timestamp when the metrics related to prebuilt workspaces were last updated; these metrics are cached.", + []string{}, + nil, + ) +) + +const ( + metricsUpdateInterval = time.Second * 15 + metricsUpdateTimeout = time.Second * 10 ) type MetricsCollector struct { database database.Store logger slog.Logger snapshotter prebuilds.StateSnapshotter + + latestState atomic.Pointer[metricsState] } var _ prometheus.Collector = new(MetricsCollector) func NewMetricsCollector(db database.Store, logger slog.Logger, snapshotter prebuilds.StateSnapshotter) *MetricsCollector { + log := logger.Named("prebuilds_metrics_collector") return &MetricsCollector{ database: db, - logger: logger.Named("prebuilds_metrics_collector"), + logger: log, snapshotter: snapshotter, } } @@ -80,38 +97,34 @@ func (*MetricsCollector) Describe(descCh chan<- *prometheus.Desc) { descCh <- desiredPrebuildsDesc descCh <- runningPrebuildsDesc descCh <- eligiblePrebuildsDesc + descCh <- lastUpdateDesc } +// Collect uses the cached state to set configured metrics. +// The state is cached because this function can be called multiple times per second and retrieving the current state +// is an expensive operation. func (mc *MetricsCollector) Collect(metricsCh chan<- prometheus.Metric) { - // nolint:gocritic // We need to set an authz context to read metrics from the db. - ctx, cancel := context.WithTimeout(dbauthz.AsPrebuildsOrchestrator(context.Background()), 10*time.Second) - defer cancel() - prebuildMetrics, err := mc.database.GetPrebuildMetrics(ctx) - if err != nil { - mc.logger.Error(ctx, "failed to get prebuild metrics", slog.Error(err)) + currentState := mc.latestState.Load() // Grab a copy; it's ok if it goes stale during the course of this func. + if currentState == nil { + mc.logger.Warn(context.Background(), "failed to set prebuilds metrics; state not set") + metricsCh <- prometheus.MustNewConstMetric(lastUpdateDesc, prometheus.GaugeValue, 0) return } - for _, metric := range prebuildMetrics { + for _, metric := range currentState.prebuildMetrics { metricsCh <- prometheus.MustNewConstMetric(createdPrebuildsDesc, prometheus.CounterValue, float64(metric.CreatedCount), metric.TemplateName, metric.PresetName, metric.OrganizationName) metricsCh <- prometheus.MustNewConstMetric(failedPrebuildsDesc, prometheus.CounterValue, float64(metric.FailedCount), metric.TemplateName, metric.PresetName, metric.OrganizationName) metricsCh <- prometheus.MustNewConstMetric(claimedPrebuildsDesc, prometheus.CounterValue, float64(metric.ClaimedCount), metric.TemplateName, metric.PresetName, metric.OrganizationName) } - snapshot, err := mc.snapshotter.SnapshotState(ctx, mc.database) - if err != nil { - mc.logger.Error(ctx, "failed to get latest prebuild state", slog.Error(err)) - return - } - - for _, preset := range snapshot.Presets { + for _, preset := range currentState.snapshot.Presets { if !preset.UsingActiveVersion { continue } - presetSnapshot, err := snapshot.FilterByPreset(preset.ID) + presetSnapshot, err := currentState.snapshot.FilterByPreset(preset.ID) if err != nil { - mc.logger.Error(ctx, "failed to filter by preset", slog.Error(err)) + mc.logger.Error(context.Background(), "failed to filter by preset", slog.Error(err)) continue } state := presetSnapshot.CalculateState() @@ -120,4 +133,57 @@ func (mc *MetricsCollector) Collect(metricsCh chan<- prometheus.Metric) { metricsCh <- prometheus.MustNewConstMetric(runningPrebuildsDesc, prometheus.GaugeValue, float64(state.Actual), preset.TemplateName, preset.Name, preset.OrganizationName) metricsCh <- prometheus.MustNewConstMetric(eligiblePrebuildsDesc, prometheus.GaugeValue, float64(state.Eligible), preset.TemplateName, preset.Name, preset.OrganizationName) } + + metricsCh <- prometheus.MustNewConstMetric(lastUpdateDesc, prometheus.GaugeValue, float64(currentState.createdAt.Unix())) +} + +type metricsState struct { + prebuildMetrics []database.GetPrebuildMetricsRow + snapshot *prebuilds.GlobalSnapshot + createdAt time.Time +} + +// BackgroundFetch updates the metrics state every given interval. +func (mc *MetricsCollector) BackgroundFetch(ctx context.Context, updateInterval, updateTimeout time.Duration) { + tick := time.NewTicker(time.Nanosecond) + defer tick.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-tick.C: + // Tick immediately, then set regular interval. + tick.Reset(updateInterval) + + if err := mc.UpdateState(ctx, updateTimeout); err != nil { + mc.logger.Error(ctx, "failed to update prebuilds metrics state", slog.Error(err)) + } + } + } +} + +// UpdateState builds the current metrics state. +func (mc *MetricsCollector) UpdateState(ctx context.Context, timeout time.Duration) error { + start := time.Now() + fetchCtx, fetchCancel := context.WithTimeout(ctx, timeout) + defer fetchCancel() + + prebuildMetrics, err := mc.database.GetPrebuildMetrics(fetchCtx) + if err != nil { + return xerrors.Errorf("fetch prebuild metrics: %w", err) + } + + snapshot, err := mc.snapshotter.SnapshotState(fetchCtx, mc.database) + if err != nil { + return xerrors.Errorf("snapshot state: %w", err) + } + mc.logger.Debug(ctx, "fetched prebuilds metrics state", slog.F("duration_secs", fmt.Sprintf("%.2f", time.Since(start).Seconds()))) + + mc.latestState.Store(&metricsState{ + prebuildMetrics: prebuildMetrics, + snapshot: snapshot, + createdAt: dbtime.Now(), + }) + return nil } diff --git a/enterprise/coderd/prebuilds/metricscollector_test.go b/enterprise/coderd/prebuilds/metricscollector_test.go index 859509ced6635..de3f5d017f715 100644 --- a/enterprise/coderd/prebuilds/metricscollector_test.go +++ b/enterprise/coderd/prebuilds/metricscollector_test.go @@ -16,6 +16,7 @@ import ( "github.com/coder/quartz" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds" @@ -248,6 +249,10 @@ func TestMetricsCollector(t *testing.T) { setupTestDBWorkspaceAgent(t, db, workspace.ID, eligible) } + // Force an update to the metrics state to allow the collector to collect fresh metrics. + // nolint:gocritic // Authz context needed to retrieve state. + require.NoError(t, collector.UpdateState(dbauthz.AsPrebuildsOrchestrator(ctx), testutil.WaitLong)) + metricsFamilies, err := registry.Gather() require.NoError(t, err) diff --git a/enterprise/coderd/prebuilds/reconcile.go b/enterprise/coderd/prebuilds/reconcile.go index c31da695637ba..79a8baa337e72 100644 --- a/enterprise/coderd/prebuilds/reconcile.go +++ b/enterprise/coderd/prebuilds/reconcile.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "math" + "sync" "sync/atomic" "time" @@ -67,10 +68,12 @@ func NewStoreReconciler(store database.Store, provisionNotifyCh: make(chan database.ProvisionerJob, 10), } - reconciler.metrics = NewMetricsCollector(store, logger, reconciler) - if err := registerer.Register(reconciler.metrics); err != nil { - // If the registerer fails to register the metrics collector, it's not fatal. - logger.Error(context.Background(), "failed to register prometheus metrics", slog.Error(err)) + if registerer != nil { + reconciler.metrics = NewMetricsCollector(store, logger, reconciler) + if err := registerer.Register(reconciler.metrics); err != nil { + // If the registerer fails to register the metrics collector, it's not fatal. + logger.Error(context.Background(), "failed to register prometheus metrics", slog.Error(err)) + } } return reconciler @@ -87,9 +90,11 @@ func (c *StoreReconciler) Run(ctx context.Context) { slog.F("backoff_interval", c.cfg.ReconciliationBackoffInterval.String()), slog.F("backoff_lookback", c.cfg.ReconciliationBackoffLookback.String())) + var wg sync.WaitGroup ticker := c.clock.NewTicker(reconciliationInterval) defer ticker.Stop() defer func() { + wg.Wait() c.done <- struct{}{} }() @@ -97,6 +102,15 @@ func (c *StoreReconciler) Run(ctx context.Context) { ctx, cancel := context.WithCancelCause(dbauthz.AsPrebuildsOrchestrator(ctx)) c.cancelFn = cancel + // Start updating metrics in the background. + if c.metrics != nil { + wg.Add(1) + go func() { + defer wg.Done() + c.metrics.BackgroundFetch(ctx, metricsUpdateInterval, metricsUpdateTimeout) + }() + } + // Everything is in place, reconciler can now be considered as running. // // NOTE: without this atomic bool, Stop might race with Run for the c.cancelFn above. From 0d5d26892f4d43eed1a3c2226eb3553f4a03ae85 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 14 May 2025 04:51:01 +1000 Subject: [PATCH 10/11] chore: optimize workspace_latest_builds view query (#17789) Avoids two sequential scans of massive tables (`workspace_builds`, `provisioner_jobs`) and uses index scans instead. This new view largely replicates our already optimized query `GetWorkspaces` to fetch the latest build. The original query and the new query were compared against the dogfood database to ensure they return the exact same data in the exact same order (minus the new `workspaces.deleted = false` filter to improve performance even more). The performance is massively improved even without the `workspaces.deleted = false` filter, but it was added to improve it even more. Note: these query times are probably inflated due to high database load on our dogfood environment that this intends to partially resolve. Before: 2,139ms ([explain](https://explain.dalibo.com/plan/997e4fch241b46e6)) After: 33ms ([explain](https://explain.dalibo.com/plan/c888dc223870f181)) Co-authored-by: Cian Johnston --------- Signed-off-by: Danny Kopping Co-authored-by: Mathias Fredriksson Co-authored-by: Danny Kopping (cherry picked from commit ef745c0c5d3f4db643a9f4fac4503ddde8bb4cac) --- coderd/database/dump.sql | 77 ++++++++++------- ...kspace_latest_builds_optimization.down.sql | 58 +++++++++++++ ...orkspace_latest_builds_optimization.up.sql | 85 +++++++++++++++++++ 3 files changed, 188 insertions(+), 32 deletions(-) create mode 100644 coderd/database/migrations/000323_workspace_latest_builds_optimization.down.sql create mode 100644 coderd/database/migrations/000323_workspace_latest_builds_optimization.up.sql diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 83d998b2b9a3e..bb81b0d5e75b5 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1991,18 +1991,52 @@ CREATE VIEW workspace_build_with_user AS COMMENT ON VIEW workspace_build_with_user IS 'Joins in the username + avatar url of the initiated by user.'; +CREATE TABLE workspaces ( + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + updated_at timestamp with time zone NOT NULL, + owner_id uuid NOT NULL, + organization_id uuid NOT NULL, + template_id uuid NOT NULL, + deleted boolean DEFAULT false NOT NULL, + name character varying(64) NOT NULL, + autostart_schedule text, + ttl bigint, + last_used_at timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, + dormant_at timestamp with time zone, + deleting_at timestamp with time zone, + automatic_updates automatic_updates DEFAULT 'never'::automatic_updates NOT NULL, + favorite boolean DEFAULT false NOT NULL, + next_start_at timestamp with time zone +); + +COMMENT ON COLUMN workspaces.favorite IS 'Favorite is true if the workspace owner has favorited the workspace.'; + CREATE VIEW workspace_latest_builds AS - SELECT DISTINCT ON (wb.workspace_id) wb.id, - wb.workspace_id, - wb.template_version_id, - wb.job_id, - wb.template_version_preset_id, - wb.transition, - wb.created_at, - pj.job_status - FROM (workspace_builds wb - JOIN provisioner_jobs pj ON ((wb.job_id = pj.id))) - ORDER BY wb.workspace_id, wb.build_number DESC; + SELECT latest_build.id, + latest_build.workspace_id, + latest_build.template_version_id, + latest_build.job_id, + latest_build.template_version_preset_id, + latest_build.transition, + latest_build.created_at, + latest_build.job_status + FROM (workspaces + LEFT JOIN LATERAL ( SELECT workspace_builds.id, + workspace_builds.workspace_id, + workspace_builds.template_version_id, + workspace_builds.job_id, + workspace_builds.template_version_preset_id, + workspace_builds.transition, + workspace_builds.created_at, + provisioner_jobs.job_status + FROM (workspace_builds + JOIN provisioner_jobs ON ((provisioner_jobs.id = workspace_builds.job_id))) + WHERE (workspace_builds.workspace_id = workspaces.id) + ORDER BY workspace_builds.build_number DESC + LIMIT 1) latest_build ON (true)) + WHERE (workspaces.deleted = false) + ORDER BY workspaces.id; CREATE TABLE workspace_modules ( id uuid NOT NULL, @@ -2039,27 +2073,6 @@ CREATE TABLE workspace_resources ( module_path text ); -CREATE TABLE workspaces ( - id uuid NOT NULL, - created_at timestamp with time zone NOT NULL, - updated_at timestamp with time zone NOT NULL, - owner_id uuid NOT NULL, - organization_id uuid NOT NULL, - template_id uuid NOT NULL, - deleted boolean DEFAULT false NOT NULL, - name character varying(64) NOT NULL, - autostart_schedule text, - ttl bigint, - last_used_at timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL, - dormant_at timestamp with time zone, - deleting_at timestamp with time zone, - automatic_updates automatic_updates DEFAULT 'never'::automatic_updates NOT NULL, - favorite boolean DEFAULT false NOT NULL, - next_start_at timestamp with time zone -); - -COMMENT ON COLUMN workspaces.favorite IS 'Favorite is true if the workspace owner has favorited the workspace.'; - CREATE VIEW workspace_prebuilds AS WITH all_prebuilds AS ( SELECT w.id, diff --git a/coderd/database/migrations/000323_workspace_latest_builds_optimization.down.sql b/coderd/database/migrations/000323_workspace_latest_builds_optimization.down.sql new file mode 100644 index 0000000000000..9d9ae7aff4bd9 --- /dev/null +++ b/coderd/database/migrations/000323_workspace_latest_builds_optimization.down.sql @@ -0,0 +1,58 @@ +DROP VIEW workspace_prebuilds; +DROP VIEW workspace_latest_builds; + +-- Revert to previous version from 000314_prebuilds.up.sql +CREATE VIEW workspace_latest_builds AS +SELECT DISTINCT ON (workspace_id) + wb.id, + wb.workspace_id, + wb.template_version_id, + wb.job_id, + wb.template_version_preset_id, + wb.transition, + wb.created_at, + pj.job_status +FROM workspace_builds wb + INNER JOIN provisioner_jobs pj ON wb.job_id = pj.id +ORDER BY wb.workspace_id, wb.build_number DESC; + +-- Recreate the dependent views +CREATE VIEW workspace_prebuilds AS + WITH all_prebuilds AS ( + SELECT w.id, + w.name, + w.template_id, + w.created_at + FROM workspaces w + WHERE (w.owner_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid) + ), workspaces_with_latest_presets AS ( + SELECT DISTINCT ON (workspace_builds.workspace_id) workspace_builds.workspace_id, + workspace_builds.template_version_preset_id + FROM workspace_builds + WHERE (workspace_builds.template_version_preset_id IS NOT NULL) + ORDER BY workspace_builds.workspace_id, workspace_builds.build_number DESC + ), workspaces_with_agents_status AS ( + SELECT w.id AS workspace_id, + bool_and((wa.lifecycle_state = 'ready'::workspace_agent_lifecycle_state)) AS ready + FROM (((workspaces w + JOIN workspace_latest_builds wlb ON ((wlb.workspace_id = w.id))) + JOIN workspace_resources wr ON ((wr.job_id = wlb.job_id))) + JOIN workspace_agents wa ON ((wa.resource_id = wr.id))) + WHERE (w.owner_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid) + GROUP BY w.id + ), current_presets AS ( + SELECT w.id AS prebuild_id, + wlp.template_version_preset_id + FROM (workspaces w + JOIN workspaces_with_latest_presets wlp ON ((wlp.workspace_id = w.id))) + WHERE (w.owner_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid) + ) + SELECT p.id, + p.name, + p.template_id, + p.created_at, + COALESCE(a.ready, false) AS ready, + cp.template_version_preset_id AS current_preset_id + FROM ((all_prebuilds p + LEFT JOIN workspaces_with_agents_status a ON ((a.workspace_id = p.id))) + JOIN current_presets cp ON ((cp.prebuild_id = p.id))); diff --git a/coderd/database/migrations/000323_workspace_latest_builds_optimization.up.sql b/coderd/database/migrations/000323_workspace_latest_builds_optimization.up.sql new file mode 100644 index 0000000000000..d65e09ef47339 --- /dev/null +++ b/coderd/database/migrations/000323_workspace_latest_builds_optimization.up.sql @@ -0,0 +1,85 @@ +-- Drop the dependent views +DROP VIEW workspace_prebuilds; +-- Previously created in 000314_prebuilds.up.sql +DROP VIEW workspace_latest_builds; + +-- The previous version of this view had two sequential scans on two very large +-- tables. This version optimized it by using index scans (via a lateral join) +-- AND avoiding selecting builds from deleted workspaces. +CREATE VIEW workspace_latest_builds AS +SELECT + latest_build.id, + latest_build.workspace_id, + latest_build.template_version_id, + latest_build.job_id, + latest_build.template_version_preset_id, + latest_build.transition, + latest_build.created_at, + latest_build.job_status +FROM workspaces +LEFT JOIN LATERAL ( + SELECT + workspace_builds.id AS id, + workspace_builds.workspace_id AS workspace_id, + workspace_builds.template_version_id AS template_version_id, + workspace_builds.job_id AS job_id, + workspace_builds.template_version_preset_id AS template_version_preset_id, + workspace_builds.transition AS transition, + workspace_builds.created_at AS created_at, + provisioner_jobs.job_status AS job_status + FROM + workspace_builds + JOIN + provisioner_jobs + ON + provisioner_jobs.id = workspace_builds.job_id + WHERE + workspace_builds.workspace_id = workspaces.id + ORDER BY + build_number DESC + LIMIT + 1 +) latest_build ON TRUE +WHERE workspaces.deleted = false +ORDER BY workspaces.id ASC; + +-- Recreate the dependent views +CREATE VIEW workspace_prebuilds AS + WITH all_prebuilds AS ( + SELECT w.id, + w.name, + w.template_id, + w.created_at + FROM workspaces w + WHERE (w.owner_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid) + ), workspaces_with_latest_presets AS ( + SELECT DISTINCT ON (workspace_builds.workspace_id) workspace_builds.workspace_id, + workspace_builds.template_version_preset_id + FROM workspace_builds + WHERE (workspace_builds.template_version_preset_id IS NOT NULL) + ORDER BY workspace_builds.workspace_id, workspace_builds.build_number DESC + ), workspaces_with_agents_status AS ( + SELECT w.id AS workspace_id, + bool_and((wa.lifecycle_state = 'ready'::workspace_agent_lifecycle_state)) AS ready + FROM (((workspaces w + JOIN workspace_latest_builds wlb ON ((wlb.workspace_id = w.id))) + JOIN workspace_resources wr ON ((wr.job_id = wlb.job_id))) + JOIN workspace_agents wa ON ((wa.resource_id = wr.id))) + WHERE (w.owner_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid) + GROUP BY w.id + ), current_presets AS ( + SELECT w.id AS prebuild_id, + wlp.template_version_preset_id + FROM (workspaces w + JOIN workspaces_with_latest_presets wlp ON ((wlp.workspace_id = w.id))) + WHERE (w.owner_id = 'c42fdf75-3097-471c-8c33-fb52454d81c0'::uuid) + ) + SELECT p.id, + p.name, + p.template_id, + p.created_at, + COALESCE(a.ready, false) AS ready, + cp.template_version_preset_id AS current_preset_id + FROM ((all_prebuilds p + LEFT JOIN workspaces_with_agents_status a ON ((a.workspace_id = p.id))) + JOIN current_presets cp ON ((cp.prebuild_id = p.id))); From 27a1aa0fbfb38328b95a8cb26954c316a20b8e72 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Fri, 9 May 2025 17:41:19 +0200 Subject: [PATCH 11/11] chore: upgrade `terraform-provider-coder` & `preview` libs (#17738) The changes in `coder/preview` necessitated the changes in `codersdk/richparameters.go` & `provisioner/terraform/resources.go`. --------- Signed-off-by: Danny Kopping Co-authored-by: Steven Masley (cherry picked from commit 3ee95f14ceac20311aee72da1c317bb70a1f2ab1) --- cli/update_test.go | 2 +- coderd/testdata/parameters/groups/main.tf | 2 +- codersdk/richparameters.go | 46 +++++++---------------- go.mod | 4 +- go.sum | 8 ++-- provisioner/terraform/resources.go | 6 ++- site/src/api/typesGenerated.ts | 1 - 7 files changed, 27 insertions(+), 42 deletions(-) diff --git a/cli/update_test.go b/cli/update_test.go index 413c3d3c37f67..367a8196aa499 100644 --- a/cli/update_test.go +++ b/cli/update_test.go @@ -757,7 +757,7 @@ func TestUpdateValidateRichParameters(t *testing.T) { err := inv.Run() // TODO: improve validation so we catch this problem before it reaches the server // but for now just validate that the server actually catches invalid monotonicity - assert.ErrorContains(t, err, fmt.Sprintf("parameter value must be equal or greater than previous value: %s", tempVal)) + assert.ErrorContains(t, err, "parameter value '1' must be equal or greater than previous value: 2") }() matches := []string{ diff --git a/coderd/testdata/parameters/groups/main.tf b/coderd/testdata/parameters/groups/main.tf index 9356cc2840e91..8bed301f33ce5 100644 --- a/coderd/testdata/parameters/groups/main.tf +++ b/coderd/testdata/parameters/groups/main.tf @@ -12,7 +12,7 @@ data "coder_parameter" "group" { name = "group" default = try(data.coder_workspace_owner.me.groups[0], "") dynamic "option" { - for_each = data.coder_workspace_owner.me.groups + for_each = concat(data.coder_workspace_owner.me.groups, "bloob") content { name = option.value value = option.value diff --git a/codersdk/richparameters.go b/codersdk/richparameters.go index 6fc6b8e0c343f..f00c947715f9d 100644 --- a/codersdk/richparameters.go +++ b/codersdk/richparameters.go @@ -1,9 +1,8 @@ package codersdk import ( - "strconv" - "golang.org/x/xerrors" + "tailscale.com/types/ptr" "github.com/coder/terraform-provider-coder/v2/provider" ) @@ -46,47 +45,31 @@ func ValidateWorkspaceBuildParameter(richParameter TemplateVersionParameter, bui } func validateBuildParameter(richParameter TemplateVersionParameter, buildParameter *WorkspaceBuildParameter, lastBuildParameter *WorkspaceBuildParameter) error { - var value string + var ( + current string + previous *string + ) if buildParameter != nil { - value = buildParameter.Value + current = buildParameter.Value } - if richParameter.Required && value == "" { - return xerrors.Errorf("parameter value is required") + if lastBuildParameter != nil { + previous = ptr.To(lastBuildParameter.Value) } - if value == "" { // parameter is optional, so take the default value - value = richParameter.DefaultValue + if richParameter.Required && current == "" { + return xerrors.Errorf("parameter value is required") } - if lastBuildParameter != nil && lastBuildParameter.Value != "" && richParameter.Type == "number" && len(richParameter.ValidationMonotonic) > 0 { - prev, err := strconv.Atoi(lastBuildParameter.Value) - if err != nil { - return xerrors.Errorf("previous parameter value is not a number: %s", lastBuildParameter.Value) - } - - current, err := strconv.Atoi(buildParameter.Value) - if err != nil { - return xerrors.Errorf("current parameter value is not a number: %s", buildParameter.Value) - } - - switch richParameter.ValidationMonotonic { - case MonotonicOrderIncreasing: - if prev > current { - return xerrors.Errorf("parameter value must be equal or greater than previous value: %d", prev) - } - case MonotonicOrderDecreasing: - if prev < current { - return xerrors.Errorf("parameter value must be equal or lower than previous value: %d", prev) - } - } + if current == "" { // parameter is optional, so take the default value + current = richParameter.DefaultValue } if len(richParameter.Options) > 0 { var matched bool for _, opt := range richParameter.Options { - if opt.Value == value { + if opt.Value == current { matched = true break } @@ -95,7 +78,6 @@ func validateBuildParameter(richParameter TemplateVersionParameter, buildParamet if !matched { return xerrors.Errorf("parameter value must match one of options: %s", parameterValuesAsArray(richParameter.Options)) } - return nil } if !validationEnabled(richParameter) { @@ -119,7 +101,7 @@ func validateBuildParameter(richParameter TemplateVersionParameter, buildParamet Error: richParameter.ValidationError, Monotonic: string(richParameter.ValidationMonotonic), } - return validation.Valid(richParameter.Type, value) + return validation.Valid(richParameter.Type, current, previous) } func findBuildParameter(params []WorkspaceBuildParameter, parameterName string) (*WorkspaceBuildParameter, bool) { diff --git a/go.mod b/go.mod index 8ff0ba1fa2376..a2f8bda5c97bb 100644 --- a/go.mod +++ b/go.mod @@ -101,7 +101,7 @@ require ( github.com/coder/quartz v0.1.2 github.com/coder/retry v1.5.1 github.com/coder/serpent v0.10.0 - github.com/coder/terraform-provider-coder/v2 v2.4.0-pre1.0.20250417100258-c86bb5c3ddcd + github.com/coder/terraform-provider-coder/v2 v2.4.0 github.com/coder/websocket v1.8.13 github.com/coder/wgtunnel v0.1.13-0.20240522110300-ade90dfb2da0 github.com/coreos/go-oidc/v3 v3.14.1 @@ -487,7 +487,7 @@ require ( ) require ( - github.com/coder/preview v0.0.1 + github.com/coder/preview v0.0.2-0.20250510235314-66630669ff6f github.com/fsnotify/fsnotify v1.9.0 github.com/kylecarbs/aisdk-go v0.0.5 github.com/mark3labs/mcp-go v0.23.1 diff --git a/go.sum b/go.sum index fc05152d34122..4a2c74f8dc69e 100644 --- a/go.sum +++ b/go.sum @@ -907,8 +907,8 @@ github.com/coder/pq v1.10.5-0.20240813183442-0c420cb5a048 h1:3jzYUlGH7ZELIH4XggX github.com/coder/pq v1.10.5-0.20240813183442-0c420cb5a048/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 h1:3A0ES21Ke+FxEM8CXx9n47SZOKOpgSE1bbJzlE4qPVs= github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0/go.mod h1:5UuS2Ts+nTToAMeOjNlnHFkPahrtDkmpydBen/3wgZc= -github.com/coder/preview v0.0.1 h1:2X5McKdMOZJILTIDf7qRplXKupT+91qTJBN67XUh5cA= -github.com/coder/preview v0.0.1/go.mod h1:eInDmOdSDF8cxCvapIvYkGRzmzvcvGAFL1HYqcA4g+E= +github.com/coder/preview v0.0.2-0.20250510235314-66630669ff6f h1:Q1AcLBpRRR8rf/H+GxcJaZ5Ox4UeIRkDgc6wvvmOJAo= +github.com/coder/preview v0.0.2-0.20250510235314-66630669ff6f/go.mod h1:GfkwIv5gQLpL01qeGU1/YoxoFtt5trzCqnWZLo77clU= github.com/coder/quartz v0.1.2 h1:PVhc9sJimTdKd3VbygXtS4826EOCpB1fXoRlLnCrE+s= github.com/coder/quartz v0.1.2/go.mod h1:vsiCc+AHViMKH2CQpGIpFgdHIEQsxwm8yCscqKmzbRA= github.com/coder/retry v1.5.1 h1:iWu8YnD8YqHs3XwqrqsjoBTAVqT9ml6z9ViJ2wlMiqc= @@ -921,8 +921,8 @@ github.com/coder/tailscale v1.1.1-0.20250422090654-5090e715905e h1:nope/SZfoLB9M github.com/coder/tailscale v1.1.1-0.20250422090654-5090e715905e/go.mod h1:1ggFFdHTRjPRu9Yc1yA7nVHBYB50w9Ce7VIXNqcW6Ko= github.com/coder/terraform-config-inspect v0.0.0-20250107175719-6d06d90c630e h1:JNLPDi2P73laR1oAclY6jWzAbucf70ASAvf5mh2cME0= github.com/coder/terraform-config-inspect v0.0.0-20250107175719-6d06d90c630e/go.mod h1:Gz/z9Hbn+4KSp8A2FBtNszfLSdT2Tn/uAKGuVqqWmDI= -github.com/coder/terraform-provider-coder/v2 v2.4.0-pre1.0.20250417100258-c86bb5c3ddcd h1:FsIG6Fd0YOEK7D0Hl/CJywRA+Y6Gd5RQbSIa2L+/BmE= -github.com/coder/terraform-provider-coder/v2 v2.4.0-pre1.0.20250417100258-c86bb5c3ddcd/go.mod h1:56/KdGYaA+VbwXJbTI8CA57XPfnuTxN8rjxbR34PbZw= +github.com/coder/terraform-provider-coder/v2 v2.4.0 h1:uuFmF03IyahAZLXEukOdmvV9hGfUMJSESD8+G5wkTcM= +github.com/coder/terraform-provider-coder/v2 v2.4.0/go.mod h1:2kaBpn5k9ZWtgKq5k4JbkVZG9DzEqR4mJSmpdshcO+s= github.com/coder/trivy v0.0.0-20250409153844-e6b004bc465a h1:yryP7e+IQUAArlycH4hQrjXQ64eRNbxsV5/wuVXHgME= github.com/coder/trivy v0.0.0-20250409153844-e6b004bc465a/go.mod h1:dDvq9axp3kZsT63gY2Znd1iwzfqDq3kXbQnccIrjRYY= github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= diff --git a/provisioner/terraform/resources.go b/provisioner/terraform/resources.go index da86ab2f3d48e..ce881480ad3aa 100644 --- a/provisioner/terraform/resources.go +++ b/provisioner/terraform/resources.go @@ -749,13 +749,17 @@ func ConvertState(ctx context.Context, modules []*tfjson.StateModule, rawGraph s if err != nil { return nil, xerrors.Errorf("decode map values for coder_parameter.%s: %w", resource.Name, err) } + var defaultVal string + if param.Default != nil { + defaultVal = *param.Default + } protoParam := &proto.RichParameter{ Name: param.Name, DisplayName: param.DisplayName, Description: param.Description, Type: param.Type, Mutable: param.Mutable, - DefaultValue: param.Default, + DefaultValue: defaultVal, Icon: param.Icon, Required: !param.Optional, // #nosec G115 - Safe conversion as parameter order value is expected to be within int32 range diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index d879c09d119b2..afa766af64a84 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1770,7 +1770,6 @@ export interface PreviewParameterValidation { readonly validation_min: number | null; readonly validation_max: number | null; readonly validation_monotonic: string | null; - readonly validation_invalid: boolean | null; } // From codersdk/deployment.go