Skip to content

Commit 72b6b09

Browse files
committed
Something works
1 parent f8a733c commit 72b6b09

17 files changed

+645
-141
lines changed

agent/server.go

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@ import (
55
"crypto/rand"
66
"crypto/rsa"
77
"errors"
8-
"fmt"
98
"io"
109
"net"
11-
"os/exec"
10+
"os"
1211
"sync"
1312
"syscall"
1413
"time"
1514

1615
"cdr.dev/slog"
17-
"github.com/coder/coder/console/pty"
16+
"github.com/ActiveState/termtest/conpty"
1817
"github.com/coder/coder/peer"
1918
"github.com/coder/coder/peerbroker"
2019
"github.com/coder/retry"
@@ -72,46 +71,31 @@ func (s *server) init(ctx context.Context) {
7271
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
7372
},
7473
Handler: func(session ssh.Session) {
75-
fmt.Printf("WE GOT %q %q\n", session.User(), session.RawCommand())
76-
7774
sshPty, windowSize, isPty := session.Pty()
7875
if isPty {
79-
cmd := exec.CommandContext(ctx, session.Command()[0], session.Command()[1:]...)
80-
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
81-
cmd.SysProcAttr = &syscall.SysProcAttr{
82-
Setsid: true,
83-
Setctty: true,
84-
}
85-
pty, err := pty.New()
86-
if err != nil {
87-
panic(err)
88-
}
89-
err = pty.Resize(uint16(sshPty.Window.Width), uint16(sshPty.Window.Height))
76+
cpty, err := conpty.New(int16(sshPty.Window.Width), int16(sshPty.Window.Height))
9077
if err != nil {
9178
panic(err)
9279
}
93-
cmd.Stdout = pty.OutPipe()
94-
cmd.Stderr = pty.OutPipe()
95-
cmd.Stdin = pty.InPipe()
96-
err = cmd.Start()
80+
_, _, err = cpty.Spawn("C:\\WINDOWS\\System32\\WindowsPowerShell\\v1.0\\powershell.exe", []string{}, &syscall.ProcAttr{
81+
Env: os.Environ(),
82+
})
9783
if err != nil {
9884
panic(err)
9985
}
10086
go func() {
10187
for win := range windowSize {
102-
err := pty.Resize(uint16(win.Width), uint16(win.Height))
88+
err := cpty.Resize(uint16(win.Width), uint16(win.Height))
10389
if err != nil {
10490
panic(err)
10591
}
10692
}
10793
}()
94+
10895
go func() {
109-
io.Copy(pty.Writer(), session)
96+
io.Copy(session, cpty)
11097
}()
111-
fmt.Printf("Got here!\n")
112-
io.Copy(session, pty.Reader())
113-
fmt.Printf("Done!\n")
114-
cmd.Wait()
98+
io.Copy(cpty, session)
11599
}
116100
},
117101
HostSigners: []ssh.Signer{randomSigner},

agent/server_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ func TestAgent(t *testing.T) {
6262
require.NoError(t, err)
6363
session.Stdout = os.Stdout
6464
session.Stderr = os.Stderr
65-
err = session.Run("echo test")
65+
err = session.Run("cmd.exe /k echo test")
6666
require.NoError(t, err)
6767
})
6868
}
69+
70+
// Read + write for input
71+
// Read + write for output

