Skip to content

fix(agent/agentssh): allow remote forwarding a socket multiple times #11631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
}

forwardHandler := &ssh.ForwardedTCPHandler{}
unixForwardHandler := &forwardedUnixHandler{log: logger}
unixForwardHandler := newForwardedUnixHandler(logger)

metrics := newSSHServerMetrics(prometheusRegistry)
s := &Server{
Expand Down
107 changes: 74 additions & 33 deletions agent/agentssh/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package agentssh

import (
"context"
"errors"
"fmt"
"io/fs"
"net"
"os"
"path/filepath"
"sync"
"syscall"

"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
Expand All @@ -33,22 +36,29 @@ type forwardedStreamLocalPayload struct {
type forwardedUnixHandler struct {
sync.Mutex
log slog.Logger
forwards map[string]net.Listener
forwards map[forwardKey]net.Listener
}

type forwardKey struct {
sessionID string
addr string
}

func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
return &forwardedUnixHandler{
log: log,
forwards: make(map[forwardKey]net.Listener),
}
}

func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
h.log.Debug(ctx, "handling SSH unix forward")
h.Lock()
if h.forwards == nil {
h.forwards = make(map[string]net.Listener)
}
h.Unlock()
conn, ok := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
if !ok {
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
return false, nil
}
log := h.log.With(slog.F("remote_addr", conn.RemoteAddr()))
log := h.log.With(slog.F("session_id", ctx.SessionID()), slog.F("remote_addr", conn.RemoteAddr()))

switch req.Type {
case "streamlocal-forward@openssh.com":
Expand All @@ -62,14 +72,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
addr := reqPayload.SocketPath
log = log.With(slog.F("socket_path", addr))
log.Debug(ctx, "request begin SSH unix forward")

key := forwardKey{
sessionID: ctx.SessionID(),
addr: addr,
}

h.Lock()
_, ok := h.forwards[addr]
_, ok := h.forwards[key]
h.Unlock()
if ok {
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
slog.F("socket_path", addr),
)
return false, nil
// In cases where `ExitOnForwardFailure=yes` is set, returning false
// here will cause the connection to be closed. To avoid this, and
// to match OpenSSH behavior, we silently ignore the second forward
// request.
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded on this session, ignoring")
return true, nil
}

// Create socket parent dir if not exists.
Expand All @@ -83,12 +101,20 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
return false, nil
}

ln, err := net.Listen("unix", addr)
// Remove existing socket if it exists. We do not use os.Remove() here
// so that directories are kept. Note that it's possible that we will
// overwrite a regular file here. Both of these behaviors match OpenSSH,
// however, which is why we unlink.
err = unlink(addr)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
log.Warn(ctx, "remove existing socket for SSH unix forward request", slog.Error(err))
return false, nil
}

lc := &net.ListenConfig{}
ln, err := lc.Listen(ctx, "unix", addr)
if err != nil {
log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
slog.F("socket_path", addr),
slog.Error(err),
)
log.Warn(ctx, "listen on Unix socket for SSH unix forward request", slog.Error(err))
return false, nil
}
log.Debug(ctx, "SSH unix forward listening on socket")
Expand All @@ -99,7 +125,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
//
// This is also what the upstream TCP version of this code does.
h.Lock()
h.forwards[addr] = ln
h.forwards[key] = ln
h.Unlock()
log.Debug(ctx, "SSH unix forward added to cache")

Expand All @@ -115,9 +141,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
c, err := ln.Accept()
if err != nil {
if !xerrors.Is(err, net.ErrClosed) {
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
slog.Error(err),
)
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request", slog.Error(err))
}
// closed below
log.Debug(ctx, "SSH unix forward listener closed")
Expand All @@ -131,10 +155,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
go func() {
ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload)
if err != nil {
h.log.Warn(ctx, "open SSH unix forward channel to client",
slog.F("socket_path", addr),
slog.Error(err),
)
h.log.Warn(ctx, "open SSH unix forward channel to client", slog.Error(err))
_ = c.Close()
return
}
Expand All @@ -144,12 +165,11 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
}

h.Lock()
ln2, ok := h.forwards[addr]
if ok && ln2 == ln {
delete(h.forwards, addr)
if ln2, ok := h.forwards[key]; ok && ln2 == ln {
delete(h.forwards, key)
}
h.Unlock()
log.Debug(ctx, "SSH unix forward listener removed from cache", slog.F("path", addr))
log.Debug(ctx, "SSH unix forward listener removed from cache")
_ = ln.Close()
}()

Expand All @@ -162,13 +182,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client", slog.Error(err))
return false, nil
}
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("path", reqPayload.SocketPath))
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("socket_path", reqPayload.SocketPath))

key := forwardKey{
sessionID: ctx.SessionID(),
addr: reqPayload.SocketPath,
}

