Skip to content

Commit f8a733c

Browse files
committed
Start Windows support
1 parent af5e3c2 commit f8a733c

File tree

7 files changed

+190
-12
lines changed

7 files changed

+190
-12
lines changed

agent/server.go

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ import (
88
"fmt"
99
"io"
1010
"net"
11+
"os/exec"
1112
"sync"
13+
"syscall"
1214
"time"
1315

1416
"cdr.dev/slog"
17+
"github.com/coder/coder/console/pty"
1518
"github.com/coder/coder/peer"
1619
"github.com/coder/coder/peerbroker"
1720
"github.com/coder/retry"
@@ -50,33 +53,83 @@ type server struct {
5053
}
5154

5255
func (s *server) init(ctx context.Context) {
53-
forwardHandler := &ssh.ForwardedTCPHandler{}
54-
key, err := rsa.GenerateKey(rand.Reader, 2048)
56+
// Clients' should ignore the host key when connecting.
57+
// The agent needs to authenticate with coderd to SSH,
58+
// so SSH authentication doesn't improve security.
59+
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
5560
if err != nil {
5661
panic(err)
5762
}
58-
signer, err := gossh.NewSignerFromKey(key)
63+
randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
5964
if err != nil {
6065
panic(err)
6166
}
67+
sshLogger := s.options.Logger.Named("ssh-server")
68+
forwardHandler := &ssh.ForwardedTCPHandler{}
6269
s.sshServer = &ssh.Server{
6370
ChannelHandlers: ssh.DefaultChannelHandlers,
6471
ConnectionFailedCallback: func(conn net.Conn, err error) {
65-
fmt.Printf("Conn failed: %s\n", err)
72+
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
6673
},
67-
Handler: func(s ssh.Session) {
68-
fmt.Printf("WE GOT %q %q\n", s.User(), s.RawCommand())
74+
Handler: func(session ssh.Session) {
75+
fmt.Printf("WE GOT %q %q\n", session.User(), session.RawCommand())
76+
77+
sshPty, windowSize, isPty := session.Pty()
78+
if isPty {
79+
cmd := exec.CommandContext(ctx, session.Command()[0], session.Command()[1:]...)
80+
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
81+
cmd.SysProcAttr = &syscall.SysProcAttr{
82+
Setsid: true,
83+
Setctty: true,
84+
}
85+
pty, err := pty.New()
86+
if err != nil {
87+
panic(err)
88+
}
89+
err = pty.Resize(uint16(sshPty.Window.Width), uint16(sshPty.Window.Height))
90+
if err != nil {
91+
panic(err)
92+
}
93+
cmd.Stdout = pty.OutPipe()
94+
cmd.Stderr = pty.OutPipe()
95+
cmd.Stdin = pty.InPipe()
96+
err = cmd.Start()
97+
if err != nil {
98+
panic(err)
99+
}
100+
go func() {
101+
for win := range windowSize {
102+
err := pty.Resize(uint16(win.Width), uint16(win.Height))
103+
if err != nil {
104+
panic(err)
105+
}
106+
}
107+
}()
108+
go func() {
109+
io.Copy(pty.Writer(), session)
110+
}()
111+
fmt.Printf("Got here!\n")
112+
io.Copy(session, pty.Reader())
113+
fmt.Printf("Done!\n")
114+
cmd.Wait()
115+
}
69116
},
70-
HostSigners: []ssh.Signer{signer},
117+
HostSigners: []ssh.Signer{randomSigner},
71118
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
72119
// Allow local port forwarding all!
120+
sshLogger.Debug(ctx, "local port forward",
121+
slog.F("destination-host", destinationHost),
122+
slog.F("destination-port", destinationPort))
73123
return true
74124
},
75125
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
76-
return false
126+
return true
77127
},
78128
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
79-
// Allow revere port forwarding all!
129+
// Allow reverse port forwarding all!
130+
sshLogger.Debug(ctx, "local port forward",
131+
slog.F("bind-host", bindHost),
132+
slog.F("bind-port", bindPort))
80133
return true
81134
},
82135
RequestHandlers: map[string]ssh.RequestHandler{
@@ -91,9 +144,6 @@ func (s *server) init(ctx context.Context) {
91144
// encrypted. If possible, we'd disable encryption entirely here.
92145
Ciphers: []string{"arcfour"},
93146
},
94-
PublicKeyCallback: func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
95-
return &gossh.Permissions{}, nil
96-
},
97147
NoClientAuth: true,
98148
}
99149
},

agent/server_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ func TestAgent(t *testing.T) {
5858
sshClient := ssh.NewClient(sshConn, channels, requests)
5959
session, err := sshClient.NewSession()
6060
require.NoError(t, err)
61+
err = session.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{})
62+
require.NoError(t, err)
6163
session.Stdout = os.Stdout
6264
session.Stderr = os.Stderr
6365
err = session.Run("echo test")

console/conpty/conpty.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ func (c *ConPty) Reader() io.Reader {
6565
return c.outFileOurSide
6666
}
6767

