diff --git a/agent/agent.go b/agent/agent.go index 31b4b8959f8df..b25c6217e3c94 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -30,6 +30,7 @@ import ( "github.com/spf13/afero" "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" + "golang.org/x/exp/slices" "golang.org/x/xerrors" "tailscale.com/net/speedtest" "tailscale.com/tailcfg" @@ -90,7 +91,7 @@ func New(options Options) io.Closer { } } ctx, cancelFunc := context.WithCancel(context.Background()) - server := &agent{ + a := &agent{ reconnectingPTYTimeout: options.ReconnectingPTYTimeout, logger: options.Logger, closeCancel: cancelFunc, @@ -101,8 +102,8 @@ func New(options Options) io.Closer { filesystem: options.Filesystem, tempDir: options.TempDir, } - server.init(ctx) - return server + a.init(ctx) + return a } type agent struct { @@ -300,10 +301,12 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ } }() if err = a.trackConnGoroutine(func() { + logger := a.logger.Named("reconnecting-pty") + for { conn, err := reconnectingPTYListener.Accept() if err != nil { - a.logger.Debug(ctx, "accept pty failed", slog.Error(err)) + logger.Debug(ctx, "accept pty failed", slog.Error(err)) return } // This cannot use a JSON decoder, since that can @@ -324,7 +327,9 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ if err != nil { continue } - go a.handleReconnectingPTY(ctx, msg, conn) + go func() { + _ = a.handleReconnectingPTY(ctx, logger, msg, conn) + }() } }); err != nil { return nil, err @@ -798,38 +803,56 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { return cmd.Wait() } -func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.ReconnectingPTYInit, conn net.Conn) { +func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg codersdk.ReconnectingPTYInit, conn net.Conn) (retErr error) { defer conn.Close() connectionID := uuid.NewString() + logger = logger.With(slog.F("id", msg.ID), slog.F("connection_id", connectionID)) + + defer func() { + if err := retErr; err != nil { + a.closeMutex.Lock() + closed := a.isClosed() + a.closeMutex.Unlock() + + // If the agent is closed, we don't want to + // log this as an error since it's expected. + if closed { + logger.Debug(ctx, "session error after agent close", slog.Error(err)) + } else { + logger.Error(ctx, "session error", slog.Error(err)) + } + } + logger.Debug(ctx, "session closed") + }() + var rpty *reconnectingPTY rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID) if ok { + logger.Debug(ctx, "connecting to existing session") rpty, ok = rawRPTY.(*reconnectingPTY) if !ok { - a.logger.Error(ctx, "found invalid type in reconnecting pty map", slog.F("id", msg.ID)) - return + return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY) } } else { + logger.Debug(ctx, "creating new session") + // Empty command will default to the users shell! cmd, err := a.createCommand(ctx, msg.Command, nil) if err != nil { - a.logger.Error(ctx, "create reconnecting pty command", slog.Error(err)) - return + return xerrors.Errorf("create command: %w", err) } cmd.Env = append(cmd.Env, "TERM=xterm-256color") // Default to buffer 64KiB. circularBuffer, err := circbuf.NewBuffer(64 << 10) if err != nil { - a.logger.Error(ctx, "create circular buffer", slog.Error(err)) - return + return xerrors.Errorf("create circular buffer: %w", err) } ptty, process, err := pty.Start(cmd) if err != nil { - a.logger.Error(ctx, "start reconnecting pty command", slog.F("id", msg.ID), slog.Error(err)) - return + return xerrors.Errorf("start command: %w", err) } ctx, cancelFunc := context.WithCancel(ctx) @@ -873,7 +896,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec _, err = rpty.circularBuffer.Write(part) rpty.circularBufferMutex.Unlock() if err != nil { - a.logger.Error(ctx, "reconnecting pty write buffer", slog.Error(err), slog.F("id", msg.ID)) + logger.Error(ctx, "write to circular buffer", slog.Error(err)) break } rpty.activeConnsMutex.Lock() @@ -889,23 +912,27 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec rpty.Close() a.reconnectingPTYs.Delete(msg.ID) }); err != nil { - a.logger.Error(ctx, "start reconnecting pty routine", slog.F("id", msg.ID), slog.Error(err)) - return + return xerrors.Errorf("start routine: %w", err) } } // Resize the PTY to initial height + width. err := rpty.ptty.Resize(msg.Height, msg.Width) if err != nil { // We can continue after this, it's not fatal! - a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", msg.ID), slog.Error(err)) + logger.Error(ctx, "resize", slog.Error(err)) } // Write any previously stored data for the TTY. rpty.circularBufferMutex.RLock() - _, err = conn.Write(rpty.circularBuffer.Bytes()) + prevBuf := slices.Clone(rpty.circularBuffer.Bytes()) rpty.circularBufferMutex.RUnlock() + // Note that there is a small race here between writing buffered + // data and storing conn in activeConns. This is likely a very minor + // edge case, but we should look into ways to avoid it. Holding + // activeConnsMutex would be one option, but holding this mutex + // while also holding circularBufferMutex seems dangerous. + _, err = conn.Write(prevBuf) if err != nil { - a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", msg.ID), slog.Error(err)) - return + return xerrors.Errorf("write buffer to conn: %w", err) } // Multiple connections to the same TTY are permitted. // This could easily be used for terminal sharing, but @@ -946,16 +973,16 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec for { err = decoder.Decode(&req) if xerrors.Is(err, io.EOF) { - return + return nil } if err != nil { - a.logger.Warn(ctx, "reconnecting pty buffer read error", slog.F("id", msg.ID), slog.Error(err)) - return + logger.Warn(ctx, "read conn", slog.Error(err)) + return nil } _, err = rpty.ptty.Input().Write([]byte(req.Data)) if err != nil { - a.logger.Warn(ctx, "write to reconnecting pty", slog.F("id", msg.ID), slog.Error(err)) - return + logger.Warn(ctx, "write to pty", slog.Error(err)) + return nil } // Check if a resize needs to happen! if req.Height == 0 || req.Width == 0 { @@ -964,7 +991,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec err = rpty.ptty.Resize(req.Height, req.Width) if err != nil { // We can continue after this, it's not fatal! - a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", msg.ID), slog.Error(err)) + logger.Error(ctx, "resize", slog.Error(err)) } } } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index e489a759c0255..ea272c2a409e2 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -552,13 +552,17 @@ func UpdateTemplateVersion(t *testing.T, client *codersdk.Client, organizationID func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid.UUID) codersdk.TemplateVersion { t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + t.Logf("waiting for template version job %s", version) var templateVersion codersdk.TemplateVersion require.Eventually(t, func() bool { var err error - templateVersion, err = client.TemplateVersion(context.Background(), version) + templateVersion, err = client.TemplateVersion(ctx, version) return assert.NoError(t, err) && templateVersion.Job.CompletedAt != nil }, testutil.WaitMedium, testutil.IntervalFast) + t.Logf("got template version job %s", version) return templateVersion } @@ -566,13 +570,17 @@ func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid func AwaitWorkspaceBuildJob(t *testing.T, client *codersdk.Client, build uuid.UUID) codersdk.WorkspaceBuild { t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + t.Logf("waiting for workspace build job %s", build) var workspaceBuild codersdk.WorkspaceBuild require.Eventually(t, func() bool { var err error - workspaceBuild, err = client.WorkspaceBuild(context.Background(), build) + workspaceBuild, err = client.WorkspaceBuild(ctx, build) return assert.NoError(t, err) && workspaceBuild.Job.CompletedAt != nil }, testutil.WaitShort, testutil.IntervalFast) + t.Logf("got workspace build job %s", build) return workspaceBuild } @@ -580,11 +588,14 @@ func AwaitWorkspaceBuildJob(t *testing.T, client *codersdk.Client, build uuid.UU func AwaitWorkspaceAgents(t *testing.T, client *codersdk.Client, workspaceID uuid.UUID) []codersdk.WorkspaceResource { t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + t.Logf("waiting for workspace agents (workspace %s)", workspaceID) var resources []codersdk.WorkspaceResource require.Eventually(t, func() bool { var err error - workspace, err := client.Workspace(context.Background(), workspaceID) + workspace, err := client.Workspace(ctx, workspaceID) if !assert.NoError(t, err) { return false } @@ -604,6 +615,7 @@ func AwaitWorkspaceAgents(t *testing.T, client *codersdk.Client, workspaceID uui return true }, testutil.WaitLong, testutil.IntervalFast) + t.Logf("got workspace agents (workspace %s)", workspaceID) return resources } diff --git a/loadtest/reconnectingpty/run_test.go b/loadtest/reconnectingpty/run_test.go index f9c99864600e0..6d78f72e25823 100644 --- a/loadtest/reconnectingpty/run_test.go +++ b/loadtest/reconnectingpty/run_test.go @@ -22,7 +22,6 @@ import ( func Test_Runner(t *testing.T) { t.Parallel() - t.Skip("See: https://github.com/coder/coder/issues/5247") t.Run("OK", func(t *testing.T) { t.Parallel()