Skip to content

Commit fc6e4b0

Browse files
authored
feat: Support keep-alive messages (#3)
1 parent 04bb837 commit fc6e4b0

File tree

8 files changed

+490
-132
lines changed

8 files changed

+490
-132
lines changed

_examples/ssh-keepalive/keepalive.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package main
2+
3+
import (
4+
"log"
5+
"time"
6+
7+
"github.com/gliderlabs/ssh"
8+
)
9+
10+
var (
11+
keepAliveInterval = 3 * time.Second
12+
keepAliveCountMax = 3
13+
)
14+
15+
func main() {
16+
ssh.Handle(func(s ssh.Session) {
17+
log.Println("new connection")
18+
i := 0
19+
for {
20+
i += 1
21+
log.Println("active seconds:", i)
22+
select {
23+
case <-time.After(time.Second):
24+
continue
25+
case <-s.Context().Done():
26+
log.Println("connection closed")
27+
return
28+
}
29+
}
30+
})
31+
32+
log.Println("starting ssh server on port 2222...")
33+
log.Printf("keep-alive mode is on: %s\n", keepAliveInterval)
34+
server := &ssh.Server{
35+
Addr: ":2222",
36+
ClientAliveInterval: keepAliveInterval,
37+
ClientAliveCountMax: keepAliveCountMax,
38+
}
39+
log.Fatal(server.ListenAndServe())
40+
}

context.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/hex"
66
"net"
77
"sync"
8+
"time"
89

910
gossh "golang.org/x/crypto/ssh"
1011
)
@@ -55,6 +56,8 @@ var (
5556
// ContextKeyPublicKey is a context key for use with Contexts in this package.
5657
// The associated value will be of type PublicKey.
5758
ContextKeyPublicKey = &contextKey{"public-key"}
59+
60+
ContextKeyKeepAlive = &contextKey{"keep-alive"}
5861
)
5962

6063
// Context is a package specific context interface. It exposes connection
@@ -87,6 +90,9 @@ type Context interface {
8790
// Permissions returns the Permissions object used for this connection.
8891
Permissions() *Permissions
8992

93+
// KeepAlive returns the SessionKeepAlive object used for checking the status of a user connection.
94+
KeepAlive() *SessionKeepAlive
95+
9096
// SetValue allows you to easily write new values into the underlying context.
9197
SetValue(key, value interface{})
9298
}
@@ -119,6 +125,11 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) {
119125
ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
120126
}
121127

