Skip to content

Commit f400d8a

Browse files
authored
fix: handle SIGHUP from OpenSSH (#10638)
Fixes an issue where remote forwards are not correctly torn down when using OpenSSH with `coder ssh --stdio`. OpenSSH sends a disconnect signal, but then also sends SIGHUP to `coder`. Previously, we just exited when we got SIGHUP, and this raced against properly disconnecting. Fixes coder/customers#327
1 parent be0436a commit f400d8a

File tree

15 files changed

+248
-29
lines changed

15 files changed

+248
-29
lines changed

agent/agentssh/forward.go

+17-9
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type forwardedUnixHandler struct {
3737
}
3838

3939
func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
40+
h.log.Debug(ctx, "handling SSH unix forward")
4041
h.Lock()
4142
if h.forwards == nil {
4243
h.forwards = make(map[string]net.Listener)
@@ -47,22 +48,25 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
4748
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
4849
return false, nil
4950
}
51+
log := h.log.With(slog.F("remote_addr", conn.RemoteAddr()))
5052

5153
switch req.Type {
5254
case "streamlocal-forward@openssh.com":
5355
var reqPayload streamLocalForwardPayload
5456
err := gossh.Unmarshal(req.Payload, &reqPayload)
5557
if err != nil {
56-
h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request payload from client", slog.Error(err))
58+
h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request (SSH unix forward) payload from client", slog.Error(err))
5759
return false, nil
5860
}
5961

6062
addr := reqPayload.SocketPath
63+
log = log.With(slog.F("socket_path", addr))
64+
log.Debug(ctx, "request begin SSH unix forward")
6165
h.Lock()
6266
_, ok := h.forwards[addr]
6367
h.Unlock()
6468
if ok {
65-
h.log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
69+
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
6670
slog.F("socket_path", addr),
6771
)
6872
return false, nil
@@ -72,22 +76,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
7276
parentDir := filepath.Dir(addr)
7377
err = os.MkdirAll(parentDir, 0o700)
7478
if err != nil {
75-
h.log.Warn(ctx, "create parent dir for SSH unix forward request",
79+
log.Warn(ctx, "create parent dir for SSH unix forward request",
7680
slog.F("parent_dir", parentDir),
77-
slog.F("socket_path", addr),
7881
slog.Error(err),
7982
)
8083
return false, nil
8184
}
8285

8386
ln, err := net.Listen("unix", addr)
8487
if err != nil {
85-
h.log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
88+
log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
8689
slog.F("socket_path", addr),
8790
slog.Error(err),
8891
)
8992
return false, nil
9093
}
94+
log.Debug(ctx, "SSH unix forward listening on socket")
9195

9296
// The listener needs to successfully start before it can be added to
9397
// the map, so we don't have to worry about checking for an existing
@@ -97,6 +101,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
97101
h.Lock()
98102
h.forwards[addr] = ln
99103
h.Unlock()
104+
log.Debug(ctx, "SSH unix forward added to cache")
100105

101106
ctx, cancel := context.WithCancel(ctx)
102107
go func() {
@@ -110,22 +115,23 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
110115
c, err := ln.Accept()
111116
if err != nil {
112117
if !xerrors.Is(err, net.ErrClosed) {
113-
h.log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
114-
slog.F("socket_path", addr),
118+
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
115119
slog.Error(err),
116120
)
117121
}
118122
// closed below
123+
log.Debug(ctx, "SSH unix forward listener closed")
119124
break
120125
}
126+
log.Debug(ctx, "accepted SSH unix forward connection")
121127
payload := gossh.Marshal(&forwardedStreamLocalPayload{
122128
SocketPath: addr,
123129
})
124130

125131
go func() {
126132
ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload)
127133
if err != nil {
128-
h.log.Warn(ctx, "open SSH channel to forward Unix connection to client",
134+
h.log.Warn(ctx, "open SSH unix forward channel to client",
129135
slog.F("socket_path", addr),
130136
slog.Error(err),
131137
)
@@ -143,6 +149,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
143149
delete(h.forwards, addr)
144150
}
145151
h.Unlock()
152+
log.Debug(ctx, "SSH unix forward listener removed from cache", slog.F("path", addr))
146153
_ = ln.Close()
147154
}()
148155

@@ -152,9 +159,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
152159
var reqPayload streamLocalForwardPayload
153160
err := gossh.Unmarshal(req.Payload, &reqPayload)
154161
if err != nil {
155-
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com request payload from client", slog.Error(err))
162+
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client", slog.Error(err))
156163
return false, nil
157164
}
165+
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("path", reqPayload.SocketPath))
158166
h.Lock()
159167
ln, ok := h.forwards[reqPayload.SocketPath]
160168
h.Unlock()

cli/agent.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"net/http/pprof"
99
"net/url"
1010
"os"
11-
"os/signal"
1211
"path/filepath"
1312
"runtime"
1413
"strconv"
@@ -144,7 +143,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
144143
// Note that we don't want to handle these signals in the
145144
// process that runs as PID 1, that's why we do this after
146145
// the reaper forked.
147-
ctx, stopNotify := signal.NotifyContext(ctx, InterruptSignals...)
146+
ctx, stopNotify := inv.SignalNotifyContext(ctx, InterruptSignals...)
148147
defer stopNotify()
149148

150149
// DumpHandler does signal handling, so we call it after the

cli/clibase/cmd.go

+25
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import (
77
"fmt"
88
"io"
99
"os"
10+
"os/signal"
1011
"strings"
12+
"testing"
1113
"unicode"
1214

1315
"github.com/spf13/pflag"
@@ -183,6 +185,9 @@ type Invocation struct {
183185
Stdout io.Writer
184186
Stderr io.Writer
185187
Stdin io.Reader
188+
189+
// testing
190+
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
186191
}
187192

188193
// WithOS returns the invocation as a main package, filling in the invocation's unset
@@ -197,6 +202,26 @@ func (inv *Invocation) WithOS() *Invocation {
197202
})
198203
}
199204

