diff --git a/agent/agent.go b/agent/agent.go index 5c171d7d513c7..efd57e5db29db 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -161,7 +161,7 @@ type agent struct { } func (a *agent) init(ctx context.Context) { - sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.sshMaxTimeout) + sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.filesystem, a.sshMaxTimeout, "") if err != nil { panic(err) } diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 86e1eb9e36af4..d6b6613ed92de 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -20,6 +20,7 @@ import ( "github.com/gliderlabs/ssh" "github.com/pkg/sftp" + "github.com/spf13/afero" "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" "golang.org/x/xerrors" @@ -48,6 +49,7 @@ const ( type Server struct { mu sync.RWMutex // Protects following. + fs afero.Fs listeners map[net.Listener]struct{} conns map[net.Conn]struct{} sessions map[ssh.Session]struct{} @@ -56,8 +58,9 @@ type Server struct { // a lock on mu but protected by closing. wg sync.WaitGroup - logger slog.Logger - srv *ssh.Server + logger slog.Logger + srv *ssh.Server + x11SocketDir string Env map[string]string AgentToken func() string @@ -68,7 +71,7 @@ type Server struct { connCountSSHSession atomic.Int64 } -func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration) (*Server, error) { +func NewServer(ctx context.Context, logger slog.Logger, fs afero.Fs, maxTimeout time.Duration, x11SocketDir string) (*Server, error) { // Clients' should ignore the host key when connecting. // The agent needs to authenticate with coderd to SSH, // so SSH authentication doesn't improve security. @@ -80,15 +83,20 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration if err != nil { return nil, err } + if x11SocketDir == "" { + x11SocketDir = filepath.Join(os.TempDir(), ".X11-unix") + } forwardHandler := &ssh.ForwardedTCPHandler{} unixForwardHandler := &forwardedUnixHandler{log: logger} s := &Server{ - listeners: make(map[net.Listener]struct{}), - conns: make(map[net.Conn]struct{}), - sessions: make(map[ssh.Session]struct{}), - logger: logger, + listeners: make(map[net.Listener]struct{}), + fs: fs, + conns: make(map[net.Conn]struct{}), + sessions: make(map[ssh.Session]struct{}), + logger: logger, + x11SocketDir: x11SocketDir, } s.srv = &ssh.Server{ @@ -125,6 +133,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration "streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, "cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, }, + X11Callback: s.x11Callback, ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { return &gossh.ServerConfig{ NoClientAuth: true, @@ -163,6 +172,15 @@ func (s *Server) sessionHandler(session ssh.Session) { ctx := session.Context() + x11, hasX11 := session.X11() + if hasX11 { + handled := s.x11Handler(session.Context(), x11) + if !handled { + _ = session.Exit(1) + return + } + } + switch ss := session.Subsystem(); ss { case "": case "sftp": diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 684c0e36bbb18..b1675f0029a2c 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -10,6 +10,7 @@ import ( "sync" "testing" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" @@ -32,7 +33,7 @@ func TestNewServer_ServeClient(t *testing.T) { ctx := context.Background() logger := slogtest.Make(t, nil) - s, err := agentssh.NewServer(ctx, logger, 0) + s, err := agentssh.NewServer(ctx, logger, afero.NewMemMapFs(), 0, "") require.NoError(t, err) // The assumption is that these are set before serving SSH connections. @@ -50,6 +51,7 @@ func TestNewServer_ServeClient(t *testing.T) { }() c := sshClient(t, ln.Addr().String()) + var b bytes.Buffer sess, err := c.NewSession() sess.Stdout = &b @@ -72,7 +74,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { ctx := context.Background() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - s, err := agentssh.NewServer(ctx, logger, 0) + s, err := agentssh.NewServer(ctx, logger, afero.NewMemMapFs(), 0, "") require.NoError(t, err) // The assumption is that these are set before serving SSH connections. diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go new file mode 100644 index 0000000000000..6d50d4a99078f --- /dev/null +++ b/agent/agentssh/x11.go @@ -0,0 +1,190 @@ +package agentssh + +import ( + "context" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/gliderlabs/ssh" + "github.com/gofrs/flock" + "github.com/spf13/afero" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" + + "cdr.dev/slog" +) + +// x11Callback is called when the client requests X11 forwarding. +// It adds an Xauthority entry to the Xauthority file. +func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool { + hostname, err := os.Hostname() + if err != nil { + s.logger.Warn(ctx, "failed to get hostname", slog.Error(err)) + return false + } + + err = s.fs.MkdirAll(s.x11SocketDir, 0o700) + if err != nil { + s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.x11SocketDir), slog.Error(err)) + return false + } + + err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie) + if err != nil { + s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err)) + return false + } + return true +} + +// x11Handler is called when a session has requested X11 forwarding. +// It listens for X11 connections and forwards them to the client. +func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool { + serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) + if !valid { + s.logger.Warn(ctx, "failed to get server connection") + return false + } + listener, err := net.Listen("unix", filepath.Join(s.x11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))) + if err != nil { + s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err)) + return false + } + s.trackListener(listener, true) + + go func() { + defer listener.Close() + defer s.trackListener(listener, false) + handledFirstConnection := false + + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err)) + return + } + if x11.SingleConnection && handledFirstConnection { + s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled") + _ = conn.Close() + continue + } + handledFirstConnection = true + + unixConn, ok := conn.(*net.UnixConn) + if !ok { + s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn)) + return + } + unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr) + if !ok { + s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr())) + return + } + + channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct { + OriginatorAddress string + OriginatorPort uint32 + }{ + OriginatorAddress: unixAddr.Name, + OriginatorPort: 0, + })) + if err != nil { + s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err)) + return + } + go gossh.DiscardRequests(reqs) + go Bicopy(ctx, conn, channel) + } + }() + return true +} + +// addXauthEntry adds an Xauthority entry to the Xauthority file. +// The Xauthority file is located at ~/.Xauthority. +func addXauthEntry(ctx context.Context, fs afero.Fs, host string, display string, authProtocol string, authCookie string) error { + // Get the Xauthority file path + homeDir, err := os.UserHomeDir() + if err != nil { + return xerrors.Errorf("failed to get user home directory: %w", err) + } + + xauthPath := filepath.Join(homeDir, ".Xauthority") + + lock := flock.New(xauthPath) + defer lock.Close() + ok, err := lock.TryLockContext(ctx, 100*time.Millisecond) + if !ok { + return xerrors.Errorf("failed to lock Xauthority file: %w", err) + } + if err != nil { + return xerrors.Errorf("failed to lock Xauthority file: %w", err) + } + + // Open or create the Xauthority file + file, err := fs.OpenFile(xauthPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o600) + if err != nil { + return xerrors.Errorf("failed to open Xauthority file: %w", err) + } + defer file.Close() + + // Convert the authCookie from hex string to byte slice + authCookieBytes, err := hex.DecodeString(authCookie) + if err != nil { + return xerrors.Errorf("failed to decode auth cookie: %w", err) + } + + // Write Xauthority entry + family := uint16(0x0100) // FamilyLocal + err = binary.Write(file, binary.BigEndian, family) + if err != nil { + return xerrors.Errorf("failed to write family: %w", err) + } + + err = binary.Write(file, binary.BigEndian, uint16(len(host))) + if err != nil { + return xerrors.Errorf("failed to write host length: %w", err) + } + _, err = file.WriteString(host) + if err != nil { + return xerrors.Errorf("failed to write host: %w", err) + } + + err = binary.Write(file, binary.BigEndian, uint16(len(display))) + if err != nil { + return xerrors.Errorf("failed to write display length: %w", err) + } + _, err = file.WriteString(display) + if err != nil { + return xerrors.Errorf("failed to write display: %w", err) + } + + err = binary.Write(file, binary.BigEndian, uint16(len(authProtocol))) + if err != nil { + return xerrors.Errorf("failed to write auth protocol length: %w", err) + } + _, err = file.WriteString(authProtocol) + if err != nil { + return xerrors.Errorf("failed to write auth protocol: %w", err) + } + + err = binary.Write(file, binary.BigEndian, uint16(len(authCookieBytes))) + if err != nil { + return xerrors.Errorf("failed to write auth cookie length: %w", err) + } + _, err = file.Write(authCookieBytes) + if err != nil { + return xerrors.Errorf("failed to write auth cookie: %w", err) + } + + return nil +} diff --git a/agent/agentssh/x11_test.go b/agent/agentssh/x11_test.go new file mode 100644 index 0000000000000..cd935d326858c --- /dev/null +++ b/agent/agentssh/x11_test.go @@ -0,0 +1,99 @@ +package agentssh_test + +import ( + "context" + "encoding/hex" + "net" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/gliderlabs/ssh" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + gossh "golang.org/x/crypto/ssh" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent/agentssh" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/testutil" +) + +func TestServer_X11(t *testing.T) { + t.Parallel() + if runtime.GOOS != "linux" { + t.Skip("X11 forwarding is only supported on Linux") + } + + ctx := context.Background() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + fs := afero.NewOsFs() + dir := t.TempDir() + s, err := agentssh.NewServer(ctx, logger, fs, 0, dir) + require.NoError(t, err) + defer s.Close() + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + + c := sshClient(t, ln.Addr().String()) + + sess, err := c.NewSession() + require.NoError(t, err) + + reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{ + AuthProtocol: "MIT-MAGIC-COOKIE-1", + AuthCookie: hex.EncodeToString([]byte("cookie")), + ScreenNumber: 0, + })) + require.NoError(t, err) + assert.True(t, reply) + + err = sess.Shell() + require.NoError(t, err) + + x11Chans := c.HandleChannelOpen("x11") + payload := "hello world" + require.Eventually(t, func() bool { + conn, err := net.Dial("unix", filepath.Join(dir, "X0")) + if err == nil { + _, err = conn.Write([]byte(payload)) + assert.NoError(t, err) + _ = conn.Close() + } + return err == nil + }, testutil.WaitShort, testutil.IntervalFast) + + x11 := <-x11Chans + ch, reqs, err := x11.Accept() + require.NoError(t, err) + go gossh.DiscardRequests(reqs) + got := make([]byte, len(payload)) + _, err = ch.Read(got) + require.NoError(t, err) + assert.Equal(t, payload, string(got)) + _ = ch.Close() + _ = s.Close() + <-done + + // Ensure the Xauthority file was written! + home, err := os.UserHomeDir() + require.NoError(t, err) + _, err = fs.Stat(filepath.Join(home, ".Xauthority")) + require.NoError(t, err) +} diff --git a/go.mod b/go.mod index e63edefb10da4..92633b91f56b6 100644 --- a/go.mod +++ b/go.mod @@ -45,7 +45,7 @@ replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20230418202606-ed93 // repo as tailscale.com/tempfork/gliderlabs/ssh, however, we can't replace the // subpath and it includes changes to golang.org/x/crypto/ssh as well which // makes importing it directly a bit messy. -replace github.com/gliderlabs/ssh => github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 +replace github.com/gliderlabs/ssh => github.com/coder/ssh v0.0.0-20230419180646-49c741437b53 // Waiting on https://github.com/imulab/go-scim/pull/95 to merge. replace github.com/imulab/go-scim/pkg/v2 => github.com/coder/go-scim/pkg/v2 v2.0.0-20230221055123-1d63c1222136 diff --git a/go.sum b/go.sum index c3ead3f4c48fe..89b80e00afab5 100644 --- a/go.sum +++ b/go.sum @@ -380,6 +380,10 @@ github.com/coder/retry v1.3.1-0.20230210155434-e90a2e1e091d h1:09JG37IgTB6n3ouX9 github.com/coder/retry v1.3.1-0.20230210155434-e90a2e1e091d/go.mod h1:r+1J5i/989wt6CUeNSuvFKKA9hHuKKPMxdzDbTuvwwk= github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko= github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= +github.com/coder/ssh v0.0.0-20230419175457-0612ba535202 h1:1I/Im5ZUan1Y9ypAr6VuAKQ4NbvEy/frR3cV86pKQk8= +github.com/coder/ssh v0.0.0-20230419175457-0612ba535202/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= +github.com/coder/ssh v0.0.0-20230419180646-49c741437b53 h1:kaLOp3tlVnbOJIjmAvXuBTgeWWoZZlJJJ4QGeSMjOnA= +github.com/coder/ssh v0.0.0-20230419180646-49c741437b53/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= github.com/coder/tailscale v1.1.1-0.20230418202606-ed9307cf1b22 h1:bvGOqnI0ITbwOZFQ0SZ4MBw/8LLUEjxmNu57XEujrfQ= github.com/coder/tailscale v1.1.1-0.20230418202606-ed9307cf1b22/go.mod h1:jpg+77g19FpXL43U1VoIqoSg1K/Vh5CVxycGldQ8KhA= github.com/coder/terraform-provider-coder v0.6.23 h1:O2Rcj0umez4DfVdGnKZi63z1Xzxd0IQOn9VQDB8YU8g=