Skip to content

Commit f1fe2b5

Browse files
authored
feat: add GPG forwarding to coder ssh (#5482)
1 parent 59e919a commit f1fe2b5

12 files changed

+1050
-21
lines changed

agent/agent.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,12 +480,16 @@ func (a *agent) init(ctx context.Context) {
480480
if err != nil {
481481
panic(err)
482482
}
483+
483484
sshLogger := a.logger.Named("ssh-server")
484485
forwardHandler := &ssh.ForwardedTCPHandler{}
486+
unixForwardHandler := &forwardedUnixHandler{log: a.logger}
487+
485488
a.sshServer = &ssh.Server{
486489
ChannelHandlers: map[string]ssh.ChannelHandler{
487-
"direct-tcpip": ssh.DirectTCPIPHandler,
488-
"session": ssh.DefaultSessionHandler,
490+
"direct-tcpip": ssh.DirectTCPIPHandler,
491+
"direct-streamlocal@openssh.com": directStreamLocalHandler,
492+
"session": ssh.DefaultSessionHandler,
489493
},
490494
ConnectionFailedCallback: func(conn net.Conn, err error) {
491495
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
@@ -525,8 +529,10 @@ func (a *agent) init(ctx context.Context) {
525529
return true
526530
},
527531
RequestHandlers: map[string]ssh.RequestHandler{
528-
"tcpip-forward": forwardHandler.HandleSSHRequest,
529-
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
532+
"tcpip-forward": forwardHandler.HandleSSHRequest,
533+
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
534+
"streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
535+
"cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
530536
},
531537
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
532538
return &gossh.ServerConfig{

agent/agent_test.go

Lines changed: 248 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) {
273273
}
274274

275275
//nolint:paralleltest // This test reserves a port.
276-
func TestAgent_LocalForwarding(t *testing.T) {
276+
func TestAgent_TCPLocalForwarding(t *testing.T) {
277277
random, err := net.Listen("tcp", "127.0.0.1:0")
278278
require.NoError(t, err)
279279
_ = random.Close()
@@ -286,24 +286,239 @@ func TestAgent_LocalForwarding(t *testing.T) {
286286
defer local.Close()
287287
tcpAddr, valid = local.Addr().(*net.TCPAddr)
288288
require.True(t, valid)
289-
localPort := tcpAddr.Port
289+
remotePort := tcpAddr.Port
290290
done := make(chan struct{})
291291
go func() {
292292
defer close(done)
293293
conn, err := local.Accept()
294294
if !assert.NoError(t, err) {
295295
return
296296
}
297-
_ = conn.Close()
297+
defer conn.Close()
298+
b := make([]byte, 4)
299+
_, err = conn.Read(b)
300+
if !assert.NoError(t, err) {
301+
return
302+
}
303+
_, err = conn.Write(b)
304+
if !assert.NoError(t, err) {
305+
return
306+
}
298307
}()
299308

300-
err = setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, localPort)}, []string{"echo", "test"}).Start()
309+
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "10"})
310+
err = cmd.Start()
311+
require.NoError(t, err)
312+
313+
require.Eventually(t, func() bool {
314+
conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(randomPort))
315+
if err != nil {
316+
return false
317+
}
318+
defer conn.Close()
319+
_, err = conn.Write([]byte("test"))
320+
if !assert.NoError(t, err) {
321+
return false
322+
}
323+
b := make([]byte, 4)
324+
_, err = conn.Read(b)
325+
if !assert.NoError(t, err) {
326+
return false
327+
}
328+
if !assert.Equal(t, "test", string(b)) {
329+
return false
330+
}
331+
332+
return true
333+
}, testutil.WaitLong, testutil.IntervalSlow)
334+
335+
<-done
336+
337+
_ = cmd.Process.Kill()
338+
}
339+
340+
//nolint:paralleltest // This test reserves a port.
341+
func TestAgent_TCPRemoteForwarding(t *testing.T) {
342+
random, err := net.Listen("tcp", "127.0.0.1:0")
301343
require.NoError(t, err)
344+
_ = random.Close()
345+
tcpAddr, valid := random.Addr().(*net.TCPAddr)
346+
require.True(t, valid)
347+
randomPort := tcpAddr.Port
302348

303-
conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(localPort))
349+
l, err := net.Listen("tcp", "127.0.0.1:0")
304350
require.NoError(t, err)
305-
conn.Close()
351+
defer l.Close()
352+
tcpAddr, valid = l.Addr().(*net.TCPAddr)
353+
require.True(t, valid)
354+
localPort := tcpAddr.Port
355+
356+
done := make(chan struct{})
357+
go func() {
358+
defer close(done)
359+
360+
conn, err := l.Accept()
361+
if err != nil {
362+
return
363+
}
364+
defer conn.Close()
365+
b := make([]byte, 4)
366+
_, err = conn.Read(b)
367+
if !assert.NoError(t, err) {
368+
return
369+
}
370+
_, err = conn.Write(b)
371+
if !assert.NoError(t, err) {
372+
return
373+
}
374+
}()
375+
376+
cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "10"})
377+
err = cmd.Start()
378+
require.NoError(t, err)
379+
380+
require.Eventually(t, func() bool {
381+
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort))
382+
if err != nil {
383+
return false
384+
}
385+
defer conn.Close()
386+
_, err = conn.Write([]byte("test"))
387+
if !assert.NoError(t, err) {
388+
return false
389+
}
390+
b := make([]byte, 4)
391+
_, err = conn.Read(b)
392+
if !assert.NoError(t, err) {
393+
return false
394+
}
395+
if !assert.Equal(t, "test", string(b)) {
396+
return false
397+
}
398+
399+
return true
400+
}, testutil.WaitLong, testutil.IntervalSlow)
401+
306402
<-done
403+
404+
_ = cmd.Process.Kill()
405+
}
406+
407+
func TestAgent_UnixLocalForwarding(t *testing.T) {
408+
t.Parallel()
409+
if runtime.GOOS == "windows" {
410+
t.Skip("unix domain sockets are not fully supported on Windows")
411+
}
412+
413+
tmpdir := tempDirUnixSocket(t)
414+
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
415+
localSocketPath := filepath.Join(tmpdir, "local-socket")
416+
417+
l, err := net.Listen("unix", remoteSocketPath)
418+
require.NoError(t, err)
419+
defer l.Close()
420+
421+
done := make(chan struct{})
422+
go func() {
423+
defer close(done)
424+
425+
conn, err := l.Accept()
426+
if err != nil {
427+
return
428+
}
429+
defer conn.Close()
430+
b := make([]byte, 4)
431+
_, err = conn.Read(b)
432+
if !assert.NoError(t, err) {
433+
return
434+
}
435+
_, err = conn.Write(b)
436+
if !assert.NoError(t, err) {
437+
return
438+
}
439+
}()
440+
441+
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "10"})
442+
err = cmd.Start()
443+
require.NoError(t, err)
444+
445+
require.Eventually(t, func() bool {
446+
_, err := os.Stat(localSocketPath)
447+
return err == nil
448+
}, testutil.WaitLong, testutil.IntervalFast)
449+
450+
conn, err := net.Dial("unix", localSocketPath)
451+
require.NoError(t, err)
452+
defer conn.Close()
453+
_, err = conn.Write([]byte("test"))
454+
require.NoError(t, err)
455+
b := make([]byte, 4)
456+
_, err = conn.Read(b)
457+
require.NoError(t, err)
458+
require.Equal(t, "test", string(b))
459+
_ = conn.Close()
460+
<-done
461+
462+
_ = cmd.Process.Kill()
463+
}
464+
465+
func TestAgent_UnixRemoteForwarding(t *testing.T) {
466+
t.Parallel()
467+
if runtime.GOOS == "windows" {
468+
t.Skip("unix domain sockets are not fully supported on Windows")
469+
}
470+
471+
tmpdir := tempDirUnixSocket(t)
472+
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
473+
localSocketPath := filepath.Join(tmpdir, "local-socket")
474+
475+
l, err := net.Listen("unix", localSocketPath)
476+
require.NoError(t, err)
477+
defer l.Close()
478+
479+
done := make(chan struct{})
480+
go func() {
481+
defer close(done)
482+
483+
conn, err := l.Accept()
484+
if err != nil {
485+
return
486+
}
487+
defer conn.Close()
488+
b := make([]byte, 4)
489+
_, err = conn.Read(b)
490+
if !assert.NoError(t, err) {
491+
return
492+
}
493+
_, err = conn.Write(b)
494+
if !assert.NoError(t, err) {
495+
return
496+
}
497+
}()
498+
499+
cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "10"})
500+
err = cmd.Start()
501+
require.NoError(t, err)
502+
503+
require.Eventually(t, func() bool {
504+
_, err := os.Stat(remoteSocketPath)
505+
return err == nil
506+
}, testutil.WaitLong, testutil.IntervalFast)
507+
508+
conn, err := net.Dial("unix", remoteSocketPath)
509+
require.NoError(t, err)
510+
defer conn.Close()
511+
_, err = conn.Write([]byte("test"))
512+
require.NoError(t, err)
513+
b := make([]byte, 4)
514+
_, err = conn.Read(b)
515+
require.NoError(t, err)
516+
require.Equal(t, "test", string(b))
517+
_ = conn.Close()
518+
519+
<-done
520+
521+
_ = cmd.Process.Kill()
307522
}
308523

