@@ -409,7 +409,7 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
409
409
magicTypeLabel := magicTypeMetricLabel (magicType )
410
410
sshPty , windowSize , isPty := session .Pty ()
411
411
412
- cmd , err := s .CreateCommand (ctx , session .RawCommand (), env )
412
+ cmd , err := s .CreateCommand (ctx , session .RawCommand (), env , nil )
413
413
if err != nil {
414
414
ptyLabel := "no"
415
415
if isPty {
@@ -670,17 +670,59 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
670
670
_ = session .Exit (1 )
671
671
}
672
672
673
+ // CreateCommandDeps encapsulates external information required by CreateCommand.
674
+ type CreateCommandDeps interface {
675
+ // CurrentUser returns the current user.
676
+ CurrentUser () (* user.User , error )
677
+ // Environ returns the environment variables of the current process.
678
+ Environ () []string
679
+ // UserHomeDir returns the home directory of the current user.
680
+ UserHomeDir () (string , error )
681
+ // UserShell returns the shell of the given user.
682
+ UserShell (username string ) (string , error )
683
+ }
684
+
685
+ type systemCreateCommandDeps struct {}
686
+
687
+ var defaultCreateCommandDeps CreateCommandDeps = & systemCreateCommandDeps {}
688
+
689
+ // DefaultCreateCommandDeps returns a default implementation of
690
+ // CreateCommandDeps. This reads information using the default Go
691
+ // implementations.
692
+ func DefaultCreateCommandDeps () CreateCommandDeps {
693
+ return defaultCreateCommandDeps
694
+ }
695
+ func (systemCreateCommandDeps ) CurrentUser () (* user.User , error ) {
696
+ return user .Current ()
697
+ }
698
+ func (systemCreateCommandDeps ) Environ () []string {
699
+ return os .Environ ()
700
+ }
701
+ func (systemCreateCommandDeps ) UserHomeDir () (string , error ) {
702
+ return userHomeDir ()
703
+ }
704
+ func (systemCreateCommandDeps ) UserShell (username string ) (string , error ) {
705
+ return usershell .Get (username )
706
+ }
707
+
673
708
// CreateCommand processes raw command input with OpenSSH-like behavior.
674
709
// If the script provided is empty, it will default to the users shell.
675
710
// This injects environment variables specified by the user at launch too.
676
- func (s * Server ) CreateCommand (ctx context.Context , script string , env []string ) (* pty.Cmd , error ) {
677
- currentUser , err := user .Current ()
711
+ // The final argument is an interface that allows the caller to provide
712
+ // alternative implementations for the dependencies of CreateCommand.
713
+ // This is useful when creating a command to be run in a separate environment
714
+ // (for example, a Docker container). Pass in nil to use the default.
715
+ func (s * Server ) CreateCommand (ctx context.Context , script string , env []string , deps CreateCommandDeps ) (* pty.Cmd , error ) {
716
+ if deps == nil {
717
+ deps = DefaultCreateCommandDeps ()
718
+ }
719
+ currentUser , err := deps .CurrentUser ()
678
720
if err != nil {
679
721
return nil , xerrors .Errorf ("get current user: %w" , err )
680
722
}
681
723
username := currentUser .Username
682
724
683
- shell , err := usershell . Get (username )
725
+ shell , err := deps . UserShell (username )
684
726
if err != nil {
685
727
return nil , xerrors .Errorf ("get user shell: %w" , err )
686
728
}
@@ -736,13 +778,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
736
778
_ , err = os .Stat (cmd .Dir )
737
779
if cmd .Dir == "" || err != nil {
738
780
// Default to user home if a directory is not set.
739
- homedir , err := userHomeDir ()
781
+ homedir , err := deps . UserHomeDir ()
740
782
if err != nil {
741
783
return nil , xerrors .Errorf ("get home dir: %w" , err )
742
784
}
743
785
cmd .Dir = homedir
744
786
}
745
- cmd .Env = append (os .Environ (), env ... )
787
+ cmd .Env = append (deps .Environ (), env ... )
746
788
cmd .Env = append (cmd .Env , fmt .Sprintf ("USER=%s" , username ))
747
789
748
790
// Set SSH connection environment variables (these are also set by OpenSSH
0 commit comments