Skip to content

Commit fb34512

Browse files
meislerjJacob Meisler
and
Jacob Meisler
authored
Register chan to Session to listen for break requests (gliderlabs#141)
Co-authored-by: Jacob Meisler <meislerj@amazon.com>
1 parent 76cadaa commit fb34512

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed

session.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ type Session interface {
7777
// If there are buffered signals when a channel is registered, they will be
7878
// sent in order on the channel immediately after registering.
7979
Signals(c chan<- Signal)
80+
81+
// Break regisers a channel to receive notifications of break requests sent
82+
// from the client. The channel must handle break requests, or it will block
83+
// the request handling loop. Registering nil will unregister the channel.
84+
// During the time that no channel is registered, breaks are ignored.
85+
Break(c chan<- bool)
8086
}
8187

8288
// maxSigBufSize is how many signals will be buffered
@@ -119,6 +125,7 @@ type session struct {
119125
ctx Context
120126
sigCh chan<- Signal
121127
sigBuf []Signal
128+
breakCh chan<- bool
122129
}
123130

124131
func (sess *session) Write(p []byte) (n int, err error) {
@@ -221,6 +228,12 @@ func (sess *session) Signals(c chan<- Signal) {
221228
}
222229
}
223230

231+
func (sess *session) Break(c chan<- bool) {
232+
sess.Lock()
233+
defer sess.Unlock()
234+
sess.breakCh = c
235+
}
236+
224237
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
225238
for req := range reqs {
226239
switch req.Type {
@@ -344,6 +357,15 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
344357
// TODO: option/callback to allow agent forwarding
345358
SetAgentRequested(sess.ctx)
346359
req.Reply(true, nil)
360+
case "break":
361+
ok := false
362+
sess.Lock()
363+
if sess.breakCh != nil {
364+
sess.breakCh <- true
365+
ok = true
366+
}
367+
req.Reply(ok, nil)
368+
sess.Unlock()
347369
default:
348370
// TODO: debug log
349371
req.Reply(false, nil)

session_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,96 @@ func TestSignals(t *testing.T) {
343343
t.Fatalf("expected nil but got %v", err)
344344
}
345345
}
346+
347+
func TestBreakWithChanRegistered(t *testing.T) {
348+
t.Parallel()
349+
350+
// errChan lets us get errors back from the session
351+
errChan := make(chan error, 5)
352+
353+
// doneChan lets us specify that we should exit.
354+
doneChan := make(chan interface{})
355+
356+
breakChan := make(chan bool)
357+
358+
readyToReceiveBreak := make(chan bool)
359+
360+
session, _, cleanup := newTestSession(t, &Server{
361+
Handler: func(s Session) {
362+
s.Break(breakChan) // register a break channel with the session
363+
readyToReceiveBreak <- true
364+
365+
select {
366+
case <-breakChan:
367+
io.WriteString(s, "break")
368+
case <-doneChan:
369+
errChan <- fmt.Errorf("Unexpected done")
370+
return
371+
}
372+
},
373+
}, nil)
374+
defer cleanup()
375+
var stdout bytes.Buffer
376+
session.Stdout = &stdout
377+
go func() {
378+
errChan <- session.Run("")
379+
}()
380+
381+
<-readyToReceiveBreak
382+
ok, err := session.SendRequest("break", true, nil)
383+
if err != nil {
384+
t.Fatalf("expected nil but got %v", err)
385+
}
386+
if ok != true {
387+
t.Fatalf("expected true but got %v", ok)
388+
}
389+
390+
err = <-errChan
391+
close(doneChan)
392+
393+
if err != nil {
394+
t.Fatalf("expected nil but got %v", err)
395+
}
396+
if !bytes.Equal(stdout.Bytes(), []byte("break")) {
397+
t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes())
398+
}
399+
}
400+
401+
func TestBreakWithoutChanRegistered(t *testing.T) {
402+
t.Parallel()
403+
404+
// errChan lets us get errors back from the session
405+
errChan := make(chan error, 5)
406+
407+
// doneChan lets us specify that we should exit.
408+
doneChan := make(chan interface{})
409+
410+
waitUntilAfterBreakSent := make(chan bool)
411+
412+
session, _, cleanup := newTestSession(t, &Server{
413+
Handler: func(s Session) {
414+
<-waitUntilAfterBreakSent
415+
},
416+
}, nil)
417+
defer cleanup()
418+
var stdout bytes.Buffer
419+
session.Stdout = &stdout
420+
go func() {
421+
errChan <- session.Run("")
422+
}()
423+
424+
ok, err := session.SendRequest("break", true, nil)
425+
if err != nil {
426+
t.Fatalf("expected nil but got %v", err)
427+
}
428+
if ok != false {
429+
t.Fatalf("expected false but got %v", ok)
430+
}
431+
waitUntilAfterBreakSent <- true
432+
433+
err = <-errChan
434+
close(doneChan)
435+
if err != nil {
436+
t.Fatalf("expected nil but got %v", err)
437+
}
438+
}

0 commit comments

Comments
 (0)