diff --git a/agent/agent.go b/agent/agent.go index b1218190bbcb4..548156b3ce89f 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -77,6 +77,7 @@ type Client interface { PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error PostMetadata(ctx context.Context, key string, req agentsdk.PostMetadataRequest) error PatchStartupLogs(ctx context.Context, req agentsdk.PatchStartupLogs) error + GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerConfig, error) } type Agent interface { @@ -163,7 +164,9 @@ type agent struct { envVars map[string]string // manifest is atomic because values can change after reconnection. - manifest atomic.Pointer[agentsdk.Manifest] + manifest atomic.Pointer[agentsdk.Manifest] + // serviceBanner is atomic because it can change. + serviceBanner atomic.Pointer[codersdk.ServiceBannerConfig] sessionToken atomic.Pointer[string] sshServer *agentssh.Server sshMaxTimeout time.Duration @@ -191,6 +194,7 @@ func (a *agent) init(ctx context.Context) { sshSrv.Env = a.envVars sshSrv.AgentToken = func() string { return *a.sessionToken.Load() } sshSrv.Manifest = &a.manifest + sshSrv.ServiceBanner = &a.serviceBanner a.sshServer = sshSrv go a.runLoop(ctx) @@ -203,6 +207,7 @@ func (a *agent) init(ctx context.Context) { func (a *agent) runLoop(ctx context.Context) { go a.reportLifecycleLoop(ctx) go a.reportMetadataLoop(ctx) + go a.fetchServiceBannerLoop(ctx) for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { a.logger.Info(ctx, "connecting to coderd") @@ -275,14 +280,15 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM return result } -func adjustIntervalForTests(i int64) time.Duration { +// adjustIntervalForTests returns a duration of testInterval milliseconds long +// for tests and interval seconds long otherwise. +func adjustIntervalForTests(interval time.Duration, testInterval time.Duration) time.Duration { // In tests we want to set shorter intervals because engineers are // impatient. - base := time.Second if flag.Lookup("test.v") != nil { - base = time.Millisecond * 100 + return testInterval } - return time.Duration(i) * base + return interval } type metadataResultAndKey struct { @@ -306,7 +312,7 @@ func (t *trySingleflight) Do(key string, fn func()) { } func (a *agent) reportMetadataLoop(ctx context.Context) { - baseInterval := adjustIntervalForTests(1) + baseInterval := adjustIntervalForTests(time.Second, time.Millisecond*100) const metadataLimit = 128 @@ -383,7 +389,9 @@ func (a *agent) reportMetadataLoop(ctx context.Context) { } // The last collected value isn't quite stale yet, so we skip it. if collectedAt.Add( - adjustIntervalForTests(md.Interval), + adjustIntervalForTests( + time.Duration(md.Interval)*time.Second, + time.Duration(md.Interval)*time.Millisecond*100), ).After(time.Now()) { continue } @@ -491,6 +499,30 @@ func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentL } } +// fetchServiceBannerLoop fetches the service banner on an interval. It will +// not be fetched immediately; the expectation is that it is primed elsewhere +// (and must be done before the session actually starts). +func (a *agent) fetchServiceBannerLoop(ctx context.Context) { + ticker := time.NewTicker(adjustIntervalForTests(2*time.Minute, time.Millisecond*100)) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + serviceBanner, err := a.client.GetServiceBanner(ctx) + if err != nil { + if ctx.Err() != nil { + return + } + a.logger.Error(ctx, "failed to update service banner", slog.Error(err)) + continue + } + a.serviceBanner.Store(&serviceBanner) + } + } +} + func (a *agent) run(ctx context.Context) error { // This allows the agent to refresh it's token if necessary. // For instance identity this is required, since the instance @@ -501,6 +533,12 @@ func (a *agent) run(ctx context.Context) error { } a.sessionToken.Store(&sessionToken) + serviceBanner, err := a.client.GetServiceBanner(ctx) + if err != nil { + return xerrors.Errorf("fetch service banner: %w", err) + } + a.serviceBanner.Store(&serviceBanner) + manifest, err := a.client.Manifest(ctx) if err != nil { return xerrors.Errorf("fetch metadata: %w", err) diff --git a/agent/agent_test.go b/agent/agent_test.go index c3320fa1aa570..8b604b84d5294 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -15,6 +15,7 @@ import ( "os/user" "path" "path/filepath" + "regexp" "runtime" "strconv" "strings" @@ -66,7 +67,7 @@ func TestAgent_Stats_SSH(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, _, stats, _, _ := setupAgent(t, &client{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) @@ -99,7 +100,7 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, _, stats, _, _ := setupAgent(t, &client{}, 0) ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "/bin/bash") require.NoError(t, err) @@ -129,7 +130,7 @@ func TestAgent_Stats_Magic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, _, _, _, _ := setupAgent(t, &client{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -156,7 +157,7 @@ func TestAgent_Stats_Magic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, _, stats, _, _ := setupAgent(t, &client{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -191,7 +192,7 @@ func TestAgent_Stats_Magic(t *testing.T) { func TestAgent_SessionExec(t *testing.T) { t.Parallel() - session := setupSSHSession(t, agentsdk.Manifest{}) + session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}) command := "echo test" if runtime.GOOS == "windows" { @@ -204,7 +205,7 @@ func TestAgent_SessionExec(t *testing.T) { func TestAgent_GitSSH(t *testing.T) { t.Parallel() - session := setupSSHSession(t, agentsdk.Manifest{}) + session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}) command := "sh -c 'echo $GIT_SSH_COMMAND'" if runtime.GOOS == "windows" { command = "cmd.exe /c echo %GIT_SSH_COMMAND%" @@ -224,7 +225,7 @@ func TestAgent_SessionTTYShell(t *testing.T) { // it seems like it could be either. t.Skip("ConPTY appears to be inconsistent on Windows.") } - session := setupSSHSession(t, agentsdk.Manifest{}) + session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}) command := "sh" if runtime.GOOS == "windows" { command = "cmd.exe" @@ -247,7 +248,7 @@ func TestAgent_SessionTTYShell(t *testing.T) { func TestAgent_SessionTTYExitCode(t *testing.T) { t.Parallel() - session := setupSSHSession(t, agentsdk.Manifest{}) + session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}) command := "areallynotrealcommand" err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) require.NoError(t, err) @@ -277,6 +278,7 @@ func TestAgent_Session_TTY_MOTD(t *testing.T) { } wantMOTD := "Welcome to your Coder workspace!" + wantServiceBanner := "Service banner text goes here" tmpdir := t.TempDir() name := filepath.Join(tmpdir, "motd") @@ -286,29 +288,176 @@ func TestAgent_Session_TTY_MOTD(t *testing.T) { // Set HOME so we can ensure no ~/.hushlogin is present. t.Setenv("HOME", tmpdir) - session := setupSSHSession(t, agentsdk.Manifest{ - MOTDFile: name, - }) - err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) - require.NoError(t, err) + tests := []struct { + name string + manifest agentsdk.Manifest + banner codersdk.ServiceBannerConfig + expected []string + unexpected []string + expectedRe *regexp.Regexp + }{ + { + name: "WithoutServiceBanner", + manifest: agentsdk.Manifest{MOTDFile: name}, + banner: codersdk.ServiceBannerConfig{}, + expected: []string{wantMOTD}, + unexpected: []string{wantServiceBanner}, + }, + { + name: "WithServiceBanner", + manifest: agentsdk.Manifest{MOTDFile: name}, + banner: codersdk.ServiceBannerConfig{ + Enabled: true, + Message: wantServiceBanner, + }, + expected: []string{wantMOTD, wantServiceBanner}, + }, + { + name: "ServiceBannerDisabled", + manifest: agentsdk.Manifest{MOTDFile: name}, + banner: codersdk.ServiceBannerConfig{ + Enabled: false, + Message: wantServiceBanner, + }, + expected: []string{wantMOTD}, + unexpected: []string{wantServiceBanner}, + }, + { + name: "ServiceBannerOnly", + manifest: agentsdk.Manifest{}, + banner: codersdk.ServiceBannerConfig{ + Enabled: true, + Message: wantServiceBanner, + }, + expected: []string{wantServiceBanner}, + unexpected: []string{wantMOTD}, + }, + { + name: "None", + manifest: agentsdk.Manifest{}, + banner: codersdk.ServiceBannerConfig{}, + unexpected: []string{wantServiceBanner, wantMOTD}, + }, + { + name: "CarriageReturns", + manifest: agentsdk.Manifest{}, + banner: codersdk.ServiceBannerConfig{ + Enabled: true, + Message: "service\n\nbanner\nhere", + }, + expected: []string{"service\r\n\r\nbanner\r\nhere\r\n\r\n"}, + unexpected: []string{}, + }, + { + name: "Trim", + manifest: agentsdk.Manifest{}, + banner: codersdk.ServiceBannerConfig{ + Enabled: true, + Message: "\n\n\n\n\n\nbanner\n\n\n\n\n\n", + }, + expectedRe: regexp.MustCompile("([^\n\r]|^)banner\r\n\r\n[^\r\n]"), + }, + } - ptty := ptytest.New(t) - var stdout bytes.Buffer - session.Stdout = &stdout - session.Stderr = ptty.Output() - session.Stdin = ptty.Input() - err = session.Shell() - require.NoError(t, err) + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + session := setupSSHSession(t, test.manifest, test.banner) + testSessionOutput(t, session, test.expected, test.unexpected, test.expectedRe) + }) + } +} - ptty.WriteLine("exit 0") - err = session.Wait() - require.NoError(t, err) +//nolint:paralleltest // This test sets an environment variable. +func TestAgent_Session_TTY_MOTD_Update(t *testing.T) { + if runtime.GOOS == "windows" { + // This might be our implementation, or ConPTY itself. + // It's difficult to find extensive tests for it, so + // it seems like it could be either. + t.Skip("ConPTY appears to be inconsistent on Windows.") + } + + // Only the banner updates dynamically; the MOTD file does not. + wantServiceBanner := "Service banner text goes here" + + tmpdir := t.TempDir() + + // Set HOME so we can ensure no ~/.hushlogin is present. + t.Setenv("HOME", tmpdir) + + tests := []struct { + banner codersdk.ServiceBannerConfig + expected []string + unexpected []string + }{ + { + banner: codersdk.ServiceBannerConfig{}, + expected: []string{}, + unexpected: []string{wantServiceBanner}, + }, + { + banner: codersdk.ServiceBannerConfig{ + Enabled: true, + Message: wantServiceBanner, + }, + expected: []string{wantServiceBanner}, + }, + { + banner: codersdk.ServiceBannerConfig{ + Enabled: false, + Message: wantServiceBanner, + }, + expected: []string{}, + unexpected: []string{wantServiceBanner}, + }, + { + banner: codersdk.ServiceBannerConfig{ + Enabled: true, + Message: wantServiceBanner, + }, + expected: []string{wantServiceBanner}, + unexpected: []string{}, + }, + { + banner: codersdk.ServiceBannerConfig{}, + unexpected: []string{wantServiceBanner}, + }, + } + + const updateInterval = 100 * time.Millisecond + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + //nolint:dogsled // Allow the blank identifiers. + conn, client, _, _, _ := setupAgent(t, &client{}, 0) + for _, test := range tests { + test := test + // Set new banner func and wait for the agent to call it to update the + // banner. + client.mu.Lock() + client.getServiceBanner = func() (codersdk.ServiceBannerConfig, error) { + return test.banner, nil + } + client.mu.Unlock() + time.Sleep(updateInterval) + + sshClient, err := conn.SSHClient(ctx) + require.NoError(t, err) + t.Cleanup(func() { + _ = sshClient.Close() + }) + session, err := sshClient.NewSession() + require.NoError(t, err) + t.Cleanup(func() { + _ = session.Close() + }) - require.Contains(t, stdout.String(), wantMOTD, "should show motd") + testSessionOutput(t, session, test.expected, test.unexpected, nil) + } } //nolint:paralleltest // This test sets an environment variable. -func TestAgent_Session_TTY_Hushlogin(t *testing.T) { +func TestAgent_Session_TTY_QuietLogin(t *testing.T) { if runtime.GOOS == "windows" { // This might be our implementation, or ConPTY itself. // It's difficult to find extensive tests for it, so @@ -317,40 +466,69 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) { } wantNotMOTD := "Welcome to your Coder workspace!" + wantServiceBanner := "Service banner text goes here" tmpdir := t.TempDir() name := filepath.Join(tmpdir, "motd") err := os.WriteFile(name, []byte(wantNotMOTD), 0o600) require.NoError(t, err, "write motd file") - // Create hushlogin to silence motd. - f, err := os.Create(filepath.Join(tmpdir, ".hushlogin")) - require.NoError(t, err, "create .hushlogin file") - err = f.Close() - require.NoError(t, err, "close .hushlogin file") - // Set HOME so we can ensure ~/.hushlogin is present. t.Setenv("HOME", tmpdir) - session := setupSSHSession(t, agentsdk.Manifest{ - MOTDFile: name, + // Neither banner nor MOTD should show if not a login shell. + t.Run("NotLogin", func(t *testing.T) { + wantEcho := "foobar" + session := setupSSHSession(t, agentsdk.Manifest{ + MOTDFile: name, + }, codersdk.ServiceBannerConfig{ + Enabled: true, + Message: wantServiceBanner, + }) + err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) + require.NoError(t, err) + + command := "echo " + wantEcho + output, err := session.Output(command) + require.NoError(t, err) + + require.Contains(t, string(output), wantEcho, "should show echo") + require.NotContains(t, string(output), wantNotMOTD, "should not show motd") + require.NotContains(t, string(output), wantServiceBanner, "should not show service banner") }) - err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) - require.NoError(t, err) - ptty := ptytest.New(t) - var stdout bytes.Buffer - session.Stdout = &stdout - session.Stderr = ptty.Output() - session.Stdin = ptty.Input() - err = session.Shell() - require.NoError(t, err) + // Only the MOTD should be silenced. + t.Run("Hushlogin", func(t *testing.T) { + // Create hushlogin to silence motd. + f, err := os.Create(filepath.Join(tmpdir, ".hushlogin")) + require.NoError(t, err, "create .hushlogin file") + err = f.Close() + require.NoError(t, err, "close .hushlogin file") + + session := setupSSHSession(t, agentsdk.Manifest{ + MOTDFile: name, + }, codersdk.ServiceBannerConfig{ + Enabled: true, + Message: wantServiceBanner, + }) + err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) + require.NoError(t, err) - ptty.WriteLine("exit 0") - err = session.Wait() - require.NoError(t, err) + ptty := ptytest.New(t) + var stdout bytes.Buffer + session.Stdout = &stdout + session.Stderr = ptty.Output() + session.Stdin = ptty.Input() + err = session.Shell() + require.NoError(t, err) - require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd") + ptty.WriteLine("exit 0") + err = session.Wait() + require.NoError(t, err) + + require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd") + require.Contains(t, stdout.String(), wantServiceBanner, "should show service banner") + }) } func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) { @@ -362,7 +540,7 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, _, _, _, _ := setupAgent(t, &client{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -412,7 +590,7 @@ func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, _, _, _, _ := setupAgent(t, &client{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -742,7 +920,7 @@ func TestAgent_SFTP(t *testing.T) { home = "/" + strings.ReplaceAll(home, "\\", "/") } //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, _, _, _, _ := setupAgent(t, &client{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -774,7 +952,7 @@ func TestAgent_SCP(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, _, _, _, _ := setupAgent(t, &client{}, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() @@ -797,7 +975,7 @@ func TestAgent_EnvironmentVariables(t *testing.T) { EnvironmentVariables: map[string]string{ key: value, }, - }) + }, codersdk.ServiceBannerConfig{}) command := "sh -c 'echo $" + key + "'" if runtime.GOOS == "windows" { command = "cmd.exe /c echo %" + key + "%" @@ -814,7 +992,7 @@ func TestAgent_EnvironmentVariableExpansion(t *testing.T) { EnvironmentVariables: map[string]string{ key: "$SOMETHINGNOTSET", }, - }) + }, codersdk.ServiceBannerConfig{}) command := "sh -c 'echo $" + key + "'" if runtime.GOOS == "windows" { command = "cmd.exe /c echo %" + key + "%" @@ -837,7 +1015,7 @@ func TestAgent_CoderEnvVars(t *testing.T) { t.Run(key, func(t *testing.T) { t.Parallel() - session := setupSSHSession(t, agentsdk.Manifest{}) + session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}) command := "sh -c 'echo $" + key + "'" if runtime.GOOS == "windows" { command = "cmd.exe /c echo %" + key + "%" @@ -860,7 +1038,7 @@ func TestAgent_SSHConnectionEnvVars(t *testing.T) { t.Run(key, func(t *testing.T) { t.Parallel() - session := setupSSHSession(t, agentsdk.Manifest{}) + session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}) command := "sh -c 'echo $" + key + "'" if runtime.GOOS == "windows" { command = "cmd.exe /c echo %" + key + "%" @@ -958,12 +1136,14 @@ func TestAgent_Metadata(t *testing.T) { t.Run("Once", func(t *testing.T) { t.Parallel() //nolint:dogsled - _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ - Metadata: []codersdk.WorkspaceAgentMetadataDescription{ - { - Key: "greeting", - Interval: 0, - Script: echoHello, + _, client, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + Metadata: []codersdk.WorkspaceAgentMetadataDescription{ + { + Key: "greeting", + Interval: 0, + Script: echoHello, + }, }, }, }, 0) @@ -988,13 +1168,15 @@ func TestAgent_Metadata(t *testing.T) { t.Run("Many", func(t *testing.T) { t.Parallel() //nolint:dogsled - _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ - Metadata: []codersdk.WorkspaceAgentMetadataDescription{ - { - Key: "greeting", - Interval: 1, - Timeout: 100, - Script: echoHello, + _, client, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + Metadata: []codersdk.WorkspaceAgentMetadataDescription{ + { + Key: "greeting", + Interval: 1, + Timeout: 100, + Script: echoHello, + }, }, }, }, 0) @@ -1037,17 +1219,19 @@ func TestAgentMetadata_Timing(t *testing.T) { script = "echo hello | tee -a " + greetingPath ) //nolint:dogsled - _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ - Metadata: []codersdk.WorkspaceAgentMetadataDescription{ - { - Key: "greeting", - Interval: reportInterval, - Script: script, - }, - { - Key: "bad", - Interval: reportInterval, - Script: "exit 1", + _, client, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + Metadata: []codersdk.WorkspaceAgentMetadataDescription{ + { + Key: "greeting", + Interval: reportInterval, + Script: script, + }, + { + Key: "bad", + Interval: reportInterval, + Script: "exit 1", + }, }, }, }, 0) @@ -1099,9 +1283,11 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("StartTimeout", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ - StartupScript: "sleep 3", - StartupScriptTimeout: time.Nanosecond, + _, client, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + StartupScript: "sleep 3", + StartupScriptTimeout: time.Nanosecond, + }, }, 0) want := []codersdk.WorkspaceAgentLifecycle{ @@ -1121,9 +1307,11 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("StartError", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ - StartupScript: "false", - StartupScriptTimeout: 30 * time.Second, + _, client, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + StartupScript: "false", + StartupScriptTimeout: 30 * time.Second, + }, }, 0) want := []codersdk.WorkspaceAgentLifecycle{ @@ -1143,9 +1331,11 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("Ready", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, + _, client, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + }, }, 0) want := []codersdk.WorkspaceAgentLifecycle{ @@ -1165,9 +1355,11 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("ShuttingDown", func(t *testing.T) { t.Parallel() - _, client, _, _, closer := setupAgent(t, agentsdk.Manifest{ - ShutdownScript: "sleep 3", - StartupScriptTimeout: 30 * time.Second, + _, client, _, _, closer := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + ShutdownScript: "sleep 3", + StartupScriptTimeout: 30 * time.Second, + }, }, 0) assert.Eventually(t, func() bool { @@ -1203,9 +1395,11 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("ShutdownTimeout", func(t *testing.T) { t.Parallel() - _, client, _, _, closer := setupAgent(t, agentsdk.Manifest{ - ShutdownScript: "sleep 3", - ShutdownScriptTimeout: time.Nanosecond, + _, client, _, _, closer := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + ShutdownScript: "sleep 3", + ShutdownScriptTimeout: time.Nanosecond, + }, }, 0) assert.Eventually(t, func() bool { @@ -1242,9 +1436,11 @@ func TestAgent_Lifecycle(t *testing.T) { t.Run("ShutdownError", func(t *testing.T) { t.Parallel() - _, client, _, _, closer := setupAgent(t, agentsdk.Manifest{ - ShutdownScript: "false", - ShutdownScriptTimeout: 30 * time.Second, + _, client, _, _, closer := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + ShutdownScript: "false", + ShutdownScriptTimeout: 30 * time.Second, + }, }, 0) assert.Eventually(t, func() bool { @@ -1338,10 +1534,12 @@ func TestAgent_Startup(t *testing.T) { t.Run("EmptyDirectory", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "", + _, client, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "", + }, }, 0) assert.Eventually(t, func() bool { return client.getStartup().Version != "" @@ -1352,10 +1550,12 @@ func TestAgent_Startup(t *testing.T) { t.Run("HomeDirectory", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "~", + _, client, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "~", + }, }, 0) assert.Eventually(t, func() bool { return client.getStartup().Version != "" @@ -1368,10 +1568,12 @@ func TestAgent_Startup(t *testing.T) { t.Run("NotAbsoluteDirectory", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "coder/coder", + _, client, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "coder/coder", + }, }, 0) assert.Eventually(t, func() bool { return client.getStartup().Version != "" @@ -1384,10 +1586,12 @@ func TestAgent_Startup(t *testing.T) { t.Run("HomeEnvironmentVariable", func(t *testing.T) { t.Parallel() - _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ - StartupScript: "true", - StartupScriptTimeout: 30 * time.Second, - Directory: "$HOME", + _, client, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + StartupScript: "true", + StartupScriptTimeout: 30 * time.Second, + Directory: "$HOME", + }, }, 0) assert.Eventually(t, func() bool { return client.getStartup().Version != "" @@ -1411,7 +1615,7 @@ func TestAgent_ReconnectingPTY(t *testing.T) { defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, _, _, _, _ := setupAgent(t, &client{}, 0) id := uuid.New() netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash") require.NoError(t, err) @@ -1513,7 +1717,7 @@ func TestAgent_Dial(t *testing.T) { }() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + conn, _, _, _, _ := setupAgent(t, &client{}, 0) require.True(t, conn.AwaitReachable(context.Background())) conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) require.NoError(t, err) @@ -1535,8 +1739,10 @@ func TestAgent_Speedtest(t *testing.T) { defer cancel() derpMap := tailnettest.RunDERPAndSTUN(t) //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{ - DERPMap: derpMap, + conn, _, _, _, _ := setupAgent(t, &client{ + manifest: agentsdk.Manifest{ + DERPMap: derpMap, + }, }, 0) defer conn.Close() res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond) @@ -1622,7 +1828,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) { //nolint:dogsled - agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + agentConn, _, _, _, _ := setupAgent(t, &client{}, 0) listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) waitGroup := sync.WaitGroup{} @@ -1666,11 +1872,20 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*pt return ptytest.Start(t, cmd) } -func setupSSHSession(t *testing.T, options agentsdk.Manifest) *ssh.Session { +func setupSSHSession( + t *testing.T, + options agentsdk.Manifest, + serviceBanner codersdk.ServiceBannerConfig, +) *ssh.Session { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, options, 0) + conn, _, _, _, _ := setupAgent(t, &client{ + manifest: options, + getServiceBanner: func() (codersdk.ServiceBannerConfig, error) { + return serviceBanner, nil + }, + }, 0) sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) t.Cleanup(func() { @@ -1690,32 +1905,25 @@ func (c closeFunc) Close() error { return c() } -func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(agent.Options) agent.Options) ( +func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(agent.Options) agent.Options) ( *codersdk.WorkspaceAgentConn, *client, <-chan *agentsdk.Stats, afero.Fs, io.Closer, ) { + c.t = t logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - if metadata.DERPMap == nil { - metadata.DERPMap = tailnettest.RunDERPAndSTUN(t) + if c.manifest.DERPMap == nil { + c.manifest.DERPMap = tailnettest.RunDERPAndSTUN(t) } - coordinator := tailnet.NewCoordinator(logger) + c.coordinator = tailnet.NewCoordinator(logger) t.Cleanup(func() { - _ = coordinator.Close() + _ = c.coordinator.Close() }) - agentID := uuid.New() - statsCh := make(chan *agentsdk.Stats, 50) + c.agentID = uuid.New() + c.statsChan = make(chan *agentsdk.Stats, 50) fs := afero.NewMemMapFs() - c := &client{ - t: t, - agentID: agentID, - manifest: metadata, - statsChan: statsCh, - coordinator: coordinator, - } - options := agent.Options{ Client: c, Filesystem: fs, @@ -1733,7 +1941,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati }) conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, - DERPMap: metadata.DERPMap, + DERPMap: c.manifest.DERPMap, Logger: logger.Named("client"), }) require.NoError(t, err) @@ -1747,7 +1955,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati }) go func() { defer close(serveClientDone) - coordinator.ServeClient(serverConn, uuid.New(), agentID) + c.coordinator.ServeClient(serverConn, uuid.New(), c.agentID) }() sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { return conn.UpdateNodes(node, false) @@ -1766,7 +1974,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati if !agentConn.AwaitReachable(ctx) { t.Fatal("agent not reachable") } - return agentConn, c, statsCh, fs, closer + return agentConn, c, c.statsChan, fs, closer } var dialTestPayload = []byte("dean-was-here123") @@ -1800,6 +2008,35 @@ func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { assert.Equal(t, len(payload), n, "payload length does not match") } +func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected []string, expectedRe *regexp.Regexp) { + t.Helper() + + err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) + require.NoError(t, err) + + ptty := ptytest.New(t) + var stdout bytes.Buffer + session.Stdout = &stdout + session.Stderr = ptty.Output() + session.Stdin = ptty.Input() + err = session.Shell() + require.NoError(t, err) + + ptty.WriteLine("exit 0") + err = session.Wait() + require.NoError(t, err) + + for _, unexpected := range unexpected { + require.NotContains(t, stdout.String(), unexpected, "should not show output") + } + for _, expect := range expected { + require.Contains(t, stdout.String(), expect, "should show output") + } + if expectedRe != nil { + require.Regexp(t, expectedRe, stdout.String()) + } +} + type client struct { t *testing.T agentID uuid.UUID @@ -1809,6 +2046,7 @@ type client struct { coordinator tailnet.Coordinator lastWorkspaceAgent func() patchWorkspaceLogs func() error + getServiceBanner func() (codersdk.ServiceBannerConfig, error) mu sync.Mutex // Protects following. lifecycleStates []codersdk.WorkspaceAgentLifecycle @@ -1930,6 +2168,15 @@ func (c *client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupL return nil } +func (c *client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.getServiceBanner != nil { + return c.getServiceBanner() + } + return codersdk.ServiceBannerConfig{}, nil +} + // tempDirUnixSocket returns a temporary directory that can safely hold unix // sockets (probably). // @@ -1961,7 +2208,7 @@ func TestAgent_Metrics_SSH(t *testing.T) { registry := prometheus.NewRegistry() //nolint:dogsled - conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(o agent.Options) agent.Options { + conn, _, _, _, _ := setupAgent(t, &client{}, 0, func(o agent.Options) agent.Options { o.PrometheusRegistry = registry return o }) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 62470b9dc11c6..4c3fdcb33aed6 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -29,6 +29,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/agent/usershell" + "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" "github.com/coder/coder/pty" ) @@ -63,9 +64,10 @@ type Server struct { srv *ssh.Server x11SocketDir string - Env map[string]string - AgentToken func() string - Manifest *atomic.Pointer[agentsdk.Manifest] + Env map[string]string + AgentToken func() string + Manifest *atomic.Pointer[agentsdk.Manifest] + ServiceBanner *atomic.Pointer[codersdk.ServiceBannerConfig] connCountVSCode atomic.Int64 connCountJetBrains atomic.Int64 @@ -345,6 +347,17 @@ func (s *Server) startPTYSession(session ptySession, magicTypeLabel string, cmd // See https://github.com/coder/coder/issues/3371. session.DisablePTYEmulation() + if isLoginShell(session.RawCommand()) { + serviceBanner := s.ServiceBanner.Load() + if serviceBanner != nil { + err := showServiceBanner(session, serviceBanner) + if err != nil { + s.logger.Error(ctx, "agent failed to show service banner", slog.Error(err)) + s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "service_banner").Add(1) + } + } + } + if !isQuietLogin(session.RawCommand()) { manifest := s.Manifest.Load() if manifest != nil { @@ -743,12 +756,16 @@ func (*Server) Shutdown(_ context.Context) error { return nil } +func isLoginShell(rawCommand string) bool { + return len(rawCommand) == 0 +} + // isQuietLogin checks if the SSH server should perform a quiet login or not. // // https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816 func isQuietLogin(rawCommand string) bool { // We are always quiet unless this is a login shell. - if len(rawCommand) != 0 { + if !isLoginShell(rawCommand) { return true } @@ -763,6 +780,18 @@ func isQuietLogin(rawCommand string) bool { return err == nil } +// showServiceBanner will write the service banner if enabled and not blank +// along with a blank line for spacing. +func showServiceBanner(session io.Writer, banner *codersdk.ServiceBannerConfig) error { + if banner.Enabled && banner.Message != "" { + // The banner supports Markdown so we might want to parse it but Markdown is + // still fairly readable in its raw form. + message := strings.TrimSpace(banner.Message) + "\n\n" + return writeWithCarriageReturn(strings.NewReader(message), session) + } + return nil +} + // showMOTD will output the message of the day from // the given filename to dest, if the file exists. // @@ -782,19 +811,22 @@ func showMOTD(dest io.Writer, filename string) error { } defer f.Close() - s := bufio.NewScanner(f) + return writeWithCarriageReturn(f, dest) +} + +// writeWithCarriageReturn writes each line with a carriage return to ensure +// that each line starts at the beginning of the terminal. +func writeWithCarriageReturn(src io.Reader, dest io.Writer) error { + s := bufio.NewScanner(src) for s.Scan() { - // Carriage return ensures each line starts - // at the beginning of the terminal. - _, err = fmt.Fprint(dest, s.Text()+"\r\n") + _, err := fmt.Fprint(dest, s.Text()+"\r\n") if err != nil { - return xerrors.Errorf("write MOTD: %w", err) + return xerrors.Errorf("write line: %w", err) } } if err := s.Err(); err != nil { - return xerrors.Errorf("read MOTD: %w", err) + return xerrors.Errorf("read line: %w", err) } - return nil } diff --git a/coderd/coderd.go b/coderd/coderd.go index 32166d180fa4a..313b7d9b8ac1f 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -674,7 +674,10 @@ func New(options *Options) *API { r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity) r.Get("/connection", api.workspaceAgentConnectionGeneric) r.Route("/me", func(r chi.Router) { - r.Use(httpmw.ExtractWorkspaceAgent(options.Database)) + r.Use(httpmw.ExtractWorkspaceAgent(httpmw.ExtractWorkspaceAgentConfig{ + DB: options.Database, + Optional: false, + })) r.Get("/manifest", api.workspaceAgentManifest) // This route is deprecated and will be removed in a future release. // New agents will use /me/manifest instead. diff --git a/coderd/httpmw/actor.go b/coderd/httpmw/actor.go index ba0ab1011d73d..7df5294b17c49 100644 --- a/coderd/httpmw/actor.go +++ b/coderd/httpmw/actor.go @@ -35,3 +35,32 @@ func RequireAPIKeyOrWorkspaceProxyAuth() func(http.Handler) http.Handler { }) } } + +// RequireAPIKeyOrWorkspaceAgent is middleware that should be inserted after +// optional ExtractAPIKey and ExtractWorkspaceAgent middlewares to ensure one of +// the two is provided. +// +// If both are provided an error is returned to avoid misuse. +func RequireAPIKeyOrWorkspaceAgent() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, hasAPIKey := APIKeyOptional(r) + _, hasWorkspaceAgent := WorkspaceAgentOptional(r) + + if hasAPIKey && hasWorkspaceAgent { + httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{ + Message: "API key and workspace agent token provided, but only one is allowed", + }) + return + } + if !hasAPIKey && !hasWorkspaceAgent { + httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{ + Message: "API key or workspace agent token required, but none provided", + }) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/coderd/httpmw/workspaceagent.go b/coderd/httpmw/workspaceagent.go index 3bfe946c05fa2..f039c6bbf7afb 100644 --- a/coderd/httpmw/workspaceagent.go +++ b/coderd/httpmw/workspaceagent.go @@ -18,40 +18,67 @@ import ( type workspaceAgentContextKey struct{} +func WorkspaceAgentOptional(r *http.Request) (database.WorkspaceAgent, bool) { + user, ok := r.Context().Value(workspaceAgentContextKey{}).(database.WorkspaceAgent) + return user, ok +} + // WorkspaceAgent returns the workspace agent from the ExtractAgent handler. func WorkspaceAgent(r *http.Request) database.WorkspaceAgent { - user, ok := r.Context().Value(workspaceAgentContextKey{}).(database.WorkspaceAgent) + user, ok := WorkspaceAgentOptional(r) if !ok { - panic("developer error: agent middleware not provided") + panic("developer error: agent middleware not provided or was made optional") } return user } +type ExtractWorkspaceAgentConfig struct { + DB database.Store + // Optional indicates whether the middleware should be optional. If true, any + // requests without the a token or with an invalid token will be allowed to + // continue and no workspace agent will be set on the request context. + Optional bool +} + // ExtractWorkspaceAgent requires authentication using a valid agent token. -func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler { +func ExtractWorkspaceAgent(opts ExtractWorkspaceAgentConfig) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + + // optionalWrite wraps httpapi.Write but runs the next handler if the + // token is optional. + // + // It should be used when the token is not provided or is invalid, but not + // when there are other errors. + optionalWrite := func(code int, response codersdk.Response) { + if opts.Optional { + next.ServeHTTP(rw, r) + return + } + httpapi.Write(ctx, rw, code, response) + } + tokenValue := APITokenFromRequest(r) if tokenValue == "" { - httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{ + optionalWrite(http.StatusUnauthorized, codersdk.Response{ Message: fmt.Sprintf("Cookie %q must be provided.", codersdk.SessionTokenCookie), }) return } token, err := uuid.Parse(tokenValue) if err != nil { - httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{ + optionalWrite(http.StatusUnauthorized, codersdk.Response{ Message: "Workspace agent token invalid.", Detail: fmt.Sprintf("An agent token must be a valid UUIDv4. (len %d)", len(tokenValue)), }) return } //nolint:gocritic // System needs to be able to get workspace agents. - agent, err := db.GetWorkspaceAgentByAuthToken(dbauthz.AsSystemRestricted(ctx), token) + agent, err := opts.DB.GetWorkspaceAgentByAuthToken(dbauthz.AsSystemRestricted(ctx), token) if err != nil { if errors.Is(err, sql.ErrNoRows) { - httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{ + optionalWrite(http.StatusUnauthorized, codersdk.Response{ Message: "Workspace agent not authorized.", Detail: "The agent cannot authenticate until the workspace provision job has been completed. If the job is no longer running, this agent is invalid.", }) @@ -66,7 +93,7 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler { } //nolint:gocritic // System needs to be able to get workspace agents. - subject, err := getAgentSubject(dbauthz.AsSystemRestricted(ctx), db, agent) + subject, err := getAgentSubject(dbauthz.AsSystemRestricted(ctx), opts.DB, agent) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace agent.", diff --git a/coderd/httpmw/workspaceagent_test.go b/coderd/httpmw/workspaceagent_test.go index bcf6ee2f7e0e2..5b50aa14b4802 100644 --- a/coderd/httpmw/workspaceagent_test.go +++ b/coderd/httpmw/workspaceagent_test.go @@ -30,7 +30,10 @@ func TestWorkspaceAgent(t *testing.T) { db := dbfake.New() rtr := chi.NewRouter() rtr.Use( - httpmw.ExtractWorkspaceAgent(db), + httpmw.ExtractWorkspaceAgent(httpmw.ExtractWorkspaceAgentConfig{ + DB: db, + Optional: false, + }), ) rtr.Get("/", nil) r := setup(db, uuid.New()) @@ -65,7 +68,10 @@ func TestWorkspaceAgent(t *testing.T) { rtr := chi.NewRouter() rtr.Use( - httpmw.ExtractWorkspaceAgent(db), + httpmw.ExtractWorkspaceAgent(httpmw.ExtractWorkspaceAgentConfig{ + DB: db, + Optional: false, + }), ) rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { _ = httpmw.WorkspaceAgent(r) diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 6fdecbcf7bf3f..34b92267080e5 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -257,3 +257,7 @@ func (*client) PostStartup(_ context.Context, _ agentsdk.PostStartupRequest) err func (*client) PatchStartupLogs(_ context.Context, _ agentsdk.PatchStartupLogs) error { return nil } + +func (*client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) { + return codersdk.ServiceBannerConfig{}, nil +} diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index ac0211cf2d37e..6078cbc8c0b94 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -593,6 +593,24 @@ func (c *Client) PatchStartupLogs(ctx context.Context, req PatchStartupLogs) err return nil } +// GetServiceBanner relays the service banner config. +func (c *Client) GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerConfig, error) { + res, err := c.SDK.Request(ctx, http.MethodGet, "/api/v2/appearance", nil) + if err != nil { + return codersdk.ServiceBannerConfig{}, err + } + defer res.Body.Close() + // If the route does not exist then Enterprise code is not enabled. + if res.StatusCode == http.StatusNotFound { + return codersdk.ServiceBannerConfig{}, nil + } + if res.StatusCode != http.StatusOK { + return codersdk.ServiceBannerConfig{}, codersdk.ReadBodyAsError(res) + } + var cfg codersdk.AppearanceConfig + return cfg.ServiceBanner, json.NewDecoder(res.Body).Decode(&cfg) +} + type GitAuthResponse struct { Username string `json:"username"` Password string `json:"password"` diff --git a/enterprise/coderd/appearance_test.go b/enterprise/coderd/appearance_test.go index b7772acee73c5..e654006414a1f 100644 --- a/enterprise/coderd/appearance_test.go +++ b/enterprise/coderd/appearance_test.go @@ -11,68 +11,136 @@ import ( "github.com/coder/coder/cli/clibase" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" + "github.com/coder/coder/codersdk/agentsdk" "github.com/coder/coder/enterprise/coderd" "github.com/coder/coder/enterprise/coderd/coderdenttest" "github.com/coder/coder/enterprise/coderd/license" + "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/testutil" + "github.com/google/uuid" ) func TestServiceBanners(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - adminClient := coderdenttest.New(t, &coderdenttest.Options{}) - - adminUser := coderdtest.CreateFirstUser(t, adminClient) - - // Even without a license, the banner should return as disabled. - sb, err := adminClient.Appearance(ctx) - require.NoError(t, err) - require.False(t, sb.ServiceBanner.Enabled) - - coderdenttest.AddLicense(t, adminClient, coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureAppearance: 1, - }, + t.Run("User", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + adminClient := coderdenttest.New(t, &coderdenttest.Options{}) + + adminUser := coderdtest.CreateFirstUser(t, adminClient) + + // Even without a license, the banner should return as disabled. + sb, err := adminClient.Appearance(ctx) + require.NoError(t, err) + require.False(t, sb.ServiceBanner.Enabled) + + coderdenttest.AddLicense(t, adminClient, coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureAppearance: 1, + }, + }) + + // Default state + sb, err = adminClient.Appearance(ctx) + require.NoError(t, err) + require.False(t, sb.ServiceBanner.Enabled) + + basicUserClient, _ := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID) + + uac := codersdk.UpdateAppearanceConfig{ + ServiceBanner: sb.ServiceBanner, + } + // Regular user should be unable to set the banner + uac.ServiceBanner.Enabled = true + + err = basicUserClient.UpdateAppearance(ctx, uac) + require.Error(t, err) + var sdkError *codersdk.Error + require.True(t, errors.As(err, &sdkError)) + require.Equal(t, http.StatusForbidden, sdkError.StatusCode()) + + // But an admin can + wantBanner := uac + wantBanner.ServiceBanner.Enabled = true + wantBanner.ServiceBanner.Message = "Hey" + wantBanner.ServiceBanner.BackgroundColor = "#00FF00" + err = adminClient.UpdateAppearance(ctx, wantBanner) + require.NoError(t, err) + gotBanner, err := adminClient.Appearance(ctx) + require.NoError(t, err) + gotBanner.SupportLinks = nil // clean "support links" before comparison + require.Equal(t, wantBanner.ServiceBanner, gotBanner.ServiceBanner) + + // But even an admin can't give a bad color + wantBanner.ServiceBanner.BackgroundColor = "#bad color" + err = adminClient.UpdateAppearance(ctx, wantBanner) + require.Error(t, err) }) - // Default state - sb, err = adminClient.Appearance(ctx) - require.NoError(t, err) - require.False(t, sb.ServiceBanner.Enabled) - - basicUserClient, _ := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID) - - uac := codersdk.UpdateAppearanceConfig{ - ServiceBanner: sb.ServiceBanner, - } - // Regular user should be unable to set the banner - uac.ServiceBanner.Enabled = true - - err = basicUserClient.UpdateAppearance(ctx, uac) - require.Error(t, err) - var sdkError *codersdk.Error - require.True(t, errors.As(err, &sdkError)) - require.Equal(t, http.StatusForbidden, sdkError.StatusCode()) - - // But an admin can - wantBanner := uac - wantBanner.ServiceBanner.Enabled = true - wantBanner.ServiceBanner.Message = "Hey" - wantBanner.ServiceBanner.BackgroundColor = "#00FF00" - err = adminClient.UpdateAppearance(ctx, wantBanner) - require.NoError(t, err) - gotBanner, err := adminClient.Appearance(ctx) - require.NoError(t, err) - gotBanner.SupportLinks = nil // clean "support links" before comparison - require.Equal(t, wantBanner.ServiceBanner, gotBanner.ServiceBanner) - - // But even an admin can't give a bad color - wantBanner.ServiceBanner.BackgroundColor = "#bad color" - err = adminClient.UpdateAppearance(ctx, wantBanner) - require.Error(t, err) + t.Run("Agent", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + client := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + }, + }) + user := coderdtest.CreateFirstUser(t, client) + license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureAppearance: 1, + }, + }) + cfg := codersdk.UpdateAppearanceConfig{ + ServiceBanner: codersdk.ServiceBannerConfig{ + Enabled: true, + Message: "Hey", + BackgroundColor: "#00FF00", + }, + } + err := client.UpdateAppearance(ctx, cfg) + require.NoError(t, err) + + authToken := uuid.NewString() + agentClient := agentsdk.New(client.URL) + agentClient.SetSessionToken(authToken) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: echo.ProvisionComplete, + ProvisionApply: echo.ProvisionApplyWithAgent(authToken), + }) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + + banner, err := agentClient.GetServiceBanner(ctx) + require.NoError(t, err) + require.Equal(t, cfg.ServiceBanner, banner) + + // No enterprise means a 404 on the endpoint meaning no banner. + client = coderdtest.New(t, &coderdtest.Options{ + IncludeProvisionerDaemon: true, + }) + agentClient = agentsdk.New(client.URL) + agentClient.SetSessionToken(authToken) + banner, err = agentClient.GetServiceBanner(ctx) + require.NoError(t, err) + require.Equal(t, codersdk.ServiceBannerConfig{}, banner) + + // No license means no banner. + client.DeleteLicense(ctx, license.ID) + banner, err = agentClient.GetServiceBanner(ctx) + require.NoError(t, err) + require.Equal(t, codersdk.ServiceBannerConfig{}, banner) + }) } func TestCustomSupportLinks(t *testing.T) { diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 1df66af44a5e6..117bce226b46a 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -78,6 +78,12 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { OAuth2Configs: oauthConfigs, RedirectToLogin: false, }) + apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: options.Database, + OAuth2Configs: oauthConfigs, + RedirectToLogin: false, + Optional: true, + }) deploymentID, err := options.Database.GetDeploymentID(ctx) if err != nil { @@ -192,11 +198,23 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { }) }) r.Route("/appearance", func(r chi.Router) { - r.Use( - apiKeyMiddleware, - ) - r.Get("/", api.appearance) - r.Put("/", api.putAppearance) + r.Group(func(r chi.Router) { + r.Use( + apiKeyMiddlewareOptional, + httpmw.ExtractWorkspaceAgent(httpmw.ExtractWorkspaceAgentConfig{ + DB: options.Database, + Optional: true, + }), + httpmw.RequireAPIKeyOrWorkspaceAgent(), + ) + r.Get("/", api.appearance) + }) + r.Group(func(r chi.Router) { + r.Use( + apiKeyMiddleware, + ) + r.Put("/", api.putAppearance) + }) }) })