Skip to content

Commit 9b88a68

Browse files
committed
Protect map and state with the same mutex
I moved the conn closes back to the lifecycle, too.
1 parent 56ca7ac commit 9b88a68

File tree

3 files changed

+58
-58
lines changed

3 files changed

+58
-58
lines changed

agent/reconnectingpty/buffered.go

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"io"
77
"net"
8-
"sync"
98
"time"
109

1110
"github.com/armon/circbuf"
@@ -23,9 +22,6 @@ import (
2322
type bufferedReconnectingPTY struct {
2423
command *pty.Cmd
2524

26-
// mutex protects writing to the circular buffer and connections.
27-
mutex sync.RWMutex
28-
2925
activeConns map[string]net.Conn
3026
circularBuffer *circbuf.Buffer
3127

@@ -100,7 +96,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
10096
break
10197
}
10298
part := buffer[:read]
103-
rpty.mutex.Lock()
99+
rpty.state.cond.L.Lock()
104100
_, err = rpty.circularBuffer.Write(part)
105101
if err != nil {
106102
logger.Error(ctx, "write to circular buffer", slog.Error(err))
@@ -119,7 +115,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
119115
rpty.metrics.WithLabelValues("write").Add(1)
120116
}
121117
}
122-
rpty.mutex.Unlock()
118+
rpty.state.cond.L.Unlock()
123119
}
124120
}()
125121

@@ -136,14 +132,29 @@ func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog.
136132
logger.Debug(ctx, "reconnecting pty ready")
137133
rpty.state.setState(StateReady, nil)
138134

139-
state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing)
135+
state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing, nil)
140136
if state < StateClosing {
141137
// If we have not closed yet then the context is what unblocked us (which
142138
// means the agent is shutting down) so move into the closing phase.
143139
rpty.Close(reasonErr.Error())
144140
}
145141
rpty.timer.Stop()
146142

143+
rpty.state.cond.L.Lock()
144+
// Log these closes only for debugging since the connections or processes
145+
// might have already closed on their own.
146+
for _, conn := range rpty.activeConns {
147+
err := conn.Close()
148+
if err != nil {
149+
logger.Debug(ctx, "closed conn with error", slog.Error(err))
150+
}
151+
}
152+
// Connections get removed once they close but it is possible there is still
153+
// some data that will be written before that happens so clear the map now to
154+
// avoid writing to closed connections.
155+
rpty.activeConns = map[string]net.Conn{}
156+
rpty.state.cond.L.Unlock()
157+
147158
// Log close/kill only for debugging since the process might have already
148159
// closed on its own.
149160
err := rpty.ptty.Close()
@@ -167,65 +178,49 @@ func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string,
167178
ctx, cancel := context.WithCancel(ctx)
168179
defer cancel()
169180

170-
state, err := rpty.state.waitForStateOrContext(ctx, StateReady)
171-
if state != StateReady {
172-
return xerrors.Errorf("reconnecting pty ready wait: %w", err)
173-
}
181+
// Once we are ready, attach the active connection while we hold the mutex.
182+
_, err := rpty.state.waitForStateOrContext(ctx, StateReady, func(state State, err error) error {
183+
if state != StateReady {
184+
return xerrors.Errorf("reconnecting pty ready wait: %w", err)
185+
}
186+
187+
go heartbeat(ctx, rpty.timer, rpty.timeout)
188+
189+
// Resize the PTY to initial height + width.
190+
err = rpty.ptty.Resize(height, width)
191+
if err != nil {
192+
// We can continue after this, it's not fatal!
193+
logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
194+
rpty.metrics.WithLabelValues("resize").Add(1)
195+
}
174196

175-
go heartbeat(ctx, rpty.timer, rpty.timeout)
197+
// Write any previously stored data for the TTY and store the connection for
198+
// future writes.
199+
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
200+
_, err = conn.Write(prevBuf)
201+
if err != nil {
202+
rpty.metrics.WithLabelValues("write").Add(1)
203+
return xerrors.Errorf("write buffer to conn: %w", err)
204+
}
205+
rpty.activeConns[connID] = conn
176206

177-
err = rpty.doAttach(ctx, connID, conn, height, width, logger)
207+
return nil
208+
})
178209
if err != nil {
179210
return err
180211
}
181212