205+
// WithTestSignalNotifyContext allows overriding the default implementation of SignalNotifyContext.
206+
// This should only be used in testing.
207+
func (inv *Invocation) WithTestSignalNotifyContext(
208+
_ testing.TB, // ensure we only call this from tests
209+
f func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc),
210+
) *Invocation {
211+
return inv.with(func(i *Invocation) {
212+
i.signalNotifyContext = f
213+
})
214+
}
215+
216+
// SignalNotifyContext is equivalent to signal.NotifyContext, but supports being overridden in
217+
// tests.
218+
func (inv *Invocation) SignalNotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
219+
if inv.signalNotifyContext == nil {
220+
return signal.NotifyContext(parent, signals...)
221+
}
222+
return inv.signalNotifyContext(parent, signals...)
223+
}
224+
200225
func (inv *Invocation) Context() context.Context {
201226
if inv.ctx == nil {
202227
return context.Background()

cli/clitest/signal.go

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package clitest
2+
3+
import (
4+
"context"
5+
"os"
6+
"sync"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
type FakeSignalNotifier struct {
13+
sync.Mutex
14+
t *testing.T
15+
ctx context.Context
16+
cancel context.CancelFunc
17+
signals []os.Signal
18+
stopped bool
19+
}
20+
21+
func NewFakeSignalNotifier(t *testing.T) *FakeSignalNotifier {
22+
fsn := &FakeSignalNotifier{t: t}
23+
return fsn
24+
}
25+
26+
func (f *FakeSignalNotifier) Stop() {
27+
f.Lock()
28+
defer f.Unlock()
29+
f.stopped = true
30+
if f.cancel == nil {
31+
f.t.Error("stopped before started")
32+
return
33+
}
34+
f.cancel()
35+
}
36+
37+
func (f *FakeSignalNotifier) NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
38+
f.Lock()
39+
defer f.Unlock()
40+
f.signals = signals
41+
f.ctx, f.cancel = context.WithCancel(parent)
42+
return f.ctx, f.Stop
43+
}
44+
45+
func (f *FakeSignalNotifier) Notify() {
46+
f.Lock()
47+
defer f.Unlock()
48+
if f.cancel == nil {
49+
f.t.Error("notified before started")
50+
return
51+
}
52+
f.cancel()
53+
}
54+
55+
func (f *FakeSignalNotifier) AssertStopped() {
56+
f.Lock()
57+
defer f.Unlock()
58+
assert.True(f.t, f.stopped)
59+
}

cli/externalauth.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package cli
22

33
import (
44
"encoding/json"
5-
"os/signal"
65

76
"golang.org/x/xerrors"
87

@@ -63,7 +62,7 @@ fi
6362
Handler: func(inv *clibase.Invocation) error {
6463
ctx := inv.Context()
6564

66-
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
65+
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
6766
defer stop()
6867

6968
client, err := r.createAgentClient()

cli/gitaskpass.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"errors"
55
"fmt"
66
"net/http"
7-
"os/signal"
87
"time"
98

109
"golang.org/x/xerrors"
@@ -26,7 +25,7 @@ func (r *RootCmd) gitAskpass() *clibase.Cmd {
2625
Handler: func(inv *clibase.Invocation) error {
2726
ctx := inv.Context()
2827

29-
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
28+
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
3029
defer stop()
3130

3231
user, host, err := gitauth.ParseAskpass(inv.Args[0])

cli/gitssh.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"io"
99
"os"
1010
"os/exec"
11-
"os/signal"
1211
"path/filepath"
1312
"strings"
1413

@@ -30,7 +29,7 @@ func (r *RootCmd) gitssh() *clibase.Cmd {
3029

3130
// Catch interrupt signals to ensure the temporary private
3231
// key file is cleaned up on most cases.
33-
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
32+
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
3433
defer stop()
3534

3635
// Early check so errors are reported immediately.

cli/server.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"net/http/pprof"
2323
"net/url"
2424
"os"
25-
"os/signal"
2625
"os/user"
2726
"path/filepath"
2827
"regexp"
@@ -333,7 +332,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
333332
//
334333
// To get out of a graceful shutdown, the user can send
335334
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
336-
notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...)
335+
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, InterruptSignals...)
337336
defer notifyStop()
338337

339338
cacheDir := vals.CacheDir.String()
@@ -1098,7 +1097,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
10981097
logger = logger.Leveled(slog.LevelDebug)
10991098
}
11001099

1101-
ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
1100+
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
11021101
defer cancel()
11031102

11041103
url, closePg, err := startBuiltinPostgres(ctx, cfg, logger)

cli/server_createadminuser.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package cli
44

55
import (
66
"fmt"
7-
"os/signal"
87
"sort"
98

109
"github.com/google/uuid"
@@ -48,7 +47,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *clibase.Cmd {
4847
logger = logger.Leveled(slog.LevelDebug)
4948
}
5049

51-
ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
50+
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
5251
defer cancel()
5352

5453
if newUserDBURL == "" {

cli/ssh.go

+10-3
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,15 @@ func (r *RootCmd) ssh() *clibase.Cmd {
6262
r.InitClient(client),
6363
),
6464
Handler: func(inv *clibase.Invocation) (retErr error) {
65-
ctx, cancel := context.WithCancel(inv.Context())
65+
// Before dialing the SSH server over TCP, capture Interrupt signals
66+
// so that if we are interrupted, we have a chance to tear down the
67+
// TCP session cleanly before exiting. If we don't, then the TCP
68+
// session can persist for up to 72 hours, since we set a long
69+
// timeout on the Agent side of the connection. In particular,
70+
// OpenSSH sends SIGHUP to terminate a proxy command.
71+
ctx, stop := inv.SignalNotifyContext(inv.Context(), InterruptSignals...)
72+
defer stop()
73+
ctx, cancel := context.WithCancel(ctx)
6674
defer cancel()
6775

6876
logger := slog.Make() // empty logger
@@ -227,8 +235,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
227235
go func() {
228236
defer wg.Done()
229237
// Ensure stdout copy closes incase stdin is closed
230-
// unexpectedly. Typically we wouldn't worry about
231-
// this since OpenSSH should kill the proxy command.
238+
// unexpectedly.
232239
defer rawSSH.Close()
233240

234241
_, err := io.Copy(rawSSH, inv.Stdin)

0 commit comments

Comments
 (0)