Skip to content

Commit 5a44a26

Browse files
committed
WIP
1 parent cb00d58 commit 5a44a26

File tree

3 files changed

+115
-94
lines changed

3 files changed

+115
-94
lines changed

agent/immortalstreams/stream.go

Lines changed: 66 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ type Stream struct {
3636

3737
// Reconnection coordination
3838
pendingReconnect *reconnectRequest
39+
// Event to signal that a reconnect request is pending
40+
reconnectRequested chan struct{}
3941

4042
// Disconnection detection
4143
disconnectChan chan struct{}
@@ -60,13 +62,14 @@ type reconnectResponse struct {
6062
// NewStream creates a new immortal stream
6163
func NewStream(id uuid.UUID, name string, port int, logger slog.Logger) *Stream {
6264
stream := &Stream{
63-
id: id,
64-
name: name,
65-
port: port,
66-
createdAt: time.Now(),
67-
logger: logger,
68-
disconnectChan: make(chan struct{}, 1),
69-
shutdownChan: make(chan struct{}, 1),
65+
id: id,
66+
name: name,
67+
port: port,
68+
createdAt: time.Now(),
69+
logger: logger,
70+
disconnectChan: make(chan struct{}, 1),
71+
shutdownChan: make(chan struct{}),
72+
reconnectRequested: make(chan struct{}),
7073
}
7174

7275
// Create a reconnect function that waits for a client connection
@@ -79,8 +82,25 @@ func NewStream(id uuid.UUID, name string, port int, logger slog.Logger) *Stream
7982
writerSeqNum: writerSeqNum,
8083
response: responseChan,
8184
}
85+
// Replace any previous pending request and signal waiters
86+
// Close and recreate the event channel to broadcast the event
87+
close(stream.reconnectRequested)
88+
stream.reconnectRequested = make(chan struct{})
8289
stream.mu.Unlock()
8390

91+
// Fast path: if the stream is already shutting down, abort immediately
92+
select {
93+
case <-stream.shutdownChan:
94+
stream.mu.Lock()
95+
// Clear the pending request since we're aborting
96+
if stream.pendingReconnect != nil {
97+
stream.pendingReconnect = nil
98+
}
99+
stream.mu.Unlock()
100+
return nil, 0, xerrors.New("stream is shutting down")
101+
default:
102+
}
103+
84104
// Wait for response from HandleReconnect or context cancellation
85105
stream.logger.Debug(context.Background(), "reconnect function waiting for response")
86106
select {
@@ -108,14 +128,6 @@ func NewStream(id uuid.UUID, name string, port int, logger slog.Logger) *Stream
108128
// Create BackedPipe with background context
109129
stream.pipe = backedpipe.NewBackedPipe(context.Background(), reconnectFn)
110130

111-
// Immediately initiate a background connection so the BackedPipe
112-
// is provided with an io.ReadWriteCloser as soon as one is available.
113-
go func() {
114-
if err := stream.pipe.ForceReconnect(context.Background()); err != nil {
115-
stream.logger.Debug(context.Background(), "initial backed pipe connect returned", slog.Error(err))
116-
}
117-
}()
118-
119131
return stream
120132
}
121133

@@ -174,55 +186,43 @@ func (s *Stream) HandleReconnect(clientConn io.ReadWriteCloser, readSeqNum uint6
174186
// No pending request - we need to trigger a reconnection
175187
s.logger.Debug(context.Background(), "no pending request, will trigger reconnection")
176188

177-
// Use a channel to coordinate with the reconnect function
178-
readyChan := make(chan struct{})
189+
// Wait for the reconnect function to post a pending request
190+
reconnectWait := s.reconnectRequested
179191
connectDone := make(chan error, 1)
180-
181-
// Prepare to intercept the next pending request
182-
interceptConn := clientConn
183-
interceptReadSeq := readSeqNum
184-
185192
s.mu.Unlock()
186193

187-
// Start a goroutine that will wait for the pending request and fulfill it
188-
go func() {
189-
// Signal when we're ready to intercept
190-
close(readyChan)
191-
192-
// Poll for the pending request
193-
for {
194-
s.mu.Lock()
195-
if s.pendingReconnect != nil {
196-
// Found the pending request, fulfill it
197-
s.pendingReconnect.response <- reconnectResponse{
198-
conn: interceptConn,
199-
readSeq: interceptReadSeq,
200-
err: nil,
201-
}
202-
s.pendingReconnect = nil
203-
s.mu.Unlock()
204-
return
205-
}
206-
s.mu.Unlock()
207-
208-
// Small sleep to avoid busy waiting
209-
time.Sleep(1 * time.Millisecond)
210-
}
211-
}()
212-
213-
// Wait for the interceptor to be ready
214-
<-readyChan
215-
216-
// Now trigger the reconnection - this will call our reconnect function
194+
// Trigger the reconnection - this will call our reconnect function
217195
go func() {
218196
s.logger.Debug(context.Background(), "calling ForceReconnect")
219197
err := s.pipe.ForceReconnect(context.Background())
220198
s.logger.Debug(context.Background(), "force reconnect returned", slog.Error(err))
221199
connectDone <- err
222200
}()
223201

224-
// Wait for the connection to complete
225-
err := <-connectDone
202+
// Wait for reconnectFn to signal the request, then fulfill it
203+
var earlyDone bool
204+
var earlyErr error
205+
select {
206+
case <-reconnectWait:
207+
s.mu.Lock()
208+
if s.pendingReconnect != nil {
209+
s.pendingReconnect.response <- reconnectResponse{conn: clientConn, readSeq: readSeqNum, err: nil}
210+
s.pendingReconnect = nil
211+
}
212+
s.mu.Unlock()
213+
case err := <-connectDone:
214+
// Reconnect returned before we got a pending request
215+
earlyDone = true
216+
earlyErr = err
217+
}
218+
219+
// Wait for the connection to complete if we didn't already
220+
var err error
221+
if earlyDone {
222+
err = earlyErr
223+
} else {
224+
err = <-connectDone
225+
}
226226

227227
s.mu.Lock()
228228
defer s.mu.Unlock()
@@ -252,12 +252,13 @@ func (s *Stream) Close() error {
252252
s.closed = true
253253
s.connected = false
254254

255-
// Signal shutdown to any pending reconnect attempts
255+
// Signal shutdown to any pending reconnect attempts and listeners
256+
// Closing the channel wakes all waiters exactly once
256257
select {
257-
case s.shutdownChan <- struct{}{}:
258-
// Signal sent successfully
258+
case <-s.shutdownChan:
259+
// already closed
259260
default:
260-
// Channel is full or already closed, which is fine
261+
close(s.shutdownChan)
261262
}
262263

263264
// Clear any pending reconnect request
@@ -429,10 +430,16 @@ func (s *Stream) handleDisconnect() {
429430

430431
// SignalDisconnect signals that the connection has been lost
431432
func (s *Stream) SignalDisconnect() {
433+
s.mu.RLock()
434+
closed := s.closed
435+
s.mu.RUnlock()
436+
if closed {
437+
return
438+
}
432439
select {
433440
case s.disconnectChan <- struct{}{}:
434441
default:
435-
// Channel is full or closed, ignore
442+
// Channel is full, ignore
436443
}
437444
}
438445

coderd/agentapi/backedpipe/backed_pipe.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,5 +328,8 @@ func (bp *BackedPipe) ForceReconnect(ctx context.Context) error {
328328
return io.ErrClosedPipe
329329
}
330330

331-
return bp.reconnectLocked(ctx)
331+
// Use the pipe's internal context so that Close() reliably cancels any
332+
// in-flight reconnection attempts. An external context here can outlive
333+
// the pipe and cause goroutines to block indefinitely.
334+
return bp.reconnectLocked(bp.ctx)
332335
}

coderd/agentapi/backedpipe/backed_reader.go

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,40 +34,51 @@ func NewBackedReader() *BackedReader {
3434
// When connected, it reads from the underlying reader and updates sequence numbers.
3535
// Connection failures are automatically detected and reported to the higher layer via callback.
3636
func (br *BackedReader) Read(p []byte) (int, error) {
37-
br.mu.Lock()
38-
defer br.mu.Unlock()
39-
40-
for {
41-
for br.reader == nil && !br.closed {
42-
br.cond.Wait()
43-
}
44-
45-
// Check if closed
46-
if br.closed {
47-
return 0, io.ErrClosedPipe
48-
}
49-
50-
n, err := br.reader.Read(p)
51-
52-
if err == nil {
53-
br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract
54-
return n, nil
55-
}
56-
57-
br.reader = nil
58-
59-
if br.onError != nil {
60-
br.onError(err)
61-
}
62-
63-
// If we got some data before the error, return it
64-
if n > 0 {
65-
br.sequenceNum += uint64(n)
66-
return n, nil
67-
}
68-
69-
// Return to Step 2 (continue the loop)
70-
}
37+
for {
38+
// Step 1: Wait until we have a reader or are closed
39+
br.mu.Lock()
40+
for br.reader == nil && !br.closed {
41+
br.cond.Wait()
42+
}
43+
44+
if br.closed {
45+
br.mu.Unlock()
46+
return 0, io.ErrClosedPipe
47+
}
48+
49+
// Capture the current reader and release the lock while performing
50+
// the potentially blocking I/O operation to avoid deadlocks with Close().
51+
r := br.reader
52+
br.mu.Unlock()
53+
54+
// Step 2: Perform the read without holding the mutex
55+
n, err := r.Read(p)
56+
57+
// Step 3: Reacquire the lock to update state based on the result
58+
br.mu.Lock()
59+
if err == nil {
60+
br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract
61+
br.mu.Unlock()
62+
return n, nil
63+
}
64+
65+
// Mark disconnected so future reads will wait for reconnection
66+
br.reader = nil
67+
68+
if br.onError != nil {
69+
br.onError(err)
70+
}
71+
72+
// If we got some data before the error, return it now
73+
if n > 0 {
74+
br.sequenceNum += uint64(n)
75+
br.mu.Unlock()
76+
return n, nil
77+
}
78+
79+
// Otherwise loop and wait for reconnection or close
80+
br.mu.Unlock()
81+
}
7182
}
7283

7384
// Reconnect coordinates reconnection using channels for better synchronization.

0 commit comments

Comments
 (0)