Skip to content

Commit 7b28cfa

Browse files
committed
chore(agent/agentssh): extract CreateCommandDeps
1 parent 7fd04d4 commit 7b28cfa

File tree

5 files changed

+53
-11
lines changed

5 files changed

+53
-11
lines changed

agent/agent.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
340340
// if it can guarantee the clocks are synchronized.
341341
CollectedAt: now,
342342
}
343-
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
343+
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil, nil)
344344
if err != nil {
345345
result.Error = fmt.Sprintf("create cmd: %+v", err)
346346
return result

agent/agentscripts/agentscripts.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
283283
cmdCtx, ctxCancel = context.WithTimeout(ctx, script.Timeout)
284284
defer ctxCancel()
285285
}
286-
cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil)
286+
cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil, nil)
287287
if err != nil {
288288
return xerrors.Errorf("%s script: create command: %w", logPath, err)
289289
}

agent/agentssh/agentssh.go

+48-6
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
409409
magicTypeLabel := magicTypeMetricLabel(magicType)
410410
sshPty, windowSize, isPty := session.Pty()
411411

412-
cmd, err := s.CreateCommand(ctx, session.RawCommand(), env)
412+
cmd, err := s.CreateCommand(ctx, session.RawCommand(), env, nil)
413413
if err != nil {
414414
ptyLabel := "no"
415415
if isPty {
@@ -670,17 +670,59 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
670670
_ = session.Exit(1)
671671
}
672672

673+
// CreateCommandDeps encapsulates external information required by CreateCommand.
674+
type CreateCommandDeps interface {
675+
// CurrentUser returns the current user.
676+
CurrentUser() (*user.User, error)
677+
// Environ returns the environment variables of the current process.
678+
Environ() []string
679+
// UserHomeDir returns the home directory of the current user.
680+
UserHomeDir() (string, error)
681+
// UserShell returns the shell of the given user.
682+
UserShell(username string) (string, error)
683+
}
684+
685+
type systemCreateCommandDeps struct{}
686+
687+
var defaultCreateCommandDeps CreateCommandDeps = &systemCreateCommandDeps{}
688+
689+
// DefaultCreateCommandDeps returns a default implementation of
690+
// CreateCommandDeps. This reads information using the default Go
691+
// implementations.
692+
func DefaultCreateCommandDeps() CreateCommandDeps {
693+
return defaultCreateCommandDeps
694+
}
695+
func (systemCreateCommandDeps) CurrentUser() (*user.User, error) {
696+
return user.Current()
697+
}
698+
func (systemCreateCommandDeps) Environ() []string {
699+
return os.Environ()
700+
}
701+
func (systemCreateCommandDeps) UserHomeDir() (string, error) {
702+
return userHomeDir()
703+
}
704+
func (systemCreateCommandDeps) UserShell(username string) (string, error) {
705+
return usershell.Get(username)
706+
}
707+
673708
// CreateCommand processes raw command input with OpenSSH-like behavior.
674709
// If the script provided is empty, it will default to the users shell.
675710
// This injects environment variables specified by the user at launch too.
676-
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) {
677-
currentUser, err := user.Current()
711+
// The final argument is an interface that allows the caller to provide
712+
// alternative implementations for the dependencies of CreateCommand.
713+
// This is useful when creating a command to be run in a separate environment
714+
// (for example, a Docker container). Pass in nil to use the default.
715+
func (s *Server) CreateCommand(ctx context.Context, script string, env []string, deps CreateCommandDeps) (*pty.Cmd, error) {
716+
if deps == nil {
717+
deps = DefaultCreateCommandDeps()
718+
}
719+
currentUser, err := deps.CurrentUser()
678720
if err != nil {
679721
return nil, xerrors.Errorf("get current user: %w", err)
680722
}
681723
username := currentUser.Username
682724

683-
shell, err := usershell.Get(username)
725+
shell, err := deps.UserShell(username)
684726
if err != nil {
685727
return nil, xerrors.Errorf("get user shell: %w", err)
686728
}
@@ -736,13 +778,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
736778
_, err = os.Stat(cmd.Dir)
737779
if cmd.Dir == "" || err != nil {
738780
// Default to user home if a directory is not set.
739-
homedir, err := userHomeDir()
781+
homedir, err := deps.UserHomeDir()
740782
if err != nil {
741783
return nil, xerrors.Errorf("get home dir: %w", err)
742784
}
743785
cmd.Dir = homedir
744786
}
745-
cmd.Env = append(os.Environ(), env...)
787+
cmd.Env = append(deps.Environ(), env...)
746788
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
747789

748790
// Set SSH connection environment variables (these are also set by OpenSSH

agent/agentssh/agentssh_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
8787
t.Run("Basic", func(t *testing.T) {
8888
t.Parallel()
8989
cmd, err := s.CreateCommand(ctx, `#!/bin/bash
90-
echo test`, nil)
90+
echo test`, nil, nil)
9191
require.NoError(t, err)
9292
output, err := cmd.AsExec().CombinedOutput()
9393
require.NoError(t, err)
@@ -96,7 +96,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
9696
t.Run("Args", func(t *testing.T) {
9797
t.Parallel()
9898
cmd, err := s.CreateCommand(ctx, `#!/usr/bin/env bash
99-
echo test`, nil)
99+
echo test`, nil, nil)
100100
require.NoError(t, err)
101101
output, err := cmd.AsExec().CombinedOutput()
102102
require.NoError(t, err)

agent/reconnectingpty/server.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
159159
}()
160160

161161
// Empty command will default to the users shell!
162-
cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil)
162+
cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil, nil)
163163
if err != nil {
164164
s.errorsTotal.WithLabelValues("create_command").Add(1)
165165
return xerrors.Errorf("create command: %w", err)

0 commit comments

Comments
 (0)