182-
go func() {
183-
_, _ = rpty.state.waitForStateOrContext(ctx, StateClosing)
184-
rpty.mutex.Lock()
185-
defer rpty.mutex.Unlock()
213+
defer func() {
214+
rpty.state.cond.L.Lock()
215+
defer rpty.state.cond.L.Unlock()
186216
delete(rpty.activeConns, connID)
187-
// Log closes only for debugging since the connection might have already
188-
// closed on its own.
189-
err := conn.Close()
190-
if err != nil {
191-
logger.Debug(ctx, "closed conn with error", slog.Error(err))
192-
}
193217
}()
194218

195219
// Pipe conn -> pty and block. pty -> conn is handled in newBuffered().
196220
readConnLoop(ctx, conn, rpty.ptty, rpty.metrics, logger)
197221
return nil
198222
}
199223

200-
// doAttach adds the connection to the map, replays the buffer, and starts the
201-
// heartbeat. It exists separately only so we can defer the mutex unlock which
202-
// is not possible in Attach since it blocks.
203-
func (rpty *bufferedReconnectingPTY) doAttach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error {
204-
// Ensure we do not write to or close connections while we attach.
205-
rpty.mutex.Lock()
206-
defer rpty.mutex.Unlock()
207-
208-
// Resize the PTY to initial height + width.
209-
err := rpty.ptty.Resize(height, width)
210-
if err != nil {
211-
// We can continue after this, it's not fatal!
212-
logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
213-
rpty.metrics.WithLabelValues("resize").Add(1)
214-
}
215-
216-
// Write any previously stored data for the TTY and store the connection for
217-
// future writes.
218-
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
219-
_, err = conn.Write(prevBuf)
220-
if err != nil {
221-
rpty.metrics.WithLabelValues("write").Add(1)
222-
return xerrors.Errorf("write buffer to conn: %w", err)
223-
}
224-
rpty.activeConns[connID] = conn
225-
226-
return nil
227-
}
228-
229224
func (rpty *bufferedReconnectingPTY) Wait() {
230225
_, _ = rpty.state.waitForState(StateClosing)
231226
}

agent/reconnectingpty/reconnectingpty.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ func (s *ptyState) waitForState(state State) (State, error) {
167167
}
168168

169169
// waitForStateOrContext blocks until the state or a greater one is reached or
170-
// the provided context ends.
171-
func (s *ptyState) waitForStateOrContext(ctx context.Context, state State) (State, error) {
170+
// the provided context ends. If fn is non-nil it will be ran while the lock is
171+
// held and fn's error will replace waitForStateOrContext's error.
172+
func (s *ptyState) waitForStateOrContext(ctx context.Context, state State, fn func(state State, err error) error) (State, error) {
172173
nevermind := make(chan struct{})
173174
defer close(nevermind)
174175
go func() {
@@ -185,10 +186,14 @@ func (s *ptyState) waitForStateOrContext(ctx context.Context, state State) (Stat
185186
for ctx.Err() == nil && state > s.state {
186187
s.cond.Wait()
187188
}
189+
err := s.error
188190
if ctx.Err() != nil {
189-
return s.state, ctx.Err()
191+
err = ctx.Err()
190192
}
191-
return s.state, s.error
193+
if fn != nil {
194+
return s.state, fn(s.state, err)
195+
}
196+
return s.state, err
192197
}
193198

194199
// readConnLoop reads messages from conn and writes to ptty as needed. Blocks

agent/reconnectingpty/screen.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func (rpty *screenReconnectingPTY) lifecycle(ctx context.Context, logger slog.Lo
130130
logger.Debug(ctx, "reconnecting pty ready")
131131
rpty.state.setState(StateReady, nil)
132132

133-
state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing)
133+
state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing, nil)
134134
if state < StateClosing {
135135
// If we have not closed yet then the context is what unblocked us (which
136136
// means the agent is shutting down) so move into the closing phase.
@@ -155,7 +155,7 @@ func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn ne
155155
ctx, cancel := context.WithCancel(ctx)
156156
defer cancel()
157157

158-
state, err := rpty.state.waitForStateOrContext(ctx, StateReady)
158+
state, err := rpty.state.waitForStateOrContext(ctx, StateReady, nil)
159159
if state != StateReady {
160160
return xerrors.Errorf("reconnecting pty ready wait: %w", err)
161161
}

0 commit comments

Comments
 (0)