Skip to content

Commit 47e0e98

Browse files
committed
feat: Add web terminal with reconnecting TTYs
This adds a web terminal that can reconnect to resume sessions! No more disconnects, and no more bad bufferring!
1 parent 603b7da commit 47e0e98

File tree

14 files changed

+535
-23
lines changed

14 files changed

+535
-23
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"ptty",
4848
"ptytest",
4949
"retrier",
50+
"rpty",
5051
"sdkproto",
5152
"Signup",
5253
"stretchr",

agent/agent.go

Lines changed: 199 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/rand"
66
"crypto/rsa"
7+
"encoding/json"
78
"errors"
89
"fmt"
910
"io"
@@ -12,10 +13,14 @@ import (
1213
"os/exec"
1314
"os/user"
1415
"runtime"
16+
"strconv"
1517
"strings"
1618
"sync"
1719
"time"
1820

21+
"github.com/google/uuid"
22+
"github.com/smallnest/ringbuffer"
23+
1924
gsyslog "github.com/hashicorp/go-syslog"
2025
"go.uber.org/atomic"
2126

@@ -33,6 +38,11 @@ import (
3338
"golang.org/x/xerrors"
3439
)
3540

41+
type Options struct {
42+
ReconnectingPTYTimeout time.Duration
43+
Logger slog.Logger
44+
}
45+
3646
type Metadata struct {
3747
OwnerEmail string `json:"owner_email"`
3848
OwnerUsername string `json:"owner_username"`
@@ -42,13 +52,20 @@ type Metadata struct {
4252

4353
type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error)
4454

45-
func New(dialer Dialer, logger slog.Logger) io.Closer {
55+
func New(dialer Dialer, options *Options) io.Closer {
56+
if options == nil {
57+
options = &Options{}
58+
}
59+
if options.ReconnectingPTYTimeout == 0 {
60+
options.ReconnectingPTYTimeout = 5 * time.Minute
61+
}
4662
ctx, cancelFunc := context.WithCancel(context.Background())
4763
server := &agent{
48-
dialer: dialer,
49-
logger: logger,
50-
closeCancel: cancelFunc,
51-
closed: make(chan struct{}),
64+
dialer: dialer,
65+
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
66+
logger: options.Logger,
67+
closeCancel: cancelFunc,
68+
closed: make(chan struct{}),
5269
}
5370
server.init(ctx)
5471
return server
@@ -58,6 +75,9 @@ type agent struct {
5875
dialer Dialer
5976
logger slog.Logger
6077

78+
reconnectingPTYs sync.Map
79+
reconnectingPTYTimeout time.Duration
80+
6181
connCloseWait sync.WaitGroup
6282
closeCancel context.CancelFunc
6383
closeMutex sync.Mutex
@@ -196,6 +216,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
196216
switch channel.Protocol() {
197217
case "ssh":
198218
go a.sshServer.HandleConn(channel.NetConn())
219+
case "reconnecting-pty":
220+
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
199221
default:
200222
a.logger.Warn(ctx, "unhandled protocol from channel",
201223
slog.F("protocol", channel.Protocol()),
@@ -282,22 +304,25 @@ func (a *agent) init(ctx context.Context) {
282304
go a.run(ctx)
283305
}
284306

285-
func (a *agent) handleSSHSession(session ssh.Session) error {
307+
// createCommand processes raw command input with OpenSSH-like behavior.
308+
// If the rawCommand provided is empty, it will default to the users shell.
309+
// This injects environment variables specified by the user at launch too.
310+
func (a *agent) createCommand(ctx context.Context, rawCommand string, env []string) (*exec.Cmd, error) {
286311
currentUser, err := user.Current()
287312
if err != nil {
288-
return xerrors.Errorf("get current user: %w", err)
313+
return nil, xerrors.Errorf("get current user: %w", err)
289314
}
290315
username := currentUser.Username
291316

292317
shell, err := usershell.Get(username)
293318
if err != nil {
294-
return xerrors.Errorf("get user shell: %w", err)
319+
return nil, xerrors.Errorf("get user shell: %w", err)
295320
}
296321

297322
// gliderlabs/ssh returns a command slice of zero
298323
// when a shell is requested.
299-
command := session.RawCommand()
300-
if len(session.Command()) == 0 {
324+
command := rawCommand
325+
if len(command) == 0 {
301326
command = shell
302327
}
303328

@@ -307,11 +332,11 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
307332
if runtime.GOOS == "windows" {
308333
caller = "/c"
309334
}
310-
cmd := exec.CommandContext(session.Context(), shell, caller, command)
311-
cmd.Env = append(os.Environ(), session.Environ()...)
335+
cmd := exec.CommandContext(ctx, shell, caller, command)
336+
cmd.Env = append(os.Environ(), env...)
312337
executablePath, err := os.Executable()
313338
if err != nil {
314-
return xerrors.Errorf("getting os executable: %w", err)
339+
return nil, xerrors.Errorf("getting os executable: %w", err)
315340
}
316341
// Git on Windows resolves with UNIX-style paths.
317342
// If using backslashes, it's unable to find the executable.
@@ -332,6 +357,14 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
332357
}
333358
}
334359
}
360+
return cmd, nil
361+
}
362+
363+
func (a *agent) handleSSHSession(session ssh.Session) error {
364+
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
365+
if err != nil {
366+
return err
367+
}
335368

336369
sshPty, windowSize, isPty := session.Pty()
337370
if isPty {
@@ -381,6 +414,140 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
381414
return cmd.Wait()
382415
}
383416

417+
func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn net.Conn) {
418+
defer conn.Close()
419+
420+
idParts := strings.Split(rawID, ":")
421+
if len(idParts) != 3 {
422+
a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID))
423+
return
424+
}
425+
id := idParts[0]
426+
// Enforce a consistent format for IDs.
427+
_, err := uuid.Parse(id)
428+
if err != nil {
429+
a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err))
430+
return
431+
}
432+
height, err := strconv.Atoi(idParts[1])
433+
if err != nil {
434+
a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1]))
435+
return
436+
}
437+
width, err := strconv.Atoi(idParts[2])
438+
if err != nil {
439+
a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2]))
440+
return
441+
}
442+
443+
var rpty *reconnectingPTY
444+
rawRPTY, ok := a.reconnectingPTYs.Load(id)
445+
if ok {
446+
rpty, ok = rawRPTY.(*reconnectingPTY)
447+
if !ok {
448+
a.logger.Warn(ctx, "found invalid type in reconnecting pty map", slog.F("id", id))
449+
}
450+
} else {
451+
// Empty command will default to the users shell!
452+
cmd, err := a.createCommand(ctx, "", nil)
453+
if err != nil {
454+
a.logger.Warn(ctx, "create reconnecting pty command", slog.Error(err))
455+
return
456+
}
457+
ptty, _, err := pty.Start(cmd)
458+
if err != nil {
459+
a.logger.Warn(ctx, "start reconnecting pty command", slog.F("id", id))
460+
}
461+
462+
rpty = &reconnectingPTY{
463+
activeConns: make(map[string]net.Conn),
464+
ptty: ptty,
465+
timeout: time.NewTimer(a.reconnectingPTYTimeout),
466+
// Default to buffer 1MB.
467+
ringBuffer: ringbuffer.New(1 << 20),
468+
}
469+
a.reconnectingPTYs.Store(id, rpty)
470+
go func() {
471+
// Close if the inactive timeout occurs, or the context ends.
472+
select {
473+
case <-rpty.timeout.C:
474+
a.logger.Info(ctx, "killing reconnecting pty due to inactivity", slog.F("id", id))
475+
case <-ctx.Done():
476+
}
477+
rpty.Close()
478+
}()
479+
go func() {
480+
buffer := make([]byte, 32*1024)
481+
for {
482+
read, err := rpty.ptty.Output().Read(buffer)
483+
if err != nil {
484+
rpty.Close()
485+
break
486+
}
487+
part := buffer[:read]
488+
_, err = rpty.ringBuffer.Write(part)
489+
if err != nil {
490+
a.logger.Error(ctx, "reconnecting pty write buffer", slog.Error(err), slog.F("id", id))
491+
return
492+
}
493+
rpty.activeConnsMutex.Lock()
494+
for _, conn := range rpty.activeConns {
495+
_, _ = conn.Write(part)
496+
}
497+
rpty.activeConnsMutex.Unlock()
498+
}
499+
// If we break from the loop, the reconnecting PTY ended.
500+
a.reconnectingPTYs.Delete(id)
501+
}()
502+
}
503+
err = rpty.ptty.Resize(uint16(height), uint16(width))
504+
if err != nil {
505+
// We can continue after this, it's not fatal!
506+
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err))
507+
}
508+
509+
_, err = conn.Write(rpty.ringBuffer.Bytes())
510+
if err != nil {
511+
a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", id), slog.Error(err))
512+
return
513+
}
514+
connectionID := uuid.NewString()
515+
rpty.activeConnsMutex.Lock()
516+
rpty.activeConns[connectionID] = conn
517+
rpty.activeConnsMutex.Unlock()
518+
defer func() {
519+
rpty.activeConnsMutex.Lock()
520+
delete(rpty.activeConns, connectionID)
521+
rpty.activeConnsMutex.Unlock()
522+
}()
523+
decoder := json.NewDecoder(conn)
524+
var req ReconnectingPTYRequest
525+
for {
526+
err = decoder.Decode(&req)
527+
if xerrors.Is(err, io.EOF) {
528+
return
529+
}
530+
if err != nil {
531+
a.logger.Warn(ctx, "reconnecting pty buffer read error", slog.F("id", id), slog.Error(err))
532+
return
533+
}
534+
_, err = rpty.ptty.Input().Write([]byte(req.Data))
535+
if err != nil {
536+
a.logger.Warn(ctx, "write to reconnecting pty", slog.F("id", id), slog.Error(err))
537+
return
538+
}
539+
// Check if a resize needs to happen!
540+
if req.Height == 0 || req.Width == 0 {
541+
continue
542+
}
543+
err = rpty.ptty.Resize(req.Height, req.Width)
544+
if err != nil {
545+
// We can continue after this, it's not fatal!
546+
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err))
547+
}
548+
}
549+
}
550+
384551
// isClosed returns whether the API is closed or not.
385552
func (a *agent) isClosed() bool {
386553
select {
@@ -403,3 +570,22 @@ func (a *agent) Close() error {
403570
a.connCloseWait.Wait()
404571
return nil
405572
}
573+
574+
type reconnectingPTY struct {
575+
activeConnsMutex sync.Mutex
576+
activeConns map[string]net.Conn
577+
578+
ringBuffer *ringbuffer.RingBuffer
579+
timeout *time.Timer
580+
ptty pty.PTY
581+
}
582+
583+
func (r *reconnectingPTY) Close() {
584+
r.activeConnsMutex.Lock()
585+
defer r.activeConnsMutex.Unlock()
586+
for _, conn := range r.activeConns {
587+
_ = conn.Close()
588+
}
589+
_ = r.ptty.Close()
590+
r.ringBuffer.Reset()
591+
}

agent/agent_test.go

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package agent_test
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"io"
78
"net"
@@ -14,6 +15,7 @@ import (
1415
"testing"
1516
"time"
1617

18+
"github.com/google/uuid"
1719
"github.com/pion/webrtc/v3"
1820
"github.com/pkg/sftp"
1921
"github.com/stretchr/testify/require"
@@ -188,6 +190,44 @@ func TestAgent(t *testing.T) {
188190
}, 15*time.Second, 100*time.Millisecond)
189191
require.Equal(t, content, strings.TrimSpace(gotContent))
190192
})
193+
194+
t.Run("ReconnectingPTY", func(t *testing.T) {
195+
t.Parallel()
196+
conn := setupAgent(t, agent.Metadata{})
197+
id := uuid.NewString()
198+
netConn, err := conn.ReconnectingPTY(id, 100, 100)
199+
require.NoError(t, err)
200+
201+
data, err := json.Marshal(agent.ReconnectingPTYRequest{
202+
Data: "echo test\r\n",
203+
})
204+
require.NoError(t, err)
205+
_, err = netConn.Write(data)
206+
require.NoError(t, err)
207+
208+
findEcho := func() {
209+
for {
210+
read, err := netConn.Read(data)
211+
require.NoError(t, err)
212+
if strings.Contains(string(data[:read]), "test") {
213+
break
214+
}
215+
}
216+
}
217+
218+
// Once for typing the command...
219+
findEcho()
220+
// And another time for the actual output.
221+
findEcho()
222+
223+
_ = netConn.Close()
224+
netConn, err = conn.ReconnectingPTY(id, 100, 100)
225+
require.NoError(t, err)
226+
227+
// Same output again!
228+
findEcho()
229+
findEcho()
230+
})
191231
}
192232

193233
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
@@ -227,12 +267,14 @@ func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {
227267
return session
228268
}
229269

230-
func setupAgent(t *testing.T, options agent.Metadata) *agent.Conn {
270+
func setupAgent(t *testing.T, metadata agent.Metadata) *agent.Conn {
231271
client, server := provisionersdk.TransportPipe()
232272
closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) {
233273
listener, err := peerbroker.Listen(server, nil)
234-
return options, listener, err
235-
}, slogtest.Make(t, nil).Leveled(slog.LevelDebug))
274+
return metadata, listener, err
275+
}, &agent.Options{
276+
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
277+
})
236278
t.Cleanup(func() {
237279
_ = client.Close()
238280
_ = server.Close()

0 commit comments

Comments
 (0)