Skip to content

fix: fix panic while tearing down reconnecting PTY #15615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 17 additions & 138 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package agent
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
Expand Down Expand Up @@ -216,8 +214,8 @@ type agent struct {
portCacheDuration time.Duration
subsystems []codersdk.AgentSubsystem

reconnectingPTYs sync.Map
reconnectingPTYTimeout time.Duration
reconnectingPTYServer *reconnectingpty.Server

// we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
// to start gracefully shutting down and "hard" which is Done when it is time to close
Expand Down Expand Up @@ -252,8 +250,6 @@ type agent struct {
statsReporter *statsReporter
logSender *agentsdk.LogSender

connCountReconnectingPTY atomic.Int64

prometheusRegistry *prometheus.Registry
// metrics are prometheus registered metrics that will be collected and
// labeled in Coder with the agent + workspace.
Expand Down Expand Up @@ -297,6 +293,13 @@ func (a *agent) init() {
// Register runner metrics. If the prom registry is nil, the metrics
// will not report anywhere.
a.scriptRunner.RegisterMetrics(a.prometheusRegistry)

a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
a.sshServer,
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
a.reconnectingPTYTimeout,
)
go a.runLoop()
}

Expand Down Expand Up @@ -1181,55 +1184,12 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
}
}()
if err = a.trackGoroutine(func() {
logger := a.logger.Named("reconnecting-pty")
var wg sync.WaitGroup
for {
conn, err := reconnectingPTYListener.Accept()
if err != nil {
if !a.isClosed() {
logger.Debug(ctx, "accept pty failed", slog.Error(err))
}
break
}
clog := logger.With(
slog.F("remote", conn.RemoteAddr().String()),
slog.F("local", conn.LocalAddr().String()))
clog.Info(ctx, "accepted conn")
wg.Add(1)
closed := make(chan struct{})
go func() {
select {
case <-closed:
case <-a.hardCtx.Done():
_ = conn.Close()
}
wg.Done()
}()
go func() {
defer close(closed)
// This cannot use a JSON decoder, since that can
// buffer additional data that is required for the PTY.
rawLen := make([]byte, 2)
_, err = conn.Read(rawLen)
if err != nil {
return
}
length := binary.LittleEndian.Uint16(rawLen)
data := make([]byte, length)
_, err = conn.Read(data)
if err != nil {
return
}
var msg workspacesdk.AgentReconnectingPTYInit
err = json.Unmarshal(data, &msg)
if err != nil {
logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data))
return
}
_ = a.handleReconnectingPTY(ctx, clog, msg, conn)
}()
rPTYServeErr := a.reconnectingPTYServer.Serve(a.gracefulCtx, a.hardCtx, reconnectingPTYListener)
if rPTYServeErr != nil &&
a.gracefulCtx.Err() == nil &&
!strings.Contains(rPTYServeErr.Error(), "use of closed network connection") {
a.logger.Error(ctx, "error serving reconnecting PTY", slog.Error(err))
}
wg.Wait()
}); err != nil {
return nil, err
}
Expand Down Expand Up @@ -1308,9 +1268,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
_ = server.Close()
}()

err := server.Serve(apiListener)
if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") {
a.logger.Critical(ctx, "serve HTTP API server", slog.Error(err))
apiServErr := server.Serve(apiListener)
if apiServErr != nil && !xerrors.Is(apiServErr, http.ErrServerClosed) && !strings.Contains(apiServErr.Error(), "use of closed network connection") {
a.logger.Critical(ctx, "serve HTTP API server", slog.Error(apiServErr))
}
}); err != nil {
return nil, err
Expand Down Expand Up @@ -1394,87 +1354,6 @@ func (a *agent) runDERPMapSubscriber(ctx context.Context, tClient tailnetproto.D
}
}

func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg workspacesdk.AgentReconnectingPTYInit, conn net.Conn) (retErr error) {
defer conn.Close()
a.metrics.connectionsTotal.Add(1)

a.connCountReconnectingPTY.Add(1)
defer a.connCountReconnectingPTY.Add(-1)

connectionID := uuid.NewString()
connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
connLogger.Debug(ctx, "starting handler")

defer func() {
if err := retErr; err != nil {
a.closeMutex.Lock()
closed := a.isClosed()
a.closeMutex.Unlock()

// If the agent is closed, we don't want to
// log this as an error since it's expected.
if closed {
connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err))
} else {
connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err))
}
}
connLogger.Info(ctx, "reconnecting pty connection closed")
}()

var rpty reconnectingpty.ReconnectingPTY
sendConnected := make(chan reconnectingpty.ReconnectingPTY, 1)
// On store, reserve this ID to prevent multiple concurrent new connections.
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
if ok {
close(sendConnected) // Unused.
connLogger.Debug(ctx, "connecting to existing reconnecting pty")
c, ok := waitReady.(chan reconnectingpty.ReconnectingPTY)
if !ok {
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
}
rpty, ok = <-c
if !ok || rpty == nil {
return xerrors.Errorf("reconnecting pty closed before connection")
}
c <- rpty // Put it back for the next reconnect.
} else {
connLogger.Debug(ctx, "creating new reconnecting pty")

connected := false
defer func() {
if !connected && retErr != nil {
a.reconnectingPTYs.Delete(msg.ID)
close(sendConnected)
}
}()

// Empty command will default to the users shell!
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
if err != nil {
a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1)
return xerrors.Errorf("create command: %w", err)
}

rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{
Timeout: a.reconnectingPTYTimeout,
Metrics: a.metrics.reconnectingPTYErrors,
}, logger.With(slog.F("message_id", msg.ID)))

if err = a.trackGoroutine(func() {
rpty.Wait()
a.reconnectingPTYs.Delete(msg.ID)
}); err != nil {
rpty.Close(err)
return xerrors.Errorf("start routine: %w", err)
}

connected = true
sendConnected <- rpty
}
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
}

// Collect collects additional stats from the agent
func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
a.logger.Debug(context.Background(), "computing stats report")
Expand All @@ -1496,7 +1375,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
stats.SessionCountVscode = sshStats.VSCode
stats.SessionCountJetbrains = sshStats.JetBrains

stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load()
stats.SessionCountReconnectingPty = a.reconnectingPTYServer.ConnCount()

// Compute the median connection latency!
a.logger.Debug(ctx, "starting peer latency measurement for stats")
Expand Down
Loading
Loading