128+
func applyKeepAlive(ctx Context, clientAliveInterval time.Duration, clientAliveCountMax int) {
129+
keepAlive := NewSessionKeepAlive(clientAliveInterval, clientAliveCountMax)
130+
ctx.SetValue(ContextKeyKeepAlive, keepAlive)
131+
}
132+
122133
func (ctx *sshContext) SetValue(key, value interface{}) {
123134
ctx.Context = context.WithValue(ctx.Context, key, value)
124135
}
@@ -153,3 +164,7 @@ func (ctx *sshContext) LocalAddr() net.Addr {
153164
func (ctx *sshContext) Permissions() *Permissions {
154165
return ctx.Value(ContextKeyPermissions).(*Permissions)
155166
}
167+
168+
func (ctx *sshContext) KeepAlive() *SessionKeepAlive {
169+
return ctx.Value(ContextKeyKeepAlive).(*SessionKeepAlive)
170+
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ go 1.12
44

55
require (
66
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
7+
github.com/stretchr/testify v1.8.4
78
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e
89
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 // indirect
910
)

go.sum

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
22
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
3+
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
5+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
6+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
7+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
8+
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
9+
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
10+
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
11+
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
12+
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
13+
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
14+
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
315
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI=
416
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
517
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -11,3 +23,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4
1123
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
1224
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
1325
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
26+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
27+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
28+
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
29+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
30+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

keepalive.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package ssh
2+
3+
import (
4+
"sync"
5+
"time"
6+
)
7+
8+
type SessionKeepAlive struct {
9+
clientAliveInterval time.Duration
10+
clientAliveCountMax int
11+
12+
ticker *time.Ticker
13+
tickerCh <-chan time.Time
14+
lastReceived time.Time
15+
16+
metrics KeepAliveMetrics
17+
18+
m sync.Mutex
19+
closed bool
20+
}
21+
22+
func NewSessionKeepAlive(clientAliveInterval time.Duration, clientAliveCountMax int) *SessionKeepAlive {
23+
var t *time.Ticker
24+
var tickerCh <-chan time.Time
25+
if clientAliveInterval > 0 {
26+
t = time.NewTicker(clientAliveInterval)
27+
tickerCh = t.C
28+
}
29+
30+
return &SessionKeepAlive{
31+
clientAliveInterval: clientAliveInterval,
32+
clientAliveCountMax: clientAliveCountMax,
33+
ticker: t,
34+
tickerCh: tickerCh,
35+
lastReceived: time.Now(),
36+
}
37+
}
38+
39+
func (ska *SessionKeepAlive) RequestHandlerCallback() {
40+
ska.m.Lock()
41+
ska.metrics.RequestHandlerCalled++
42+
ska.m.Unlock()
43+
44+
ska.Reset()
45+
}
46+
47+
func (ska *SessionKeepAlive) ServerRequestedKeepAliveCallback() {
48+
ska.m.Lock()
49+
defer ska.m.Unlock()
50+
51+
ska.metrics.ServerRequestedKeepAlive++
52+
}
53+
54+
func (ska *SessionKeepAlive) Reset() {
55+
ska.m.Lock()
56+
defer ska.m.Unlock()
57+
58+
ska.metrics.KeepAliveReplyReceived++
59+
60+
if ska.ticker != nil && !ska.closed {
61+
ska.lastReceived = time.Now()
62+
ska.ticker.Reset(ska.clientAliveInterval)
63+
}
64+
}
65+
66+
func (ska *SessionKeepAlive) Ticks() <-chan time.Time {
67+
return ska.tickerCh
68+
}
69+
70+
func (ska *SessionKeepAlive) TimeIsUp() bool {
71+
ska.m.Lock()
72+
defer ska.m.Unlock()
73+
74+
// true: Keep-alive reply not received
75+
return ska.lastReceived.Add(time.Duration(ska.clientAliveCountMax) * ska.clientAliveInterval).Before(time.Now())
76+
}
77+
78+
func (ska *SessionKeepAlive) Close() {
79+
ska.m.Lock()
80+
defer ska.m.Unlock()
81+
82+
if ska.ticker != nil {
83+
ska.ticker.Stop()
84+
}
85+
ska.closed = true
86+
}
87+
88+
func (ska *SessionKeepAlive) Metrics() KeepAliveMetrics {
89+
ska.m.Lock()
90+
defer ska.m.Unlock()
91+
92+
return ska.metrics
93+
}
94+
95+
type KeepAliveMetrics struct {
96+
RequestHandlerCalled int
97+
KeepAliveReplyReceived int
98+
ServerRequestedKeepAlive int
99+
}

server.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ var DefaultSubsystemHandlers = map[string]SubsystemHandler{}
2121

2222
type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
2323

24-
var DefaultRequestHandlers = map[string]RequestHandler{}
24+
var DefaultRequestHandlers = map[string]RequestHandler{
25+
keepAliveRequestType: KeepAliveRequestHandler,
26+
}
2527

2628
type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
2729

@@ -68,6 +70,9 @@ type Server struct {
6870
// handlers, but handle named subsystems.
6971
SubsystemHandlers map[string]SubsystemHandler
7072

73+
ClientAliveInterval time.Duration
74+
ClientAliveCountMax int
75+
7176
listenerWg sync.WaitGroup
7277
mu sync.RWMutex
7378
listeners map[net.Listener]struct{}
@@ -222,6 +227,10 @@ func (srv *Server) Shutdown(ctx context.Context) error {
222227
//
223228
// Serve always returns a non-nil error.
224229
func (srv *Server) Serve(l net.Listener) error {
230+
if (srv.ClientAliveInterval != 0 && srv.ClientAliveCountMax == 0) || (srv.ClientAliveInterval == 0 && srv.ClientAliveCountMax != 0) {
231+
return fmt.Errorf("ClientAliveInterval and ClientAliveCountMax must be set together")
232+
}
233+
225234
srv.ensureHandlers()
226235
defer l.Close()
227236
if err := srv.ensureHostSigner(); err != nil {
@@ -292,6 +301,8 @@ func (srv *Server) HandleConn(newConn net.Conn) {
292301

293302
ctx.SetValue(ContextKeyConn, sshConn)
294303
applyConnMetadata(ctx, sshConn)
304+
// To prevent race conditions, we need to configure the keep-alive before goroutines kick off
305+
applyKeepAlive(ctx, srv.ClientAliveInterval, srv.ClientAliveCountMax)
295306
//go gossh.DiscardRequests(reqs)
296307
go srv.handleRequests(ctx, reqs)
297308
for ch := range chans {

0 commit comments

Comments
 (0)