From 7b28cfa28de6855f4d8d8cfeb1ae3495c7407d4c Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 18 Feb 2025 12:23:40 +0000 Subject: [PATCH 1/4] chore(agent/agentssh): extract CreateCommandDeps --- agent/agent.go | 2 +- agent/agentscripts/agentscripts.go | 2 +- agent/agentssh/agentssh.go | 54 ++++++++++++++++++++++++++---- agent/agentssh/agentssh_test.go | 4 +-- agent/reconnectingpty/server.go | 2 +- 5 files changed, 53 insertions(+), 11 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 28ea524bf3da3..8ff6d68d25f0b 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -340,7 +340,7 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM // if it can guarantee the clocks are synchronized. CollectedAt: now, } - cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil) + cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil, nil) if err != nil { result.Error = fmt.Sprintf("create cmd: %+v", err) return result diff --git a/agent/agentscripts/agentscripts.go b/agent/agentscripts/agentscripts.go index 25ea0ba46fcf3..bd83d71875c73 100644 --- a/agent/agentscripts/agentscripts.go +++ b/agent/agentscripts/agentscripts.go @@ -283,7 +283,7 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript, cmdCtx, ctxCancel = context.WithTimeout(ctx, script.Timeout) defer ctxCancel() } - cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil) + cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil, nil) if err != nil { return xerrors.Errorf("%s script: create command: %w", logPath, err) } diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index dae1b73b2de6c..c84053e988d87 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -409,7 +409,7 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv magicTypeLabel := magicTypeMetricLabel(magicType) sshPty, windowSize, isPty := session.Pty() - cmd, err := s.CreateCommand(ctx, session.RawCommand(), env) + cmd, err := s.CreateCommand(ctx, session.RawCommand(), env, nil) if err != nil { ptyLabel := "no" if isPty { @@ -670,17 +670,59 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) { _ = session.Exit(1) } +// CreateCommandDeps encapsulates external information required by CreateCommand. +type CreateCommandDeps interface { + // CurrentUser returns the current user. + CurrentUser() (*user.User, error) + // Environ returns the environment variables of the current process. + Environ() []string + // UserHomeDir returns the home directory of the current user. + UserHomeDir() (string, error) + // UserShell returns the shell of the given user. + UserShell(username string) (string, error) +} + +type systemCreateCommandDeps struct{} + +var defaultCreateCommandDeps CreateCommandDeps = &systemCreateCommandDeps{} + +// DefaultCreateCommandDeps returns a default implementation of +// CreateCommandDeps. This reads information using the default Go +// implementations. +func DefaultCreateCommandDeps() CreateCommandDeps { + return defaultCreateCommandDeps +} +func (systemCreateCommandDeps) CurrentUser() (*user.User, error) { + return user.Current() +} +func (systemCreateCommandDeps) Environ() []string { + return os.Environ() +} +func (systemCreateCommandDeps) UserHomeDir() (string, error) { + return userHomeDir() +} +func (systemCreateCommandDeps) UserShell(username string) (string, error) { + return usershell.Get(username) +} + // CreateCommand processes raw command input with OpenSSH-like behavior. // If the script provided is empty, it will default to the users shell. // This injects environment variables specified by the user at launch too. -func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) { - currentUser, err := user.Current() +// The final argument is an interface that allows the caller to provide +// alternative implementations for the dependencies of CreateCommand. +// This is useful when creating a command to be run in a separate environment +// (for example, a Docker container). Pass in nil to use the default. +func (s *Server) CreateCommand(ctx context.Context, script string, env []string, deps CreateCommandDeps) (*pty.Cmd, error) { + if deps == nil { + deps = DefaultCreateCommandDeps() + } + currentUser, err := deps.CurrentUser() if err != nil { return nil, xerrors.Errorf("get current user: %w", err) } username := currentUser.Username - shell, err := usershell.Get(username) + shell, err := deps.UserShell(username) if err != nil { return nil, xerrors.Errorf("get user shell: %w", err) } @@ -736,13 +778,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) _, err = os.Stat(cmd.Dir) if cmd.Dir == "" || err != nil { // Default to user home if a directory is not set. - homedir, err := userHomeDir() + homedir, err := deps.UserHomeDir() if err != nil { return nil, xerrors.Errorf("get home dir: %w", err) } cmd.Dir = homedir } - cmd.Env = append(os.Environ(), env...) + cmd.Env = append(deps.Environ(), env...) cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username)) // Set SSH connection environment variables (these are also set by OpenSSH diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 76321e6e19d85..9cf63b12fe6ee 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -87,7 +87,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) { t.Run("Basic", func(t *testing.T) { t.Parallel() cmd, err := s.CreateCommand(ctx, `#!/bin/bash - echo test`, nil) + echo test`, nil, nil) require.NoError(t, err) output, err := cmd.AsExec().CombinedOutput() require.NoError(t, err) @@ -96,7 +96,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) { t.Run("Args", func(t *testing.T) { t.Parallel() cmd, err := s.CreateCommand(ctx, `#!/usr/bin/env bash - echo test`, nil) + echo test`, nil, nil) require.NoError(t, err) output, err := cmd.AsExec().CombinedOutput() require.NoError(t, err) diff --git a/agent/reconnectingpty/server.go b/agent/reconnectingpty/server.go index d48c7abec9353..465667c616180 100644 --- a/agent/reconnectingpty/server.go +++ b/agent/reconnectingpty/server.go @@ -159,7 +159,7 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co }() // Empty command will default to the users shell! - cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil) + cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil, nil) if err != nil { s.errorsTotal.WithLabelValues("create_command").Add(1) return xerrors.Errorf("create command: %w", err) From 251d1d262045901a42df5e3a0f703f9471ad2eb1 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 18 Feb 2025 12:40:37 +0000 Subject: [PATCH 2/4] make fmt --- agent/agentssh/agentssh.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index c84053e988d87..ee20674464f87 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -692,15 +692,19 @@ var defaultCreateCommandDeps CreateCommandDeps = &systemCreateCommandDeps{} func DefaultCreateCommandDeps() CreateCommandDeps { return defaultCreateCommandDeps } + func (systemCreateCommandDeps) CurrentUser() (*user.User, error) { return user.Current() } + func (systemCreateCommandDeps) Environ() []string { return os.Environ() } + func (systemCreateCommandDeps) UserHomeDir() (string, error) { return userHomeDir() } + func (systemCreateCommandDeps) UserShell(username string) (string, error) { return usershell.Get(username) } From 1259518199ec3e7521d607bf425874dc29f074e3 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 18 Feb 2025 20:43:59 +0000 Subject: [PATCH 3/4] EnvInfoer is so Go --- agent/agentssh/agentssh.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index ee20674464f87..d17e9cd761fe6 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -670,8 +670,8 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) { _ = session.Exit(1) } -// CreateCommandDeps encapsulates external information required by CreateCommand. -type CreateCommandDeps interface { +// EnvInfoer encapsulates external information required by CreateCommand. +type EnvInfoer interface { // CurrentUser returns the current user. CurrentUser() (*user.User, error) // Environ returns the environment variables of the current process. @@ -682,30 +682,30 @@ type CreateCommandDeps interface { UserShell(username string) (string, error) } -type systemCreateCommandDeps struct{} +type systemEnvInfoer struct{} -var defaultCreateCommandDeps CreateCommandDeps = &systemCreateCommandDeps{} +var defaultEnvInfoer EnvInfoer = &systemEnvInfoer{} -// DefaultCreateCommandDeps returns a default implementation of -// CreateCommandDeps. This reads information using the default Go +// DefaultEnvInfoer returns a default implementation of +// EnvInfoer. This reads information using the default Go // implementations. -func DefaultCreateCommandDeps() CreateCommandDeps { - return defaultCreateCommandDeps +func DefaultEnvInfoer() EnvInfoer { + return defaultEnvInfoer } -func (systemCreateCommandDeps) CurrentUser() (*user.User, error) { +func (systemEnvInfoer) CurrentUser() (*user.User, error) { return user.Current() } -func (systemCreateCommandDeps) Environ() []string { +func (systemEnvInfoer) Environ() []string { return os.Environ() } -func (systemCreateCommandDeps) UserHomeDir() (string, error) { +func (systemEnvInfoer) UserHomeDir() (string, error) { return userHomeDir() } -func (systemCreateCommandDeps) UserShell(username string) (string, error) { +func (systemEnvInfoer) UserShell(username string) (string, error) { return usershell.Get(username) } @@ -716,9 +716,9 @@ func (systemCreateCommandDeps) UserShell(username string) (string, error) { // alternative implementations for the dependencies of CreateCommand. // This is useful when creating a command to be run in a separate environment // (for example, a Docker container). Pass in nil to use the default. -func (s *Server) CreateCommand(ctx context.Context, script string, env []string, deps CreateCommandDeps) (*pty.Cmd, error) { +func (s *Server) CreateCommand(ctx context.Context, script string, env []string, deps EnvInfoer) (*pty.Cmd, error) { if deps == nil { - deps = DefaultCreateCommandDeps() + deps = DefaultEnvInfoer() } currentUser, err := deps.CurrentUser() if err != nil { From c5f45e48d1f51f44a542f5c627a8cfb7253641fe Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 18 Feb 2025 20:51:13 +0000 Subject: [PATCH 4/4] add test for custom implementation --- agent/agentssh/agentssh_test.go | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 9cf63b12fe6ee..b9cec420e5651 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "net" + "os/user" "runtime" "strings" "sync" @@ -102,6 +103,39 @@ func TestNewServer_ExecuteShebang(t *testing.T) { require.NoError(t, err) require.Equal(t, "test\n", string(output)) }) + t.Run("CustomEnvInfoer", func(t *testing.T) { + t.Parallel() + ei := &fakeEnvInfoer{ + CurrentUserFn: func() (u *user.User, err error) { + return nil, assert.AnError + }, + } + _, err := s.CreateCommand(ctx, `whatever`, nil, ei) + require.ErrorIs(t, err, assert.AnError) + }) +} + +type fakeEnvInfoer struct { + CurrentUserFn func() (*user.User, error) + EnvironFn func() []string + UserHomeDirFn func() (string, error) + UserShellFn func(string) (string, error) +} + +func (f *fakeEnvInfoer) CurrentUser() (u *user.User, err error) { + return f.CurrentUserFn() +} + +func (f *fakeEnvInfoer) Environ() []string { + return f.EnvironFn() +} + +func (f *fakeEnvInfoer) UserHomeDir() (string, error) { + return f.UserHomeDirFn() +} + +func (f *fakeEnvInfoer) UserShell(u string) (string, error) { + return f.UserShellFn(u) } func TestNewServer_CloseActiveConnections(t *testing.T) {