Skip to content

feat: Support keep-alive messages #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions _examples/ssh-keepalive/keepalive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package main

import (
"log"
"time"

"github.com/gliderlabs/ssh"
)

var (
keepAliveInterval = 3 * time.Second
keepAliveCountMax = 3
)

func main() {
ssh.Handle(func(s ssh.Session) {
log.Println("new connection")
i := 0
for {
i += 1
log.Println("active seconds:", i)
select {
case <-time.After(time.Second):
continue
case <-s.Context().Done():
log.Println("connection closed")
return
}
}
})

log.Println("starting ssh server on port 2222...")
log.Printf("keep-alive mode is on: %s\n", keepAliveInterval)
server := &ssh.Server{
Addr: ":2222",
ClientAliveInterval: keepAliveInterval,
ClientAliveCountMax: keepAliveCountMax,
}
log.Fatal(server.ListenAndServe())
}
15 changes: 15 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/hex"
"net"
"sync"
"time"

gossh "golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -55,6 +56,8 @@ var (
// ContextKeyPublicKey is a context key for use with Contexts in this package.
// The associated value will be of type PublicKey.
ContextKeyPublicKey = &contextKey{"public-key"}

ContextKeyKeepAlive = &contextKey{"keep-alive"}
)

// Context is a package specific context interface. It exposes connection
Expand Down Expand Up @@ -87,6 +90,9 @@ type Context interface {
// Permissions returns the Permissions object used for this connection.
Permissions() *Permissions

// KeepAlive returns the SessionKeepAlive object used for checking the status of a user connection.
KeepAlive() *SessionKeepAlive

// SetValue allows you to easily write new values into the underlying context.
SetValue(key, value interface{})
}
Expand Down Expand Up @@ -119,6 +125,11 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) {
ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
}

func applyKeepAlive(ctx Context, clientAliveInterval time.Duration, clientAliveCountMax int) {
keepAlive := NewSessionKeepAlive(clientAliveInterval, clientAliveCountMax)
ctx.SetValue(ContextKeyKeepAlive, keepAlive)
}

func (ctx *sshContext) SetValue(key, value interface{}) {
ctx.Context = context.WithValue(ctx.Context, key, value)
}
Expand Down Expand Up @@ -153,3 +164,7 @@ func (ctx *sshContext) LocalAddr() net.Addr {
func (ctx *sshContext) Permissions() *Permissions {
return ctx.Value(ContextKeyPermissions).(*Permissions)
}

func (ctx *sshContext) KeepAlive() *SessionKeepAlive {
return ctx.Value(ContextKeyKeepAlive).(*SessionKeepAlive)
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.12

require (
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/stretchr/testify v1.8.4
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 // indirect
)
17 changes: 17 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
Expand All @@ -11,3 +23,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
99 changes: 99 additions & 0 deletions keepalive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package ssh

import (
"sync"
"time"
)

type SessionKeepAlive struct {
clientAliveInterval time.Duration
clientAliveCountMax int

ticker *time.Ticker
tickerCh <-chan time.Time
lastReceived time.Time

metrics KeepAliveMetrics

m sync.Mutex
closed bool
}

func NewSessionKeepAlive(clientAliveInterval time.Duration, clientAliveCountMax int) *SessionKeepAlive {
var t *time.Ticker
var tickerCh <-chan time.Time
if clientAliveInterval > 0 {
t = time.NewTicker(clientAliveInterval)
tickerCh = t.C
}

return &SessionKeepAlive{
clientAliveInterval: clientAliveInterval,
clientAliveCountMax: clientAliveCountMax,
ticker: t,
tickerCh: tickerCh,
lastReceived: time.Now(),
}
}

func (ska *SessionKeepAlive) RequestHandlerCallback() {
ska.m.Lock()
ska.metrics.RequestHandlerCalled++
ska.m.Unlock()

ska.Reset()
}

func (ska *SessionKeepAlive) ServerRequestedKeepAliveCallback() {
ska.m.Lock()
defer ska.m.Unlock()

ska.metrics.ServerRequestedKeepAlive++
}

func (ska *SessionKeepAlive) Reset() {
ska.m.Lock()
defer ska.m.Unlock()

ska.metrics.KeepAliveReplyReceived++

if ska.ticker != nil && !ska.closed {
ska.lastReceived = time.Now()
ska.ticker.Reset(ska.clientAliveInterval)
}
}

func (ska *SessionKeepAlive) Ticks() <-chan time.Time {
return ska.tickerCh
}

func (ska *SessionKeepAlive) TimeIsUp() bool {
ska.m.Lock()
defer ska.m.Unlock()

// true: Keep-alive reply not received
return ska.lastReceived.Add(time.Duration(ska.clientAliveCountMax) * ska.clientAliveInterval).Before(time.Now())
}

func (ska *SessionKeepAlive) Close() {
ska.m.Lock()
defer ska.m.Unlock()

if ska.ticker != nil {
ska.ticker.Stop()
}
ska.closed = true
}

func (ska *SessionKeepAlive) Metrics() KeepAliveMetrics {
ska.m.Lock()
defer ska.m.Unlock()

return ska.metrics
}

type KeepAliveMetrics struct {
RequestHandlerCalled int
KeepAliveReplyReceived int
ServerRequestedKeepAlive int
}
13 changes: 12 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ var DefaultSubsystemHandlers = map[string]SubsystemHandler{}

type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)

var DefaultRequestHandlers = map[string]RequestHandler{}
var DefaultRequestHandlers = map[string]RequestHandler{
keepAliveRequestType: KeepAliveRequestHandler,
}

type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)

Expand Down Expand Up @@ -68,6 +70,9 @@ type Server struct {
// handlers, but handle named subsystems.
SubsystemHandlers map[string]SubsystemHandler

ClientAliveInterval time.Duration
ClientAliveCountMax int

listenerWg sync.WaitGroup
mu sync.RWMutex
listeners map[net.Listener]struct{}
Expand Down Expand Up @@ -222,6 +227,10 @@ func (srv *Server) Shutdown(ctx context.Context) error {
//
// Serve always returns a non-nil error.
func (srv *Server) Serve(l net.Listener) error {
if (srv.ClientAliveInterval != 0 && srv.ClientAliveCountMax == 0) || (srv.ClientAliveInterval == 0 && srv.ClientAliveCountMax != 0) {
return fmt.Errorf("ClientAliveInterval and ClientAliveCountMax must be set together")
}

srv.ensureHandlers()
defer l.Close()
if err := srv.ensureHostSigner(); err != nil {
Expand Down Expand Up @@ -292,6 +301,8 @@ func (srv *Server) HandleConn(newConn net.Conn) {

ctx.SetValue(ContextKeyConn, sshConn)
applyConnMetadata(ctx, sshConn)
// To prevent race conditions, we need to configure the keep-alive before goroutines kick off
applyKeepAlive(ctx, srv.ClientAliveInterval, srv.ClientAliveCountMax)
//go gossh.DiscardRequests(reqs)
go srv.handleRequests(ctx, reqs)
for ch := range chans {
Expand Down
Loading