diff --git a/agent/agent.go b/agent/agent.go
index 6d27802d1291a..2d5b9a663202e 100644
--- a/agent/agent.go
+++ b/agent/agent.go
@@ -3,12 +3,10 @@ package agent
 import (
 	"bytes"
 	"context"
-	"encoding/binary"
 	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
-	"net"
 	"net/http"
 	"net/netip"
 	"os"
@@ -216,8 +214,8 @@ type agent struct {
 	portCacheDuration time.Duration
 	subsystems        []codersdk.AgentSubsystem
 
-	reconnectingPTYs       sync.Map
 	reconnectingPTYTimeout time.Duration
+	reconnectingPTYServer  *reconnectingpty.Server
 
 	// we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
 	// to start gracefully shutting down and "hard" which is Done when it is time to close
@@ -252,8 +250,6 @@ type agent struct {
 	statsReporter *statsReporter
 	logSender     *agentsdk.LogSender
 
-	connCountReconnectingPTY atomic.Int64
-
 	prometheusRegistry *prometheus.Registry
 	// metrics are prometheus registered metrics that will be collected and
 	// labeled in Coder with the agent + workspace.
@@ -297,6 +293,13 @@ func (a *agent) init() {
 	// Register runner metrics. If the prom registry is nil, the metrics
 	// will not report anywhere.
 	a.scriptRunner.RegisterMetrics(a.prometheusRegistry)
+
+	a.reconnectingPTYServer = reconnectingpty.NewServer(
+		a.logger.Named("reconnecting-pty"),
+		a.sshServer,
+		a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
+		a.reconnectingPTYTimeout,
+	)
 	go a.runLoop()
 }
 
@@ -1181,55 +1184,12 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
 		}
 	}()
 	if err = a.trackGoroutine(func() {
-		logger := a.logger.Named("reconnecting-pty")
-		var wg sync.WaitGroup
-		for {
-			conn, err := reconnectingPTYListener.Accept()
-			if err != nil {
-				if !a.isClosed() {
-					logger.Debug(ctx, "accept pty failed", slog.Error(err))
-				}
-				break
-			}
-			clog := logger.With(
-				slog.F("remote", conn.RemoteAddr().String()),
-				slog.F("local", conn.LocalAddr().String()))
-			clog.Info(ctx, "accepted conn")
-			wg.Add(1)
-			closed := make(chan struct{})
-			go func() {
-				select {
-				case <-closed:
-				case <-a.hardCtx.Done():
-					_ = conn.Close()
-				}
-				wg.Done()
-			}()
-			go func() {
-				defer close(closed)
-				// This cannot use a JSON decoder, since that can
-				// buffer additional data that is required for the PTY.
-				rawLen := make([]byte, 2)
-				_, err = conn.Read(rawLen)
-				if err != nil {
-					return
-				}
-				length := binary.LittleEndian.Uint16(rawLen)
-				data := make([]byte, length)
-				_, err = conn.Read(data)
-				if err != nil {
-					return
-				}
-				var msg workspacesdk.AgentReconnectingPTYInit
-				err = json.Unmarshal(data, &msg)
-				if err != nil {
-					logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data))
-					return
-				}
-				_ = a.handleReconnectingPTY(ctx, clog, msg, conn)
-			}()
+		rPTYServeErr := a.reconnectingPTYServer.Serve(a.gracefulCtx, a.hardCtx, reconnectingPTYListener)
+		if rPTYServeErr != nil &&
+			a.gracefulCtx.Err() == nil &&
+			!strings.Contains(rPTYServeErr.Error(), "use of closed network connection") {
+			a.logger.Error(ctx, "error serving reconnecting PTY", slog.Error(err))
 		}
-		wg.Wait()
 	}); err != nil {
 		return nil, err
 	}
@@ -1308,9 +1268,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
 			_ = server.Close()
 		}()
 
-		err := server.Serve(apiListener)
-		if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") {
-			a.logger.Critical(ctx, "serve HTTP API server", slog.Error(err))
+		apiServErr := server.Serve(apiListener)
+		if apiServErr != nil && !xerrors.Is(apiServErr, http.ErrServerClosed) && !strings.Contains(apiServErr.Error(), "use of closed network connection") {
+			a.logger.Critical(ctx, "serve HTTP API server", slog.Error(apiServErr))
 		}
 	}); err != nil {
 		return nil, err
@@ -1394,87 +1354,6 @@ func (a *agent) runDERPMapSubscriber(ctx context.Context, tClient tailnetproto.D
 	}
 }
 
