Skip to content

Commit 2a4be8f

Browse files
committed
feat(agent/agentssh): handle session signals
1 parent d58239b commit 2a4be8f

10 files changed

+308
-9
lines changed

agent/agentssh/agentssh.go

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,10 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
311311
if isPty {
312312
return s.startPTYSession(logger, session, magicTypeLabel, cmd, sshPty, windowSize)
313313
}
314-
return s.startNonPTYSession(session, magicTypeLabel, cmd.AsExec())
314+
return s.startNonPTYSession(logger, session, magicTypeLabel, cmd.AsExec())
315315
}
316316

317-
func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
317+
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
318318
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)
319319

320320
cmd.Stdout = session
@@ -338,6 +338,17 @@ func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string,
338338
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "no", "start_command").Add(1)
339339
return xerrors.Errorf("start: %w", err)
340340
}
341+
sigs := make(chan ssh.Signal, 1)
342+
session.Signals(sigs)
343+
defer func() {
344+
session.Signals(nil)
345+
close(sigs)
346+
}()
347+
go func() {
348+
for sig := range sigs {
349+
s.handleSignal(logger, sig, cmd.Process, magicTypeLabel)
350+
}
351+
}()
341352
return cmd.Wait()
342353
}
343354

@@ -348,6 +359,7 @@ type ptySession interface {
348359
Context() ssh.Context
349360
DisablePTYEmulation()
350361
RawCommand() string
362+
Signals(chan<- ssh.Signal)
351363
}
352364

353365
func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTypeLabel string, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
@@ -403,13 +415,36 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
403415
}
404416
}
405417
}()
418+
sigs := make(chan ssh.Signal, 1)
419+
session.Signals(sigs)
420+
defer func() {
421+
session.Signals(nil)
422+
close(sigs)
423+
}()
406424
go func() {
407-
for win := range windowSize {
408-
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
409-
// If the pty is closed, then command has exited, no need to log.
410-
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
411-
logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
412-
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1)
425+
for {
426+
if sigs == nil && windowSize == nil {
427+
return
428+
}
429+
430+
select {
431+
case sig, ok := <-sigs:
432+
if !ok {
433+
sigs = nil
434+
continue
435+
}
436+
s.handleSignal(logger, sig, process, magicTypeLabel)
437+
case win, ok := <-windowSize:
438+
if !ok {
439+
windowSize = nil
440+
continue
441+
}
442+
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
443+
// If the pty is closed, then command has exited, no need to log.
444+
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
445+
logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
446+
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1)
447+
}
413448
}
414449
}
415450
}()
@@ -452,6 +487,18 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
452487
return nil
453488
}
454489

490+
func (s *Server) handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, magicTypeLabel string) {
491+
ctx := context.Background()
492+
sig := osSignalFrom(ssig)
493+
logger = logger.With(slog.F("ssh_signal", ssig), slog.F("signal", sig.String()))
494+
logger.Info(ctx, "received signal")
495+
err := signaler.Signal(sig)
496+
if err != nil {
497+
logger.Warn(ctx, "signal failed", slog.Error(err))
498+
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
499+
}
500+
}
501+
455502
func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
456503
s.metrics.sftpConnectionsTotal.Add(1)
457504

agent/agentssh/agentssh_internal_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ type testSSHContext struct {
114114
context.Context
115115
}
116116

