diff --git a/agent/immortalstreams/backedpipe/backed_pipe.go b/agent/immortalstreams/backedpipe/backed_pipe.go new file mode 100644 index 0000000000000..8161b91533233 --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_pipe.go @@ -0,0 +1,303 @@ +package backedpipe + +import ( + "context" + "io" + "sync" + + "golang.org/x/sync/singleflight" + "golang.org/x/xerrors" +) + +const ( + // Default buffer capacity used by the writer - 64MB + DefaultBufferSize = 64 * 1024 * 1024 +) + +// ReconnectFunc is called when the BackedPipe needs to establish a new connection. +// It should: +// 1. Establish a new connection to the remote side +// 2. Exchange sequence numbers with the remote side +// 3. Return the new connection and the remote's current sequence number +// +// The writerSeqNum parameter is the local writer's current sequence number, +// which should be sent to the remote side so it knows where to resume reading from. +// +// The returned readerSeqNum should be the remote side's current sequence number, +// which indicates where the local reader should resume from. +type ReconnectFunc func(ctx context.Context, writerSeqNum uint64) (conn io.ReadWriteCloser, readerSeqNum uint64, err error) + +// BackedPipe provides a reliable bidirectional byte stream over unreliable network connections. +// It orchestrates a BackedReader and BackedWriter to provide transparent reconnection +// and data replay capabilities. +type BackedPipe struct { + ctx context.Context + cancel context.CancelFunc + mu sync.RWMutex + reader *BackedReader + writer *BackedWriter + reconnectFn ReconnectFunc + conn io.ReadWriteCloser + connected bool + closed bool + + // Reconnection state + reconnecting bool + + // Error channel for receiving connection errors from reader/writer + errorChan chan error + + // Connection state notification + connectionChanged chan struct{} + + // singleflight group to dedupe concurrent ForceReconnect calls + sf singleflight.Group +} + +// NewBackedPipe creates a new BackedPipe with default options and the specified reconnect function. +// The pipe starts disconnected and must be connected using Connect(). +func NewBackedPipe(ctx context.Context, reconnectFn ReconnectFunc) *BackedPipe { + pipeCtx, cancel := context.WithCancel(ctx) + + errorChan := make(chan error, 2) // Buffer for reader and writer errors + bp := &BackedPipe{ + ctx: pipeCtx, + cancel: cancel, + reader: NewBackedReader(), + writer: NewBackedWriter(DefaultBufferSize, errorChan), + reconnectFn: reconnectFn, + errorChan: errorChan, + connectionChanged: make(chan struct{}, 1), + } + + // Set up error callback for reader only (writer uses error channel directly) + bp.reader.SetErrorCallback(func(err error) { + select { + case bp.errorChan <- err: + case <-bp.ctx.Done(): + } + }) + + // Start error handler goroutine + go bp.handleErrors() + + return bp +} + +// Connect establishes the initial connection using the reconnect function. +func (bp *BackedPipe) Connect(_ context.Context) error { // external ctx ignored; internal ctx used + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.closed { + return xerrors.New("pipe is closed") + } + + if bp.connected { + return xerrors.New("pipe is already connected") + } + + // Use internal context for the actual reconnect operation to ensure + // Close() reliably cancels any in-flight attempt. + return bp.reconnectLocked() +} + +// Read implements io.Reader by delegating to the BackedReader. +func (bp *BackedPipe) Read(p []byte) (int, error) { + bp.mu.RLock() + reader := bp.reader + closed := bp.closed + bp.mu.RUnlock() + + if closed { + return 0, io.ErrClosedPipe + } + + return reader.Read(p) +} + +// Write implements io.Writer by delegating to the BackedWriter. +func (bp *BackedPipe) Write(p []byte) (int, error) { + bp.mu.RLock() + writer := bp.writer + closed := bp.closed + bp.mu.RUnlock() + + if closed { + return 0, io.ErrClosedPipe + } + + return writer.Write(p) +} + +// Close closes the pipe and all underlying connections. +func (bp *BackedPipe) Close() error { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.closed { + return nil + } + + bp.closed = true + bp.cancel() // Cancel main context + + // Close underlying components + var readerErr, writerErr, connErr error + + if bp.reader != nil { + readerErr = bp.reader.Close() + } + + if bp.writer != nil { + writerErr = bp.writer.Close() + } + + if bp.conn != nil { + connErr = bp.conn.Close() + bp.conn = nil + } + + bp.connected = false + bp.signalConnectionChange() + + // Return first error encountered + if readerErr != nil { + return readerErr + } + if writerErr != nil { + return writerErr + } + return connErr +} + +// Connected returns whether the pipe is currently connected. +func (bp *BackedPipe) Connected() bool { + bp.mu.RLock() + defer bp.mu.RUnlock() + return bp.connected +} + +// signalConnectionChange signals that the connection state has changed. +func (bp *BackedPipe) signalConnectionChange() { + select { + case bp.connectionChanged <- struct{}{}: + default: + // Channel is full, which is fine - we just want to signal that something changed + } +} + +// reconnectLocked handles the reconnection logic. Must be called with write lock held. +func (bp *BackedPipe) reconnectLocked() error { + if bp.reconnecting { + return xerrors.New("reconnection already in progress") + } + + bp.reconnecting = true + defer func() { + bp.reconnecting = false + }() + + // Close existing connection if any + if bp.conn != nil { + _ = bp.conn.Close() + bp.conn = nil + } + + bp.connected = false + bp.signalConnectionChange() + + // Get current writer sequence number to send to remote + writerSeqNum := bp.writer.SequenceNum() + + // Unlock during reconnect attempt to avoid blocking reads/writes + bp.mu.Unlock() + conn, readerSeqNum, err := bp.reconnectFn(bp.ctx, writerSeqNum) + bp.mu.Lock() + + if err != nil { + return xerrors.Errorf("reconnect failed: %w", err) + } + + // Validate sequence numbers + if readerSeqNum > writerSeqNum { + _ = conn.Close() + return xerrors.Errorf("remote sequence number %d exceeds local sequence %d, cannot replay", + readerSeqNum, writerSeqNum) + } + + // Reconnect reader and writer + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go bp.reader.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + <-seqNum + newR <- conn + + err = bp.writer.Reconnect(readerSeqNum, conn) + if err != nil { + _ = conn.Close() + return xerrors.Errorf("reconnect writer: %w", err) + } + + // Success - update state + bp.conn = conn + bp.connected = true + bp.signalConnectionChange() + + return nil +} + +// handleErrors listens for connection errors from reader/writer and triggers reconnection. +func (bp *BackedPipe) handleErrors() { + for { + select { + case <-bp.ctx.Done(): + return + case err := <-bp.errorChan: + // Connection error occurred + bp.mu.Lock() + + // Skip if already closed or not connected + if bp.closed || !bp.connected { + bp.mu.Unlock() + continue + } + + // Mark as disconnected + bp.connected = false + bp.signalConnectionChange() + + // Try to reconnect using internal context + reconnectErr := bp.reconnectLocked() + bp.mu.Unlock() + + if reconnectErr != nil { + // Reconnection failed - log or handle as needed + // For now, we'll just continue and wait for manual reconnection + _ = err // Use the original error + } + } + } +} + +// ForceReconnect forces a reconnection attempt immediately. +// This can be used to force a reconnection if a new connection is established. +func (bp *BackedPipe) ForceReconnect() error { + // Deduplicate concurrent ForceReconnect calls so only one reconnection + // attempt runs at a time from this API. Use the pipe's internal context + // to ensure Close() cancels any in-flight attempt. + _, err, _ := bp.sf.Do("backedpipe-reconnect", func() (interface{}, error) { + bp.mu.Lock() + defer bp.mu.Unlock() + + if bp.closed { + return nil, io.ErrClosedPipe + } + + return nil, bp.reconnectLocked() + }) + return err +} diff --git a/agent/immortalstreams/backedpipe/backed_pipe_test.go b/agent/immortalstreams/backedpipe/backed_pipe_test.go new file mode 100644 index 0000000000000..1cd5bd227ebcd --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_pipe_test.go @@ -0,0 +1,720 @@ +package backedpipe_test + +import ( + "bytes" + "context" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" + "github.com/coder/coder/v2/testutil" +) + +// mockConnection implements io.ReadWriteCloser for testing +type mockConnection struct { + mu sync.Mutex + readBuffer bytes.Buffer + writeBuffer bytes.Buffer + closed bool + readError error + writeError error + closeError error + readFunc func([]byte) (int, error) + writeFunc func([]byte) (int, error) + seqNum uint64 +} + +func newMockConnection() *mockConnection { + return &mockConnection{} +} + +func (mc *mockConnection) Read(p []byte) (int, error) { + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.readFunc != nil { + return mc.readFunc(p) + } + + if mc.readError != nil { + return 0, mc.readError + } + + return mc.readBuffer.Read(p) +} + +func (mc *mockConnection) Write(p []byte) (int, error) { + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.writeFunc != nil { + return mc.writeFunc(p) + } + + if mc.writeError != nil { + return 0, mc.writeError + } + + return mc.writeBuffer.Write(p) +} + +func (mc *mockConnection) Close() error { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.closed = true + return mc.closeError +} + +func (mc *mockConnection) WriteString(s string) { + mc.mu.Lock() + defer mc.mu.Unlock() + _, _ = mc.readBuffer.WriteString(s) +} + +func (mc *mockConnection) ReadString() string { + mc.mu.Lock() + defer mc.mu.Unlock() + return mc.writeBuffer.String() +} + +func (mc *mockConnection) SetReadError(err error) { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.readError = err +} + +func (mc *mockConnection) SetWriteError(err error) { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.writeError = err +} + +func (mc *mockConnection) Reset() { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.readBuffer.Reset() + mc.writeBuffer.Reset() + mc.readError = nil + mc.writeError = nil + mc.closed = false +} + +// mockReconnectFunc creates a unified reconnect function with all behaviors enabled +func mockReconnectFunc(connections ...*mockConnection) (backedpipe.ReconnectFunc, *int, chan struct{}) { + connectionIndex := 0 + callCount := 0 + signalChan := make(chan struct{}, 1) + + reconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + callCount++ + + if connectionIndex >= len(connections) { + return nil, 0, xerrors.New("no more connections available") + } + + conn := connections[connectionIndex] + connectionIndex++ + + // Signal when reconnection happens + if connectionIndex > 1 { + select { + case signalChan <- struct{}{}: + default: + } + } + + // Determine readerSeqNum based on call count + var readerSeqNum uint64 + switch { + case callCount == 1: + readerSeqNum = 0 + case conn.seqNum != 0: + readerSeqNum = conn.seqNum + default: + readerSeqNum = writerSeqNum + } + + return conn, readerSeqNum, nil + } + + return reconnectFn, &callCount, signalChan +} + +func TestBackedPipe_NewBackedPipe(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reconnectFn, _, _ := mockReconnectFunc(newMockConnection()) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + require.NotNil(t, bp) + require.False(t, bp.Connected()) +} + +func TestBackedPipe_Connect(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, callCount, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + err := bp.Connect(ctx) + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, *callCount) +} + +func TestBackedPipe_ConnectAlreadyConnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + err := bp.Connect(ctx) + require.NoError(t, err) + + // Second connect should fail + err = bp.Connect(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "already connected") +} + +func TestBackedPipe_ConnectAfterClose(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Close() + require.NoError(t, err) + + err = bp.Connect(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "closed") +} + +func TestBackedPipe_BasicReadWrite(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + err := bp.Connect(ctx) + require.NoError(t, err) + + // Write data + n, err := bp.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Simulate data coming back + conn.WriteString("world") + + // Read data + buf := make([]byte, 10) + n, err = bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "world", string(buf[:n])) +} + +func TestBackedPipe_WriteBeforeConnect(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + // Write before connecting should block + writeComplete := make(chan error, 1) + go func() { + _, err := bp.Write([]byte("hello")) + writeComplete <- err + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(100 * time.Millisecond): + // Expected - write is blocked + } + + // Connect should unblock the write + err := bp.Connect(ctx) + require.NoError(t, err) + + // Write should now complete + err = testutil.RequireReceive(ctx, t, writeComplete) + require.NoError(t, err) + + // Check that data was replayed to connection + require.Equal(t, "hello", conn.ReadString()) +} + +func TestBackedPipe_ReadBlocksWhenDisconnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testCtx := testutil.Context(t, testutil.WaitShort) + reconnectFn, _, _ := mockReconnectFunc(newMockConnection()) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + // Start a read that should block + readDone := make(chan struct{}) + readStarted := make(chan struct{}) + var readErr error + + go func() { + defer close(readDone) + close(readStarted) // Signal that we're about to start the read + buf := make([]byte, 10) + _, readErr = bp.Read(buf) + }() + + // Wait for the goroutine to start + testutil.TryReceive(testCtx, t, readStarted) + + // Give a brief moment for the read to actually block + time.Sleep(time.Millisecond) + + // Read should still be blocked + select { + case <-readDone: + t.Fatal("Read should be blocked when disconnected") + default: + // Good, still blocked + } + + // Close should unblock the read + bp.Close() + + testutil.TryReceive(testCtx, t, readDone) + require.Equal(t, io.EOF, readErr) +} + +func TestBackedPipe_Reconnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testCtx := testutil.Context(t, testutil.WaitShort) + conn1 := newMockConnection() + conn2 := newMockConnection() + conn2.seqNum = 17 // Remote has received 17 bytes, so replay from sequence 17 + reconnectFn, _, signalChan := mockReconnectFunc(conn1, conn2) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + // Initial connect + err := bp.Connect(ctx) + require.NoError(t, err) + + // Write some data before failure + bp.Write([]byte("before disconnect***")) + + // Simulate connection failure + conn1.SetReadError(xerrors.New("connection lost")) + conn1.SetWriteError(xerrors.New("connection lost")) + + // Trigger a write to cause the pipe to notice the failure + _, _ = bp.Write([]byte("trigger failure ")) + + testutil.RequireReceive(testCtx, t, signalChan) + + // Wait for reconnection to complete + require.Eventually(t, func() bool { + return bp.Connected() + }, testutil.WaitShort, testutil.IntervalFast, "pipe should reconnect") + + replayedData := conn2.ReadString() + require.Equal(t, "***trigger failure ", replayedData, "Should replay exactly the data written after sequence 17") + + // Verify that new writes work with the reconnected pipe + _, err = bp.Write([]byte("new data after reconnect")) + require.NoError(t, err) + + // Read all data from the connection (replayed + new data) + allData := conn2.ReadString() + require.Equal(t, "***trigger failure new data after reconnect", allData, "Should have replayed data plus new data") +} + +func TestBackedPipe_Close(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Connect(ctx) + require.NoError(t, err) + + err = bp.Close() + require.NoError(t, err) + require.True(t, conn.closed) + + // Operations after close should fail + _, err = bp.Read(make([]byte, 10)) + require.Equal(t, io.ErrClosedPipe, err) + + _, err = bp.Write([]byte("test")) + require.Equal(t, io.ErrClosedPipe, err) +} + +func TestBackedPipe_CloseIdempotent(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + err := bp.Close() + require.NoError(t, err) + + // Second close should be no-op + err = bp.Close() + require.NoError(t, err) +} + +func TestBackedPipe_ConcurrentReadWrite(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + err := bp.Connect(ctx) + require.NoError(t, err) + + var wg sync.WaitGroup + numWriters := 3 + writesPerWriter := 10 + + // Fill read buffer with test data first + testData := make([]byte, 1000) + for i := range testData { + testData[i] = 'A' + } + conn.WriteString(string(testData)) + + // Channel to collect all written data + writtenData := make(chan byte, numWriters*writesPerWriter) + + // Start a few readers + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 10) + for j := 0; j < 10; j++ { + bp.Read(buf) + time.Sleep(time.Millisecond) // Small delay to avoid busy waiting + } + }() + } + + // Start writers + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < writesPerWriter; j++ { + data := []byte{byte(id + '0')} + bp.Write(data) + writtenData <- byte(id + '0') + time.Sleep(time.Millisecond) // Small delay + } + }(i) + } + + // Wait with timeout + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + testutil.TryReceive(ctx, t, done) + + // Close the channel and collect all written data + close(writtenData) + var allWritten []byte + for b := range writtenData { + allWritten = append(allWritten, b) + } + + // Verify that all written data was received by the connection + // Note: Since this test uses the old mock that returns readerSeqNum = 0, + // all data will be replayed, so we expect to receive all written data + receivedData := conn.ReadString() + require.GreaterOrEqual(t, len(receivedData), len(allWritten), "Connection should have received at least all written data") + + // Check that all written bytes appear in the received data + for _, writtenByte := range allWritten { + require.Contains(t, receivedData, string(writtenByte), "Written byte %c should be present in received data", writtenByte) + } +} + +func TestBackedPipe_ReconnectFunctionFailure(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + failingReconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + return nil, 0, xerrors.New("reconnect failed") + } + + bp := backedpipe.NewBackedPipe(ctx, failingReconnectFn) + defer bp.Close() + + err := bp.Connect(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "reconnect failed") + require.False(t, bp.Connected()) +} + +func TestBackedPipe_ForceReconnect(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn1 := newMockConnection() + conn2 := newMockConnection() + reconnectFn, callCount, _ := mockReconnectFunc(conn1, conn2) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + // Initial connect + err := bp.Connect(ctx) + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, *callCount) + + // Write some data to the first connection + _, err = bp.Write([]byte("test data")) + require.NoError(t, err) + require.Equal(t, "test data", conn1.ReadString()) + + // Force a reconnection + err = bp.ForceReconnect() + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 2, *callCount) + + // Since the mock now returns the proper sequence number, no data should be replayed + // The new connection should be empty + require.Equal(t, "", conn2.ReadString()) + + // Verify that data can still be written and read after forced reconnection + _, err = bp.Write([]byte("new data")) + require.NoError(t, err) + require.Equal(t, "new data", conn2.ReadString()) + + // Verify that reads work with the new connection + conn2.WriteString("response data") + buf := make([]byte, 20) + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 13, n) + require.Equal(t, "response data", string(buf[:n])) +} + +func TestBackedPipe_ForceReconnectWhenClosed(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + + // Close the pipe first + err := bp.Close() + require.NoError(t, err) + + // Try to force reconnect when closed + err = bp.ForceReconnect() + require.Error(t, err) + require.Equal(t, io.ErrClosedPipe, err) +} + +func TestBackedPipe_ForceReconnectWhenDisconnected(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newMockConnection() + reconnectFn, callCount, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + // Don't connect initially, just force reconnect + err := bp.ForceReconnect() + require.NoError(t, err) + require.True(t, bp.Connected()) + require.Equal(t, 1, *callCount) + + // Verify we can write and read + _, err = bp.Write([]byte("test")) + require.NoError(t, err) + require.Equal(t, "test", conn.ReadString()) + + conn.WriteString("response") + buf := make([]byte, 10) + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 8, n) + require.Equal(t, "response", string(buf[:n])) +} + +func TestBackedPipe_EOFTriggersReconnection(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // Create connections where we can control when EOF occurs + conn1 := newMockConnection() + conn2 := newMockConnection() + conn2.WriteString("newdata") // Pre-populate conn2 with data + + // Make conn1 return EOF after reading "world" + hasReadData := false + conn1.readFunc = func(p []byte) (int, error) { + // Don't lock here - the Read method already holds the lock + + // First time: return "world" + if !hasReadData && conn1.readBuffer.Len() > 0 { + n, _ := conn1.readBuffer.Read(p) + hasReadData = true + return n, nil + } + // After that: return EOF + return 0, io.EOF + } + conn1.WriteString("world") + + callCount := 0 + reconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) { + callCount++ + + if callCount == 1 { + return conn1, 0, nil + } + if callCount == 2 { + // Second call is the reconnection after EOF + return conn2, writerSeqNum, nil // conn2 already has the reader sequence at writerSeqNum + } + + return nil, 0, xerrors.New("no more connections") + } + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + defer bp.Close() + + // Initial connect + err := bp.Connect(ctx) + require.NoError(t, err) + require.Equal(t, 1, callCount) + + // Write some data + _, err = bp.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + + // First read should succeed + n, err := bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 5, n) + require.Equal(t, "world", string(buf[:n])) + + // Next read will encounter EOF and should trigger reconnection + // After reconnection, it should read from conn2 + n, err = bp.Read(buf) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, "newdata", string(buf[:n])) + + // Verify reconnection happened + require.Equal(t, 2, callCount) + + // Verify the pipe is still connected and functional + require.True(t, bp.Connected()) + + // Further writes should go to the new connection + _, err = bp.Write([]byte("aftereof")) + require.NoError(t, err) + require.Equal(t, "aftereof", conn2.ReadString()) +} + +func BenchmarkBackedPipe_Write(b *testing.B) { + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + bp.Connect(ctx) + b.Cleanup(func() { + _ = bp.Close() + }) + + data := make([]byte, 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bp.Write(data) + } +} + +func BenchmarkBackedPipe_Read(b *testing.B) { + ctx := context.Background() + conn := newMockConnection() + reconnectFn, _, _ := mockReconnectFunc(conn) + + bp := backedpipe.NewBackedPipe(ctx, reconnectFn) + bp.Connect(ctx) + b.Cleanup(func() { + _ = bp.Close() + }) + + buf := make([]byte, 1024) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Fill connection with fresh data for each iteration + conn.WriteString(string(buf)) + bp.Read(buf) + } +} diff --git a/agent/immortalstreams/backedpipe/backed_reader.go b/agent/immortalstreams/backedpipe/backed_reader.go new file mode 100644 index 0000000000000..4632cafc92e8f --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_reader.go @@ -0,0 +1,147 @@ +package backedpipe + +import ( + "io" + "sync" +) + +// BackedReader wraps an unreliable io.Reader and makes it resilient to disconnections. +// It tracks sequence numbers for all bytes read and can handle reconnection, +// blocking reads when disconnected instead of erroring. +type BackedReader struct { + mu sync.Mutex + cond *sync.Cond + reader io.Reader + sequenceNum uint64 + closed bool + + // Error callback to notify parent when connection fails + onError func(error) +} + +// NewBackedReader creates a new BackedReader. The reader is initially disconnected +// and must be connected using Reconnect before reads will succeed. +func NewBackedReader() *BackedReader { + br := &BackedReader{} + br.cond = sync.NewCond(&br.mu) + return br +} + +// Read implements io.Reader. It blocks when disconnected until either: +// 1. A reconnection is established +// 2. The reader is closed +// +// When connected, it reads from the underlying reader and updates sequence numbers. +// Connection failures are automatically detected and reported to the higher layer via callback. +func (br *BackedReader) Read(p []byte) (int, error) { + br.mu.Lock() + defer br.mu.Unlock() + + for { + // Step 1: Wait until we have a reader or are closed + for br.reader == nil && !br.closed { + br.cond.Wait() + } + + if br.closed { + return 0, io.EOF + } + + // Step 2: Perform the read while holding the mutex + // This ensures proper synchronization with Reconnect and Close operations + n, err := br.reader.Read(p) + br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract + + if err == nil { + return n, nil + } + + // Mark reader as disconnected so future reads will wait for reconnection + br.reader = nil + + if br.onError != nil { + br.onError(err) + } + + // If we got some data before the error, return it now + if n > 0 { + return n, nil + } + } +} + +// Reconnect coordinates reconnection using channels for better synchronization. +// The seqNum channel is used to send the current sequence number to the caller. +// The newR channel is used to receive the new reader from the caller. +// This allows for better coordination during the reconnection process. +func (br *BackedReader) Reconnect(seqNum chan<- uint64, newR <-chan io.Reader) { + // Grab the lock + br.mu.Lock() + defer br.mu.Unlock() + + if br.closed { + // Close the channel to indicate closed state + close(seqNum) + return + } + + // Get the sequence number to send to the other side via seqNum channel + seqNum <- br.sequenceNum + close(seqNum) + + // Wait for the reconnect to complete, via newR channel, and give us a new io.Reader + newReader := <-newR + + // If reconnection fails while we are starting it, the caller sends nil on newR + if newReader == nil { + // Reconnection failed, keep current state + return + } + + // Reconnection successful + br.reader = newReader + + // Notify any waiting reads via the cond + br.cond.Broadcast() +} + +// Close the reader and wake up any blocked reads. +// After closing, all Read calls will return io.EOF. +func (br *BackedReader) Close() error { + br.mu.Lock() + defer br.mu.Unlock() + + if br.closed { + return nil + } + + br.closed = true + br.reader = nil + + // Wake up any blocked reads + br.cond.Broadcast() + + return nil +} + +// SetErrorCallback sets the callback function that will be called when +// a connection error occurs (excluding EOF). +func (br *BackedReader) SetErrorCallback(fn func(error)) { + br.mu.Lock() + defer br.mu.Unlock() + br.onError = fn +} + +// SequenceNum returns the current sequence number (total bytes read). +func (br *BackedReader) SequenceNum() uint64 { + br.mu.Lock() + defer br.mu.Unlock() + return br.sequenceNum +} + +// Connected returns whether the reader is currently connected. +func (br *BackedReader) Connected() bool { + br.mu.Lock() + defer br.mu.Unlock() + return br.reader != nil +} diff --git a/agent/immortalstreams/backedpipe/backed_reader_test.go b/agent/immortalstreams/backedpipe/backed_reader_test.go new file mode 100644 index 0000000000000..25d2038d6d843 --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_reader_test.go @@ -0,0 +1,604 @@ +package backedpipe_test + +import ( + "context" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" + "github.com/coder/coder/v2/testutil" +) + +// mockReader implements io.Reader with controllable behavior for testing +type mockReader struct { + mu sync.Mutex + data []byte + pos int + err error + readFunc func([]byte) (int, error) +} + +func newMockReader(data string) *mockReader { + return &mockReader{data: []byte(data)} +} + +func (mr *mockReader) Read(p []byte) (int, error) { + mr.mu.Lock() + defer mr.mu.Unlock() + + if mr.readFunc != nil { + return mr.readFunc(p) + } + + if mr.err != nil { + return 0, mr.err + } + + if mr.pos >= len(mr.data) { + return 0, io.EOF + } + + n := copy(p, mr.data[mr.pos:]) + mr.pos += n + return n, nil +} + +func (mr *mockReader) setError(err error) { + mr.mu.Lock() + defer mr.mu.Unlock() + mr.err = err +} + +func TestBackedReader_NewBackedReader(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + assert.NotNil(t, br) + assert.Equal(t, uint64(0), br.SequenceNum()) + assert.False(t, br.Connected()) +} + +func TestBackedReader_BasicReadOperation(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + reader := newMockReader("hello world") + + // Connect the reader + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number from reader + seq := testutil.RequireReceive(ctx, t, seqNum) + assert.Equal(t, uint64(0), seq) + + // Send new reader + testutil.RequireSend(ctx, t, newR, io.Reader(reader)) + + // Read data + buf := make([]byte, 5) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, "hello", string(buf)) + assert.Equal(t, uint64(5), br.SequenceNum()) + + // Read more data + n, err = br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, " worl", string(buf)) + assert.Equal(t, uint64(10), br.SequenceNum()) +} + +func TestBackedReader_ReadBlocksWhenDisconnected(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + + // Start a read operation that should block + readDone := make(chan struct{}) + var readErr error + var readBuf []byte + var readN int + + go func() { + defer close(readDone) + buf := make([]byte, 10) + readN, readErr = br.Read(buf) + readBuf = buf[:readN] + }() + + // Give a brief moment for the read to actually start and block on the condition variable + time.Sleep(time.Millisecond) + + // Read should still be blocked + select { + case <-readDone: + t.Fatal("Read should be blocked when disconnected") + default: + // Good, still blocked + } + + // Connect and the read should unblock + reader := newMockReader("test") + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader)) + + // Wait for read to complete + testutil.TryReceive(ctx, t, readDone) + assert.NoError(t, readErr) + assert.Equal(t, "test", string(readBuf)) +} + +func TestBackedReader_ReconnectionAfterFailure(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + reader1 := newMockReader("first") + + // Initial connection + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader1)) + + // Read some data + buf := make([]byte, 5) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, "first", string(buf[:n])) + assert.Equal(t, uint64(5), br.SequenceNum()) + + // Set up error callback to verify error notification + errorReceived := make(chan error, 1) + br.SetErrorCallback(func(err error) { + errorReceived <- err + }) + + // Simulate connection failure + reader1.setError(xerrors.New("connection lost")) + + // Start a read that will block due to connection failure + readDone := make(chan error, 1) + go func() { + _, err := br.Read(buf) + readDone <- err + }() + + // Wait for the error to be reported via callback + receivedErr := testutil.RequireReceive(ctx, t, errorReceived) + assert.Error(t, receivedErr) + assert.Contains(t, receivedErr.Error(), "connection lost") + + // Verify read is still blocked + select { + case err := <-readDone: + t.Fatalf("Read should still be blocked, but completed with: %v", err) + default: + // Good, still blocked + } + + // Verify disconnection + assert.False(t, br.Connected()) + + // Reconnect with new reader + reader2 := newMockReader("second") + seqNum2 := make(chan uint64, 1) + newR2 := make(chan io.Reader, 1) + + go br.Reconnect(seqNum2, newR2) + + // Get sequence number and send new reader + seq := testutil.RequireReceive(ctx, t, seqNum2) + assert.Equal(t, uint64(5), seq) // Should return current sequence number + testutil.RequireSend(ctx, t, newR2, io.Reader(reader2)) + + // Wait for read to unblock and succeed with new data + readErr := testutil.RequireReceive(ctx, t, readDone) + assert.NoError(t, readErr) // Should succeed with new reader +} + +func TestBackedReader_Close(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + reader := newMockReader("test") + + // Connect + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader)) + + // First, read all available data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 4, n) // "test" is 4 bytes + + // Close the reader before EOF triggers reconnection + err = br.Close() + require.NoError(t, err) + + // After close, reads should return EOF + n, err = br.Read(buf) + assert.Equal(t, 0, n) + assert.Equal(t, io.EOF, err) + + // Subsequent reads should return EOF + _, err = br.Read(buf) + assert.Equal(t, io.EOF, err) +} + +func TestBackedReader_CloseIdempotent(t *testing.T) { + t.Parallel() + + br := backedpipe.NewBackedReader() + + err := br.Close() + assert.NoError(t, err) + + // Second close should be no-op + err = br.Close() + assert.NoError(t, err) +} + +func TestBackedReader_ReconnectAfterClose(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + + err := br.Close() + require.NoError(t, err) + + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Should get 0 sequence number for closed reader + seq := testutil.TryReceive(ctx, t, seqNum) + assert.Equal(t, uint64(0), seq) +} + +// Helper function to reconnect a reader using channels +func reconnectReader(ctx context.Context, t testing.TB, br *backedpipe.BackedReader, reader io.Reader) { + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send new reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, reader) +} + +func TestBackedReader_SequenceNumberTracking(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + reader := newMockReader("0123456789") + + reconnectReader(ctx, t, br, reader) + + // Read in chunks and verify sequence number + buf := make([]byte, 3) + + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, uint64(3), br.SequenceNum()) + + n, err = br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, uint64(6), br.SequenceNum()) + + n, err = br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, uint64(9), br.SequenceNum()) +} + +func TestBackedReader_EOFHandling(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + reader := newMockReader("test") + + // Set up error callback to track when EOF triggers disconnection + errorReceived := make(chan error, 1) + br.SetErrorCallback(func(err error) { + errorReceived <- err + }) + + reconnectReader(ctx, t, br, reader) + + // Read all data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "test", string(buf[:n])) + + // Next read should encounter EOF, which triggers disconnection + // The read should block waiting for reconnection + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + defer close(readDone) + readN, readErr = br.Read(buf) + }() + + // Wait for EOF to be reported via error callback + receivedErr := testutil.RequireReceive(ctx, t, errorReceived) + assert.Equal(t, io.EOF, receivedErr) + + // Reader should be disconnected after EOF + assert.False(t, br.Connected()) + + // Read should still be blocked + select { + case <-readDone: + t.Fatal("Read should be blocked waiting for reconnection after EOF") + default: + // Good, still blocked + } + + // Reconnect with new data + reader2 := newMockReader("more") + reconnectReader(ctx, t, br, reader2) + + // Wait for the blocked read to complete with new data + testutil.TryReceive(ctx, t, readDone) + require.NoError(t, readErr) + assert.Equal(t, 4, readN) + assert.Equal(t, "more", string(buf[:readN])) +} + +func BenchmarkBackedReader_Read(b *testing.B) { + br := backedpipe.NewBackedReader() + buf := make([]byte, 1024) + + // Create a reader that never returns EOF by cycling through data + reader := &mockReader{ + readFunc: func(p []byte) (int, error) { + // Fill buffer with 'x' characters - never EOF + for i := range p { + p[i] = 'x' + } + return len(p), nil + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + reconnectReader(ctx, b, br, reader) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + br.Read(buf) + } +} + +func TestBackedReader_PartialReads(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + + // Create a reader that returns partial reads + reader := &mockReader{ + readFunc: func(p []byte) (int, error) { + // Always return just 1 byte at a time + if len(p) == 0 { + return 0, nil + } + p[0] = 'A' + return 1, nil + }, + } + + reconnectReader(ctx, t, br, reader) + + // Read multiple times + buf := make([]byte, 10) + for i := 0; i < 5; i++ { + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, 1, n) + assert.Equal(t, byte('A'), buf[0]) + } + + assert.Equal(t, uint64(5), br.SequenceNum()) +} + +func TestBackedReader_CloseWhileBlockedOnUnderlyingReader(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + + // Create a reader that blocks on Read calls but can be unblocked + readStarted := make(chan struct{}, 1) + readUnblocked := make(chan struct{}) + blockingReader := &mockReader{ + readFunc: func(p []byte) (int, error) { + select { + case readStarted <- struct{}{}: + default: + } + <-readUnblocked // Block until signaled + // After unblocking, return an error to simulate connection failure + return 0, xerrors.New("connection interrupted") + }, + } + + // Connect the blocking reader + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send blocking reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(blockingReader)) + + // Start a read that will block on the underlying reader + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + defer close(readDone) + buf := make([]byte, 10) + readN, readErr = br.Read(buf) + }() + + // Wait for the read to start and block on the underlying reader + testutil.RequireReceive(ctx, t, readStarted) + + // Give a brief moment for the read to actually block + time.Sleep(time.Millisecond) + + // Verify read is blocked + select { + case <-readDone: + t.Fatal("Read should be blocked on underlying reader") + default: + // Good, still blocked + } + + // Start Close() in a goroutine since it will block until the underlying read completes + closeDone := make(chan error, 1) + go func() { + closeDone <- br.Close() + }() + + // Verify Close() is also blocked waiting for the underlying read + select { + case <-closeDone: + t.Fatal("Close should be blocked until underlying read completes") + case <-time.After(10 * time.Millisecond): + // Good, Close is blocked + } + + // Unblock the underlying reader, which will cause both the read and close to complete + close(readUnblocked) + + // Wait for both the read and close to complete + testutil.TryReceive(ctx, t, readDone) + closeErr := testutil.RequireReceive(ctx, t, closeDone) + require.NoError(t, closeErr) + + // The read should return EOF because Close() was called while it was blocked, + // even though the underlying reader returned an error + assert.Equal(t, 0, readN) + assert.Equal(t, io.EOF, readErr) + + // Subsequent reads should return EOF since the reader is now closed + buf := make([]byte, 10) + n, err := br.Read(buf) + assert.Equal(t, 0, n) + assert.Equal(t, io.EOF, err) +} + +func TestBackedReader_CloseWhileBlockedWaitingForReconnect(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + br := backedpipe.NewBackedReader() + reader1 := newMockReader("initial") + + // Initial connection + seqNum := make(chan uint64, 1) + newR := make(chan io.Reader, 1) + + go br.Reconnect(seqNum, newR) + + // Get sequence number and send initial reader + testutil.RequireReceive(ctx, t, seqNum) + testutil.RequireSend(ctx, t, newR, io.Reader(reader1)) + + // Read initial data + buf := make([]byte, 10) + n, err := br.Read(buf) + require.NoError(t, err) + assert.Equal(t, "initial", string(buf[:n])) + + // Set up error callback to track connection failure + errorReceived := make(chan error, 1) + br.SetErrorCallback(func(err error) { + errorReceived <- err + }) + + // Simulate connection failure + reader1.setError(xerrors.New("connection lost")) + + // Start a read that will block waiting for reconnection + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + defer close(readDone) + readN, readErr = br.Read(buf) + }() + + // Wait for the error to be reported (indicating disconnection) + receivedErr := testutil.RequireReceive(ctx, t, errorReceived) + assert.Error(t, receivedErr) + assert.Contains(t, receivedErr.Error(), "connection lost") + + // Verify read is blocked waiting for reconnection + select { + case <-readDone: + t.Fatal("Read should be blocked waiting for reconnection") + default: + // Good, still blocked + } + + // Verify reader is disconnected + assert.False(t, br.Connected()) + + // Close the BackedReader while read is blocked waiting for reconnection + err = br.Close() + require.NoError(t, err) + + // The read should unblock and return EOF + testutil.TryReceive(ctx, t, readDone) + assert.Equal(t, 0, readN) + assert.Equal(t, io.EOF, readErr) +} diff --git a/agent/immortalstreams/backedpipe/backed_writer.go b/agent/immortalstreams/backedpipe/backed_writer.go new file mode 100644 index 0000000000000..707abb16271a8 --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_writer.go @@ -0,0 +1,204 @@ +package backedpipe + +import ( + "io" + "sync" + + "golang.org/x/xerrors" +) + +// BackedWriter wraps an unreliable io.Writer and makes it resilient to disconnections. +// It maintains a ring buffer of recent writes for replay during reconnection. +type BackedWriter struct { + mu sync.Mutex + cond *sync.Cond + writer io.Writer + buffer *ringBuffer + sequenceNum uint64 // total bytes written + closed bool + + // Error channel to notify parent when connection fails + errorChan chan<- error +} + +// NewBackedWriter creates a new BackedWriter with the specified buffer capacity. +// The writer is initially disconnected and will block writes until connected. +// The errorChan is required and will receive connection errors. +// Capacity must be > 0. +func NewBackedWriter(capacity int, errorChan chan<- error) *BackedWriter { + if capacity <= 0 { + panic("backed writer capacity must be > 0") + } + if errorChan == nil { + panic("error channel cannot be nil") + } + bw := &BackedWriter{ + buffer: newRingBuffer(capacity), + errorChan: errorChan, + } + bw.cond = sync.NewCond(&bw.mu) + return bw +} + +// Write implements io.Writer. +// When connected, it writes to both the ring buffer and the underlying writer. +// If the underlying write fails, the writer is marked as disconnected and the write blocks +// until reconnection occurs. +func (bw *BackedWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return 0, io.ErrClosedPipe + } + + // Block until connected + for bw.writer == nil && !bw.closed { + bw.cond.Wait() + } + + // Check if we were closed while waiting + if bw.closed { + return 0, io.ErrClosedPipe + } + + // Always write to buffer first + bw.buffer.Write(p) + // Always advance sequence number by the full length + bw.sequenceNum += uint64(len(p)) + + // Write to underlying writer + n, err := bw.writer.Write(p) + if err != nil { + // Connection failed, mark as disconnected + bw.writer = nil + + // Notify parent of error + select { + case bw.errorChan <- err: + default: + } + return 0, err + } + + if n != len(p) { + bw.writer = nil + err = xerrors.Errorf("short write: %d bytes written, %d expected", n, len(p)) + select { + case bw.errorChan <- err: + default: + } + return 0, err + } + + return len(p), nil +} + +// Reconnect replaces the current writer with a new one and replays data from the specified +// sequence number. If the requested sequence number is no longer in the buffer, +// returns an error indicating data loss. +func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) error { + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return xerrors.New("cannot reconnect closed writer") + } + + if newWriter == nil { + return xerrors.New("new writer cannot be nil") + } + + // Check if we can replay from the requested sequence number + if replayFromSeq > bw.sequenceNum { + return xerrors.Errorf("cannot replay from future sequence %d: current sequence is %d", replayFromSeq, bw.sequenceNum) + } + + // Calculate how many bytes we need to replay + replayBytes := bw.sequenceNum - replayFromSeq + + var replayData []byte + if replayBytes > 0 { + // Get the last replayBytes from buffer + // If the buffer doesn't have enough data (some was evicted), + // ReadLast will return an error + var err error + // Safe conversion: replayBytes is always non-negative due to the check above + // No overflow possible since replayBytes is calculated as sequenceNum - replayFromSeq + // and uint64->int conversion is safe for reasonable buffer sizes + //nolint:gosec // Safe conversion: replayBytes is calculated from uint64 subtraction + replayData, err = bw.buffer.ReadLast(int(replayBytes)) + if err != nil { + return xerrors.Errorf("failed to read replay data: %w", err) + } + } + + // Clear the current writer first in case replay fails + bw.writer = nil + + // Replay data if needed. We keep the writer as nil during replay to ensure + // no concurrent writes can happen, then set it only after successful replay. + if len(replayData) > 0 { + bw.mu.Unlock() + n, err := newWriter.Write(replayData) + bw.mu.Lock() + + if err != nil { + // Reconnect failed, writer remains nil + return xerrors.Errorf("replay failed: %w", err) + } + + if n != len(replayData) { + // Reconnect failed, writer remains nil + return xerrors.Errorf("partial replay: wrote %d of %d bytes", n, len(replayData)) + } + } + + // Set new writer only after successful replay. This ensures no concurrent + // writes can interfere with the replay operation. + bw.writer = newWriter + + // Wake up any operations waiting for connection + bw.cond.Broadcast() + + return nil +} + +// Close closes the writer and prevents further writes. +// After closing, all Write calls will return io.ErrClosedPipe. +// This code keeps the Close() signature consistent with io.Closer, +// but it never actually returns an error. +func (bw *BackedWriter) Close() error { + bw.mu.Lock() + defer bw.mu.Unlock() + + if bw.closed { + return nil + } + + bw.closed = true + bw.writer = nil + + // Wake up any blocked operations + bw.cond.Broadcast() + + return nil +} + +// SequenceNum returns the current sequence number (total bytes written). +func (bw *BackedWriter) SequenceNum() uint64 { + bw.mu.Lock() + defer bw.mu.Unlock() + return bw.sequenceNum +} + +// Connected returns whether the writer is currently connected. +func (bw *BackedWriter) Connected() bool { + bw.mu.Lock() + defer bw.mu.Unlock() + return bw.writer != nil +} diff --git a/agent/immortalstreams/backedpipe/backed_writer_test.go b/agent/immortalstreams/backedpipe/backed_writer_test.go new file mode 100644 index 0000000000000..02f48d811f5c6 --- /dev/null +++ b/agent/immortalstreams/backedpipe/backed_writer_test.go @@ -0,0 +1,586 @@ +package backedpipe_test + +import ( + "bytes" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/agent/immortalstreams/backedpipe" + "github.com/coder/coder/v2/testutil" +) + +// mockWriter implements io.Writer with controllable behavior for testing +type mockWriter struct { + mu sync.Mutex + buffer bytes.Buffer + err error + writeFunc func([]byte) (int, error) + writeCalls int +} + +func newMockWriter() *mockWriter { + return &mockWriter{} +} + +// newBackedWriterForTest creates a BackedWriter with a small buffer for testing eviction behavior +func newBackedWriterForTest(bufferSize int) *backedpipe.BackedWriter { + errorChan := make(chan error, 1) + return backedpipe.NewBackedWriter(bufferSize, errorChan) +} + +func (mw *mockWriter) Write(p []byte) (int, error) { + mw.mu.Lock() + defer mw.mu.Unlock() + + mw.writeCalls++ + + if mw.writeFunc != nil { + return mw.writeFunc(p) + } + + if mw.err != nil { + return 0, mw.err + } + + return mw.buffer.Write(p) +} + +func (mw *mockWriter) Len() int { + mw.mu.Lock() + defer mw.mu.Unlock() + return mw.buffer.Len() +} + +func (mw *mockWriter) Reset() { + mw.mu.Lock() + defer mw.mu.Unlock() + mw.buffer.Reset() + mw.writeCalls = 0 + mw.err = nil + mw.writeFunc = nil +} + +func (mw *mockWriter) setError(err error) { + mw.mu.Lock() + defer mw.mu.Unlock() + mw.err = err +} + +func TestBackedWriter_NewBackedWriter(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + require.NotNil(t, bw) + require.Equal(t, uint64(0), bw.SequenceNum()) + require.False(t, bw.Connected()) +} + +func TestBackedWriter_WriteBlocksWhenDisconnected(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Write should block when disconnected + writeComplete := make(chan struct{}) + var writeErr error + var n int + + go func() { + defer close(writeComplete) + n, writeErr = bw.Write([]byte("hello")) + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Connect and verify write completes + writer := newMockWriter() + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Write should now complete + testutil.TryReceive(ctx, t, writeComplete) + + require.NoError(t, writeErr) + require.Equal(t, 5, n) + require.Equal(t, uint64(5), bw.SequenceNum()) + require.Equal(t, []byte("hello"), writer.buffer.Bytes()) +} + +func TestBackedWriter_WriteToUnderlyingWhenConnected(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + writer := newMockWriter() + + // Connect + err := bw.Reconnect(0, writer) + require.NoError(t, err) + require.True(t, bw.Connected()) + + // Write should go to both buffer and underlying writer + n, err := bw.Write([]byte("hello")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Data should be buffered + + // Check underlying writer + require.Equal(t, []byte("hello"), writer.buffer.Bytes()) +} + +func TestBackedWriter_DisconnectOnWriteFailure(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + writer := newMockWriter() + + // Connect + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Cause write to fail + writer.setError(xerrors.New("write failed")) + + // Write should fail and disconnect + n, err := bw.Write([]byte("hello")) + require.Error(t, err) // Write should fail + require.Equal(t, 0, n) + require.False(t, bw.Connected()) // Should be disconnected + + // Error should be sent to error channel + select { + case receivedErr := <-errorChan: + require.Contains(t, receivedErr.Error(), "write failed") + default: + t.Fatal("Expected error to be sent to error channel") + } +} + +func TestBackedWriter_ReplayOnReconnect(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Write some data while connected + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + _, err = bw.Write([]byte(" world")) + require.NoError(t, err) + + require.Equal(t, uint64(11), bw.SequenceNum()) + + // Disconnect by causing a write failure + writer1.setError(xerrors.New("connection lost")) + _, err = bw.Write([]byte("test")) + require.Error(t, err) + require.False(t, bw.Connected()) + + // Reconnect with new writer and request replay from beginning + writer2 := newMockWriter() + err = bw.Reconnect(0, writer2) + require.NoError(t, err) + + // Should have replayed all data including the failed write that was buffered + require.Equal(t, []byte("hello worldtest"), writer2.buffer.Bytes()) + + // Write new data should go to both + _, err = bw.Write([]byte("!")) + require.NoError(t, err) + require.Equal(t, []byte("hello worldtest!"), writer2.buffer.Bytes()) +} + +func TestBackedWriter_PartialReplay(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Write some data + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + _, err = bw.Write([]byte(" world")) + require.NoError(t, err) + _, err = bw.Write([]byte("!")) + require.NoError(t, err) + + // Reconnect with new writer and request replay from middle + writer2 := newMockWriter() + err = bw.Reconnect(5, writer2) // From " world!" + require.NoError(t, err) + + // Should have replayed only the requested portion + require.Equal(t, []byte(" world!"), writer2.buffer.Bytes()) +} + +func TestBackedWriter_ReplayFromFutureSequence(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + + writer2 := newMockWriter() + err = bw.Reconnect(10, writer2) // Future sequence + require.Error(t, err) + require.Contains(t, err.Error(), "future sequence") +} + +func TestBackedWriter_ReplayDataLoss(t *testing.T) { + t.Parallel() + + bw := newBackedWriterForTest(10) // Small buffer for testing + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + // Fill buffer beyond capacity to cause eviction + _, err = bw.Write([]byte("0123456789")) // Fills buffer exactly + require.NoError(t, err) + _, err = bw.Write([]byte("abcdef")) // Should evict "012345" + require.NoError(t, err) + + writer2 := newMockWriter() + err = bw.Reconnect(0, writer2) // Try to replay from evicted data + // With the new error handling, this should fail because we can't read all the data + require.Error(t, err) + require.Contains(t, err.Error(), "failed to read replay data") +} + +func TestBackedWriter_BufferEviction(t *testing.T) { + t.Parallel() + + bw := newBackedWriterForTest(5) // Very small buffer for testing + + // Connect initially + writer := newMockWriter() + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Write data that will cause eviction + n, err := bw.Write([]byte("abcde")) + require.NoError(t, err) + require.Equal(t, 5, n) + + // Write more to cause eviction + n, err = bw.Write([]byte("fg")) + require.NoError(t, err) + require.Equal(t, 2, n) + + // Buffer should contain "cdefg" (latest data) +} + +func TestBackedWriter_Close(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + writer := newMockWriter() + + bw.Reconnect(0, writer) + + err := bw.Close() + require.NoError(t, err) + + // Writes after close should fail + _, err = bw.Write([]byte("test")) + require.Equal(t, io.ErrClosedPipe, err) + + // Reconnect after close should fail + err = bw.Reconnect(0, newMockWriter()) + require.Error(t, err) + require.Contains(t, err.Error(), "closed") +} + +func TestBackedWriter_CloseIdempotent(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + err := bw.Close() + require.NoError(t, err) + + // Second close should be no-op + err = bw.Close() + require.NoError(t, err) +} + +func TestBackedWriter_ConcurrentWrites(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + writer := newMockWriter() + bw.Reconnect(0, writer) + + var wg sync.WaitGroup + numWriters := 10 + writesPerWriter := 50 + + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < writesPerWriter; j++ { + data := []byte{byte(id + '0')} + bw.Write(data) + } + }(i) + } + + wg.Wait() + + // Should have written expected amount to buffer + expectedBytes := uint64(numWriters * writesPerWriter) //nolint:gosec // Safe conversion: test constants with small values + require.Equal(t, expectedBytes, bw.SequenceNum()) + // Note: underlying writer may not receive all bytes due to potential disconnections + // during concurrent operations, but the buffer should track all writes + require.True(t, writer.Len() <= int(expectedBytes)) //nolint:gosec // Safe conversion: expectedBytes is calculated from small test values +} + +func TestBackedWriter_ReconnectDuringReplay(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Connect initially to write some data + writer1 := newMockWriter() + err := bw.Reconnect(0, writer1) + require.NoError(t, err) + + _, err = bw.Write([]byte("hello world")) + require.NoError(t, err) + + // Create a writer that fails during replay + writer2 := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + return 0, xerrors.New("replay failed") + }, + } + + err = bw.Reconnect(0, writer2) + require.Error(t, err) + require.Contains(t, err.Error(), "replay failed") + require.False(t, bw.Connected()) +} + +func TestBackedWriter_PartialWriteToUnderlying(t *testing.T) { + t.Parallel() + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Create writer that does partial writes + writer := &mockWriter{ + writeFunc: func(p []byte) (int, error) { + if len(p) > 3 { + return 3, nil // Only write first 3 bytes + } + return len(p), nil + }, + } + + bw.Reconnect(0, writer) + + // Write should fail due to partial write + n, err := bw.Write([]byte("hello")) + require.Error(t, err) + require.Equal(t, 0, n) + require.False(t, bw.Connected()) + require.Contains(t, err.Error(), "short write") + + // Error should be sent to error channel + select { + case receivedErr := <-errorChan: + require.Contains(t, receivedErr.Error(), "short write") + default: + t.Fatal("Expected error to be sent to error channel") + } +} + +func TestBackedWriter_WriteUnblocksOnReconnect(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Start a single write that should block + writeResult := make(chan error, 1) + go func() { + _, err := bw.Write([]byte("test")) + writeResult <- err + }() + + // Verify write is blocked + select { + case <-writeResult: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Connect and verify write completes + writer := newMockWriter() + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Write should now complete + err = testutil.RequireReceive(ctx, t, writeResult) + require.NoError(t, err) + + // Write should have been written to the underlying writer + require.Equal(t, "test", writer.buffer.String()) +} + +func TestBackedWriter_CloseUnblocksWaitingWrites(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Start a write that should block + writeComplete := make(chan error, 1) + go func() { + _, err := bw.Write([]byte("test")) + writeComplete <- err + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked when disconnected") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Close the writer + err := bw.Close() + require.NoError(t, err) + + // Write should now complete with error + err = testutil.RequireReceive(ctx, t, writeComplete) + require.Equal(t, io.ErrClosedPipe, err) +} + +func TestBackedWriter_WriteBlocksAfterDisconnection(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + writer := newMockWriter() + + // Connect initially + err := bw.Reconnect(0, writer) + require.NoError(t, err) + + // Write should succeed when connected + _, err = bw.Write([]byte("hello")) + require.NoError(t, err) + + // Cause disconnection + writer.setError(xerrors.New("connection lost")) + _, err = bw.Write([]byte("world")) + require.Error(t, err) + require.False(t, bw.Connected()) + + // Subsequent write should block + writeComplete := make(chan error, 1) + go func() { + _, err := bw.Write([]byte("blocked")) + writeComplete <- err + }() + + // Verify write is blocked + select { + case <-writeComplete: + t.Fatal("Write should have blocked after disconnection") + case <-time.After(50 * time.Millisecond): + // Expected - write is blocked + } + + // Reconnect and verify write completes + writer2 := newMockWriter() + err = bw.Reconnect(5, writer2) // Replay from after "hello" + require.NoError(t, err) + + err = testutil.RequireReceive(ctx, t, writeComplete) + require.NoError(t, err) +} + +func BenchmarkBackedWriter_Write(b *testing.B) { + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) // 64KB buffer + writer := newMockWriter() + bw.Reconnect(0, writer) + + data := bytes.Repeat([]byte("x"), 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bw.Write(data) + } +} + +func BenchmarkBackedWriter_Reconnect(b *testing.B) { + errorChan := make(chan error, 1) + bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errorChan) + + // Connect initially to fill buffer with data + initialWriter := newMockWriter() + err := bw.Reconnect(0, initialWriter) + if err != nil { + b.Fatal(err) + } + + // Fill buffer with data + data := bytes.Repeat([]byte("x"), 1024) + for i := 0; i < 32; i++ { + bw.Write(data) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + writer := newMockWriter() + bw.Reconnect(0, writer) + } +} diff --git a/agent/immortalstreams/backedpipe/ring_buffer.go b/agent/immortalstreams/backedpipe/ring_buffer.go new file mode 100644 index 0000000000000..eefec300b306c --- /dev/null +++ b/agent/immortalstreams/backedpipe/ring_buffer.go @@ -0,0 +1,131 @@ +package backedpipe + +import ( + "golang.org/x/xerrors" +) + +// ringBuffer implements an efficient circular buffer with a fixed-size allocation. +// This implementation is not thread-safe and relies on external synchronization. +type ringBuffer struct { + buffer []byte + start int // index of first valid byte + end int // index of last valid byte (-1 when empty) +} + +// newRingBuffer creates a new ring buffer with the specified capacity. +// Capacity must be > 0. +func newRingBuffer(capacity int) *ringBuffer { + if capacity <= 0 { + panic("ring buffer capacity must be > 0") + } + return &ringBuffer{ + buffer: make([]byte, capacity), + end: -1, // -1 indicates empty buffer + } +} + +// Size returns the current number of bytes in the buffer. +func (rb *ringBuffer) Size() int { + if rb.end == -1 { + return 0 // Buffer is empty + } + if rb.start <= rb.end { + return rb.end - rb.start + 1 + } + // Buffer wraps around + return len(rb.buffer) - rb.start + rb.end + 1 +} + +// Write writes data to the ring buffer. If the buffer would overflow, +// it evicts the oldest data to make room for new data. +func (rb *ringBuffer) Write(data []byte) { + if len(data) == 0 { + return + } + + capacity := len(rb.buffer) + + // If data is larger than capacity, only keep the last capacity bytes + if len(data) > capacity { + data = data[len(data)-capacity:] + // Clear buffer and write new data + rb.start = 0 + rb.end = -1 // Will be set properly below + } + + // Calculate how much we need to evict to fit new data + spaceNeeded := len(data) + availableSpace := capacity - rb.Size() + + if spaceNeeded > availableSpace { + bytesToEvict := spaceNeeded - availableSpace + rb.evict(bytesToEvict) + } + + // Buffer has data, write after current end + writePos := (rb.end + 1) % capacity + if writePos+len(data) <= capacity { + // No wrap needed - single copy + copy(rb.buffer[writePos:], data) + rb.end = (rb.end + len(data)) % capacity + } else { + // Need to wrap around - two copies + firstChunk := capacity - writePos + copy(rb.buffer[writePos:], data[:firstChunk]) + copy(rb.buffer[0:], data[firstChunk:]) + rb.end = len(data) - firstChunk - 1 + } +} + +// evict removes the specified number of bytes from the beginning of the buffer. +func (rb *ringBuffer) evict(count int) { + if count >= rb.Size() { + // Evict everything + rb.start = 0 + rb.end = -1 + return + } + + rb.start = (rb.start + count) % len(rb.buffer) + // Buffer remains non-empty after partial eviction +} + +// ReadLast returns the last n bytes from the buffer. +// If n is greater than the available data, returns an error. +// If n is negative, returns an error. +func (rb *ringBuffer) ReadLast(n int) ([]byte, error) { + if n < 0 { + return nil, xerrors.New("cannot read negative number of bytes") + } + + if n == 0 { + return nil, nil + } + + size := rb.Size() + + // If requested more than available, return error + if n > size { + return nil, xerrors.Errorf("requested %d bytes but only %d available", n, size) + } + + result := make([]byte, n) + capacity := len(rb.buffer) + + // Calculate where to start reading from (n bytes before the end) + startOffset := size - n + actualStart := (rb.start + startOffset) % capacity + + // Copy the last n bytes + if actualStart+n <= capacity { + // No wrap needed + copy(result, rb.buffer[actualStart:actualStart+n]) + } else { + // Need to wrap around + firstChunk := capacity - actualStart + copy(result[0:firstChunk], rb.buffer[actualStart:capacity]) + copy(result[firstChunk:], rb.buffer[0:n-firstChunk]) + } + + return result, nil +} diff --git a/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go b/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go new file mode 100644 index 0000000000000..34fe5fb1cbf6e --- /dev/null +++ b/agent/immortalstreams/backedpipe/ring_buffer_internal_test.go @@ -0,0 +1,264 @@ +package backedpipe + +import ( + "bytes" + "os" + "runtime" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/coder/coder/v2/testutil" +) + +func TestMain(m *testing.M) { + if runtime.GOOS == "windows" { + // Don't run goleak on windows tests, they're super flaky right now. + // See: https://github.com/coder/coder/issues/8954 + os.Exit(m.Run()) + } + goleak.VerifyTestMain(m, testutil.GoleakOptions...) +} + +func TestRingBuffer_NewRingBuffer(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(100) + // Test that we can write and read from the buffer + rb.Write([]byte("test")) + + data, err := rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("test"), data) +} + +func TestRingBuffer_WriteAndRead(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(10) + + // Write some data + rb.Write([]byte("hello")) + + // Read last 4 bytes + data, err := rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, "ello", string(data)) + + // Write more data + rb.Write([]byte("world")) + + // Read last 5 bytes + data, err = rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, "world", string(data)) + + // Read last 3 bytes + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, "rld", string(data)) + + // Read more than available (should be 10 bytes total) + _, err = rb.ReadLast(15) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 15 bytes but only") +} + +func TestRingBuffer_OverflowEviction(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(5) + + // Fill buffer + rb.Write([]byte("abcde")) + + // Overflow should evict oldest data + rb.Write([]byte("fg")) + + // Should now contain "cdefg" + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("cdefg"), data) +} + +func TestRingBuffer_LargeWrite(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(5) + + // Write data larger than capacity + rb.Write([]byte("abcdefghij")) + + // Should contain last 5 bytes + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("fghij"), data) +} + +func TestRingBuffer_WrapAround(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(5) + + // Fill buffer + rb.Write([]byte("abcde")) + + // Write more to cause wrap-around + rb.Write([]byte("fgh")) + + // Should contain "defgh" + data, err := rb.ReadLast(5) + require.NoError(t, err) + require.Equal(t, []byte("defgh"), data) + + // Test reading last 3 bytes after wrap + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("fgh"), data) +} + +func TestRingBuffer_ReadLastEdgeCases(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(3) + + // Write some data (5 bytes to a 3-byte buffer, so only last 3 bytes remain) + rb.Write([]byte("hello")) + + // Test reading negative count + data, err := rb.ReadLast(-1) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot read negative number of bytes") + require.Nil(t, data) + + // Test reading zero bytes + data, err = rb.ReadLast(0) + require.NoError(t, err) + require.Nil(t, data) + + // Test reading more than available (buffer has 3 bytes, try to read 10) + _, err = rb.ReadLast(10) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 10 bytes but only 3 available") + + // Test reading exact amount available + data, err = rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("llo"), data) +} + +func TestRingBuffer_EmptyWrite(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(10) + + // Write empty data + rb.Write([]byte{}) + + // Buffer should still be empty + _, err := rb.ReadLast(5) + require.Error(t, err) + require.Contains(t, err.Error(), "requested 5 bytes but only 0 available") +} + +// TestRingBuffer_ConcurrentAccess removed - the ring buffer is no longer thread-safe +// by design, as it relies on external synchronization provided by BackedWriter. + +func TestRingBuffer_MultipleWrites(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(10) + + // Write data in chunks + rb.Write([]byte("ab")) + rb.Write([]byte("cd")) + rb.Write([]byte("ef")) + + data, err := rb.ReadLast(6) + require.NoError(t, err) + require.Equal(t, []byte("abcdef"), data) + + // Test partial reads + data, err = rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("cdef"), data) + + data, err = rb.ReadLast(2) + require.NoError(t, err) + require.Equal(t, []byte("ef"), data) +} + +func TestRingBuffer_EdgeCaseEviction(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(3) + + // Write data that will cause eviction + rb.Write([]byte("abc")) + + // Write more to cause eviction + rb.Write([]byte("d")) + + // Should now contain "bcd" + data, err := rb.ReadLast(3) + require.NoError(t, err) + require.Equal(t, []byte("bcd"), data) +} + +func TestRingBuffer_ComplexWrapAroundScenario(t *testing.T) { + t.Parallel() + + rb := newRingBuffer(8) + + // Fill buffer + rb.Write([]byte("12345678")) + + // Evict some and add more to create complex wrap scenario + rb.Write([]byte("abcd")) + data, err := rb.ReadLast(8) + require.NoError(t, err) + require.Equal(t, []byte("5678abcd"), data) + + // Add more + rb.Write([]byte("xyz")) + data, err = rb.ReadLast(8) + require.NoError(t, err) + require.Equal(t, []byte("8abcdxyz"), data) + + // Test reading various amounts from the end + data, err = rb.ReadLast(7) + require.NoError(t, err) + require.Equal(t, []byte("abcdxyz"), data) + + data, err = rb.ReadLast(4) + require.NoError(t, err) + require.Equal(t, []byte("dxyz"), data) +} + +// Benchmark tests for performance validation +func BenchmarkRingBuffer_Write(b *testing.B) { + rb := newRingBuffer(64 * 1024 * 1024) // 64MB for benchmarks + data := bytes.Repeat([]byte("x"), 1024) // 1KB writes + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rb.Write(data) + } +} + +func BenchmarkRingBuffer_ReadLast(b *testing.B) { + rb := newRingBuffer(64 * 1024 * 1024) // 64MB for benchmarks + // Fill buffer with test data + for i := 0; i < 64; i++ { + rb.Write(bytes.Repeat([]byte("x"), 1024)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := rb.ReadLast((i % 100) + 1) + if err != nil { + b.Fatal(err) + } + } +}