68+
func (c *ConPty) Writer() io.Writer {
69+
return c.inFileOurSide
70+
}
71+
6872
// InPipe returns input pipe of the pseudo terminal
6973
// Note: It is safer to use the Write method to prevent partially-written VT sequences
7074
// from corrupting the terminal

console/pty/pty.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ type Pty interface {
1212
Resize(cols uint16, rows uint16) error
1313
WriteString(str string) (int, error)
1414
Reader() io.Reader
15+
Writer() io.Writer
1516
Close() error
1617
}
1718

console/pty/pty_other.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ func (p *unixPty) Reader() io.Reader {
3838
return p.pty
3939
}
4040

41+
func (p *unixPty) Writer() io.Writer {
42+
return p.pty
43+
}
44+
4145
func (p *unixPty) WriteString(str string) (int, error) {
4246
return p.pty.WriteString(str)
4347
}

console/pty/pty_windows.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ func (p *pipePtyVal) Reader() io.Reader {
6161
return p.outFileOurSide
6262
}
6363

64+
func (p *pipePtyVal) Writer() io.Writer {
65+
return p.inFileOurSide
66+
}
67+
6468
func (p *pipePtyVal) WriteString(str string) (int, error) {
6569
return p.inFileOurSide.WriteString(str)
6670
}

wintest/main.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"os"
6+
"testing"
7+
8+
"cdr.dev/slog/sloggers/slogtest"
9+
"github.com/coder/coder/agent"
10+
"github.com/coder/coder/peer"
11+
"github.com/coder/coder/peerbroker"
12+
"github.com/coder/coder/peerbroker/proto"
13+
"github.com/coder/coder/provisionersdk"
14+
"github.com/pion/webrtc/v3"
15+
"github.com/stretchr/testify/require"
16+
"golang.org/x/crypto/ssh"
17+
"golang.org/x/sys/windows"
18+
)
19+
20+
func main() {
21+
state, err := MakeOutputRaw(os.Stdout.Fd())
22+
if err != nil {
23+
panic(err)
24+
}
25+
defer Restore(os.Stdout.Fd(), state)
26+
27+
t := &testing.T{}
28+
ctx := context.Background()
29+
client, server := provisionersdk.TransportPipe()
30+
defer client.Close()
31+
defer server.Close()
32+
closer := agent.Server(func(ctx context.Context) (*peerbroker.Listener, error) {
33+
return peerbroker.Listen(server, &peer.ConnOptions{
34+
Logger: slogtest.Make(t, nil),
35+
})
36+
}, &agent.Options{
37+
Logger: slogtest.Make(t, nil),
38+
})
39+
defer closer.Close()
40+
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
41+
stream, err := api.NegotiateConnection(ctx)
42+
require.NoError(t, err)
43+
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
44+
Logger: slogtest.Make(t, nil),
45+
})
46+
require.NoError(t, err)
47+
defer conn.Close()
48+
channel, err := conn.Dial(ctx, "example", &peer.ChannelOptions{
49+
Protocol: "ssh",
50+
})
51+
require.NoError(t, err)
52+
sshConn, channels, requests, err := ssh.NewClientConn(channel.NetConn(), "localhost:22", &ssh.ClientConfig{
53+
User: "kyle",
54+
Config: ssh.Config{
55+
Ciphers: []string{"arcfour"},
56+
},
57+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
58+
})
59+
require.NoError(t, err)
60+
sshClient := ssh.NewClient(sshConn, channels, requests)
61+
session, err := sshClient.NewSession()
62+
require.NoError(t, err)
63+
err = session.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{
64+
ssh.ECHO: 1,
65+
})
66+
require.NoError(t, err)
67+
session.Stdin = os.Stdin
68+
session.Stdout = os.Stdout
69+
session.Stderr = os.Stderr
70+
err = session.Run("bash")
71+
require.NoError(t, err)
72+
}
73+
74+
// State differs per-platform.
75+
type State struct {
76+
mode uint32
77+
}
78+
79+
// makeRaw sets the terminal in raw mode and returns the previous state so it can be restored.
80+
func makeRaw(handle windows.Handle, input bool) (uint32, error) {
81+
var prevState uint32
82+
if err := windows.GetConsoleMode(handle, &prevState); err != nil {
83+
return 0, err
84+
}
85+
86+
var raw uint32
87+
if input {
88+
raw = prevState &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
89+
raw |= windows.ENABLE_VIRTUAL_TERMINAL_INPUT
90+
} else {
91+
raw = prevState | windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING
92+
}
93+
94+
if err := windows.SetConsoleMode(handle, raw); err != nil {
95+
return 0, err
96+
}
97+
return prevState, nil
98+
}
99+
100+
// MakeOutputRaw sets an output terminal to raw and enables VT100 processing.
101+
func MakeOutputRaw(handle uintptr) (*State, error) {
102+
prevState, err := makeRaw(windows.Handle(handle), false)
103+
if err != nil {
104+
return nil, err
105+
}
106+
107+
return &State{mode: prevState}, nil
108+
}
109+
110+
// Restore terminal back to original state.
111+
func Restore(handle uintptr, state *State) error {
112+
return windows.SetConsoleMode(windows.Handle(handle), state.mode)
113+
}

0 commit comments

Comments
 (0)