Skip to content

fix: handle SIGHUP from OpenSSH #10638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions agent/agentssh/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type forwardedUnixHandler struct {
}

func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
h.log.Debug(ctx, "handling SSH unix forward")
h.Lock()
if h.forwards == nil {
h.forwards = make(map[string]net.Listener)
Expand All @@ -47,22 +48,25 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
return false, nil
}
log := h.log.With(slog.F("remote_addr", conn.RemoteAddr()))

switch req.Type {
case "streamlocal-forward@openssh.com":
var reqPayload streamLocalForwardPayload
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request payload from client", slog.Error(err))
h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request (SSH unix forward) payload from client", slog.Error(err))
return false, nil
}

addr := reqPayload.SocketPath
log = log.With(slog.F("socket_path", addr))
log.Debug(ctx, "request begin SSH unix forward")
h.Lock()
_, ok := h.forwards[addr]
h.Unlock()
if ok {
h.log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
slog.F("socket_path", addr),
)
return false, nil
Expand All @@ -72,22 +76,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
parentDir := filepath.Dir(addr)
err = os.MkdirAll(parentDir, 0o700)
if err != nil {
h.log.Warn(ctx, "create parent dir for SSH unix forward request",
log.Warn(ctx, "create parent dir for SSH unix forward request",
slog.F("parent_dir", parentDir),
slog.F("socket_path", addr),
slog.Error(err),
)
return false, nil
}

ln, err := net.Listen("unix", addr)
if err != nil {
h.log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
slog.F("socket_path", addr),
slog.Error(err),
)
return false, nil
}
log.Debug(ctx, "SSH unix forward listening on socket")

// The listener needs to successfully start before it can be added to
// the map, so we don't have to worry about checking for an existing
Expand All @@ -97,6 +101,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
h.Lock()
h.forwards[addr] = ln
h.Unlock()
log.Debug(ctx, "SSH unix forward added to cache")

ctx, cancel := context.WithCancel(ctx)
go func() {
Expand All @@ -110,22 +115,23 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
c, err := ln.Accept()
if err != nil {
if !xerrors.Is(err, net.ErrClosed) {
h.log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
slog.F("socket_path", addr),
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
slog.Error(err),
)
}
// closed below
log.Debug(ctx, "SSH unix forward listener closed")
break
}
log.Debug(ctx, "accepted SSH unix forward connection")
payload := gossh.Marshal(&forwardedStreamLocalPayload{
SocketPath: addr,
})

go func() {
ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload)
if err != nil {
h.log.Warn(ctx, "open SSH channel to forward Unix connection to client",
h.log.Warn(ctx, "open SSH unix forward channel to client",
slog.F("socket_path", addr),
slog.Error(err),
)
Expand All @@ -143,6 +149,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
delete(h.forwards, addr)
}
h.Unlock()
log.Debug(ctx, "SSH unix forward listener removed from cache", slog.F("path", addr))
_ = ln.Close()
}()

Expand All @@ -152,9 +159,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
var reqPayload streamLocalForwardPayload
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com request payload from client", slog.Error(err))
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client", slog.Error(err))
return false, nil
}
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("path", reqPayload.SocketPath))
h.Lock()
ln, ok := h.forwards[reqPayload.SocketPath]
h.Unlock()
Expand Down
3 changes: 1 addition & 2 deletions cli/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"net/http/pprof"
"net/url"
"os"
"os/signal"
"path/filepath"
"runtime"
"strconv"
Expand Down Expand Up @@ -144,7 +143,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
// Note that we don't want to handle these signals in the
// process that runs as PID 1, that's why we do this after
// the reaper forked.
ctx, stopNotify := signal.NotifyContext(ctx, InterruptSignals...)
ctx, stopNotify := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer stopNotify()

// DumpHandler does signal handling, so we call it after the
Expand Down
25 changes: 25 additions & 0 deletions cli/clibase/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"fmt"
"io"
"os"
"os/signal"
"strings"
"testing"
"unicode"

"github.com/spf13/pflag"
Expand Down Expand Up @@ -183,6 +185,9 @@ type Invocation struct {
Stdout io.Writer
Stderr io.Writer
Stdin io.Reader

// testing
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
}

// WithOS returns the invocation as a main package, filling in the invocation's unset
Expand All @@ -197,6 +202,26 @@ func (inv *Invocation) WithOS() *Invocation {
})
}

// WithTestSignalNotifyContext allows overriding the default implementation of SignalNotifyContext.
// This should only be used in testing.
func (inv *Invocation) WithTestSignalNotifyContext(
_ testing.TB, // ensure we only call this from tests
f func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc),
) *Invocation {
return inv.with(func(i *Invocation) {
i.signalNotifyContext = f
})
}