-func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg workspacesdk.AgentReconnectingPTYInit, conn net.Conn) (retErr error) {
-	defer conn.Close()
-	a.metrics.connectionsTotal.Add(1)
-
-	a.connCountReconnectingPTY.Add(1)
-	defer a.connCountReconnectingPTY.Add(-1)
-
-	connectionID := uuid.NewString()
-	connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
-	connLogger.Debug(ctx, "starting handler")
-
-	defer func() {
-		if err := retErr; err != nil {
-			a.closeMutex.Lock()
-			closed := a.isClosed()
-			a.closeMutex.Unlock()
-
-			// If the agent is closed, we don't want to
-			// log this as an error since it's expected.
-			if closed {
-				connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err))
-			} else {
-				connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err))
-			}
-		}
-		connLogger.Info(ctx, "reconnecting pty connection closed")
-	}()
-
-	var rpty reconnectingpty.ReconnectingPTY
-	sendConnected := make(chan reconnectingpty.ReconnectingPTY, 1)
-	// On store, reserve this ID to prevent multiple concurrent new connections.
-	waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
-	if ok {
-		close(sendConnected) // Unused.
-		connLogger.Debug(ctx, "connecting to existing reconnecting pty")
-		c, ok := waitReady.(chan reconnectingpty.ReconnectingPTY)
-		if !ok {
-			return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
-		}
-		rpty, ok = <-c
-		if !ok || rpty == nil {
-			return xerrors.Errorf("reconnecting pty closed before connection")
-		}
-		c <- rpty // Put it back for the next reconnect.
-	} else {
-		connLogger.Debug(ctx, "creating new reconnecting pty")
-
-		connected := false
-		defer func() {
-			if !connected && retErr != nil {
-				a.reconnectingPTYs.Delete(msg.ID)
-				close(sendConnected)
-			}
-		}()
-
-		// Empty command will default to the users shell!
-		cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
-		if err != nil {
-			a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1)
-			return xerrors.Errorf("create command: %w", err)
-		}
-
-		rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{
-			Timeout: a.reconnectingPTYTimeout,
-			Metrics: a.metrics.reconnectingPTYErrors,
-		}, logger.With(slog.F("message_id", msg.ID)))
-
-		if err = a.trackGoroutine(func() {
-			rpty.Wait()
-			a.reconnectingPTYs.Delete(msg.ID)
-		}); err != nil {
-			rpty.Close(err)
-			return xerrors.Errorf("start routine: %w", err)
-		}
-
-		connected = true
-		sendConnected <- rpty
-	}
-	return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
-}
-
 // Collect collects additional stats from the agent
 func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
 	a.logger.Debug(context.Background(), "computing stats report")
@@ -1496,7 +1375,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
 	stats.SessionCountVscode = sshStats.VSCode
 	stats.SessionCountJetbrains = sshStats.JetBrains
 
-	stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load()
+	stats.SessionCountReconnectingPty = a.reconnectingPTYServer.ConnCount()
 
 	// Compute the median connection latency!
 	a.logger.Debug(ctx, "starting peer latency measurement for stats")
