From 07a1b9f49d79f60e8efbc39a8cf9b4386cd34743 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Thu, 7 Aug 2025 11:16:53 +0000 Subject: [PATCH 01/11] "chore: add backed reader, writer and pipe" --- coderd/agentapi/backedpipe/backed_pipe.go | 332 ++++++++ .../agentapi/backedpipe/backed_pipe_test.go | 727 ++++++++++++++++++ coderd/agentapi/backedpipe/backed_reader.go | 150 ++++ .../agentapi/backedpipe/backed_reader_test.go | 471 ++++++++++++ coderd/agentapi/backedpipe/backed_writer.go | 244 ++++++ .../agentapi/backedpipe/backed_writer_test.go | 411 ++++++++++ coderd/agentapi/backedpipe/ring_buffer.go | 140 ++++ .../backedpipe/ring_buffer_internal_test.go | 162 ++++ .../agentapi/backedpipe/ring_buffer_test.go | 326 ++++++++ testutil/duration.go | 9 +- 10 files changed, 2968 insertions(+), 4 deletions(-) create mode 100644 coderd/agentapi/backedpipe/backed_pipe.go create mode 100644 coderd/agentapi/backedpipe/backed_pipe_test.go create mode 100644 coderd/agentapi/backedpipe/backed_reader.go create mode 100644 coderd/agentapi/backedpipe/backed_reader_test.go create mode 100644 coderd/agentapi/backedpipe/backed_writer.go create mode 100644 coderd/agentapi/backedpipe/backed_writer_test.go create mode 100644 coderd/agentapi/backedpipe/ring_buffer.go create mode 100644 coderd/agentapi/backedpipe/ring_buffer_internal_test.go create mode 100644 coderd/agentapi/backedpipe/ring_buffer_test.go diff --git a/coderd/agentapi/backedpipe/backed_pipe.go b/coderd/agentapi/backedpipe/backed_pipe.go new file mode 100644 index 0000000000000..784b8d55353b1 --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_pipe.go @@ -0,0 +1,332 @@ +package backedpipe + +import ( + "context" + "io" + "sync" + + "golang.org/x/xerrors" +) + +const ( + // DefaultBufferSize is the default buffer size for the BackedWriter (64MB) + DefaultBufferSize = 64 * 1024 * 1024 +) + +// ReconnectFunc is called when the BackedPipe needs to establish a new connection. +// It should: +// 1. Establish a new connection to the remote side +// 2. Exchange sequence numbers with the remote side +// 3. Return the new connection and the remote's current sequence number +// +// The writerSeqNum parameter is the local writer's current sequence number, +// which should be sent to the remote side so it knows where to resume reading from. +// +// The returned readerSeqNum should be the remote side's current sequence number, +// which indicates where the local reader should resume from. +type ReconnectFunc func(ctx context.Context, writerSeqNum uint64) (conn io.ReadWriteCloser, readerSeqNum uint64, err error) + +// BackedPipe provides a reliable bidirectional byte stream over unreliable network connections. +// It orchestrates a BackedReader and BackedWriter to provide transparent reconnection +// and data replay capabilities. +type BackedPipe struct { + ctx context.Context + cancel context.CancelFunc + mu sync.RWMutex + reader *BackedReader + writer *BackedWriter + reconnectFn ReconnectFunc + conn io.ReadWriteCloser + connected bool + closed bool + + // Reconnection state + reconnecting bool + + // Error channel for receiving connection errors from reader/writer + errorChan chan error + + // Connection state notification + connectionChanged chan struct{} +} + +// NewBackedPipe creates a new BackedPipe with default options and the specified reconnect function. +// The pipe starts disconnected and must be connected using Connect(). +func NewBackedPipe(ctx context.Context, reconnectFn ReconnectFunc) *BackedPipe { + pipeCtx, cancel := context.WithCancel(ctx) + + bp := &BackedPipe{ + ctx: pipeCtx, + cancel: cancel, + reader: NewBackedReader(), + writer: NewBackedWriterWithCapacity(DefaultBufferSize), // 64MB default buffer + reconnectFn: reconnectFn, + errorChan: make(chan error, 2), // Buffer for reader and writer errors + connectionChanged: make(chan struct{}, 1), + } + + // Set up error callbacks + bp.reader.SetErrorCallback(func(err error) { + select { + case bp.errorChan <- err: + case <-bp.ctx.Done(): + } + }) + + bp.writer.SetErrorCallback(func(err error) { + select { + case bp.errorChan <- err: + case <-bp.ctx.Done(): + } + }) + + // Start error handler goroutine + go bp.handleErrors() + + return bp +} + +// Connect establishes the initial connection using the reconnect function. +func (bp *BackedPipe) Connect(ctx context.Context) error { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.closed { + return xerrors.New("pipe is closed") + } + + if bp.connected { + return xerrors.New("pipe is already connected") + } + + return bp.reconnectLocked(ctx) +} + +// Read implements io.Reader by delegating to the BackedReader. +func (bp *BackedPipe) Read(p []byte) (int, error) { + bp.mu.RLock() + reader := bp.reader + closed := bp.closed + bp.mu.RUnlock() + + if closed { + return 0, io.ErrClosedPipe + } + + return reader.Read(p) +} + +// Write implements io.Writer by delegating to the BackedWriter. +func (bp *BackedPipe) Write(p []byte) (int, error) { + bp.mu.RLock() + writer := bp.writer + closed := bp.closed + bp.mu.RUnlock() + + if closed { + return 0, io.ErrClosedPipe + } + + return writer.Write(p) +} + +// Close closes the pipe and all underlying connections. +func (bp *BackedPipe) Close() error { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.closed { + return nil + } + + bp.closed = true + bp.cancel() // Cancel main context + + // Close underlying components + var readerErr, writerErr, connErr error + + if bp.reader != nil { + readerErr = bp.reader.Close() + } + + if bp.writer != nil { + writerErr = bp.writer.Close() + } + + if bp.conn != nil { + connErr = bp.conn.Close() + bp.conn = nil + } + + bp.connected = false + bp.signalConnectionChange() + + // Return first error encountered + if readerErr != nil { + return readerErr + } + if writerErr != nil { + return writerErr + } + return connErr +} + +// Connected returns whether the pipe is currently connected. +func (bp *BackedPipe) Connected() bool { + bp.mu.RLock() + defer bp.mu.RUnlock() + return bp.connected +} + +// signalConnectionChange signals that the connection state has changed. +func (bp *BackedPipe) signalConnectionChange() { + select { + case bp.connectionChanged <- struct{}{}: + default: + // Channel is full, which is fine - we just want to signal that something changed + } +} + +// reconnectLocked handles the reconnection logic. Must be called with write lock held. +func (bp *BackedPipe) reconnectLocked(ctx context.Context) error { + if bp.reconnecting { + return xerrors.New("reconnection already in progress") + } + + bp.reconnecting = true + defer func() { + bp.reconnecting = false + }() + + // Close existing connection if any + if bp.conn != nil { + _ = bp.conn.Close() + bp.conn = nil + } + + bp.connected = false + bp.signalConnectionChange() + + // Get current writer sequence number to send to remote + writerSeqNum := bp.writer.SequenceNum() + + // Unlock during reconnect attempt to avoid blocking reads/writes + bp.mu.Unlock() + conn, readerSeqNum, err := bp.reconnectFn(ctx, writerSeqNum) + bp.mu.Lock() + + if err != nil { + return xerrors.Errorf("reconnect failed: %w", err) + } + + // Validate sequence numbers + if readerSeqNum > writerSeqNum { + _ = conn.Close() + return xerrors.Errorf("remote sequence number %d exceeds local sequence %d, cannot replay", + readerSeqNum, writerSeqNum) + } + + // Validate writer can replay from the requested sequence + if !bp.writer.CanReplayFrom(readerSeqNum) { + _ = conn.Close() + // Calculate data loss + currentSeq := bp.writer.SequenceNum() + dataLoss := currentSeq - DefaultBufferSize - readerSeqNum + return xerrors.Errorf("cannot replay from sequence %d (current: %d, data loss: ~%d bytes)", + readerSeqNum, currentSeq, dataLoss) + } + + // Reconnect reader and writer + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go bp.reader.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- conn + + err = bp.writer.Reconnect(readerSeqNum, conn) + if err != nil { + _ = conn.Close() + return xerrors.Errorf("reconnect writer: %w", err) + } + + // Success - update state + bp.conn = conn + bp.connected = true + bp.signalConnectionChange() + + return nil +} + +// handleErrors listens for connection errors from reader/writer and triggers reconnection. +func (bp *BackedPipe) handleErrors() { + for { + select { + case <-bp.ctx.Done(): + return + case err := <-bp.errorChan: + // Connection error occurred + bp.mu.Lock() + + // Skip if already closed or not connected + if bp.closed || !bp.connected { + bp.mu.Unlock() + continue + } + + // Mark as disconnected + bp.connected = false + bp.signalConnectionChange() + + // Try to reconnect + reconnectErr := bp.reconnectLocked(bp.ctx) + bp.mu.Unlock() + + if reconnectErr != nil { + // Reconnection failed - log or handle as needed + // For now, we'll just continue and wait for manual reconnection + _ = err // Use the original error + } + } + } +} + +// WaitForConnection blocks until the pipe is connected or the context is canceled. +func (bp *BackedPipe) WaitForConnection(ctx context.Context) error { + for { + bp.mu.RLock() + connected := bp.connected + closed := bp.closed + bp.mu.RUnlock() + + if closed { + return io.ErrClosedPipe + } + + if connected { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-bp.connectionChanged: + // Connection state changed, check again + } + } +} + +// ForceReconnect forces a reconnection attempt immediately. +// This can be used to force a reconnection if a new connection is established. +func (bp *BackedPipe) ForceReconnect(ctx context.Context) error { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.closed { + return io.ErrClosedPipe + } + + return bp.reconnectLocked(ctx) +} diff --git a/coderd/agentapi/backedpipe/backed_pipe_test.go b/coderd/agentapi/backedpipe/backed_pipe_test.go new file mode 100644 index 0000000000000..c841112ed07e1 --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_pipe_test.go @@ -0,0 +1,727 @@ +package backedpipe_test + +import ( + "bytes" + "context" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/agentapi/backedpipe" + "github.com/coder/coder/v2/testutil" +) + +// mockConnection implements io.ReadWriteCloser for testing +type mockConnection struct { + mu sync.Mutex + readBuffer bytes.Buffer + writeBuffer bytes.Buffer + closed bool + readError error + writeError error + closeError error + readFunc func([]byte) (int, error) + writeFunc func([]byte) (int, error) + seqNum uint64 +} + +func newMockConnection() *mockConnection { + return &mockConnection{} +} + +func (mc *mockConnection) Read(p []byte) (int, error) { + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.readFunc != nil { + return mc.readFunc(p) + } + + if mc.readError != nil { + return 0, mc.readError + } + + return mc.readBuffer.Read(p) +} + +func (mc *mockConnection) Write(p []byte) (int, error) { + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.writeFunc != nil { + return mc.writeFunc(p) + } + + if mc.writeError != nil { + return 0, mc.writeError + } + + return mc.writeBuffer.Write(p) +} + +func (mc *mockConnection) Close() error { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.closed = true + return mc.closeError +} + +func (mc *mockConnection) WriteString(s string) { + mc.mu.Lock() + defer mc.mu.Unlock() + _, _ = mc.readBuffer.WriteString(s) +} + +func (mc *mockConnection) ReadString() string { + mc.mu.Lock() + defer mc.mu.Unlock() + return mc.writeBuffer.String() +} + +func (mc *mockConnection) SetReadError(err error) { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.readError = err +} + +func (mc *mockConnection) SetWriteError(err error) { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.writeError = err +} + +func (mc *mockConnection) Reset() { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.readBuffer.Reset() + mc.writeBuffer.Reset() + mc.readError = nil + mc.writeError = nil + mc.closed = false +} + +// mockReconnectFunc creates a unified reconnect function with all behaviors enabled +func mockReconnectFunc(connections ...*mockConnection) (backedpipe.ReconnectFunc, *int, chan struct{}) { + connectionIndex := 0 + callCount := 0 + signalChan := make(chan struct{}, 1) + + reconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + callCount++ + + if connectionIndex >= len(connections) { + return nil, 0, xerrors.New("no more connections available") + } + + conn := connections[connectionIndex] + connectionIndex++ + + // Signal when reconnection happens + if connectionIndex > 1 { + select { + case signalChan <- struct{}{}: + default: + } + } + + // Determine readerSeqNum based on call count + var readerSeqNum uint64 + switch { + case callCount == 1: + readerSeqNum = 0 + case conn.seqNum != 0: + readerSeqNum = conn.seqNum + default: + readerSeqNum = writerSeqNum + } + + return conn, readerSeqNum, nil + } + + return reconnectFn, &callCount, signalChan +} + +func TestBackedPipe_NewBackedPipe(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reconnectFn, _, _ := mockReconnectFunc(newMockConnection()) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + require.NotNil(t, bp) + require.False(t, bp.Connected()) +} + +func TestBackedPipe_Connect(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, callCount, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, *callCount) +} + +func TestBackedPipe_ConnectAlreadyConnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + + // Second connect should fail + err = bp.Connect(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "already connected") +} + +func TestBackedPipe_ConnectAfterClose(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Close() + require.NoError(t, err) + + err = bp.Connect(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "closed") +} + +func TestBackedPipe_BasicReadWrite(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + + // Write data + n, err := bp.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Simulate data coming back + conn.WriteString("world") + + // Read data + buf := make([]byte, 10) + n, err = bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "world", string(buf[:n])) +} + +func TestBackedPipe_WriteBeforeConnect(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Write before connecting should succeed (buffered) + n, err := bp.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Connect should replay the buffered data + err = bp.Connect(ctx) + require.NoError(t, err) + + // Check that data was replayed to connection + require.Equal(t, "hello", conn.ReadString()) +} + +func TestBackedPipe_ReadBlocksWhenDisconnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reconnectFn, _, _ := mockReconnectFunc(newMockConnection()) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Start a read that should block + readDone := make(chan struct{}) + readStarted := make(chan struct{}) + var readErr error + + go func() { + defer close(readDone) + close(readStarted) // Signal that we're about to start the read + buf := make([]byte, 10) + _, readErr = bp.Read(buf) + }() + + // Wait for the goroutine to start + <-readStarted + + // Give a brief moment for the read to actually block + time.Sleep(time.Millisecond) + + // Read should still be blocked + select { + case <-readDone: + t.Fatal("Read should be blocked when disconnected") + default: + // Good, still blocked + } + + // Close should unblock the read + bp.Close() + + select { + case <-readDone: + require.Equal(t, io.ErrClosedPipe, readErr) + case <-time.After(time.Second): + t.Fatal("Read did not unblock after close") + } +} + +func TestBackedPipe_Reconnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn1 := newMockConnection() + conn2 := newMockConnection() + conn2.seqNum = 17 // Remote has received 17 bytes, so replay from sequence 17 + reconnectFn, _, signalChan := mockReconnectFunc(conn1, conn2) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Initial connect + err := bp.Connect(ctx) + require.NoError(t, err) + + // Write some data before failure + bp.Write([]byte("before disconnect***")) + + // Simulate connection failure + conn1.SetReadError(xerrors.New("connection lost")) + conn1.SetWriteError(xerrors.New("connection lost")) + + // Trigger a write to cause the pipe to notice the failure + _, _ = bp.Write([]byte("trigger failure ")) + + <-signalChan + + err = bp.WaitForConnection(ctx) + require.NoError(t, err) + + replayedData := conn2.ReadString() + require.Equal(t, "***trigger failure ", replayedData, "Should replay exactly the data written after sequence 17") + + // Verify that new writes work with the reconnected pipe + _, err = bp.Write([]byte("new data after reconnect")) + require.NoError(t, err) + + // Read all data from the connection (replayed + new data) + allData := conn2.ReadString() + require.Equal(t, "***trigger failure new data after reconnect", allData, "Should have replayed data plus new data") +} + +func TestBackedPipe_Close(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + + err = bp.Close() + require.NoError(t, err) + require.True(t, conn.closed) + + // Operations after close should fail + _, err = bp.Read(make([]byte, 10)) + require.Equal(t, io.ErrClosedPipe, err) + + _, err = bp.Write([]byte("test")) + require.Equal(t, io.ErrClosedPipe, err) +} + +func TestBackedPipe_CloseIdempotent(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Close() + require.NoError(t, err) + + // Second close should be no-op + err = bp.Close() + require.NoError(t, err) +} + +func TestBackedPipe_WaitForConnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Should timeout when not connected + // Use a shorter timeout for this test to speed up test runs + timeoutCtx, cancel := context.WithTimeout(ctx, testutil.WaitSuperShort) + defer cancel() + + err := bp.WaitForConnection(timeoutCtx) + require.Equal(t, context.DeadlineExceeded, err) + + // Connect in background after a brief delay + connectionStarted := make(chan struct{}) + go func() { + close(connectionStarted) + // Small delay to ensure WaitForConnection is called first + time.Sleep(time.Millisecond) + bp.Connect(context.Background()) + }() + + // Wait for connection goroutine to start + <-connectionStarted + + // Should succeed once connected + err = bp.WaitForConnection(context.Background()) + require.NoError(t, err) +} + +func TestBackedPipe_ConcurrentReadWrite(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + + var wg sync.WaitGroup + numWriters := 3 + writesPerWriter := 10 + + // Fill read buffer with test data first + testData := make([]byte, 1000) + for i := range testData { + testData[i] = 'A' + } + conn.WriteString(string(testData)) + + // Channel to collect all written data + writtenData := make(chan byte, numWriters*writesPerWriter) + + // Start a few readers + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 10) + for j := 0; j < 10; j++ { + bp.Read(buf) + time.Sleep(time.Millisecond) // Small delay to avoid busy waiting + } + }() + } + + // Start writers + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < writesPerWriter; j++ { + data := []byte{byte(id + '0')} + bp.Write(data) + writtenData <- byte(id + '0') + time.Sleep(time.Millisecond) // Small delay + } + }(i) + } + + // Wait with timeout + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-done: + // Success + case <-time.After(5 * time.Second): + t.Fatal("Test timed out") + } + + // Close the channel and collect all written data + close(writtenData) + var allWritten []byte + for b := range writtenData { + allWritten = append(allWritten, b) + } + + // Verify that all written data was received by the connection + // Note: Since this test uses the old mock that returns readerSeqNum = 0, + // all data will be replayed, so we expect to receive all written data + receivedData := conn.ReadString() + require.GreaterOrEqual(t, len(receivedData), len(allWritten), "Connection should have received at least all written data") + + // Check that all written bytes appear in the received data + for _, writtenByte := range allWritten { + require.Contains(t, receivedData, string(writtenByte), "Written byte %c should be present in received data", writtenByte) + } +} + +func TestBackedPipe_ReconnectFunctionFailure(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + failingReconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + return nil, 0, xerrors.New("reconnect failed") + } + + bp := backedpipe.NewBackedPipe(ctx, failingReconnectFn) + + err := bp.Connect(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "reconnect failed") + require.False(t, bp.Connected()) +} + +func TestBackedPipe_ForceReconnect(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn1 := newMockConnection() + conn2 := newMockConnection() + reconnectFn, callCount, _ := mockReconnectFunc(conn1, conn2) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Initial connect + err := bp.Connect(ctx) + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, *callCount) + + // Write some data to the first connection + _, err = bp.Write([]byte("test data")) + require.NoError(t, err) + require.Equal(t, "test data", conn1.ReadString()) + + // Force a reconnection + err = bp.ForceReconnect(ctx) + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 2, *callCount) + + // Since the mock now returns the proper sequence number, no data should be replayed + // The new connection should be empty + require.Equal(t, "", conn2.ReadString()) + + // Verify that data can still be written and read after forced reconnection + _, err = bp.Write([]byte("new data")) + require.NoError(t, err) + require.Equal(t, "new data", conn2.ReadString()) + + // Verify that reads work with the new connection + conn2.WriteString("response data") + buf := make([]byte, 20) + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 13, n) + require.Equal(t, "response data", string(buf[:n])) +} + +func TestBackedPipe_ForceReconnectWhenClosed(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Close the pipe first + err := bp.Close() + require.NoError(t, err) + + // Try to force reconnect when closed + err = bp.ForceReconnect(ctx) + require.Error(t, err) + require.Equal(t, io.ErrClosedPipe, err) +} + +func TestBackedPipe_ForceReconnectWhenDisconnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, callCount, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Don't connect initially, just force reconnect + err := bp.ForceReconnect(ctx) + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, *callCount) + + // Verify we can write and read + _, err = bp.Write([]byte("test")) + require.NoError(t, err) + require.Equal(t, "test", conn.ReadString()) + + conn.WriteString("response") + buf := make([]byte, 10) + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 8, n) + require.Equal(t, "response", string(buf[:n])) +} + +func TestBackedPipe_EOFTriggersReconnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // Create connections where we can control when EOF occurs + conn1 := newMockConnection() + conn2 := newMockConnection() + conn2.WriteString("newdata") // Pre-populate conn2 with data + + // Make conn1 return EOF after reading "world" + hasReadData := false + conn1.readFunc = func(p []byte) (int, error) { + // Don't lock here - the Read method already holds the lock + + // First time: return "world" + if !hasReadData && conn1.readBuffer.Len() > 0 { + n, _ := conn1.readBuffer.Read(p) + hasReadData = true + return n, nil + } + // After that: return EOF + return 0, io.EOF + } + conn1.WriteString("world") + + callCount := 0 + reconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + callCount++ + + if callCount == 1 { + return conn1, 0, nil + } + if callCount == 2 { + // Second call is the reconnection after EOF + return conn2, writerSeqNum, nil // conn2 already has the reader sequence at writerSeqNum + } + + return nil, 0, xerrors.New("no more connections") + } + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Initial connect + err := bp.Connect(ctx) + require.NoError(t, err) + require.Equal(t, 1, callCount) + + // Write some data + _, err = bp.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + + // First read should succeed + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "world", string(buf[:n])) + + // Next read will encounter EOF and should trigger reconnection + // After reconnection, it should read from conn2 + n, err = bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, "newdata", string(buf[:n])) + + // Verify reconnection happened + require.Equal(t, 2, callCount) + + // Verify the pipe is still connected and functional + require.True(t, bp.Connected()) + + // Further writes should go to the new connection + _, err = bp.Write([]byte("aftereof")) + require.NoError(t, err) + require.Equal(t, "aftereof", conn2.ReadString()) +} + +func BenchmarkBackedPipe_Write(b *testing.B) { + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + bp.Connect(ctx) + + data := make([]byte, 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bp.Write(data) + } +} + +func BenchmarkBackedPipe_Read(b *testing.B) { + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + bp.Connect(ctx) + + buf := make([]byte, 1024) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Fill connection with fresh data for each iteration + conn.WriteString(string(buf)) + bp.Read(buf) + } +} diff --git a/coderd/agentapi/backedpipe/backed_reader.go b/coderd/agentapi/backedpipe/backed_reader.go new file mode 100644 index 0000000000000..7b57986638065 --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_reader.go @@ -0,0 +1,150 @@ +package backedpipe + +import ( + "io" + "sync" +) + +// BackedReader wraps an unreliable io.Reader and makes it resilient to disconnections. +// It tracks sequence numbers for all bytes read and can handle reconnection, +// blocking reads when disconnected instead of erroring. +type BackedReader struct { + mu sync.Mutex + cond *sync.Cond + reader io.Reader + sequenceNum uint64 + closed bool + + // Error callback to notify parent when connection fails + onError func(error) +} + +// NewBackedReader creates a new BackedReader. The reader is initially disconnected +// and must be connected using Reconnect before reads will succeed. +func NewBackedReader() *BackedReader { + br := &BackedReader{} + br.cond = sync.NewCond(&br.mu) + return br +} + +// Read implements io.Reader. It blocks when disconnected until either: +// 1. A reconnection is established +// 2. The reader is closed +// +// When connected, it reads from the underlying reader and updates sequence numbers. +// Connection failures are automatically detected and reported to the higher layer via callback. +func (br *BackedReader) Read(p []byte) (int, error) { + br.mu.Lock() + defer br.mu.Unlock() + + for { + for br.reader == nil && !br.closed { + br.cond.Wait() + } + + // Check if closed + if br.closed { + return 0, io.ErrClosedPipe + } + + br.mu.Unlock() + n, err := br.reader.Read(p) + br.mu.Lock() + + if err == nil { + br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract + return n, nil + } + + br.reader = nil + + if br.onError != nil { + br.onError(err) + } + + // If we got some data before the error, return it + if n > 0 { + br.sequenceNum += uint64(n) + return n, nil + } + + // Return to Step 2 (continue the loop) + } +} + +// Reconnect coordinates reconnection using channels for better synchronization. +// The seqNum channel is used to send the current sequence number to the caller. +// The newR channel is used to receive the new reader from the caller. +// This allows for better coordination during the reconnection process. +func (br *BackedReader) Reconnect(seqNum chan<- uint64, newR <-chan io.Reader) { + // Grab the lock + br.mu.Lock() + defer br.mu.Unlock() + + if br.closed { + // Send 0 sequence number and close the channel to indicate closed state + seqNum <- 0 + close(seqNum) + return + } + + // Get the sequence number to send to the other side via seqNum channel + seqNum <- br.sequenceNum + close(seqNum) + + // Wait for the reconnect to complete, via newR channel, and give us a new io.Reader + newReader := <-newR + + // If reconnection fails while we are starting it, the caller sends nil on newR + if newReader == nil { + // Reconnection failed, keep current state + return + } + + // Reconnection successful + br.reader = newReader + + // Notify any waiting reads via the cond + br.cond.Broadcast() +} + +// Closes the reader and wakes up any blocked reads. +// After closing, all Read calls will return io.ErrClosedPipe. +func (br *BackedReader) Close() error { + br.mu.Lock() + defer br.mu.Unlock() + + if br.closed { + return nil + } + + br.closed = true + br.reader = nil + + // Wake up any blocked reads + br.cond.Broadcast() + + return nil +} + +// SetErrorCallback sets the callback function that will be called when +// a connection error occurs (excluding EOF). +func (br *BackedReader) SetErrorCallback(fn func(error)) { + br.mu.Lock() + defer br.mu.Unlock() + br.onError = fn +} + +// SequenceNum returns the current sequence number (total bytes read). +func (br *BackedReader) SequenceNum() uint64 { + br.mu.Lock() + defer br.mu.Unlock() + return br.sequenceNum +} + +// Connected returns whether the reader is currently connected. +func (br *BackedReader) Connected() bool { + br.mu.Lock() + defer br.mu.Unlock() + return br.reader != nil +} diff --git a/coderd/agentapi/backedpipe/backed_reader_test.go b/coderd/agentapi/backedpipe/backed_reader_test.go new file mode 100644 index 0000000000000..810abb7c64bd6 --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_reader_test.go @@ -0,0 +1,471 @@ +package backedpipe_test + +import ( + "io" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/agentapi/backedpipe" +) + +// mockReader implements io.Reader with controllable behavior for testing +type mockReader struct { + mu sync.Mutex + data []byte + pos int + err error + readFunc func([]byte) (int, error) +} + +func newMockReader(data string) *mockReader { + return &mockReader{data: []byte(data)} +} + +func (mr *mockReader) Read(p []byte) (int, error) { + mr.mu.Lock() + defer mr.mu.Unlock() + + if mr.readFunc != nil { + return mr.readFunc(p) + } + + if mr.err != nil { + return 0, mr.err + } + + if mr.pos >= len(mr.data) { + return 0, io.EOF + } + + n := copy(p, mr.data[mr.pos:]) + mr.pos += n + return n, nil +} + +func (mr *mockReader) setError(err error) { + mr.mu.Lock() + defer mr.mu.Unlock() + mr.err = err +} + +func TestBackedReader_NewBackedReader(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + assert.NotNil(t, br) + assert.Equal(t, uint64(0), br.SequenceNum()) + assert.False(t, br.Connected()) +} + +func TestBackedReader_BasicReadOperation(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader := newMockReader("hello world") + + // Connect the reader + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number from reader + seq := <-seqNum + assert.Equal(t, uint64(0), seq) + + // Send new reader + newR <- reader + + // Read data + buf := make([]byte, 5) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, "hello", string(buf)) + assert.Equal(t, uint64(5), br.SequenceNum()) + + // Read more data + n, err = br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, " worl", string(buf)) + assert.Equal(t, uint64(10), br.SequenceNum()) +} + +func TestBackedReader_ReadBlocksWhenDisconnected(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + + // Start a read operation that should block + readDone := make(chan struct{}) + readStarted := make(chan struct{}) + var readErr error + + go func() { + defer close(readDone) + close(readStarted) // Signal that we're about to start the read + buf := make([]byte, 10) + _, readErr = br.Read(buf) + }() + + // Wait for the goroutine to start + <-readStarted + + // Give a brief moment for the read to actually block on the condition variable + // This is much shorter and more deterministic than the previous approach + time.Sleep(time.Millisecond) + + // Read should still be blocked + select { + case <-readDone: + t.Fatal("Read should be blocked when disconnected") + default: + // Good, still blocked + } + + // Connect and the read should unblock + reader := newMockReader("test") + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- reader + + // Wait for read to complete + select { + case <-readDone: + assert.NoError(t, readErr) + case <-time.After(time.Second): + t.Fatal("Read did not unblock after reconnection") + } +} + +func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader1 := newMockReader("first") + + // Initial connection + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- reader1 + + // Read some data + buf := make([]byte, 5) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, "first", string(buf[:n])) + assert.Equal(t, uint64(5), br.SequenceNum()) + + // Set up error callback to verify error notification + errorReceived := make(chan error, 1) + br.SetErrorCallback(func(err error) { + errorReceived <- err + }) + + // Simulate connection failure + reader1.setError(xerrors.New("connection lost")) + + // Start a read that will block due to connection failure + readDone := make(chan error, 1) + go func() { + _, err := br.Read(buf) + readDone <- err + }() + + // Wait for the error to be reported via callback + select { + case receivedErr := <-errorReceived: + assert.Error(t, receivedErr) + assert.Contains(t, receivedErr.Error(), "connection lost") + case <-time.After(time.Second): + t.Fatal("Error callback was not invoked within timeout") + } + + // Verify disconnection + assert.False(t, br.Connected()) + + // Reconnect with new reader + reader2 := newMockReader("second") + seqNum2 := make(chan uint64, 1) + newR2 := make(chan io.Reader, 1) + + go br.Reconnect(seqNum2, newR2) + + // Get sequence number and send new reader + seq := <-seqNum2 + assert.Equal(t, uint64(5), seq) // Should return current sequence number + newR2 <- reader2 + + // Wait for read to unblock and succeed with new data + select { + case readErr := <-readDone: + assert.NoError(t, readErr) // Should succeed with new reader + case <-time.After(time.Second): + t.Fatal("Read did not unblock after reconnection") + } +} + +func TestBackedReader_Close(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader := newMockReader("test") + + // Connect + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- reader + + // First, read all available data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 4, n) // "test" is 4 bytes + + // Close the reader before EOF triggers reconnection + err = br.Close() + require.NoError(t, err) + + // After close, reads should return ErrClosedPipe + n, err = br.Read(buf) + assert.Equal(t, 0, n) + assert.Equal(t, io.ErrClosedPipe, err) + + // Subsequent reads should return ErrClosedPipe + _, err = br.Read(buf) + assert.Equal(t, io.ErrClosedPipe, err) +} + +func TestBackedReader_CloseIdempotent(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + + err := br.Close() + assert.NoError(t, err) + + // Second close should be no-op + err = br.Close() + assert.NoError(t, err) +} + +func TestBackedReader_ReconnectAfterClose(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + + err := br.Close() + require.NoError(t, err) + + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Should get 0 sequence number for closed reader + seq := <-seqNum + assert.Equal(t, uint64(0), seq) +} + +// Helper function to reconnect a reader using channels +func reconnectReader(br *backedpipe.BackedReader, reader io.Reader) { + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- reader +} + +func TestBackedReader_SequenceNumberTracking(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader := newMockReader("0123456789") + + reconnectReader(br, reader) + + // Read in chunks and verify sequence number + buf := make([]byte, 3) + + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, uint64(3), br.SequenceNum()) + + n, err = br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, uint64(6), br.SequenceNum()) + + n, err = br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, uint64(9), br.SequenceNum()) +} + +func TestBackedReader_ConcurrentReads(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader := newMockReader(strings.Repeat("a", 1000)) + + reconnectReader(br, reader) + + var wg sync.WaitGroup + numReaders := 5 + readsPerReader := 10 + + for i := 0; i < numReaders; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 10) + for j := 0; j < readsPerReader; j++ { + br.Read(buf) + } + }() + } + + wg.Wait() + + // Should have read some data (exact amount depends on scheduling) + assert.True(t, br.SequenceNum() > 0) + assert.True(t, br.SequenceNum() <= 1000) +} + +func TestBackedReader_EOFHandling(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + reader := newMockReader("test") + + // Set up error callback to track when EOF triggers disconnection + errorReceived := make(chan error, 1) + br.SetErrorCallback(func(err error) { + errorReceived <- err + }) + + reconnectReader(br, reader) + + // Read all data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "test", string(buf[:n])) + + // Next read should encounter EOF, which triggers disconnection + // The read should block waiting for reconnection + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + defer close(readDone) + readN, readErr = br.Read(buf) + }() + + // Wait for EOF to be reported via error callback + select { + case receivedErr := <-errorReceived: + assert.Equal(t, io.EOF, receivedErr) + case <-time.After(time.Second): + t.Fatal("EOF was not reported via error callback within timeout") + } + + // Reader should be disconnected after EOF + assert.False(t, br.Connected()) + + // Read should still be blocked + select { + case <-readDone: + t.Fatal("Read should be blocked waiting for reconnection after EOF") + default: + // Good, still blocked + } + + // Reconnect with new data + reader2 := newMockReader("more") + reconnectReader(br, reader2) + + // Wait for the blocked read to complete with new data + select { + case <-readDone: + require.NoError(t, readErr) + assert.Equal(t, 4, readN) + assert.Equal(t, "more", string(buf[:readN])) + case <-time.After(time.Second): + t.Fatal("Read did not unblock after reconnection") + } +} + +func BenchmarkBackedReader_Read(b *testing.B) { + br := backedpipe.NewBackedReader() + buf := make([]byte, 1024) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Create fresh reader with data for each iteration + data := strings.Repeat("x", 1024) // 1KB of data per iteration + reader := newMockReader(data) + reconnectReader(br, reader) + + br.Read(buf) + } +} + +func TestBackedReader_PartialReads(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + + // Create a reader that returns partial reads + reader := &mockReader{ + readFunc: func(p []byte) (int, error) { + // Always return just 1 byte at a time + if len(p) == 0 { + return 0, nil + } + p[0] = 'A' + return 1, nil + }, + } + + reconnectReader(br, reader) + + // Read multiple times + buf := make([]byte, 10) + for i := 0; i < 5; i++ { + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 1, n) + assert.Equal(t, byte('A'), buf[0]) + } + + assert.Equal(t, uint64(5), br.SequenceNum()) +} diff --git a/coderd/agentapi/backedpipe/backed_writer.go b/coderd/agentapi/backedpipe/backed_writer.go new file mode 100644 index 0000000000000..bc72d8bfc7385 --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_writer.go @@ -0,0 +1,244 @@ +package backedpipe + +import ( + "context" + "io" + "sync" + + "golang.org/x/xerrors" +) + +// BackedWriter wraps an unreliable io.Writer and makes it resilient to disconnections. +// It maintains a ring buffer of recent writes for replay during reconnection and +// always writes to the buffer even when disconnected. +type BackedWriter struct { + mu sync.Mutex + cond *sync.Cond + writer io.Writer + buffer *RingBuffer + sequenceNum uint64 // total bytes written + closed bool + + // Error callback to notify parent when connection fails + onError func(error) +} + +// NewBackedWriter creates a new BackedWriter with a 64MB ring buffer. +// The writer is initially disconnected and will buffer writes until connected. +func NewBackedWriter() *BackedWriter { + return NewBackedWriterWithCapacity(64 * 1024 * 1024) +} + +// NewBackedWriterWithCapacity creates a new BackedWriter with the specified buffer capacity. +// The writer is initially disconnected and will buffer writes until connected. +func NewBackedWriterWithCapacity(capacity int) *BackedWriter { + bw := &BackedWriter{ + buffer: NewRingBufferWithCapacity(capacity), + } + bw.cond = sync.NewCond(&bw.mu) + return bw +} + +// Write implements io.Writer. It always writes to the ring buffer, even when disconnected. +// When connected, it also writes to the underlying writer. If the underlying write fails, +// the writer is marked as disconnected but the buffer write still succeeds. +func (bw *BackedWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return 0, io.ErrClosedPipe + } + + // Always write to buffer first + written, _ := bw.buffer.Write(p) + //nolint:gosec // Safe conversion: written is always non-negative from buffer.Write + bw.sequenceNum += uint64(written) + + // If connected, also write to underlying writer + if bw.writer != nil { + // Unlock during actual write to avoid blocking other operations + bw.mu.Unlock() + n, err := bw.writer.Write(p) + bw.mu.Lock() + + if n != len(p) { + err = xerrors.Errorf("partial write: wrote %d of %d bytes", n, len(p)) + } + + if err != nil { + // Connection failed, mark as disconnected + bw.writer = nil + + // Notify parent of error if callback is set + if bw.onError != nil { + bw.onError(err) + } + } + } + + return written, nil +} + +// Reconnect replaces the current writer with a new one and replays data from the specified +// sequence number. If the requested sequence number is no longer in the buffer, +// returns an error indicating data loss. +func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) error { + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return xerrors.New("cannot reconnect closed writer") + } + + if newWriter == nil { + return xerrors.New("new writer cannot be nil") + } + + // Check if we can replay from the requested sequence number + if replayFromSeq > bw.sequenceNum { + return xerrors.Errorf("cannot replay from future sequence %d: current sequence is %d", replayFromSeq, bw.sequenceNum) + } + + // Calculate how many bytes we need to replay + replayBytes := bw.sequenceNum - replayFromSeq + + var replayData []byte + if replayBytes > 0 { + // Get the last replayBytes from buffer + // If the buffer doesn't have enough data (some was evicted), + // ReadLast will return an error + var err error + // Safe conversion: replayBytes is always non-negative due to the check above + // No overflow possible since replayBytes is calculated as sequenceNum - replayFromSeq + // and uint64->int conversion is safe for reasonable buffer sizes + //nolint:gosec // Safe conversion: replayBytes is calculated from uint64 subtraction + replayData, err = bw.buffer.ReadLast(int(replayBytes)) + if err != nil { + return xerrors.Errorf("failed to read replay data: %w", err) + } + } + + // Set new writer + bw.writer = newWriter + + // Replay data if needed + if len(replayData) > 0 { + bw.mu.Unlock() + n, err := newWriter.Write(replayData) + bw.mu.Lock() + + if err != nil { + bw.writer = nil + return xerrors.Errorf("replay failed: %w", err) + } + + if n != len(replayData) { + bw.writer = nil + return xerrors.Errorf("partial replay: wrote %d of %d bytes", n, len(replayData)) + } + } + + // Wake up any operations waiting for connection + bw.cond.Broadcast() + + return nil +} + +// Close closes the writer and prevents further writes. +// After closing, all Write calls will return io.ErrClosedPipe. +// This code keeps the Close() signature consistent with io.Closer, +// but it never actually returns an error. +func (bw *BackedWriter) Close() error { + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return nil + } + + bw.closed = true + bw.writer = nil + + // Wake up any blocked operations + bw.cond.Broadcast() + + return nil +} + +// SetErrorCallback sets the callback function that will be called when +// a connection error occurs. +func (bw *BackedWriter) SetErrorCallback(fn func(error)) { + bw.mu.Lock() + defer bw.mu.Unlock() + bw.onError = fn +} + +// SequenceNum returns the current sequence number (total bytes written). +func (bw *BackedWriter) SequenceNum() uint64 { + bw.mu.Lock() + defer bw.mu.Unlock() + return bw.sequenceNum +} + +// Connected returns whether the writer is currently connected. +func (bw *BackedWriter) Connected() bool { + bw.mu.Lock() + defer bw.mu.Unlock() + return bw.writer != nil +} + +// CanReplayFrom returns true if the writer can replay data from the given sequence number. +func (bw *BackedWriter) CanReplayFrom(seqNum uint64) bool { + bw.mu.Lock() + defer bw.mu.Unlock() + return seqNum <= bw.sequenceNum && bw.sequenceNum-seqNum <= DefaultBufferSize +} + +// WaitForConnection blocks until the writer is connected or the context is canceled. +func (bw *BackedWriter) WaitForConnection(ctx context.Context) error { + bw.mu.Lock() + defer bw.mu.Unlock() + + return bw.waitForConnectionLocked(ctx) +} + +// waitForConnectionLocked waits for connection with lock held. +func (bw *BackedWriter) waitForConnectionLocked(ctx context.Context) error { + for bw.writer == nil && !bw.closed { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // Use a timeout to avoid infinite waiting + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + bw.cond.Broadcast() + case <-done: + } + }() + + bw.cond.Wait() + close(done) + + // Check context again after waking up + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } + } + + if bw.closed { + return io.ErrClosedPipe + } + + return nil +} diff --git a/coderd/agentapi/backedpipe/backed_writer_test.go b/coderd/agentapi/backedpipe/backed_writer_test.go new file mode 100644 index 0000000000000..f92a79c6f366b --- /dev/null +++ b/coderd/agentapi/backedpipe/backed_writer_test.go @@ -0,0 +1,411 @@ +package backedpipe_test + +import ( + "bytes" + "context" + "io" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/agentapi/backedpipe" + "github.com/coder/coder/v2/testutil" +) + +// mockWriter implements io.Writer with controllable behavior for testing +type mockWriter struct { + mu sync.Mutex + buffer bytes.Buffer + err error + writeFunc func([]byte) (int, error) + writeCalls int +} + +func newMockWriter() *mockWriter { + return &mockWriter{} +} + +// newBackedWriterForTest creates a BackedWriter with a small buffer for testing eviction behavior +func newBackedWriterForTest(bufferSize int) *backedpipe.BackedWriter { + return backedpipe.NewBackedWriterWithCapacity(bufferSize) +} + +func (mw *mockWriter) Write(p []byte) (int, error) { + mw.mu.Lock() + defer mw.mu.Unlock() + + mw.writeCalls++ + + if mw.writeFunc != nil { + return mw.writeFunc(p) + } + + if mw.err != nil { + return 0, mw.err + } + + return mw.buffer.Write(p) +} + +func (mw *mockWriter) Len() int { + mw.mu.Lock() + defer mw.mu.Unlock() + return mw.buffer.Len() +} + +func (mw *mockWriter) Reset() { + mw.mu.Lock() + defer mw.mu.Unlock() + mw.buffer.Reset() + mw.writeCalls = 0 + mw.err = nil + mw.writeFunc = nil +} + +func (mw *mockWriter) setError(err error) { + mw.mu.Lock() + defer mw.mu.Unlock() + mw.err = err +} + +func TestBackedWriter_NewBackedWriter(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + require.NotNil(t, bw) + require.Equal(t, uint64(0), bw.SequenceNum()) + require.False(t, bw.Connected()) +} + +func TestBackedWriter_WriteToBufferWhenDisconnected(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + // Write should succeed even when disconnected + n, err := bw.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, uint64(5), bw.SequenceNum()) + + // Data should be in buffer +} + +func TestBackedWriter_WriteToUnderlyingWhenConnected(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + writer := newMockWriter() + + // Connect + err := bw.Reconnect(0, writer) + require.NoError(t, err) + require.True(t, bw.Connected()) + + // Write should go to both buffer and underlying writer + n, err := bw.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Data should be buffered + + // Check underlying writer + require.Equal(t, []byte("hello"), writer.buffer.Bytes()) +} + +func TestBackedWriter_DisconnectOnWriteFailure(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + writer := newMockWriter() + + // Connect + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Cause write to fail + writer.setError(xerrors.New("write failed")) + + // Write should still succeed to buffer but disconnect + n, err := bw.Write([]byte("hello")) + require.NoError(t, err) // Buffer write succeeds + require.Equal(t, 5, n) + require.False(t, bw.Connected()) // Should be disconnected + + // Data should still be in buffer +} + +func TestBackedWriter_ReplayOnReconnect(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + // Write some data while disconnected + bw.Write([]byte("hello")) + bw.Write([]byte(" world")) + + require.Equal(t, uint64(11), bw.SequenceNum()) + + // Reconnect and request replay from beginning + writer := newMockWriter() + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Should have replayed all data + require.Equal(t, []byte("hello world"), writer.buffer.Bytes()) + + // Write new data should go to both + bw.Write([]byte("!")) + require.Equal(t, []byte("hello world!"), writer.buffer.Bytes()) +} + +func TestBackedWriter_PartialReplay(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + // Write some data + bw.Write([]byte("hello")) + bw.Write([]byte(" world")) + bw.Write([]byte("!")) + + // Reconnect and request replay from middle + writer := newMockWriter() + err := bw.Reconnect(5, writer) // From " world!" + require.NoError(t, err) + + // Should have replayed only the requested portion + require.Equal(t, []byte(" world!"), writer.buffer.Bytes()) +} + +func TestBackedWriter_ReplayFromFutureSequence(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + bw.Write([]byte("hello")) + + writer := newMockWriter() + err := bw.Reconnect(10, writer) // Future sequence + require.Error(t, err) + require.Contains(t, err.Error(), "future sequence") +} + +func TestBackedWriter_ReplayDataLoss(t *testing.T) { + t.Parallel() + + bw := newBackedWriterForTest(10) // Small buffer for testing + + // Fill buffer beyond capacity to cause eviction + bw.Write([]byte("0123456789")) // Fills buffer exactly + bw.Write([]byte("abcdef")) // Should evict "012345" + + writer := newMockWriter() + err := bw.Reconnect(0, writer) // Try to replay from evicted data + // With the new error handling, this should fail because we can't read all the data + require.Error(t, err) + require.Contains(t, err.Error(), "failed to read replay data") +} + +func TestBackedWriter_BufferEviction(t *testing.T) { + t.Parallel() + + bw := newBackedWriterForTest(5) // Very small buffer for testing + + // Write data that will cause eviction + n, err := bw.Write([]byte("abcde")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Write more to cause eviction + n, err = bw.Write([]byte("fg")) + require.NoError(t, err) + require.Equal(t, 2, n) + + // Buffer should contain "cdefg" (latest data) +} + +func TestBackedWriter_Close(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + writer := newMockWriter() + + bw.Reconnect(0, writer) + + err := bw.Close() + require.NoError(t, err) + + // Writes after close should fail + _, err = bw.Write([]byte("test")) + require.Equal(t, io.ErrClosedPipe, err) + + // Reconnect after close should fail + err = bw.Reconnect(0, newMockWriter()) + require.Error(t, err) + require.Contains(t, err.Error(), "closed") +} + +func TestBackedWriter_CloseIdempotent(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + err := bw.Close() + require.NoError(t, err) + + // Second close should be no-op + err = bw.Close() + require.NoError(t, err) +} + +func TestBackedWriter_CanReplayFrom(t *testing.T) { + t.Parallel() + + bw := newBackedWriterForTest(10) // Small buffer for testing eviction + + // Empty buffer + require.True(t, bw.CanReplayFrom(0)) + require.False(t, bw.CanReplayFrom(1)) + + // Write some data + bw.Write([]byte("hello")) + require.True(t, bw.CanReplayFrom(0)) + require.True(t, bw.CanReplayFrom(3)) + require.True(t, bw.CanReplayFrom(5)) + require.False(t, bw.CanReplayFrom(6)) + + // Fill buffer and cause eviction + bw.Write([]byte("world!")) + require.True(t, bw.CanReplayFrom(0)) // Can replay from any sequence up to current + require.True(t, bw.CanReplayFrom(bw.SequenceNum())) +} + +func TestBackedWriter_WaitForConnection(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + // Should timeout when not connected + // Use a shorter timeout for this test to speed up test runs + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperShort) + defer cancel() + + err := bw.WaitForConnection(ctx) + require.Equal(t, context.DeadlineExceeded, err) + + // Should succeed immediately when connected + writer := newMockWriter() + bw.Reconnect(0, writer) + + ctx = context.Background() + err = bw.WaitForConnection(ctx) + require.NoError(t, err) +} + +func TestBackedWriter_ConcurrentWrites(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + writer := newMockWriter() + bw.Reconnect(0, writer) + + var wg sync.WaitGroup + numWriters := 10 + writesPerWriter := 50 + + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < writesPerWriter; j++ { + data := []byte{byte(id + '0')} + bw.Write(data) + } + }(i) + } + + wg.Wait() + + // Should have written expected amount to buffer + expectedBytes := uint64(numWriters * writesPerWriter) //nolint:gosec // Safe conversion: test constants with small values + require.Equal(t, expectedBytes, bw.SequenceNum()) + // Note: underlying writer may not receive all bytes due to potential disconnections + // during concurrent operations, but the buffer should track all writes + require.True(t, writer.Len() <= int(expectedBytes)) //nolint:gosec // Safe conversion: expectedBytes is calculated from small test values +} + +func TestBackedWriter_ReconnectDuringReplay(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + bw.Write([]byte("hello world")) + + // Create a writer that fails during replay + writer := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + return 0, xerrors.New("replay failed") + }, + } + + err := bw.Reconnect(0, writer) + require.Error(t, err) + require.Contains(t, err.Error(), "replay failed") + require.False(t, bw.Connected()) +} + +func TestBackedWriter_PartialWriteToUnderlying(t *testing.T) { + t.Parallel() + + bw := backedpipe.NewBackedWriter() + + // Create writer that does partial writes + writer := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + if len(p) > 3 { + return 3, nil // Only write first 3 bytes + } + return len(p), nil + }, + } + + bw.Reconnect(0, writer) + + // Write should succeed to buffer but disconnect due to partial write + n, err := bw.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + require.False(t, bw.Connected()) + + // Buffer should have all data +} + +func BenchmarkBackedWriter_Write(b *testing.B) { + bw := backedpipe.NewBackedWriter() // 64KB buffer + writer := newMockWriter() + bw.Reconnect(0, writer) + + data := bytes.Repeat([]byte("x"), 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bw.Write(data) + } +} + +func BenchmarkBackedWriter_Reconnect(b *testing.B) { + bw := backedpipe.NewBackedWriter() + + // Fill buffer with data + data := bytes.Repeat([]byte("x"), 1024) + for i := 0; i < 32; i++ { + bw.Write(data) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + writer := newMockWriter() + bw.Reconnect(0, writer) + } +} diff --git a/coderd/agentapi/backedpipe/ring_buffer.go b/coderd/agentapi/backedpipe/ring_buffer.go new file mode 100644 index 0000000000000..f092385741e0c --- /dev/null +++ b/coderd/agentapi/backedpipe/ring_buffer.go @@ -0,0 +1,140 @@ +package backedpipe + +import ( + "sync" + + "golang.org/x/xerrors" +) + +// RingBuffer implements an efficient circular buffer with a fixed-size allocation. +// It supports concurrent access and handles wrap-around seamlessly. +// The buffer is designed for high-performance scenarios where avoiding +// dynamic memory allocation during operation is critical. +type RingBuffer struct { + mu sync.RWMutex + buffer []byte + start int // index of first valid byte + end int // index after last valid byte + size int // current number of bytes in buffer + cap int // maximum capacity +} + +// NewRingBuffer creates a new ring buffer with 64MB capacity. +func NewRingBuffer() *RingBuffer { + const capacity = 64 * 1024 * 1024 // 64MB + return NewRingBufferWithCapacity(capacity) +} + +// NewRingBufferWithCapacity creates a new ring buffer with the specified capacity. +// If capacity is <= 0, it defaults to 64MB. +func NewRingBufferWithCapacity(capacity int) *RingBuffer { + if capacity <= 0 { + capacity = 64 * 1024 * 1024 // Default to 64MB + } + return &RingBuffer{ + buffer: make([]byte, capacity), + cap: capacity, + } +} + +// Write writes data to the ring buffer. If the buffer would overflow, +// it evicts the oldest data to make room for new data. +// Returns the number of bytes written and the number of bytes evicted. +func (rb *RingBuffer) Write(data []byte) (written int, evicted int) { + if len(data) == 0 { + return 0, 0 + } + + rb.mu.Lock() + defer rb.mu.Unlock() + + written = len(data) + + // If data is larger than capacity, only keep the last capacity bytes + if len(data) > rb.cap { + evicted = len(data) - rb.cap + data = data[evicted:] + written = rb.cap + // Clear buffer and write new data + rb.start = 0 + rb.end = 0 + rb.size = 0 + } + + // Calculate how much we need to evict to fit new data + spaceNeeded := len(data) + availableSpace := rb.cap - rb.size + + if spaceNeeded > availableSpace { + bytesToEvict := spaceNeeded - availableSpace + evicted += bytesToEvict + rb.evict(bytesToEvict) + } + + // Write the data + for _, b := range data { + rb.buffer[rb.end] = b + rb.end = (rb.end + 1) % rb.cap + rb.size++ + } + + return written, evicted +} + +// evict removes the specified number of bytes from the beginning of the buffer. +// Must be called with lock held. +func (rb *RingBuffer) evict(count int) { + if count >= rb.size { + // Evict everything + rb.start = 0 + rb.end = 0 + rb.size = 0 + return + } + + rb.start = (rb.start + count) % rb.cap + rb.size -= count +} + +// ReadLast returns the last n bytes from the buffer. +// If n is greater than the available data, returns all available data. +// If n is 0 or negative, returns nil. +func (rb *RingBuffer) ReadLast(n int) ([]byte, error) { + rb.mu.RLock() + defer rb.mu.RUnlock() + + if n <= 0 { + return nil, nil + } + + if rb.size == 0 { + return nil, xerrors.New("buffer is empty") + } + + // If requested more than available, return error + if n > rb.size { + return nil, xerrors.Errorf("requested %d bytes but only %d available", n, rb.size) + } + + result := make([]byte, n) + + // Calculate where to start reading from (n bytes before the end) + startOffset := rb.size - n + actualStart := rb.start + startOffset + if rb.cap > 0 { + actualStart %= rb.cap + } + + // Copy the last n bytes + if actualStart+n <= rb.cap { + // No wrap needed + copy(result, rb.buffer[actualStart:actualStart+n]) + } else { + // Need to wrap around + firstChunk := rb.cap - actualStart + copy(result[0:firstChunk], rb.buffer[actualStart:rb.cap]) + copy(result[firstChunk:], rb.buffer[0:n-firstChunk]) + } + + return result, nil +} diff --git a/coderd/agentapi/backedpipe/ring_buffer_internal_test.go b/coderd/agentapi/backedpipe/ring_buffer_internal_test.go new file mode 100644 index 0000000000000..5a23880774057 --- /dev/null +++ b/coderd/agentapi/backedpipe/ring_buffer_internal_test.go @@ -0,0 +1,162 @@ +package backedpipe + +import ( + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRingBuffer_ClearInternal(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(10) + rb.Write([]byte("hello")) + require.Equal(t, 5, rb.size) + + rb.Clear() + require.Equal(t, 0, rb.size) + require.Equal(t, "", rb.String()) +} + +func TestRingBuffer_Available(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(10) + require.Equal(t, 10, rb.Available()) + + rb.Write([]byte("hello")) + require.Equal(t, 5, rb.Available()) + + rb.Write([]byte("world")) + require.Equal(t, 0, rb.Available()) +} + +func TestRingBuffer_StringInternal(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(10) + require.Equal(t, "", rb.String()) + + rb.Write([]byte("hello")) + require.Equal(t, "hello", rb.String()) + + rb.Write([]byte("world")) + require.Equal(t, "helloworld", rb.String()) +} + +func TestRingBuffer_StringWithWrapAround(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(5) + rb.Write([]byte("hello")) + require.Equal(t, "hello", rb.String()) + + rb.Write([]byte("world")) + require.Equal(t, "world", rb.String()) +} + +func TestRingBuffer_ConcurrentAccessWithString(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(1000) + var wg sync.WaitGroup + + // Start multiple goroutines writing + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + data := fmt.Sprintf("data-%d", id) + for j := 0; j < 100; j++ { + rb.Write([]byte(data)) + } + }(i) + } + + wg.Wait() + + // Verify buffer is still in valid state + require.NotEmpty(t, rb.String()) +} + +func TestRingBuffer_EdgeCaseEvictionWithString(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(3) + rb.Write([]byte("hello")) + rb.Write([]byte("world")) + + // Should evict "he" and keep "llo world" + require.Equal(t, "rld", rb.String()) + + // Write more data to cause more eviction + rb.Write([]byte("test")) + require.Equal(t, "est", rb.String()) +} + +// TestRingBuffer_ComplexWrapAroundScenarioWithString tests complex wrap-around with String +func TestRingBuffer_ComplexWrapAroundScenarioWithString(t *testing.T) { + t.Parallel() + + rb := NewRingBufferWithCapacity(5) + + // Fill buffer + rb.Write([]byte("abcde")) + require.Equal(t, "abcde", rb.String()) + + // Write more to cause wrap-around + rb.Write([]byte("fgh")) + require.Equal(t, "defgh", rb.String()) + + // Write even more + rb.Write([]byte("ijklmn")) + require.Equal(t, "jklmn", rb.String()) +} + +// Helper function to get available space (for internal tests only) +func (rb *RingBuffer) Available() int { + rb.mu.RLock() + defer rb.mu.RUnlock() + return rb.cap - rb.size +} + +// Helper function to clear buffer (for internal tests only) +func (rb *RingBuffer) Clear() { + rb.mu.Lock() + defer rb.mu.Unlock() + + rb.start = 0 + rb.end = 0 + rb.size = 0 +} + +// Helper function to get string representation (for internal tests only) +func (rb *RingBuffer) String() string { + rb.mu.RLock() + defer rb.mu.RUnlock() + + if rb.size == 0 { + return "" + } + + // readAllInternal equivalent for internal tests + if rb.size == 0 { + return "" + } + + result := make([]byte, rb.size) + + if rb.start+rb.size <= rb.cap { + // No wrap needed + copy(result, rb.buffer[rb.start:rb.start+rb.size]) + } else { + // Need to wrap around + firstChunk := rb.cap - rb.start + copy(result[0:firstChunk], rb.buffer[rb.start:rb.cap]) + copy(result[firstChunk:], rb.buffer[0:rb.size-firstChunk]) + } + + return string(result) +} diff --git a/coderd/agentapi/backedpipe/ring_buffer_test.go b/coderd/agentapi/backedpipe/ring_buffer_test.go new file mode 100644 index 0000000000000..3febbcb433e5a --- /dev/null +++ b/coderd/agentapi/backedpipe/ring_buffer_test.go @@ -0,0 +1,326 @@ +package backedpipe_test + +import ( + "bytes" + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/agentapi/backedpipe" +) + +func TestRingBuffer_NewRingBuffer(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(100) + // Test that we can write and read from the buffer + written, evicted := rb.Write([]byte("test")) + require.Equal(t, 4, written) + require.Equal(t, 0, evicted) + + data, err := rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("test"), data) +} + +func TestRingBuffer_WriteAndRead(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(10) + + // Write some data + rb.Write([]byte("hello")) + + // Read last 4 bytes + data, err := rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, "ello", string(data)) + + // Write more data + rb.Write([]byte("world")) + + // Read last 5 bytes + data, err = rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, "world", string(data)) + + // Read last 3 bytes + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, "rld", string(data)) + + // Read more than available (should be 10 bytes total) + _, err = rb.ReadLast(15) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 15 bytes but only") +} + +func TestRingBuffer_OverflowEviction(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(5) + + // Fill buffer + written, evicted := rb.Write([]byte("abcde")) + require.Equal(t, 5, written) + require.Equal(t, 0, evicted) + + // Overflow should evict oldest data + written, evicted = rb.Write([]byte("fg")) + require.Equal(t, 2, written) + require.Equal(t, 2, evicted) + + // Should now contain "cdefg" + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("cdefg"), data) +} + +func TestRingBuffer_LargeWrite(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(5) + + // Write data larger than capacity + written, evicted := rb.Write([]byte("abcdefghij")) + require.Equal(t, 5, written) + require.Equal(t, 5, evicted) + + // Should contain last 5 bytes + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("fghij"), data) +} + +func TestRingBuffer_WrapAround(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(5) + + // Fill buffer + rb.Write([]byte("abcde")) + + // Write more to cause wrap-around + rb.Write([]byte("fgh")) + + // Should contain "defgh" + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("defgh"), data) + + // Test reading last 3 bytes after wrap + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("fgh"), data) +} + +func TestRingBuffer_ReadLastEdgeCases(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(3) + + // Write some data (5 bytes to a 3-byte buffer, so only last 3 bytes remain) + rb.Write([]byte("hello")) + + // Test reading negative count + data, err := rb.ReadLast(-1) + require.NoError(t, err) + require.Nil(t, data) + + // Test reading zero bytes + data, err = rb.ReadLast(0) + require.NoError(t, err) + require.Nil(t, data) + + // Test reading more than available (buffer has 3 bytes, try to read 10) + _, err = rb.ReadLast(10) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 10 bytes but only 3 available") + + // Test reading exact amount available + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("llo"), data) +} + +func TestRingBuffer_EmptyWrite(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(10) + + // Write empty data + written, evicted := rb.Write([]byte{}) + require.Equal(t, 0, written) + require.Equal(t, 0, evicted) + + // Buffer should still be empty + _, err := rb.ReadLast(5) + require.Error(t, err) + require.Contains(t, err.Error(), "buffer is empty") +} + +func TestRingBuffer_ConcurrentAccess(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(1000) + var wg sync.WaitGroup + + // Start multiple goroutines writing + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + data := []byte(fmt.Sprintf("data-%d", id)) + for j := 0; j < 100; j++ { + rb.Write(data) + } + }(i) + } + + // Start multiple goroutines reading + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + _, err := rb.ReadLast(100) + if err != nil { + // Error is expected if buffer doesn't have enough data + continue + } + } + }() + } + + wg.Wait() + + // Verify buffer is still in valid state + data, err := rb.ReadLast(1000) + require.NoError(t, err) + require.NotNil(t, data) +} + +func TestRingBuffer_MultipleWrites(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(10) + + // Write data in chunks + rb.Write([]byte("ab")) + rb.Write([]byte("cd")) + rb.Write([]byte("ef")) + + data, err := rb.ReadLast(6) + require.NoError(t, err) + require.Equal(t, []byte("abcdef"), data) + + // Test partial reads + data, err = rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("cdef"), data) + + data, err = rb.ReadLast(2) + require.NoError(t, err) + require.Equal(t, []byte("ef"), data) +} + +func TestRingBuffer_EdgeCaseEviction(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(3) + + // Write data that will cause eviction + written, evicted := rb.Write([]byte("abc")) + require.Equal(t, 3, written) + require.Equal(t, 0, evicted) + + // Write more to cause eviction + written, evicted = rb.Write([]byte("d")) + require.Equal(t, 1, written) + require.Equal(t, 1, evicted) + + // Should now contain "bcd" + data, err := rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("bcd"), data) +} + +func TestRingBuffer_ComplexWrapAroundScenario(t *testing.T) { + t.Parallel() + + rb := backedpipe.NewRingBufferWithCapacity(8) + + // Fill buffer + rb.Write([]byte("12345678")) + + // Evict some and add more to create complex wrap scenario + rb.Write([]byte("abcd")) + data, err := rb.ReadLast(8) + require.NoError(t, err) + require.Equal(t, []byte("5678abcd"), data) + + // Add more + rb.Write([]byte("xyz")) + data, err = rb.ReadLast(8) + require.NoError(t, err) + require.Equal(t, []byte("8abcdxyz"), data) + + // Test reading various amounts from the end + data, err = rb.ReadLast(7) + require.NoError(t, err) + require.Equal(t, []byte("abcdxyz"), data) + + data, err = rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("dxyz"), data) +} + +// Benchmark tests for performance validation +func BenchmarkRingBuffer_Write(b *testing.B) { + rb := backedpipe.NewRingBuffer() // Use full 64MB for benchmarks + data := bytes.Repeat([]byte("x"), 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rb.Write(data) + } +} + +func BenchmarkRingBuffer_ReadLast(b *testing.B) { + rb := backedpipe.NewRingBuffer() // Use full 64MB for benchmarks + // Fill buffer with test data + for i := 0; i < 64; i++ { + rb.Write(bytes.Repeat([]byte("x"), 1024)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := rb.ReadLast((i % 100) + 1) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkRingBuffer_ConcurrentAccess(b *testing.B) { + rb := backedpipe.NewRingBuffer() // Use full 64MB for benchmarks + data := bytes.Repeat([]byte("x"), 100) + + // Pre-fill buffer with enough data + for i := 0; i < 100; i++ { + rb.Write(data) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + rb.Write(data) + _, err := rb.ReadLast(100) // Read only what we know is available + if err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/testutil/duration.go b/testutil/duration.go index a8c35030cdea2..821684f6b0f98 100644 --- a/testutil/duration.go +++ b/testutil/duration.go @@ -7,10 +7,11 @@ import ( // Constants for timing out operations, usable for creating contexts // that timeout or in require.Eventually. const ( - WaitShort = 10 * time.Second - WaitMedium = 15 * time.Second - WaitLong = 25 * time.Second - WaitSuperLong = 60 * time.Second + WaitSuperShort = 100 * time.Millisecond + WaitShort = 10 * time.Second + WaitMedium = 15 * time.Second + WaitLong = 25 * time.Second + WaitSuperLong = 60 * time.Second ) // Constants for delaying repeated operations, e.g. in From f77d7bf08007b4c315d4876c9db1a16843bf8386 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 12 Aug 2025 12:34:52 +0000 Subject: [PATCH 02/11] improvement to backed pipe concurrency and tests --- coderd/agentapi/backedpipe/backed_pipe.go | 43 +++++++++++++------ .../agentapi/backedpipe/backed_pipe_test.go | 25 +++++++++-- coderd/agentapi/backedpipe/backed_reader.go | 25 +++++++---- .../agentapi/backedpipe/ring_buffer_test.go | 13 ++++++ 4 files changed, 81 insertions(+), 25 deletions(-) diff --git a/coderd/agentapi/backedpipe/backed_pipe.go b/coderd/agentapi/backedpipe/backed_pipe.go index 784b8d55353b1..a06dac4a22604 100644 --- a/coderd/agentapi/backedpipe/backed_pipe.go +++ b/coderd/agentapi/backedpipe/backed_pipe.go @@ -4,7 +4,9 @@ import ( "context" "io" "sync" + "time" + "golang.org/x/sync/singleflight" "golang.org/x/xerrors" ) @@ -48,6 +50,9 @@ type BackedPipe struct { // Connection state notification connectionChanged chan struct{} + + // singleflight group to dedupe concurrent ForceReconnect calls + sf singleflight.Group } // NewBackedPipe creates a new BackedPipe with default options and the specified reconnect function. @@ -87,7 +92,7 @@ func NewBackedPipe(ctx context.Context, reconnectFn ReconnectFunc) *BackedPipe { } // Connect establishes the initial connection using the reconnect function. -func (bp *BackedPipe) Connect(ctx context.Context) error { +func (bp *BackedPipe) Connect(_ context.Context) error { // external ctx ignored; internal ctx used bp.mu.Lock() defer bp.mu.Unlock() @@ -99,7 +104,9 @@ func (bp *BackedPipe) Connect(ctx context.Context) error { return xerrors.New("pipe is already connected") } - return bp.reconnectLocked(ctx) + // Use internal context for the actual reconnect operation to ensure + // Close() reliably cancels any in-flight attempt. + return bp.reconnectLocked() } // Read implements io.Reader by delegating to the BackedReader. @@ -188,7 +195,7 @@ func (bp *BackedPipe) signalConnectionChange() { } // reconnectLocked handles the reconnection logic. Must be called with write lock held. -func (bp *BackedPipe) reconnectLocked(ctx context.Context) error { +func (bp *BackedPipe) reconnectLocked() error { if bp.reconnecting { return xerrors.New("reconnection already in progress") } @@ -212,7 +219,7 @@ func (bp *BackedPipe) reconnectLocked(ctx context.Context) error { // Unlock during reconnect attempt to avoid blocking reads/writes bp.mu.Unlock() - conn, readerSeqNum, err := bp.reconnectFn(ctx, writerSeqNum) + conn, readerSeqNum, err := bp.reconnectFn(bp.ctx, writerSeqNum) bp.mu.Lock() if err != nil { @@ -280,8 +287,8 @@ func (bp *BackedPipe) handleErrors() { bp.connected = false bp.signalConnectionChange() - // Try to reconnect - reconnectErr := bp.reconnectLocked(bp.ctx) + // Try to reconnect using internal context + reconnectErr := bp.reconnectLocked() bp.mu.Unlock() if reconnectErr != nil { @@ -314,19 +321,27 @@ func (bp *BackedPipe) WaitForConnection(ctx context.Context) error { return ctx.Err() case <-bp.connectionChanged: // Connection state changed, check again + case <-time.After(10 * time.Millisecond): + // Periodically re-check to avoid missed notifications } } } // ForceReconnect forces a reconnection attempt immediately. // This can be used to force a reconnection if a new connection is established. -func (bp *BackedPipe) ForceReconnect(ctx context.Context) error { - bp.mu.Lock() - defer bp.mu.Unlock() - - if bp.closed { - return io.ErrClosedPipe - } +func (bp *BackedPipe) ForceReconnect() error { + // Deduplicate concurrent ForceReconnect calls so only one reconnection + // attempt runs at a time from this API. Use the pipe's internal context + // to ensure Close() cancels any in-flight attempt. + _, err, _ := bp.sf.Do("backedpipe-reconnect", func() (interface{}, error) { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.closed { + return nil, io.ErrClosedPipe + } - return bp.reconnectLocked(ctx) + return nil, bp.reconnectLocked() + }) + return err } diff --git a/coderd/agentapi/backedpipe/backed_pipe_test.go b/coderd/agentapi/backedpipe/backed_pipe_test.go index c841112ed07e1..5b345a325fd56 100644 --- a/coderd/agentapi/backedpipe/backed_pipe_test.go +++ b/coderd/agentapi/backedpipe/backed_pipe_test.go @@ -152,6 +152,7 @@ func TestBackedPipe_NewBackedPipe(t *testing.T) { reconnectFn, _, _ := mockReconnectFunc(newMockConnection()) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() require.NotNil(t, bp) require.False(t, bp.Connected()) } @@ -164,6 +165,7 @@ func TestBackedPipe_Connect(t *testing.T) { reconnectFn, callCount, _ := mockReconnectFunc(conn) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() err := bp.Connect(ctx) require.NoError(t, err) @@ -179,6 +181,7 @@ func TestBackedPipe_ConnectAlreadyConnected(t *testing.T) { reconnectFn, _, _ := mockReconnectFunc(conn) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() err := bp.Connect(ctx) require.NoError(t, err) @@ -214,6 +217,7 @@ func TestBackedPipe_BasicReadWrite(t *testing.T) { reconnectFn, _, _ := mockReconnectFunc(conn) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() err := bp.Connect(ctx) require.NoError(t, err) @@ -242,6 +246,7 @@ func TestBackedPipe_WriteBeforeConnect(t *testing.T) { reconnectFn, _, _ := mockReconnectFunc(conn) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() // Write before connecting should succeed (buffered) n, err := bp.Write([]byte("hello")) @@ -263,6 +268,7 @@ func TestBackedPipe_ReadBlocksWhenDisconnected(t *testing.T) { reconnectFn, _, _ := mockReconnectFunc(newMockConnection()) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() // Start a read that should block readDone := make(chan struct{}) @@ -311,6 +317,7 @@ func TestBackedPipe_Reconnection(t *testing.T) { reconnectFn, _, signalChan := mockReconnectFunc(conn1, conn2) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() // Initial connect err := bp.Connect(ctx) @@ -392,6 +399,7 @@ func TestBackedPipe_WaitForConnection(t *testing.T) { reconnectFn, _, _ := mockReconnectFunc(conn) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() // Should timeout when not connected // Use a shorter timeout for this test to speed up test runs @@ -426,6 +434,7 @@ func TestBackedPipe_ConcurrentReadWrite(t *testing.T) { reconnectFn, _, _ := mockReconnectFunc(conn) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() err := bp.Connect(ctx) require.NoError(t, err) @@ -514,6 +523,7 @@ func TestBackedPipe_ReconnectFunctionFailure(t *testing.T) { } bp := backedpipe.NewBackedPipe(ctx, failingReconnectFn) + defer bp.Close() err := bp.Connect(ctx) require.Error(t, err) @@ -530,6 +540,7 @@ func TestBackedPipe_ForceReconnect(t *testing.T) { reconnectFn, callCount, _ := mockReconnectFunc(conn1, conn2) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() // Initial connect err := bp.Connect(ctx) @@ -543,7 +554,7 @@ func TestBackedPipe_ForceReconnect(t *testing.T) { require.Equal(t, "test data", conn1.ReadString()) // Force a reconnection - err = bp.ForceReconnect(ctx) + err = bp.ForceReconnect() require.NoError(t, err) require.True(t, bp.Connected()) require.Equal(t, 2, *callCount) @@ -580,7 +591,7 @@ func TestBackedPipe_ForceReconnectWhenClosed(t *testing.T) { require.NoError(t, err) // Try to force reconnect when closed - err = bp.ForceReconnect(ctx) + err = bp.ForceReconnect() require.Error(t, err) require.Equal(t, io.ErrClosedPipe, err) } @@ -593,9 +604,10 @@ func TestBackedPipe_ForceReconnectWhenDisconnected(t *testing.T) { reconnectFn, callCount, _ := mockReconnectFunc(conn) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() // Don't connect initially, just force reconnect - err := bp.ForceReconnect(ctx) + err := bp.ForceReconnect() require.NoError(t, err) require.True(t, bp.Connected()) require.Equal(t, 1, *callCount) @@ -655,6 +667,7 @@ func TestBackedPipe_EOFTriggersReconnection(t *testing.T) { } bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() // Initial connect err := bp.Connect(ctx) @@ -699,6 +712,9 @@ func BenchmarkBackedPipe_Write(b *testing.B) { bp := backedpipe.NewBackedPipe(ctx, reconnectFn) bp.Connect(ctx) + b.Cleanup(func() { + _ = bp.Close() + }) data := make([]byte, 1024) // 1KB writes @@ -715,6 +731,9 @@ func BenchmarkBackedPipe_Read(b *testing.B) { bp := backedpipe.NewBackedPipe(ctx, reconnectFn) bp.Connect(ctx) + b.Cleanup(func() { + _ = bp.Close() + }) buf := make([]byte, 1024) diff --git a/coderd/agentapi/backedpipe/backed_reader.go b/coderd/agentapi/backedpipe/backed_reader.go index 7b57986638065..1ea93cd3c39b4 100644 --- a/coderd/agentapi/backedpipe/backed_reader.go +++ b/coderd/agentapi/backedpipe/backed_reader.go @@ -34,41 +34,50 @@ func NewBackedReader() *BackedReader { // When connected, it reads from the underlying reader and updates sequence numbers. // Connection failures are automatically detected and reported to the higher layer via callback. func (br *BackedReader) Read(p []byte) (int, error) { - br.mu.Lock() - defer br.mu.Unlock() - for { + // Step 1: Wait until we have a reader or are closed + br.mu.Lock() for br.reader == nil && !br.closed { br.cond.Wait() } - // Check if closed if br.closed { + br.mu.Unlock() return 0, io.ErrClosedPipe } + // Capture the current reader and release the lock while performing + // the potentially blocking I/O operation to avoid deadlocks with Close(). + r := br.reader br.mu.Unlock() - n, err := br.reader.Read(p) - br.mu.Lock() + // Step 2: Perform the read without holding the mutex + n, err := r.Read(p) + + // Step 3: Reacquire the lock to update state based on the result + br.mu.Lock() if err == nil { br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract + br.mu.Unlock() return n, nil } + // Mark disconnected so future reads will wait for reconnection br.reader = nil if br.onError != nil { br.onError(err) } - // If we got some data before the error, return it + // If we got some data before the error, return it now if n > 0 { br.sequenceNum += uint64(n) + br.mu.Unlock() return n, nil } - // Return to Step 2 (continue the loop) + // Otherwise loop and wait for reconnection or close + br.mu.Unlock() } } diff --git a/coderd/agentapi/backedpipe/ring_buffer_test.go b/coderd/agentapi/backedpipe/ring_buffer_test.go index 3febbcb433e5a..8bfe6af82ad56 100644 --- a/coderd/agentapi/backedpipe/ring_buffer_test.go +++ b/coderd/agentapi/backedpipe/ring_buffer_test.go @@ -3,14 +3,27 @@ package backedpipe_test import ( "bytes" "fmt" + "os" + "runtime" "sync" "testing" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "github.com/coder/coder/v2/coderd/agentapi/backedpipe" + "github.com/coder/coder/v2/testutil" ) +func TestMain(m *testing.M) { + if runtime.GOOS == "windows" { + // Don't run goleak on windows tests, they're super flaky right now. + // See: https://github.com/coder/coder/issues/8954 + os.Exit(m.Run()) + } + goleak.VerifyTestMain(m, testutil.GoleakOptions...) +} + func TestRingBuffer_NewRingBuffer(t *testing.T) { t.Parallel() From 086fbf846a14c82355ea65de7271a4a8bcf2fd6c Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 12 Aug 2025 15:33:36 +0000 Subject: [PATCH 03/11] chore: moved backedpipe to be a part of the agent code --- .../immortalstreams}/backedpipe/backed_pipe.go | 0 .../immortalstreams}/backedpipe/backed_pipe_test.go | 2 +- .../immortalstreams}/backedpipe/backed_reader.go | 0 .../immortalstreams}/backedpipe/backed_reader_test.go | 2 +- .../immortalstreams}/backedpipe/backed_writer.go | 0 .../immortalstreams}/backedpipe/backed_writer_test.go | 2 +- .../immortalstreams}/backedpipe/ring_buffer.go | 0 .../immortalstreams}/backedpipe/ring_buffer_internal_test.go | 0 .../immortalstreams}/backedpipe/ring_buffer_test.go | 2 +- 9 files changed, 4 insertions(+), 4 deletions(-) rename {coderd/agentapi => agent/immortalstreams}/backedpipe/backed_pipe.go (100%) rename {coderd/agentapi => agent/immortalstreams}/backedpipe/backed_pipe_test.go (99%) rename {coderd/agentapi => agent/immortalstreams}/backedpipe/backed_reader.go (100%) rename {coderd/agentapi => agent/immortalstreams}/backedpipe/backed_reader_test.go (99%) rename {coderd/agentapi => agent/immortalstreams}/backedpipe/backed_writer.go (100%) rename {coderd/agentapi => agent/immortalstreams}/backedpipe/backed_writer_test.go (99%) rename {coderd/agentapi => agent/immortalstreams}/backedpipe/ring_buffer.go (100%) rename {coderd/agentapi => agent/immortalstreams}/backedpipe/ring_buffer_internal_test.go (100%) rename {coderd/agentapi => agent/immortalstreams}/backedpipe/ring_buffer_test.go (99%) diff --git a/coderd/agentapi/backedpipe/backed_pipe.go b/agent/immortalstreams/backedpipe/backed_pipe.go similarity index 100% rename from coderd/agentapi/backedpipe/backed_pipe.go rename to agent/immortalstreams/backedpipe/backed_pipe.go diff --git a/coderd/agentapi/backedpipe/backed_pipe_test.go b/agent/immortalstreams/backedpipe/backed_pipe_test.go similarity index 99% rename from coderd/agentapi/backedpipe/backed_pipe_test.go rename to agent/immortalstreams/backedpipe/backed_pipe_test.go index 5b345a325fd56..be78e8b896be5 100644 --- a/coderd/agentapi/backedpipe/backed_pipe_test.go +++ b/agent/immortalstreams/backedpipe/backed_pipe_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/xerrors" - "github.com/coder/coder/v2/coderd/agentapi/backedpipe" + "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" "github.com/coder/coder/v2/testutil" ) diff --git a/coderd/agentapi/backedpipe/backed_reader.go b/agent/immortalstreams/backedpipe/backed_reader.go similarity index 100% rename from coderd/agentapi/backedpipe/backed_reader.go rename to agent/immortalstreams/backedpipe/backed_reader.go diff --git a/coderd/agentapi/backedpipe/backed_reader_test.go b/agent/immortalstreams/backedpipe/backed_reader_test.go similarity index 99% rename from coderd/agentapi/backedpipe/backed_reader_test.go rename to agent/immortalstreams/backedpipe/backed_reader_test.go index 810abb7c64bd6..a16f1d5ecb4af 100644 --- a/coderd/agentapi/backedpipe/backed_reader_test.go +++ b/agent/immortalstreams/backedpipe/backed_reader_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/xerrors" - "github.com/coder/coder/v2/coderd/agentapi/backedpipe" + "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" ) // mockReader implements io.Reader with controllable behavior for testing diff --git a/coderd/agentapi/backedpipe/backed_writer.go b/agent/immortalstreams/backedpipe/backed_writer.go similarity index 100% rename from coderd/agentapi/backedpipe/backed_writer.go rename to agent/immortalstreams/backedpipe/backed_writer.go diff --git a/coderd/agentapi/backedpipe/backed_writer_test.go b/agent/immortalstreams/backedpipe/backed_writer_test.go similarity index 99% rename from coderd/agentapi/backedpipe/backed_writer_test.go rename to agent/immortalstreams/backedpipe/backed_writer_test.go index f92a79c6f366b..18d662e92ea0b 100644 --- a/coderd/agentapi/backedpipe/backed_writer_test.go +++ b/agent/immortalstreams/backedpipe/backed_writer_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/xerrors" - "github.com/coder/coder/v2/coderd/agentapi/backedpipe" + "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" "github.com/coder/coder/v2/testutil" ) diff --git a/coderd/agentapi/backedpipe/ring_buffer.go b/agent/immortalstreams/backedpipe/ring_buffer.go similarity index 100% rename from coderd/agentapi/backedpipe/ring_buffer.go rename to agent/immortalstreams/backedpipe/ring_buffer.go diff --git a/coderd/agentapi/backedpipe/ring_buffer_internal_test.go b/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go similarity index 100% rename from coderd/agentapi/backedpipe/ring_buffer_internal_test.go rename to agent/immortalstreams/backedpipe/ring_buffer_internal_test.go diff --git a/coderd/agentapi/backedpipe/ring_buffer_test.go b/agent/immortalstreams/backedpipe/ring_buffer_test.go similarity index 99% rename from coderd/agentapi/backedpipe/ring_buffer_test.go rename to agent/immortalstreams/backedpipe/ring_buffer_test.go index 8bfe6af82ad56..0d9b4a12dc947 100644 --- a/coderd/agentapi/backedpipe/ring_buffer_test.go +++ b/agent/immortalstreams/backedpipe/ring_buffer_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" - "github.com/coder/coder/v2/coderd/agentapi/backedpipe" + "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" "github.com/coder/coder/v2/testutil" ) From d82da56dfcd370c96bfd69998cd2dc1f657e34ff Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:51:56 +0000 Subject: [PATCH 04/11] PR feedback implemented --- .../immortalstreams/backedpipe/backed_pipe.go | 54 +-- .../backedpipe/backed_pipe_test.go | 90 ++--- .../backedpipe/backed_reader.go | 36 +- .../backedpipe/backed_reader_test.go | 327 ++++++++++----- .../backedpipe/backed_writer.go | 167 +++----- .../backedpipe/backed_writer_test.go | 375 +++++++++++++----- .../immortalstreams/backedpipe/ring_buffer.go | 133 +++---- .../backedpipe/ring_buffer_internal_test.go | 298 +++++++++----- .../backedpipe/ring_buffer_test.go | 339 ---------------- 9 files changed, 880 insertions(+), 939 deletions(-) delete mode 100644 agent/immortalstreams/backedpipe/ring_buffer_test.go diff --git a/agent/immortalstreams/backedpipe/backed_pipe.go b/agent/immortalstreams/backedpipe/backed_pipe.go index a06dac4a22604..8161b91533233 100644 --- a/agent/immortalstreams/backedpipe/backed_pipe.go +++ b/agent/immortalstreams/backedpipe/backed_pipe.go @@ -4,14 +4,13 @@ import ( "context" "io" "sync" - "time" "golang.org/x/sync/singleflight" "golang.org/x/xerrors" ) const ( - // DefaultBufferSize is the default buffer size for the BackedWriter (64MB) + // Default buffer capacity used by the writer - 64MB DefaultBufferSize = 64 * 1024 * 1024 ) @@ -60,17 +59,18 @@ type BackedPipe struct { func NewBackedPipe(ctx context.Context, reconnectFn ReconnectFunc) *BackedPipe { pipeCtx, cancel := context.WithCancel(ctx) + errorChan := make(chan error, 2) // Buffer for reader and writer errors bp := &BackedPipe{ ctx: pipeCtx, cancel: cancel, reader: NewBackedReader(), - writer: NewBackedWriterWithCapacity(DefaultBufferSize), // 64MB default buffer + writer: NewBackedWriter(DefaultBufferSize, errorChan), reconnectFn: reconnectFn, - errorChan: make(chan error, 2), // Buffer for reader and writer errors + errorChan: errorChan, connectionChanged: make(chan struct{}, 1), } - // Set up error callbacks + // Set up error callback for reader only (writer uses error channel directly) bp.reader.SetErrorCallback(func(err error) { select { case bp.errorChan <- err: @@ -78,13 +78,6 @@ func NewBackedPipe(ctx context.Context, reconnectFn ReconnectFunc) *BackedPipe { } }) - bp.writer.SetErrorCallback(func(err error) { - select { - case bp.errorChan <- err: - case <-bp.ctx.Done(): - } - }) - // Start error handler goroutine go bp.handleErrors() @@ -233,16 +226,6 @@ func (bp *BackedPipe) reconnectLocked() error { readerSeqNum, writerSeqNum) } - // Validate writer can replay from the requested sequence - if !bp.writer.CanReplayFrom(readerSeqNum) { - _ = conn.Close() - // Calculate data loss - currentSeq := bp.writer.SequenceNum() - dataLoss := currentSeq - DefaultBufferSize - readerSeqNum - return xerrors.Errorf("cannot replay from sequence %d (current: %d, data loss: ~%d bytes)", - readerSeqNum, currentSeq, dataLoss) - } - // Reconnect reader and writer seqNum := make(chan uint64, 1) newR := make(chan io.Reader, 1) @@ -300,33 +283,6 @@ func (bp *BackedPipe) handleErrors() { } } -// WaitForConnection blocks until the pipe is connected or the context is canceled. -func (bp *BackedPipe) WaitForConnection(ctx context.Context) error { - for { - bp.mu.RLock() - connected := bp.connected - closed := bp.closed - bp.mu.RUnlock() - - if closed { - return io.ErrClosedPipe - } - - if connected { - return nil - } - - select { - case <-ctx.Done(): - return ctx.Err() - case <-bp.connectionChanged: - // Connection state changed, check again - case <-time.After(10 * time.Millisecond): - // Periodically re-check to avoid missed notifications - } - } -} - // ForceReconnect forces a reconnection attempt immediately. // This can be used to force a reconnection if a new connection is established. func (bp *BackedPipe) ForceReconnect() error { diff --git a/agent/immortalstreams/backedpipe/backed_pipe_test.go b/agent/immortalstreams/backedpipe/backed_pipe_test.go index be78e8b896be5..1cd5bd227ebcd 100644 --- a/agent/immortalstreams/backedpipe/backed_pipe_test.go +++ b/agent/immortalstreams/backedpipe/backed_pipe_test.go @@ -240,21 +240,35 @@ func TestBackedPipe_BasicReadWrite(t *testing.T) { func TestBackedPipe_WriteBeforeConnect(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) - ctx := context.Background() conn := newMockConnection() reconnectFn, _, _ := mockReconnectFunc(conn) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) defer bp.Close() - // Write before connecting should succeed (buffered) - n, err := bp.Write([]byte("hello")) + // Write before connecting should block + writeComplete := make(chan error, 1) + go func() { + _, err := bp.Write([]byte("hello")) + writeComplete <- err + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(100 * time.Millisecond): + // Expected - write is blocked + } + + // Connect should unblock the write + err := bp.Connect(ctx) require.NoError(t, err) - require.Equal(t, 5, n) - // Connect should replay the buffered data - err = bp.Connect(ctx) + // Write should now complete + err = testutil.RequireReceive(ctx, t, writeComplete) require.NoError(t, err) // Check that data was replayed to connection @@ -265,6 +279,7 @@ func TestBackedPipe_ReadBlocksWhenDisconnected(t *testing.T) { t.Parallel() ctx := context.Background() + testCtx := testutil.Context(t, testutil.WaitShort) reconnectFn, _, _ := mockReconnectFunc(newMockConnection()) bp := backedpipe.NewBackedPipe(ctx, reconnectFn) @@ -283,7 +298,7 @@ func TestBackedPipe_ReadBlocksWhenDisconnected(t *testing.T) { }() // Wait for the goroutine to start - <-readStarted + testutil.TryReceive(testCtx, t, readStarted) // Give a brief moment for the read to actually block time.Sleep(time.Millisecond) @@ -299,18 +314,15 @@ func TestBackedPipe_ReadBlocksWhenDisconnected(t *testing.T) { // Close should unblock the read bp.Close() - select { - case <-readDone: - require.Equal(t, io.ErrClosedPipe, readErr) - case <-time.After(time.Second): - t.Fatal("Read did not unblock after close") - } + testutil.TryReceive(testCtx, t, readDone) + require.Equal(t, io.EOF, readErr) } func TestBackedPipe_Reconnection(t *testing.T) { t.Parallel() ctx := context.Background() + testCtx := testutil.Context(t, testutil.WaitShort) conn1 := newMockConnection() conn2 := newMockConnection() conn2.seqNum = 17 // Remote has received 17 bytes, so replay from sequence 17 @@ -333,10 +345,12 @@ func TestBackedPipe_Reconnection(t *testing.T) { // Trigger a write to cause the pipe to notice the failure _, _ = bp.Write([]byte("trigger failure ")) - <-signalChan + testutil.RequireReceive(testCtx, t, signalChan) - err = bp.WaitForConnection(ctx) - require.NoError(t, err) + // Wait for reconnection to complete + require.Eventually(t, func() bool { + return bp.Connected() + }, testutil.WaitShort, testutil.IntervalFast, "pipe should reconnect") replayedData := conn2.ReadString() require.Equal(t, "***trigger failure ", replayedData, "Should replay exactly the data written after sequence 17") @@ -391,45 +405,10 @@ func TestBackedPipe_CloseIdempotent(t *testing.T) { require.NoError(t, err) } -func TestBackedPipe_WaitForConnection(t *testing.T) { - t.Parallel() - - ctx := context.Background() - conn := newMockConnection() - reconnectFn, _, _ := mockReconnectFunc(conn) - - bp := backedpipe.NewBackedPipe(ctx, reconnectFn) - defer bp.Close() - - // Should timeout when not connected - // Use a shorter timeout for this test to speed up test runs - timeoutCtx, cancel := context.WithTimeout(ctx, testutil.WaitSuperShort) - defer cancel() - - err := bp.WaitForConnection(timeoutCtx) - require.Equal(t, context.DeadlineExceeded, err) - - // Connect in background after a brief delay - connectionStarted := make(chan struct{}) - go func() { - close(connectionStarted) - // Small delay to ensure WaitForConnection is called first - time.Sleep(time.Millisecond) - bp.Connect(context.Background()) - }() - - // Wait for connection goroutine to start - <-connectionStarted - - // Should succeed once connected - err = bp.WaitForConnection(context.Background()) - require.NoError(t, err) -} - func TestBackedPipe_ConcurrentReadWrite(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) - ctx := context.Background() conn := newMockConnection() reconnectFn, _, _ := mockReconnectFunc(conn) @@ -487,12 +466,7 @@ func TestBackedPipe_ConcurrentReadWrite(t *testing.T) { wg.Wait() }() - select { - case <-done: - // Success - case <-time.After(5 * time.Second): - t.Fatal("Test timed out") - } + testutil.TryReceive(ctx, t, done) // Close the channel and collect all written data close(writtenData) diff --git a/agent/immortalstreams/backedpipe/backed_reader.go b/agent/immortalstreams/backedpipe/backed_reader.go index 1ea93cd3c39b4..4632cafc92e8f 100644 --- a/agent/immortalstreams/backedpipe/backed_reader.go +++ b/agent/immortalstreams/backedpipe/backed_reader.go @@ -34,35 +34,29 @@ func NewBackedReader() *BackedReader { // When connected, it reads from the underlying reader and updates sequence numbers. // Connection failures are automatically detected and reported to the higher layer via callback. func (br *BackedReader) Read(p []byte) (int, error) { + br.mu.Lock() + defer br.mu.Unlock() + for { // Step 1: Wait until we have a reader or are closed - br.mu.Lock() for br.reader == nil && !br.closed { br.cond.Wait() } if br.closed { - br.mu.Unlock() - return 0, io.ErrClosedPipe + return 0, io.EOF } - // Capture the current reader and release the lock while performing - // the potentially blocking I/O operation to avoid deadlocks with Close(). - r := br.reader - br.mu.Unlock() + // Step 2: Perform the read while holding the mutex + // This ensures proper synchronization with Reconnect and Close operations + n, err := br.reader.Read(p) + br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract - // Step 2: Perform the read without holding the mutex - n, err := r.Read(p) - - // Step 3: Reacquire the lock to update state based on the result - br.mu.Lock() if err == nil { - br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract - br.mu.Unlock() return n, nil } - // Mark disconnected so future reads will wait for reconnection + // Mark reader as disconnected so future reads will wait for reconnection br.reader = nil if br.onError != nil { @@ -71,13 +65,8 @@ func (br *BackedReader) Read(p []byte) (int, error) { // If we got some data before the error, return it now if n > 0 { - br.sequenceNum += uint64(n) - br.mu.Unlock() return n, nil } - - // Otherwise loop and wait for reconnection or close - br.mu.Unlock() } } @@ -91,8 +80,7 @@ func (br *BackedReader) Reconnect(seqNum chan<- uint64, newR <-chan io.Reader) { defer br.mu.Unlock() if br.closed { - // Send 0 sequence number and close the channel to indicate closed state - seqNum <- 0 + // Close the channel to indicate closed state close(seqNum) return } @@ -117,8 +105,8 @@ func (br *BackedReader) Reconnect(seqNum chan<- uint64, newR <-chan io.Reader) { br.cond.Broadcast() } -// Closes the reader and wakes up any blocked reads. -// After closing, all Read calls will return io.ErrClosedPipe. +// Close the reader and wake up any blocked reads. +// After closing, all Read calls will return io.EOF. func (br *BackedReader) Close() error { br.mu.Lock() defer br.mu.Unlock() diff --git a/agent/immortalstreams/backedpipe/backed_reader_test.go b/agent/immortalstreams/backedpipe/backed_reader_test.go index a16f1d5ecb4af..25d2038d6d843 100644 --- a/agent/immortalstreams/backedpipe/backed_reader_test.go +++ b/agent/immortalstreams/backedpipe/backed_reader_test.go @@ -1,8 +1,8 @@ package backedpipe_test import ( + "context" "io" - "strings" "sync" "testing" "time" @@ -12,6 +12,7 @@ import ( "golang.org/x/xerrors" "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" + "github.com/coder/coder/v2/testutil" ) // mockReader implements io.Reader with controllable behavior for testing @@ -65,6 +66,7 @@ func TestBackedReader_NewBackedReader(t *testing.T) { func TestBackedReader_BasicReadOperation(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) br := backedpipe.NewBackedReader() reader := newMockReader("hello world") @@ -76,11 +78,11 @@ func TestBackedReader_BasicReadOperation(t *testing.T) { go br.Reconnect(seqNum, newR) // Get sequence number from reader - seq := <-seqNum + seq := testutil.RequireReceive(ctx, t, seqNum) assert.Equal(t, uint64(0), seq) // Send new reader - newR <- reader + testutil.RequireSend(ctx, t, newR, io.Reader(reader)) // Read data buf := make([]byte, 5) @@ -100,26 +102,24 @@ func TestBackedReader_BasicReadOperation(t *testing.T) { func TestBackedReader_ReadBlocksWhenDisconnected(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) br := backedpipe.NewBackedReader() // Start a read operation that should block readDone := make(chan struct{}) - readStarted := make(chan struct{}) var readErr error + var readBuf []byte + var readN int go func() { defer close(readDone) - close(readStarted) // Signal that we're about to start the read buf := make([]byte, 10) - _, readErr = br.Read(buf) + readN, readErr = br.Read(buf) + readBuf = buf[:readN] }() - // Wait for the goroutine to start - <-readStarted - - // Give a brief moment for the read to actually block on the condition variable - // This is much shorter and more deterministic than the previous approach + // Give a brief moment for the read to actually start and block on the condition variable time.Sleep(time.Millisecond) // Read should still be blocked @@ -138,20 +138,18 @@ func TestBackedReader_ReadBlocksWhenDisconnected(t *testing.T) { go br.Reconnect(seqNum, newR) // Get sequence number and send new reader - <-seqNum - newR <- reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader)) // Wait for read to complete - select { - case <-readDone: - assert.NoError(t, readErr) - case <-time.After(time.Second): - t.Fatal("Read did not unblock after reconnection") - } + testutil.TryReceive(ctx, t, readDone) + assert.NoError(t, readErr) + assert.Equal(t, "test", string(readBuf)) } func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) br := backedpipe.NewBackedReader() reader1 := newMockReader("first") @@ -163,8 +161,8 @@ func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { go br.Reconnect(seqNum, newR) // Get sequence number and send new reader - <-seqNum - newR <- reader1 + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader1)) // Read some data buf := make([]byte, 5) @@ -190,12 +188,16 @@ func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { }() // Wait for the error to be reported via callback + receivedErr := testutil.RequireReceive(ctx, t, errorReceived) + assert.Error(t, receivedErr) + assert.Contains(t, receivedErr.Error(), "connection lost") + + // Verify read is still blocked select { - case receivedErr := <-errorReceived: - assert.Error(t, receivedErr) - assert.Contains(t, receivedErr.Error(), "connection lost") - case <-time.After(time.Second): - t.Fatal("Error callback was not invoked within timeout") + case err := <-readDone: + t.Fatalf("Read should still be blocked, but completed with: %v", err) + default: + // Good, still blocked } // Verify disconnection @@ -209,21 +211,18 @@ func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { go br.Reconnect(seqNum2, newR2) // Get sequence number and send new reader - seq := <-seqNum2 + seq := testutil.RequireReceive(ctx, t, seqNum2) assert.Equal(t, uint64(5), seq) // Should return current sequence number - newR2 <- reader2 + testutil.RequireSend(ctx, t, newR2, io.Reader(reader2)) // Wait for read to unblock and succeed with new data - select { - case readErr := <-readDone: - assert.NoError(t, readErr) // Should succeed with new reader - case <-time.After(time.Second): - t.Fatal("Read did not unblock after reconnection") - } + readErr := testutil.RequireReceive(ctx, t, readDone) + assert.NoError(t, readErr) // Should succeed with new reader } func TestBackedReader_Close(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) br := backedpipe.NewBackedReader() reader := newMockReader("test") @@ -235,8 +234,8 @@ func TestBackedReader_Close(t *testing.T) { go br.Reconnect(seqNum, newR) // Get sequence number and send new reader - <-seqNum - newR <- reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader)) // First, read all available data buf := make([]byte, 10) @@ -248,14 +247,14 @@ func TestBackedReader_Close(t *testing.T) { err = br.Close() require.NoError(t, err) - // After close, reads should return ErrClosedPipe + // After close, reads should return EOF n, err = br.Read(buf) assert.Equal(t, 0, n) - assert.Equal(t, io.ErrClosedPipe, err) + assert.Equal(t, io.EOF, err) - // Subsequent reads should return ErrClosedPipe + // Subsequent reads should return EOF _, err = br.Read(buf) - assert.Equal(t, io.ErrClosedPipe, err) + assert.Equal(t, io.EOF, err) } func TestBackedReader_CloseIdempotent(t *testing.T) { @@ -273,6 +272,7 @@ func TestBackedReader_CloseIdempotent(t *testing.T) { func TestBackedReader_ReconnectAfterClose(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) br := backedpipe.NewBackedReader() @@ -285,29 +285,30 @@ func TestBackedReader_ReconnectAfterClose(t *testing.T) { go br.Reconnect(seqNum, newR) // Should get 0 sequence number for closed reader - seq := <-seqNum + seq := testutil.TryReceive(ctx, t, seqNum) assert.Equal(t, uint64(0), seq) } // Helper function to reconnect a reader using channels -func reconnectReader(br *backedpipe.BackedReader, reader io.Reader) { +func reconnectReader(ctx context.Context, t testing.TB, br *backedpipe.BackedReader, reader io.Reader) { seqNum := make(chan uint64, 1) newR := make(chan io.Reader, 1) go br.Reconnect(seqNum, newR) // Get sequence number and send new reader - <-seqNum - newR <- reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, reader) } func TestBackedReader_SequenceNumberTracking(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) br := backedpipe.NewBackedReader() reader := newMockReader("0123456789") - reconnectReader(br, reader) + reconnectReader(ctx, t, br, reader) // Read in chunks and verify sequence number buf := make([]byte, 3) @@ -328,38 +329,9 @@ func TestBackedReader_SequenceNumberTracking(t *testing.T) { assert.Equal(t, uint64(9), br.SequenceNum()) } -func TestBackedReader_ConcurrentReads(t *testing.T) { - t.Parallel() - - br := backedpipe.NewBackedReader() - reader := newMockReader(strings.Repeat("a", 1000)) - - reconnectReader(br, reader) - - var wg sync.WaitGroup - numReaders := 5 - readsPerReader := 10 - - for i := 0; i < numReaders; i++ { - wg.Add(1) - go func() { - defer wg.Done() - buf := make([]byte, 10) - for j := 0; j < readsPerReader; j++ { - br.Read(buf) - } - }() - } - - wg.Wait() - - // Should have read some data (exact amount depends on scheduling) - assert.True(t, br.SequenceNum() > 0) - assert.True(t, br.SequenceNum() <= 1000) -} - func TestBackedReader_EOFHandling(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) br := backedpipe.NewBackedReader() reader := newMockReader("test") @@ -370,7 +342,7 @@ func TestBackedReader_EOFHandling(t *testing.T) { errorReceived <- err }) - reconnectReader(br, reader) + reconnectReader(ctx, t, br, reader) // Read all data buf := make([]byte, 10) @@ -391,12 +363,8 @@ func TestBackedReader_EOFHandling(t *testing.T) { }() // Wait for EOF to be reported via error callback - select { - case receivedErr := <-errorReceived: - assert.Equal(t, io.EOF, receivedErr) - case <-time.After(time.Second): - t.Fatal("EOF was not reported via error callback within timeout") - } + receivedErr := testutil.RequireReceive(ctx, t, errorReceived) + assert.Equal(t, io.EOF, receivedErr) // Reader should be disconnected after EOF assert.False(t, br.Connected()) @@ -411,36 +379,43 @@ func TestBackedReader_EOFHandling(t *testing.T) { // Reconnect with new data reader2 := newMockReader("more") - reconnectReader(br, reader2) + reconnectReader(ctx, t, br, reader2) // Wait for the blocked read to complete with new data - select { - case <-readDone: - require.NoError(t, readErr) - assert.Equal(t, 4, readN) - assert.Equal(t, "more", string(buf[:readN])) - case <-time.After(time.Second): - t.Fatal("Read did not unblock after reconnection") - } + testutil.TryReceive(ctx, t, readDone) + require.NoError(t, readErr) + assert.Equal(t, 4, readN) + assert.Equal(t, "more", string(buf[:readN])) } func BenchmarkBackedReader_Read(b *testing.B) { br := backedpipe.NewBackedReader() buf := make([]byte, 1024) + // Create a reader that never returns EOF by cycling through data + reader := &mockReader{ + readFunc: func(p []byte) (int, error) { + // Fill buffer with 'x' characters - never EOF + for i := range p { + p[i] = 'x' + } + return len(p), nil + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + reconnectReader(ctx, b, br, reader) + b.ResetTimer() for i := 0; i < b.N; i++ { - // Create fresh reader with data for each iteration - data := strings.Repeat("x", 1024) // 1KB of data per iteration - reader := newMockReader(data) - reconnectReader(br, reader) - br.Read(buf) } } func TestBackedReader_PartialReads(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) br := backedpipe.NewBackedReader() @@ -456,7 +431,7 @@ func TestBackedReader_PartialReads(t *testing.T) { }, } - reconnectReader(br, reader) + reconnectReader(ctx, t, br, reader) // Read multiple times buf := make([]byte, 10) @@ -469,3 +444,161 @@ func TestBackedReader_PartialReads(t *testing.T) { assert.Equal(t, uint64(5), br.SequenceNum()) } + +func TestBackedReader_CloseWhileBlockedOnUnderlyingReader(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + + // Create a reader that blocks on Read calls but can be unblocked + readStarted := make(chan struct{}, 1) + readUnblocked := make(chan struct{}) + blockingReader := &mockReader{ + readFunc: func(p []byte) (int, error) { + select { + case readStarted <- struct{}{}: + default: + } + <-readUnblocked // Block until signaled + // After unblocking, return an error to simulate connection failure + return 0, xerrors.New("connection interrupted") + }, + } + + // Connect the blocking reader + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send blocking reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(blockingReader)) + + // Start a read that will block on the underlying reader + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + defer close(readDone) + buf := make([]byte, 10) + readN, readErr = br.Read(buf) + }() + + // Wait for the read to start and block on the underlying reader + testutil.RequireReceive(ctx, t, readStarted) + + // Give a brief moment for the read to actually block + time.Sleep(time.Millisecond) + + // Verify read is blocked + select { + case <-readDone: + t.Fatal("Read should be blocked on underlying reader") + default: + // Good, still blocked + } + + // Start Close() in a goroutine since it will block until the underlying read completes + closeDone := make(chan error, 1) + go func() { + closeDone <- br.Close() + }() + + // Verify Close() is also blocked waiting for the underlying read + select { + case <-closeDone: + t.Fatal("Close should be blocked until underlying read completes") + case <-time.After(10 * time.Millisecond): + // Good, Close is blocked + } + + // Unblock the underlying reader, which will cause both the read and close to complete + close(readUnblocked) + + // Wait for both the read and close to complete + testutil.TryReceive(ctx, t, readDone) + closeErr := testutil.RequireReceive(ctx, t, closeDone) + require.NoError(t, closeErr) + + // The read should return EOF because Close() was called while it was blocked, + // even though the underlying reader returned an error + assert.Equal(t, 0, readN) + assert.Equal(t, io.EOF, readErr) + + // Subsequent reads should return EOF since the reader is now closed + buf := make([]byte, 10) + n, err := br.Read(buf) + assert.Equal(t, 0, n) + assert.Equal(t, io.EOF, err) +} + +func TestBackedReader_CloseWhileBlockedWaitingForReconnect(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + reader1 := newMockReader("initial") + + // Initial connection + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send initial reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader1)) + + // Read initial data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, "initial", string(buf[:n])) + + // Set up error callback to track connection failure + errorReceived := make(chan error, 1) + br.SetErrorCallback(func(err error) { + errorReceived <- err + }) + + // Simulate connection failure + reader1.setError(xerrors.New("connection lost")) + + // Start a read that will block waiting for reconnection + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + defer close(readDone) + readN, readErr = br.Read(buf) + }() + + // Wait for the error to be reported (indicating disconnection) + receivedErr := testutil.RequireReceive(ctx, t, errorReceived) + assert.Error(t, receivedErr) + assert.Contains(t, receivedErr.Error(), "connection lost") + + // Verify read is blocked waiting for reconnection + select { + case <-readDone: + t.Fatal("Read should be blocked waiting for reconnection") + default: + // Good, still blocked + } + + // Verify reader is disconnected + assert.False(t, br.Connected()) + + // Close the BackedReader while read is blocked waiting for reconnection + err = br.Close() + require.NoError(t, err) + + // The read should unblock and return EOF + testutil.TryReceive(ctx, t, readDone) + assert.Equal(t, 0, readN) + assert.Equal(t, io.EOF, readErr) +} diff --git a/agent/immortalstreams/backedpipe/backed_writer.go b/agent/immortalstreams/backedpipe/backed_writer.go index bc72d8bfc7385..b268f48bc77eb 100644 --- a/agent/immortalstreams/backedpipe/backed_writer.go +++ b/agent/immortalstreams/backedpipe/backed_writer.go @@ -1,7 +1,6 @@ package backedpipe import ( - "context" "io" "sync" @@ -9,39 +8,42 @@ import ( ) // BackedWriter wraps an unreliable io.Writer and makes it resilient to disconnections. -// It maintains a ring buffer of recent writes for replay during reconnection and -// always writes to the buffer even when disconnected. +// It maintains a ring buffer of recent writes for replay during reconnection. type BackedWriter struct { mu sync.Mutex cond *sync.Cond writer io.Writer - buffer *RingBuffer + buffer *ringBuffer sequenceNum uint64 // total bytes written closed bool - // Error callback to notify parent when connection fails - onError func(error) + // Error channel to notify parent when connection fails + errorChan chan<- error } -// NewBackedWriter creates a new BackedWriter with a 64MB ring buffer. -// The writer is initially disconnected and will buffer writes until connected. -func NewBackedWriter() *BackedWriter { - return NewBackedWriterWithCapacity(64 * 1024 * 1024) -} - -// NewBackedWriterWithCapacity creates a new BackedWriter with the specified buffer capacity. -// The writer is initially disconnected and will buffer writes until connected. -func NewBackedWriterWithCapacity(capacity int) *BackedWriter { +// NewBackedWriter creates a new BackedWriter with the specified buffer capacity. +// The writer is initially disconnected and will block writes until connected. +// The errorChan is required and will receive connection errors. +// Capacity must be > 0. +func NewBackedWriter(capacity int, errorChan chan<- error) *BackedWriter { + if capacity <= 0 { + panic("backed writer capacity must be > 0") + } + if errorChan == nil { + panic("error channel cannot be nil") + } bw := &BackedWriter{ - buffer: NewRingBufferWithCapacity(capacity), + buffer: newRingBuffer(capacity), + errorChan: errorChan, } bw.cond = sync.NewCond(&bw.mu) return bw } -// Write implements io.Writer. It always writes to the ring buffer, even when disconnected. -// When connected, it also writes to the underlying writer. If the underlying write fails, -// the writer is marked as disconnected but the buffer write still succeeds. +// Write implements io.Writer. +// When connected, it writes to both the ring buffer and the underlying writer. +// If the underlying write fails, the writer is marked as disconnected and the write blocks +// until reconnection occurs. func (bw *BackedWriter) Write(p []byte) (int, error) { if len(p) == 0 { return 0, nil @@ -54,34 +56,47 @@ func (bw *BackedWriter) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } + // Block until connected + for bw.writer == nil && !bw.closed { + bw.cond.Wait() + } + + // Check if we were closed while waiting + if bw.closed { + return 0, io.ErrClosedPipe + } + // Always write to buffer first - written, _ := bw.buffer.Write(p) - //nolint:gosec // Safe conversion: written is always non-negative from buffer.Write - bw.sequenceNum += uint64(written) + bw.buffer.Write(p) + // Always advance sequence number by the full length + bw.sequenceNum += uint64(len(p)) - // If connected, also write to underlying writer - if bw.writer != nil { - // Unlock during actual write to avoid blocking other operations - bw.mu.Unlock() - n, err := bw.writer.Write(p) - bw.mu.Lock() + // Write to underlying writer + n, err := bw.writer.Write(p) - if n != len(p) { - err = xerrors.Errorf("partial write: wrote %d of %d bytes", n, len(p)) - } + if err != nil { + // Connection failed, mark as disconnected + bw.writer = nil - if err != nil { - // Connection failed, mark as disconnected - bw.writer = nil + // Notify parent of error + select { + case bw.errorChan <- err: + default: + } + return 0, err + } - // Notify parent of error if callback is set - if bw.onError != nil { - bw.onError(err) - } + if n != len(p) { + bw.writer = nil + err = xerrors.Errorf("short write: %d bytes written, %d expected", n, len(p)) + select { + case bw.errorChan <- err: + default: } + return 0, err } - return written, nil + return len(p), nil } // Reconnect replaces the current writer with a new one and replays data from the specified @@ -123,26 +138,31 @@ func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) err } } - // Set new writer - bw.writer = newWriter + // Clear the current writer first in case replay fails + bw.writer = nil - // Replay data if needed + // Replay data if needed. We keep the writer as nil during replay to ensure + // no concurrent writes can happen, then set it only after successful replay. if len(replayData) > 0 { bw.mu.Unlock() n, err := newWriter.Write(replayData) bw.mu.Lock() if err != nil { - bw.writer = nil + // Reconnect failed, writer remains nil return xerrors.Errorf("replay failed: %w", err) } if n != len(replayData) { - bw.writer = nil + // Reconnect failed, writer remains nil return xerrors.Errorf("partial replay: wrote %d of %d bytes", n, len(replayData)) } } + // Set new writer only after successful replay. This ensures no concurrent + // writes can interfere with the replay operation. + bw.writer = newWriter + // Wake up any operations waiting for connection bw.cond.Broadcast() @@ -170,14 +190,6 @@ func (bw *BackedWriter) Close() error { return nil } -// SetErrorCallback sets the callback function that will be called when -// a connection error occurs. -func (bw *BackedWriter) SetErrorCallback(fn func(error)) { - bw.mu.Lock() - defer bw.mu.Unlock() - bw.onError = fn -} - // SequenceNum returns the current sequence number (total bytes written). func (bw *BackedWriter) SequenceNum() uint64 { bw.mu.Lock() @@ -191,54 +203,3 @@ func (bw *BackedWriter) Connected() bool { defer bw.mu.Unlock() return bw.writer != nil } - -// CanReplayFrom returns true if the writer can replay data from the given sequence number. -func (bw *BackedWriter) CanReplayFrom(seqNum uint64) bool { - bw.mu.Lock() - defer bw.mu.Unlock() - return seqNum <= bw.sequenceNum && bw.sequenceNum-seqNum <= DefaultBufferSize -} - -// WaitForConnection blocks until the writer is connected or the context is canceled. -func (bw *BackedWriter) WaitForConnection(ctx context.Context) error { - bw.mu.Lock() - defer bw.mu.Unlock() - - return bw.waitForConnectionLocked(ctx) -} - -// waitForConnectionLocked waits for connection with lock held. -func (bw *BackedWriter) waitForConnectionLocked(ctx context.Context) error { - for bw.writer == nil && !bw.closed { - select { - case <-ctx.Done(): - return ctx.Err() - default: - // Use a timeout to avoid infinite waiting - done := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - bw.cond.Broadcast() - case <-done: - } - }() - - bw.cond.Wait() - close(done) - - // Check context again after waking up - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - } - } - - if bw.closed { - return io.ErrClosedPipe - } - - return nil -} diff --git a/agent/immortalstreams/backedpipe/backed_writer_test.go b/agent/immortalstreams/backedpipe/backed_writer_test.go index 18d662e92ea0b..02f48d811f5c6 100644 --- a/agent/immortalstreams/backedpipe/backed_writer_test.go +++ b/agent/immortalstreams/backedpipe/backed_writer_test.go @@ -2,10 +2,10 @@ package backedpipe_test import ( "bytes" - "context" "io" "sync" "testing" + "time" "github.com/stretchr/testify/require" "golang.org/x/xerrors" @@ -29,7 +29,8 @@ func newMockWriter() *mockWriter { // newBackedWriterForTest creates a BackedWriter with a small buffer for testing eviction behavior func newBackedWriterForTest(bufferSize int) *backedpipe.BackedWriter { - return backedpipe.NewBackedWriterWithCapacity(bufferSize) + errorChan := make(chan error, 1) + return backedpipe.NewBackedWriter(bufferSize, errorChan) } func (mw *mockWriter) Write(p []byte) (int, error) { @@ -73,30 +74,57 @@ func (mw *mockWriter) setError(err error) { func TestBackedWriter_NewBackedWriter(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) require.NotNil(t, bw) require.Equal(t, uint64(0), bw.SequenceNum()) require.False(t, bw.Connected()) } -func TestBackedWriter_WriteToBufferWhenDisconnected(t *testing.T) { +func TestBackedWriter_WriteBlocksWhenDisconnected(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Write should block when disconnected + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + n, writeErr = bw.Write([]byte("hello")) + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } - bw := backedpipe.NewBackedWriter() - - // Write should succeed even when disconnected - n, err := bw.Write([]byte("hello")) + // Connect and verify write completes + writer := newMockWriter() + err := bw.Reconnect(0, writer) require.NoError(t, err) + + // Write should now complete + testutil.TryReceive(ctx, t, writeComplete) + + require.NoError(t, writeErr) require.Equal(t, 5, n) require.Equal(t, uint64(5), bw.SequenceNum()) - - // Data should be in buffer + require.Equal(t, []byte("hello"), writer.buffer.Bytes()) } func TestBackedWriter_WriteToUnderlyingWhenConnected(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) writer := newMockWriter() // Connect @@ -118,7 +146,8 @@ func TestBackedWriter_WriteToUnderlyingWhenConnected(t *testing.T) { func TestBackedWriter_DisconnectOnWriteFailure(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) writer := newMockWriter() // Connect @@ -128,66 +157,104 @@ func TestBackedWriter_DisconnectOnWriteFailure(t *testing.T) { // Cause write to fail writer.setError(xerrors.New("write failed")) - // Write should still succeed to buffer but disconnect + // Write should fail and disconnect n, err := bw.Write([]byte("hello")) - require.NoError(t, err) // Buffer write succeeds - require.Equal(t, 5, n) + require.Error(t, err) // Write should fail + require.Equal(t, 0, n) require.False(t, bw.Connected()) // Should be disconnected - // Data should still be in buffer + // Error should be sent to error channel + select { + case receivedErr := <-errorChan: + require.Contains(t, receivedErr.Error(), "write failed") + default: + t.Fatal("Expected error to be sent to error channel") + } } func TestBackedWriter_ReplayOnReconnect(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) - // Write some data while disconnected - bw.Write([]byte("hello")) - bw.Write([]byte(" world")) + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Write some data while connected + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + _, err = bw.Write([]byte(" world")) + require.NoError(t, err) require.Equal(t, uint64(11), bw.SequenceNum()) - // Reconnect and request replay from beginning - writer := newMockWriter() - err := bw.Reconnect(0, writer) + // Disconnect by causing a write failure + writer1.setError(xerrors.New("connection lost")) + _, err = bw.Write([]byte("test")) + require.Error(t, err) + require.False(t, bw.Connected()) + + // Reconnect with new writer and request replay from beginning + writer2 := newMockWriter() + err = bw.Reconnect(0, writer2) require.NoError(t, err) - // Should have replayed all data - require.Equal(t, []byte("hello world"), writer.buffer.Bytes()) + // Should have replayed all data including the failed write that was buffered + require.Equal(t, []byte("hello worldtest"), writer2.buffer.Bytes()) // Write new data should go to both - bw.Write([]byte("!")) - require.Equal(t, []byte("hello world!"), writer.buffer.Bytes()) + _, err = bw.Write([]byte("!")) + require.NoError(t, err) + require.Equal(t, []byte("hello worldtest!"), writer2.buffer.Bytes()) } func TestBackedWriter_PartialReplay(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) // Write some data - bw.Write([]byte("hello")) - bw.Write([]byte(" world")) - bw.Write([]byte("!")) + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + _, err = bw.Write([]byte(" world")) + require.NoError(t, err) + _, err = bw.Write([]byte("!")) + require.NoError(t, err) - // Reconnect and request replay from middle - writer := newMockWriter() - err := bw.Reconnect(5, writer) // From " world!" + // Reconnect with new writer and request replay from middle + writer2 := newMockWriter() + err = bw.Reconnect(5, writer2) // From " world!" require.NoError(t, err) // Should have replayed only the requested portion - require.Equal(t, []byte(" world!"), writer.buffer.Bytes()) + require.Equal(t, []byte(" world!"), writer2.buffer.Bytes()) } func TestBackedWriter_ReplayFromFutureSequence(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() - bw.Write([]byte("hello")) + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) - writer := newMockWriter() - err := bw.Reconnect(10, writer) // Future sequence + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + + writer2 := newMockWriter() + err = bw.Reconnect(10, writer2) // Future sequence require.Error(t, err) require.Contains(t, err.Error(), "future sequence") } @@ -197,12 +264,19 @@ func TestBackedWriter_ReplayDataLoss(t *testing.T) { bw := newBackedWriterForTest(10) // Small buffer for testing + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + // Fill buffer beyond capacity to cause eviction - bw.Write([]byte("0123456789")) // Fills buffer exactly - bw.Write([]byte("abcdef")) // Should evict "012345" + _, err = bw.Write([]byte("0123456789")) // Fills buffer exactly + require.NoError(t, err) + _, err = bw.Write([]byte("abcdef")) // Should evict "012345" + require.NoError(t, err) - writer := newMockWriter() - err := bw.Reconnect(0, writer) // Try to replay from evicted data + writer2 := newMockWriter() + err = bw.Reconnect(0, writer2) // Try to replay from evicted data // With the new error handling, this should fail because we can't read all the data require.Error(t, err) require.Contains(t, err.Error(), "failed to read replay data") @@ -213,6 +287,11 @@ func TestBackedWriter_BufferEviction(t *testing.T) { bw := newBackedWriterForTest(5) // Very small buffer for testing + // Connect initially + writer := newMockWriter() + err := bw.Reconnect(0, writer) + require.NoError(t, err) + // Write data that will cause eviction n, err := bw.Write([]byte("abcde")) require.NoError(t, err) @@ -229,7 +308,8 @@ func TestBackedWriter_BufferEviction(t *testing.T) { func TestBackedWriter_Close(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) writer := newMockWriter() bw.Reconnect(0, writer) @@ -250,7 +330,8 @@ func TestBackedWriter_Close(t *testing.T) { func TestBackedWriter_CloseIdempotent(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) err := bw.Close() require.NoError(t, err) @@ -260,54 +341,11 @@ func TestBackedWriter_CloseIdempotent(t *testing.T) { require.NoError(t, err) } -func TestBackedWriter_CanReplayFrom(t *testing.T) { - t.Parallel() - - bw := newBackedWriterForTest(10) // Small buffer for testing eviction - - // Empty buffer - require.True(t, bw.CanReplayFrom(0)) - require.False(t, bw.CanReplayFrom(1)) - - // Write some data - bw.Write([]byte("hello")) - require.True(t, bw.CanReplayFrom(0)) - require.True(t, bw.CanReplayFrom(3)) - require.True(t, bw.CanReplayFrom(5)) - require.False(t, bw.CanReplayFrom(6)) - - // Fill buffer and cause eviction - bw.Write([]byte("world!")) - require.True(t, bw.CanReplayFrom(0)) // Can replay from any sequence up to current - require.True(t, bw.CanReplayFrom(bw.SequenceNum())) -} - -func TestBackedWriter_WaitForConnection(t *testing.T) { - t.Parallel() - - bw := backedpipe.NewBackedWriter() - - // Should timeout when not connected - // Use a shorter timeout for this test to speed up test runs - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperShort) - defer cancel() - - err := bw.WaitForConnection(ctx) - require.Equal(t, context.DeadlineExceeded, err) - - // Should succeed immediately when connected - writer := newMockWriter() - bw.Reconnect(0, writer) - - ctx = context.Background() - err = bw.WaitForConnection(ctx) - require.NoError(t, err) -} - func TestBackedWriter_ConcurrentWrites(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) writer := newMockWriter() bw.Reconnect(0, writer) @@ -339,17 +377,25 @@ func TestBackedWriter_ConcurrentWrites(t *testing.T) { func TestBackedWriter_ReconnectDuringReplay(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() - bw.Write([]byte("hello world")) + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + _, err = bw.Write([]byte("hello world")) + require.NoError(t, err) // Create a writer that fails during replay - writer := &mockWriter{ + writer2 := &mockWriter{ writeFunc: func(p []byte) (int, error) { return 0, xerrors.New("replay failed") }, } - err := bw.Reconnect(0, writer) + err = bw.Reconnect(0, writer2) require.Error(t, err) require.Contains(t, err.Error(), "replay failed") require.False(t, bw.Connected()) @@ -358,7 +404,8 @@ func TestBackedWriter_ReconnectDuringReplay(t *testing.T) { func TestBackedWriter_PartialWriteToUnderlying(t *testing.T) { t.Parallel() - bw := backedpipe.NewBackedWriter() + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) // Create writer that does partial writes writer := &mockWriter{ @@ -372,17 +419,137 @@ func TestBackedWriter_PartialWriteToUnderlying(t *testing.T) { bw.Reconnect(0, writer) - // Write should succeed to buffer but disconnect due to partial write + // Write should fail due to partial write n, err := bw.Write([]byte("hello")) + require.Error(t, err) + require.Equal(t, 0, n) + require.False(t, bw.Connected()) + require.Contains(t, err.Error(), "short write") + + // Error should be sent to error channel + select { + case receivedErr := <-errorChan: + require.Contains(t, receivedErr.Error(), "short write") + default: + t.Fatal("Expected error to be sent to error channel") + } +} + +func TestBackedWriter_WriteUnblocksOnReconnect(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Start a single write that should block + writeResult := make(chan error, 1) + go func() { + _, err := bw.Write([]byte("test")) + writeResult <- err + }() + + // Verify write is blocked + select { + case <-writeResult: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Connect and verify write completes + writer := newMockWriter() + err := bw.Reconnect(0, writer) require.NoError(t, err) - require.Equal(t, 5, n) + + // Write should now complete + err = testutil.RequireReceive(ctx, t, writeResult) + require.NoError(t, err) + + // Write should have been written to the underlying writer + require.Equal(t, "test", writer.buffer.String()) +} + +func TestBackedWriter_CloseUnblocksWaitingWrites(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Start a write that should block + writeComplete := make(chan error, 1) + go func() { + _, err := bw.Write([]byte("test")) + writeComplete <- err + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Close the writer + err := bw.Close() + require.NoError(t, err) + + // Write should now complete with error + err = testutil.RequireReceive(ctx, t, writeComplete) + require.Equal(t, io.ErrClosedPipe, err) +} + +func TestBackedWriter_WriteBlocksAfterDisconnection(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + writer := newMockWriter() + + // Connect initially + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Write should succeed when connected + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + + // Cause disconnection + writer.setError(xerrors.New("connection lost")) + _, err = bw.Write([]byte("world")) + require.Error(t, err) require.False(t, bw.Connected()) - // Buffer should have all data + // Subsequent write should block + writeComplete := make(chan error, 1) + go func() { + _, err := bw.Write([]byte("blocked")) + writeComplete <- err + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked after disconnection") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Reconnect and verify write completes + writer2 := newMockWriter() + err = bw.Reconnect(5, writer2) // Replay from after "hello" + require.NoError(t, err) + + err = testutil.RequireReceive(ctx, t, writeComplete) + require.NoError(t, err) } func BenchmarkBackedWriter_Write(b *testing.B) { - bw := backedpipe.NewBackedWriter() // 64KB buffer + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) // 64KB buffer writer := newMockWriter() bw.Reconnect(0, writer) @@ -395,7 +562,15 @@ func BenchmarkBackedWriter_Write(b *testing.B) { } func BenchmarkBackedWriter_Reconnect(b *testing.B) { - bw := backedpipe.NewBackedWriter() + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Connect initially to fill buffer with data + initialWriter := newMockWriter() + err := bw.Reconnect(0, initialWriter) + if err != nil { + b.Fatal(err) + } // Fill buffer with data data := bytes.Repeat([]byte("x"), 1024) diff --git a/agent/immortalstreams/backedpipe/ring_buffer.go b/agent/immortalstreams/backedpipe/ring_buffer.go index f092385741e0c..eefec300b306c 100644 --- a/agent/immortalstreams/backedpipe/ring_buffer.go +++ b/agent/immortalstreams/backedpipe/ring_buffer.go @@ -1,138 +1,129 @@ package backedpipe import ( - "sync" - "golang.org/x/xerrors" ) -// RingBuffer implements an efficient circular buffer with a fixed-size allocation. -// It supports concurrent access and handles wrap-around seamlessly. -// The buffer is designed for high-performance scenarios where avoiding -// dynamic memory allocation during operation is critical. -type RingBuffer struct { - mu sync.RWMutex +// ringBuffer implements an efficient circular buffer with a fixed-size allocation. +// This implementation is not thread-safe and relies on external synchronization. +type ringBuffer struct { buffer []byte start int // index of first valid byte - end int // index after last valid byte - size int // current number of bytes in buffer - cap int // maximum capacity -} - -// NewRingBuffer creates a new ring buffer with 64MB capacity. -func NewRingBuffer() *RingBuffer { - const capacity = 64 * 1024 * 1024 // 64MB - return NewRingBufferWithCapacity(capacity) + end int // index of last valid byte (-1 when empty) } -// NewRingBufferWithCapacity creates a new ring buffer with the specified capacity. -// If capacity is <= 0, it defaults to 64MB. -func NewRingBufferWithCapacity(capacity int) *RingBuffer { +// newRingBuffer creates a new ring buffer with the specified capacity. +// Capacity must be > 0. +func newRingBuffer(capacity int) *ringBuffer { if capacity <= 0 { - capacity = 64 * 1024 * 1024 // Default to 64MB + panic("ring buffer capacity must be > 0") } - return &RingBuffer{ + return &ringBuffer{ buffer: make([]byte, capacity), - cap: capacity, + end: -1, // -1 indicates empty buffer + } +} + +// Size returns the current number of bytes in the buffer. +func (rb *ringBuffer) Size() int { + if rb.end == -1 { + return 0 // Buffer is empty + } + if rb.start <= rb.end { + return rb.end - rb.start + 1 } + // Buffer wraps around + return len(rb.buffer) - rb.start + rb.end + 1 } // Write writes data to the ring buffer. If the buffer would overflow, // it evicts the oldest data to make room for new data. -// Returns the number of bytes written and the number of bytes evicted. -func (rb *RingBuffer) Write(data []byte) (written int, evicted int) { +func (rb *ringBuffer) Write(data []byte) { if len(data) == 0 { - return 0, 0 + return } - rb.mu.Lock() - defer rb.mu.Unlock() - - written = len(data) + capacity := len(rb.buffer) // If data is larger than capacity, only keep the last capacity bytes - if len(data) > rb.cap { - evicted = len(data) - rb.cap - data = data[evicted:] - written = rb.cap + if len(data) > capacity { + data = data[len(data)-capacity:] // Clear buffer and write new data rb.start = 0 - rb.end = 0 - rb.size = 0 + rb.end = -1 // Will be set properly below } // Calculate how much we need to evict to fit new data spaceNeeded := len(data) - availableSpace := rb.cap - rb.size + availableSpace := capacity - rb.Size() if spaceNeeded > availableSpace { bytesToEvict := spaceNeeded - availableSpace - evicted += bytesToEvict rb.evict(bytesToEvict) } - // Write the data - for _, b := range data { - rb.buffer[rb.end] = b - rb.end = (rb.end + 1) % rb.cap - rb.size++ + // Buffer has data, write after current end + writePos := (rb.end + 1) % capacity + if writePos+len(data) <= capacity { + // No wrap needed - single copy + copy(rb.buffer[writePos:], data) + rb.end = (rb.end + len(data)) % capacity + } else { + // Need to wrap around - two copies + firstChunk := capacity - writePos + copy(rb.buffer[writePos:], data[:firstChunk]) + copy(rb.buffer[0:], data[firstChunk:]) + rb.end = len(data) - firstChunk - 1 } - - return written, evicted } // evict removes the specified number of bytes from the beginning of the buffer. -// Must be called with lock held. -func (rb *RingBuffer) evict(count int) { - if count >= rb.size { +func (rb *ringBuffer) evict(count int) { + if count >= rb.Size() { // Evict everything rb.start = 0 - rb.end = 0 - rb.size = 0 + rb.end = -1 return } - rb.start = (rb.start + count) % rb.cap - rb.size -= count + rb.start = (rb.start + count) % len(rb.buffer) + // Buffer remains non-empty after partial eviction } // ReadLast returns the last n bytes from the buffer. -// If n is greater than the available data, returns all available data. -// If n is 0 or negative, returns nil. -func (rb *RingBuffer) ReadLast(n int) ([]byte, error) { - rb.mu.RLock() - defer rb.mu.RUnlock() +// If n is greater than the available data, returns an error. +// If n is negative, returns an error. +func (rb *ringBuffer) ReadLast(n int) ([]byte, error) { + if n < 0 { + return nil, xerrors.New("cannot read negative number of bytes") + } - if n <= 0 { + if n == 0 { return nil, nil } - if rb.size == 0 { - return nil, xerrors.New("buffer is empty") - } + size := rb.Size() // If requested more than available, return error - if n > rb.size { - return nil, xerrors.Errorf("requested %d bytes but only %d available", n, rb.size) + if n > size { + return nil, xerrors.Errorf("requested %d bytes but only %d available", n, size) } result := make([]byte, n) + capacity := len(rb.buffer) // Calculate where to start reading from (n bytes before the end) - startOffset := rb.size - n - actualStart := rb.start + startOffset - if rb.cap > 0 { - actualStart %= rb.cap - } + startOffset := size - n + actualStart := (rb.start + startOffset) % capacity // Copy the last n bytes - if actualStart+n <= rb.cap { + if actualStart+n <= capacity { // No wrap needed copy(result, rb.buffer[actualStart:actualStart+n]) } else { // Need to wrap around - firstChunk := rb.cap - actualStart - copy(result[0:firstChunk], rb.buffer[actualStart:rb.cap]) + firstChunk := capacity - actualStart + copy(result[0:firstChunk], rb.buffer[actualStart:capacity]) copy(result[firstChunk:], rb.buffer[0:n-firstChunk]) } diff --git a/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go b/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go index 5a23880774057..34fe5fb1cbf6e 100644 --- a/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go +++ b/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go @@ -1,162 +1,264 @@ package backedpipe import ( - "fmt" - "sync" + "bytes" + "os" + "runtime" "testing" "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/coder/coder/v2/testutil" ) -func TestRingBuffer_ClearInternal(t *testing.T) { +func TestMain(m *testing.M) { + if runtime.GOOS == "windows" { + // Don't run goleak on windows tests, they're super flaky right now. + // See: https://github.com/coder/coder/issues/8954 + os.Exit(m.Run()) + } + goleak.VerifyTestMain(m, testutil.GoleakOptions...) +} + +func TestRingBuffer_NewRingBuffer(t *testing.T) { t.Parallel() - rb := NewRingBufferWithCapacity(10) - rb.Write([]byte("hello")) - require.Equal(t, 5, rb.size) + rb := newRingBuffer(100) + // Test that we can write and read from the buffer + rb.Write([]byte("test")) - rb.Clear() - require.Equal(t, 0, rb.size) - require.Equal(t, "", rb.String()) + data, err := rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("test"), data) } -func TestRingBuffer_Available(t *testing.T) { +func TestRingBuffer_WriteAndRead(t *testing.T) { t.Parallel() - rb := NewRingBufferWithCapacity(10) - require.Equal(t, 10, rb.Available()) + rb := newRingBuffer(10) + // Write some data rb.Write([]byte("hello")) - require.Equal(t, 5, rb.Available()) + // Read last 4 bytes + data, err := rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, "ello", string(data)) + + // Write more data rb.Write([]byte("world")) - require.Equal(t, 0, rb.Available()) + + // Read last 5 bytes + data, err = rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, "world", string(data)) + + // Read last 3 bytes + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, "rld", string(data)) + + // Read more than available (should be 10 bytes total) + _, err = rb.ReadLast(15) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 15 bytes but only") } -func TestRingBuffer_StringInternal(t *testing.T) { +func TestRingBuffer_OverflowEviction(t *testing.T) { t.Parallel() - rb := NewRingBufferWithCapacity(10) - require.Equal(t, "", rb.String()) + rb := newRingBuffer(5) - rb.Write([]byte("hello")) - require.Equal(t, "hello", rb.String()) + // Fill buffer + rb.Write([]byte("abcde")) - rb.Write([]byte("world")) - require.Equal(t, "helloworld", rb.String()) + // Overflow should evict oldest data + rb.Write([]byte("fg")) + + // Should now contain "cdefg" + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("cdefg"), data) } -func TestRingBuffer_StringWithWrapAround(t *testing.T) { +func TestRingBuffer_LargeWrite(t *testing.T) { t.Parallel() - rb := NewRingBufferWithCapacity(5) - rb.Write([]byte("hello")) - require.Equal(t, "hello", rb.String()) + rb := newRingBuffer(5) - rb.Write([]byte("world")) - require.Equal(t, "world", rb.String()) + // Write data larger than capacity + rb.Write([]byte("abcdefghij")) + + // Should contain last 5 bytes + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("fghij"), data) } -func TestRingBuffer_ConcurrentAccessWithString(t *testing.T) { +func TestRingBuffer_WrapAround(t *testing.T) { t.Parallel() - rb := NewRingBufferWithCapacity(1000) - var wg sync.WaitGroup - - // Start multiple goroutines writing - for i := 0; i < 10; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - data := fmt.Sprintf("data-%d", id) - for j := 0; j < 100; j++ { - rb.Write([]byte(data)) - } - }(i) - } + rb := newRingBuffer(5) + + // Fill buffer + rb.Write([]byte("abcde")) + + // Write more to cause wrap-around + rb.Write([]byte("fgh")) - wg.Wait() + // Should contain "defgh" + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("defgh"), data) - // Verify buffer is still in valid state - require.NotEmpty(t, rb.String()) + // Test reading last 3 bytes after wrap + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("fgh"), data) } -func TestRingBuffer_EdgeCaseEvictionWithString(t *testing.T) { +func TestRingBuffer_ReadLastEdgeCases(t *testing.T) { t.Parallel() - rb := NewRingBufferWithCapacity(3) + rb := newRingBuffer(3) + + // Write some data (5 bytes to a 3-byte buffer, so only last 3 bytes remain) rb.Write([]byte("hello")) - rb.Write([]byte("world")) - // Should evict "he" and keep "llo world" - require.Equal(t, "rld", rb.String()) + // Test reading negative count + data, err := rb.ReadLast(-1) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot read negative number of bytes") + require.Nil(t, data) + + // Test reading zero bytes + data, err = rb.ReadLast(0) + require.NoError(t, err) + require.Nil(t, data) + + // Test reading more than available (buffer has 3 bytes, try to read 10) + _, err = rb.ReadLast(10) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 10 bytes but only 3 available") + + // Test reading exact amount available + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("llo"), data) +} - // Write more data to cause more eviction - rb.Write([]byte("test")) - require.Equal(t, "est", rb.String()) +func TestRingBuffer_EmptyWrite(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(10) + + // Write empty data + rb.Write([]byte{}) + + // Buffer should still be empty + _, err := rb.ReadLast(5) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 5 bytes but only 0 available") } -// TestRingBuffer_ComplexWrapAroundScenarioWithString tests complex wrap-around with String -func TestRingBuffer_ComplexWrapAroundScenarioWithString(t *testing.T) { +// TestRingBuffer_ConcurrentAccess removed - the ring buffer is no longer thread-safe +// by design, as it relies on external synchronization provided by BackedWriter. + +func TestRingBuffer_MultipleWrites(t *testing.T) { t.Parallel() - rb := NewRingBufferWithCapacity(5) + rb := newRingBuffer(10) - // Fill buffer - rb.Write([]byte("abcde")) - require.Equal(t, "abcde", rb.String()) + // Write data in chunks + rb.Write([]byte("ab")) + rb.Write([]byte("cd")) + rb.Write([]byte("ef")) - // Write more to cause wrap-around - rb.Write([]byte("fgh")) - require.Equal(t, "defgh", rb.String()) + data, err := rb.ReadLast(6) + require.NoError(t, err) + require.Equal(t, []byte("abcdef"), data) - // Write even more - rb.Write([]byte("ijklmn")) - require.Equal(t, "jklmn", rb.String()) + // Test partial reads + data, err = rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("cdef"), data) + + data, err = rb.ReadLast(2) + require.NoError(t, err) + require.Equal(t, []byte("ef"), data) } -// Helper function to get available space (for internal tests only) -func (rb *RingBuffer) Available() int { - rb.mu.RLock() - defer rb.mu.RUnlock() - return rb.cap - rb.size +func TestRingBuffer_EdgeCaseEviction(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(3) + + // Write data that will cause eviction + rb.Write([]byte("abc")) + + // Write more to cause eviction + rb.Write([]byte("d")) + + // Should now contain "bcd" + data, err := rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("bcd"), data) } -// Helper function to clear buffer (for internal tests only) -func (rb *RingBuffer) Clear() { - rb.mu.Lock() - defer rb.mu.Unlock() +func TestRingBuffer_ComplexWrapAroundScenario(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(8) - rb.start = 0 - rb.end = 0 - rb.size = 0 + // Fill buffer + rb.Write([]byte("12345678")) + + // Evict some and add more to create complex wrap scenario + rb.Write([]byte("abcd")) + data, err := rb.ReadLast(8) + require.NoError(t, err) + require.Equal(t, []byte("5678abcd"), data) + + // Add more + rb.Write([]byte("xyz")) + data, err = rb.ReadLast(8) + require.NoError(t, err) + require.Equal(t, []byte("8abcdxyz"), data) + + // Test reading various amounts from the end + data, err = rb.ReadLast(7) + require.NoError(t, err) + require.Equal(t, []byte("abcdxyz"), data) + + data, err = rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("dxyz"), data) } -// Helper function to get string representation (for internal tests only) -func (rb *RingBuffer) String() string { - rb.mu.RLock() - defer rb.mu.RUnlock() +// Benchmark tests for performance validation +func BenchmarkRingBuffer_Write(b *testing.B) { + rb := newRingBuffer(64 * 1024 * 1024) // 64MB for benchmarks + data := bytes.Repeat([]byte("x"), 1024) // 1KB writes - if rb.size == 0 { - return "" + b.ResetTimer() + for i := 0; i < b.N; i++ { + rb.Write(data) } +} - // readAllInternal equivalent for internal tests - if rb.size == 0 { - return "" +func BenchmarkRingBuffer_ReadLast(b *testing.B) { + rb := newRingBuffer(64 * 1024 * 1024) // 64MB for benchmarks + // Fill buffer with test data + for i := 0; i < 64; i++ { + rb.Write(bytes.Repeat([]byte("x"), 1024)) } - result := make([]byte, rb.size) - - if rb.start+rb.size <= rb.cap { - // No wrap needed - copy(result, rb.buffer[rb.start:rb.start+rb.size]) - } else { - // Need to wrap around - firstChunk := rb.cap - rb.start - copy(result[0:firstChunk], rb.buffer[rb.start:rb.cap]) - copy(result[firstChunk:], rb.buffer[0:rb.size-firstChunk]) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := rb.ReadLast((i % 100) + 1) + if err != nil { + b.Fatal(err) + } } - - return string(result) } diff --git a/agent/immortalstreams/backedpipe/ring_buffer_test.go b/agent/immortalstreams/backedpipe/ring_buffer_test.go deleted file mode 100644 index 0d9b4a12dc947..0000000000000 --- a/agent/immortalstreams/backedpipe/ring_buffer_test.go +++ /dev/null @@ -1,339 +0,0 @@ -package backedpipe_test - -import ( - "bytes" - "fmt" - "os" - "runtime" - "sync" - "testing" - - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - - "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" - "github.com/coder/coder/v2/testutil" -) - -func TestMain(m *testing.M) { - if runtime.GOOS == "windows" { - // Don't run goleak on windows tests, they're super flaky right now. - // See: https://github.com/coder/coder/issues/8954 - os.Exit(m.Run()) - } - goleak.VerifyTestMain(m, testutil.GoleakOptions...) -} - -func TestRingBuffer_NewRingBuffer(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(100) - // Test that we can write and read from the buffer - written, evicted := rb.Write([]byte("test")) - require.Equal(t, 4, written) - require.Equal(t, 0, evicted) - - data, err := rb.ReadLast(4) - require.NoError(t, err) - require.Equal(t, []byte("test"), data) -} - -func TestRingBuffer_WriteAndRead(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(10) - - // Write some data - rb.Write([]byte("hello")) - - // Read last 4 bytes - data, err := rb.ReadLast(4) - require.NoError(t, err) - require.Equal(t, "ello", string(data)) - - // Write more data - rb.Write([]byte("world")) - - // Read last 5 bytes - data, err = rb.ReadLast(5) - require.NoError(t, err) - require.Equal(t, "world", string(data)) - - // Read last 3 bytes - data, err = rb.ReadLast(3) - require.NoError(t, err) - require.Equal(t, "rld", string(data)) - - // Read more than available (should be 10 bytes total) - _, err = rb.ReadLast(15) - require.Error(t, err) - require.Contains(t, err.Error(), "requested 15 bytes but only") -} - -func TestRingBuffer_OverflowEviction(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(5) - - // Fill buffer - written, evicted := rb.Write([]byte("abcde")) - require.Equal(t, 5, written) - require.Equal(t, 0, evicted) - - // Overflow should evict oldest data - written, evicted = rb.Write([]byte("fg")) - require.Equal(t, 2, written) - require.Equal(t, 2, evicted) - - // Should now contain "cdefg" - data, err := rb.ReadLast(5) - require.NoError(t, err) - require.Equal(t, []byte("cdefg"), data) -} - -func TestRingBuffer_LargeWrite(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(5) - - // Write data larger than capacity - written, evicted := rb.Write([]byte("abcdefghij")) - require.Equal(t, 5, written) - require.Equal(t, 5, evicted) - - // Should contain last 5 bytes - data, err := rb.ReadLast(5) - require.NoError(t, err) - require.Equal(t, []byte("fghij"), data) -} - -func TestRingBuffer_WrapAround(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(5) - - // Fill buffer - rb.Write([]byte("abcde")) - - // Write more to cause wrap-around - rb.Write([]byte("fgh")) - - // Should contain "defgh" - data, err := rb.ReadLast(5) - require.NoError(t, err) - require.Equal(t, []byte("defgh"), data) - - // Test reading last 3 bytes after wrap - data, err = rb.ReadLast(3) - require.NoError(t, err) - require.Equal(t, []byte("fgh"), data) -} - -func TestRingBuffer_ReadLastEdgeCases(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(3) - - // Write some data (5 bytes to a 3-byte buffer, so only last 3 bytes remain) - rb.Write([]byte("hello")) - - // Test reading negative count - data, err := rb.ReadLast(-1) - require.NoError(t, err) - require.Nil(t, data) - - // Test reading zero bytes - data, err = rb.ReadLast(0) - require.NoError(t, err) - require.Nil(t, data) - - // Test reading more than available (buffer has 3 bytes, try to read 10) - _, err = rb.ReadLast(10) - require.Error(t, err) - require.Contains(t, err.Error(), "requested 10 bytes but only 3 available") - - // Test reading exact amount available - data, err = rb.ReadLast(3) - require.NoError(t, err) - require.Equal(t, []byte("llo"), data) -} - -func TestRingBuffer_EmptyWrite(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(10) - - // Write empty data - written, evicted := rb.Write([]byte{}) - require.Equal(t, 0, written) - require.Equal(t, 0, evicted) - - // Buffer should still be empty - _, err := rb.ReadLast(5) - require.Error(t, err) - require.Contains(t, err.Error(), "buffer is empty") -} - -func TestRingBuffer_ConcurrentAccess(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(1000) - var wg sync.WaitGroup - - // Start multiple goroutines writing - for i := 0; i < 10; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - data := []byte(fmt.Sprintf("data-%d", id)) - for j := 0; j < 100; j++ { - rb.Write(data) - } - }(i) - } - - // Start multiple goroutines reading - for i := 0; i < 5; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 100; j++ { - _, err := rb.ReadLast(100) - if err != nil { - // Error is expected if buffer doesn't have enough data - continue - } - } - }() - } - - wg.Wait() - - // Verify buffer is still in valid state - data, err := rb.ReadLast(1000) - require.NoError(t, err) - require.NotNil(t, data) -} - -func TestRingBuffer_MultipleWrites(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(10) - - // Write data in chunks - rb.Write([]byte("ab")) - rb.Write([]byte("cd")) - rb.Write([]byte("ef")) - - data, err := rb.ReadLast(6) - require.NoError(t, err) - require.Equal(t, []byte("abcdef"), data) - - // Test partial reads - data, err = rb.ReadLast(4) - require.NoError(t, err) - require.Equal(t, []byte("cdef"), data) - - data, err = rb.ReadLast(2) - require.NoError(t, err) - require.Equal(t, []byte("ef"), data) -} - -func TestRingBuffer_EdgeCaseEviction(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(3) - - // Write data that will cause eviction - written, evicted := rb.Write([]byte("abc")) - require.Equal(t, 3, written) - require.Equal(t, 0, evicted) - - // Write more to cause eviction - written, evicted = rb.Write([]byte("d")) - require.Equal(t, 1, written) - require.Equal(t, 1, evicted) - - // Should now contain "bcd" - data, err := rb.ReadLast(3) - require.NoError(t, err) - require.Equal(t, []byte("bcd"), data) -} - -func TestRingBuffer_ComplexWrapAroundScenario(t *testing.T) { - t.Parallel() - - rb := backedpipe.NewRingBufferWithCapacity(8) - - // Fill buffer - rb.Write([]byte("12345678")) - - // Evict some and add more to create complex wrap scenario - rb.Write([]byte("abcd")) - data, err := rb.ReadLast(8) - require.NoError(t, err) - require.Equal(t, []byte("5678abcd"), data) - - // Add more - rb.Write([]byte("xyz")) - data, err = rb.ReadLast(8) - require.NoError(t, err) - require.Equal(t, []byte("8abcdxyz"), data) - - // Test reading various amounts from the end - data, err = rb.ReadLast(7) - require.NoError(t, err) - require.Equal(t, []byte("abcdxyz"), data) - - data, err = rb.ReadLast(4) - require.NoError(t, err) - require.Equal(t, []byte("dxyz"), data) -} - -// Benchmark tests for performance validation -func BenchmarkRingBuffer_Write(b *testing.B) { - rb := backedpipe.NewRingBuffer() // Use full 64MB for benchmarks - data := bytes.Repeat([]byte("x"), 1024) // 1KB writes - - b.ResetTimer() - for i := 0; i < b.N; i++ { - rb.Write(data) - } -} - -func BenchmarkRingBuffer_ReadLast(b *testing.B) { - rb := backedpipe.NewRingBuffer() // Use full 64MB for benchmarks - // Fill buffer with test data - for i := 0; i < 64; i++ { - rb.Write(bytes.Repeat([]byte("x"), 1024)) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := rb.ReadLast((i % 100) + 1) - if err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkRingBuffer_ConcurrentAccess(b *testing.B) { - rb := backedpipe.NewRingBuffer() // Use full 64MB for benchmarks - data := bytes.Repeat([]byte("x"), 100) - - // Pre-fill buffer with enough data - for i := 0; i < 100; i++ { - rb.Write(data) - } - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - rb.Write(data) - _, err := rb.ReadLast(100) // Read only what we know is available - if err != nil { - b.Fatal(err) - } - } - }) -} From a303e68792bf3ae4cee1d0835645d5a605a15941 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:55:29 +0000 Subject: [PATCH 05/11] fmt --- agent/immortalstreams/backedpipe/backed_writer.go | 1 - 1 file changed, 1 deletion(-) diff --git a/agent/immortalstreams/backedpipe/backed_writer.go b/agent/immortalstreams/backedpipe/backed_writer.go index b268f48bc77eb..707abb16271a8 100644 --- a/agent/immortalstreams/backedpipe/backed_writer.go +++ b/agent/immortalstreams/backedpipe/backed_writer.go @@ -73,7 +73,6 @@ func (bw *BackedWriter) Write(p []byte) (int, error) { // Write to underlying writer n, err := bw.writer.Write(p) - if err != nil { // Connection failed, mark as disconnected bw.writer = nil From fd2fd868d6872add95ab1e84c0b3e9d87a4d7f9d Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Wed, 13 Aug 2025 15:07:06 +0000 Subject: [PATCH 06/11] revereted WaitSuperShort - not used anymore --- testutil/duration.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/testutil/duration.go b/testutil/duration.go index 821684f6b0f98..a8c35030cdea2 100644 --- a/testutil/duration.go +++ b/testutil/duration.go @@ -7,11 +7,10 @@ import ( // Constants for timing out operations, usable for creating contexts // that timeout or in require.Eventually. const ( - WaitSuperShort = 100 * time.Millisecond - WaitShort = 10 * time.Second - WaitMedium = 15 * time.Second - WaitLong = 25 * time.Second - WaitSuperLong = 60 * time.Second + WaitShort = 10 * time.Second + WaitMedium = 15 * time.Second + WaitLong = 25 * time.Second + WaitSuperLong = 60 * time.Second ) // Constants for delaying repeated operations, e.g. in From 68c08b15b61073160bf7c8d119974a85fed7bad8 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Thu, 14 Aug 2025 23:30:51 +0000 Subject: [PATCH 07/11] changed err to EOF instead of ErrClosedPipe and changed backed_writer to block on errors --- .../immortalstreams/backedpipe/backed_pipe.go | 6 +- .../backedpipe/backed_pipe_test.go | 6 +- .../backedpipe/backed_writer.go | 53 +++++-- .../backedpipe/backed_writer_test.go | 132 +++++++++++++++--- 4 files changed, 154 insertions(+), 43 deletions(-) diff --git a/agent/immortalstreams/backedpipe/backed_pipe.go b/agent/immortalstreams/backedpipe/backed_pipe.go index 8161b91533233..8e9b68ee21364 100644 --- a/agent/immortalstreams/backedpipe/backed_pipe.go +++ b/agent/immortalstreams/backedpipe/backed_pipe.go @@ -110,7 +110,7 @@ func (bp *BackedPipe) Read(p []byte) (int, error) { bp.mu.RUnlock() if closed { - return 0, io.ErrClosedPipe + return 0, io.EOF } return reader.Read(p) @@ -124,7 +124,7 @@ func (bp *BackedPipe) Write(p []byte) (int, error) { bp.mu.RUnlock() if closed { - return 0, io.ErrClosedPipe + return 0, io.EOF } return writer.Write(p) @@ -294,7 +294,7 @@ func (bp *BackedPipe) ForceReconnect() error { defer bp.mu.Unlock() if bp.closed { - return nil, io.ErrClosedPipe + return nil, io.EOF } return nil, bp.reconnectLocked() diff --git a/agent/immortalstreams/backedpipe/backed_pipe_test.go b/agent/immortalstreams/backedpipe/backed_pipe_test.go index 1cd5bd227ebcd..30a751342b323 100644 --- a/agent/immortalstreams/backedpipe/backed_pipe_test.go +++ b/agent/immortalstreams/backedpipe/backed_pipe_test.go @@ -382,10 +382,10 @@ func TestBackedPipe_Close(t *testing.T) { // Operations after close should fail _, err = bp.Read(make([]byte, 10)) - require.Equal(t, io.ErrClosedPipe, err) + require.Equal(t, io.EOF, err) _, err = bp.Write([]byte("test")) - require.Equal(t, io.ErrClosedPipe, err) + require.Equal(t, io.EOF, err) } func TestBackedPipe_CloseIdempotent(t *testing.T) { @@ -567,7 +567,7 @@ func TestBackedPipe_ForceReconnectWhenClosed(t *testing.T) { // Try to force reconnect when closed err = bp.ForceReconnect() require.Error(t, err) - require.Equal(t, io.ErrClosedPipe, err) + require.Equal(t, io.EOF, err) } func TestBackedPipe_ForceReconnectWhenDisconnected(t *testing.T) { diff --git a/agent/immortalstreams/backedpipe/backed_writer.go b/agent/immortalstreams/backedpipe/backed_writer.go index 707abb16271a8..3d791aaa5cc0e 100644 --- a/agent/immortalstreams/backedpipe/backed_writer.go +++ b/agent/immortalstreams/backedpipe/backed_writer.go @@ -40,8 +40,21 @@ func NewBackedWriter(capacity int, errorChan chan<- error) *BackedWriter { return bw } +// blockUntilConnectedOrClosed blocks until either a writer is available or the BackedWriter is closed. +// Returns io.EOF if closed while waiting, nil if connected. +func (bw *BackedWriter) blockUntilConnectedOrClosed() error { + for bw.writer == nil && !bw.closed { + bw.cond.Wait() + } + if bw.closed { + return io.EOF + } + return nil +} + // Write implements io.Writer. -// When connected, it writes to both the ring buffer and the underlying writer. +// When connected, it writes to both the ring buffer (to preserve data in case we need to replay it) +// and the underlying writer. // If the underlying write fails, the writer is marked as disconnected and the write blocks // until reconnection occurs. func (bw *BackedWriter) Write(p []byte) (int, error) { @@ -53,25 +66,19 @@ func (bw *BackedWriter) Write(p []byte) (int, error) { defer bw.mu.Unlock() if bw.closed { - return 0, io.ErrClosedPipe + return 0, io.EOF } // Block until connected - for bw.writer == nil && !bw.closed { - bw.cond.Wait() - } - - // Check if we were closed while waiting - if bw.closed { - return 0, io.ErrClosedPipe + if err := bw.blockUntilConnectedOrClosed(); err != nil { + return 0, err } - // Always write to buffer first + // Write to buffer bw.buffer.Write(p) - // Always advance sequence number by the full length bw.sequenceNum += uint64(len(p)) - // Write to underlying writer + // Try to write to underlying writer n, err := bw.writer.Write(p) if err != nil { // Connection failed, mark as disconnected @@ -82,19 +89,35 @@ func (bw *BackedWriter) Write(p []byte) (int, error) { case bw.errorChan <- err: default: } - return 0, err + + // Block until reconnected - reconnection will replay this data + if err := bw.blockUntilConnectedOrClosed(); err != nil { + return 0, err + } + + // Don't retry - reconnection replay handled it + return len(p), nil } if n != len(p) { + // Partial write - treat as failure bw.writer = nil err = xerrors.Errorf("short write: %d bytes written, %d expected", n, len(p)) select { case bw.errorChan <- err: default: } - return 0, err + + // Block until reconnected - reconnection will replay this data + if err := bw.blockUntilConnectedOrClosed(); err != nil { + return 0, err + } + + // Don't retry - reconnection replay handled it + return len(p), nil } + // Write succeeded return len(p), nil } @@ -169,7 +192,7 @@ func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) err } // Close closes the writer and prevents further writes. -// After closing, all Write calls will return io.ErrClosedPipe. +// After closing, all Write calls will return io.EOF. // This code keeps the Close() signature consistent with io.Closer, // but it never actually returns an error. func (bw *BackedWriter) Close() error { diff --git a/agent/immortalstreams/backedpipe/backed_writer_test.go b/agent/immortalstreams/backedpipe/backed_writer_test.go index 02f48d811f5c6..01d794eb9efe2 100644 --- a/agent/immortalstreams/backedpipe/backed_writer_test.go +++ b/agent/immortalstreams/backedpipe/backed_writer_test.go @@ -143,8 +143,9 @@ func TestBackedWriter_WriteToUnderlyingWhenConnected(t *testing.T) { require.Equal(t, []byte("hello"), writer.buffer.Bytes()) } -func TestBackedWriter_DisconnectOnWriteFailure(t *testing.T) { +func TestBackedWriter_BlockOnWriteFailure(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) errorChan := make(chan error, 1) bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) @@ -157,11 +158,26 @@ func TestBackedWriter_DisconnectOnWriteFailure(t *testing.T) { // Cause write to fail writer.setError(xerrors.New("write failed")) - // Write should fail and disconnect - n, err := bw.Write([]byte("hello")) - require.Error(t, err) // Write should fail - require.Equal(t, 0, n) - require.False(t, bw.Connected()) // Should be disconnected + // Write should block when underlying writer fails + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + n, writeErr = bw.Write([]byte("hello")) + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when underlying writer fails") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Should be disconnected + require.False(t, bw.Connected()) // Error should be sent to error channel select { @@ -170,6 +186,19 @@ func TestBackedWriter_DisconnectOnWriteFailure(t *testing.T) { default: t.Fatal("Expected error to be sent to error channel") } + + // Reconnect with working writer and verify write completes + writer2 := newMockWriter() + err = bw.Reconnect(0, writer2) // Replay from beginning + require.NoError(t, err) + + // Write should now complete + testutil.TryReceive(ctx, t, writeComplete) + + require.NoError(t, writeErr) + require.Equal(t, 5, n) + require.Equal(t, uint64(5), bw.SequenceNum()) + require.Equal(t, []byte("hello"), writer2.buffer.Bytes()) } func TestBackedWriter_ReplayOnReconnect(t *testing.T) { @@ -193,8 +222,25 @@ func TestBackedWriter_ReplayOnReconnect(t *testing.T) { // Disconnect by causing a write failure writer1.setError(xerrors.New("connection lost")) - _, err = bw.Write([]byte("test")) - require.Error(t, err) + + // Write should block when underlying writer fails + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + n, writeErr = bw.Write([]byte("test")) + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when underlying writer fails") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + require.False(t, bw.Connected()) // Reconnect with new writer and request replay from beginning @@ -202,6 +248,17 @@ func TestBackedWriter_ReplayOnReconnect(t *testing.T) { err = bw.Reconnect(0, writer2) require.NoError(t, err) + // Write should now complete + select { + case <-writeComplete: + // Expected - write completed + case <-time.After(100 * time.Millisecond): + t.Fatal("Write should have completed after reconnection") + } + + require.NoError(t, writeErr) + require.Equal(t, 4, n) + // Should have replayed all data including the failed write that was buffered require.Equal(t, []byte("hello worldtest"), writer2.buffer.Bytes()) @@ -319,7 +376,7 @@ func TestBackedWriter_Close(t *testing.T) { // Writes after close should fail _, err = bw.Write([]byte("test")) - require.Equal(t, io.ErrClosedPipe, err) + require.Equal(t, io.EOF, err) // Reconnect after close should fail err = bw.Reconnect(0, newMockWriter()) @@ -401,8 +458,9 @@ func TestBackedWriter_ReconnectDuringReplay(t *testing.T) { require.False(t, bw.Connected()) } -func TestBackedWriter_PartialWriteToUnderlying(t *testing.T) { +func TestBackedWriter_BlockOnPartialWrite(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) errorChan := make(chan error, 1) bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) @@ -419,12 +477,26 @@ func TestBackedWriter_PartialWriteToUnderlying(t *testing.T) { bw.Reconnect(0, writer) - // Write should fail due to partial write - n, err := bw.Write([]byte("hello")) - require.Error(t, err) - require.Equal(t, 0, n) + // Write should block due to partial write + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + n, writeErr = bw.Write([]byte("hello")) + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when underlying writer does partial write") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Should be disconnected require.False(t, bw.Connected()) - require.Contains(t, err.Error(), "short write") // Error should be sent to error channel select { @@ -433,6 +505,19 @@ func TestBackedWriter_PartialWriteToUnderlying(t *testing.T) { default: t.Fatal("Expected error to be sent to error channel") } + + // Reconnect with working writer and verify write completes + writer2 := newMockWriter() + err := bw.Reconnect(0, writer2) // Replay from beginning + require.NoError(t, err) + + // Write should now complete + testutil.TryReceive(ctx, t, writeComplete) + + require.NoError(t, writeErr) + require.Equal(t, 5, n) + require.Equal(t, uint64(5), bw.SequenceNum()) + require.Equal(t, []byte("hello"), writer2.buffer.Bytes()) } func TestBackedWriter_WriteUnblocksOnReconnect(t *testing.T) { @@ -498,7 +583,7 @@ func TestBackedWriter_CloseUnblocksWaitingWrites(t *testing.T) { // Write should now complete with error err = testutil.RequireReceive(ctx, t, writeComplete) - require.Equal(t, io.ErrClosedPipe, err) + require.Equal(t, io.EOF, err) } func TestBackedWriter_WriteBlocksAfterDisconnection(t *testing.T) { @@ -517,16 +602,13 @@ func TestBackedWriter_WriteBlocksAfterDisconnection(t *testing.T) { _, err = bw.Write([]byte("hello")) require.NoError(t, err) - // Cause disconnection + // Cause disconnection - the write should now block instead of returning an error writer.setError(xerrors.New("connection lost")) - _, err = bw.Write([]byte("world")) - require.Error(t, err) - require.False(t, bw.Connected()) - // Subsequent write should block + // This write should block writeComplete := make(chan error, 1) go func() { - _, err := bw.Write([]byte("blocked")) + _, err := bw.Write([]byte("world")) writeComplete <- err }() @@ -538,6 +620,9 @@ func TestBackedWriter_WriteBlocksAfterDisconnection(t *testing.T) { // Expected - write is blocked } + // Should be disconnected + require.False(t, bw.Connected()) + // Reconnect and verify write completes writer2 := newMockWriter() err = bw.Reconnect(5, writer2) // Replay from after "hello" @@ -545,6 +630,9 @@ func TestBackedWriter_WriteBlocksAfterDisconnection(t *testing.T) { err = testutil.RequireReceive(ctx, t, writeComplete) require.NoError(t, err) + + // Check that only "world" was written during replay (not duplicated) + require.Equal(t, []byte("world"), writer2.buffer.Bytes()) // Only "world" since we replayed from sequence 5 } func BenchmarkBackedWriter_Write(b *testing.B) { From 568b1a4f96d960fd037e87aca4586690103471c7 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Thu, 14 Aug 2025 23:43:00 +0000 Subject: [PATCH 08/11] added sentinel errors --- .../immortalstreams/backedpipe/backed_pipe.go | 22 +++++++++++++------ .../backedpipe/backed_pipe_test.go | 6 ++--- .../backedpipe/backed_writer.go | 21 +++++++++++++----- .../backedpipe/backed_writer_test.go | 8 +++---- .../immortalstreams/backedpipe/ring_buffer.go | 4 +--- 5 files changed, 38 insertions(+), 23 deletions(-) diff --git a/agent/immortalstreams/backedpipe/backed_pipe.go b/agent/immortalstreams/backedpipe/backed_pipe.go index 8e9b68ee21364..2e5aeeebc6696 100644 --- a/agent/immortalstreams/backedpipe/backed_pipe.go +++ b/agent/immortalstreams/backedpipe/backed_pipe.go @@ -9,6 +9,15 @@ import ( "golang.org/x/xerrors" ) +var ( + ErrPipeClosed = xerrors.New("pipe is closed") + ErrPipeAlreadyConnected = xerrors.New("pipe is already connected") + ErrReconnectionInProgress = xerrors.New("reconnection already in progress") + ErrReconnectFailed = xerrors.New("reconnect failed") + ErrInvalidSequenceNumber = xerrors.New("remote sequence number exceeds local sequence") + ErrReconnectWriterFailed = xerrors.New("reconnect writer failed") +) + const ( // Default buffer capacity used by the writer - 64MB DefaultBufferSize = 64 * 1024 * 1024 @@ -90,11 +99,11 @@ func (bp *BackedPipe) Connect(_ context.Context) error { // external ctx ignored defer bp.mu.Unlock() if bp.closed { - return xerrors.New("pipe is closed") + return ErrPipeClosed } if bp.connected { - return xerrors.New("pipe is already connected") + return ErrPipeAlreadyConnected } // Use internal context for the actual reconnect operation to ensure @@ -190,7 +199,7 @@ func (bp *BackedPipe) signalConnectionChange() { // reconnectLocked handles the reconnection logic. Must be called with write lock held. func (bp *BackedPipe) reconnectLocked() error { if bp.reconnecting { - return xerrors.New("reconnection already in progress") + return ErrReconnectionInProgress } bp.reconnecting = true @@ -216,14 +225,13 @@ func (bp *BackedPipe) reconnectLocked() error { bp.mu.Lock() if err != nil { - return xerrors.Errorf("reconnect failed: %w", err) + return ErrReconnectFailed } // Validate sequence numbers if readerSeqNum > writerSeqNum { _ = conn.Close() - return xerrors.Errorf("remote sequence number %d exceeds local sequence %d, cannot replay", - readerSeqNum, writerSeqNum) + return ErrInvalidSequenceNumber } // Reconnect reader and writer @@ -239,7 +247,7 @@ func (bp *BackedPipe) reconnectLocked() error { err = bp.writer.Reconnect(readerSeqNum, conn) if err != nil { _ = conn.Close() - return xerrors.Errorf("reconnect writer: %w", err) + return ErrReconnectWriterFailed } // Success - update state diff --git a/agent/immortalstreams/backedpipe/backed_pipe_test.go b/agent/immortalstreams/backedpipe/backed_pipe_test.go index 30a751342b323..f9c5bd0adc760 100644 --- a/agent/immortalstreams/backedpipe/backed_pipe_test.go +++ b/agent/immortalstreams/backedpipe/backed_pipe_test.go @@ -189,7 +189,7 @@ func TestBackedPipe_ConnectAlreadyConnected(t *testing.T) { // Second connect should fail err = bp.Connect(ctx) require.Error(t, err) - require.Contains(t, err.Error(), "already connected") + require.ErrorIs(t, err, backedpipe.ErrPipeAlreadyConnected) } func TestBackedPipe_ConnectAfterClose(t *testing.T) { @@ -206,7 +206,7 @@ func TestBackedPipe_ConnectAfterClose(t *testing.T) { err = bp.Connect(ctx) require.Error(t, err) - require.Contains(t, err.Error(), "closed") + require.ErrorIs(t, err, backedpipe.ErrPipeClosed) } func TestBackedPipe_BasicReadWrite(t *testing.T) { @@ -501,7 +501,7 @@ func TestBackedPipe_ReconnectFunctionFailure(t *testing.T) { err := bp.Connect(ctx) require.Error(t, err) - require.Contains(t, err.Error(), "reconnect failed") + require.ErrorIs(t, err, backedpipe.ErrReconnectFailed) require.False(t, bp.Connected()) } diff --git a/agent/immortalstreams/backedpipe/backed_writer.go b/agent/immortalstreams/backedpipe/backed_writer.go index 3d791aaa5cc0e..894aaa4240118 100644 --- a/agent/immortalstreams/backedpipe/backed_writer.go +++ b/agent/immortalstreams/backedpipe/backed_writer.go @@ -7,6 +7,15 @@ import ( "golang.org/x/xerrors" ) +var ( + ErrWriterClosed = xerrors.New("cannot reconnect closed writer") + ErrNilWriter = xerrors.New("new writer cannot be nil") + ErrFutureSequence = xerrors.New("cannot replay from future sequence") + ErrReplayDataUnavailable = xerrors.New("failed to read replay data") + ErrReplayFailed = xerrors.New("replay failed") + ErrPartialReplay = xerrors.New("partial replay") +) + // BackedWriter wraps an unreliable io.Writer and makes it resilient to disconnections. // It maintains a ring buffer of recent writes for replay during reconnection. type BackedWriter struct { @@ -129,16 +138,16 @@ func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) err defer bw.mu.Unlock() if bw.closed { - return xerrors.New("cannot reconnect closed writer") + return ErrWriterClosed } if newWriter == nil { - return xerrors.New("new writer cannot be nil") + return ErrNilWriter } // Check if we can replay from the requested sequence number if replayFromSeq > bw.sequenceNum { - return xerrors.Errorf("cannot replay from future sequence %d: current sequence is %d", replayFromSeq, bw.sequenceNum) + return ErrFutureSequence } // Calculate how many bytes we need to replay @@ -156,7 +165,7 @@ func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) err //nolint:gosec // Safe conversion: replayBytes is calculated from uint64 subtraction replayData, err = bw.buffer.ReadLast(int(replayBytes)) if err != nil { - return xerrors.Errorf("failed to read replay data: %w", err) + return ErrReplayDataUnavailable } } @@ -172,12 +181,12 @@ func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) err if err != nil { // Reconnect failed, writer remains nil - return xerrors.Errorf("replay failed: %w", err) + return ErrReplayFailed } if n != len(replayData) { // Reconnect failed, writer remains nil - return xerrors.Errorf("partial replay: wrote %d of %d bytes", n, len(replayData)) + return ErrPartialReplay } } diff --git a/agent/immortalstreams/backedpipe/backed_writer_test.go b/agent/immortalstreams/backedpipe/backed_writer_test.go index 01d794eb9efe2..928cea89b361d 100644 --- a/agent/immortalstreams/backedpipe/backed_writer_test.go +++ b/agent/immortalstreams/backedpipe/backed_writer_test.go @@ -313,7 +313,7 @@ func TestBackedWriter_ReplayFromFutureSequence(t *testing.T) { writer2 := newMockWriter() err = bw.Reconnect(10, writer2) // Future sequence require.Error(t, err) - require.Contains(t, err.Error(), "future sequence") + require.ErrorIs(t, err, backedpipe.ErrFutureSequence) } func TestBackedWriter_ReplayDataLoss(t *testing.T) { @@ -336,7 +336,7 @@ func TestBackedWriter_ReplayDataLoss(t *testing.T) { err = bw.Reconnect(0, writer2) // Try to replay from evicted data // With the new error handling, this should fail because we can't read all the data require.Error(t, err) - require.Contains(t, err.Error(), "failed to read replay data") + require.ErrorIs(t, err, backedpipe.ErrReplayDataUnavailable) } func TestBackedWriter_BufferEviction(t *testing.T) { @@ -381,7 +381,7 @@ func TestBackedWriter_Close(t *testing.T) { // Reconnect after close should fail err = bw.Reconnect(0, newMockWriter()) require.Error(t, err) - require.Contains(t, err.Error(), "closed") + require.ErrorIs(t, err, backedpipe.ErrWriterClosed) } func TestBackedWriter_CloseIdempotent(t *testing.T) { @@ -454,7 +454,7 @@ func TestBackedWriter_ReconnectDuringReplay(t *testing.T) { err = bw.Reconnect(0, writer2) require.Error(t, err) - require.Contains(t, err.Error(), "replay failed") + require.ErrorIs(t, err, backedpipe.ErrReplayFailed) require.False(t, bw.Connected()) } diff --git a/agent/immortalstreams/backedpipe/ring_buffer.go b/agent/immortalstreams/backedpipe/ring_buffer.go index eefec300b306c..91fde569afb25 100644 --- a/agent/immortalstreams/backedpipe/ring_buffer.go +++ b/agent/immortalstreams/backedpipe/ring_buffer.go @@ -1,8 +1,6 @@ package backedpipe -import ( - "golang.org/x/xerrors" -) +import "golang.org/x/xerrors" // ringBuffer implements an efficient circular buffer with a fixed-size allocation. // This implementation is not thread-safe and relies on external synchronization. From 07c69636f77ab3ab9bfa9d2fcfeb7b92d1b0b74a Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Thu, 14 Aug 2025 23:47:09 +0000 Subject: [PATCH 09/11] fixed eviction tests in writer and moved from assert to require in reader tests --- .../backedpipe/backed_reader_test.go | 103 +++++++++--------- .../backedpipe/backed_writer_test.go | 18 ++- 2 files changed, 68 insertions(+), 53 deletions(-) diff --git a/agent/immortalstreams/backedpipe/backed_reader_test.go b/agent/immortalstreams/backedpipe/backed_reader_test.go index 25d2038d6d843..019511bbcfc64 100644 --- a/agent/immortalstreams/backedpipe/backed_reader_test.go +++ b/agent/immortalstreams/backedpipe/backed_reader_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/xerrors" @@ -59,9 +58,9 @@ func TestBackedReader_NewBackedReader(t *testing.T) { t.Parallel() br := backedpipe.NewBackedReader() - assert.NotNil(t, br) - assert.Equal(t, uint64(0), br.SequenceNum()) - assert.False(t, br.Connected()) + require.NotNil(t, br) + require.Equal(t, uint64(0), br.SequenceNum()) + require.False(t, br.Connected()) } func TestBackedReader_BasicReadOperation(t *testing.T) { @@ -79,7 +78,7 @@ func TestBackedReader_BasicReadOperation(t *testing.T) { // Get sequence number from reader seq := testutil.RequireReceive(ctx, t, seqNum) - assert.Equal(t, uint64(0), seq) + require.Equal(t, uint64(0), seq) // Send new reader testutil.RequireSend(ctx, t, newR, io.Reader(reader)) @@ -88,16 +87,16 @@ func TestBackedReader_BasicReadOperation(t *testing.T) { buf := make([]byte, 5) n, err := br.Read(buf) require.NoError(t, err) - assert.Equal(t, 5, n) - assert.Equal(t, "hello", string(buf)) - assert.Equal(t, uint64(5), br.SequenceNum()) + require.Equal(t, 5, n) + require.Equal(t, "hello", string(buf)) + require.Equal(t, uint64(5), br.SequenceNum()) // Read more data n, err = br.Read(buf) require.NoError(t, err) - assert.Equal(t, 5, n) - assert.Equal(t, " worl", string(buf)) - assert.Equal(t, uint64(10), br.SequenceNum()) + require.Equal(t, 5, n) + require.Equal(t, " worl", string(buf)) + require.Equal(t, uint64(10), br.SequenceNum()) } func TestBackedReader_ReadBlocksWhenDisconnected(t *testing.T) { @@ -143,8 +142,8 @@ func TestBackedReader_ReadBlocksWhenDisconnected(t *testing.T) { // Wait for read to complete testutil.TryReceive(ctx, t, readDone) - assert.NoError(t, readErr) - assert.Equal(t, "test", string(readBuf)) + require.NoError(t, readErr) + require.Equal(t, "test", string(readBuf)) } func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { @@ -168,8 +167,8 @@ func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { buf := make([]byte, 5) n, err := br.Read(buf) require.NoError(t, err) - assert.Equal(t, "first", string(buf[:n])) - assert.Equal(t, uint64(5), br.SequenceNum()) + require.Equal(t, "first", string(buf[:n])) + require.Equal(t, uint64(5), br.SequenceNum()) // Set up error callback to verify error notification errorReceived := make(chan error, 1) @@ -189,8 +188,8 @@ func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { // Wait for the error to be reported via callback receivedErr := testutil.RequireReceive(ctx, t, errorReceived) - assert.Error(t, receivedErr) - assert.Contains(t, receivedErr.Error(), "connection lost") + require.Error(t, receivedErr) + require.Contains(t, receivedErr.Error(), "connection lost") // Verify read is still blocked select { @@ -201,7 +200,7 @@ func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { } // Verify disconnection - assert.False(t, br.Connected()) + require.False(t, br.Connected()) // Reconnect with new reader reader2 := newMockReader("second") @@ -212,12 +211,12 @@ func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { // Get sequence number and send new reader seq := testutil.RequireReceive(ctx, t, seqNum2) - assert.Equal(t, uint64(5), seq) // Should return current sequence number + require.Equal(t, uint64(5), seq) // Should return current sequence number testutil.RequireSend(ctx, t, newR2, io.Reader(reader2)) // Wait for read to unblock and succeed with new data readErr := testutil.RequireReceive(ctx, t, readDone) - assert.NoError(t, readErr) // Should succeed with new reader + require.NoError(t, readErr) // Should succeed with new reader } func TestBackedReader_Close(t *testing.T) { @@ -241,7 +240,7 @@ func TestBackedReader_Close(t *testing.T) { buf := make([]byte, 10) n, err := br.Read(buf) require.NoError(t, err) - assert.Equal(t, 4, n) // "test" is 4 bytes + require.Equal(t, 4, n) // "test" is 4 bytes // Close the reader before EOF triggers reconnection err = br.Close() @@ -249,12 +248,12 @@ func TestBackedReader_Close(t *testing.T) { // After close, reads should return EOF n, err = br.Read(buf) - assert.Equal(t, 0, n) - assert.Equal(t, io.EOF, err) + require.Equal(t, 0, n) + require.Equal(t, io.EOF, err) // Subsequent reads should return EOF _, err = br.Read(buf) - assert.Equal(t, io.EOF, err) + require.Equal(t, io.EOF, err) } func TestBackedReader_CloseIdempotent(t *testing.T) { @@ -263,11 +262,11 @@ func TestBackedReader_CloseIdempotent(t *testing.T) { br := backedpipe.NewBackedReader() err := br.Close() - assert.NoError(t, err) + require.NoError(t, err) // Second close should be no-op err = br.Close() - assert.NoError(t, err) + require.NoError(t, err) } func TestBackedReader_ReconnectAfterClose(t *testing.T) { @@ -286,7 +285,7 @@ func TestBackedReader_ReconnectAfterClose(t *testing.T) { // Should get 0 sequence number for closed reader seq := testutil.TryReceive(ctx, t, seqNum) - assert.Equal(t, uint64(0), seq) + require.Equal(t, uint64(0), seq) } // Helper function to reconnect a reader using channels @@ -315,18 +314,18 @@ func TestBackedReader_SequenceNumberTracking(t *testing.T) { n, err := br.Read(buf) require.NoError(t, err) - assert.Equal(t, 3, n) - assert.Equal(t, uint64(3), br.SequenceNum()) + require.Equal(t, 3, n) + require.Equal(t, uint64(3), br.SequenceNum()) n, err = br.Read(buf) require.NoError(t, err) - assert.Equal(t, 3, n) - assert.Equal(t, uint64(6), br.SequenceNum()) + require.Equal(t, 3, n) + require.Equal(t, uint64(6), br.SequenceNum()) n, err = br.Read(buf) require.NoError(t, err) - assert.Equal(t, 3, n) - assert.Equal(t, uint64(9), br.SequenceNum()) + require.Equal(t, 3, n) + require.Equal(t, uint64(9), br.SequenceNum()) } func TestBackedReader_EOFHandling(t *testing.T) { @@ -348,8 +347,8 @@ func TestBackedReader_EOFHandling(t *testing.T) { buf := make([]byte, 10) n, err := br.Read(buf) require.NoError(t, err) - assert.Equal(t, 4, n) - assert.Equal(t, "test", string(buf[:n])) + require.Equal(t, 4, n) + require.Equal(t, "test", string(buf[:n])) // Next read should encounter EOF, which triggers disconnection // The read should block waiting for reconnection @@ -364,10 +363,10 @@ func TestBackedReader_EOFHandling(t *testing.T) { // Wait for EOF to be reported via error callback receivedErr := testutil.RequireReceive(ctx, t, errorReceived) - assert.Equal(t, io.EOF, receivedErr) + require.Equal(t, io.EOF, receivedErr) // Reader should be disconnected after EOF - assert.False(t, br.Connected()) + require.False(t, br.Connected()) // Read should still be blocked select { @@ -384,8 +383,8 @@ func TestBackedReader_EOFHandling(t *testing.T) { // Wait for the blocked read to complete with new data testutil.TryReceive(ctx, t, readDone) require.NoError(t, readErr) - assert.Equal(t, 4, readN) - assert.Equal(t, "more", string(buf[:readN])) + require.Equal(t, 4, readN) + require.Equal(t, "more", string(buf[:readN])) } func BenchmarkBackedReader_Read(b *testing.B) { @@ -438,11 +437,11 @@ func TestBackedReader_PartialReads(t *testing.T) { for i := 0; i < 5; i++ { n, err := br.Read(buf) require.NoError(t, err) - assert.Equal(t, 1, n) - assert.Equal(t, byte('A'), buf[0]) + require.Equal(t, 1, n) + require.Equal(t, byte('A'), buf[0]) } - assert.Equal(t, uint64(5), br.SequenceNum()) + require.Equal(t, uint64(5), br.SequenceNum()) } func TestBackedReader_CloseWhileBlockedOnUnderlyingReader(t *testing.T) { @@ -525,14 +524,14 @@ func TestBackedReader_CloseWhileBlockedOnUnderlyingReader(t *testing.T) { // The read should return EOF because Close() was called while it was blocked, // even though the underlying reader returned an error - assert.Equal(t, 0, readN) - assert.Equal(t, io.EOF, readErr) + require.Equal(t, 0, readN) + require.Equal(t, io.EOF, readErr) // Subsequent reads should return EOF since the reader is now closed buf := make([]byte, 10) n, err := br.Read(buf) - assert.Equal(t, 0, n) - assert.Equal(t, io.EOF, err) + require.Equal(t, 0, n) + require.Equal(t, io.EOF, err) } func TestBackedReader_CloseWhileBlockedWaitingForReconnect(t *testing.T) { @@ -556,7 +555,7 @@ func TestBackedReader_CloseWhileBlockedWaitingForReconnect(t *testing.T) { buf := make([]byte, 10) n, err := br.Read(buf) require.NoError(t, err) - assert.Equal(t, "initial", string(buf[:n])) + require.Equal(t, "initial", string(buf[:n])) // Set up error callback to track connection failure errorReceived := make(chan error, 1) @@ -579,8 +578,8 @@ func TestBackedReader_CloseWhileBlockedWaitingForReconnect(t *testing.T) { // Wait for the error to be reported (indicating disconnection) receivedErr := testutil.RequireReceive(ctx, t, errorReceived) - assert.Error(t, receivedErr) - assert.Contains(t, receivedErr.Error(), "connection lost") + require.Error(t, receivedErr) + require.Contains(t, receivedErr.Error(), "connection lost") // Verify read is blocked waiting for reconnection select { @@ -591,7 +590,7 @@ func TestBackedReader_CloseWhileBlockedWaitingForReconnect(t *testing.T) { } // Verify reader is disconnected - assert.False(t, br.Connected()) + require.False(t, br.Connected()) // Close the BackedReader while read is blocked waiting for reconnection err = br.Close() @@ -599,6 +598,6 @@ func TestBackedReader_CloseWhileBlockedWaitingForReconnect(t *testing.T) { // The read should unblock and return EOF testutil.TryReceive(ctx, t, readDone) - assert.Equal(t, 0, readN) - assert.Equal(t, io.EOF, readErr) + require.Equal(t, 0, readN) + require.Equal(t, io.EOF, readErr) } diff --git a/agent/immortalstreams/backedpipe/backed_writer_test.go b/agent/immortalstreams/backedpipe/backed_writer_test.go index 928cea89b361d..0d7d29da195e9 100644 --- a/agent/immortalstreams/backedpipe/backed_writer_test.go +++ b/agent/immortalstreams/backedpipe/backed_writer_test.go @@ -359,7 +359,23 @@ func TestBackedWriter_BufferEviction(t *testing.T) { require.NoError(t, err) require.Equal(t, 2, n) - // Buffer should contain "cdefg" (latest data) + // Verify that the buffer contains only the latest data after eviction + // Total sequence number should be 7 (5 + 2) + require.Equal(t, uint64(7), bw.SequenceNum()) + + // Try to reconnect from the beginning - this should fail because + // the early data was evicted from the buffer + writer2 := newMockWriter() + err = bw.Reconnect(0, writer2) + require.Error(t, err) + require.ErrorIs(t, err, backedpipe.ErrReplayDataUnavailable) + + // However, reconnecting from a sequence that's still in the buffer should work + // The buffer should contain the last 5 bytes: "cdefg" + writer3 := newMockWriter() + err = bw.Reconnect(2, writer3) // From sequence 2, should replay "cdefg" + require.NoError(t, err) + require.Equal(t, []byte("cdefg"), writer3.buffer.Bytes()) } func TestBackedWriter_Close(t *testing.T) { From d0ab610b2de21752f074f7d5498c6d63a8a9519d Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Fri, 15 Aug 2025 00:03:19 +0000 Subject: [PATCH 10/11] added holding the mutex during reconnection and test cases to verify being closed/reconnected race conditions --- .../backedpipe/backed_writer.go | 6 +- .../backedpipe/backed_writer_test.go | 198 ++++++++++++++++++ 2 files changed, 200 insertions(+), 4 deletions(-) diff --git a/agent/immortalstreams/backedpipe/backed_writer.go b/agent/immortalstreams/backedpipe/backed_writer.go index 894aaa4240118..1e30fe06194ad 100644 --- a/agent/immortalstreams/backedpipe/backed_writer.go +++ b/agent/immortalstreams/backedpipe/backed_writer.go @@ -172,12 +172,10 @@ func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) err // Clear the current writer first in case replay fails bw.writer = nil - // Replay data if needed. We keep the writer as nil during replay to ensure - // no concurrent writes can happen, then set it only after successful replay. + // Replay data if needed. We keep the mutex held during replay to ensure + // no concurrent operations can interfere with the reconnection process. if len(replayData) > 0 { - bw.mu.Unlock() n, err := newWriter.Write(replayData) - bw.mu.Lock() if err != nil { // Reconnect failed, writer remains nil diff --git a/agent/immortalstreams/backedpipe/backed_writer_test.go b/agent/immortalstreams/backedpipe/backed_writer_test.go index 0d7d29da195e9..42848422a1ede 100644 --- a/agent/immortalstreams/backedpipe/backed_writer_test.go +++ b/agent/immortalstreams/backedpipe/backed_writer_test.go @@ -651,6 +651,204 @@ func TestBackedWriter_WriteBlocksAfterDisconnection(t *testing.T) { require.Equal(t, []byte("world"), writer2.buffer.Bytes()) // Only "world" since we replayed from sequence 5 } +func TestBackedWriter_ConcurrentWriteAndClose(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + writer := newMockWriter() + bw.Reconnect(0, writer) + + // Start a write operation that will be interrupted by close + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + // Write some data that should succeed + n, writeErr = bw.Write([]byte("hello")) + }() + + // Give write a chance to start + time.Sleep(10 * time.Millisecond) + + // Close the writer + closeErr := bw.Close() + require.NoError(t, closeErr) + + // Wait for write to complete + <-writeComplete + + // Write should have either succeeded (if it completed before close) + // or failed with EOF (if close interrupted it) + if writeErr == nil { + require.Equal(t, 5, n) + } else { + require.ErrorIs(t, writeErr, io.EOF) + } + + // Subsequent writes should fail + n, err := bw.Write([]byte("world")) + require.Equal(t, 0, n) + require.ErrorIs(t, err, io.EOF) +} + +func TestBackedWriter_ConcurrentWriteAndReconnect(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Initial connection + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Write some initial data + _, err = bw.Write([]byte("initial")) + require.NoError(t, err) + + // Start a write operation that will be blocked by reconnect + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + // This write should be blocked during reconnect + n, writeErr = bw.Write([]byte("blocked")) + }() + + // Give write a chance to start + time.Sleep(10 * time.Millisecond) + + // Start reconnection which will cause the write to wait + writer2 := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + // Simulate slow replay + time.Sleep(50 * time.Millisecond) + return len(p), nil + }, + } + + reconnectErr := bw.Reconnect(0, writer2) + require.NoError(t, reconnectErr) + + // Wait for write to complete + <-writeComplete + + // Write should succeed after reconnection completes + require.NoError(t, writeErr) + require.Equal(t, 7, n) // "blocked" is 7 bytes + + // Verify the writer is connected + require.True(t, bw.Connected()) +} + +func TestBackedWriter_ConcurrentReconnectAndClose(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Initial connection and write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + _, err = bw.Write([]byte("test data")) + require.NoError(t, err) + + // Start reconnection with slow replay + reconnectComplete := make(chan struct{}) + var reconnectErr error + + go func() { + defer close(reconnectComplete) + writer2 := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + // Simulate slow replay - this should be interrupted by close + time.Sleep(100 * time.Millisecond) + return len(p), nil + }, + } + reconnectErr = bw.Reconnect(0, writer2) + }() + + // Give reconnect a chance to start + time.Sleep(10 * time.Millisecond) + + // Close while reconnection is in progress + closeErr := bw.Close() + require.NoError(t, closeErr) + + // Wait for reconnect to complete + <-reconnectComplete + + // With mutex held during replay, Close() waits for Reconnect() to finish. + // So Reconnect() should succeed, then Close() runs and closes the writer. + require.NoError(t, reconnectErr) + + // Verify writer is closed (Close() ran after Reconnect() completed) + require.False(t, bw.Connected()) +} + +func TestBackedWriter_MultipleWritesDuringReconnect(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Initial connection + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Write some initial data + _, err = bw.Write([]byte("initial")) + require.NoError(t, err) + + // Start multiple write operations + numWriters := 5 + var wg sync.WaitGroup + writeResults := make([]error, numWriters) + + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + data := []byte{byte('A' + id)} + _, writeResults[id] = bw.Write(data) + }(i) + } + + // Give writes a chance to start + time.Sleep(10 * time.Millisecond) + + // Start reconnection with slow replay + writer2 := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + // Simulate slow replay + time.Sleep(50 * time.Millisecond) + return len(p), nil + }, + } + + reconnectErr := bw.Reconnect(0, writer2) + require.NoError(t, reconnectErr) + + // Wait for all writes to complete + wg.Wait() + + // All writes should succeed + for i, err := range writeResults { + require.NoError(t, err, "Write %d should succeed", i) + } + + // Verify the writer is connected + require.True(t, bw.Connected()) +} + func BenchmarkBackedWriter_Write(b *testing.B) { errorChan := make(chan error, 1) bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) // 64KB buffer From b2188f974aefd1e308ee5521b12bff40fc6646d7 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Fri, 15 Aug 2025 00:22:13 +0000 Subject: [PATCH 11/11] writer's tets cleanup --- .../backedpipe/backed_reader_test.go | 1 + .../backedpipe/backed_writer_test.go | 41 ++----------------- 2 files changed, 5 insertions(+), 37 deletions(-) diff --git a/agent/immortalstreams/backedpipe/backed_reader_test.go b/agent/immortalstreams/backedpipe/backed_reader_test.go index 019511bbcfc64..509a517a36583 100644 --- a/agent/immortalstreams/backedpipe/backed_reader_test.go +++ b/agent/immortalstreams/backedpipe/backed_reader_test.go @@ -217,6 +217,7 @@ func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { // Wait for read to unblock and succeed with new data readErr := testutil.RequireReceive(ctx, t, readDone) require.NoError(t, readErr) // Should succeed with new reader + require.True(t, br.Connected()) } func TestBackedReader_Close(t *testing.T) { diff --git a/agent/immortalstreams/backedpipe/backed_writer_test.go b/agent/immortalstreams/backedpipe/backed_writer_test.go index 42848422a1ede..463d9f32d544f 100644 --- a/agent/immortalstreams/backedpipe/backed_writer_test.go +++ b/agent/immortalstreams/backedpipe/backed_writer_test.go @@ -138,6 +138,7 @@ func TestBackedWriter_WriteToUnderlyingWhenConnected(t *testing.T) { require.Equal(t, 5, n) // Data should be buffered + require.Equal(t, uint64(5), bw.SequenceNum()) // Check underlying writer require.Equal(t, []byte("hello"), writer.buffer.Bytes()) @@ -158,7 +159,7 @@ func TestBackedWriter_BlockOnWriteFailure(t *testing.T) { // Cause write to fail writer.setError(xerrors.New("write failed")) - // Write should block when underlying writer fails + // Write should block when underlying writer fails, not succeed immediately writeComplete := make(chan struct{}) var writeErr error var n int @@ -376,6 +377,7 @@ func TestBackedWriter_BufferEviction(t *testing.T) { err = bw.Reconnect(2, writer3) // From sequence 2, should replay "cdefg" require.NoError(t, err) require.Equal(t, []byte("cdefg"), writer3.buffer.Bytes()) + require.True(t, bw.Connected()) } func TestBackedWriter_Close(t *testing.T) { @@ -414,39 +416,6 @@ func TestBackedWriter_CloseIdempotent(t *testing.T) { require.NoError(t, err) } -func TestBackedWriter_ConcurrentWrites(t *testing.T) { - t.Parallel() - - errorChan := make(chan error, 1) - bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) - writer := newMockWriter() - bw.Reconnect(0, writer) - - var wg sync.WaitGroup - numWriters := 10 - writesPerWriter := 50 - - for i := 0; i < numWriters; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - for j := 0; j < writesPerWriter; j++ { - data := []byte{byte(id + '0')} - bw.Write(data) - } - }(i) - } - - wg.Wait() - - // Should have written expected amount to buffer - expectedBytes := uint64(numWriters * writesPerWriter) //nolint:gosec // Safe conversion: test constants with small values - require.Equal(t, expectedBytes, bw.SequenceNum()) - // Note: underlying writer may not receive all bytes due to potential disconnections - // during concurrent operations, but the buffer should track all writes - require.True(t, writer.Len() <= int(expectedBytes)) //nolint:gosec // Safe conversion: expectedBytes is calculated from small test values -} - func TestBackedWriter_ReconnectDuringReplay(t *testing.T) { t.Parallel() @@ -463,9 +432,7 @@ func TestBackedWriter_ReconnectDuringReplay(t *testing.T) { // Create a writer that fails during replay writer2 := &mockWriter{ - writeFunc: func(p []byte) (int, error) { - return 0, xerrors.New("replay failed") - }, + err: backedpipe.ErrReplayFailed, } err = bw.Reconnect(0, writer2)