Skip to content

Commit f6b35a4

Browse files
committed
fix: avoid data race in session signal channel register/deregister
1 parent 9a7e234 commit f6b35a4

File tree

2 files changed

+119
-10
lines changed

2 files changed

+119
-10
lines changed

session.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ type session struct {
142142
rawCmd string
143143
subsystem string
144144
ctx Context
145+
sigMu sync.Mutex
145146
sigCh chan<- Signal
146147
sigBuf []Signal
147148
breakCh chan<- bool
@@ -247,16 +248,19 @@ func (sess *session) X11() (X11, bool) {
247248
}
248249

249250
func (sess *session) Signals(c chan<- Signal) {
250-
sess.Lock()
251-
defer sess.Unlock()
251+
sess.sigMu.Lock()
252252
sess.sigCh = c
253-
if len(sess.sigBuf) > 0 {
254-
go func() {
255-
for _, sig := range sess.sigBuf {
256-
sess.sigCh <- sig
257-
}
258-
}()
253+
if len(sess.sigBuf) == 0 || sess.sigCh == nil {
254+
sess.sigMu.Unlock()
255+
return
259256
}
257+
go func() {
258+
defer sess.sigMu.Unlock()
259+
for _, sig := range sess.sigBuf {
260+
c <- sig
261+
}
262+
sess.sigBuf = nil
263+
}()
260264
}
261265

262266
func (sess *session) Break(c chan<- bool) {
@@ -379,15 +383,15 @@ func (sess *session) handleRequests(ctx Context, reqs <-chan *gossh.Request) {
379383
case "signal":
380384
var payload struct{ Signal string }
381385
gossh.Unmarshal(req.Payload, &payload)
382-
sess.Lock()
386+
sess.sigMu.Lock()
383387
if sess.sigCh != nil {
384388
sess.sigCh <- Signal(payload.Signal)
385389
} else {
386390
if len(sess.sigBuf) < maxSigBufSize {
387391
sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal))
388392
}
389393
}
390-
sess.Unlock()
394+
sess.sigMu.Unlock()
391395
case "pty-req":
392396
if sess.handled || sess.pty != nil {
393397
req.Reply(false, nil)

session_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,111 @@ func TestSignals(t *testing.T) {
390390
}
391391
}
392392

393+
func TestSignalsRaceDeregisterAndReregister(t *testing.T) {
394+
t.Parallel()
395+
396+
numSignals := 128
397+
398+
// errChan lets us get errors back from the session
399+
errChan := make(chan error, 5)
400+
401+
// doneChan lets us specify that we should exit.
402+
doneChan := make(chan interface{})
403+
404+
// Channels to synchronize the handler and the test.
405+
handlerPreRegister := make(chan struct{})
406+
handlerPostRegister := make(chan struct{})
407+
signalInit := make(chan struct{})
408+
409+
session, _, cleanup := newTestSession(t, &Server{
410+
Handler: func(s Session) {
411+
// Single buffer slot, this is to make sure we don't miss
412+
// signals or send on nil a channel.
413+
signals := make(chan Signal, 1)
414+
415+
<-handlerPreRegister // Wait for initial signal buffering.
416+
417+
// Register signals.
418+
s.Signals(signals)
419+
close(handlerPostRegister) // Trigger post register signaling.
420+
421+
// Process signals so that we can don't see a deadlock.
422+
discarded := 0
423+
discardDone := make(chan struct{})
424+
go func() {
425+
defer close(discardDone)
426+
for range signals {
427+
discarded++
428+
}
429+
}()
430+
// Deregister signals.
431+
s.Signals(nil)
432+
// Close channel to close goroutine and ensure we don't send
433+
// on a closed channel.
434+
close(signals)
435+
<-discardDone
436+
437+
signals = make(chan Signal, 1)
438+
consumeDone := make(chan struct{})
439+
go func() {
440+
defer close(consumeDone)
441+
442+
for i := 0; i < numSignals-discarded; i++ {
443+
select {
444+
case sig := <-signals:
445+
if sig != SIGHUP {
446+
errChan <- fmt.Errorf("expected signal %v but got %v", SIGHUP, sig)
447+
return
448+
}
449+
case <-doneChan:
450+
errChan <- fmt.Errorf("Unexpected done")
451+
return
452+
}
453+
}
454+
}()
455+
456+
// Re-register signals and make sure we don't miss any.
457+
s.Signals(signals)
458+
close(signalInit)
459+
460+
<-consumeDone
461+
},
462+
}, nil)
463+
defer cleanup()
464+
465+
go func() {
466+
// Send 1/4th directly to buffer.
467+
for i := 0; i < numSignals/4; i++ {
468+
session.Signal(gossh.SIGHUP)
469+
}
470+
close(handlerPreRegister)
471+
<-handlerPostRegister
472+
// Send 1/4th to channel or buffer.
473+
for i := 0; i < numSignals/4; i++ {
474+
session.Signal(gossh.SIGHUP)
475+
}
476+
// Send final 1/2 to channel.
477+
<-signalInit
478+
for i := 0; i < numSignals/2; i++ {
479+
session.Signal(gossh.SIGHUP)
480+
}
481+
}()
482+
483+
go func() {
484+
errChan <- session.Run("")
485+
}()
486+
487+
select {
488+
case err := <-errChan:
489+
close(doneChan)
490+
if err != nil {
491+
t.Fatalf("expected nil but got %v", err)
492+
}
493+
case <-time.After(5 * time.Second):
494+
t.Fatalf("timed out waiting for session to exit")
495+
}
496+
}
497+
393498
func TestBreakWithChanRegistered(t *testing.T) {
394499
t.Parallel()
395500

0 commit comments

Comments
 (0)