Skip to content

Commit 5c48787

Browse files
committed
Fix coder ssh to handle SIGHUP from OpenSSH
1 parent be0436a commit 5c48787

File tree

16 files changed

+310
-21
lines changed

16 files changed

+310
-21
lines changed

agent/agentssh/agentssh_test.go

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ package agentssh_test
55
import (
66
"bytes"
77
"context"
8+
"io"
89
"net"
10+
"os"
11+
"path"
912
"runtime"
1013
"strings"
1114
"sync"
@@ -19,11 +22,13 @@ import (
1922
"go.uber.org/goleak"
2023
"golang.org/x/crypto/ssh"
2124

25+
"cdr.dev/slog"
2226
"cdr.dev/slog/sloggers/slogtest"
2327

2428
"github.com/coder/coder/v2/agent/agentssh"
2529
"github.com/coder/coder/v2/codersdk/agentsdk"
2630
"github.com/coder/coder/v2/pty/ptytest"
31+
"github.com/coder/coder/v2/testutil"
2732
)
2833

2934
func TestMain(m *testing.M) {
@@ -159,7 +164,85 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
159164
wg.Wait()
160165
}
161166

167+
func Test_UnixRemoteForward(t *testing.T) {
168+
t.Parallel()
169+
170+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
171+
defer cancel()
172+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
173+
s, err := agentssh.NewServer(ctx, logger.Named("ssh"), prometheus.NewRegistry(), afero.NewOsFs(), 0, "")
174+
require.NoError(t, err)
175+
defer s.Close()
176+
177+
// The assumption is that these are set before serving SSH connections.
178+
s.AgentToken = func() string { return "" }
179+
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
180+
181+
ln, err := net.Listen("tcp", "127.0.0.1:0")
182+
require.NoError(t, err)
183+
184+
serverErr := make(chan error)
185+
go func() {
186+
err := s.Serve(ln)
187+
serverErr <- err
188+
}()
189+
190+
dir := t.TempDir()
191+
remoteSock := path.Join(dir, "test.sock")
192+
193+
// connect, disconnect several times
194+
for i := 0; i < 4; i++ {
195+
client, clientConn := sshClientAndConn(t, ln.Addr().String())
196+
197+
l, err := client.ListenUnix(remoteSock)
198+
require.NoError(t, err)
199+
msgs := make(chan string)
200+
201+
// accept a single connection
202+
go func() {
203+
conn, err := l.Accept()
204+
assert.NoError(t, err)
205+
msg, err := io.ReadAll(conn)
206+
assert.NoError(t, err)
207+
msgs <- string(msg)
208+
}()
209+
210+
// write a single message
211+
go func() {
212+
conn, err := net.Dial("unix", remoteSock)
213+
assert.NoError(t, err)
214+
_, err = conn.Write([]byte("test"))
215+
assert.NoError(t, err)
216+
err = conn.Close()
217+
assert.NoError(t, err)
218+
}()
219+
220+
msg := recvCtx(ctx, t, msgs)
221+
require.Equal(t, "test", msg)
222+
223+
err = l.Close()
224+
require.NoError(t, err)
225+
226+
//err = client.Close()
227+
//require.NoError(t, err)
228+
err = clientConn.Close()
229+
require.NoError(t, err)
230+
}
231+
_, err = os.Stat(remoteSock)
232+
require.ErrorIs(t, err, os.ErrNotExist)
233+
234+
err = s.Close()
235+
require.NoError(t, err)
236+
// don't care whether we get an error
237+
_ = recvCtx(ctx, t, serverErr)
238+
}
239+
162240
func sshClient(t *testing.T, addr string) *ssh.Client {
241+
client, _ := sshClientAndConn(t, addr)
242+
return client
243+
}
244+
245+
func sshClientAndConn(t *testing.T, addr string) (*ssh.Client, net.Conn) {
163246
conn, err := net.Dial("tcp", addr)
164247
require.NoError(t, err)
165248
t.Cleanup(func() {
@@ -177,5 +260,16 @@ func sshClient(t *testing.T, addr string) *ssh.Client {
177260
t.Cleanup(func() {
178261
_ = c.Close()
179262
})
180-
return c
263+
return c, conn
264+
}
265+
266+
func recvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A) {
267+
t.Helper()
268+
select {
269+
case <-ctx.Done():
270+
t.Fatal("timeout")
271+
return
272+
case a = <-c:
273+
return a
274+
}
181275
}

agent/agentssh/forward.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ type forwardedStreamLocalPayload struct {
2828
Reserved uint32
2929
}
3030

31+
// unixSock represents a local unix socket that forwards connections over an SSH ServerConn
32+
type unixSock struct {
33+
sync.Mutex
34+
path string
35+
listener net.Listener
36+
conn *gossh.ServerConn
37+
closed bool
38+
}
39+
3140
// forwardedUnixHandler is a clone of ssh.ForwardedTCPHandler that does
3241
// streamlocal forwarding (aka. unix forwarding) instead of TCP forwarding.
3342
type forwardedUnixHandler struct {
@@ -37,6 +46,7 @@ type forwardedUnixHandler struct {
3746
}
3847

3948
func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
49+
h.log.Debug(ctx, "handling unix remote forward")
4050
h.Lock()
4151
if h.forwards == nil {
4252
h.forwards = make(map[string]net.Listener)
@@ -47,6 +57,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
4757
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
4858
return false, nil
4959
}
60+
log := h.log.With(slog.F("remote_addr", conn.RemoteAddr()))
5061

5162
switch req.Type {
5263
case "streamlocal-forward@openssh.com":
@@ -58,6 +69,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
5869
}
5970

6071
addr := reqPayload.SocketPath
72+
log.Debug(ctx, "request begin unix remote-forward", slog.F("path", addr))
6173
h.Lock()
6274
_, ok := h.forwards[addr]
6375
h.Unlock()
@@ -116,8 +128,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
116128
)
117129
}
118130
// closed below
131+
log.Debug(ctx, "remote-forward listener closed", slog.F("path", addr))
119132
break
120133
}
134+
log.Debug(ctx, "accepted remote-forward unix connection", slog.F("path", addr))
121135
payload := gossh.Marshal(&forwardedStreamLocalPayload{
122136
SocketPath: addr,
123137
})
@@ -143,6 +157,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
143157
delete(h.forwards, addr)
144158
}
145159
h.Unlock()
160+
log.Debug(ctx, "unix remote-forward listener removed from cache", slog.F("path", addr))
146161
_ = ln.Close()
147162
}()
148163

