Skip to content

Commit 96c0724

Browse files
committed
fix(agent/agentssh): allow remote forwarding a socket multiple times
Fixes #11198 Fixes coder/customers#407
1 parent 054420b commit 96c0724

File tree

3 files changed

+184
-36
lines changed

3 files changed

+184
-36
lines changed

agent/agentssh/agentssh.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
9999
}
100100

101101
forwardHandler := &ssh.ForwardedTCPHandler{}
102-
unixForwardHandler := &forwardedUnixHandler{log: logger}
102+
unixForwardHandler := newForwardedUnixHandler(logger)
103103

104104
metrics := newSSHServerMetrics(prometheusRegistry)
105105
s := &Server{

agent/agentssh/forward.go

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package agentssh
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"io/fs"
68
"net"
79
"os"
810
"path/filepath"
911
"sync"
12+
"syscall"
1013

1114
"github.com/gliderlabs/ssh"
1215
gossh "golang.org/x/crypto/ssh"
@@ -33,22 +36,29 @@ type forwardedStreamLocalPayload struct {
3336
type forwardedUnixHandler struct {
3437
sync.Mutex
3538
log slog.Logger
36-
forwards map[string]net.Listener
39+
forwards map[forwardKey]net.Listener
40+
}
41+
42+
type forwardKey struct {
43+
sessionID string
44+
addr string
45+
}
46+
47+
func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
48+
return &forwardedUnixHandler{
49+
log: log,
50+
forwards: make(map[forwardKey]net.Listener),
51+
}
3752
}
3853

3954
func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
4055
h.log.Debug(ctx, "handling SSH unix forward")
41-
h.Lock()
42-
if h.forwards == nil {
43-
h.forwards = make(map[string]net.Listener)
44-
}
45-
h.Unlock()
4656
conn, ok := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
4757
if !ok {
4858
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
4959
return false, nil
5060
}
51-
log := h.log.With(slog.F("remote_addr", conn.RemoteAddr()))
61+
log := h.log.With(slog.F("session_id", ctx.SessionID()), slog.F("remote_addr", conn.RemoteAddr()))
5262

5363
switch req.Type {
5464
case "streamlocal-forward@openssh.com":
@@ -62,13 +72,17 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
6272
addr := reqPayload.SocketPath
6373
log = log.With(slog.F("socket_path", addr))
6474
log.Debug(ctx, "request begin SSH unix forward")
75+
76+
key := forwardKey{
77+
sessionID: ctx.SessionID(),
78+
addr: addr,
79+
}
80+
6581
h.Lock()
66-
_, ok := h.forwards[addr]
82+
_, ok := h.forwards[key]
6783
h.Unlock()
6884
if ok {
69-
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
70-
slog.F("socket_path", addr),
71-
)
85+
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded on this session")
7286
return false, nil
7387
}
7488

@@ -83,12 +97,18 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
8397
return false, nil
8498
}
8599

86-
ln, err := net.Listen("unix", addr)
100+
// Remove existing socket if it exists. It's possible we will overwrite
101+
// a regular file here, but this matches the behavior of OpenSSH.
102+
err = unlink(addr)
103+
if err != nil && !errors.Is(err, fs.ErrNotExist) {
104+
log.Warn(ctx, "remove existing socket for SSH unix forward request", slog.Error(err))
105+
return false, nil
106+
}
107+
108+
lc := &net.ListenConfig{}
109+
ln, err := lc.Listen(ctx, "unix", addr)
87110
if err != nil {
88-
log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
89-
slog.F("socket_path", addr),
90-
slog.Error(err),
91-
)
111+
log.Warn(ctx, "listen on Unix socket for SSH unix forward request", slog.Error(err))
92112
return false, nil
93113
}
94114
log.Debug(ctx, "SSH unix forward listening on socket")
@@ -99,7 +119,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
99119
//
100120
// This is also what the upstream TCP version of this code does.
101121
h.Lock()
102-
h.forwards[addr] = ln
122+
h.forwards[key] = ln
103123
h.Unlock()
104124
log.Debug(ctx, "SSH unix forward added to cache")
105125

