Skip to content

Commit 70855de

Browse files
authored
fix: avoid data race in session signal channel register/deregister (#5)
1 parent 9a7e234 commit 70855de

File tree

2 files changed

+148
-25
lines changed

2 files changed

+148
-25
lines changed

session.go

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,25 @@ func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.Ne
127127
type session struct {
128128
sync.Mutex
129129
gossh.Channel
130-
conn *gossh.ServerConn
131-
handler Handler
132-
subsystemHandlers map[string]SubsystemHandler
133-
handled bool
134-
exited bool
135-
pty *Pty
136-
x11 *X11
137-
winch chan Window
138-
env []string
139-
ptyCb PtyCallback
140-
x11Cb X11Callback
141-
sessReqCb SessionRequestCallback
142-
rawCmd string
143-
subsystem string
144-
ctx Context
130+
conn *gossh.ServerConn
131+
handler Handler
132+
subsystemHandlers map[string]SubsystemHandler
133+
handled bool
134+
exited bool
135+
pty *Pty
136+
x11 *X11
137+
winch chan Window
138+
env []string
139+
ptyCb PtyCallback
140+
x11Cb X11Callback
141+
sessReqCb SessionRequestCallback
142+
rawCmd string
143+
subsystem string
144+
ctx Context
145+
// sigMu protects sigCh and sigBuf, it is made separate from the
146+
// session mutex to reduce the risk of deadlocks while we process
147+
// buffered signals.
148+
sigMu sync.Mutex
145149
sigCh chan<- Signal
146150
sigBuf []Signal
147151
breakCh chan<- bool
@@ -247,16 +251,30 @@ func (sess *session) X11() (X11, bool) {
247251
}
248252

249253
func (sess *session) Signals(c chan<- Signal) {
250-
sess.Lock()
251-
defer sess.Unlock()
254+
sess.sigMu.Lock()
252255
sess.sigCh = c
253-
if len(sess.sigBuf) > 0 {
254-
go func() {
255-
for _, sig := range sess.sigBuf {
256-
sess.sigCh <- sig
257-
}
258-
}()
256+
if len(sess.sigBuf) == 0 || sess.sigCh == nil {
257+
sess.sigMu.Unlock()
258+
return
259259
}
260+
// If we have buffered signals, we need to send them whilst
261+
// holding the signal mutex to avoid race conditions on sigCh
262+
// and sigBuf. We also guarantee that calling Signals(ch)
263+
// followed by Signals(nil) will have depleted the sigBuf when
264+
// the second call returns and that there will be no more
265+
// signals on ch. This is done in a goroutine so we can return
266+
// early and allow the caller to set up processing for the
267+
// channel even after calling Signals(ch).
268+
go func() {
269+
// Here we're relying on the mutex being locked in the outer
270+
// Signals() function, so we simply unlock it when we're done.
271+
defer sess.sigMu.Unlock()
272+
273+
for _, sig := range sess.sigBuf {
274+
sess.sigCh <- sig
275+
}
276+
sess.sigBuf = nil
277+
}()
260278
}
261279

262280
func (sess *session) Break(c chan<- bool) {
@@ -379,15 +397,15 @@ func (sess *session) handleRequests(ctx Context, reqs <-chan *gossh.Request) {
379397
case "signal":
380398
var payload struct{ Signal string }
381399
gossh.Unmarshal(req.Payload, &payload)
382-
sess.Lock()
400+
sess.sigMu.Lock()
383401
if sess.sigCh != nil {
384402
sess.sigCh <- Signal(payload.Signal)
385403
} else {
386404
if len(sess.sigBuf) < maxSigBufSize {
387405
sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal))
388406
}
389407
}
390-
sess.Unlock()
408+
sess.sigMu.Unlock()
391409
case "pty-req":
392410
if sess.handled || sess.pty != nil {
393411
req.Reply(false, nil)

session_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,111 @@ func TestSignals(t *testing.T) {
390390
}
391391
}
392392

393+
func TestSignalsRaceDeregisterAndReregister(t *testing.T) {
394+
t.Parallel()
395+
396+
numSignals := 128
397+
398+
// errChan lets us get errors back from the session
399+
errChan := make(chan error, 5)
400+
401+
// doneChan lets us specify that we should exit.
402+
doneChan := make(chan interface{})
403+
404+
// Channels to synchronize the handler and the test.
405+
handlerPreRegister := make(chan struct{})
406+
handlerPostRegister := make(chan struct{})
407+
signalInit := make(chan struct{})
408+
409+
session, _, cleanup := newTestSession(t, &Server{
410+
Handler: func(s Session) {
411+
// Single buffer slot, this is to make sure we don't miss
412+
// signals or send on nil a channel.
413+
signals := make(chan Signal, 1)
414+
415+
<-handlerPreRegister // Wait for initial signal buffering.
416+
417+
// Register signals.
418+
s.Signals(signals)
419+
close(handlerPostRegister) // Trigger post register signaling.
420+
421+
// Process signals so that we can don't see a deadlock.
422+
discarded := 0
423+
discardDone := make(chan struct{})
424+
go func() {
425+
defer close(discardDone)
426+
for range signals {
427+
discarded++
428+
}
429+
}()
430+
// Deregister signals.
431+
s.Signals(nil)
432+
// Close channel to close goroutine and ensure we don't send
433+
// on a closed channel.
434+
close(signals)
435+
<-discardDone
436+
437+
signals = make(chan Signal, 1)
438+
consumeDone := make(chan struct{})
439+
go func() {
440+
defer close(consumeDone)
441+
442+
for i := 0; i < numSignals-discarded; i++ {
443+
select {
444+
case sig := <-signals:
445+
if sig != SIGHUP {
446+
errChan <- fmt.Errorf("expected signal %v but got %v", SIGHUP, sig)
447+
return
448+
}
449+
case <-doneChan:
450+
errChan <- fmt.Errorf("Unexpected done")
451+
return
452+
}
453+
}
454+
}()
455+
456+
// Re-register signals and make sure we don't miss any.
457+
s.Signals(signals)
458+
close(signalInit)
459+
460+
<-consumeDone
461+
},
462+
}, nil)
463+
defer cleanup()
464+
465+
go func() {
466+
// Send 1/4th directly to buffer.
467+
for i := 0; i < numSignals/4; i++ {
468+
session.Signal(gossh.SIGHUP)
469+
}
470+
close(handlerPreRegister)
471+
<-handlerPostRegister
472+
// Send 1/4th to channel or buffer.
473+
for i := 0; i < numSignals/4; i++ {
474+
session.Signal(gossh.SIGHUP)
475+
}
476+
// Send final 1/2 to channel.
477+
<-signalInit
478+
for i := 0; i < numSignals/2; i++ {
479+
session.Signal(gossh.SIGHUP)
480+
}
481+
}()
482+
483+
go func() {
484+
errChan <- session.Run("")
485+
}()
486+
487+
select {
488+
case err := <-errChan:
489+
close(doneChan)
490+
if err != nil {
491+
t.Fatalf("expected nil but got %v", err)
492+
}
493+
case <-time.After(5 * time.Second):
494+
t.Fatalf("timed out waiting for session to exit")
495+
}
496+
}
497+
393498
func TestBreakWithChanRegistered(t *testing.T) {
394499
t.Parallel()
395500

0 commit comments

Comments
 (0)