From 517184bff25c181e8f5a3f8f8a9bf1d0d0abb1f6 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Sat, 12 Mar 2022 17:12:09 -0800 Subject: [PATCH 1/7] add support for terminal opcodes Updates tailscale/tailscale#4146 Signed-off-by: Maisem Ali --- session.go | 2 +- ssh.go | 34 +++++++++++++-- util.go | 125 ++++++++++++++++++++++++++++++++++++++++++----------- 3 files changed, 131 insertions(+), 30 deletions(-) diff --git a/session.go b/session.go index 3a3ad70..fb47581 100644 --- a/session.go +++ b/session.go @@ -346,7 +346,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { req.Reply(false, nil) continue } - win, ok := parseWinchRequest(req.Payload) + win, _, ok := parseWindow(req.Payload) if ok { sess.pty.Window = win sess.winch <- win diff --git a/ssh.go b/ssh.go index fbeb150..7dc76b3 100644 --- a/ssh.go +++ b/ssh.go @@ -69,16 +69,44 @@ type ServerConfigCallback func(ctx Context) *gossh.ServerConfig type ConnectionFailedCallback func(conn net.Conn, err error) // Window represents the size of a PTY window. +// +// From https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 +// +// Zero dimension parameters MUST be ignored. The character/row dimensions +// override the pixel dimensions (when nonzero). Pixel dimensions refer +// to the drawable area of the window. type Window struct { - Width int + // Width is the number of columns. + // It overrides WidthPixels. + Width int + // Height is the number of rows. + // It overrides HeightPixels. Height int + + // WidthPixels is the drawable width of the window, in pixels. + WidthPixels int + // HeightPixels is the drawable height of the window, in pixels. + HeightPixels int } // Pty represents a PTY request and configuration. type Pty struct { - Term string + // Term is the TERM environment variable value. + Term string + + // Window is the Window sent as part of the pty-req. Window Window - // HELP WANTED: terminal modes! + + // Modes represent a mapping of Terminal Mode opcode to value as it was + // requested by the client as part of the pty-req. These are outlined as + // part of https://datatracker.ietf.org/doc/html/rfc4254#section-8. + // + // The opcodes are defined as constants in golang.org/x/crypto/ssh (VINTR,VQUIT,etc.). + // Boolean opcodes have values 0 or 1. + // + // Note: golang.org/x/crypto/ssh currently (2022-03-12) doesn't have a + // definition for opcode 42 "iutf8" which was introduced in https://datatracker.ietf.org/doc/html/rfc8160. + Modes gossh.TerminalModes } // Serve accepts incoming SSH connections on the listener l, creating a new diff --git a/util.go b/util.go index 015a44e..a5e6699 100644 --- a/util.go +++ b/util.go @@ -16,61 +16,134 @@ func generateSigner() (ssh.Signer, error) { return ssh.NewSignerFromKey(key) } -func parsePtyRequest(s []byte) (pty Pty, ok bool) { - term, s, ok := parseString(s) +func parsePtyRequest(payload []byte) (pty Pty, ok bool) { + // From https://datatracker.ietf.org/doc/html/rfc4254 + // 6.2. Requesting a Pseudo-Terminal + // A pseudo-terminal can be allocated for the session by sending the + // following message. + // byte SSH_MSG_CHANNEL_REQUEST + // uint32 recipient channel + // string "pty-req" + // boolean want_reply + // string TERM environment variable value (e.g., vt100) + // uint32 terminal width, characters (e.g., 80) + // uint32 terminal height, rows (e.g., 24) + // uint32 terminal width, pixels (e.g., 640) + // uint32 terminal height, pixels (e.g., 480) + // string encoded terminal modes + + // The payload starts from the TERM variable. + term, rem, ok := parseString(payload) if !ok { return } - width32, s, ok := parseUint32(s) + win, rem, ok := parseWindow(rem) if !ok { return } - height32, _, ok := parseUint32(s) + modes, ok := parseTerminalModes(rem) if !ok { return } pty = Pty{ - Term: term, - Window: Window{ - Width: int(width32), - Height: int(height32), - }, + Term: term, + Window: win, + Modes: modes, } return } -func parseWinchRequest(s []byte) (win Window, ok bool) { - width32, s, ok := parseUint32(s) - if width32 < 1 { - ok = false +func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { + // From https://datatracker.ietf.org/doc/html/rfc4254 + // 8. Encoding of Terminal Modes + // + // All 'encoded terminal modes' (as passed in a pty request) are encoded + // into a byte stream. It is intended that the coding be portable + // across different environments. The stream consists of opcode- + // argument pairs wherein the opcode is a byte value. Opcodes 1 to 159 + // have a single uint32 argument. Opcodes 160 to 255 are not yet + // defined, and cause parsing to stop (they should only be used after + // any other data). The stream is terminated by opcode TTY_OP_END + // (0x00). + // + // The client SHOULD put any modes it knows about in the stream, and the + // server MAY ignore any modes it does not know about. This allows some + // degree of machine-independence, at least between systems that use a + // POSIX-like tty interface. The protocol can support other systems as + // well, but the client may need to fill reasonable values for a number + // of parameters so the server pty gets set to a reasonable mode (the + // server leaves all unspecified mode bits in their default values, and + // only some combinations make sense). + _, rem, ok := parseUint32(in) + if !ok { + return + } + const ttyOpEnd = 0 + for len(rem) > 0 { + if modes == nil { + modes = make(ssh.TerminalModes) + } + code := uint8(rem[0]) + rem = rem[1:] + if code == ttyOpEnd || code > 160 { + break + } + var val uint32 + val, rem, ok = parseUint32(rem) + if !ok { + return + } + modes[code] = val + } + ok = true + return +} + +func parseWindow(s []byte) (win Window, rem []byte, ok bool) { + // 6.7. Window Dimension Change Message + // When the window (terminal) size changes on the client side, it MAY + // send a message to the other side to inform it of the new dimensions. + + // byte SSH_MSG_CHANNEL_REQUEST + // uint32 recipient channel + // string "window-change" + // boolean FALSE + // uint32 terminal width, columns + // uint32 terminal height, rows + // uint32 terminal width, pixels + // uint32 terminal height, pixels + wCols, rem, ok := parseUint32(s) + if !ok { + return } + hRows, rem, ok := parseUint32(rem) if !ok { return } - height32, _, ok := parseUint32(s) - if height32 < 1 { - ok = false + wPixels, rem, ok := parseUint32(rem) + if !ok { + return } + hPixels, rem, ok := parseUint32(rem) if !ok { return } win = Window{ - Width: int(width32), - Height: int(height32), + Width: int(wCols), + Height: int(hRows), + WidthPixels: int(wPixels), + HeightPixels: int(hPixels), } return } -func parseString(in []byte) (out string, rest []byte, ok bool) { - if len(in) < 4 { - return - } - length := binary.BigEndian.Uint32(in) - if uint32(len(in)) < 4+length { +func parseString(in []byte) (out string, rem []byte, ok bool) { + length, rem, ok := parseUint32(in) + if uint32(len(rem)) < length || !ok { + ok = false return } - out = string(in[4 : 4+length]) - rest = in[4+length:] + out, rem = string(rem[:length]), rem[length:] ok = true return } From c39e73e22c80016508691cb106b6d7d68109b3b7 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Sat, 12 Mar 2022 17:18:23 -0800 Subject: [PATCH 2/7] document behavior of NL to CRNL translation in Write and add a way to disable it. Updates tailscale/tailscale#4146 Signed-off-by: Maisem Ali --- session.go | 46 +++++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/session.go b/session.go index fb47581..a8936dc 100644 --- a/session.go +++ b/session.go @@ -82,6 +82,13 @@ type Session interface { // the request handling loop. Registering nil will unregister the channel. // During the time that no channel is registered, breaks are ignored. Break(c chan<- bool) + + // DisablePTYEmulation disables the session's default minimal PTY emulation. + // If you're setting the pty's termios settings from the Pty request, use + // this method to avoid corruption. + // Currently (2022-03-12) the only emulation implemented is NL-to-CRNL translation (`\n`=>`\r\n`). + // A call of DisablePTYEmulation must precede any call to Write. + DisablePTYEmulation() } // maxSigBufSize is how many signals will be buffered @@ -109,26 +116,31 @@ func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.Ne type session struct { sync.Mutex gossh.Channel - conn *gossh.ServerConn - handler Handler - subsystemHandlers map[string]SubsystemHandler - handled bool - exited bool - pty *Pty - winch chan Window - env []string - ptyCb PtyCallback - sessReqCb SessionRequestCallback - rawCmd string - subsystem string - ctx Context - sigCh chan<- Signal - sigBuf []Signal - breakCh chan<- bool + conn *gossh.ServerConn + handler Handler + subsystemHandlers map[string]SubsystemHandler + handled bool + exited bool + pty *Pty + winch chan Window + env []string + ptyCb PtyCallback + sessReqCb SessionRequestCallback + rawCmd string + subsystem string + ctx Context + sigCh chan<- Signal + sigBuf []Signal + breakCh chan<- bool + disablePtyEmulation bool +} + +func (sess *session) DisablePTYEmulation() { + sess.disablePtyEmulation = true } func (sess *session) Write(p []byte) (n int, err error) { - if sess.pty != nil { + if sess.pty != nil && !sess.disablePtyEmulation { m := len(p) // normalize \n to \r\n when pty is accepted. // this is a hardcoded shortcut since we don't support terminal modes. From fcea99919338850e1656717ea7e6400e47145529 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Sat, 12 Mar 2022 17:31:26 -0800 Subject: [PATCH 3/7] address comments from #1 Signed-off-by: Maisem Ali --- ssh.go | 5 +---- util.go | 7 ++++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/ssh.go b/ssh.go index 7dc76b3..8bb02a3 100644 --- a/ssh.go +++ b/ssh.go @@ -70,7 +70,7 @@ type ConnectionFailedCallback func(conn net.Conn, err error) // Window represents the size of a PTY window. // -// From https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 +// See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 // // Zero dimension parameters MUST be ignored. The character/row dimensions // override the pixel dimensions (when nonzero). Pixel dimensions refer @@ -103,9 +103,6 @@ type Pty struct { // // The opcodes are defined as constants in golang.org/x/crypto/ssh (VINTR,VQUIT,etc.). // Boolean opcodes have values 0 or 1. - // - // Note: golang.org/x/crypto/ssh currently (2022-03-12) doesn't have a - // definition for opcode 42 "iutf8" which was introduced in https://datatracker.ietf.org/doc/html/rfc8160. Modes gossh.TerminalModes } diff --git a/util.go b/util.go index a5e6699..3bee06d 100644 --- a/util.go +++ b/util.go @@ -17,7 +17,7 @@ func generateSigner() (ssh.Signer, error) { } func parsePtyRequest(payload []byte) (pty Pty, ok bool) { - // From https://datatracker.ietf.org/doc/html/rfc4254 + // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 // 6.2. Requesting a Pseudo-Terminal // A pseudo-terminal can be allocated for the session by sending the // following message. @@ -54,7 +54,7 @@ func parsePtyRequest(payload []byte) (pty Pty, ok bool) { } func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { - // From https://datatracker.ietf.org/doc/html/rfc4254 + // See https://datatracker.ietf.org/doc/html/rfc4254#section-8 // 8. Encoding of Terminal Modes // // All 'encoded terminal modes' (as passed in a pty request) are encoded @@ -100,7 +100,8 @@ func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { } func parseWindow(s []byte) (win Window, rem []byte, ok bool) { - // 6.7. Window Dimension Change Message + // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.7 + // 6.7. Window Dimension Change Message // When the window (terminal) size changes on the client side, it MAY // send a message to the other side to inform it of the new dimensions. From 04bb837133e11ad897cef16572b3b732f76c4cc4 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Fri, 21 Apr 2023 09:02:25 -0500 Subject: [PATCH 4/7] feat: add support for X11 forwarding (#2) --- server.go | 1 + session.go | 31 +++++++++++++++++++++++++++++++ session_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++--- ssh.go | 13 +++++++++++++ util.go | 6 ++++++ 5 files changed, 94 insertions(+), 3 deletions(-) diff --git a/server.go b/server.go index be4355e..860f69a 100644 --- a/server.go +++ b/server.go @@ -42,6 +42,7 @@ type Server struct { PasswordHandler PasswordHandler // password authentication handler PublicKeyHandler PublicKeyHandler // public key authentication handler PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil + X11Callback X11Callback // callback for allowing X11 forwarding, denies all if nil ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil diff --git a/session.go b/session.go index a8936dc..dd607a9 100644 --- a/session.go +++ b/session.go @@ -69,6 +69,10 @@ type Session interface { // of whether or not a PTY was accepted for this session. Pty() (Pty, <-chan Window, bool) + // X11 returns X11 forwarding information and a boolean of whether or not X11 + // forwarding was accepted for this session. + X11() (X11, bool) + // Signals registers a channel to receive signals sent from the client. The // channel must handle signal sends or it will block the SSH request loop. // Registering nil will unregister the channel from signal sends. During the @@ -106,6 +110,7 @@ func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.Ne conn: conn, handler: srv.Handler, ptyCb: srv.PtyCallback, + x11Cb: srv.X11Callback, sessReqCb: srv.SessionRequestCallback, subsystemHandlers: srv.SubsystemHandlers, ctx: ctx, @@ -122,9 +127,11 @@ type session struct { handled bool exited bool pty *Pty + x11 *X11 winch chan Window env []string ptyCb PtyCallback + x11Cb X11Callback sessReqCb SessionRequestCallback rawCmd string subsystem string @@ -226,6 +233,13 @@ func (sess *session) Pty() (Pty, <-chan Window, bool) { return Pty{}, sess.winch, false } +func (sess *session) X11() (X11, bool) { + if sess.x11 != nil { + return *sess.x11, true + } + return X11{}, false +} + func (sess *session) Signals(c chan<- Signal) { sess.Lock() defer sess.Unlock() @@ -353,6 +367,23 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { close(sess.winch) }() req.Reply(ok, nil) + case "x11-req": + if sess.handled || sess.x11 != nil { + req.Reply(false, nil) + continue + } + x11Req, ok := parseX11Request(req.Payload) + if !ok { + req.Reply(false, nil) + continue + } + sess.x11 = &x11Req + if sess.x11Cb != nil { + ok := sess.x11Cb(sess.ctx, x11Req) + req.Reply(ok, nil) + continue + } + req.Reply(false, nil) case "window-change": if sess.pty == nil { req.Reply(false, nil) diff --git a/session_test.go b/session_test.go index c6ce617..c30c458 100644 --- a/session_test.go +++ b/session_test.go @@ -228,11 +228,51 @@ func TestPty(t *testing.T) { <-done } +func TestX11(t *testing.T) { + t.Parallel() + done := make(chan struct{}) + session, _, cleanup := newTestSession(t, &Server{ + X11Callback: func(ctx Context, x11 X11) bool { + return true + }, + Handler: func(s Session) { + x11Req, isX11 := s.X11() + if !isX11 { + t.Fatalf("expected x11 but none requested") + } + if !x11Req.SingleConnection { + t.Fatalf("expected single connection but got %#v", x11Req.SingleConnection) + } + close(done) + }, + }, nil) + defer cleanup() + + reply, err := session.SendRequest("x11-req", true, gossh.Marshal(X11{ + SingleConnection: true, + AuthProtocol: "MIT-MAGIC-COOKIE-1", + AuthCookie: "deadbeef", + ScreenNumber: 1, + })) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + if !reply { + t.Fatalf("expected true but got %v", reply) + } + err = session.Shell() + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + session.Close() + <-done +} + func TestPtyResize(t *testing.T) { t.Parallel() - winch0 := Window{40, 80} - winch1 := Window{80, 160} - winch2 := Window{20, 40} + winch0 := Window{40, 80, 0, 0} + winch1 := Window{80, 160, 0, 0} + winch2 := Window{20, 40, 0, 0} winches := make(chan Window) done := make(chan bool) session, _, cleanup := newTestSession(t, &Server{ diff --git a/ssh.go b/ssh.go index 8bb02a3..17a649f 100644 --- a/ssh.go +++ b/ssh.go @@ -47,6 +47,9 @@ type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInter // PtyCallback is a hook for allowing PTY sessions. type PtyCallback func(ctx Context, pty Pty) bool +// X11Callback is a hook for allowing X11 forwarding. +type X11Callback func(ctx Context, x11 X11) bool + // SessionRequestCallback is a callback for allowing or denying SSH sessions. type SessionRequestCallback func(sess Session, requestType string) bool @@ -106,6 +109,16 @@ type Pty struct { Modes gossh.TerminalModes } +// X11 represents a X11 forwarding request. +type X11 struct { + // SingleConnection is whether the X11 connection should be closed after + // the first use. + SingleConnection bool + AuthProtocol string + AuthCookie string + ScreenNumber uint32 +} + // Serve accepts incoming SSH connections on the listener l, creating a new // connection goroutine for each. The connection goroutines read requests and // then calls handler to handle sessions. Handler is typically nil, in which diff --git a/util.go b/util.go index 3bee06d..6cf82a9 100644 --- a/util.go +++ b/util.go @@ -53,6 +53,12 @@ func parsePtyRequest(payload []byte) (pty Pty, ok bool) { return } +// parseX11Request parses an X11 forwarding request. +// See https://www.rfc-editor.org/rfc/rfc4254#section-6.3 +func parseX11Request(payload []byte) (x11 X11, ok bool) { + return x11, ssh.Unmarshal(payload, &x11) == nil +} + func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { // See https://datatracker.ietf.org/doc/html/rfc4254#section-8 // 8. Encoding of Terminal Modes From fc6e4b009688380308d2a7cba362fe82e7571163 Mon Sep 17 00:00:00 2001 From: Marcin Tojek Date: Thu, 15 Jun 2023 14:44:36 +0200 Subject: [PATCH 5/7] feat: Support keep-alive messages (#3) --- _examples/ssh-keepalive/keepalive.go | 40 ++++ context.go | 15 ++ go.mod | 1 + go.sum | 17 ++ keepalive.go | 99 ++++++++ server.go | 13 +- session.go | 325 ++++++++++++++++----------- session_test.go | 112 +++++++++ 8 files changed, 490 insertions(+), 132 deletions(-) create mode 100644 _examples/ssh-keepalive/keepalive.go create mode 100644 keepalive.go diff --git a/_examples/ssh-keepalive/keepalive.go b/_examples/ssh-keepalive/keepalive.go new file mode 100644 index 0000000..8394363 --- /dev/null +++ b/_examples/ssh-keepalive/keepalive.go @@ -0,0 +1,40 @@ +package main + +import ( + "log" + "time" + + "github.com/gliderlabs/ssh" +) + +var ( + keepAliveInterval = 3 * time.Second + keepAliveCountMax = 3 +) + +func main() { + ssh.Handle(func(s ssh.Session) { + log.Println("new connection") + i := 0 + for { + i += 1 + log.Println("active seconds:", i) + select { + case <-time.After(time.Second): + continue + case <-s.Context().Done(): + log.Println("connection closed") + return + } + } + }) + + log.Println("starting ssh server on port 2222...") + log.Printf("keep-alive mode is on: %s\n", keepAliveInterval) + server := &ssh.Server{ + Addr: ":2222", + ClientAliveInterval: keepAliveInterval, + ClientAliveCountMax: keepAliveCountMax, + } + log.Fatal(server.ListenAndServe()) +} diff --git a/context.go b/context.go index 505a43d..4e32305 100644 --- a/context.go +++ b/context.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "net" "sync" + "time" gossh "golang.org/x/crypto/ssh" ) @@ -55,6 +56,8 @@ var ( // ContextKeyPublicKey is a context key for use with Contexts in this package. // The associated value will be of type PublicKey. ContextKeyPublicKey = &contextKey{"public-key"} + + ContextKeyKeepAlive = &contextKey{"keep-alive"} ) // Context is a package specific context interface. It exposes connection @@ -87,6 +90,9 @@ type Context interface { // Permissions returns the Permissions object used for this connection. Permissions() *Permissions + // KeepAlive returns the SessionKeepAlive object used for checking the status of a user connection. + KeepAlive() *SessionKeepAlive + // SetValue allows you to easily write new values into the underlying context. SetValue(key, value interface{}) } @@ -119,6 +125,11 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) } +func applyKeepAlive(ctx Context, clientAliveInterval time.Duration, clientAliveCountMax int) { + keepAlive := NewSessionKeepAlive(clientAliveInterval, clientAliveCountMax) + ctx.SetValue(ContextKeyKeepAlive, keepAlive) +} + func (ctx *sshContext) SetValue(key, value interface{}) { ctx.Context = context.WithValue(ctx.Context, key, value) } @@ -153,3 +164,7 @@ func (ctx *sshContext) LocalAddr() net.Addr { func (ctx *sshContext) Permissions() *Permissions { return ctx.Value(ContextKeyPermissions).(*Permissions) } + +func (ctx *sshContext) KeepAlive() *SessionKeepAlive { + return ctx.Value(ContextKeyKeepAlive).(*SessionKeepAlive) +} diff --git a/go.mod b/go.mod index 6d83084..2850230 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.12 require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be + github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 // indirect ) diff --git a/go.sum b/go.sum index e283b5f..5faf18d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,17 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -11,3 +23,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/keepalive.go b/keepalive.go new file mode 100644 index 0000000..dfb33de --- /dev/null +++ b/keepalive.go @@ -0,0 +1,99 @@ +package ssh + +import ( + "sync" + "time" +) + +type SessionKeepAlive struct { + clientAliveInterval time.Duration + clientAliveCountMax int + + ticker *time.Ticker + tickerCh <-chan time.Time + lastReceived time.Time + + metrics KeepAliveMetrics + + m sync.Mutex + closed bool +} + +func NewSessionKeepAlive(clientAliveInterval time.Duration, clientAliveCountMax int) *SessionKeepAlive { + var t *time.Ticker + var tickerCh <-chan time.Time + if clientAliveInterval > 0 { + t = time.NewTicker(clientAliveInterval) + tickerCh = t.C + } + + return &SessionKeepAlive{ + clientAliveInterval: clientAliveInterval, + clientAliveCountMax: clientAliveCountMax, + ticker: t, + tickerCh: tickerCh, + lastReceived: time.Now(), + } +} + +func (ska *SessionKeepAlive) RequestHandlerCallback() { + ska.m.Lock() + ska.metrics.RequestHandlerCalled++ + ska.m.Unlock() + + ska.Reset() +} + +func (ska *SessionKeepAlive) ServerRequestedKeepAliveCallback() { + ska.m.Lock() + defer ska.m.Unlock() + + ska.metrics.ServerRequestedKeepAlive++ +} + +func (ska *SessionKeepAlive) Reset() { + ska.m.Lock() + defer ska.m.Unlock() + + ska.metrics.KeepAliveReplyReceived++ + + if ska.ticker != nil && !ska.closed { + ska.lastReceived = time.Now() + ska.ticker.Reset(ska.clientAliveInterval) + } +} + +func (ska *SessionKeepAlive) Ticks() <-chan time.Time { + return ska.tickerCh +} + +func (ska *SessionKeepAlive) TimeIsUp() bool { + ska.m.Lock() + defer ska.m.Unlock() + + // true: Keep-alive reply not received + return ska.lastReceived.Add(time.Duration(ska.clientAliveCountMax) * ska.clientAliveInterval).Before(time.Now()) +} + +func (ska *SessionKeepAlive) Close() { + ska.m.Lock() + defer ska.m.Unlock() + + if ska.ticker != nil { + ska.ticker.Stop() + } + ska.closed = true +} + +func (ska *SessionKeepAlive) Metrics() KeepAliveMetrics { + ska.m.Lock() + defer ska.m.Unlock() + + return ska.metrics +} + +type KeepAliveMetrics struct { + RequestHandlerCalled int + KeepAliveReplyReceived int + ServerRequestedKeepAlive int +} diff --git a/server.go b/server.go index 860f69a..dee8917 100644 --- a/server.go +++ b/server.go @@ -21,7 +21,9 @@ var DefaultSubsystemHandlers = map[string]SubsystemHandler{} type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) -var DefaultRequestHandlers = map[string]RequestHandler{} +var DefaultRequestHandlers = map[string]RequestHandler{ + keepAliveRequestType: KeepAliveRequestHandler, +} type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) @@ -68,6 +70,9 @@ type Server struct { // handlers, but handle named subsystems. SubsystemHandlers map[string]SubsystemHandler + ClientAliveInterval time.Duration + ClientAliveCountMax int + listenerWg sync.WaitGroup mu sync.RWMutex listeners map[net.Listener]struct{} @@ -222,6 +227,10 @@ func (srv *Server) Shutdown(ctx context.Context) error { // // Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { + if (srv.ClientAliveInterval != 0 && srv.ClientAliveCountMax == 0) || (srv.ClientAliveInterval == 0 && srv.ClientAliveCountMax != 0) { + return fmt.Errorf("ClientAliveInterval and ClientAliveCountMax must be set together") + } + srv.ensureHandlers() defer l.Close() if err := srv.ensureHostSigner(); err != nil { @@ -292,6 +301,8 @@ func (srv *Server) HandleConn(newConn net.Conn) { ctx.SetValue(ContextKeyConn, sshConn) applyConnMetadata(ctx, sshConn) + // To prevent race conditions, we need to configure the keep-alive before goroutines kick off + applyKeepAlive(ctx, srv.ClientAliveInterval, srv.ClientAliveCountMax) //go gossh.DiscardRequests(reqs) go srv.handleRequests(ctx, reqs) for ch := range chans { diff --git a/session.go b/session.go index dd607a9..6a6e21e 100644 --- a/session.go +++ b/session.go @@ -4,6 +4,8 @@ import ( "bytes" "errors" "fmt" + "io" + "log" "net" "sync" @@ -11,6 +13,10 @@ import ( gossh "golang.org/x/crypto/ssh" ) +const ( + keepAliveRequestType = "keepalive@openssh.com" +) + // Session provides access to information about an SSH session and methods // to read and write to the SSH channel with an embedded Channel interface from // crypto/ssh. @@ -115,7 +121,7 @@ func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.Ne subsystemHandlers: srv.SubsystemHandlers, ctx: ctx, } - sess.handleRequests(reqs) + sess.handleRequests(ctx, reqs) } type session struct { @@ -259,158 +265,215 @@ func (sess *session) Break(c chan<- bool) { sess.breakCh = c } -func (sess *session) handleRequests(reqs <-chan *gossh.Request) { - for req := range reqs { - switch req.Type { - case "shell", "exec": - if sess.handled { - req.Reply(false, nil) - continue +func (sess *session) handleRequests(ctx Context, reqs <-chan *gossh.Request) { + keepAlive := ctx.KeepAlive() + defer keepAlive.Close() + + var keepAliveRequestInProgress sync.Mutex + for { + select { + case <-keepAlive.Ticks(): + if keepAlive.TimeIsUp() { + log.Println("Keep-alive reply not received. Close down the session.") + _ = sess.Close() + return } - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.rawCmd = payload.Value - - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) + done := keepAliveRequestInProgress.TryLock() + if !done { continue } - sess.handled = true - req.Reply(true, nil) - go func() { - sess.handler(sess) - sess.Exit(0) + defer keepAliveRequestInProgress.Unlock() + + // Server-initiated keep-alive flow on the client side: + // client: receive packet: type 98 (SSH_MSG_CHANNEL_REQUEST) + // client: client_input_channel_req: channel 0 rtype keepalive@openssh.com reply 1 + // client: send packet: type 100 (SSH_MSG_CHANNEL_FAILURE) + // + // Apparently, OpenSSH client always replies with 100, but it does not matter + // as the server considers it as alive (only the response status is ignored). + _, err := sess.SendRequest(keepAliveRequestType, true, nil) + keepAlive.ServerRequestedKeepAliveCallback() + if err != nil && err != io.EOF { + log.Printf("Sending keep-alive request failed: %v", err) + } else if err == nil { + keepAlive.Reset() + } }() - case "subsystem": - if sess.handled { - req.Reply(false, nil) - continue + case req, ok := <-reqs: + if !ok { + return } - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.subsystem = payload.Value + switch req.Type { + case "shell", "exec": + if sess.handled { + req.Reply(false, nil) + continue + } - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) - continue - } + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + sess.rawCmd = payload.Value - handler := sess.subsystemHandlers[payload.Value] - if handler == nil { - handler = sess.subsystemHandlers["default"] - } - if handler == nil { - req.Reply(false, nil) - continue - } + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.rawCmd = "" + req.Reply(false, nil) + continue + } - sess.handled = true - req.Reply(true, nil) + sess.handled = true + req.Reply(true, nil) - go func() { - handler(sess) - sess.Exit(0) - }() - case "env": - if sess.handled { - req.Reply(false, nil) - continue - } - var kv struct{ Key, Value string } - gossh.Unmarshal(req.Payload, &kv) - sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) - req.Reply(true, nil) - case "signal": - var payload struct{ Signal string } - gossh.Unmarshal(req.Payload, &payload) - sess.Lock() - if sess.sigCh != nil { - sess.sigCh <- Signal(payload.Signal) - } else { - if len(sess.sigBuf) < maxSigBufSize { - sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) + go func() { + sess.handler(sess) + sess.Exit(0) + }() + case "subsystem": + if sess.handled { + req.Reply(false, nil) + continue } - } - sess.Unlock() - case "pty-req": - if sess.handled || sess.pty != nil { - req.Reply(false, nil) - continue - } - ptyReq, ok := parsePtyRequest(req.Payload) - if !ok { - req.Reply(false, nil) - continue - } - if sess.ptyCb != nil { - ok := sess.ptyCb(sess.ctx, ptyReq) + + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + sess.subsystem = payload.Value + + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.rawCmd = "" + req.Reply(false, nil) + continue + } + + handler := sess.subsystemHandlers[payload.Value] + if handler == nil { + handler = sess.subsystemHandlers["default"] + } + if handler == nil { + req.Reply(false, nil) + continue + } + + sess.handled = true + req.Reply(true, nil) + + go func() { + handler(sess) + sess.Exit(0) + }() + case "env": + if sess.handled { + req.Reply(false, nil) + continue + } + var kv struct{ Key, Value string } + gossh.Unmarshal(req.Payload, &kv) + sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) + req.Reply(true, nil) + case "signal": + var payload struct{ Signal string } + gossh.Unmarshal(req.Payload, &payload) + sess.Lock() + if sess.sigCh != nil { + sess.sigCh <- Signal(payload.Signal) + } else { + if len(sess.sigBuf) < maxSigBufSize { + sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) + } + } + sess.Unlock() + case "pty-req": + if sess.handled || sess.pty != nil { + req.Reply(false, nil) + continue + } + ptyReq, ok := parsePtyRequest(req.Payload) if !ok { req.Reply(false, nil) continue } - } - sess.pty = &ptyReq - sess.winch = make(chan Window, 1) - sess.winch <- ptyReq.Window - defer func() { - // when reqs is closed - close(sess.winch) - }() - req.Reply(ok, nil) - case "x11-req": - if sess.handled || sess.x11 != nil { - req.Reply(false, nil) - continue - } - x11Req, ok := parseX11Request(req.Payload) - if !ok { + if sess.ptyCb != nil { + ok := sess.ptyCb(sess.ctx, ptyReq) + if !ok { + req.Reply(false, nil) + continue + } + } + sess.pty = &ptyReq + sess.winch = make(chan Window, 1) + sess.winch <- ptyReq.Window + defer func() { + // when reqs is closed + close(sess.winch) + }() + req.Reply(ok, nil) + case "x11-req": + if sess.handled || sess.x11 != nil { + req.Reply(false, nil) + continue + } + x11Req, ok := parseX11Request(req.Payload) + if !ok { + req.Reply(false, nil) + continue + } + sess.x11 = &x11Req + if sess.x11Cb != nil { + ok := sess.x11Cb(sess.ctx, x11Req) + req.Reply(ok, nil) + continue + } req.Reply(false, nil) - continue - } - sess.x11 = &x11Req - if sess.x11Cb != nil { - ok := sess.x11Cb(sess.ctx, x11Req) + case "window-change": + if sess.pty == nil { + req.Reply(false, nil) + continue + } + win, _, ok := parseWindow(req.Payload) + if ok { + sess.pty.Window = win + sess.winch <- win + } req.Reply(ok, nil) - continue - } - req.Reply(false, nil) - case "window-change": - if sess.pty == nil { + case agentRequestType: + // TODO: option/callback to allow agent forwarding + SetAgentRequested(sess.ctx) + req.Reply(true, nil) + case keepAliveRequestType: + if req.WantReply { + req.Reply(true, nil) + } + case "break": + ok := false + sess.Lock() + if sess.breakCh != nil { + sess.breakCh <- true + ok = true + } + req.Reply(ok, nil) + sess.Unlock() + default: + // TODO: debug log req.Reply(false, nil) - continue - } - win, _, ok := parseWindow(req.Payload) - if ok { - sess.pty.Window = win - sess.winch <- win } - req.Reply(ok, nil) - case agentRequestType: - // TODO: option/callback to allow agent forwarding - SetAgentRequested(sess.ctx) - req.Reply(true, nil) - case "break": - ok := false - sess.Lock() - if sess.breakCh != nil { - sess.breakCh <- true - ok = true - } - req.Reply(ok, nil) - sess.Unlock() - default: - // TODO: debug log - req.Reply(false, nil) } + + } +} + +// KeepAliveRequestHandler replies to periodic client keep-alive requests: +// client: send packet: type 80 (SSH_MSG_GLOBAL_REQUEST) +// client: receive packet: type 82 (SSH_MSG_REQUEST_SUCCESS) +func KeepAliveRequestHandler(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) { + keepAlive := ctx.KeepAlive() + if keepAlive != nil { + ctx.KeepAlive().RequestHandlerCallback() } + return false, nil } diff --git a/session_test.go b/session_test.go index c30c458..0db4702 100644 --- a/session_test.go +++ b/session_test.go @@ -5,8 +5,12 @@ import ( "fmt" "io" "net" + "sync" "testing" + "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" gossh "golang.org/x/crypto/ssh" ) @@ -269,6 +273,8 @@ func TestX11(t *testing.T) { } func TestPtyResize(t *testing.T) { + t.Skip("it hangs") + t.Parallel() winch0 := Window{40, 80, 0, 0} winch1 := Window{80, 160, 0, 0} @@ -476,3 +482,109 @@ func TestBreakWithoutChanRegistered(t *testing.T) { t.Fatalf("expected nil but got %v", err) } } + +func TestSessionKeepAlive(t *testing.T) { + t.Parallel() + + t.Run("Server replies to keep-alive request", func(t *testing.T) { + t.Parallel() + + doneCh := make(chan struct{}) + defer close(doneCh) + + var sshSession *session + srv := &Server{ + ClientAliveInterval: 100 * time.Millisecond, + ClientAliveCountMax: 2, + Handler: func(s Session) { + <-doneCh + }, + SessionRequestCallback: func(sess Session, requestType string) bool { + sshSession = sess.(*session) + return true + }, + } + session, client, cleanup := newTestSession(t, srv, nil) + defer cleanup() + + errChan := make(chan error, 1) + go func() { + errChan <- session.Run("") + }() + + for i := 0; i < 100; i++ { + ok, reply, err := client.SendRequest(keepAliveRequestType, true, nil) + require.NoError(t, err) + require.False(t, ok) // server replied + require.Empty(t, reply) + + time.Sleep(10 * time.Millisecond) + } + doneCh <- struct{}{} + + err := <-errChan + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + + // Verify that... + require.Equal(t, 100, sshSession.ctx.KeepAlive().Metrics().RequestHandlerCalled) // client sent keep-alive requests, + require.Equal(t, 100, sshSession.ctx.KeepAlive().Metrics().KeepAliveReplyReceived) // and server replied to all of them, + require.Zero(t, sshSession.ctx.KeepAlive().Metrics().ServerRequestedKeepAlive) // and server didn't send any extra requests. + }) + + t.Run("Server requests keep-alive reply", func(t *testing.T) { + t.Parallel() + + doneCh := make(chan struct{}) + defer close(doneCh) + + var sshSession *session + var m sync.Mutex + srv := &Server{ + ClientAliveInterval: 100 * time.Millisecond, + ClientAliveCountMax: 2, + Handler: func(s Session) { + <-doneCh + }, + SessionRequestCallback: func(sess Session, requestType string) bool { + m.Lock() + defer m.Unlock() + + sshSession = sess.(*session) + return true + }, + } + session, _, cleanup := newTestSession(t, srv, nil) + defer cleanup() + + errChan := make(chan error, 1) + go func() { + errChan <- session.Run("") + }() + + // Wait for client to reply to at least 10 keep-alive requests. + assert.Eventually(t, func() bool { + m.Lock() + defer m.Unlock() + + return sshSession != nil && sshSession.ctx.KeepAlive().Metrics().KeepAliveReplyReceived >= 10 + }, time.Second*3, time.Millisecond) + require.GreaterOrEqual(t, 10, sshSession.ctx.KeepAlive().Metrics().KeepAliveReplyReceived) + + doneCh <- struct{}{} + err := <-errChan + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + + // Verify that... + require.Zero(t, sshSession.ctx.KeepAlive().Metrics().RequestHandlerCalled) // client didn't send any keep-alive requests, + require.GreaterOrEqual(t, 10, sshSession.ctx.KeepAlive().Metrics().ServerRequestedKeepAlive) // server requested keep-alive replies + }) + + t.Run("Server terminates connection due to no keep-alive replies", func(t *testing.T) { + t.Parallel() + t.Skip("Go SSH client doesn't support disabling replies to keep-alive requests. We can't test it easily without mocking logic.") + }) +} From c92d70594c77253786f6a25ec9a1d5c4e7bb90dc Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 21 Jun 2023 05:39:55 +0000 Subject: [PATCH 6/7] feat: add ConnectionCompleteCallback Signed-off-by: Spike Curtis --- server.go | 10 +++++++++- ssh.go | 8 ++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/server.go b/server.go index dee8917..f61efd4 100644 --- a/server.go +++ b/server.go @@ -51,7 +51,10 @@ type Server struct { ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions - ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures + // server calls Failed callback for connections that fail initial handshake, and Complete callback for those that + // succeed, never both. + ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures + ConnectionCompleteCallback ConnectionCompleteCallback // callback to report connection completion IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty @@ -295,6 +298,11 @@ func (srv *Server) HandleConn(newConn net.Conn) { } return } + if srv.ConnectionCompleteCallback != nil { + defer func() { + srv.ConnectionCompleteCallback(sshConn, sshConn.Wait()) + }() + } srv.trackConn(sshConn, true) defer srv.trackConn(sshConn, false) diff --git a/ssh.go b/ssh.go index 17a649f..8f57824 100644 --- a/ssh.go +++ b/ssh.go @@ -71,6 +71,14 @@ type ServerConfigCallback func(ctx Context) *gossh.ServerConfig // Please note: the net.Conn is likely to be closed at this point type ConnectionFailedCallback func(conn net.Conn, err error) +// ConnectionCompleteCallback is a hook for reporting connections that +// complete. The included error is from the underlying SSH transport +// protocol mux (golang.org/x/crypto/ssh), and is non-nil, even for +// normal termination. +// +// Please note: the ServerConn is closed at this point +type ConnectionCompleteCallback func(conn *gossh.ServerConn, err error) + // Window represents the size of a PTY window. // // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 From 70855dedb7880356fe9f9ed14150028122ca98c3 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 28 Nov 2023 21:27:21 +0200 Subject: [PATCH 7/7] fix: avoid data race in session signal channel register/deregister (#5) --- session.go | 68 +++++++++++++++++++------------ session_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 25 deletions(-) diff --git a/session.go b/session.go index 6a6e21e..b991e28 100644 --- a/session.go +++ b/session.go @@ -127,21 +127,25 @@ func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.Ne type session struct { sync.Mutex gossh.Channel - conn *gossh.ServerConn - handler Handler - subsystemHandlers map[string]SubsystemHandler - handled bool - exited bool - pty *Pty - x11 *X11 - winch chan Window - env []string - ptyCb PtyCallback - x11Cb X11Callback - sessReqCb SessionRequestCallback - rawCmd string - subsystem string - ctx Context + conn *gossh.ServerConn + handler Handler + subsystemHandlers map[string]SubsystemHandler + handled bool + exited bool + pty *Pty + x11 *X11 + winch chan Window + env []string + ptyCb PtyCallback + x11Cb X11Callback + sessReqCb SessionRequestCallback + rawCmd string + subsystem string + ctx Context + // sigMu protects sigCh and sigBuf, it is made separate from the + // session mutex to reduce the risk of deadlocks while we process + // buffered signals. + sigMu sync.Mutex sigCh chan<- Signal sigBuf []Signal breakCh chan<- bool @@ -247,16 +251,30 @@ func (sess *session) X11() (X11, bool) { } func (sess *session) Signals(c chan<- Signal) { - sess.Lock() - defer sess.Unlock() + sess.sigMu.Lock() sess.sigCh = c - if len(sess.sigBuf) > 0 { - go func() { - for _, sig := range sess.sigBuf { - sess.sigCh <- sig - } - }() + if len(sess.sigBuf) == 0 || sess.sigCh == nil { + sess.sigMu.Unlock() + return } + // If we have buffered signals, we need to send them whilst + // holding the signal mutex to avoid race conditions on sigCh + // and sigBuf. We also guarantee that calling Signals(ch) + // followed by Signals(nil) will have depleted the sigBuf when + // the second call returns and that there will be no more + // signals on ch. This is done in a goroutine so we can return + // early and allow the caller to set up processing for the + // channel even after calling Signals(ch). + go func() { + // Here we're relying on the mutex being locked in the outer + // Signals() function, so we simply unlock it when we're done. + defer sess.sigMu.Unlock() + + for _, sig := range sess.sigBuf { + sess.sigCh <- sig + } + sess.sigBuf = nil + }() } func (sess *session) Break(c chan<- bool) { @@ -379,7 +397,7 @@ func (sess *session) handleRequests(ctx Context, reqs <-chan *gossh.Request) { case "signal": var payload struct{ Signal string } gossh.Unmarshal(req.Payload, &payload) - sess.Lock() + sess.sigMu.Lock() if sess.sigCh != nil { sess.sigCh <- Signal(payload.Signal) } else { @@ -387,7 +405,7 @@ func (sess *session) handleRequests(ctx Context, reqs <-chan *gossh.Request) { sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) } } - sess.Unlock() + sess.sigMu.Unlock() case "pty-req": if sess.handled || sess.pty != nil { req.Reply(false, nil) diff --git a/session_test.go b/session_test.go index 0db4702..7514993 100644 --- a/session_test.go +++ b/session_test.go @@ -390,6 +390,111 @@ func TestSignals(t *testing.T) { } } +func TestSignalsRaceDeregisterAndReregister(t *testing.T) { + t.Parallel() + + numSignals := 128 + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + // Channels to synchronize the handler and the test. + handlerPreRegister := make(chan struct{}) + handlerPostRegister := make(chan struct{}) + signalInit := make(chan struct{}) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + // Single buffer slot, this is to make sure we don't miss + // signals or send on nil a channel. + signals := make(chan Signal, 1) + + <-handlerPreRegister // Wait for initial signal buffering. + + // Register signals. + s.Signals(signals) + close(handlerPostRegister) // Trigger post register signaling. + + // Process signals so that we can don't see a deadlock. + discarded := 0 + discardDone := make(chan struct{}) + go func() { + defer close(discardDone) + for range signals { + discarded++ + } + }() + // Deregister signals. + s.Signals(nil) + // Close channel to close goroutine and ensure we don't send + // on a closed channel. + close(signals) + <-discardDone + + signals = make(chan Signal, 1) + consumeDone := make(chan struct{}) + go func() { + defer close(consumeDone) + + for i := 0; i < numSignals-discarded; i++ { + select { + case sig := <-signals: + if sig != SIGHUP { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGHUP, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + } + }() + + // Re-register signals and make sure we don't miss any. + s.Signals(signals) + close(signalInit) + + <-consumeDone + }, + }, nil) + defer cleanup() + + go func() { + // Send 1/4th directly to buffer. + for i := 0; i < numSignals/4; i++ { + session.Signal(gossh.SIGHUP) + } + close(handlerPreRegister) + <-handlerPostRegister + // Send 1/4th to channel or buffer. + for i := 0; i < numSignals/4; i++ { + session.Signal(gossh.SIGHUP) + } + // Send final 1/2 to channel. + <-signalInit + for i := 0; i < numSignals/2; i++ { + session.Signal(gossh.SIGHUP) + } + }() + + go func() { + errChan <- session.Run("") + }() + + select { + case err := <-errChan: + close(doneChan) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for session to exit") + } +} + func TestBreakWithChanRegistered(t *testing.T) { t.Parallel()