h.Lock()
ln, ok := h.forwards[reqPayload.SocketPath]
ln, ok := h.forwards[key]
delete(h.forwards, key)
h.Unlock()
if ok {
_ = ln.Close()
if !ok {
log.Warn(ctx, "SSH unix forward not found in cache")
return true, nil
}
_ = ln.Close()
return true, nil

default:
Expand Down Expand Up @@ -209,3 +238,15 @@ func directStreamLocalHandler(_ *ssh.Server, _ *gossh.ServerConn, newChan gossh.

Bicopy(ctx, ch, dconn)
}

// unlink removes files and unlike os.Remove, directories are kept.
func unlink(path string) error {
// Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go
// for more details.
for {
err := syscall.Unlink(path)
if !errors.Is(err, syscall.EINTR) {
return err
}
}
}
118 changes: 115 additions & 3 deletions cli/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
gosshagent "golang.org/x/crypto/ssh/agent"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"

"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/cli/cliui"
Expand Down Expand Up @@ -738,8 +740,8 @@ func TestSSH(t *testing.T) {
defer cancel()

tmpdir := tempDirUnixSocket(t)
agentSock := filepath.Join(tmpdir, "agent.sock")
l, err := net.Listen("unix", agentSock)
localSock := filepath.Join(tmpdir, "local.sock")
l, err := net.Listen("unix", localSock)
require.NoError(t, err)
defer l.Close()
remoteSock := filepath.Join(tmpdir, "remote.sock")
Expand All @@ -748,7 +750,7 @@ func TestSSH(t *testing.T) {
"ssh",
workspace.Name,
"--remote-forward",
fmt.Sprintf("%s:%s", remoteSock, agentSock),
fmt.Sprintf("%s:%s", remoteSock, localSock),
)
clitest.SetupConfig(t, client, root)
pty := ptytest.New(t).Attach(inv)
Expand All @@ -771,6 +773,116 @@ func TestSSH(t *testing.T) {
<-cmdDone
})

// Test that we can forward a local unix socket to a remote unix socket and
// that new SSH sessions take over the socket without closing active socket
// connections.
t.Run("RemoteForwardUnixSocketMultipleSessionsOverwrite", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Test not supported on windows")
}

t.Parallel()

client, workspace, agentToken := setupWorkspaceForAgent(t)

_ = agenttest.New(t, client.URL, agentToken)
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)

// Wait super super long so this doesn't flake on -race test.
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong*2)
defer cancel()

tmpdir := tempDirUnixSocket(t)

localSock := filepath.Join(tmpdir, "local.sock")
l, err := net.Listen("unix", localSock)
require.NoError(t, err)
defer l.Close()
testutil.Go(t, func() {
for {
fd, err := l.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
assert.NoError(t, err, "listener accept failed")
}
return
}

testutil.Go(t, func() {
defer fd.Close()
agentssh.Bicopy(ctx, fd, fd)
})
}
})

remoteSock := filepath.Join(tmpdir, "remote.sock")

var done []func() error
for i := 0; i < 2; i++ {
id := fmt.Sprintf("ssh-%d", i)
inv, root := clitest.New(t,
"ssh",
workspace.Name,
"--remote-forward",
fmt.Sprintf("%s:%s", remoteSock, localSock),
)
inv.Logger = inv.Logger.Named(id)
clitest.SetupConfig(t, client, root)
pty := ptytest.New(t).Attach(inv)
inv.Stderr = pty.Output()
cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.NoError(t, err, "ssh command failed: %s", id)
})

// Since something was output, it should be safe to write input.
// This could show a prompt or "running startup scripts", so it's
// not indicative of the SSH connection being ready.
_ = pty.Peek(ctx, 1)

// Ensure the SSH connection is ready by testing the shell
// input/output.
pty.WriteLine("echo ping' 'pong")
pty.ExpectMatchContext(ctx, "ping pong")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using the CLI to test the agent is awkward and this is an example. This test would be easier to write and easier to understand if were in agent_test.go and we could use the go ssh client.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do somewhat agree, but I'm not sure agent_test.go is a much better place since the actual forwarding is implemented here, in cli/remoteforward.go. This test should ensure that the behavior we want works with the actual coder ssh client.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have tests for coder ssh in terms of making a unix socket remote forward and verifying it works. None of that behavior is changing in this PR.

What's changing is on the agent side of the connection. And it's not just coder ssh that talks to that SSH server: it could be OpenSSH or JetBrains Gateway or VSCode (via coder ssh --stdio), or even a custom SSH client written in go using codersdk. Therefore, we need to think about the required behavior in terms of what the SSH server is supposed to do, not just end-to-end with a specific client.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm not opposed tbh. I think the best place for most of this testing is within the agentssh package itself. I'll create an issue for this since if we go this route, there's some other restructuring that needs to be done as well. Like moving the forwarding logic from cli to agentssh, etc. Otherwise I think we'll end up with "many implementations of an SSH client".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


d := &net.Dialer{}
fd, err := d.DialContext(ctx, "unix", remoteSock)
require.NoError(t, err, id)

// Ping / pong to ensure the socket is working.
_, err = fd.Write([]byte("hello world"))
require.NoError(t, err, id)

buf := make([]byte, 11)
_, err = fd.Read(buf)
require.NoError(t, err, id)
require.Equal(t, "hello world", string(buf), id)

done = append(done, func() error {
// Redo ping / pong to ensure that the socket
// connections still work.
_, err := fd.Write([]byte("hello world"))
assert.NoError(t, err, id)

buf := make([]byte, 11)
_, err = fd.Read(buf)
assert.NoError(t, err, id)
assert.Equal(t, "hello world", string(buf), id)

pty.WriteLine("exit")
<-cmdDone
return nil
})
}

var eg errgroup.Group
for _, d := range done {
eg.Go(d)
}
err = eg.Wait()
require.NoError(t, err)
})

t.Run("FileLogging", func(t *testing.T) {
t.Parallel()

Expand Down