@@ -115,9 +135,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
115135
c, err := ln.Accept()
116136
if err != nil {
117137
if !xerrors.Is(err, net.ErrClosed) {
118-
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
119-
slog.Error(err),
120-
)
138+
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request", slog.Error(err))
121139
}
122140
// closed below
123141
log.Debug(ctx, "SSH unix forward listener closed")
@@ -131,10 +149,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
131149
go func() {
132150
ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload)
133151
if err != nil {
134-
h.log.Warn(ctx, "open SSH unix forward channel to client",
135-
slog.F("socket_path", addr),
136-
slog.Error(err),
137-
)
152+
h.log.Warn(ctx, "open SSH unix forward channel to client", slog.Error(err))
138153
_ = c.Close()
139154
return
140155
}
@@ -144,12 +159,11 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
144159
}
145160

146161
h.Lock()
147-
ln2, ok := h.forwards[addr]
148-
if ok && ln2 == ln {
149-
delete(h.forwards, addr)
162+
if ln2, ok := h.forwards[key]; ok && ln2 == ln {
163+
delete(h.forwards, key)
150164
}
151165
h.Unlock()
152-
log.Debug(ctx, "SSH unix forward listener removed from cache", slog.F("path", addr))
166+
log.Debug(ctx, "SSH unix forward listener removed from cache")
153167
_ = ln.Close()
154168
}()
155169

