Skip to content

fix: avoid data race in session signal channel register/deregister #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 43 additions & 25 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,25 @@ func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.Ne
type session struct {
sync.Mutex
gossh.Channel
conn *gossh.ServerConn
handler Handler
subsystemHandlers map[string]SubsystemHandler
handled bool
exited bool
pty *Pty
x11 *X11
winch chan Window
env []string
ptyCb PtyCallback
x11Cb X11Callback
sessReqCb SessionRequestCallback
rawCmd string
subsystem string
ctx Context
conn *gossh.ServerConn
handler Handler
subsystemHandlers map[string]SubsystemHandler
handled bool
exited bool
pty *Pty
x11 *X11
winch chan Window
env []string
ptyCb PtyCallback
x11Cb X11Callback
sessReqCb SessionRequestCallback
rawCmd string
subsystem string
ctx Context
// sigMu protects sigCh and sigBuf, it is made separate from the
// session mutex to reduce the risk of deadlocks while we process
// buffered signals.
sigMu sync.Mutex
sigCh chan<- Signal
sigBuf []Signal
breakCh chan<- bool
Expand Down Expand Up @@ -247,16 +251,30 @@ func (sess *session) X11() (X11, bool) {
}

func (sess *session) Signals(c chan<- Signal) {
sess.Lock()
defer sess.Unlock()
sess.sigMu.Lock()
sess.sigCh = c
if len(sess.sigBuf) > 0 {
go func() {
for _, sig := range sess.sigBuf {
sess.sigCh <- sig
}
}()
if len(sess.sigBuf) == 0 || sess.sigCh == nil {
sess.sigMu.Unlock()
return
}
// If we have buffered signals, we need to send them whilst
// holding the signal mutex to avoid race conditions on sigCh
// and sigBuf. We also guarantee that calling Signals(ch)
// followed by Signals(nil) will have depleted the sigBuf when
// the second call returns and that there will be no more
// signals on ch. This is done in a goroutine so we can return
// early and allow the caller to set up processing for the
// channel even after calling Signals(ch).
go func() {
// Here we're relying on the mutex being locked in the outer
// Signals() function, so we simply unlock it when we're done.
defer sess.sigMu.Unlock()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is unorthodox enough to warrant a comment; it looked like a bug to me at first.

Copy link
Member Author

@mafredri mafredri Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to captured the motivation in 45772fa (and 773ac57) now, thoughts?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what I meant is the fact that you're calling

defer sess.sigMu.Unlock()

without the corresponding Lock() in the same function looks very strange to me! We often talk about a function "holding" the lock while it runs, but here you "transfer" the lock from the Signals() function to the child goroutine so the parent can return.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I understood that's what you meant. Did you feel the motivation for it was lacking?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an aside, this is explicitly documented in https://pkg.go.dev/sync#Mutex.Unlock

A locked Mutex is not associated with a particular goroutine. It is allowed for one goroutine to lock a Mutex and then arrange for another goroutine to unlock it.

Not that I'm making any claims about this being standard in any way... I was simply trying to retrofit an edge-cases covered fix into the existing code without too many changes.

Interestingly, there's been a proposal to disallow this, but it didn't gain traction: golang/go#9201

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was actually just about the mechanics --- like

// here we're relying on the fact that the mutex was locked in the outer `Signal()`
// function, so this goroutine/function just has to unlock the mutex when it is done


for _, sig := range sess.sigBuf {
sess.sigCh <- sig
}
sess.sigBuf = nil
}()
}

func (sess *session) Break(c chan<- bool) {
Expand Down Expand Up @@ -379,15 +397,15 @@ func (sess *session) handleRequests(ctx Context, reqs <-chan *gossh.Request) {
case "signal":
var payload struct{ Signal string }
gossh.Unmarshal(req.Payload, &payload)
sess.Lock()
sess.sigMu.Lock()
if sess.sigCh != nil {
sess.sigCh <- Signal(payload.Signal)
} else {
if len(sess.sigBuf) < maxSigBufSize {
sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal))
}
}
sess.Unlock()
sess.sigMu.Unlock()
case "pty-req":
if sess.handled || sess.pty != nil {
req.Reply(false, nil)
Expand Down
105 changes: 105 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,111 @@ func TestSignals(t *testing.T) {
}
}

func TestSignalsRaceDeregisterAndReregister(t *testing.T) {
t.Parallel()

numSignals := 128

// errChan lets us get errors back from the session
errChan := make(chan error, 5)

// doneChan lets us specify that we should exit.
doneChan := make(chan interface{})

// Channels to synchronize the handler and the test.
handlerPreRegister := make(chan struct{})
handlerPostRegister := make(chan struct{})
signalInit := make(chan struct{})

session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
// Single buffer slot, this is to make sure we don't miss
// signals or send on nil a channel.
signals := make(chan Signal, 1)

<-handlerPreRegister // Wait for initial signal buffering.

// Register signals.
s.Signals(signals)
close(handlerPostRegister) // Trigger post register signaling.

// Process signals so that we can don't see a deadlock.
discarded := 0
discardDone := make(chan struct{})
go func() {
defer close(discardDone)
for range signals {
discarded++
}
}()
// Deregister signals.
s.Signals(nil)
// Close channel to close goroutine and ensure we don't send
// on a closed channel.
close(signals)
<-discardDone

signals = make(chan Signal, 1)
consumeDone := make(chan struct{})
go func() {
defer close(consumeDone)

for i := 0; i < numSignals-discarded; i++ {
select {
case sig := <-signals:
if sig != SIGHUP {
errChan <- fmt.Errorf("expected signal %v but got %v", SIGHUP, sig)
return
}
case <-doneChan:
errChan <- fmt.Errorf("Unexpected done")
return
}
}
}()

// Re-register signals and make sure we don't miss any.
s.Signals(signals)
close(signalInit)

<-consumeDone
},
}, nil)
defer cleanup()

go func() {
// Send 1/4th directly to buffer.
for i := 0; i < numSignals/4; i++ {
session.Signal(gossh.SIGHUP)
}
close(handlerPreRegister)
<-handlerPostRegister
// Send 1/4th to channel or buffer.
for i := 0; i < numSignals/4; i++ {
session.Signal(gossh.SIGHUP)
}
// Send final 1/2 to channel.
<-signalInit
for i := 0; i < numSignals/2; i++ {
session.Signal(gossh.SIGHUP)
}
}()

go func() {
errChan <- session.Run("")
}()

select {
case err := <-errChan:
close(doneChan)
if err != nil {
t.Fatalf("expected nil but got %v", err)
}
case <-time.After(5 * time.Second):
t.Fatalf("timed out waiting for session to exit")
}
}

func TestBreakWithChanRegistered(t *testing.T) {
t.Parallel()

Expand Down