Skip to content

test(agent): use afero for motd tests to allow parallel execution #8329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentL
// not be fetched immediately; the expectation is that it is primed elsewhere
// (and must be done before the session actually starts).
func (a *agent) fetchServiceBannerLoop(ctx context.Context) {
ticker := time.NewTicker(adjustIntervalForTests(2*time.Minute, time.Millisecond*100))
ticker := time.NewTicker(adjustIntervalForTests(2*time.Minute, time.Millisecond*5))
defer ticker.Stop()
for {
select {
Expand Down
104 changes: 55 additions & 49 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func TestAgent_Stats_Magic(t *testing.T) {

func TestAgent_SessionExec(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{})
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)

command := "echo test"
if runtime.GOOS == "windows" {
Expand All @@ -205,7 +205,7 @@ func TestAgent_SessionExec(t *testing.T) {

func TestAgent_GitSSH(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{})
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $GIT_SSH_COMMAND'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
Expand All @@ -225,7 +225,7 @@ func TestAgent_SessionTTYShell(t *testing.T) {
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{})
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "sh"
if runtime.GOOS == "windows" {
command = "cmd.exe"
Expand All @@ -248,7 +248,7 @@ func TestAgent_SessionTTYShell(t *testing.T) {

func TestAgent_SessionTTYExitCode(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{})
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "areallynotrealcommand"
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
Expand All @@ -268,25 +268,22 @@ func TestAgent_SessionTTYExitCode(t *testing.T) {
}
}

//nolint:paralleltest // This test sets an environment variable.
func TestAgent_Session_TTY_MOTD(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
// It's difficult to find extensive tests for it, so
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}

wantMOTD := "Welcome to your Coder workspace!"
wantServiceBanner := "Service banner text goes here"
u, err := user.Current()
require.NoError(t, err, "get current user")

tmpdir := t.TempDir()
name := filepath.Join(tmpdir, "motd")
err := os.WriteFile(name, []byte(wantMOTD), 0o600)
require.NoError(t, err, "write motd file")
name := filepath.Join(u.HomeDir, "motd")

// Set HOME so we can ensure no ~/.hushlogin is present.
t.Setenv("HOME", tmpdir)
wantMOTD := "Welcome to your Coder workspace!"
wantServiceBanner := "Service banner text goes here"

tests := []struct {
name string
Expand Down Expand Up @@ -362,14 +359,20 @@ func TestAgent_Session_TTY_MOTD(t *testing.T) {
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
session := setupSSHSession(t, test.manifest, test.banner)
t.Parallel()
session := setupSSHSession(t, test.manifest, test.banner, func(fs afero.Fs) {
err := fs.MkdirAll(filepath.Dir(name), 0o700)
require.NoError(t, err)
err = afero.WriteFile(fs, name, []byte(wantMOTD), 0o600)
require.NoError(t, err)
})
testSessionOutput(t, session, test.expected, test.unexpected, test.expectedRe)
})
}
}

//nolint:paralleltest // This test sets an environment variable.
func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
// It's difficult to find extensive tests for it, so
Expand All @@ -380,11 +383,6 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
// Only the banner updates dynamically; the MOTD file does not.
wantServiceBanner := "Service banner text goes here"

tmpdir := t.TempDir()

// Set HOME so we can ensure no ~/.hushlogin is present.
t.Setenv("HOME", tmpdir)

tests := []struct {
banner codersdk.ServiceBannerConfig
expected []string
Expand Down Expand Up @@ -424,22 +422,25 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
},
}

const updateInterval = 100 * time.Millisecond

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled // Allow the blank identifiers.
conn, client, _, _, _ := setupAgent(t, &client{}, 0)
for _, test := range tests {
test := test
// Set new banner func and wait for the agent to call it to update the
// banner.

ready := make(chan struct{}, 2)
client.mu.Lock()
client.getServiceBanner = func() (codersdk.ServiceBannerConfig, error) {
select {
case ready <- struct{}{}:
default:
}
return test.banner, nil
}
client.mu.Unlock()
time.Sleep(updateInterval)
<-ready
<-ready // Wait for two updates to ensure the value has propagated.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: This is the race fix, if we only wait for one ready, the service banner may not yet be stored in the agents atomic pointer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh good find! Thank you for fixing!


sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
Expand All @@ -466,50 +467,51 @@ func TestAgent_Session_TTY_QuietLogin(t *testing.T) {
}

wantNotMOTD := "Welcome to your Coder workspace!"
wantServiceBanner := "Service banner text goes here"
wantMaybeServiceBanner := "Service banner text goes here"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I like the use of maybe; I wish I had thought of that.


tmpdir := t.TempDir()
name := filepath.Join(tmpdir, "motd")
err := os.WriteFile(name, []byte(wantNotMOTD), 0o600)
require.NoError(t, err, "write motd file")
u, err := user.Current()
require.NoError(t, err, "get current user")

// Set HOME so we can ensure ~/.hushlogin is present.
t.Setenv("HOME", tmpdir)
name := filepath.Join(u.HomeDir, "motd")

// Neither banner nor MOTD should show if not a login shell.
t.Run("NotLogin", func(t *testing.T) {
wantEcho := "foobar"
session := setupSSHSession(t, agentsdk.Manifest{
MOTDFile: name,
}, codersdk.ServiceBannerConfig{
Enabled: true,
Message: wantServiceBanner,
Message: wantMaybeServiceBanner,
}, func(fs afero.Fs) {
err := afero.WriteFile(fs, name, []byte(wantNotMOTD), 0o600)
require.NoError(t, err, "write motd file")
})
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)

wantEcho := "foobar"
command := "echo " + wantEcho
output, err := session.Output(command)
require.NoError(t, err)

require.Contains(t, string(output), wantEcho, "should show echo")
require.NotContains(t, string(output), wantNotMOTD, "should not show motd")
require.NotContains(t, string(output), wantServiceBanner, "should not show service banner")
require.NotContains(t, string(output), wantMaybeServiceBanner, "should not show service banner")
})

// Only the MOTD should be silenced.
// Only the MOTD should be silenced when hushlogin is present.
t.Run("Hushlogin", func(t *testing.T) {
// Create hushlogin to silence motd.
f, err := os.Create(filepath.Join(tmpdir, ".hushlogin"))
require.NoError(t, err, "create .hushlogin file")
err = f.Close()
require.NoError(t, err, "close .hushlogin file")

session := setupSSHSession(t, agentsdk.Manifest{
MOTDFile: name,
}, codersdk.ServiceBannerConfig{
Enabled: true,
Message: wantServiceBanner,
Message: wantMaybeServiceBanner,
}, func(fs afero.Fs) {
err := afero.WriteFile(fs, name, []byte(wantNotMOTD), 0o600)
require.NoError(t, err, "write motd file")

// Create hushlogin to silence motd.
err = afero.WriteFile(fs, name, []byte{}, 0o600)
require.NoError(t, err, "write hushlogin file")
})
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
Expand All @@ -527,7 +529,7 @@ func TestAgent_Session_TTY_QuietLogin(t *testing.T) {
require.NoError(t, err)

require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd")
require.Contains(t, stdout.String(), wantServiceBanner, "should show service banner")
require.Contains(t, stdout.String(), wantMaybeServiceBanner, "should show service banner")
})
}

Expand Down Expand Up @@ -975,7 +977,7 @@ func TestAgent_EnvironmentVariables(t *testing.T) {
EnvironmentVariables: map[string]string{
key: value,
},
}, codersdk.ServiceBannerConfig{})
}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
Expand All @@ -992,7 +994,7 @@ func TestAgent_EnvironmentVariableExpansion(t *testing.T) {
EnvironmentVariables: map[string]string{
key: "$SOMETHINGNOTSET",
},
}, codersdk.ServiceBannerConfig{})
}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
Expand All @@ -1015,7 +1017,7 @@ func TestAgent_CoderEnvVars(t *testing.T) {
t.Run(key, func(t *testing.T) {
t.Parallel()

session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{})
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
Expand All @@ -1038,7 +1040,7 @@ func TestAgent_SSHConnectionEnvVars(t *testing.T) {
t.Run(key, func(t *testing.T) {
t.Parallel()

session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{})
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
Expand Down Expand Up @@ -1876,16 +1878,20 @@ func setupSSHSession(
t *testing.T,
options agentsdk.Manifest,
serviceBanner codersdk.ServiceBannerConfig,
prepareFS func(fs afero.Fs),
) *ssh.Session {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, &client{
conn, _, _, fs, _ := setupAgent(t, &client{
manifest: options,
getServiceBanner: func() (codersdk.ServiceBannerConfig, error) {
return serviceBanner, nil
},
}, 0)
if prepareFS != nil {
prepareFS(fs)
}
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
t.Cleanup(func() {
Expand Down
6 changes: 3 additions & 3 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ func (s *Server) startPTYSession(session ptySession, magicTypeLabel string, cmd
if !isQuietLogin(session.RawCommand()) {
manifest := s.Manifest.Load()
if manifest != nil {
err := showMOTD(session, manifest.MOTDFile)
err := showMOTD(s.fs, session, manifest.MOTDFile)
if err != nil {
s.logger.Error(ctx, "agent failed to show MOTD", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "motd").Add(1)
Expand Down Expand Up @@ -796,12 +796,12 @@ func showServiceBanner(session io.Writer, banner *codersdk.ServiceBannerConfig)
// the given filename to dest, if the file exists.
//
// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L784
func showMOTD(dest io.Writer, filename string) error {
func showMOTD(fs afero.Fs, dest io.Writer, filename string) error {
if filename == "" {
return nil
}

f, err := os.Open(filename)
f, err := fs.Open(filename)
if err != nil {
if xerrors.Is(err, os.ErrNotExist) {
// This is not an error, there simply isn't a MOTD to show.
Expand Down