// SignalNotifyContext is equivalent to signal.NotifyContext, but supports being overridden in
// tests.
func (inv *Invocation) SignalNotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
if inv.signalNotifyContext == nil {
return signal.NotifyContext(parent, signals...)
}
return inv.signalNotifyContext(parent, signals...)
}

func (inv *Invocation) Context() context.Context {
if inv.ctx == nil {
return context.Background()
Expand Down
59 changes: 59 additions & 0 deletions cli/clitest/signal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package clitest

import (
"context"
"os"
"sync"
"testing"

"github.com/stretchr/testify/assert"
)

type FakeSignalNotifier struct {
sync.Mutex
t *testing.T
ctx context.Context
cancel context.CancelFunc
signals []os.Signal
stopped bool
}

func NewFakeSignalNotifier(t *testing.T) *FakeSignalNotifier {
fsn := &FakeSignalNotifier{t: t}
return fsn
}

func (f *FakeSignalNotifier) Stop() {
f.Lock()
defer f.Unlock()
f.stopped = true
if f.cancel == nil {
f.t.Error("stopped before started")
return
}
f.cancel()
}

func (f *FakeSignalNotifier) NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
f.Lock()
defer f.Unlock()
f.signals = signals
f.ctx, f.cancel = context.WithCancel(parent)
return f.ctx, f.Stop
}

func (f *FakeSignalNotifier) Notify() {
f.Lock()
defer f.Unlock()
if f.cancel == nil {
f.t.Error("notified before started")
return
}
f.cancel()
}

func (f *FakeSignalNotifier) AssertStopped() {
f.Lock()
defer f.Unlock()
assert.True(f.t, f.stopped)
}
3 changes: 1 addition & 2 deletions cli/externalauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cli

import (
"encoding/json"
"os/signal"

"golang.org/x/xerrors"

Expand Down Expand Up @@ -63,7 +62,7 @@ fi
Handler: func(inv *clibase.Invocation) error {
ctx := inv.Context()

ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer stop()

client, err := r.createAgentClient()
Expand Down
3 changes: 1 addition & 2 deletions cli/gitaskpass.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"os/signal"
"time"

"golang.org/x/xerrors"
Expand All @@ -26,7 +25,7 @@ func (r *RootCmd) gitAskpass() *clibase.Cmd {
Handler: func(inv *clibase.Invocation) error {
ctx := inv.Context()

ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer stop()

user, host, err := gitauth.ParseAskpass(inv.Args[0])
Expand Down
3 changes: 1 addition & 2 deletions cli/gitssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"

Expand All @@ -30,7 +29,7 @@ func (r *RootCmd) gitssh() *clibase.Cmd {

// Catch interrupt signals to ensure the temporary private
// key file is cleaned up on most cases.
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer stop()

// Early check so errors are reported immediately.
Expand Down
5 changes: 2 additions & 3 deletions cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"net/http/pprof"
"net/url"
"os"
"os/signal"
"os/user"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -333,7 +332,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
//
// To get out of a graceful shutdown, the user can send
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...)
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer notifyStop()

cacheDir := vals.CacheDir.String()
Expand Down Expand Up @@ -1098,7 +1097,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
logger = logger.Leveled(slog.LevelDebug)
}

ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer cancel()

url, closePg, err := startBuiltinPostgres(ctx, cfg, logger)
Expand Down
3 changes: 1 addition & 2 deletions cli/server_createadminuser.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package cli

import (
"fmt"
"os/signal"
"sort"

"github.com/google/uuid"
Expand Down Expand Up @@ -48,7 +47,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *clibase.Cmd {
logger = logger.Leveled(slog.LevelDebug)
}

ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer cancel()

if newUserDBURL == "" {
Expand Down
13 changes: 10 additions & 3 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,15 @@ func (r *RootCmd) ssh() *clibase.Cmd {
r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) (retErr error) {
ctx, cancel := context.WithCancel(inv.Context())
// Before dialing the SSH server over TCP, capture Interrupt signals
// so that if we are interrupted, we have a chance to tear down the
// TCP session cleanly before exiting. If we don't, then the TCP
// session can persist for up to 72 hours, since we set a long
// timeout on the Agent side of the connection. In particular,
// OpenSSH sends SIGHUP to terminate a proxy command.
ctx, stop := inv.SignalNotifyContext(inv.Context(), InterruptSignals...)
defer stop()
ctx, cancel := context.WithCancel(ctx)
defer cancel()

logger := slog.Make() // empty logger
Expand Down Expand Up @@ -227,8 +235,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
go func() {
defer wg.Done()
// Ensure stdout copy closes incase stdin is closed
// unexpectedly. Typically we wouldn't worry about
// this since OpenSSH should kill the proxy command.
// unexpectedly.
defer rawSSH.Close()

_, err := io.Copy(rawSSH, inv.Stdin)
Expand Down
Loading