Skip to content

feat: Add web terminal with reconnecting TTYs #1186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"cSpell.words": [
"circbuf",
"cliflag",
"cliui",
"coderd",
Expand Down Expand Up @@ -47,6 +48,7 @@
"ptty",
"ptytest",
"retrier",
"rpty",
"sdkproto",
"Signup",
"stretchr",
Expand All @@ -60,8 +62,10 @@
"unconvert",
"Untar",
"VMID",
"weblinks",
"webrtc",
"xerrors",
"xstate",
"yamux"
],
"emeraldwalk.runonsave": {
Expand Down
269 changes: 256 additions & 13 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -12,10 +13,14 @@ import (
"os/exec"
"os/user"
"runtime"
"strconv"
"strings"
"sync"
"time"

"github.com/armon/circbuf"
"github.com/google/uuid"

gsyslog "github.com/hashicorp/go-syslog"
"go.uber.org/atomic"

Expand All @@ -33,6 +38,11 @@ import (
"golang.org/x/xerrors"
)

type Options struct {
ReconnectingPTYTimeout time.Duration
Logger slog.Logger
}

type Metadata struct {
OwnerEmail string `json:"owner_email"`
OwnerUsername string `json:"owner_username"`
Expand All @@ -42,13 +52,20 @@ type Metadata struct {

type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error)

func New(dialer Dialer, logger slog.Logger) io.Closer {
func New(dialer Dialer, options *Options) io.Closer {
if options == nil {
options = &Options{}
}
if options.ReconnectingPTYTimeout == 0 {
options.ReconnectingPTYTimeout = 5 * time.Minute
}
ctx, cancelFunc := context.WithCancel(context.Background())
server := &agent{
dialer: dialer,
logger: logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
dialer: dialer,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
}
server.init(ctx)
return server
Expand All @@ -58,6 +75,9 @@ type agent struct {
dialer Dialer
logger slog.Logger

reconnectingPTYs sync.Map
reconnectingPTYTimeout time.Duration

connCloseWait sync.WaitGroup
closeCancel context.CancelFunc
closeMutex sync.Mutex
Expand Down Expand Up @@ -196,6 +216,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
switch channel.Protocol() {
case "ssh":
go a.sshServer.HandleConn(channel.NetConn())
case "reconnecting-pty":
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
default:
a.logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()),
Expand Down Expand Up @@ -282,22 +304,25 @@ func (a *agent) init(ctx context.Context) {
go a.run(ctx)
}

func (a *agent) handleSSHSession(session ssh.Session) error {
// createCommand processes raw command input with OpenSSH-like behavior.
// If the rawCommand provided is empty, it will default to the users shell.
// This injects environment variables specified by the user at launch too.
func (a *agent) createCommand(ctx context.Context, rawCommand string, env []string) (*exec.Cmd, error) {
currentUser, err := user.Current()
if err != nil {
return xerrors.Errorf("get current user: %w", err)
return nil, xerrors.Errorf("get current user: %w", err)
}
username := currentUser.Username

shell, err := usershell.Get(username)
if err != nil {
return xerrors.Errorf("get user shell: %w", err)
return nil, xerrors.Errorf("get user shell: %w", err)
}

// gliderlabs/ssh returns a command slice of zero
// when a shell is requested.
command := session.RawCommand()
if len(session.Command()) == 0 {
command := rawCommand
if len(command) == 0 {
command = shell
}

Expand All @@ -307,11 +332,11 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
if runtime.GOOS == "windows" {
caller = "/c"
}
cmd := exec.CommandContext(session.Context(), shell, caller, command)
cmd.Env = append(os.Environ(), session.Environ()...)
cmd := exec.CommandContext(ctx, shell, caller, command)
cmd.Env = append(os.Environ(), env...)
executablePath, err := os.Executable()
if err != nil {
return xerrors.Errorf("getting os executable: %w", err)
return nil, xerrors.Errorf("getting os executable: %w", err)
}
// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
Expand All @@ -332,6 +357,14 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
}
}
}
return cmd, nil
}

func (a *agent) handleSSHSession(session ssh.Session) error {
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
if err != nil {
return err
}

sshPty, windowSize, isPty := session.Pty()
if isPty {
Expand Down Expand Up @@ -381,6 +414,194 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
return cmd.Wait()
}

func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn net.Conn) {
defer conn.Close()

// The ID format is referenced in conn.go.
// <uuid>:<height>:<width>
idParts := strings.Split(rawID, ":")
if len(idParts) != 3 {
a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID))
return
}
id := idParts[0]
// Enforce a consistent format for IDs.
_, err := uuid.Parse(id)
if err != nil {
a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err))
return
}
// Parse the initial terminal dimensions.
height, err := strconv.Atoi(idParts[1])
if err != nil {
a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1]))
return
}
width, err := strconv.Atoi(idParts[2])
if err != nil {
a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2]))
return
}

