Skip to content

Commit 5732ab8

Browse files
committed
Fix coder ssh to handle SIGHUP from OpenSSH
1 parent be0436a commit 5732ab8

File tree

15 files changed

+206
-20
lines changed

15 files changed

+206
-20
lines changed

agent/agentssh/forward.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type forwardedUnixHandler struct {
3737
}
3838

3939
func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
40+
h.log.Debug(ctx, "handling unix remote forward")
4041
h.Lock()
4142
if h.forwards == nil {
4243
h.forwards = make(map[string]net.Listener)
@@ -47,6 +48,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
4748
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
4849
return false, nil
4950
}
51+
log := h.log.With(slog.F("remote_addr", conn.RemoteAddr()))
5052

5153
switch req.Type {
5254
case "streamlocal-forward@openssh.com":
@@ -58,6 +60,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
5860
}
5961

6062
addr := reqPayload.SocketPath
63+
log.Debug(ctx, "request begin unix remote-forward", slog.F("path", addr))
6164
h.Lock()
6265
_, ok := h.forwards[addr]
6366
h.Unlock()
@@ -116,8 +119,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
116119
)
117120
}
118121
// closed below
122+
log.Debug(ctx, "remote-forward listener closed", slog.F("path", addr))
119123
break
120124
}
125+
log.Debug(ctx, "accepted remote-forward unix connection", slog.F("path", addr))
121126
payload := gossh.Marshal(&forwardedStreamLocalPayload{
122127
SocketPath: addr,
123128
})
@@ -143,6 +148,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
143148
delete(h.forwards, addr)
144149
}
145150
h.Unlock()
151+
log.Debug(ctx, "unix remote-forward listener removed from cache", slog.F("path", addr))
146152
_ = ln.Close()
147153
}()
148154

@@ -155,6 +161,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
155161
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com request payload from client", slog.Error(err))
156162
return false, nil
157163
}
164+
log.Debug(ctx, "request to cancel unix remote-forward", slog.F("path", reqPayload.SocketPath))
158165
h.Lock()
159166
ln, ok := h.forwards[reqPayload.SocketPath]
160167
h.Unlock()

cli/agent.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"net/http/pprof"
99
"net/url"
1010
"os"
11-
"os/signal"
1211
"path/filepath"
1312
"runtime"
1413
"strconv"
@@ -144,7 +143,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
144143
// Note that we don't want to handle these signals in the
145144
// process that runs as PID 1, that's why we do this after
146145
// the reaper forked.
147-
ctx, stopNotify := signal.NotifyContext(ctx, InterruptSignals...)
146+
ctx, stopNotify := inv.SignalNotifyContext(ctx, InterruptSignals...)
148147
defer stopNotify()
149148

150149
// DumpHandler does signal handling, so we call it after the

cli/clibase/cmd.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"os"
10+
"os/signal"
1011
"strings"
1112
"unicode"
1213

@@ -183,6 +184,9 @@ type Invocation struct {
183184
Stdout io.Writer
184185
Stderr io.Writer
185186
Stdin io.Reader
187+
188+
// testing
189+
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
186190
}
187191

188192
// WithOS returns the invocation as a main package, filling in the invocation's unset
@@ -197,6 +201,25 @@ func (inv *Invocation) WithOS() *Invocation {
197201
})
198202
}
199203

