@@ -409,7 +409,7 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
409
409
magicTypeLabel := magicTypeMetricLabel (magicType )
410
410
sshPty , windowSize , isPty := session .Pty ()
411
411
412
- cmd , err := s .CreateCommand (ctx , session .RawCommand (), env )
412
+ cmd , err := s .CreateCommand (ctx , session .RawCommand (), env , nil )
413
413
if err != nil {
414
414
ptyLabel := "no"
415
415
if isPty {
@@ -670,17 +670,63 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
670
670
_ = session .Exit (1 )
671
671
}
672
672
673
+ // EnvInfoer encapsulates external information required by CreateCommand.
674
+ type EnvInfoer interface {
675
+ // CurrentUser returns the current user.
676
+ CurrentUser () (* user.User , error )
677
+ // Environ returns the environment variables of the current process.
678
+ Environ () []string
679
+ // UserHomeDir returns the home directory of the current user.
680
+ UserHomeDir () (string , error )
681
+ // UserShell returns the shell of the given user.
682
+ UserShell (username string ) (string , error )
683
+ }
684
+
685
+ type systemEnvInfoer struct {}
686
+
687
+ var defaultEnvInfoer EnvInfoer = & systemEnvInfoer {}
688
+
689
+ // DefaultEnvInfoer returns a default implementation of
690
+ // EnvInfoer. This reads information using the default Go
691
+ // implementations.
692
+ func DefaultEnvInfoer () EnvInfoer {
693
+ return defaultEnvInfoer
694
+ }
695
+
696
+ func (systemEnvInfoer ) CurrentUser () (* user.User , error ) {
697
+ return user .Current ()
698
+ }
699
+
700
+ func (systemEnvInfoer ) Environ () []string {
701
+ return os .Environ ()
702
+ }
703
+
704
+ func (systemEnvInfoer ) UserHomeDir () (string , error ) {
705
+ return userHomeDir ()
706
+ }
707
+
708
+ func (systemEnvInfoer ) UserShell (username string ) (string , error ) {
709
+ return usershell .Get (username )
710
+ }
711
+
673
712
// CreateCommand processes raw command input with OpenSSH-like behavior.
674
713
// If the script provided is empty, it will default to the users shell.
675
714
// This injects environment variables specified by the user at launch too.
676
- func (s * Server ) CreateCommand (ctx context.Context , script string , env []string ) (* pty.Cmd , error ) {
677
- currentUser , err := user .Current ()
715
+ // The final argument is an interface that allows the caller to provide
716
+ // alternative implementations for the dependencies of CreateCommand.
717
+ // This is useful when creating a command to be run in a separate environment
718
+ // (for example, a Docker container). Pass in nil to use the default.
719
+ func (s * Server ) CreateCommand (ctx context.Context , script string , env []string , deps EnvInfoer ) (* pty.Cmd , error ) {
720
+ if deps == nil {
721
+ deps = DefaultEnvInfoer ()
722
+ }
723
+ currentUser , err := deps .CurrentUser ()
678
724
if err != nil {
679
725
return nil , xerrors .Errorf ("get current user: %w" , err )
680
726
}
681
727
username := currentUser .Username
682
728
683
- shell , err := usershell . Get (username )
729
+ shell , err := deps . UserShell (username )
684
730
if err != nil {
685
731
return nil , xerrors .Errorf ("get user shell: %w" , err )
686
732
}
@@ -736,13 +782,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
736
782
_ , err = os .Stat (cmd .Dir )
737
783
if cmd .Dir == "" || err != nil {
738
784
// Default to user home if a directory is not set.
739
- homedir , err := userHomeDir ()
785
+ homedir , err := deps . UserHomeDir ()
740
786
if err != nil {
741
787
return nil , xerrors .Errorf ("get home dir: %w" , err )
742
788
}
743
789
cmd .Dir = homedir
744
790
}
745
- cmd .Env = append (os .Environ (), env ... )
791
+ cmd .Env = append (deps .Environ (), env ... )
746
792
cmd .Env = append (cmd .Env , fmt .Sprintf ("USER=%s" , username ))
747
793
748
794
// Set SSH connection environment variables (these are also set by OpenSSH
0 commit comments