Skip to content

feat: Add support for executing processes with Windows ConPty #311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Start Windows support
  • Loading branch information
kylecarbs committed Feb 16, 2022
commit f8a733c7e2ef5a5c7bf4c07059e8fa4460ea9117
74 changes: 62 additions & 12 deletions agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import (
"fmt"
"io"
"net"
"os/exec"
"sync"
"syscall"
"time"

"cdr.dev/slog"
"github.com/coder/coder/console/pty"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/retry"
Expand Down Expand Up @@ -50,33 +53,83 @@ type server struct {
}

func (s *server) init(ctx context.Context) {
forwardHandler := &ssh.ForwardedTCPHandler{}
key, err := rsa.GenerateKey(rand.Reader, 2048)
// Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
signer, err := gossh.NewSignerFromKey(key)
randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
if err != nil {
panic(err)
}
sshLogger := s.options.Logger.Named("ssh-server")
forwardHandler := &ssh.ForwardedTCPHandler{}
s.sshServer = &ssh.Server{
ChannelHandlers: ssh.DefaultChannelHandlers,
ConnectionFailedCallback: func(conn net.Conn, err error) {
fmt.Printf("Conn failed: %s\n", err)
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
},
Handler: func(s ssh.Session) {
fmt.Printf("WE GOT %q %q\n", s.User(), s.RawCommand())
Handler: func(session ssh.Session) {
fmt.Printf("WE GOT %q %q\n", session.User(), session.RawCommand())

sshPty, windowSize, isPty := session.Pty()
if isPty {
cmd := exec.CommandContext(ctx, session.Command()[0], session.Command()[1:]...)
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
Setctty: true,
}
pty, err := pty.New()
if err != nil {
panic(err)
}
err = pty.Resize(uint16(sshPty.Window.Width), uint16(sshPty.Window.Height))
if err != nil {
panic(err)
}
cmd.Stdout = pty.OutPipe()
cmd.Stderr = pty.OutPipe()
cmd.Stdin = pty.InPipe()
err = cmd.Start()
if err != nil {
panic(err)
}
go func() {
for win := range windowSize {
err := pty.Resize(uint16(win.Width), uint16(win.Height))
if err != nil {
panic(err)
}
}
}()
go func() {
io.Copy(pty.Writer(), session)
}()
fmt.Printf("Got here!\n")
io.Copy(session, pty.Reader())
fmt.Printf("Done!\n")
cmd.Wait()
}
},
HostSigners: []ssh.Signer{signer},
HostSigners: []ssh.Signer{randomSigner},
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
// Allow local port forwarding all!
sshLogger.Debug(ctx, "local port forward",
slog.F("destination-host", destinationHost),
slog.F("destination-port", destinationPort))
return true
},
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
return false
return true
},
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
// Allow revere port forwarding all!
// Allow reverse port forwarding all!
sshLogger.Debug(ctx, "local port forward",
slog.F("bind-host", bindHost),
slog.F("bind-port", bindPort))
return true
},
RequestHandlers: map[string]ssh.RequestHandler{
Expand All @@ -91,9 +144,6 @@ func (s *server) init(ctx context.Context) {
// encrypted. If possible, we'd disable encryption entirely here.
Ciphers: []string{"arcfour"},
},
PublicKeyCallback: func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
return &gossh.Permissions{}, nil
},
NoClientAuth: true,
}
},
Expand Down
2 changes: 2 additions & 0 deletions agent/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ func TestAgent(t *testing.T) {
sshClient := ssh.NewClient(sshConn, channels, requests)
session, err := sshClient.NewSession()
require.NoError(t, err)
err = session.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
session.Stdout = os.Stdout
session.Stderr = os.Stderr
err = session.Run("echo test")
Expand Down
4 changes: 4 additions & 0 deletions console/conpty/conpty.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ func (c *ConPty) Reader() io.Reader {
return c.outFileOurSide
}

func (c *ConPty) Writer() io.Writer {
return c.inFileOurSide
}

// InPipe returns input pipe of the pseudo terminal
// Note: It is safer to use the Write method to prevent partially-written VT sequences
// from corrupting the terminal
Expand Down
1 change: 1 addition & 0 deletions console/pty/pty.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type Pty interface {
Resize(cols uint16, rows uint16) error
WriteString(str string) (int, error)
Reader() io.Reader
Writer() io.Writer
Close() error
}

Expand Down
4 changes: 4 additions & 0 deletions console/pty/pty_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ func (p *unixPty) Reader() io.Reader {
return p.pty
}

func (p *unixPty) Writer() io.Writer {
return p.pty
}

func (p *unixPty) WriteString(str string) (int, error) {
return p.pty.WriteString(str)
}
Expand Down
4 changes: 4 additions & 0 deletions console/pty/pty_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ func (p *pipePtyVal) Reader() io.Reader {
return p.outFileOurSide
}

func (p *pipePtyVal) Writer() io.Writer {
return p.inFileOurSide
}

func (p *pipePtyVal) WriteString(str string) (int, error) {
return p.inFileOurSide.WriteString(str)
}
Expand Down
113 changes: 113 additions & 0 deletions wintest/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package main

import (
"context"
"os"
"testing"

"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"golang.org/x/sys/windows"
)

func main() {
state, err := MakeOutputRaw(os.Stdout.Fd())
if err != nil {
panic(err)
}
defer Restore(os.Stdout.Fd(), state)

t := &testing.T{}
ctx := context.Background()
client, server := provisionersdk.TransportPipe()
defer client.Close()
defer server.Close()
closer := agent.Server(func(ctx context.Context) (*peerbroker.Listener, error) {
return peerbroker.Listen(server, &peer.ConnOptions{
Logger: slogtest.Make(t, nil),
})
}, &agent.Options{
Logger: slogtest.Make(t, nil),
})
defer closer.Close()
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := api.NegotiateConnection(ctx)
require.NoError(t, err)
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil),
})
require.NoError(t, err)
defer conn.Close()
channel, err := conn.Dial(ctx, "example", &peer.ChannelOptions{
Protocol: "ssh",
})
require.NoError(t, err)
sshConn, channels, requests, err := ssh.NewClientConn(channel.NetConn(), "localhost:22", &ssh.ClientConfig{
User: "kyle",
Config: ssh.Config{
Ciphers: []string{"arcfour"},
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
require.NoError(t, err)
sshClient := ssh.NewClient(sshConn, channels, requests)
session, err := sshClient.NewSession()
require.NoError(t, err)
err = session.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{
ssh.ECHO: 1,
})
require.NoError(t, err)
session.Stdin = os.Stdin
session.Stdout = os.Stdout
session.Stderr = os.Stderr
err = session.Run("bash")
require.NoError(t, err)
}

// State differs per-platform.
type State struct {
mode uint32
}

// makeRaw sets the terminal in raw mode and returns the previous state so it can be restored.
func makeRaw(handle windows.Handle, input bool) (uint32, error) {
var prevState uint32
if err := windows.GetConsoleMode(handle, &prevState); err != nil {
return 0, err
}

var raw uint32
if input {
raw = prevState &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
raw |= windows.ENABLE_VIRTUAL_TERMINAL_INPUT
} else {
raw = prevState | windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING
}

if err := windows.SetConsoleMode(handle, raw); err != nil {
return 0, err
}
return prevState, nil
}

// MakeOutputRaw sets an output terminal to raw and enables VT100 processing.
func MakeOutputRaw(handle uintptr) (*State, error) {
prevState, err := makeRaw(windows.Handle(handle), false)
if err != nil {
return nil, err
}

return &State{mode: prevState}, nil
}

// Restore terminal back to original state.
func Restore(handle uintptr, state *State) error {
return windows.SetConsoleMode(windows.Handle(handle), state.mode)
}