diff --git a/agent/agentssh/forward.go b/agent/agentssh/forward.go index 1e3635fd8ff91..ac5e5ac7100f8 100644 --- a/agent/agentssh/forward.go +++ b/agent/agentssh/forward.go @@ -37,6 +37,7 @@ type forwardedUnixHandler struct { } func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) { + h.log.Debug(ctx, "handling SSH unix forward") h.Lock() if h.forwards == nil { h.forwards = make(map[string]net.Listener) @@ -47,22 +48,25 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection") return false, nil } + log := h.log.With(slog.F("remote_addr", conn.RemoteAddr())) switch req.Type { case "streamlocal-forward@openssh.com": var reqPayload streamLocalForwardPayload err := gossh.Unmarshal(req.Payload, &reqPayload) if err != nil { - h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request payload from client", slog.Error(err)) + h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request (SSH unix forward) payload from client", slog.Error(err)) return false, nil } addr := reqPayload.SocketPath + log = log.With(slog.F("socket_path", addr)) + log.Debug(ctx, "request begin SSH unix forward") h.Lock() _, ok := h.forwards[addr] h.Unlock() if ok { - h.log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)", + log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)", slog.F("socket_path", addr), ) return false, nil @@ -72,9 +76,8 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, parentDir := filepath.Dir(addr) err = os.MkdirAll(parentDir, 0o700) if err != nil { - h.log.Warn(ctx, "create parent dir for SSH unix forward request", + log.Warn(ctx, "create parent dir for SSH unix forward request", slog.F("parent_dir", parentDir), - slog.F("socket_path", addr), slog.Error(err), ) return false, nil @@ -82,12 +85,13 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, ln, err := net.Listen("unix", addr) if err != nil { - h.log.Warn(ctx, "listen on Unix socket for SSH unix forward request", + log.Warn(ctx, "listen on Unix socket for SSH unix forward request", slog.F("socket_path", addr), slog.Error(err), ) return false, nil } + log.Debug(ctx, "SSH unix forward listening on socket") // The listener needs to successfully start before it can be added to // the map, so we don't have to worry about checking for an existing @@ -97,6 +101,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, h.Lock() h.forwards[addr] = ln h.Unlock() + log.Debug(ctx, "SSH unix forward added to cache") ctx, cancel := context.WithCancel(ctx) go func() { @@ -110,14 +115,15 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, c, err := ln.Accept() if err != nil { if !xerrors.Is(err, net.ErrClosed) { - h.log.Warn(ctx, "accept on local Unix socket for SSH unix forward request", - slog.F("socket_path", addr), + log.Warn(ctx, "accept on local Unix socket for SSH unix forward request", slog.Error(err), ) } // closed below + log.Debug(ctx, "SSH unix forward listener closed") break } + log.Debug(ctx, "accepted SSH unix forward connection") payload := gossh.Marshal(&forwardedStreamLocalPayload{ SocketPath: addr, }) @@ -125,7 +131,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, go func() { ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload) if err != nil { - h.log.Warn(ctx, "open SSH channel to forward Unix connection to client", + h.log.Warn(ctx, "open SSH unix forward channel to client", slog.F("socket_path", addr), slog.Error(err), ) @@ -143,6 +149,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, delete(h.forwards, addr) } h.Unlock() + log.Debug(ctx, "SSH unix forward listener removed from cache", slog.F("path", addr)) _ = ln.Close() }() @@ -152,9 +159,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, var reqPayload streamLocalForwardPayload err := gossh.Unmarshal(req.Payload, &reqPayload) if err != nil { - h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com request payload from client", slog.Error(err)) + h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client", slog.Error(err)) return false, nil } + log.Debug(ctx, "request to cancel SSH unix forward", slog.F("path", reqPayload.SocketPath)) h.Lock() ln, ok := h.forwards[reqPayload.SocketPath] h.Unlock() diff --git a/cli/agent.go b/cli/agent.go index 8a836cd4c3c04..fa62761cc1716 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -8,7 +8,6 @@ import ( "net/http/pprof" "net/url" "os" - "os/signal" "path/filepath" "runtime" "strconv" @@ -144,7 +143,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { // Note that we don't want to handle these signals in the // process that runs as PID 1, that's why we do this after // the reaper forked. - ctx, stopNotify := signal.NotifyContext(ctx, InterruptSignals...) + ctx, stopNotify := inv.SignalNotifyContext(ctx, InterruptSignals...) defer stopNotify() // DumpHandler does signal handling, so we call it after the diff --git a/cli/clibase/cmd.go b/cli/clibase/cmd.go index c3729d2d586cb..a2ca36b2c9142 100644 --- a/cli/clibase/cmd.go +++ b/cli/clibase/cmd.go @@ -7,7 +7,9 @@ import ( "fmt" "io" "os" + "os/signal" "strings" + "testing" "unicode" "github.com/spf13/pflag" @@ -183,6 +185,9 @@ type Invocation struct { Stdout io.Writer Stderr io.Writer Stdin io.Reader + + // testing + signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) } // WithOS returns the invocation as a main package, filling in the invocation's unset @@ -197,6 +202,26 @@ func (inv *Invocation) WithOS() *Invocation { }) } +// WithTestSignalNotifyContext allows overriding the default implementation of SignalNotifyContext. +// This should only be used in testing. +func (inv *Invocation) WithTestSignalNotifyContext( + _ testing.TB, // ensure we only call this from tests + f func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc), +) *Invocation { + return inv.with(func(i *Invocation) { + i.signalNotifyContext = f + }) +} + +// SignalNotifyContext is equivalent to signal.NotifyContext, but supports being overridden in +// tests. +func (inv *Invocation) SignalNotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) { + if inv.signalNotifyContext == nil { + return signal.NotifyContext(parent, signals...) + } + return inv.signalNotifyContext(parent, signals...) +} + func (inv *Invocation) Context() context.Context { if inv.ctx == nil { return context.Background() diff --git a/cli/clitest/signal.go b/cli/clitest/signal.go new file mode 100644 index 0000000000000..2de73a1a01ecd --- /dev/null +++ b/cli/clitest/signal.go @@ -0,0 +1,59 @@ +package clitest + +import ( + "context" + "os" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +type FakeSignalNotifier struct { + sync.Mutex + t *testing.T + ctx context.Context + cancel context.CancelFunc + signals []os.Signal + stopped bool +} + +func NewFakeSignalNotifier(t *testing.T) *FakeSignalNotifier { + fsn := &FakeSignalNotifier{t: t} + return fsn +} + +func (f *FakeSignalNotifier) Stop() { + f.Lock() + defer f.Unlock() + f.stopped = true + if f.cancel == nil { + f.t.Error("stopped before started") + return + } + f.cancel() +} + +func (f *FakeSignalNotifier) NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) { + f.Lock() + defer f.Unlock() + f.signals = signals + f.ctx, f.cancel = context.WithCancel(parent) + return f.ctx, f.Stop +} + +func (f *FakeSignalNotifier) Notify() { + f.Lock() + defer f.Unlock() + if f.cancel == nil { + f.t.Error("notified before started") + return + } + f.cancel() +} + +func (f *FakeSignalNotifier) AssertStopped() { + f.Lock() + defer f.Unlock() + assert.True(f.t, f.stopped) +} diff --git a/cli/externalauth.go b/cli/externalauth.go index c81795d95d6fc..7230db894ac4e 100644 --- a/cli/externalauth.go +++ b/cli/externalauth.go @@ -2,7 +2,6 @@ package cli import ( "encoding/json" - "os/signal" "golang.org/x/xerrors" @@ -63,7 +62,7 @@ fi Handler: func(inv *clibase.Invocation) error { ctx := inv.Context() - ctx, stop := signal.NotifyContext(ctx, InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...) defer stop() client, err := r.createAgentClient() diff --git a/cli/gitaskpass.go b/cli/gitaskpass.go index 83ac98094e72e..ddfd05af9d1f9 100644 --- a/cli/gitaskpass.go +++ b/cli/gitaskpass.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/http" - "os/signal" "time" "golang.org/x/xerrors" @@ -26,7 +25,7 @@ func (r *RootCmd) gitAskpass() *clibase.Cmd { Handler: func(inv *clibase.Invocation) error { ctx := inv.Context() - ctx, stop := signal.NotifyContext(ctx, InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...) defer stop() user, host, err := gitauth.ParseAskpass(inv.Args[0]) diff --git a/cli/gitssh.go b/cli/gitssh.go index ea461394c3241..b627b3911b820 100644 --- a/cli/gitssh.go +++ b/cli/gitssh.go @@ -8,7 +8,6 @@ import ( "io" "os" "os/exec" - "os/signal" "path/filepath" "strings" @@ -30,7 +29,7 @@ func (r *RootCmd) gitssh() *clibase.Cmd { // Catch interrupt signals to ensure the temporary private // key file is cleaned up on most cases. - ctx, stop := signal.NotifyContext(ctx, InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...) defer stop() // Early check so errors are reported immediately. diff --git a/cli/server.go b/cli/server.go index 3baded2363bf6..b09f814068343 100644 --- a/cli/server.go +++ b/cli/server.go @@ -22,7 +22,6 @@ import ( "net/http/pprof" "net/url" "os" - "os/signal" "os/user" "path/filepath" "regexp" @@ -333,7 +332,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // // To get out of a graceful shutdown, the user can send // SIGQUIT with ctrl+\ or SIGKILL with `kill -9`. - notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...) + notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, InterruptSignals...) defer notifyStop() cacheDir := vals.CacheDir.String() @@ -1098,7 +1097,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. logger = logger.Leveled(slog.LevelDebug) } - ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...) + ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...) defer cancel() url, closePg, err := startBuiltinPostgres(ctx, cfg, logger) diff --git a/cli/server_createadminuser.go b/cli/server_createadminuser.go index fa82e4fbcd051..3200177b6dc54 100644 --- a/cli/server_createadminuser.go +++ b/cli/server_createadminuser.go @@ -4,7 +4,6 @@ package cli import ( "fmt" - "os/signal" "sort" "github.com/google/uuid" @@ -48,7 +47,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *clibase.Cmd { logger = logger.Leveled(slog.LevelDebug) } - ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...) + ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...) defer cancel() if newUserDBURL == "" { diff --git a/cli/ssh.go b/cli/ssh.go index dbff0ea52017e..fd0474281959a 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -62,7 +62,15 @@ func (r *RootCmd) ssh() *clibase.Cmd { r.InitClient(client), ), Handler: func(inv *clibase.Invocation) (retErr error) { - ctx, cancel := context.WithCancel(inv.Context()) + // Before dialing the SSH server over TCP, capture Interrupt signals + // so that if we are interrupted, we have a chance to tear down the + // TCP session cleanly before exiting. If we don't, then the TCP + // session can persist for up to 72 hours, since we set a long + // timeout on the Agent side of the connection. In particular, + // OpenSSH sends SIGHUP to terminate a proxy command. + ctx, stop := inv.SignalNotifyContext(inv.Context(), InterruptSignals...) + defer stop() + ctx, cancel := context.WithCancel(ctx) defer cancel() logger := slog.Make() // empty logger @@ -227,8 +235,7 @@ func (r *RootCmd) ssh() *clibase.Cmd { go func() { defer wg.Done() // Ensure stdout copy closes incase stdin is closed - // unexpectedly. Typically we wouldn't worry about - // this since OpenSSH should kill the proxy command. + // unexpectedly. defer rawSSH.Close() _, err := io.Copy(rawSSH, inv.Stdin) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 9a0f6b6f6e4e2..77eae7d3b05fa 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -14,12 +14,15 @@ import ( "net/http/httptest" "os" "os/exec" + "path" "path/filepath" "runtime" "strings" "testing" "time" + "golang.org/x/xerrors" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -245,6 +248,119 @@ func TestSSH(t *testing.T) { <-cmdDone }) + // Test that we handle OS signals properly while remote forwarding, and don't just leave the TCP + // socket hanging. + t.Run("RemoteForward_Unix_Signal", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("No unix sockets on windows") + } + t.Parallel() + ctx := testutil.Context(t, testutil.WaitSuperLong) + client, workspace, agentToken := setupWorkspaceForAgent(t, nil) + _, _ = tGoContext(t, func(ctx context.Context) { + // Run this async so the SSH command has to wait for + // the build and agent to connect! + _ = agenttest.New(t, client.URL, agentToken) + <-ctx.Done() + }) + + tmpdir := tempDirUnixSocket(t) + localSock := filepath.Join(tmpdir, "local.sock") + l, err := net.Listen("unix", localSock) + require.NoError(t, err) + defer l.Close() + remoteSock := path.Join(tmpdir, "remote.sock") + for i := 0; i < 2; i++ { + t.Logf("connect %d of 2", i+1) + inv, root := clitest.New(t, + "ssh", + workspace.Name, + "--remote-forward", + remoteSock+":"+localSock, + ) + fsn := clitest.NewFakeSignalNotifier(t) + inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext) + inv.Stdout = io.Discard + inv.Stderr = io.Discard + + clitest.SetupConfig(t, client, root) + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.Error(t, err) + }) + + // accept a single connection + msgs := make(chan string, 1) + go func() { + conn, err := l.Accept() + if !assert.NoError(t, err) { + return + } + msg, err := io.ReadAll(conn) + if !assert.NoError(t, err) { + return + } + msgs <- string(msg) + }() + + // Unfortunately, there is a race in crypto/ssh where it sends the request to forward + // unix sockets before it is prepared to receive the response, meaning that even after + // the socket exists on the file system, the client might not be ready to accept the + // channel. + // + // https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33 + // + // To work around this, we attempt to send messages in a loop until one succeeds + success := make(chan struct{}) + go func() { + var ( + conn net.Conn + err error + ) + for { + time.Sleep(testutil.IntervalMedium) + select { + case <-ctx.Done(): + t.Error("timeout") + return + case <-success: + return + default: + // Ok + } + conn, err = net.Dial("unix", remoteSock) + if err != nil { + t.Logf("dial error: %s", err) + continue + } + _, err = conn.Write([]byte("test")) + if err != nil { + t.Logf("write error: %s", err) + } + err = conn.Close() + if err != nil { + t.Logf("close error: %s", err) + } + } + }() + + msg := testutil.RequireRecvCtx(ctx, t, msgs) + require.Equal(t, "test", msg) + close(success) + fsn.Notify() + <-cmdDone + fsn.AssertStopped() + + // wait for the remote socket to get cleaned up before retrying, + // because cleaning up the socket happens asynchronously, and we + // might connect to an old listener on the agent side. + require.Eventually(t, func() bool { + _, err = os.Stat(remoteSock) + return xerrors.Is(err, os.ErrNotExist) + }, testutil.WaitShort, testutil.IntervalFast) + } + }) + t.Run("StdioExitOnStop", func(t *testing.T) { t.Parallel() if runtime.GOOS == "windows" { diff --git a/enterprise/cli/provisionerdaemons.go b/enterprise/cli/provisionerdaemons.go index 623368487ff68..6583f1b1a8329 100644 --- a/enterprise/cli/provisionerdaemons.go +++ b/enterprise/cli/provisionerdaemons.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "os" - "os/signal" "time" "github.com/google/uuid" @@ -61,7 +60,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { ctx, cancel := context.WithCancel(inv.Context()) defer cancel() - notifyCtx, notifyStop := signal.NotifyContext(ctx, agpl.InterruptSignals...) + notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, agpl.InterruptSignals...) defer notifyStop() tags, err := agpl.ParseProvisionerTags(rawTags) diff --git a/enterprise/cli/proxyserver.go b/enterprise/cli/proxyserver.go index 0d7e92531342f..c245426b51052 100644 --- a/enterprise/cli/proxyserver.go +++ b/enterprise/cli/proxyserver.go @@ -10,7 +10,6 @@ import ( "net" "net/http" "net/http/pprof" - "os/signal" "regexp" rpprof "runtime/pprof" "time" @@ -142,7 +141,7 @@ func (*RootCmd) proxyServer() *clibase.Cmd { // // To get out of a graceful shutdown, the user can send // SIGQUIT with ctrl+\ or SIGKILL with `kill -9`. - notifyCtx, notifyStop := signal.NotifyContext(ctx, cli.InterruptSignals...) + notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, cli.InterruptSignals...) defer notifyStop() // Clean up idle connections at the end, e.g. diff --git a/tailnet/conn.go b/tailnet/conn.go index f9c82bc753c92..c785e7fabbe96 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -788,6 +788,7 @@ func (c *Conn) Close() error { } _ = c.netStack.Close() + c.logger.Debug(context.Background(), "closed netstack") c.dialCancel() _ = c.wireguardMonitor.Close() _ = c.dialer.Close() diff --git a/testutil/ctx.go b/testutil/ctx.go index e23c48da85722..2cc44c5bad8d7 100644 --- a/testutil/ctx.go +++ b/testutil/ctx.go @@ -11,3 +11,14 @@ func Context(t *testing.T, dur time.Duration) context.Context { t.Cleanup(cancel) return ctx } + +func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A) { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timeout") + return a + case a = <-c: + return a + } +}