Skip to content

Commit 2e983be

Browse files
committed
Merge branch 'main' into updatetf
2 parents 24f99e6 + 81577f1 commit 2e983be

39 files changed

+1611
-139
lines changed

.vscode/settings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"cSpell.words": [
3+
"circbuf",
34
"cliflag",
45
"cliui",
56
"coderd",
@@ -47,6 +48,7 @@
4748
"ptty",
4849
"ptytest",
4950
"retrier",
51+
"rpty",
5052
"sdkproto",
5153
"Signup",
5254
"stretchr",
@@ -60,8 +62,10 @@
6062
"unconvert",
6163
"Untar",
6264
"VMID",
65+
"weblinks",
6366
"webrtc",
6467
"xerrors",
68+
"xstate",
6569
"yamux"
6670
],
6771
"emeraldwalk.runonsave": {

agent/agent.go

Lines changed: 256 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/rand"
66
"crypto/rsa"
7+
"encoding/json"
78
"errors"
89
"fmt"
910
"io"
@@ -12,10 +13,14 @@ import (
1213
"os/exec"
1314
"os/user"
1415
"runtime"
16+
"strconv"
1517
"strings"
1618
"sync"
1719
"time"
1820

21+
"github.com/armon/circbuf"
22+
"github.com/google/uuid"
23+
1924
gsyslog "github.com/hashicorp/go-syslog"
2025
"go.uber.org/atomic"
2126

@@ -33,6 +38,11 @@ import (
3338
"golang.org/x/xerrors"
3439
)
3540

41+
type Options struct {
42+
ReconnectingPTYTimeout time.Duration
43+
Logger slog.Logger
44+
}
45+
3646
type Metadata struct {
3747
OwnerEmail string `json:"owner_email"`
3848
OwnerUsername string `json:"owner_username"`
@@ -43,13 +53,20 @@ type Metadata struct {
4353

4454
type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error)
4555

46-
func New(dialer Dialer, logger slog.Logger) io.Closer {
56+
func New(dialer Dialer, options *Options) io.Closer {
57+
if options == nil {
58+
options = &Options{}
59+
}
60+
if options.ReconnectingPTYTimeout == 0 {
61+
options.ReconnectingPTYTimeout = 5 * time.Minute
62+
}
4763
ctx, cancelFunc := context.WithCancel(context.Background())
4864
server := &agent{
49-
dialer: dialer,
50-
logger: logger,
51-
closeCancel: cancelFunc,
52-
closed: make(chan struct{}),
65+
dialer: dialer,
66+
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
67+
logger: options.Logger,
68+
closeCancel: cancelFunc,
69+
closed: make(chan struct{}),
5370
}
5471
server.init(ctx)
5572
return server
@@ -59,6 +76,9 @@ type agent struct {
5976
dialer Dialer
6077
logger slog.Logger
6178

79+
reconnectingPTYs sync.Map
80+
reconnectingPTYTimeout time.Duration
81+
6282
connCloseWait sync.WaitGroup
6383
closeCancel context.CancelFunc
6484
closeMutex sync.Mutex
@@ -199,6 +219,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
199219
switch channel.Protocol() {
200220
case "ssh":
201221
go a.sshServer.HandleConn(channel.NetConn())
222+
case "reconnecting-pty":
223+
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
202224
default:
203225
a.logger.Warn(ctx, "unhandled protocol from channel",
204226
slog.F("protocol", channel.Protocol()),
@@ -285,22 +307,25 @@ func (a *agent) init(ctx context.Context) {
285307
go a.run(ctx)
286308
}
287309

288-
func (a *agent) handleSSHSession(session ssh.Session) error {
310+
// createCommand processes raw command input with OpenSSH-like behavior.
311+
// If the rawCommand provided is empty, it will default to the users shell.
312+
// This injects environment variables specified by the user at launch too.
313+
func (a *agent) createCommand(ctx context.Context, rawCommand string, env []string) (*exec.Cmd, error) {
289314
currentUser, err := user.Current()
290315
if err != nil {
291-
return xerrors.Errorf("get current user: %w", err)
316+
return nil, xerrors.Errorf("get current user: %w", err)
292317
}
293318
username := currentUser.Username
294319

295320
shell, err := usershell.Get(username)
296321
if err != nil {
297-
return xerrors.Errorf("get user shell: %w", err)
322+
return nil, xerrors.Errorf("get user shell: %w", err)
298323
}
299324

300325
// gliderlabs/ssh returns a command slice of zero
301326
// when a shell is requested.
302-
command := session.RawCommand()
303-
if len(session.Command()) == 0 {
327+
command := rawCommand
328+
if len(command) == 0 {
304329
command = shell
305330
}
306331

@@ -310,16 +335,16 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
310335
if runtime.GOOS == "windows" {
311336
caller = "/c"
312337
}
313-
cmd := exec.CommandContext(session.Context(), shell, caller, command)
338+
cmd := exec.CommandContext(ctx, shell, caller, command)
314339
cmd.Dir = a.directory.Load()
315340
if cmd.Dir == "" {
316341
// Default to $HOME if a directory is not set!
317342
cmd.Dir = os.Getenv("HOME")
318343
}
319-
cmd.Env = append(os.Environ(), session.Environ()...)
344+
cmd.Env = append(os.Environ(), env...)
320345
executablePath, err := os.Executable()
321346
if err != nil {
322-
return xerrors.Errorf("getting os executable: %w", err)
347+
return nil, xerrors.Errorf("getting os executable: %w", err)
323348
}
324349
// Git on Windows resolves with UNIX-style paths.
325350
// If using backslashes, it's unable to find the executable.
@@ -340,6 +365,14 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
340365
}
341366
}
342367
}
368+
return cmd, nil
369+
}
370+
371+
func (a *agent) handleSSHSession(session ssh.Session) error {
372+
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
373+
if err != nil {
374+
return err
375+
}
343376

344377
sshPty, windowSize, isPty := session.Pty()
345378
if isPty {
@@ -389,6 +422,194 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
389422
return cmd.Wait()
390423
}
391424

425+
func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn net.Conn) {
426+
defer conn.Close()
427+
428+
// The ID format is referenced in conn.go.
429+
// <uuid>:<height>:<width>
430+
idParts := strings.Split(rawID, ":")
431+
if len(idParts) != 3 {
432+
a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID))
433+
return
434+
}
435+
id := idParts[0]
436+
// Enforce a consistent format for IDs.
437+
_, err := uuid.Parse(id)
438+
if err != nil {
439+
a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err))
440+
return
441+
}
442+
// Parse the initial terminal dimensions.
443+
height, err := strconv.Atoi(idParts[1])
444+
if err != nil {
445+
a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1]))
446+
return
447+
}
448+
width, err := strconv.Atoi(idParts[2])
449+
if err != nil {
450+
a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2]))
451+
return
452+
}
453+
454+
var rpty *reconnectingPTY
455+
rawRPTY, ok := a.reconnectingPTYs.Load(id)
456+
if ok {
457+
rpty, ok = rawRPTY.(*reconnectingPTY)
458+
if !ok {
459+
a.logger.Warn(ctx, "found invalid type in reconnecting pty map", slog.F("id", id))
460+
}
461+
} else {
462+
// Empty command will default to the users shell!
463+
cmd, err := a.createCommand(ctx, "", nil)
464+
if err != nil {
465+
a.logger.Warn(ctx, "create reconnecting pty command", slog.Error(err))
466+
return
467+
}
468+
cmd.Env = append(cmd.Env, "TERM=xterm-256color")
469+
470+
ptty, process, err := pty.Start(cmd)
471+
if err != nil {
472+
a.logger.Warn(ctx, "start reconnecting pty command", slog.F("id", id))
473+
}
474+
475+
// Default to buffer 64KB.
476+
circularBuffer, err := circbuf.NewBuffer(64 * 1024)
477+
if err != nil {
478+
a.logger.Warn(ctx, "create circular buffer", slog.Error(err))
479+
return
480+
}
481+
482+
a.closeMutex.Lock()
483+
a.connCloseWait.Add(1)
484+
a.closeMutex.Unlock()
485+
ctx, cancelFunc := context.WithCancel(ctx)
486+
rpty = &reconnectingPTY{
487+
activeConns: make(map[string]net.Conn),
488+
ptty: ptty,
489+
// Timeouts created with an after func can be reset!
490+
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
491+
circularBuffer: circularBuffer,
492+
}
493+
a.reconnectingPTYs.Store(id, rpty)
494+
go func() {
495+
// CommandContext isn't respected for Windows PTYs right now,
496+
// so we need to manually track the lifecycle.
497+
// When the context has been completed either:
498+
// 1. The timeout completed.
499+
// 2. The parent context was canceled.
500+
<-ctx.Done()
501+
_ = process.Kill()
502+
}()
503+
go func() {
504+
// If the process dies randomly, we should
505+
// close the pty.
506+
_, _ = process.Wait()
507+
rpty.Close()
508+
}()
509+
go func() {
510+
buffer := make([]byte, 1024)
511+
for {
512+
read, err := rpty.ptty.Output().Read(buffer)
513+
if err != nil {
514+
// When the PTY is closed, this is triggered.
515+
break
516+
}
517+
part := buffer[:read]
518+
_, err = rpty.circularBuffer.Write(part)
519+
if err != nil {
520+
a.logger.Error(ctx, "reconnecting pty write buffer", slog.Error(err), slog.F("id", id))
521+
break
522+
}
523+
rpty.activeConnsMutex.Lock()
524+
for _, conn := range rpty.activeConns {
525+
_, _ = conn.Write(part)
526+
}
527+
rpty.activeConnsMutex.Unlock()
528+
}
529+
530+
// Cleanup the process, PTY, and delete it's
531+
// ID from memory.
532+
_ = process.Kill()
533+
rpty.Close()
534+
a.reconnectingPTYs.Delete(id)
535+
a.connCloseWait.Done()
536+
}()
537+
}
538+
// Resize the PTY to initial height + width.
539+
err = rpty.ptty.Resize(uint16(height), uint16(width))
540+
if err != nil {
541+
// We can continue after this, it's not fatal!
542+
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err))
543+
}
544+
// Write any previously stored data for the TTY.
545+
_, err = conn.Write(rpty.circularBuffer.Bytes())
546+
if err != nil {
547+
a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", id), slog.Error(err))
548+
return
549+
}
550+
connectionID := uuid.NewString()
551+
// Multiple connections to the same TTY are permitted.
552+
// This could easily be used for terminal sharing, but
553+
// we do it because it's a nice user experience to
554+
// copy/paste a terminal URL and have it _just work_.
555+
rpty.activeConnsMutex.Lock()
556+
rpty.activeConns[connectionID] = conn
557+
rpty.activeConnsMutex.Unlock()
558+
// Resetting this timeout prevents the PTY from exiting.
559+
rpty.timeout.Reset(a.reconnectingPTYTimeout)
560+
561+
ctx, cancelFunc := context.WithCancel(ctx)
562+
defer cancelFunc()
563+
heartbeat := time.NewTicker(a.reconnectingPTYTimeout / 2)
564+
defer heartbeat.Stop()
565+
go func() {
566+
// Keep updating the activity while this
567+
// connection is alive!
568+
for {
569+
select {
570+
case <-ctx.Done():
571+
return
572+
case <-heartbeat.C:
573+
}
574+
rpty.timeout.Reset(a.reconnectingPTYTimeout)
575+
}
576+
}()
577+
defer func() {
578+
// After this connection ends, remove it from
579+
// the PTYs active connections. If it isn't
580+
// removed, all PTY data will be sent to it.
581+
rpty.activeConnsMutex.Lock()
582+
delete(rpty.activeConns, connectionID)
583+
rpty.activeConnsMutex.Unlock()
584+
}()
585+
decoder := json.NewDecoder(conn)
586+
var req ReconnectingPTYRequest
587+
for {
588+
err = decoder.Decode(&req)
589+
if xerrors.Is(err, io.EOF) {
590+
return
591+
}
592+
if err != nil {
593+
a.logger.Warn(ctx, "reconnecting pty buffer read error", slog.F("id", id), slog.Error(err))
594+
return
595+
}
596+
_, err = rpty.ptty.Input().Write([]byte(req.Data))
597+
if err != nil {
598+
a.logger.Warn(ctx, "write to reconnecting pty", slog.F("id", id), slog.Error(err))
599+
return
600+
}
601+
// Check if a resize needs to happen!
602+
if req.Height == 0 || req.Width == 0 {
603+
continue
604+
}
605+
err = rpty.ptty.Resize(req.Height, req.Width)
606+
if err != nil {
607+
// We can continue after this, it's not fatal!
608+
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err))
609+
}
610+
}
611+
}
612+
392613
// isClosed returns whether the API is closed or not.
393614
func (a *agent) isClosed() bool {
394615
select {
@@ -411,3 +632,25 @@ func (a *agent) Close() error {
411632
a.connCloseWait.Wait()
412633
return nil
413634
}
635+
636+
type reconnectingPTY struct {
637+
activeConnsMutex sync.Mutex
638+
activeConns map[string]net.Conn
639+
640+
circularBuffer *circbuf.Buffer
641+
timeout *time.Timer
642+
ptty pty.PTY
643+
}
644+
645+
// Close ends all connections to the reconnecting
646+
// PTY and clear the circular buffer.
647+
func (r *reconnectingPTY) Close() {
648+
r.activeConnsMutex.Lock()
649+
defer r.activeConnsMutex.Unlock()
650+
for _, conn := range r.activeConns {
651+
_ = conn.Close()
652+
}
653+
_ = r.ptty.Close()
654+
r.circularBuffer.Reset()
655+
r.timeout.Stop()
656+
}

0 commit comments

Comments
 (0)