Skip to content

Commit e508d9a

Browse files
authored
fix(agent/usershell): check shell on darwin via dscl (#8366)
1 parent de1d04d commit e508d9a

File tree

4 files changed

+70
-31
lines changed

4 files changed

+70
-31
lines changed

agent/usershell/usershell_darwin.go

+20-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
11
package usershell
22

3-
import "os"
3+
import (
4+
"os"
5+
"os/exec"
6+
"path/filepath"
7+
"strings"
8+
9+
"golang.org/x/xerrors"
10+
)
411

512
// Get returns the $SHELL environment variable.
6-
func Get(_ string) (string, error) {
7-
return os.Getenv("SHELL"), nil
13+
func Get(username string) (string, error) {
14+
// This command will output "UserShell: /bin/zsh" if successful, we
15+
// can ignore the error since we have fallback behavior.
16+
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output()
17+
s, ok := strings.CutPrefix(string(out), "UserShell: ")
18+
if ok {
19+
return strings.TrimSpace(s), nil
20+
}
21+
if s = os.Getenv("SHELL"); s != "" {
22+
return s, nil
23+
}
24+
return "", xerrors.Errorf("shell for user %q not found via dscl or in $SHELL", username)
825
}

agent/usershell/usershell_other.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,8 @@ func Get(username string) (string, error) {
2727
}
2828
return parts[6], nil
2929
}
30-
return "", xerrors.Errorf("user %q not found in /etc/passwd", username)
30+
if s := os.Getenv("SHELL"); s != "" {
31+
return s, nil
32+
}
33+
return "", xerrors.Errorf("shell for user %q not found in /etc/passwd or $SHELL", username)
3134
}

agent/usershell/usershell_other_test.go

-27
This file was deleted.

agent/usershell/usershell_test.go

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package usershell_test
2+
3+
import (
4+
"os/user"
5+
"runtime"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/coder/coder/agent/usershell"
11+
)
12+
13+
//nolint:paralleltest,tparallel // This test sets an environment variable.
14+
func TestGet(t *testing.T) {
15+
if runtime.GOOS == "windows" {
16+
t.SkipNow()
17+
}
18+
19+
t.Run("Fallback", func(t *testing.T) {
20+
t.Setenv("SHELL", "/bin/sh")
21+
22+
t.Run("NonExistentUser", func(t *testing.T) {
23+
shell, err := usershell.Get("notauser")
24+
require.NoError(t, err)
25+
require.Equal(t, "/bin/sh", shell)
26+
})
27+
})
28+
29+
t.Run("NoFallback", func(t *testing.T) {
30+
// Disable env fallback for these tests.
31+
t.Setenv("SHELL", "")
32+
33+
t.Run("NotFound", func(t *testing.T) {
34+
_, err := usershell.Get("notauser")
35+
require.Error(t, err)
36+
})
37+
38+
t.Run("User", func(t *testing.T) {
39+
u, err := user.Current()
40+
require.NoError(t, err)
41+
shell, err := usershell.Get(u.Username)
42+
require.NoError(t, err)
43+
require.NotEmpty(t, shell)
44+
})
45+
})
46+
}

0 commit comments

Comments
 (0)