console/conpty/spawn.go

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
package conpty
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"strings"
7+
"syscall"
8+
"unicode/utf16"
9+
"unsafe"
10+
11+
"golang.org/x/sys/windows"
12+
)
13+
14+
// Spawn spawns a new process attached to the pseudo terminal
15+
func Spawn(conpty *ConPty, argv0 string, argv []string, attr *syscall.ProcAttr) (pid int, handle uintptr, err error) {
16+
startupInfo := &startupInfoEx{}
17+
var attrListSize uint64
18+
startupInfo.startupInfo.Cb = uint32(unsafe.Sizeof(startupInfo))
19+
20+
err = initializeProcThreadAttributeList(0, 1, &attrListSize)
21+
if err != nil {
22+
return 0, 0, fmt.Errorf("could not retrieve list size: %v", err)
23+
}
24+
25+
attributeListBuffer := make([]byte, attrListSize)
26+
startupInfo.lpAttributeList = windows.Handle(unsafe.Pointer(&attributeListBuffer[0]))
27+
28+
err = initializeProcThreadAttributeList(uintptr(startupInfo.lpAttributeList), 1, &attrListSize)
29+
if err != nil {
30+
return 0, 0, fmt.Errorf("failed to initialize proc thread attributes for conpty: %v", err)
31+
}
32+
33+
err = updateProcThreadAttributeList(
34+
startupInfo.lpAttributeList,
35+
procThreadAttributePseudoconsole,
36+
conpty.hpCon,
37+
unsafe.Sizeof(conpty.hpCon))
38+
if err != nil {
39+
return 0, 0, fmt.Errorf("failed to update proc thread attributes attributes for conpty usage: %v", err)
40+
}
41+
42+
if attr == nil {
43+
attr = &syscall.ProcAttr{}
44+
}
45+
46+
if len(attr.Dir) != 0 {
47+
// StartProcess assumes that argv0 is relative to attr.Dir,
48+
// because it implies Chdir(attr.Dir) before executing argv0.
49+
// Windows CreateProcess assumes the opposite: it looks for
50+
// argv0 relative to the current directory, and, only once the new
51+
// process is started, it does Chdir(attr.Dir). We are adjusting
52+
// for that difference here by making argv0 absolute.
53+
var err error
54+
argv0, err = joinExeDirAndFName(attr.Dir, argv0)
55+
if err != nil {
56+
return 0, 0, err
57+
}
58+
}
59+
argv0p, err := windows.UTF16PtrFromString(argv0)
60+
if err != nil {
61+
return 0, 0, err
62+
}
63+
64+
// Windows CreateProcess takes the command line as a single string:
65+
// use attr.CmdLine if set, else build the command line by escaping
66+
// and joining each argument with spaces
67+
cmdline := makeCmdLine(argv)
68+
69+
var argvp *uint16
70+
if len(cmdline) != 0 {
71+
argvp, err = windows.UTF16PtrFromString(cmdline)
72+
if err != nil {
73+
return 0, 0, fmt.Errorf("utf ptr from string: %w", err)
74+
}
75+
}
76+
77+
var dirp *uint16
78+
if len(attr.Dir) != 0 {
79+
dirp, err = windows.UTF16PtrFromString(attr.Dir)
80+
if err != nil {
81+
return 0, 0, fmt.Errorf("utf ptr from string: %w", err)
82+
}
83+
}
84+
85+
startupInfo.startupInfo.Flags = windows.STARTF_USESTDHANDLES
86+
87+
pi := new(windows.ProcessInformation)
88+
89+
flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | extendedStartupinfoPresent
90+
91+
var zeroSec windows.SecurityAttributes
92+
pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1}
93+
tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1}
94+
95+
// c.startupInfo.startupInfo.Cb = uint32(unsafe.Sizeof(c.startupInfo))
96+
err = windows.CreateProcess(
97+
argv0p,
98+
argvp,
99+
pSec, // process handle not inheritable
100+
tSec, // thread handles not inheritable,
101+
false,
102+
flags,
103+
createEnvBlock(addCriticalEnv(dedupEnvCase(true, attr.Env))),
104+
dirp, // use current directory later: dirp,
105+
&startupInfo.startupInfo,
106+
pi)
107+
108+
if err != nil {
109+
return 0, 0, fmt.Errorf("create process: %w", err)
110+
}
111+
defer windows.CloseHandle(windows.Handle(pi.Thread))
112+
113+
return int(pi.ProcessId), uintptr(pi.Process), nil
114+
}
115+
116+
// makeCmdLine builds a command line out of args by escaping "special"
117+
// characters and joining the arguments with spaces.
118+
func makeCmdLine(args []string) string {
119+
var s string
120+
for _, v := range args {
121+
if s != "" {
122+
s += " "
123+
}
124+
s += windows.EscapeArg(v)
125+
}
126+
return s
127+
}
128+
129+
func isSlash(c uint8) bool {
130+
return c == '\\' || c == '/'
131+
}
132+
133+
func normalizeDir(dir string) (name string, err error) {
134+
ndir, err := syscall.FullPath(dir)
135+
if err != nil {
136+
return "", err
137+
}
138+
if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) {
139+
// dir cannot have \\server\share\path form
140+
return "", syscall.EINVAL
141+
}
142+
return ndir, nil
143+
}
144+
145+
func volToUpper(ch int) int {
146+
if 'a' <= ch && ch <= 'z' {
147+
ch += 'A' - 'a'
148+
}
149+
return ch
150+
}
151+
152+
func joinExeDirAndFName(dir, p string) (name string, err error) {
153+
if len(p) == 0 {
154+
return "", syscall.EINVAL
155+
}
156+
if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) {
157+
// \\server\share\path form
158+
return p, nil
159+
}
160+
if len(p) > 1 && p[1] == ':' {
161+
// has drive letter
162+
if len(p) == 2 {
163+
return "", syscall.EINVAL
164+
}
165+
if isSlash(p[2]) {
166+
return p, nil
167+
} else {
168+
d, err := normalizeDir(dir)
169+
if err != nil {
170+
return "", err
171+
}
172+
if volToUpper(int(p[0])) == volToUpper(int(d[0])) {
173+
return syscall.FullPath(d + "\\" + p[2:])
174+
} else {
175+
return syscall.FullPath(p)
176+
}
177+
}
178+
} else {
179+
// no drive letter
180+
d, err := normalizeDir(dir)
181+
if err != nil {
182+
return "", err
183+
}
184+
if isSlash(p[0]) {
185+
return windows.FullPath(d[:2] + p)
186+
} else {
187+
return windows.FullPath(d + "\\" + p)
188+
}
189+
}
190+
}
191+
192+
// createEnvBlock converts an array of environment strings into
193+
// the representation required by CreateProcess: a sequence of NUL
194+
// terminated strings followed by a nil.
195+
// Last bytes are two UCS-2 NULs, or four NUL bytes.
196+
func createEnvBlock(envv []string) *uint16 {
197+
if len(envv) == 0 {
198+
return &utf16.Encode([]rune("\x00\x00"))[0]
199+
}
200+
length := 0
201+
for _, s := range envv {
202+
length += len(s) + 1
203+
}
204+
length += 1
205+
206+
b := make([]byte, length)
207+
i := 0
208+
for _, s := range envv {
209+
l := len(s)
210+
copy(b[i:i+l], []byte(s))
211+
copy(b[i+l:i+l+1], []byte{0})
212+
i = i + l + 1
213+
}
214+
copy(b[i:i+1], []byte{0})
215+
216+
return &utf16.Encode([]rune(string(b)))[0]
217+
}
218+
219+
// dedupEnvCase is dedupEnv with a case option for testing.
220+
// If caseInsensitive is true, the case of keys is ignored.
221+
func dedupEnvCase(caseInsensitive bool, env []string) []string {
222+
out := make([]string, 0, len(env))
223+
saw := make(map[string]int, len(env)) // key => index into out
224+
for _, kv := range env {
225+
eq := strings.Index(kv, "=")
226+
if eq < 0 {
227+
out = append(out, kv)
228+
continue
229+
}
230+
k := kv[:eq]
231+
if caseInsensitive {
232+
k = strings.ToLower(k)
233+
}
234+
if dupIdx, isDup := saw[k]; isDup {
235+
out[dupIdx] = kv
236+
continue
237+
}
238+
saw[k] = len(out)
239+
out = append(out, kv)
240+
}
241+
return out
242+
}
243+
244+
// addCriticalEnv adds any critical environment variables that are required
245+
// (or at least almost always required) on the operating system.
246+
// Currently this is only used for Windows.
247+
func addCriticalEnv(env []string) []string {
248+
for _, kv := range env {
249+
eq := strings.Index(kv, "=")
250+
if eq < 0 {
251+
continue
252+
}
253+
k := kv[:eq]
254+
if strings.EqualFold(k, "SYSTEMROOT") {
255+
// We already have it.
256+
return env
257+
}
258+
}
259+
return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT"))
260+
}

0 commit comments

Comments
 (0)