Skip to content

Commit a98292b

Browse files
committed
WIP
1 parent 5a44a26 commit a98292b

File tree

4 files changed

+144
-94
lines changed

4 files changed

+144
-94
lines changed

agent/immortalstreams/stream.go

Lines changed: 74 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,16 @@ type Stream struct {
3131
connected bool
3232
closed bool
3333

34+
// Indicates a reconnection attempt is in progress (single-flight)
35+
reconnecting bool
36+
3437
// goroutines manages the copy goroutines
3538
goroutines sync.WaitGroup
3639

3740
// Reconnection coordination
3841
pendingReconnect *reconnectRequest
39-
// Event to signal that a reconnect request is pending
40-
reconnectRequested chan struct{}
42+
// Condition variable to wait for pendingReconnect changes
43+
reconnectCond *sync.Cond
4144

4245
// Disconnection detection
4346
disconnectChan chan struct{}
@@ -62,15 +65,15 @@ type reconnectResponse struct {
6265
// NewStream creates a new immortal stream
6366
func NewStream(id uuid.UUID, name string, port int, logger slog.Logger) *Stream {
6467
stream := &Stream{
65-
id: id,
66-
name: name,
67-
port: port,
68-
createdAt: time.Now(),
69-
logger: logger,
70-
disconnectChan: make(chan struct{}, 1),
71-
shutdownChan: make(chan struct{}),
72-
reconnectRequested: make(chan struct{}),
68+
id: id,
69+
name: name,
70+
port: port,
71+
createdAt: time.Now(),
72+
logger: logger,
73+
disconnectChan: make(chan struct{}, 1),
74+
shutdownChan: make(chan struct{}),
7375
}
76+
stream.reconnectCond = sync.NewCond(&stream.mu)
7477

7578
// Create a reconnect function that waits for a client connection
7679
reconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
@@ -82,10 +85,8 @@ func NewStream(id uuid.UUID, name string, port int, logger slog.Logger) *Stream
8285
writerSeqNum: writerSeqNum,
8386
response: responseChan,
8487
}
85-
// Replace any previous pending request and signal waiters
86-
// Close and recreate the event channel to broadcast the event
87-
close(stream.reconnectRequested)
88-
stream.reconnectRequested = make(chan struct{})
88+
// Signal waiters a reconnect request is pending
89+
stream.reconnectCond.Broadcast()
8990
stream.mu.Unlock()
9091

9192
// Fast path: if the stream is already shutting down, abort immediately
@@ -183,59 +184,77 @@ func (s *Stream) HandleReconnect(clientConn io.ReadWriteCloser, readSeqNum uint6
183184
return nil
184185
}
185186

186-
// No pending request - we need to trigger a reconnection
187+
// No pending request - we need to trigger a reconnection (single-flight)
187188
s.logger.Debug(context.Background(), "no pending request, will trigger reconnection")
188189

189-
// Wait for the reconnect function to post a pending request
190-
reconnectWait := s.reconnectRequested
191-
connectDone := make(chan error, 1)
192-
s.mu.Unlock()
193-
194-
// Trigger the reconnection - this will call our reconnect function
195-
go func() {
196-
s.logger.Debug(context.Background(), "calling ForceReconnect")
197-
err := s.pipe.ForceReconnect(context.Background())
198-
s.logger.Debug(context.Background(), "force reconnect returned", slog.Error(err))
199-
connectDone <- err
200-
}()
190+
for {
191+
// Ensure only one goroutine kicks off ForceReconnect
192+
if !s.reconnecting {
193+
s.reconnecting = true
194+
go func() {
195+
s.logger.Debug(context.Background(), "calling ForceReconnect")
196+
err := s.pipe.ForceReconnect(context.Background())
197+
s.logger.Debug(context.Background(), "force reconnect returned", slog.Error(err))
198+
s.mu.Lock()
199+
s.reconnecting = false
200+
// Notify any waiters in case we need to retry
201+
if s.reconnectCond != nil {
202+
s.reconnectCond.Broadcast()
203+
}
204+
s.mu.Unlock()
205+
}()
206+
}
201207

202-
// Wait for reconnectFn to signal the request, then fulfill it
203-
var earlyDone bool
204-
var earlyErr error
205-
select {
206-
case <-reconnectWait:
207-
s.mu.Lock()
208+
// If a reconnect request is pending, respond and break
208209
if s.pendingReconnect != nil {
209210
s.pendingReconnect.response <- reconnectResponse{conn: clientConn, readSeq: readSeqNum, err: nil}
210211
s.pendingReconnect = nil
212+
break
211213
}
212-
s.mu.Unlock()
213-
case err := <-connectDone:
214-
// Reconnect returned before we got a pending request
215-
earlyDone = true
216-
earlyErr = err
217-
}
218214

219-
// Wait for the connection to complete if we didn't already
220-
var err error
221-
if earlyDone {
222-
err = earlyErr
223-
} else {
224-
err = <-connectDone
215+
// If the stream has been closed, exit
216+
if s.closed {
217+
s.mu.Unlock()
218+
return xerrors.New("stream is closed")
219+
}
220+
221+
// If already connected (another goroutine handled it), we're done
222+
if s.connected {
223+
s.mu.Unlock()
224+
s.logger.Debug(context.Background(), "another goroutine completed reconnection")
225+
return nil
226+
}
227+
228+
// Wait until something changes (pending request posted, reconnect attempt finishes, or close)
229+
s.reconnectCond.Wait()
225230
}
226231

227-
s.mu.Lock()
228-
defer s.mu.Unlock()
232+
s.mu.Unlock()
229233

230-
if err != nil {
234+
// Wait until the pipe reports a connected state
235+
// This ensures the reconnection handshake fully completes regardless of
236+
// which goroutine initiated it.
237+
if err := s.pipe.WaitForConnection(context.Background()); err != nil {
238+
s.mu.Lock()
231239
s.connected = false
240+
// Notify any other waiters to re-check state or exit
241+
if s.reconnectCond != nil {
242+
s.reconnectCond.Broadcast()
243+
}
244+
s.mu.Unlock()
232245
s.logger.Warn(context.Background(), "failed to connect backed pipe", slog.Error(err))
233246
return xerrors.Errorf("failed to establish connection: %w", err)
234247
}
235248

236-
// Success
249+
s.mu.Lock()
237250
s.lastConnectionAt = time.Now()
238251
s.connected = true
252+
// Wake any concurrent HandleReconnect callers waiting for connection
253+
if s.reconnectCond != nil {
254+
s.reconnectCond.Broadcast()
255+
}
256+
s.mu.Unlock()
257+
239258
s.logger.Debug(context.Background(), "client reconnection successful")
240259
return nil
241260
}
@@ -261,6 +280,12 @@ func (s *Stream) Close() error {
261280
close(s.shutdownChan)
262281
}
263282

283+
// Wake any goroutines waiting for a pending reconnect request so they
284+
// observe the closed state and exit promptly.
285+
if s.reconnectCond != nil {
286+
s.reconnectCond.Broadcast()
287+
}
288+
264289
// Clear any pending reconnect request
265290
if s.pendingReconnect != nil {
266291
s.pendingReconnect.response <- reconnectResponse{

agent/immortalstreams/stream_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@ import (
44
"fmt"
55
"io"
66
"net"
7+
"os"
8+
"runtime"
79
"sync"
810
"testing"
911
"time"
1012

1113
"github.com/google/uuid"
14+
"go.uber.org/goleak"
1215

1316
"github.com/stretchr/testify/require"
1417

@@ -17,6 +20,15 @@ import (
1720
"github.com/coder/coder/v2/testutil"
1821
)
1922

23+
func TestMain(m *testing.M) {
24+
if runtime.GOOS == "windows" {
25+
// Don't run goleak on windows tests, they're super flaky right now.
26+
// See: https://github.com/coder/coder/issues/8954
27+
os.Exit(m.Run())
28+
}
29+
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
30+
}
31+
2032
func TestStream_Start(t *testing.T) {
2133
t.Parallel()
2234

coderd/agentapi/backedpipe/backed_reader.go

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -34,51 +34,51 @@ func NewBackedReader() *BackedReader {
3434
// When connected, it reads from the underlying reader and updates sequence numbers.
3535
// Connection failures are automatically detected and reported to the higher layer via callback.
3636
func (br *BackedReader) Read(p []byte) (int, error) {
37-
for {
38-
// Step 1: Wait until we have a reader or are closed
39-
br.mu.Lock()
40-
for br.reader == nil && !br.closed {
41-
br.cond.Wait()
42-
}
43-
44-
if br.closed {
45-
br.mu.Unlock()
46-
return 0, io.ErrClosedPipe
47-
}
48-
49-
// Capture the current reader and release the lock while performing
50-
// the potentially blocking I/O operation to avoid deadlocks with Close().
51-
r := br.reader
52-
br.mu.Unlock()
53-
54-
// Step 2: Perform the read without holding the mutex
55-
n, err := r.Read(p)
56-
57-
// Step 3: Reacquire the lock to update state based on the result
58-
br.mu.Lock()
59-
if err == nil {
60-
br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract
61-
br.mu.Unlock()
62-
return n, nil
63-
}
64-
65-
// Mark disconnected so future reads will wait for reconnection
66-
br.reader = nil
67-
68-
if br.onError != nil {
69-
br.onError(err)
70-
}
71-
72-
// If we got some data before the error, return it now
73-
if n > 0 {
74-
br.sequenceNum += uint64(n)
75-
br.mu.Unlock()
76-
return n, nil
77-
}
78-
79-
// Otherwise loop and wait for reconnection or close
80-
br.mu.Unlock()
81-
}
37+
for {
38+
// Step 1: Wait until we have a reader or are closed
39+
br.mu.Lock()
40+
for br.reader == nil && !br.closed {
41+
br.cond.Wait()
42+
}
43+
44+
if br.closed {
45+
br.mu.Unlock()
46+
return 0, io.ErrClosedPipe
47+
}
48+
49+
// Capture the current reader and release the lock while performing
50+
// the potentially blocking I/O operation to avoid deadlocks with Close().
51+
r := br.reader
52+
br.mu.Unlock()
53+
54+
// Step 2: Perform the read without holding the mutex
55+
n, err := r.Read(p)
56+
57+
// Step 3: Reacquire the lock to update state based on the result
58+
br.mu.Lock()
59+
if err == nil {
60+
br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract
61+
br.mu.Unlock()
62+
return n, nil
63+
}
64+
65+
// Mark disconnected so future reads will wait for reconnection
66+
br.reader = nil
67+
68+
if br.onError != nil {
69+
br.onError(err)
70+
}
71+
72+
// If we got some data before the error, return it now
73+
if n > 0 {
74+
br.sequenceNum += uint64(n)
75+
br.mu.Unlock()
76+
return n, nil
77+
}
78+
79+
// Otherwise loop and wait for reconnection or close
80+
br.mu.Unlock()
81+
}
8282
}
8383

8484
// Reconnect coordinates reconnection using channels for better synchronization.

coderd/agentapi/backedpipe/ring_buffer_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,27 @@ package backedpipe_test
33
import (
44
"bytes"
55
"fmt"
6+
"os"
7+
"runtime"
68
"sync"
79
"testing"
810

911
"github.com/stretchr/testify/require"
12+
"go.uber.org/goleak"
1013

1114
"github.com/coder/coder/v2/coderd/agentapi/backedpipe"
15+
"github.com/coder/coder/v2/testutil"
1216
)
1317

18+
func TestMain(m *testing.M) {
19+
if runtime.GOOS == "windows" {
20+
// Don't run goleak on windows tests, they're super flaky right now.
21+
// See: https://github.com/coder/coder/issues/8954
22+
os.Exit(m.Run())
23+
}
24+
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
25+
}
26+
1427
func TestRingBuffer_NewRingBuffer(t *testing.T) {
1528
t.Parallel()
1629

0 commit comments

Comments
 (0)