@@ -155,6 +170,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
155170
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com request payload from client", slog.Error(err))
156171
return false, nil
157172
}
173+
log.Debug(ctx, "request to cancel unix remote-forward", slog.F("path", reqPayload.SocketPath))
158174
h.Lock()
159175
ln, ok := h.forwards[reqPayload.SocketPath]
160176
h.Unlock()

cli/agent.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"net/http/pprof"
99
"net/url"
1010
"os"
11-
"os/signal"
1211
"path/filepath"
1312
"runtime"
1413
"strconv"
@@ -144,7 +143,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
144143
// Note that we don't want to handle these signals in the
145144
// process that runs as PID 1, that's why we do this after
146145
// the reaper forked.
147-
ctx, stopNotify := signal.NotifyContext(ctx, InterruptSignals...)
146+
ctx, stopNotify := inv.SignalNotifyContext(ctx, InterruptSignals...)
148147
defer stopNotify()
149148

150149
// DumpHandler does signal handling, so we call it after the

cli/clibase/cmd.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"os"
10+
"os/signal"
1011
"strings"
1112
"unicode"
1213

@@ -183,6 +184,9 @@ type Invocation struct {
183184
Stdout io.Writer
184185
Stderr io.Writer
185186
Stdin io.Reader
187+
188+
// testing
189+
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
186190
}
187191

188192
// WithOS returns the invocation as a main package, filling in the invocation's unset
@@ -197,6 +201,25 @@ func (inv *Invocation) WithOS() *Invocation {
197201
})
198202
}
199203