117+
var (
118+
_ gliderssh.Context = testSSHContext{}
119+
_ ptySession = &testSession{}
120+
)
121+
117122
func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) {
118123
toClient, fromPty := io.Pipe()
119124
toPty, fromClient := io.Pipe()
@@ -144,6 +149,10 @@ func (s *testSession) Write(p []byte) (n int, err error) {
144149
return s.fromPty.Write(p)
145150
}
146151

152+
func (*testSession) Signals(_ chan<- gliderssh.Signal) {
153+
// Not implemented, but will be called.
154+
}
155+
147156
func (testSSHContext) Lock() {
148157
panic("not implemented")
149158
}

agent/agentssh/agentssh_test.go

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
package agentssh_test
44

55
import (
6+
"bufio"
67
"bytes"
78
"context"
9+
"fmt"
810
"net"
911
"runtime"
1012
"strings"
@@ -24,6 +26,7 @@ import (
2426
"github.com/coder/coder/v2/agent/agentssh"
2527
"github.com/coder/coder/v2/codersdk/agentsdk"
2628
"github.com/coder/coder/v2/pty/ptytest"
29+
"github.com/coder/coder/v2/testutil"
2730
)
2831

2932
func TestMain(m *testing.M) {
@@ -57,8 +60,8 @@ func TestNewServer_ServeClient(t *testing.T) {
5760

5861
var b bytes.Buffer
5962
sess, err := c.NewSession()
60-
sess.Stdout = &b
6163
require.NoError(t, err)
64+
sess.Stdout = &b
6265
err = sess.Start("echo hello")
6366
require.NoError(t, err)
6467

@@ -139,6 +142,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
139142
defer wg.Done()
140143
c := sshClient(t, ln.Addr().String())
141144
sess, err := c.NewSession()
145+
require.NoError(t, err)
142146
sess.Stdin = pty.Input()
143147
sess.Stdout = pty.Output()
144148
sess.Stderr = pty.Output()
@@ -159,6 +163,142 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
159163
wg.Wait()
160164
}
161165

166+
func TestNewServer_Signal(t *testing.T) {
167+
t.Parallel()
168+
169+
t.Run("Stdout", func(t *testing.T) {
170+
t.Parallel()
171+
172+
ctx := context.Background()
173+
logger := slogtest.Make(t, nil)
174+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
175+
require.NoError(t, err)
176+
defer s.Close()
177+
178+
// The assumption is that these are set before serving SSH connections.
179+
s.AgentToken = func() string { return "" }
180+
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
181+
182+
ln, err := net.Listen("tcp", "127.0.0.1:0")
183+
require.NoError(t, err)
184+
185+
done := make(chan struct{})
186+
go func() {
187+
defer close(done)
188+
err := s.Serve(ln)
189+
assert.Error(t, err) // Server is closed.
190+
}()
191+
defer func() {
192+
err := s.Close()
193+
require.NoError(t, err)
194+
<-done
195+
}()
196+
197+
c := sshClient(t, ln.Addr().String())
198+
199+
sess, err := c.NewSession()
200+
require.NoError(t, err)
201+
r, err := sess.StdoutPipe()
202+
require.NoError(t, err)
203+
204+
// Perform multiple sleeps since the interrupt signal doesn't propagate to
205+
// the process group, this lets us exit early.
206+
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
207+
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
208+
require.NoError(t, err)
209+
210+
sc := bufio.NewScanner(r)
211+
for sc.Scan() {
212+
t.Log(sc.Text())
213+
if strings.Contains(sc.Text(), "hello") {
214+
break
215+
}
216+
}
217+
require.NoError(t, sc.Err())
218+
219+
err = sess.Signal(ssh.SIGINT)
220+
require.NoError(t, err)
221+
222+
// Assumption, signal propagates and the command exists, closing stdout.
223+
for sc.Scan() {
224+
t.Log(sc.Text())
225+
require.NotContains(t, sc.Text(), "bye")
226+
}
227+
require.NoError(t, sc.Err())
228+
229+
err = sess.Wait()
230+
require.Error(t, err)
231+
})
232+
t.Run("PTY", func(t *testing.T) {
233+
t.Parallel()
234+
235+
ctx := context.Background()
236+
logger := slogtest.Make(t, nil)
237+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
238+
require.NoError(t, err)
239+
defer s.Close()
240+
241+
// The assumption is that these are set before serving SSH connections.
242+
s.AgentToken = func() string { return "" }
243+
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
244+
245+
ln, err := net.Listen("tcp", "127.0.0.1:0")
246+
require.NoError(t, err)
247+
248+
done := make(chan struct{})
249+
go func() {
250+
defer close(done)
251+
err := s.Serve(ln)
252+
assert.Error(t, err) // Server is closed.
253+
}()
254+
defer func() {
255+
err := s.Close()
256+
require.NoError(t, err)
257+
<-done
258+
}()
259+
260+
c := sshClient(t, ln.Addr().String())
261+
262+
sess, err := c.NewSession()
263+
require.NoError(t, err)
264+
r, err := sess.StdoutPipe()
265+
require.NoError(t, err)
266+
267+
// Note, we request pty but don't use ptytest here because we can't
268+
// easily test for no text before EOF.
269+
err = sess.RequestPty("xterm", 80, 80, nil)
270+
require.NoError(t, err)
271+
272+
// Perform multiple sleeps since the interrupt signal doesn't propagate to
273+
// the process group, this lets us exit early.
274+
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
275+
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
276+
require.NoError(t, err)
277+
278+
sc := bufio.NewScanner(r)
279+
for sc.Scan() {
280+
t.Log(sc.Text())
281+
if strings.Contains(sc.Text(), "hello") {
282+
break
283+
}
284+
}
285+
require.NoError(t, sc.Err())
286+
287+
err = sess.Signal(ssh.SIGINT)
288+
require.NoError(t, err)
289+
290+
// Assumption, signal propagates and the command exists, closing stdout.
291+
for sc.Scan() {
292+
t.Log(sc.Text())
293+
require.NotContains(t, sc.Text(), "bye")
294+
}
295+
require.NoError(t, sc.Err())
296+
297+
err = sess.Wait()
298+
require.Error(t, err)
299+
})
300+
}
301+
162302
func sshClient(t *testing.T, addr string) *ssh.Client {
163303
conn, err := net.Dial("tcp", addr)
164304
require.NoError(t, err)

agent/agentssh/signal_other.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//go:build !windows
2+
3+
package agentssh
4+
5+
import (
6+
"os"
7+
8+
"github.com/gliderlabs/ssh"
9+
"golang.org/x/sys/unix"
10+
)
11+
12+
func osSignalFrom(sig ssh.Signal) os.Signal {
13+
switch sig {
14+
case ssh.SIGABRT:
15+
return unix.SIGABRT
16+
case ssh.SIGALRM:
17+
return unix.SIGALRM
18+
case ssh.SIGFPE:
19+
return unix.SIGFPE
20+
case ssh.SIGHUP:
21+
return unix.SIGHUP
22+
case ssh.SIGILL:
23+
return unix.SIGILL
24+
case ssh.SIGINT:
25+
return unix.SIGINT
26+
case ssh.SIGKILL:
27+
return unix.SIGKILL
28+
case ssh.SIGPIPE:
29+
return unix.SIGPIPE
30+
case ssh.SIGQUIT:
31+
return unix.SIGQUIT
32+
case ssh.SIGSEGV:
33+
return unix.SIGSEGV
34+
case ssh.SIGTERM:
35+
return unix.SIGTERM
36+
case ssh.SIGUSR1:
37+
return unix.SIGUSR1
38+
case ssh.SIGUSR2:
39+
return unix.SIGUSR2
40+
41+
// Unhandled, use sane fallback.
42+
default:
43+
return unix.SIGKILL
44+
}
45+
}

agent/agentssh/signal_windows.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package agentssh
2+
3+
import (
4+
"os"
5+
6+
"github.com/gliderlabs/ssh"
7+
)
8+
9+
func osSignalFrom(sig ssh.Signal) os.Signal {
10+
switch sig {
11+
case ssh.SIGINT:
12+
return os.Interrupt
13+
default:
14+
return os.Kill
15+
}
16+
}

pty/pty.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pty
33
import (
44
"io"
55
"log"
6+
"os"
67

78
"github.com/gliderlabs/ssh"
89
"golang.org/x/xerrors"
@@ -69,6 +70,11 @@ type Process interface {
6970

7071
// Kill the command process. Returned error is as for os.Process.Kill()
7172
Kill() error
73+
74+
// Signal sends a signal to the command process. On non-windows systems, the
75+
// returned error is as for os.Process.Signal(), on Windows it's
76+
// as for os.Process.Kill().
77+
Signal(sig os.Signal) error
7278
}
7379

7480
// WithFlags represents a PTY whose flags can be inspected, in particular

pty/pty_other.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ func (p *otherProcess) Kill() error {
170170
return p.cmd.Process.Kill()
171171
}
172172

173+
func (p *otherProcess) Signal(sig os.Signal) error {
174+
return p.cmd.Process.Signal(sig)
175+
}
176+
173177
func (p *otherProcess) waitInternal() {
174178
// The GC can garbage collect the TTY FD before the command
175179
// has finished running. See:

0 commit comments

Comments
 (0)