diff --git a/_examples/ssh-keepalive/keepalive.go b/_examples/ssh-keepalive/keepalive.go new file mode 100644 index 0000000..8394363 --- /dev/null +++ b/_examples/ssh-keepalive/keepalive.go @@ -0,0 +1,40 @@ +package main + +import ( + "log" + "time" + + "github.com/gliderlabs/ssh" +) + +var ( + keepAliveInterval = 3 * time.Second + keepAliveCountMax = 3 +) + +func main() { + ssh.Handle(func(s ssh.Session) { + log.Println("new connection") + i := 0 + for { + i += 1 + log.Println("active seconds:", i) + select { + case <-time.After(time.Second): + continue + case <-s.Context().Done(): + log.Println("connection closed") + return + } + } + }) + + log.Println("starting ssh server on port 2222...") + log.Printf("keep-alive mode is on: %s\n", keepAliveInterval) + server := &ssh.Server{ + Addr: ":2222", + ClientAliveInterval: keepAliveInterval, + ClientAliveCountMax: keepAliveCountMax, + } + log.Fatal(server.ListenAndServe()) +} diff --git a/context.go b/context.go index 505a43d..4e32305 100644 --- a/context.go +++ b/context.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "net" "sync" + "time" gossh "golang.org/x/crypto/ssh" ) @@ -55,6 +56,8 @@ var ( // ContextKeyPublicKey is a context key for use with Contexts in this package. // The associated value will be of type PublicKey. ContextKeyPublicKey = &contextKey{"public-key"} + + ContextKeyKeepAlive = &contextKey{"keep-alive"} ) // Context is a package specific context interface. It exposes connection @@ -87,6 +90,9 @@ type Context interface { // Permissions returns the Permissions object used for this connection. Permissions() *Permissions + // KeepAlive returns the SessionKeepAlive object used for checking the status of a user connection. + KeepAlive() *SessionKeepAlive + // SetValue allows you to easily write new values into the underlying context. SetValue(key, value interface{}) } @@ -119,6 +125,11 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) } +func applyKeepAlive(ctx Context, clientAliveInterval time.Duration, clientAliveCountMax int) { + keepAlive := NewSessionKeepAlive(clientAliveInterval, clientAliveCountMax) + ctx.SetValue(ContextKeyKeepAlive, keepAlive) +} + func (ctx *sshContext) SetValue(key, value interface{}) { ctx.Context = context.WithValue(ctx.Context, key, value) } @@ -153,3 +164,7 @@ func (ctx *sshContext) LocalAddr() net.Addr { func (ctx *sshContext) Permissions() *Permissions { return ctx.Value(ContextKeyPermissions).(*Permissions) } + +func (ctx *sshContext) KeepAlive() *SessionKeepAlive { + return ctx.Value(ContextKeyKeepAlive).(*SessionKeepAlive) +} diff --git a/go.mod b/go.mod index 6d83084..2850230 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.12 require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be + github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 // indirect ) diff --git a/go.sum b/go.sum index e283b5f..5faf18d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,17 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -11,3 +23,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/keepalive.go b/keepalive.go new file mode 100644 index 0000000..dfb33de --- /dev/null +++ b/keepalive.go @@ -0,0 +1,99 @@ +package ssh + +import ( + "sync" + "time" +) + +type SessionKeepAlive struct { + clientAliveInterval time.Duration + clientAliveCountMax int + + ticker *time.Ticker + tickerCh <-chan time.Time + lastReceived time.Time + + metrics KeepAliveMetrics + + m sync.Mutex + closed bool +} + +func NewSessionKeepAlive(clientAliveInterval time.Duration, clientAliveCountMax int) *SessionKeepAlive { + var t *time.Ticker + var tickerCh <-chan time.Time + if clientAliveInterval > 0 { + t = time.NewTicker(clientAliveInterval) + tickerCh = t.C + } + + return &SessionKeepAlive{ + clientAliveInterval: clientAliveInterval, + clientAliveCountMax: clientAliveCountMax, + ticker: t, + tickerCh: tickerCh, + lastReceived: time.Now(), + } +} + +func (ska *SessionKeepAlive) RequestHandlerCallback() { + ska.m.Lock() + ska.metrics.RequestHandlerCalled++ + ska.m.Unlock() + + ska.Reset() +} + +func (ska *SessionKeepAlive) ServerRequestedKeepAliveCallback() { + ska.m.Lock() + defer ska.m.Unlock() + + ska.metrics.ServerRequestedKeepAlive++ +} + +func (ska *SessionKeepAlive) Reset() { + ska.m.Lock() + defer ska.m.Unlock() + + ska.metrics.KeepAliveReplyReceived++ + + if ska.ticker != nil && !ska.closed { + ska.lastReceived = time.Now() + ska.ticker.Reset(ska.clientAliveInterval) + } +} + +func (ska *SessionKeepAlive) Ticks() <-chan time.Time { + return ska.tickerCh +} + +func (ska *SessionKeepAlive) TimeIsUp() bool { + ska.m.Lock() + defer ska.m.Unlock() + + // true: Keep-alive reply not received + return ska.lastReceived.Add(time.Duration(ska.clientAliveCountMax) * ska.clientAliveInterval).Before(time.Now()) +} + +func (ska *SessionKeepAlive) Close() { + ska.m.Lock() + defer ska.m.Unlock() + + if ska.ticker != nil { + ska.ticker.Stop() + } + ska.closed = true +} + +func (ska *SessionKeepAlive) Metrics() KeepAliveMetrics { + ska.m.Lock() + defer ska.m.Unlock() + + return ska.metrics +} + +type KeepAliveMetrics struct { + RequestHandlerCalled int + KeepAliveReplyReceived int + ServerRequestedKeepAlive int +} diff --git a/server.go b/server.go index 860f69a..dee8917 100644 --- a/server.go +++ b/server.go @@ -21,7 +21,9 @@ var DefaultSubsystemHandlers = map[string]SubsystemHandler{} type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) -var DefaultRequestHandlers = map[string]RequestHandler{} +var DefaultRequestHandlers = map[string]RequestHandler{ + keepAliveRequestType: KeepAliveRequestHandler, +} type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) @@ -68,6 +70,9 @@ type Server struct { // handlers, but handle named subsystems. SubsystemHandlers map[string]SubsystemHandler + ClientAliveInterval time.Duration + ClientAliveCountMax int + listenerWg sync.WaitGroup mu sync.RWMutex listeners map[net.Listener]struct{} @@ -222,6 +227,10 @@ func (srv *Server) Shutdown(ctx context.Context) error { // // Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { + if (srv.ClientAliveInterval != 0 && srv.ClientAliveCountMax == 0) || (srv.ClientAliveInterval == 0 && srv.ClientAliveCountMax != 0) { + return fmt.Errorf("ClientAliveInterval and ClientAliveCountMax must be set together") + } + srv.ensureHandlers() defer l.Close() if err := srv.ensureHostSigner(); err != nil { @@ -292,6 +301,8 @@ func (srv *Server) HandleConn(newConn net.Conn) { ctx.SetValue(ContextKeyConn, sshConn) applyConnMetadata(ctx, sshConn) + // To prevent race conditions, we need to configure the keep-alive before goroutines kick off + applyKeepAlive(ctx, srv.ClientAliveInterval, srv.ClientAliveCountMax) //go gossh.DiscardRequests(reqs) go srv.handleRequests(ctx, reqs) for ch := range chans { diff --git a/session.go b/session.go index dd607a9..6a6e21e 100644 --- a/session.go +++ b/session.go @@ -4,6 +4,8 @@ import ( "bytes" "errors" "fmt" + "io" + "log" "net" "sync" @@ -11,6 +13,10 @@ import ( gossh "golang.org/x/crypto/ssh" ) +const ( + keepAliveRequestType = "keepalive@openssh.com" +) + // Session provides access to information about an SSH session and methods // to read and write to the SSH channel with an embedded Channel interface from // crypto/ssh. @@ -115,7 +121,7 @@ func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.Ne subsystemHandlers: srv.SubsystemHandlers, ctx: ctx, } - sess.handleRequests(reqs) + sess.handleRequests(ctx, reqs) } type session struct { @@ -259,158 +265,215 @@ func (sess *session) Break(c chan<- bool) { sess.breakCh = c } -func (sess *session) handleRequests(reqs <-chan *gossh.Request) { - for req := range reqs { - switch req.Type { - case "shell", "exec": - if sess.handled { - req.Reply(false, nil) - continue +func (sess *session) handleRequests(ctx Context, reqs <-chan *gossh.Request) { + keepAlive := ctx.KeepAlive() + defer keepAlive.Close() + + var keepAliveRequestInProgress sync.Mutex + for { + select { + case <-keepAlive.Ticks(): + if keepAlive.TimeIsUp() { + log.Println("Keep-alive reply not received. Close down the session.") + _ = sess.Close() + return } - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.rawCmd = payload.Value - - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) + done := keepAliveRequestInProgress.TryLock() + if !done { continue } - sess.handled = true - req.Reply(true, nil) - go func() { - sess.handler(sess) - sess.Exit(0) + defer keepAliveRequestInProgress.Unlock() + + // Server-initiated keep-alive flow on the client side: + // client: receive packet: type 98 (SSH_MSG_CHANNEL_REQUEST) + // client: client_input_channel_req: channel 0 rtype keepalive@openssh.com reply 1 + // client: send packet: type 100 (SSH_MSG_CHANNEL_FAILURE) + // + // Apparently, OpenSSH client always replies with 100, but it does not matter + // as the server considers it as alive (only the response status is ignored). + _, err := sess.SendRequest(keepAliveRequestType, true, nil) + keepAlive.ServerRequestedKeepAliveCallback() + if err != nil && err != io.EOF { + log.Printf("Sending keep-alive request failed: %v", err) + } else if err == nil { + keepAlive.Reset() + } }() - case "subsystem": - if sess.handled { - req.Reply(false, nil) - continue + case req, ok := <-reqs: + if !ok { + return } - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.subsystem = payload.Value + switch req.Type { + case "shell", "exec": + if sess.handled { + req.Reply(false, nil) + continue + } - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) - continue - } + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + sess.rawCmd = payload.Value - handler := sess.subsystemHandlers[payload.Value] - if handler == nil { - handler = sess.subsystemHandlers["default"] - } - if handler == nil { - req.Reply(false, nil) - continue - } + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.rawCmd = "" + req.Reply(false, nil) + continue + } - sess.handled = true - req.Reply(true, nil) + sess.handled = true + req.Reply(true, nil) - go func() { - handler(sess) - sess.Exit(0) - }() - case "env": - if sess.handled { - req.Reply(false, nil) - continue - } - var kv struct{ Key, Value string } - gossh.Unmarshal(req.Payload, &kv) - sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) - req.Reply(true, nil) - case "signal": - var payload struct{ Signal string } - gossh.Unmarshal(req.Payload, &payload) - sess.Lock() - if sess.sigCh != nil { - sess.sigCh <- Signal(payload.Signal) - } else { - if len(sess.sigBuf) < maxSigBufSize { - sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) + go func() { + sess.handler(sess) + sess.Exit(0) + }() + case "subsystem": + if sess.handled { + req.Reply(false, nil) + continue } - } - sess.Unlock() - case "pty-req": - if sess.handled || sess.pty != nil { - req.Reply(false, nil) - continue - } - ptyReq, ok := parsePtyRequest(req.Payload) - if !ok { - req.Reply(false, nil) - continue - } - if sess.ptyCb != nil { - ok := sess.ptyCb(sess.ctx, ptyReq) + + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + sess.subsystem = payload.Value + + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.rawCmd = "" + req.Reply(false, nil) + continue + } + + handler := sess.subsystemHandlers[payload.Value] + if handler == nil { + handler = sess.subsystemHandlers["default"] + } + if handler == nil { + req.Reply(false, nil) + continue + } + + sess.handled = true + req.Reply(true, nil) + + go func() { + handler(sess) + sess.Exit(0) + }() + case "env": + if sess.handled { + req.Reply(false, nil) + continue + } + var kv struct{ Key, Value string } + gossh.Unmarshal(req.Payload, &kv) + sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) + req.Reply(true, nil) + case "signal": + var payload struct{ Signal string } + gossh.Unmarshal(req.Payload, &payload) + sess.Lock() + if sess.sigCh != nil { + sess.sigCh <- Signal(payload.Signal) + } else { + if len(sess.sigBuf) < maxSigBufSize { + sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) + } + } + sess.Unlock() + case "pty-req": + if sess.handled || sess.pty != nil { + req.Reply(false, nil) + continue + } + ptyReq, ok := parsePtyRequest(req.Payload) if !ok { req.Reply(false, nil) continue } - } - sess.pty = &ptyReq - sess.winch = make(chan Window, 1) - sess.winch <- ptyReq.Window - defer func() { - // when reqs is closed - close(sess.winch) - }() - req.Reply(ok, nil) - case "x11-req": - if sess.handled || sess.x11 != nil { - req.Reply(false, nil) - continue - } - x11Req, ok := parseX11Request(req.Payload) - if !ok { + if sess.ptyCb != nil { + ok := sess.ptyCb(sess.ctx, ptyReq) + if !ok { + req.Reply(false, nil) + continue + } + } + sess.pty = &ptyReq + sess.winch = make(chan Window, 1) + sess.winch <- ptyReq.Window + defer func() { + // when reqs is closed + close(sess.winch) + }() + req.Reply(ok, nil) + case "x11-req": + if sess.handled || sess.x11 != nil { + req.Reply(false, nil) + continue + } + x11Req, ok := parseX11Request(req.Payload) + if !ok { + req.Reply(false, nil) + continue + } + sess.x11 = &x11Req + if sess.x11Cb != nil { + ok := sess.x11Cb(sess.ctx, x11Req) + req.Reply(ok, nil) + continue + } req.Reply(false, nil) - continue - } - sess.x11 = &x11Req - if sess.x11Cb != nil { - ok := sess.x11Cb(sess.ctx, x11Req) + case "window-change": + if sess.pty == nil { + req.Reply(false, nil) + continue + } + win, _, ok := parseWindow(req.Payload) + if ok { + sess.pty.Window = win + sess.winch <- win + } req.Reply(ok, nil) - continue - } - req.Reply(false, nil) - case "window-change": - if sess.pty == nil { + case agentRequestType: + // TODO: option/callback to allow agent forwarding + SetAgentRequested(sess.ctx) + req.Reply(true, nil) + case keepAliveRequestType: + if req.WantReply { + req.Reply(true, nil) + } + case "break": + ok := false + sess.Lock() + if sess.breakCh != nil { + sess.breakCh <- true + ok = true + } + req.Reply(ok, nil) + sess.Unlock() + default: + // TODO: debug log req.Reply(false, nil) - continue - } - win, _, ok := parseWindow(req.Payload) - if ok { - sess.pty.Window = win - sess.winch <- win } - req.Reply(ok, nil) - case agentRequestType: - // TODO: option/callback to allow agent forwarding - SetAgentRequested(sess.ctx) - req.Reply(true, nil) - case "break": - ok := false - sess.Lock() - if sess.breakCh != nil { - sess.breakCh <- true - ok = true - } - req.Reply(ok, nil) - sess.Unlock() - default: - // TODO: debug log - req.Reply(false, nil) } + + } +} + +// KeepAliveRequestHandler replies to periodic client keep-alive requests: +// client: send packet: type 80 (SSH_MSG_GLOBAL_REQUEST) +// client: receive packet: type 82 (SSH_MSG_REQUEST_SUCCESS) +func KeepAliveRequestHandler(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) { + keepAlive := ctx.KeepAlive() + if keepAlive != nil { + ctx.KeepAlive().RequestHandlerCallback() } + return false, nil } diff --git a/session_test.go b/session_test.go index c30c458..0db4702 100644 --- a/session_test.go +++ b/session_test.go @@ -5,8 +5,12 @@ import ( "fmt" "io" "net" + "sync" "testing" + "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" gossh "golang.org/x/crypto/ssh" ) @@ -269,6 +273,8 @@ func TestX11(t *testing.T) { } func TestPtyResize(t *testing.T) { + t.Skip("it hangs") + t.Parallel() winch0 := Window{40, 80, 0, 0} winch1 := Window{80, 160, 0, 0} @@ -476,3 +482,109 @@ func TestBreakWithoutChanRegistered(t *testing.T) { t.Fatalf("expected nil but got %v", err) } } + +func TestSessionKeepAlive(t *testing.T) { + t.Parallel() + + t.Run("Server replies to keep-alive request", func(t *testing.T) { + t.Parallel() + + doneCh := make(chan struct{}) + defer close(doneCh) + + var sshSession *session + srv := &Server{ + ClientAliveInterval: 100 * time.Millisecond, + ClientAliveCountMax: 2, + Handler: func(s Session) { + <-doneCh + }, + SessionRequestCallback: func(sess Session, requestType string) bool { + sshSession = sess.(*session) + return true + }, + } + session, client, cleanup := newTestSession(t, srv, nil) + defer cleanup() + + errChan := make(chan error, 1) + go func() { + errChan <- session.Run("") + }() + + for i := 0; i < 100; i++ { + ok, reply, err := client.SendRequest(keepAliveRequestType, true, nil) + require.NoError(t, err) + require.False(t, ok) // server replied + require.Empty(t, reply) + + time.Sleep(10 * time.Millisecond) + } + doneCh <- struct{}{} + + err := <-errChan + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + + // Verify that... + require.Equal(t, 100, sshSession.ctx.KeepAlive().Metrics().RequestHandlerCalled) // client sent keep-alive requests, + require.Equal(t, 100, sshSession.ctx.KeepAlive().Metrics().KeepAliveReplyReceived) // and server replied to all of them, + require.Zero(t, sshSession.ctx.KeepAlive().Metrics().ServerRequestedKeepAlive) // and server didn't send any extra requests. + }) + + t.Run("Server requests keep-alive reply", func(t *testing.T) { + t.Parallel() + + doneCh := make(chan struct{}) + defer close(doneCh) + + var sshSession *session + var m sync.Mutex + srv := &Server{ + ClientAliveInterval: 100 * time.Millisecond, + ClientAliveCountMax: 2, + Handler: func(s Session) { + <-doneCh + }, + SessionRequestCallback: func(sess Session, requestType string) bool { + m.Lock() + defer m.Unlock() + + sshSession = sess.(*session) + return true + }, + } + session, _, cleanup := newTestSession(t, srv, nil) + defer cleanup() + + errChan := make(chan error, 1) + go func() { + errChan <- session.Run("") + }() + + // Wait for client to reply to at least 10 keep-alive requests. + assert.Eventually(t, func() bool { + m.Lock() + defer m.Unlock() + + return sshSession != nil && sshSession.ctx.KeepAlive().Metrics().KeepAliveReplyReceived >= 10 + }, time.Second*3, time.Millisecond) + require.GreaterOrEqual(t, 10, sshSession.ctx.KeepAlive().Metrics().KeepAliveReplyReceived) + + doneCh <- struct{}{} + err := <-errChan + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + + // Verify that... + require.Zero(t, sshSession.ctx.KeepAlive().Metrics().RequestHandlerCalled) // client didn't send any keep-alive requests, + require.GreaterOrEqual(t, 10, sshSession.ctx.KeepAlive().Metrics().ServerRequestedKeepAlive) // server requested keep-alive replies + }) + + t.Run("Server terminates connection due to no keep-alive replies", func(t *testing.T) { + t.Parallel() + t.Skip("Go SSH client doesn't support disabling replies to keep-alive requests. We can't test it easily without mocking logic.") + }) +}