204+
// WithTestSignalNotifyContext allows overriding the default implementation of SignalNotifyContext.
205+
// This should only be used in testing.
206+
func (inv *Invocation) WithTestSignalNotifyContext(
207+
f func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc),
208+
) *Invocation {
209+
return inv.with(func(i *Invocation) {
210+
i.signalNotifyContext = f
211+
})
212+
}
213+
214+
// SignalNotifyContext is equivalent to signal.NotifyContext, but supports being overridden in
215+
// tests.
216+
func (inv *Invocation) SignalNotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
217+
if inv.signalNotifyContext == nil {
218+
return signal.NotifyContext(parent, signals...)
219+
}
220+
return inv.signalNotifyContext(parent, signals...)
221+
}
222+
200223
func (inv *Invocation) Context() context.Context {
201224
if inv.ctx == nil {
202225
return context.Background()

cli/clitest/signal.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package clitest
2+
3+
import (
4+
"context"
5+
"os"
6+
"sync"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
type FakeSignalNotifier struct {
13+
sync.Mutex
14+
t *testing.T
15+
ctx context.Context
16+
cancel context.CancelFunc
17+
signals []os.Signal
18+
stopped bool
19+
}
20+
21+
func NewFakeSignalNotifier(t *testing.T) *FakeSignalNotifier {
22+
fsn := &FakeSignalNotifier{t: t}
23+
return fsn
24+
}
25+
26+
func (f *FakeSignalNotifier) Stop() {
27+
f.Lock()
28+
defer f.Unlock()
29+
f.stopped = true
30+
if f.cancel == nil {
31+
f.t.Error("stopped before started")
32+
return
33+
}
34+
f.cancel()
35+
}
36+
37+
func (f *FakeSignalNotifier) NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
38+
f.Lock()
39+
defer f.Unlock()
40+
f.signals = signals
41+
f.ctx, f.cancel = context.WithCancel(parent)
42+
return f.ctx, f.Stop
43+
}
44+
45+
func (f *FakeSignalNotifier) Notify() {
46+
f.Lock()
47+
defer f.Unlock()
48+
if f.cancel == nil {
49+
f.t.Error("notified before started")
50+
return
51+
}
52+
f.cancel()
53+
}
54+
55+
func (f *FakeSignalNotifier) AssertStopped() {
56+
f.Lock()
57+
defer f.Unlock()
58+
assert.True(f.t, f.stopped)
59+
}

cli/externalauth.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package cli
22

33
import (
44
"encoding/json"
5-
"os/signal"
65

76
"golang.org/x/xerrors"
87

@@ -63,7 +62,7 @@ fi
6362
Handler: func(inv *clibase.Invocation) error {
6463
ctx := inv.Context()
6564

66-
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
65+
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
6766
defer stop()
6867

6968
client, err := r.createAgentClient()

cli/gitaskpass.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"errors"
55
"fmt"
66
"net/http"
7-
"os/signal"
87
"time"
98

109
"golang.org/x/xerrors"
@@ -26,7 +25,7 @@ func (r *RootCmd) gitAskpass() *clibase.Cmd {
2625
Handler: func(inv *clibase.Invocation) error {
2726
ctx := inv.Context()
2827

29-
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
28+
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
3029
defer stop()
3130

3231
user, host, err := gitauth.ParseAskpass(inv.Args[0])

cli/gitssh.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"io"
99
"os"
1010
"os/exec"
11-
"os/signal"
1211
"path/filepath"
1312
"strings"
1413

@@ -30,7 +29,7 @@ func (r *RootCmd) gitssh() *clibase.Cmd {
3029

3130
// Catch interrupt signals to ensure the temporary private
3231
// key file is cleaned up on most cases.
33-
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
32+
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
3433
defer stop()
3534

3635
// Early check so errors are reported immediately.

cli/server.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"net/http/pprof"
2323
"net/url"
2424
"os"
25-
"os/signal"
2625
"os/user"
2726
"path/filepath"
2827
"regexp"
@@ -333,7 +332,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
333332
//
334333
// To get out of a graceful shutdown, the user can send
335334
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
336-
notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...)
335+
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, InterruptSignals...)
337336
defer notifyStop()
338337

339338
cacheDir := vals.CacheDir.String()
@@ -1098,7 +1097,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
10981097
logger = logger.Leveled(slog.LevelDebug)
10991098
}
11001099

1101-
ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
1100+
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
11021101
defer cancel()
11031102

11041103
url, closePg, err := startBuiltinPostgres(ctx, cfg, logger)

0 commit comments

Comments
 (0)