309524
func TestAgent_SFTP(t *testing.T) {
@@ -733,7 +948,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
733948
args := append(beforeArgs,
734949
"-o", "HostName "+tcpAddr.IP.String(),
735950
"-o", "Port "+strconv.Itoa(tcpAddr.Port),
736-
"-o", "StrictHostKeyChecking=no", "host")
951+
"-o", "StrictHostKeyChecking=no",
952+
"-o", "UserKnownHostsFile=/dev/null",
953+
"host",
954+
)
737955
args = append(args, afterArgs...)
738956
return exec.Command("ssh", args...)
739957
}
@@ -919,3 +1137,26 @@ func (*client) PostWorkspaceAgentAppHealth(_ context.Context, _ codersdk.PostWor
9191137
func (*client) PostWorkspaceAgentVersion(_ context.Context, _ string) error {
9201138
return nil
9211139
}
1140+
1141+
// tempDirUnixSocket returns a temporary directory that can safely hold unix
1142+
// sockets (probably).
1143+
//
1144+
// During tests on darwin we hit the max path length limit for unix sockets
1145+
// pretty easily in the default location, so this function uses /tmp instead to
1146+
// get shorter paths.
1147+
func tempDirUnixSocket(t *testing.T) string {
1148+
t.Helper()
1149+
if runtime.GOOS == "darwin" {
1150+
testName := strings.ReplaceAll(t.Name(), "/", "_")
1151+
dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("coder-test-%s-", testName))
1152+
require.NoError(t, err, "create temp dir for gpg test")
1153+
1154+
t.Cleanup(func() {
1155+
err := os.RemoveAll(dir)
1156+
assert.NoError(t, err, "remove temp dir", dir)
1157+
})
1158+
return dir
1159+
}
1160+
1161+
return t.TempDir()
1162+
}

0 commit comments

Comments
 (0)