diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index c27e59d0afc02..0e1328badd541 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -99,7 +99,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom } forwardHandler := &ssh.ForwardedTCPHandler{} - unixForwardHandler := &forwardedUnixHandler{log: logger} + unixForwardHandler := newForwardedUnixHandler(logger) metrics := newSSHServerMetrics(prometheusRegistry) s := &Server{ diff --git a/agent/agentssh/forward.go b/agent/agentssh/forward.go index ac5e5ac7100f8..adce24c8a9af8 100644 --- a/agent/agentssh/forward.go +++ b/agent/agentssh/forward.go @@ -2,11 +2,14 @@ package agentssh import ( "context" + "errors" "fmt" + "io/fs" "net" "os" "path/filepath" "sync" + "syscall" "github.com/gliderlabs/ssh" gossh "golang.org/x/crypto/ssh" @@ -33,22 +36,29 @@ type forwardedStreamLocalPayload struct { type forwardedUnixHandler struct { sync.Mutex log slog.Logger - forwards map[string]net.Listener + forwards map[forwardKey]net.Listener +} + +type forwardKey struct { + sessionID string + addr string +} + +func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler { + return &forwardedUnixHandler{ + log: log, + forwards: make(map[forwardKey]net.Listener), + } } func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) { h.log.Debug(ctx, "handling SSH unix forward") - h.Lock() - if h.forwards == nil { - h.forwards = make(map[string]net.Listener) - } - h.Unlock() conn, ok := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) if !ok { h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection") return false, nil } - log := h.log.With(slog.F("remote_addr", conn.RemoteAddr())) + log := h.log.With(slog.F("session_id", ctx.SessionID()), slog.F("remote_addr", conn.RemoteAddr())) switch req.Type { case "streamlocal-forward@openssh.com": @@ -62,14 +72,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, addr := reqPayload.SocketPath log = log.With(slog.F("socket_path", addr)) log.Debug(ctx, "request begin SSH unix forward") + + key := forwardKey{ + sessionID: ctx.SessionID(), + addr: addr, + } + h.Lock() - _, ok := h.forwards[addr] + _, ok := h.forwards[key] h.Unlock() if ok { - log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)", - slog.F("socket_path", addr), - ) - return false, nil + // In cases where `ExitOnForwardFailure=yes` is set, returning false + // here will cause the connection to be closed. To avoid this, and + // to match OpenSSH behavior, we silently ignore the second forward + // request. + log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded on this session, ignoring") + return true, nil } // Create socket parent dir if not exists. @@ -83,12 +101,20 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, return false, nil } - ln, err := net.Listen("unix", addr) + // Remove existing socket if it exists. We do not use os.Remove() here + // so that directories are kept. Note that it's possible that we will + // overwrite a regular file here. Both of these behaviors match OpenSSH, + // however, which is why we unlink. + err = unlink(addr) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + log.Warn(ctx, "remove existing socket for SSH unix forward request", slog.Error(err)) + return false, nil + } + + lc := &net.ListenConfig{} + ln, err := lc.Listen(ctx, "unix", addr) if err != nil { - log.Warn(ctx, "listen on Unix socket for SSH unix forward request", - slog.F("socket_path", addr), - slog.Error(err), - ) + log.Warn(ctx, "listen on Unix socket for SSH unix forward request", slog.Error(err)) return false, nil } log.Debug(ctx, "SSH unix forward listening on socket") @@ -99,7 +125,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, // // This is also what the upstream TCP version of this code does. h.Lock() - h.forwards[addr] = ln + h.forwards[key] = ln h.Unlock() log.Debug(ctx, "SSH unix forward added to cache") @@ -115,9 +141,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, c, err := ln.Accept() if err != nil { if !xerrors.Is(err, net.ErrClosed) { - log.Warn(ctx, "accept on local Unix socket for SSH unix forward request", - slog.Error(err), - ) + log.Warn(ctx, "accept on local Unix socket for SSH unix forward request", slog.Error(err)) } // closed below log.Debug(ctx, "SSH unix forward listener closed") @@ -131,10 +155,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, go func() { ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload) if err != nil { - h.log.Warn(ctx, "open SSH unix forward channel to client", - slog.F("socket_path", addr), - slog.Error(err), - ) + h.log.Warn(ctx, "open SSH unix forward channel to client", slog.Error(err)) _ = c.Close() return } @@ -144,12 +165,11 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, } h.Lock() - ln2, ok := h.forwards[addr] - if ok && ln2 == ln { - delete(h.forwards, addr) + if ln2, ok := h.forwards[key]; ok && ln2 == ln { + delete(h.forwards, key) } h.Unlock() - log.Debug(ctx, "SSH unix forward listener removed from cache", slog.F("path", addr)) + log.Debug(ctx, "SSH unix forward listener removed from cache") _ = ln.Close() }() @@ -162,13 +182,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client", slog.Error(err)) return false, nil } - log.Debug(ctx, "request to cancel SSH unix forward", slog.F("path", reqPayload.SocketPath)) + log.Debug(ctx, "request to cancel SSH unix forward", slog.F("socket_path", reqPayload.SocketPath)) + + key := forwardKey{ + sessionID: ctx.SessionID(), + addr: reqPayload.SocketPath, + } + h.Lock() - ln, ok := h.forwards[reqPayload.SocketPath] + ln, ok := h.forwards[key] + delete(h.forwards, key) h.Unlock() - if ok { - _ = ln.Close() + if !ok { + log.Warn(ctx, "SSH unix forward not found in cache") + return true, nil } + _ = ln.Close() return true, nil default: @@ -209,3 +238,15 @@ func directStreamLocalHandler(_ *ssh.Server, _ *gossh.ServerConn, newChan gossh. Bicopy(ctx, ch, dconn) } + +// unlink removes files and unlike os.Remove, directories are kept. +func unlink(path string) error { + // Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go + // for more details. + for { + err := syscall.Unlink(path) + if !errors.Is(err, syscall.EINTR) { + return err + } + } +} diff --git a/cli/ssh_test.go b/cli/ssh_test.go index faf69d0d98faf..684e8700c1f50 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -26,12 +26,14 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" gosshagent "golang.org/x/crypto/ssh/agent" + "golang.org/x/sync/errgroup" "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/agent" + "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" @@ -738,8 +740,8 @@ func TestSSH(t *testing.T) { defer cancel() tmpdir := tempDirUnixSocket(t) - agentSock := filepath.Join(tmpdir, "agent.sock") - l, err := net.Listen("unix", agentSock) + localSock := filepath.Join(tmpdir, "local.sock") + l, err := net.Listen("unix", localSock) require.NoError(t, err) defer l.Close() remoteSock := filepath.Join(tmpdir, "remote.sock") @@ -748,7 +750,7 @@ func TestSSH(t *testing.T) { "ssh", workspace.Name, "--remote-forward", - fmt.Sprintf("%s:%s", remoteSock, agentSock), + fmt.Sprintf("%s:%s", remoteSock, localSock), ) clitest.SetupConfig(t, client, root) pty := ptytest.New(t).Attach(inv) @@ -771,6 +773,116 @@ func TestSSH(t *testing.T) { <-cmdDone }) + // Test that we can forward a local unix socket to a remote unix socket and + // that new SSH sessions take over the socket without closing active socket + // connections. + t.Run("RemoteForwardUnixSocketMultipleSessionsOverwrite", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Test not supported on windows") + } + + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + // Wait super super long so this doesn't flake on -race test. + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong*2) + defer cancel() + + tmpdir := tempDirUnixSocket(t) + + localSock := filepath.Join(tmpdir, "local.sock") + l, err := net.Listen("unix", localSock) + require.NoError(t, err) + defer l.Close() + testutil.Go(t, func() { + for { + fd, err := l.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + assert.NoError(t, err, "listener accept failed") + } + return + } + + testutil.Go(t, func() { + defer fd.Close() + agentssh.Bicopy(ctx, fd, fd) + }) + } + }) + + remoteSock := filepath.Join(tmpdir, "remote.sock") + + var done []func() error + for i := 0; i < 2; i++ { + id := fmt.Sprintf("ssh-%d", i) + inv, root := clitest.New(t, + "ssh", + workspace.Name, + "--remote-forward", + fmt.Sprintf("%s:%s", remoteSock, localSock), + ) + inv.Logger = inv.Logger.Named(id) + clitest.SetupConfig(t, client, root) + pty := ptytest.New(t).Attach(inv) + inv.Stderr = pty.Output() + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.NoError(t, err, "ssh command failed: %s", id) + }) + + // Since something was output, it should be safe to write input. + // This could show a prompt or "running startup scripts", so it's + // not indicative of the SSH connection being ready. + _ = pty.Peek(ctx, 1) + + // Ensure the SSH connection is ready by testing the shell + // input/output. + pty.WriteLine("echo ping' 'pong") + pty.ExpectMatchContext(ctx, "ping pong") + + d := &net.Dialer{} + fd, err := d.DialContext(ctx, "unix", remoteSock) + require.NoError(t, err, id) + + // Ping / pong to ensure the socket is working. + _, err = fd.Write([]byte("hello world")) + require.NoError(t, err, id) + + buf := make([]byte, 11) + _, err = fd.Read(buf) + require.NoError(t, err, id) + require.Equal(t, "hello world", string(buf), id) + + done = append(done, func() error { + // Redo ping / pong to ensure that the socket + // connections still work. + _, err := fd.Write([]byte("hello world")) + assert.NoError(t, err, id) + + buf := make([]byte, 11) + _, err = fd.Read(buf) + assert.NoError(t, err, id) + assert.Equal(t, "hello world", string(buf), id) + + pty.WriteLine("exit") + <-cmdDone + return nil + }) + } + + var eg errgroup.Group + for _, d := range done { + eg.Go(d) + } + err = eg.Wait() + require.NoError(t, err) + }) + t.Run("FileLogging", func(t *testing.T) { t.Parallel()