diff --git a/agent/reconnectingpty/server.go b/agent/reconnectingpty/server.go
new file mode 100644
index 0000000000000..052a88e52b0b4
--- /dev/null
+++ b/agent/reconnectingpty/server.go
@@ -0,0 +1,191 @@
+package reconnectingpty
+
+import (
+	"context"
+	"encoding/binary"
+	"encoding/json"
+	"net"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/google/uuid"
+	"github.com/prometheus/client_golang/prometheus"
+	"golang.org/x/xerrors"
+
+	"cdr.dev/slog"
+	"github.com/coder/coder/v2/agent/agentssh"
+	"github.com/coder/coder/v2/codersdk/workspacesdk"
+)
+
+type Server struct {
+	logger           slog.Logger
+	connectionsTotal prometheus.Counter
+	errorsTotal      *prometheus.CounterVec
+	commandCreator   *agentssh.Server
+	connCount        atomic.Int64
+	reconnectingPTYs sync.Map
+	timeout          time.Duration
+}
+
+// NewServer returns a new ReconnectingPTY server
+func NewServer(logger slog.Logger, commandCreator *agentssh.Server,
+	connectionsTotal prometheus.Counter, errorsTotal *prometheus.CounterVec,
+	timeout time.Duration,
+) *Server {
+	return &Server{
+		logger:           logger,
+		commandCreator:   commandCreator,
+		connectionsTotal: connectionsTotal,
+		errorsTotal:      errorsTotal,
+		timeout:          timeout,
+	}
+}
+
+func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr error) {
+	var wg sync.WaitGroup
+	for {
+		if ctx.Err() != nil {
+			break
+		}
+		conn, err := l.Accept()
+		if err != nil {
+			s.logger.Debug(ctx, "accept pty failed", slog.Error(err))
+			retErr = err
+			break
+		}
+		clog := s.logger.With(
+			slog.F("remote", conn.RemoteAddr().String()),
+			slog.F("local", conn.LocalAddr().String()))
+		clog.Info(ctx, "accepted conn")
+		wg.Add(1)
+		closed := make(chan struct{})
+		go func() {
+			select {
+			case <-closed:
+			case <-hardCtx.Done():
+				_ = conn.Close()
+			}
+			wg.Done()
+		}()
+		wg.Add(1)
+		go func() {
+			defer close(closed)
+			defer wg.Done()
+			_ = s.handleConn(ctx, clog, conn)
+		}()
+	}
+	wg.Wait()
+	return retErr
+}
+
+func (s *Server) ConnCount() int64 {
+	return s.connCount.Load()
+}
+
+func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Conn) (retErr error) {
+	defer conn.Close()
+	s.connectionsTotal.Add(1)
+	s.connCount.Add(1)
+	defer s.connCount.Add(-1)
+
+	// This cannot use a JSON decoder, since that can
+	// buffer additional data that is required for the PTY.
+	rawLen := make([]byte, 2)
+	_, err := conn.Read(rawLen)
+	if err != nil {
+		// logging at info since a single incident isn't too worrying (the client could just have
+		// hung up), but if we get a lot of these we'd want to investigate.
+		logger.Info(ctx, "failed to read AgentReconnectingPTYInit length", slog.Error(err))
+		return nil
+	}
+	length := binary.LittleEndian.Uint16(rawLen)
+	data := make([]byte, length)
+	_, err = conn.Read(data)
+	if err != nil {
+		// logging at info since a single incident isn't too worrying (the client could just have
+		// hung up), but if we get a lot of these we'd want to investigate.
+		logger.Info(ctx, "failed to read AgentReconnectingPTYInit", slog.Error(err))
+		return nil
+	}
+	var msg workspacesdk.AgentReconnectingPTYInit
+	err = json.Unmarshal(data, &msg)
+	if err != nil {
+		logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data))
+		return nil
+	}
+
+	connectionID := uuid.NewString()
+	connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
+	connLogger.Debug(ctx, "starting handler")
+
+	defer func() {
+		if err := retErr; err != nil {
+			// If the context is done, we don't want to log this as an error since it's expected.
+			if ctx.Err() != nil {
+				connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err))
+			} else {
+				connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err))
+			}
+		}
+		connLogger.Info(ctx, "reconnecting pty connection closed")
+	}()
+
+	var rpty ReconnectingPTY
+	sendConnected := make(chan ReconnectingPTY, 1)
+	// On store, reserve this ID to prevent multiple concurrent new connections.
+	waitReady, ok := s.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
+	if ok {
+		close(sendConnected) // Unused.
+		connLogger.Debug(ctx, "connecting to existing reconnecting pty")
+		c, ok := waitReady.(chan ReconnectingPTY)
+		if !ok {
+			return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
+		}
+		rpty, ok = <-c
+		if !ok || rpty == nil {
+			return xerrors.Errorf("reconnecting pty closed before connection")
+		}
+		c <- rpty // Put it back for the next reconnect.
+	} else {
+		connLogger.Debug(ctx, "creating new reconnecting pty")
+
+		connected := false
+		defer func() {
+			if !connected && retErr != nil {
+				s.reconnectingPTYs.Delete(msg.ID)
+				close(sendConnected)
+			}
+		}()
+
+		// Empty command will default to the users shell!
+		cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil)
+		if err != nil {
+			s.errorsTotal.WithLabelValues("create_command").Add(1)
+			return xerrors.Errorf("create command: %w", err)
+		}
+
+		rpty = New(ctx, cmd, &Options{
+			Timeout: s.timeout,
+			Metrics: s.errorsTotal,
+		}, logger.With(slog.F("message_id", msg.ID)))
+
+		done := make(chan struct{})
+		go func() {
+			select {
+			case <-done:
+			case <-ctx.Done():
+				rpty.Close(ctx.Err())
+			}
+		}()
+
+		go func() {
+			rpty.Wait()
+			s.reconnectingPTYs.Delete(msg.ID)
+		}()
+
+		connected = true
+		sendConnected <- rpty
+	}
+	return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
+}