var rpty *reconnectingPTY
rawRPTY, ok := a.reconnectingPTYs.Load(id)
if ok {
rpty, ok = rawRPTY.(*reconnectingPTY)
if !ok {
a.logger.Warn(ctx, "found invalid type in reconnecting pty map", slog.F("id", id))
}
} else {
// Empty command will default to the users shell!
cmd, err := a.createCommand(ctx, "", nil)
if err != nil {
a.logger.Warn(ctx, "create reconnecting pty command", slog.Error(err))
return
}
cmd.Env = append(cmd.Env, "TERM=xterm-256color")

ptty, process, err := pty.Start(cmd)
if err != nil {
a.logger.Warn(ctx, "start reconnecting pty command", slog.F("id", id))
}

// Default to buffer 64KB.
circularBuffer, err := circbuf.NewBuffer(64 * 1024)
if err != nil {
a.logger.Warn(ctx, "create circular buffer", slog.Error(err))
return
}

a.closeMutex.Lock()
a.connCloseWait.Add(1)
a.closeMutex.Unlock()
ctx, cancelFunc := context.WithCancel(ctx)
rpty = &reconnectingPTY{
activeConns: make(map[string]net.Conn),
ptty: ptty,
// Timeouts created with an after func can be reset!
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
circularBuffer: circularBuffer,
}
a.reconnectingPTYs.Store(id, rpty)
go func() {
// CommandContext isn't respected for Windows PTYs right now,
// so we need to manually track the lifecycle.
// When the context has been completed either:
// 1. The timeout completed.
// 2. The parent context was canceled.
<-ctx.Done()
_ = process.Kill()
}()
go func() {
// If the process dies randomly, we should
// close the pty.
_, _ = process.Wait()
rpty.Close()
}()
go func() {
buffer := make([]byte, 1024)
for {
read, err := rpty.ptty.Output().Read(buffer)
if err != nil {
// When the PTY is closed, this is triggered.
break
}
part := buffer[:read]
_, err = rpty.circularBuffer.Write(part)
if err != nil {
a.logger.Error(ctx, "reconnecting pty write buffer", slog.Error(err), slog.F("id", id))
break
}
rpty.activeConnsMutex.Lock()
for _, conn := range rpty.activeConns {
_, _ = conn.Write(part)
}
rpty.activeConnsMutex.Unlock()
}

// Cleanup the process, PTY, and delete it's
// ID from memory.
_ = process.Kill()
rpty.Close()
a.reconnectingPTYs.Delete(id)
a.connCloseWait.Done()
}()
}
// Resize the PTY to initial height + width.
err = rpty.ptty.Resize(uint16(height), uint16(width))
if err != nil {
// We can continue after this, it's not fatal!
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err))
}
// Write any previously stored data for the TTY.
_, err = conn.Write(rpty.circularBuffer.Bytes())
if err != nil {
a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", id), slog.Error(err))
return
}
connectionID := uuid.NewString()
// Multiple connections to the same TTY are permitted.
// This could easily be used for terminal sharing, but
// we do it because it's a nice user experience to
// copy/paste a terminal URL and have it _just work_.
rpty.activeConnsMutex.Lock()
rpty.activeConns[connectionID] = conn
rpty.activeConnsMutex.Unlock()
// Resetting this timeout prevents the PTY from exiting.
rpty.timeout.Reset(a.reconnectingPTYTimeout)

ctx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
heartbeat := time.NewTicker(a.reconnectingPTYTimeout / 2)
defer heartbeat.Stop()
go func() {
// Keep updating the activity while this
// connection is alive!
for {
select {
case <-ctx.Done():
return
case <-heartbeat.C:
}
rpty.timeout.Reset(a.reconnectingPTYTimeout)
}
}()
defer func() {
// After this connection ends, remove it from
// the PTYs active connections. If it isn't
// removed, all PTY data will be sent to it.
rpty.activeConnsMutex.Lock()
delete(rpty.activeConns, connectionID)
rpty.activeConnsMutex.Unlock()
}()
decoder := json.NewDecoder(conn)
var req ReconnectingPTYRequest
for {
err = decoder.Decode(&req)
if xerrors.Is(err, io.EOF) {
return
}
if err != nil {
a.logger.Warn(ctx, "reconnecting pty buffer read error", slog.F("id", id), slog.Error(err))
return
}
_, err = rpty.ptty.Input().Write([]byte(req.Data))
if err != nil {
a.logger.Warn(ctx, "write to reconnecting pty", slog.F("id", id), slog.Error(err))
return
}
// Check if a resize needs to happen!
if req.Height == 0 || req.Width == 0 {
continue
}
err = rpty.ptty.Resize(req.Height, req.Width)
if err != nil {
// We can continue after this, it's not fatal!
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err))
}
}
}

// isClosed returns whether the API is closed or not.
func (a *agent) isClosed() bool {
select {
Expand All @@ -403,3 +624,25 @@ func (a *agent) Close() error {
a.connCloseWait.Wait()
return nil
}

type reconnectingPTY struct {
activeConnsMutex sync.Mutex
activeConns map[string]net.Conn

circularBuffer *circbuf.Buffer
timeout *time.Timer
ptty pty.PTY
}

// Close ends all connections to the reconnecting
// PTY and clear the circular buffer.
func (r *reconnectingPTY) Close() {
r.activeConnsMutex.Lock()
defer r.activeConnsMutex.Unlock()
for _, conn := range r.activeConns {
_ = conn.Close()
}
_ = r.ptty.Close()
r.circularBuffer.Reset()
r.timeout.Stop()
}
Loading