diff --git a/agent/agent.go b/agent/agent.go index 75787b4cfc5e1..55d169cd35ec0 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -380,6 +380,16 @@ func (a *agent) handleSSHSession(session ssh.Session) error { return err } + if ssh.AgentRequested(session) { + l, err := ssh.NewAgentListener() + if err != nil { + return xerrors.Errorf("new agent listener: %w", err) + } + defer l.Close() + go ssh.ForwardAgentConnections(l, session) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String())) + } + sshPty, windowSize, isPty := session.Pty() if isPty { cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) diff --git a/cli/configssh.go b/cli/configssh.go index ed8f785b178de..f511dd07f47da 100644 --- a/cli/configssh.go +++ b/cli/configssh.go @@ -38,6 +38,11 @@ func configSSH() *cobra.Command { Annotations: workspaceCommand, Use: "config-ssh", Short: "Populate your SSH config with Host entries for all of your workspaces", + Example: ` + - You can use -o (or --ssh-option) so set SSH options to be used for all your + workspaces. + + ` + cliui.Styles.Code.Render("$ coder config-ssh -o ForwardAgent=yes"), RunE: func(cmd *cobra.Command, args []string) error { client, err := createClient(cmd) if err != nil { diff --git a/cli/ssh.go b/cli/ssh.go index 58d3677b579de..85c3d9fd9b7b9 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -15,6 +15,7 @@ import ( "github.com/mattn/go-isatty" "github.com/spf13/cobra" gossh "golang.org/x/crypto/ssh" + gosshagent "golang.org/x/crypto/ssh/agent" "golang.org/x/term" "golang.org/x/xerrors" @@ -32,6 +33,7 @@ func ssh() *cobra.Command { var ( stdio bool shuffle bool + forwardAgent bool wsPollInterval time.Duration ) cmd := &cobra.Command{ @@ -108,6 +110,17 @@ func ssh() *cobra.Command { return err } + if forwardAgent && os.Getenv("SSH_AUTH_SOCK") != "" { + err = gosshagent.ForwardToRemote(sshClient, os.Getenv("SSH_AUTH_SOCK")) + if err != nil { + return xerrors.Errorf("forward agent failed: %w", err) + } + err = gosshagent.RequestAgentForwarding(sshSession) + if err != nil { + return xerrors.Errorf("request agent forwarding failed: %w", err) + } + } + stdoutFile, valid := cmd.OutOrStdout().(*os.File) if valid && isatty.IsTerminal(stdoutFile.Fd()) { state, err := term.MakeRaw(int(os.Stdin.Fd())) @@ -156,8 +169,9 @@ func ssh() *cobra.Command { } cliflag.BoolVarP(cmd.Flags(), &stdio, "stdio", "", "CODER_SSH_STDIO", false, "Specifies whether to emit SSH output over stdin/stdout.") cliflag.BoolVarP(cmd.Flags(), &shuffle, "shuffle", "", "CODER_SSH_SHUFFLE", false, "Specifies whether to choose a random workspace") - cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.") _ = cmd.Flags().MarkHidden("shuffle") + cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK") + cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.") return cmd } diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 172d250ae8302..eae32161953d7 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -1,8 +1,14 @@ package cli_test import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "errors" "io" "net" + "path/filepath" "runtime" "testing" "time" @@ -11,9 +17,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" + gosshagent "golang.org/x/crypto/ssh/agent" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent" "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" @@ -23,49 +31,53 @@ import ( "github.com/coder/coder/pty/ptytest" ) +func setupWorkspaceForSSH(t *testing.T) (*codersdk.Client, codersdk.Workspace, string) { + t.Helper() + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) + user := coderdtest.CreateFirstUser(t, client) + agentToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "dev", + Type: "google_compute_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: agentToken, + }, + }}, + }}, + }, + }, + }}, + }) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + + return client, workspace, agentToken +} + func TestSSH(t *testing.T) { - t.Skip("This is causing test flakes. TODO @cian fix this") t.Parallel() t.Run("ImmediateExit", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) - user := coderdtest.CreateFirstUser(t, client) - agentToken := uuid.NewString() - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionDryRun: echo.ProvisionComplete, - Provision: []*proto.Provision_Response{{ - Type: &proto.Provision_Response_Complete{ - Complete: &proto.Provision_Complete{ - Resources: []*proto.Resource{{ - Name: "dev", - Type: "google_compute_instance", - Agents: []*proto.Agent{{ - Id: uuid.NewString(), - Auth: &proto.Agent_Token{ - Token: agentToken, - }, - }}, - }}, - }, - }, - }}, - }) - coderdtest.AwaitTemplateVersionJob(t, client, version.ID) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + client, workspace, agentToken := setupWorkspaceForSSH(t) cmd, root := clitest.New(t, "ssh", workspace.Name) clitest.SetupConfig(t, client, root) - doneChan := make(chan struct{}) pty := ptytest.New(t) cmd.SetIn(pty.Input()) cmd.SetErr(pty.Output()) cmd.SetOut(pty.Output()) - go func() { - defer close(doneChan) + cmdDone := tGo(t, func() { err := cmd.Execute() assert.NoError(t, err) - }() + }) pty.ExpectMatch("Waiting") coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) agentClient := codersdk.New(client.URL) @@ -76,39 +88,16 @@ func TestSSH(t *testing.T) { t.Cleanup(func() { _ = agentCloser.Close() }) + // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. pty.WriteLine("exit") - <-doneChan + <-cmdDone }) t.Run("Stdio", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) - user := coderdtest.CreateFirstUser(t, client) - agentToken := uuid.NewString() - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionDryRun: echo.ProvisionComplete, - Provision: []*proto.Provision_Response{{ - Type: &proto.Provision_Response_Complete{ - Complete: &proto.Provision_Complete{ - Resources: []*proto.Resource{{ - Name: "dev", - Type: "google_compute_instance", - Agents: []*proto.Agent{{ - Id: uuid.NewString(), - Auth: &proto.Agent_Token{ - Token: agentToken, - }, - }}, - }}, - }, - }, - }}, - }) - coderdtest.AwaitTemplateVersionJob(t, client, version.ID) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) - go func() { + client, workspace, agentToken := setupWorkspaceForSSH(t) + + _, _ = tGoContext(t, func(ctx context.Context) { // Run this async so the SSH command has to wait for // the build and agent to connect! coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) @@ -117,25 +106,22 @@ func TestSSH(t *testing.T) { agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), }) - t.Cleanup(func() { - _ = agentCloser.Close() - }) - }() + <-ctx.Done() + _ = agentCloser.Close() + }) clientOutput, clientInput := io.Pipe() serverOutput, serverInput := io.Pipe() cmd, root := clitest.New(t, "ssh", "--stdio", workspace.Name) clitest.SetupConfig(t, client, root) - doneChan := make(chan struct{}) cmd.SetIn(clientOutput) cmd.SetOut(serverInput) cmd.SetErr(io.Discard) - go func() { - defer close(doneChan) + cmdDone := tGo(t, func() { err := cmd.Execute() assert.NoError(t, err) - }() + }) conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ Reader: serverOutput, @@ -157,8 +143,135 @@ func TestSSH(t *testing.T) { err = sshClient.Close() require.NoError(t, err) _ = clientOutput.Close() - <-doneChan + + <-cmdDone + }) + //nolint:paralleltest // Disabled due to use of t.Setenv. + t.Run("ForwardAgent", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Test not supported on windows") + } + + client, workspace, agentToken := setupWorkspaceForSSH(t) + + _, _ = tGoContext(t, func(ctx context.Context) { + // Run this async so the SSH command has to wait for + // the build and agent to connect! + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + agentClient := codersdk.New(client.URL) + agentClient.SessionToken = agentToken + agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + }) + <-ctx.Done() + _ = agentCloser.Close() + }) + + // Generate private key. + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + kr := gosshagent.NewKeyring() + kr.Add(gosshagent.AddedKey{ + PrivateKey: privateKey, + }) + + // Start up ssh agent listening on unix socket. + tmpdir := t.TempDir() + agentSock := filepath.Join(tmpdir, "agent.sock") + l, err := net.Listen("unix", agentSock) + require.NoError(t, err) + defer l.Close() + _ = tGo(t, func() { + for { + fd, err := l.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + t.Logf("accept error: %v", err) + } + return + } + + err = gosshagent.ServeAgent(kr, fd) + if !errors.Is(err, io.EOF) { + assert.NoError(t, err) + } + } + }) + + t.Setenv("SSH_AUTH_SOCK", agentSock) + cmd, root := clitest.New(t, + "ssh", + workspace.Name, + "--forward-agent", + ) + clitest.SetupConfig(t, client, root) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) + cmd.SetErr(io.Discard) + cmdDone := tGo(t, func() { + err := cmd.Execute() + assert.NoError(t, err) + }) + + // Ensure that SSH_AUTH_SOCK is set. + // Linux: /tmp/auth-agent3167016167/listener.sock + // macOS: /var/folders/ng/m1q0wft14hj0t3rtjxrdnzsr0000gn/T/auth-agent3245553419/listener.sock + pty.WriteLine("env") + pty.ExpectMatch("SSH_AUTH_SOCK=") + // Ensure that ssh-add lists our key. + pty.WriteLine("ssh-add -L") + keys, err := kr.List() + require.NoError(t, err) + pty.ExpectMatch(keys[0].String()) + + // And we're done. + pty.WriteLine("exit") + <-cmdDone + }) +} + +// tGoContext runs fn in a goroutine passing a context that will be +// canceled on test completion and wait until fn has finished executing. +// Done and cancel are returned for optionally waiting until completion +// or early cancellation. +// +// NOTE(mafredri): This could be moved to a helper library. +func tGoContext(t *testing.T, fn func(context.Context)) (done <-chan struct{}, cancel context.CancelFunc) { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + doneC := make(chan struct{}) + t.Cleanup(func() { + cancel() + <-done + }) + go func() { + fn(ctx) + close(doneC) + }() + + return doneC, cancel +} + +// tGo runs fn in a goroutine and waits until fn has completed before +// test completion. Done is returned for optionally waiting for fn to +// exit. +// +// NOTE(mafredri): This could be moved to a helper library. +func tGo(t *testing.T, fn func()) (done <-chan struct{}) { + t.Helper() + + doneC := make(chan struct{}) + t.Cleanup(func() { + <-doneC }) + go func() { + fn() + close(doneC) + }() + + return doneC } type stdioConn struct {