204+
// WithTestSignalNotifyContext allows overriding the default implementation of SignalNotifyContext.
205+
// This should only be used in testing.
206+
func (inv *Invocation) WithTestSignalNotifyContext(
207+
f func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc),
208+
) *Invocation {
209+
return inv.with(func(i *Invocation) {
210+
i.signalNotifyContext = f
211+
})
212+
}
213+
214+
// SignalNotifyContext is equivalent to signal.NotifyContext, but supports being overridden in
215+
// tests.
216+
func (inv *Invocation) SignalNotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
217+
if inv.signalNotifyContext == nil {
218+
return signal.NotifyContext(parent, signals...)
219+
}
220+
return inv.signalNotifyContext(parent, signals...)
221+
}
222+
200223
func (inv *Invocation) Context() context.Context {
201224
if inv.ctx == nil {
202225
return context.Background()

cli/clitest/signal.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package clitest
2+
3+
import (
4+
"context"
5+
"os"
6+
"sync"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
type FakeSignalNotifier struct {
13+
sync.Mutex
14+
t *testing.T
15+
ctx context.Context
16+
cancel context.CancelFunc
17+
signals []os.Signal
18+
stopped bool
19+
}
20+
21+
func NewFakeSignalNotifier(t *testing.T) *FakeSignalNotifier {
22+
fsn := &FakeSignalNotifier{t: t}
23+
return fsn
24+
}
25+
26+
func (f *FakeSignalNotifier) Stop() {
27+
f.Lock()
28+
defer f.Unlock()
29+
f.stopped = true
30+
if f.cancel == nil {
31+
f.t.Error("stopped before started")
32+
return
33+
}
34+
f.cancel()
35+
}
36+
37+
func (f *FakeSignalNotifier) NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
38+
f.Lock()
39+
defer f.Unlock()
40+
f.signals = signals
41+
f.ctx, f.cancel = context.WithCancel(parent)
42+
return f.ctx, f.Stop
43+
}
44+
45+
func (f *FakeSignalNotifier) Notify() {
46+
f.Lock()
47+
defer f.Unlock()
48+
if f.cancel == nil {
49+
f.t.Error("notified before started")
50+
return
51+
}
52+
f.cancel()
53+
}
54+
55+
func (f *FakeSignalNotifier) AssertStopped() {
56+
f.Lock()
57+
defer f.Unlock()
58+
assert.True(f.t, f.stopped)
59+
}

cli/externalauth.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package cli
22

33
import (
44
"encoding/json"
5-
"os/signal"
65

76
"golang.org/x/xerrors"
87

@@ -63,7 +62,7 @@ fi
6362
Handler: func(inv *clibase.Invocation) error {
6463
ctx := inv.Context()
6564

66-
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
65+
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
6766
defer stop()
6867

6968
client, err := r.createAgentClient()

cli/gitaskpass.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"errors"
55
"fmt"
66
"net/http"
7-
"os/signal"
87
"time"
98

109
"golang.org/x/xerrors"
@@ -26,7 +25,7 @@ func (r *RootCmd) gitAskpass() *clibase.Cmd {
2625
Handler: func(inv *clibase.Invocation) error {
2726
ctx := inv.Context()
2827

29-
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
28+
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
3029
defer stop()
3130

3231
user, host, err := gitauth.ParseAskpass(inv.Args[0])

cli/gitssh.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"io"
99
"os"
1010
"os/exec"
11-
"os/signal"
1211
"path/filepath"
1312
"strings"
1413

@@ -30,7 +29,7 @@ func (r *RootCmd) gitssh() *clibase.Cmd {
3029

3130
// Catch interrupt signals to ensure the temporary private
3231
// key file is cleaned up on most cases.
33-
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
32+
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
3433
defer stop()
3534

3635
// Early check so errors are reported immediately.

cli/server.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"net/http/pprof"
2323
"net/url"
2424
"os"
25-
"os/signal"
2625
"os/user"
2726
"path/filepath"
2827
"regexp"
@@ -333,7 +332,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
333332
//
334333
// To get out of a graceful shutdown, the user can send
335334
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
336-
notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...)
335+
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, InterruptSignals...)
337336
defer notifyStop()
338337

339338
cacheDir := vals.CacheDir.String()
@@ -1098,7 +1097,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
10981097
logger = logger.Leveled(slog.LevelDebug)
10991098
}
11001099

1101-
ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
1100+
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
11021101
defer cancel()
11031102

11041103
url, closePg, err := startBuiltinPostgres(ctx, cfg, logger)

cli/server_createadminuser.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package cli
44

55
import (
66
"fmt"
7-
"os/signal"
87
"sort"
98

109
"github.com/google/uuid"
@@ -48,7 +47,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *clibase.Cmd {
4847
logger = logger.Leveled(slog.LevelDebug)
4948
}
5049

51-
ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
50+
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
5251
defer cancel()
5352

5453
if newUserDBURL == "" {

cli/ssh.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,15 @@ func (r *RootCmd) ssh() *clibase.Cmd {
6262
r.InitClient(client),
6363
),
6464
Handler: func(inv *clibase.Invocation) (retErr error) {
65-
ctx, cancel := context.WithCancel(inv.Context())
65+
// before dialing the SSH server over TCP, capture Interrupt signals
66+
// so that if we are interrupted, we have a chance to tear down the
67+
// TCP session cleanly before exiting. If we don't, then the TCP
68+
// session can persist for up to 72 hours, since we set a long
69+
// timeout on the Agent side of the connection. In particular,
70+
// OpenSSH sends SIGHUP to terminate a proxy command.
71+
ctx, stop := inv.SignalNotifyContext(inv.Context(), InterruptSignals...)
72+
defer stop()
73+
ctx, cancel := context.WithCancel(ctx)
6674
defer cancel()
6775

6876
logger := slog.Make() // empty logger
@@ -227,8 +235,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
227235
go func() {
228236
defer wg.Done()
229237
// Ensure stdout copy closes incase stdin is closed
230-
// unexpectedly. Typically we wouldn't worry about
231-
// this since OpenSSH should kill the proxy command.
238+
// unexpectedly.
232239
defer rawSSH.Close()
233240

234241
_, err := io.Copy(rawSSH, inv.Stdin)

cli/ssh_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"net/http/httptest"
1515
"os"
1616
"os/exec"
17+
"path"
1718
"path/filepath"
1819
"runtime"
1920
"strings"
@@ -245,6 +246,91 @@ func TestSSH(t *testing.T) {
245246
<-cmdDone
246247
})
247248

249+
// Test that we handle OS signals properly while remote forwarding, and don't just leave the TCP
250+
// socket hanging.
251+
t.Run("RemoteForward_Unix_Signal", func(t *testing.T) {
252+
if runtime.GOOS == "windows" {
253+
t.Skip("No unix sockets on windows")
254+
}
255+
t.Parallel()
256+
ctx := testutil.Context(t, testutil.WaitLong)
257+
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
258+
_, _ = tGoContext(t, func(ctx context.Context) {
259+
// Run this async so the SSH command has to wait for
260+
// the build and agent to connect!
261+
_ = agenttest.New(t, client.URL, agentToken)
262+
<-ctx.Done()
263+
})
264+
265+
tmpdir := tempDirUnixSocket(t)
266+
localSock := filepath.Join(tmpdir, "local.sock")
267+
l, err := net.Listen("unix", localSock)
268+
require.NoError(t, err)
269+
defer l.Close()
270+
remoteSock := path.Join(tmpdir, "remote.sock")
271+
for i := 0; i < 3; i++ {
272+
t.Logf("connect %d of 3", i+1)
273+
inv, root := clitest.New(t,
274+
"ssh",
275+
workspace.Name,
276+
"--remote-forward",
277+
remoteSock+":"+localSock,
278+
)
279+
fsn := clitest.NewFakeSignalNotifier(t)
280+
inv = inv.WithTestSignalNotifyContext(fsn.NotifyContext)
281+
inv.Stdout = io.Discard
282+
inv.Stderr = io.Discard
283+
284+
clitest.SetupConfig(t, client, root)
285+
cmdDone := tGo(t, func() {
286+
err := inv.WithContext(ctx).Run()
287+
assert.Error(t, err)
288+
})
289+
290+
// accept a single connection
291+
msgs := make(chan string)
292+
go func() {
293+
conn, err := l.Accept()
294+
assert.NoError(t, err)
295+
msg, err := io.ReadAll(conn)
296+
assert.NoError(t, err)
297+
msgs <- string(msg)
298+
}()
299+
300+
// write a single message
301+
go func() {
302+
var (
303+
conn net.Conn
304+
err error
305+
)
306+
for {
307+
select {
308+
case <-ctx.Done():
309+
t.Error("timeout")
310+
return
311+
default:
312+
conn, err = net.Dial("unix", remoteSock)
313+
}
314+
if err == nil {
315+
break
316+
}
317+
}
318+
_, err = conn.Write([]byte("test"))
319+
assert.NoError(t, err)
320+
err = conn.Close()
321+
assert.NoError(t, err)
322+
}()
323+
324+
msg := testutil.RequireRecvCtx(ctx, t, msgs)
325+
require.Equal(t, "test", msg)
326+
fsn.Notify()
327+
<-cmdDone
328+
fsn.AssertStopped()
329+
}
330+
_, err = os.Stat(remoteSock)
331+
require.ErrorIs(t, err, os.ErrNotExist, "didn't clean up remote socket")
332+
})
333+
248334
t.Run("StdioExitOnStop", func(t *testing.T) {
249335
t.Parallel()
250336
if runtime.GOOS == "windows" {

0 commit comments

Comments
 (0)