Skip to content

Commit 81577f1

Browse files
authored
feat: Add web terminal with reconnecting TTYs (coder#1186)
* feat: Add web terminal with reconnecting TTYs This adds a web terminal that can reconnect to resume sessions! No more disconnects, and no more bad bufferring! * Add xstate service * Add the webpage for accessing a web terminal * Add terminal page tests * Use Ticker instead of Timer * Active Windows mode on Windows
1 parent 23e5636 commit 81577f1

28 files changed

+1448
-39
lines changed

.vscode/settings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"cSpell.words": [
3+
"circbuf",
34
"cliflag",
45
"cliui",
56
"coderd",
@@ -47,6 +48,7 @@
4748
"ptty",
4849
"ptytest",
4950
"retrier",
51+
"rpty",
5052
"sdkproto",
5153
"Signup",
5254
"stretchr",
@@ -60,8 +62,10 @@
6062
"unconvert",
6163
"Untar",
6264
"VMID",
65+
"weblinks",
6366
"webrtc",
6467
"xerrors",
68+
"xstate",
6569
"yamux"
6670
],
6771
"emeraldwalk.runonsave": {

agent/agent.go

Lines changed: 256 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/rand"
66
"crypto/rsa"
7+
"encoding/json"
78
"errors"
89
"fmt"
910
"io"
@@ -12,10 +13,14 @@ import (
1213
"os/exec"
1314
"os/user"
1415
"runtime"
16+
"strconv"
1517
"strings"
1618
"sync"
1719
"time"
1820

21+
"github.com/armon/circbuf"
22+
"github.com/google/uuid"
23+
1924
gsyslog "github.com/hashicorp/go-syslog"
2025
"go.uber.org/atomic"
2126

@@ -33,6 +38,11 @@ import (
3338
"golang.org/x/xerrors"
3439
)
3540

41+
type Options struct {
42+
ReconnectingPTYTimeout time.Duration
43+
Logger slog.Logger
44+
}
45+
3646
type Metadata struct {
3747
OwnerEmail string `json:"owner_email"`
3848
OwnerUsername string `json:"owner_username"`
@@ -42,13 +52,20 @@ type Metadata struct {
4252

4353
type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error)
4454

45-
func New(dialer Dialer, logger slog.Logger) io.Closer {
55+
func New(dialer Dialer, options *Options) io.Closer {
56+
if options == nil {
57+
options = &Options{}
58+
}
59+
if options.ReconnectingPTYTimeout == 0 {
60+
options.ReconnectingPTYTimeout = 5 * time.Minute
61+
}
4662
ctx, cancelFunc := context.WithCancel(context.Background())
4763
server := &agent{
48-
dialer: dialer,
49-
logger: logger,
50-
closeCancel: cancelFunc,
51-
closed: make(chan struct{}),
64+
dialer: dialer,
65+
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
66+
logger: options.Logger,
67+
closeCancel: cancelFunc,
68+
closed: make(chan struct{}),
5269
}
5370
server.init(ctx)
5471
return server
@@ -58,6 +75,9 @@ type agent struct {
5875
dialer Dialer
5976
logger slog.Logger
6077

78+
reconnectingPTYs sync.Map
79+
reconnectingPTYTimeout time.Duration
80+
6181
connCloseWait sync.WaitGroup
6282
closeCancel context.CancelFunc
6383
closeMutex sync.Mutex
@@ -196,6 +216,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
196216
switch channel.Protocol() {
197217
case "ssh":
198218
go a.sshServer.HandleConn(channel.NetConn())
219+
case "reconnecting-pty":
220+
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
199221
default:
200222
a.logger.Warn(ctx, "unhandled protocol from channel",
201223
slog.F("protocol", channel.Protocol()),
@@ -282,22 +304,25 @@ func (a *agent) init(ctx context.Context) {
282304
go a.run(ctx)
283305
}
284306

285-
func (a *agent) handleSSHSession(session ssh.Session) error {
307+
// createCommand processes raw command input with OpenSSH-like behavior.
308+
// If the rawCommand provided is empty, it will default to the users shell.
309+
// This injects environment variables specified by the user at launch too.
310+
func (a *agent) createCommand(ctx context.Context, rawCommand string, env []string) (*exec.Cmd, error) {
286311
currentUser, err := user.Current()
287312
if err != nil {
288-
return xerrors.Errorf("get current user: %w", err)
313+
return nil, xerrors.Errorf("get current user: %w", err)
289314
}
290315
username := currentUser.Username
291316

292317
shell, err := usershell.Get(username)
293318
if err != nil {
294-
return xerrors.Errorf("get user shell: %w", err)
319+
return nil, xerrors.Errorf("get user shell: %w", err)
295320
}
296321

297322
// gliderlabs/ssh returns a command slice of zero
298323
// when a shell is requested.
299-
command := session.RawCommand()
300-
if len(session.Command()) == 0 {
324+
command := rawCommand
325+
if len(command) == 0 {
301326
command = shell
302327
}
303328

@@ -307,11 +332,11 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
307332
if runtime.GOOS == "windows" {
308333
caller = "/c"
309334
}
310-
cmd := exec.CommandContext(session.Context(), shell, caller, command)
311-
cmd.Env = append(os.Environ(), session.Environ()...)
335+
cmd := exec.CommandContext(ctx, shell, caller, command)
336+
cmd.Env = append(os.Environ(), env...)
312337
executablePath, err := os.Executable()
313338
if err != nil {
314-
return xerrors.Errorf("getting os executable: %w", err)
339+
return nil, xerrors.Errorf("getting os executable: %w", err)
315340
}
316341
// Git on Windows resolves with UNIX-style paths.
317342
// If using backslashes, it's unable to find the executable.
@@ -332,6 +357,14 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
332357
}
333358
}
334359
}
360+
return cmd, nil
361+
}
362+
363+
func (a *agent) handleSSHSession(session ssh.Session) error {
364+
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
365+
if err != nil {
366+
return err
367+
}
335368

336369
sshPty, windowSize, isPty := session.Pty()
337370
if isPty {
@@ -381,6 +414,194 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
381414
return cmd.Wait()
382415
}
383416

417+
func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn net.Conn) {
418+
defer conn.Close()
419+
420+
// The ID format is referenced in conn.go.
421+
// <uuid>:<height>:<width>
422+
idParts := strings.Split(rawID, ":")
423+
if len(idParts) != 3 {
424+
a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID))
425+
return
426+
}
427+
id := idParts[0]
428+
// Enforce a consistent format for IDs.
429+
_, err := uuid.Parse(id)
430+
if err != nil {
431+
a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err))
432+
return
433+
}
434+
// Parse the initial terminal dimensions.
435+
height, err := strconv.Atoi(idParts[1])
436+
if err != nil {
437+
a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1]))
438+
return
439+
}
440+
width, err := strconv.Atoi(idParts[2])
441+
if err != nil {
442+
a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2]))
443+
return
444+
}
445+
446+
var rpty *reconnectingPTY
447+
rawRPTY, ok := a.reconnectingPTYs.Load(id)
448+
if ok {
449+
rpty, ok = rawRPTY.(*reconnectingPTY)
450+
if !ok {
451+
a.logger.Warn(ctx, "found invalid type in reconnecting pty map", slog.F("id", id))
452+
}
453+
} else {
454+
// Empty command will default to the users shell!
455+
cmd, err := a.createCommand(ctx, "", nil)
456+
if err != nil {
457+
a.logger.Warn(ctx, "create reconnecting pty command", slog.Error(err))
458+
return
459+
}
460+
cmd.Env = append(cmd.Env, "TERM=xterm-256color")
461+
462+
ptty, process, err := pty.Start(cmd)
463+
if err != nil {
464+
a.logger.Warn(ctx, "start reconnecting pty command", slog.F("id", id))
465+
}
466+
467+
// Default to buffer 64KB.
468+
circularBuffer, err := circbuf.NewBuffer(64 * 1024)
469+
if err != nil {
470+
a.logger.Warn(ctx, "create circular buffer", slog.Error(err))
471+
return
472+
}
473+
474+
a.closeMutex.Lock()
475+
a.connCloseWait.Add(1)
476+
a.closeMutex.Unlock()
477+
ctx, cancelFunc := context.WithCancel(ctx)
478+
rpty = &reconnectingPTY{
479+
activeConns: make(map[string]net.Conn),
480+
ptty: ptty,
481+
// Timeouts created with an after func can be reset!
482+
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
483+
circularBuffer: circularBuffer,
484+
}
485+
a.reconnectingPTYs.Store(id, rpty)
486+
go func() {
487+
// CommandContext isn't respected for Windows PTYs right now,
488+
// so we need to manually track the lifecycle.
489+
// When the context has been completed either:
490+
// 1. The timeout completed.
491+
// 2. The parent context was canceled.
492+
<-ctx.Done()
493+
_ = process.Kill()
494+
}()
495+
go func() {
496+
// If the process dies randomly, we should
497+
// close the pty.
498+
_, _ = process.Wait()
499+
rpty.Close()
500+
}()
501+
go func() {
502+
buffer := make([]byte, 1024)
503+
for {
504+
read, err := rpty.ptty.Output().Read(buffer)
505+
if err != nil {
506+
// When the PTY is closed, this is triggered.
507+
break
508+
}
509+
part := buffer[:read]
510+
_, err = rpty.circularBuffer.Write(part)
511+
if err != nil {
512+
a.logger.Error(ctx, "reconnecting pty write buffer", slog.Error(err), slog.F("id", id))
513+
break
514+
}
515+
rpty.activeConnsMutex.Lock()
516+
for _, conn := range rpty.activeConns {
517+
_, _ = conn.Write(part)
518+
}
519+
rpty.activeConnsMutex.Unlock()
520+
}
521+
522+
// Cleanup the process, PTY, and delete it's
523+
// ID from memory.
524+
_ = process.Kill()
525+
rpty.Close()
526+
a.reconnectingPTYs.Delete(id)
527+
a.connCloseWait.Done()
528+
}()
529+
}
530+
// Resize the PTY to initial height + width.
531+
err = rpty.ptty.Resize(uint16(height), uint16(width))
532+
if err != nil {
533+
// We can continue after this, it's not fatal!
534+
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err))
535+
}
536+
// Write any previously stored data for the TTY.
537+
_, err = conn.Write(rpty.circularBuffer.Bytes())
538+
if err != nil {
539+
a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", id), slog.Error(err))
540+
return
541+
}
542+
connectionID := uuid.NewString()
543+
// Multiple connections to the same TTY are permitted.
544+
// This could easily be used for terminal sharing, but
545+
// we do it because it's a nice user experience to
546+
// copy/paste a terminal URL and have it _just work_.
547+
rpty.activeConnsMutex.Lock()
548+
rpty.activeConns[connectionID] = conn
549+
rpty.activeConnsMutex.Unlock()
550+
// Resetting this timeout prevents the PTY from exiting.
551+
rpty.timeout.Reset(a.reconnectingPTYTimeout)
552+
553+
ctx, cancelFunc := context.WithCancel(ctx)
554+
defer cancelFunc()
555+
heartbeat := time.NewTicker(a.reconnectingPTYTimeout / 2)
556+
defer heartbeat.Stop()
557+
go func() {
558+
// Keep updating the activity while this
559+
// connection is alive!
560+
for {
561+
select {
562+
case <-ctx.Done():
563+
return
564+
case <-heartbeat.C:
565+
}
566+
rpty.timeout.Reset(a.reconnectingPTYTimeout)
567+
}
568+
}()
569+
defer func() {
570+
// After this connection ends, remove it from
571+
// the PTYs active connections. If it isn't
572+
// removed, all PTY data will be sent to it.
573+
rpty.activeConnsMutex.Lock()
574+
delete(rpty.activeConns, connectionID)
575+
rpty.activeConnsMutex.Unlock()
576+
}()
577+
decoder := json.NewDecoder(conn)
578+
var req ReconnectingPTYRequest
579+
for {
580+
err = decoder.Decode(&req)
581+
if xerrors.Is(err, io.EOF) {
582+
return
583+
}
584+
if err != nil {
585+
a.logger.Warn(ctx, "reconnecting pty buffer read error", slog.F("id", id), slog.Error(err))
586+
return
587+
}
588+
_, err = rpty.ptty.Input().Write([]byte(req.Data))
589+
if err != nil {
590+
a.logger.Warn(ctx, "write to reconnecting pty", slog.F("id", id), slog.Error(err))
591+
return
592+
}
593+
// Check if a resize needs to happen!
594+
if req.Height == 0 || req.Width == 0 {
595+
continue
596+
}
597+
err = rpty.ptty.Resize(req.Height, req.Width)
598+
if err != nil {
599+
// We can continue after this, it's not fatal!
600+
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err))
601+
}
602+
}
603+
}
604+
384605
// isClosed returns whether the API is closed or not.
385606
func (a *agent) isClosed() bool {
386607
select {
@@ -403,3 +624,25 @@ func (a *agent) Close() error {
403624
a.connCloseWait.Wait()
404625
return nil
405626
}
627+
628+
type reconnectingPTY struct {
629+
activeConnsMutex sync.Mutex
630+
activeConns map[string]net.Conn
631+
632+
circularBuffer *circbuf.Buffer
633+
timeout *time.Timer
634+
ptty pty.PTY
635+
}
636+
637+
// Close ends all connections to the reconnecting
638+
// PTY and clear the circular buffer.
639+
func (r *reconnectingPTY) Close() {
640+
r.activeConnsMutex.Lock()
641+
defer r.activeConnsMutex.Unlock()
642+
for _, conn := range r.activeConns {
643+
_ = conn.Close()
644+
}
645+
_ = r.ptty.Close()
646+
r.circularBuffer.Reset()
647+
r.timeout.Stop()
648+
}

0 commit comments

Comments
 (0)