Skip to content

Commit 5801da2

Browse files
johnstcnaslilac
authored andcommitted
chore(agent/agentssh): extract CreateCommandDeps (#16603)
Extracts environment-level dependencies of `agentssh.Server.CreateCommand()` to an interface to allow alternative implementations to be passed in.
1 parent 5aa526f commit 5801da2

File tree

5 files changed

+91
-11
lines changed

5 files changed

+91
-11
lines changed

agent/agent.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
340340
// if it can guarantee the clocks are synchronized.
341341
CollectedAt: now,
342342
}
343-
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
343+
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil, nil)
344344
if err != nil {
345345
result.Error = fmt.Sprintf("create cmd: %+v", err)
346346
return result

agent/agentscripts/agentscripts.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
283283
cmdCtx, ctxCancel = context.WithTimeout(ctx, script.Timeout)
284284
defer ctxCancel()
285285
}
286-
cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil)
286+
cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil, nil)
287287
if err != nil {
288288
return xerrors.Errorf("%s script: create command: %w", logPath, err)
289289
}

agent/agentssh/agentssh.go

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
409409
magicTypeLabel := magicTypeMetricLabel(magicType)
410410
sshPty, windowSize, isPty := session.Pty()
411411

412-
cmd, err := s.CreateCommand(ctx, session.RawCommand(), env)
412+
cmd, err := s.CreateCommand(ctx, session.RawCommand(), env, nil)
413413
if err != nil {
414414
ptyLabel := "no"
415415
if isPty {
@@ -670,17 +670,63 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
670670
_ = session.Exit(1)
671671
}
672672

673+
// EnvInfoer encapsulates external information required by CreateCommand.
674+
type EnvInfoer interface {
675+
// CurrentUser returns the current user.
676+
CurrentUser() (*user.User, error)
677+
// Environ returns the environment variables of the current process.
678+
Environ() []string
679+
// UserHomeDir returns the home directory of the current user.
680+
UserHomeDir() (string, error)
681+
// UserShell returns the shell of the given user.
682+
UserShell(username string) (string, error)
683+
}
684+
685+
type systemEnvInfoer struct{}
686+
687+
var defaultEnvInfoer EnvInfoer = &systemEnvInfoer{}
688+
689+
// DefaultEnvInfoer returns a default implementation of
690+
// EnvInfoer. This reads information using the default Go
691+
// implementations.
692+
func DefaultEnvInfoer() EnvInfoer {
693+
return defaultEnvInfoer
694+
}
695+
696+
func (systemEnvInfoer) CurrentUser() (*user.User, error) {
697+
return user.Current()
698+
}
699+
700+
func (systemEnvInfoer) Environ() []string {
701+
return os.Environ()
702+
}
703+
704+
func (systemEnvInfoer) UserHomeDir() (string, error) {
705+
return userHomeDir()
706+
}
707+
708+
func (systemEnvInfoer) UserShell(username string) (string, error) {
709+
return usershell.Get(username)
710+
}
711+
673712
// CreateCommand processes raw command input with OpenSSH-like behavior.
674713
// If the script provided is empty, it will default to the users shell.
675714
// This injects environment variables specified by the user at launch too.
676-
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) {
677-
currentUser, err := user.Current()
715+
// The final argument is an interface that allows the caller to provide
716+
// alternative implementations for the dependencies of CreateCommand.
717+
// This is useful when creating a command to be run in a separate environment
718+
// (for example, a Docker container). Pass in nil to use the default.
719+
func (s *Server) CreateCommand(ctx context.Context, script string, env []string, deps EnvInfoer) (*pty.Cmd, error) {
720+
if deps == nil {
721+
deps = DefaultEnvInfoer()
722+
}
723+
currentUser, err := deps.CurrentUser()
678724
if err != nil {
679725
return nil, xerrors.Errorf("get current user: %w", err)
680726
}
681727
username := currentUser.Username
682728

683-
shell, err := usershell.Get(username)
729+
shell, err := deps.UserShell(username)
684730
if err != nil {
685731
return nil, xerrors.Errorf("get user shell: %w", err)
686732
}
@@ -736,13 +782,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
736782
_, err = os.Stat(cmd.Dir)
737783
if cmd.Dir == "" || err != nil {
738784
// Default to user home if a directory is not set.
739-
homedir, err := userHomeDir()
785+
homedir, err := deps.UserHomeDir()
740786
if err != nil {
741787
return nil, xerrors.Errorf("get home dir: %w", err)
742788
}
743789
cmd.Dir = homedir
744790
}
745-
cmd.Env = append(os.Environ(), env...)
791+
cmd.Env = append(deps.Environ(), env...)
746792
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
747793

748794
// Set SSH connection environment variables (these are also set by OpenSSH

agent/agentssh/agentssh_test.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"context"
99
"fmt"
1010
"net"
11+
"os/user"
1112
"runtime"
1213
"strings"
1314
"sync"
@@ -87,7 +88,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
8788
t.Run("Basic", func(t *testing.T) {
8889
t.Parallel()
8990
cmd, err := s.CreateCommand(ctx, `#!/bin/bash
90-
echo test`, nil)
91+
echo test`, nil, nil)
9192
require.NoError(t, err)
9293
output, err := cmd.AsExec().CombinedOutput()
9394
require.NoError(t, err)
@@ -96,12 +97,45 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
9697
t.Run("Args", func(t *testing.T) {
9798
t.Parallel()
9899
cmd, err := s.CreateCommand(ctx, `#!/usr/bin/env bash
99-
echo test`, nil)
100+
echo test`, nil, nil)
100101
require.NoError(t, err)
101102
output, err := cmd.AsExec().CombinedOutput()
102103
require.NoError(t, err)
103104
require.Equal(t, "test\n", string(output))
104105
})
106+
t.Run("CustomEnvInfoer", func(t *testing.T) {
107+
t.Parallel()
108+
ei := &fakeEnvInfoer{
109+
CurrentUserFn: func() (u *user.User, err error) {
110+
return nil, assert.AnError
111+
},
112+
}
113+
_, err := s.CreateCommand(ctx, `whatever`, nil, ei)
114+
require.ErrorIs(t, err, assert.AnError)
115+
})
116+
}
117+
118+
type fakeEnvInfoer struct {
119+
CurrentUserFn func() (*user.User, error)
120+
EnvironFn func() []string
121+
UserHomeDirFn func() (string, error)
122+
UserShellFn func(string) (string, error)
123+
}
124+
125+
func (f *fakeEnvInfoer) CurrentUser() (u *user.User, err error) {
126+
return f.CurrentUserFn()
127+
}
128+
129+
func (f *fakeEnvInfoer) Environ() []string {
130+
return f.EnvironFn()
131+
}
132+
133+
func (f *fakeEnvInfoer) UserHomeDir() (string, error) {
134+
return f.UserHomeDirFn()
135+
}
136+
137+
func (f *fakeEnvInfoer) UserShell(u string) (string, error) {
138+
return f.UserShellFn(u)
105139
}
106140

107141
func TestNewServer_CloseActiveConnections(t *testing.T) {

agent/reconnectingpty/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
159159
}()
160160

161161
// Empty command will default to the users shell!
162-
cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil)
162+
cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil, nil)
163163
if err != nil {
164164
s.errorsTotal.WithLabelValues("create_command").Add(1)
165165
return xerrors.Errorf("create command: %w", err)

0 commit comments

Comments
 (0)