@@ -162,13 +176,23 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
162176
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client", slog.Error(err))
163177
return false, nil
164178
}
165-
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("path", reqPayload.SocketPath))
179+
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("socket_path", reqPayload.SocketPath))
180+
181+
key := forwardKey{
182+
sessionID: ctx.SessionID(),
183+
addr: reqPayload.SocketPath,
184+
}
185+
166186
h.Lock()
167-
ln, ok := h.forwards[reqPayload.SocketPath]
187+
ln, ok := h.forwards[key]
188+
delete(h.forwards, key)
168189
h.Unlock()
169-
if ok {
170-
_ = ln.Close()
190+
if !ok {
191+
log.Warn(ctx, "SSH unix forward not found in cache")
192+
return true, nil
171193
}
194+
log.Debug(ctx, "SSH unix forward listener removed from cache")
195+
_ = ln.Close()
172196
return true, nil
173197

174198
default:
@@ -209,3 +233,14 @@ func directStreamLocalHandler(_ *ssh.Server, _ *gossh.ServerConn, newChan gossh.
209233

210234
Bicopy(ctx, ch, dconn)
211235
}
236+
237+
// unlink removes files only.
238+
func unlink(path string) error {
239+
// From os/file_posix.go:
240+
for {
241+
err := syscall.Unlink(path)
242+
if !errors.Is(err, syscall.EINTR) {
243+
return err
244+
}
245+
}
246+
}

cli/ssh_test.go

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ import (
2626
"github.com/stretchr/testify/require"
2727
"golang.org/x/crypto/ssh"
2828
gosshagent "golang.org/x/crypto/ssh/agent"
29+
"golang.org/x/sync/errgroup"
2930
"golang.org/x/xerrors"
3031

3132
"cdr.dev/slog"
3233
"cdr.dev/slog/sloggers/slogtest"
3334

3435
"github.com/coder/coder/v2/agent"
36+
"github.com/coder/coder/v2/agent/agentssh"
3537
"github.com/coder/coder/v2/agent/agenttest"
3638
"github.com/coder/coder/v2/cli/clitest"
3739
"github.com/coder/coder/v2/cli/cliui"
@@ -738,8 +740,8 @@ func TestSSH(t *testing.T) {
738740
defer cancel()
739741

740742
tmpdir := tempDirUnixSocket(t)
741-
agentSock := filepath.Join(tmpdir, "agent.sock")
742-
l, err := net.Listen("unix", agentSock)
743+
localSock := filepath.Join(tmpdir, "local.sock")
744+
l, err := net.Listen("unix", localSock)
743745
require.NoError(t, err)
744746
defer l.Close()
745747
remoteSock := filepath.Join(tmpdir, "remote.sock")
@@ -748,7 +750,7 @@ func TestSSH(t *testing.T) {
748750
"ssh",
749751
workspace.Name,
750752
"--remote-forward",
751-
fmt.Sprintf("%s:%s", remoteSock, agentSock),
753+
fmt.Sprintf("%s:%s", remoteSock, localSock),
752754
)
753755
clitest.SetupConfig(t, client, root)
754756
pty := ptytest.New(t).Attach(inv)
@@ -771,6 +773,117 @@ func TestSSH(t *testing.T) {
771773
<-cmdDone
772774
})
773775

776+
// Test that we can forward a local unix socket to a remote unix socket and
777+
// that new SSH sessions take over the socket without closing active socket
778+
// connections.
779+
t.Run("RemoteForwardUnixSocketMultipleSessionsOverwrite", func(t *testing.T) {
780+
if runtime.GOOS == "windows" {
781+
t.Skip("Test not supported on windows")
782+
}
783+
784+
t.Parallel()
785+
786+
client, workspace, agentToken := setupWorkspaceForAgent(t)
787+
788+
_ = agenttest.New(t, client.URL, agentToken)
789+
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
790+
791+
// Wait super super long so this doesn't flake on -race test.
792+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong*2)
793+
defer cancel()
794+
795+
tmpdir := tempDirUnixSocket(t)
796+
797+
localSock := filepath.Join(tmpdir, "local.sock")
798+
l, err := net.Listen("unix", localSock)
799+
require.NoError(t, err)
800+
defer l.Close()
801+
testutil.Go(t, func() {
802+
for {
803+
fd, err := l.Accept()
804+
if err != nil {
805+
if !errors.Is(err, net.ErrClosed) {
806+
assert.NoError(t, err, "listener accept failed")
807+
}
808+
return
809+
}
810+
811+
testutil.Go(t, func() {
812+
defer fd.Close()
813+
agentssh.Bicopy(ctx, fd, fd)
814+
})
815+
}
816+
})
817+
818+
remoteSock := filepath.Join(tmpdir, "remote.sock")
819+
820+
var done []func() error
821+
for i := 0; i < 2; i++ {
822+
id := fmt.Sprintf("ssh-%d", i)
823+
inv, root := clitest.New(t,
824+
"ssh",
825+
workspace.Name,
826+
"--remote-forward",
827+
fmt.Sprintf("%s:%s", remoteSock, localSock),
828+
)
829+
inv.Logger = inv.Logger.Named(id)
830+
clitest.SetupConfig(t, client, root)
831+
pty := ptytest.New(t).Attach(inv)
832+
inv.Stderr = pty.Output()
833+
cmdDone := tGo(t, func() {
834+
err := inv.WithContext(ctx).Run()
835+
assert.NoError(t, err, "ssh command failed: %s", id)
836+
})
837+
838+
// Since something was output, it should be safe to write input.
839+
// This could show a prompt or "running startup scripts", so it's
840+
// not indicative of the SSH connection being ready.
841+
_ = pty.Peek(ctx, 1)
842+
843+
// Ensure the SSH connection is ready by testing the shell
844+
// input/output.
845+
pty.WriteLine("echo ping' 'pong")
846+
pty.ExpectMatchContext(ctx, "ping pong")
847+
848+
d := &net.Dialer{}
849+
fd, err := d.DialContext(ctx, "unix", remoteSock)
850+
require.NoError(t, err, id)
851+
852+
// Send a message to the server.
853+
_, err = fd.Write([]byte("hello world"))
854+
require.NoError(t, err, id)
855+
856+
// Read the response.
857+
buf := make([]byte, 11)
858+
_, err = fd.Read(buf)
859+
require.NoError(t, err, id)
860+
require.Equal(t, "hello world", string(buf), id)
861+
862+
done = append(done, func() error {
863+
// Test that both socket connections still work.
864+
_, err := fd.Write([]byte("hello world"))
865+
require.NoError(t, err, id)
866+
867+
// Read the response.
868+
buf := make([]byte, 11)
869+
_, err = fd.Read(buf)
870+
require.NoError(t, err, id)
871+
require.Equal(t, "hello world", string(buf), id)
872+
873+
pty.WriteLine("exit")
874+
<-cmdDone
875+
return nil
876+
})
877+
}
878+
879+
var eg errgroup.Group
880+
for _, d := range done {
881+
eg.Go(d)
882+
}
883+
err = eg.Wait()
884+
require.NoError(t, err)
885+
})
886+
774887
t.Run("FileLogging", func(t *testing.T) {
775888
t.Parallel()
776889

0 commit comments

Comments
 (0)