diff --git a/agent/agent.go b/agent/agent.go index 28ea524bf3da3..8ff6d68d25f0b 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -340,7 +340,7 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM // if it can guarantee the clocks are synchronized. CollectedAt: now, } - cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil) + cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil, nil) if err != nil { result.Error = fmt.Sprintf("create cmd: %+v", err) return result diff --git a/agent/agentscripts/agentscripts.go b/agent/agentscripts/agentscripts.go index 25ea0ba46fcf3..bd83d71875c73 100644 --- a/agent/agentscripts/agentscripts.go +++ b/agent/agentscripts/agentscripts.go @@ -283,7 +283,7 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript, cmdCtx, ctxCancel = context.WithTimeout(ctx, script.Timeout) defer ctxCancel() } - cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil) + cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil, nil) if err != nil { return xerrors.Errorf("%s script: create command: %w", logPath, err) } diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index dae1b73b2de6c..d17e9cd761fe6 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -409,7 +409,7 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv magicTypeLabel := magicTypeMetricLabel(magicType) sshPty, windowSize, isPty := session.Pty() - cmd, err := s.CreateCommand(ctx, session.RawCommand(), env) + cmd, err := s.CreateCommand(ctx, session.RawCommand(), env, nil) if err != nil { ptyLabel := "no" if isPty { @@ -670,17 +670,63 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) { _ = session.Exit(1) } +// EnvInfoer encapsulates external information required by CreateCommand. +type EnvInfoer interface { + // CurrentUser returns the current user. + CurrentUser() (*user.User, error) + // Environ returns the environment variables of the current process. + Environ() []string + // UserHomeDir returns the home directory of the current user. + UserHomeDir() (string, error) + // UserShell returns the shell of the given user. + UserShell(username string) (string, error) +} + +type systemEnvInfoer struct{} + +var defaultEnvInfoer EnvInfoer = &systemEnvInfoer{} + +// DefaultEnvInfoer returns a default implementation of +// EnvInfoer. This reads information using the default Go +// implementations. +func DefaultEnvInfoer() EnvInfoer { + return defaultEnvInfoer +} + +func (systemEnvInfoer) CurrentUser() (*user.User, error) { + return user.Current() +} + +func (systemEnvInfoer) Environ() []string { + return os.Environ() +} + +func (systemEnvInfoer) UserHomeDir() (string, error) { + return userHomeDir() +} + +func (systemEnvInfoer) UserShell(username string) (string, error) { + return usershell.Get(username) +} + // CreateCommand processes raw command input with OpenSSH-like behavior. // If the script provided is empty, it will default to the users shell. // This injects environment variables specified by the user at launch too. -func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) { - currentUser, err := user.Current() +// The final argument is an interface that allows the caller to provide +// alternative implementations for the dependencies of CreateCommand. +// This is useful when creating a command to be run in a separate environment +// (for example, a Docker container). Pass in nil to use the default. +func (s *Server) CreateCommand(ctx context.Context, script string, env []string, deps EnvInfoer) (*pty.Cmd, error) { + if deps == nil { + deps = DefaultEnvInfoer() + } + currentUser, err := deps.CurrentUser() if err != nil { return nil, xerrors.Errorf("get current user: %w", err) } username := currentUser.Username - shell, err := usershell.Get(username) + shell, err := deps.UserShell(username) if err != nil { return nil, xerrors.Errorf("get user shell: %w", err) } @@ -736,13 +782,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) _, err = os.Stat(cmd.Dir) if cmd.Dir == "" || err != nil { // Default to user home if a directory is not set. - homedir, err := userHomeDir() + homedir, err := deps.UserHomeDir() if err != nil { return nil, xerrors.Errorf("get home dir: %w", err) } cmd.Dir = homedir } - cmd.Env = append(os.Environ(), env...) + cmd.Env = append(deps.Environ(), env...) cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username)) // Set SSH connection environment variables (these are also set by OpenSSH diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 76321e6e19d85..b9cec420e5651 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "net" + "os/user" "runtime" "strings" "sync" @@ -87,7 +88,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() cmd, err := s.CreateCommand(ctx, `#!/bin/bash - echo test`, nil) + echo test`, nil, nil) require.NoError(t, err) output, err := cmd.AsExec().CombinedOutput() require.NoError(t, err) @@ -96,12 +97,45 @@ func TestNewServer_ExecuteShebang(t *testing.T) { t.Run("Args", func(t *testing.T) { t.Parallel() cmd, err := s.CreateCommand(ctx, `#!/usr/bin/env bash - echo test`, nil) + echo test`, nil, nil) require.NoError(t, err) output, err := cmd.AsExec().CombinedOutput() require.NoError(t, err) require.Equal(t, "test\n", string(output)) }) + t.Run("CustomEnvInfoer", func(t *testing.T) { + t.Parallel() + ei := &fakeEnvInfoer{ + CurrentUserFn: func() (u *user.User, err error) { + return nil, assert.AnError + }, + } + _, err := s.CreateCommand(ctx, `whatever`, nil, ei) + require.ErrorIs(t, err, assert.AnError) + }) +} + +type fakeEnvInfoer struct { + CurrentUserFn func() (*user.User, error) + EnvironFn func() []string + UserHomeDirFn func() (string, error) + UserShellFn func(string) (string, error) +} + +func (f *fakeEnvInfoer) CurrentUser() (u *user.User, err error) { + return f.CurrentUserFn() +} + +func (f *fakeEnvInfoer) Environ() []string { + return f.EnvironFn() +} + +func (f *fakeEnvInfoer) UserHomeDir() (string, error) { + return f.UserHomeDirFn() +} + +func (f *fakeEnvInfoer) UserShell(u string) (string, error) { + return f.UserShellFn(u) } func TestNewServer_CloseActiveConnections(t *testing.T) { diff --git a/agent/reconnectingpty/server.go b/agent/reconnectingpty/server.go index d48c7abec9353..465667c616180 100644 --- a/agent/reconnectingpty/server.go +++ b/agent/reconnectingpty/server.go @@ -159,7 +159,7 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co }() // Empty command will default to the users shell! - cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil) + cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil, nil) if err != nil { s.errorsTotal.WithLabelValues("create_command").Add(1) return xerrors.Errorf("create command: %w", err)