diff --git a/agent/agent.go b/agent/agent.go index 165c73598939c..506de80f6ad54 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m }() var rpty *reconnectingPTY - rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID) + sendConnected := make(chan *reconnectingPTY, 1) + // On store, reserve this ID to prevent multiple concurrent new connections. + waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected) if ok { + close(sendConnected) // Unused. logger.Debug(ctx, "connecting to existing session") - rpty, ok = rawRPTY.(*reconnectingPTY) + c, ok := waitReady.(chan *reconnectingPTY) if !ok { - return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY) + return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady) } + rpty, ok = <-c + if !ok || rpty == nil { + return xerrors.Errorf("reconnecting pty closed before connection") + } + c <- rpty // Put it back for the next reconnect. } else { logger.Debug(ctx, "creating new session") + connected := false + defer func() { + if !connected && retErr != nil { + a.reconnectingPTYs.Delete(msg.ID) + close(sendConnected) + } + }() + // Empty command will default to the users shell! cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil) if err != nil { @@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m return xerrors.Errorf("start command: %w", err) } - ctx, cancelFunc := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(ctx) rpty = &reconnectingPTY{ activeConns: map[string]net.Conn{ // We have to put the connection in the map instantly otherwise @@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m }, ptty: ptty, // Timeouts created with an after func can be reset! - timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc), + timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel), circularBuffer: circularBuffer, } - a.reconnectingPTYs.Store(msg.ID, rpty) // We don't need to separately monitor for the process exiting. // When it exits, our ptty.OutputReader() will return EOF after // reading all process output. @@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m rpty.Close() a.reconnectingPTYs.Delete(msg.ID) }); err != nil { + _ = process.Kill() + _ = ptty.Close() return xerrors.Errorf("start routine: %w", err) } + connected = true + sendConnected <- rpty } // Resize the PTY to initial height + width. err := rpty.ptty.Resize(msg.Height, msg.Width)