From ca1487127c1c634ae710a1ca823bb2c5fcda575e Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 2 Aug 2022 16:35:31 +0300 Subject: [PATCH 1/3] fix: Improve shutdown procedure of ssh, portforward, wgtunnel cmds We could turn it into a practice to wrap `cmd.Context()` so that we have more fine-grained control of cancellation. Sometimes in tests we may be running commands with a context that is never canceled. Related to #3221 --- cli/portforward.go | 19 ++++++---- cli/ssh.go | 80 +++++++++++++++++++++++++----------------- cli/wireguardtunnel.go | 20 +++++++---- 3 files changed, 72 insertions(+), 47 deletions(-) diff --git a/cli/portforward.go b/cli/portforward.go index a604aad220a2f..9a13ef6eab6dd 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -55,6 +55,9 @@ func portForward() *cobra.Command { }, ), RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards) if err != nil { return xerrors.Errorf("parse port-forward specs: %w", err) @@ -72,7 +75,7 @@ func portForward() *cobra.Command { return err } - workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false) + workspace, agent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false) if err != nil { return err } @@ -80,13 +83,13 @@ func portForward() *cobra.Command { return xerrors.New("workspace must be in start transition to port-forward") } if workspace.LatestBuild.Job.CompletedAt == nil { - err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) + err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) if err != nil { return err } } - err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ + err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ WorkspaceName: workspace.Name, Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { return client.WorkspaceAgent(ctx, agent.ID) @@ -96,7 +99,7 @@ func portForward() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil) + conn, err := client.DialWorkspaceAgent(ctx, agent.ID, nil) if err != nil { return xerrors.Errorf("dial workspace agent: %w", err) } @@ -104,7 +107,6 @@ func portForward() *cobra.Command { // Start all listeners. var ( - ctx, cancel = context.WithCancel(cmd.Context()) wg = new(sync.WaitGroup) listeners = make([]net.Listener, len(specs)) closeAllListeners = func() { @@ -116,11 +118,11 @@ func portForward() *cobra.Command { } } ) - defer cancel() + defer closeAllListeners() + for i, spec := range specs { l, err := listenAndPortForward(ctx, cmd, conn, wg, spec) if err != nil { - closeAllListeners() return err } listeners[i] = l @@ -129,7 +131,10 @@ func portForward() *cobra.Command { // Wait for the context to be canceled or for a signal and close // all listeners. var closeErr error + wg.Add(1) go func() { + defer wg.Done() + sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) diff --git a/cli/ssh.go b/cli/ssh.go index a5e420be0bbaa..f5a1f96c9a964 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -51,6 +51,9 @@ func ssh() *cobra.Command { Short: "SSH into a workspace", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + client, err := createClient(cmd) if err != nil { return err @@ -68,14 +71,14 @@ func ssh() *cobra.Command { } } - workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle) + workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], shuffle) if err != nil { return err } // OpenSSH passes stderr directly to the calling TTY. // This is required in "stdio" mode so a connecting indicator can be displayed. - err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ + err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ WorkspaceName: workspace.Name, Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { return client.WorkspaceAgent(ctx, workspaceAgent.ID) @@ -85,19 +88,16 @@ func ssh() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - var ( - sshClient *gossh.Client - sshSession *gossh.Session - ) + var newSSHClient func() (*gossh.Client, error) if !wireguard { - conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil) + conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) if err != nil { return err } defer conn.Close() - stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace) + stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) defer stopPolling() if stdio { @@ -105,6 +105,8 @@ func ssh() *cobra.Command { if err != nil { return err } + defer rawSSH.Close() + go func() { _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) }() @@ -112,15 +114,7 @@ func ssh() *cobra.Command { return nil } - sshClient, err = conn.SSHClient() - if err != nil { - return err - } - - sshSession, err = sshClient.NewSession() - if err != nil { - return err - } + newSSHClient = conn.SSHClient } else { // TODO: more granual control of Tailscale logging. peerwg.Logf = tslogger.Discard @@ -133,8 +127,9 @@ func ssh() *cobra.Command { if err != nil { return xerrors.Errorf("create wireguard network: %w", err) } + defer wgn.Close() - err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{ + err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{ Recipient: workspaceAgent.ID, NodePublicKey: wgn.NodePrivateKey.Public(), DiscoPublicKey: wgn.DiscoPublicKey, @@ -155,10 +150,11 @@ func ssh() *cobra.Command { } if stdio { - rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP()) + rawSSH, err := wgn.SSH(ctx, workspaceAgent.IPv6.IP()) if err != nil { return err } + defer rawSSH.Close() go func() { _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) @@ -167,16 +163,29 @@ func ssh() *cobra.Command { return nil } - sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP()) - if err != nil { - return err + newSSHClient = func() (*gossh.Client, error) { + return wgn.SSHClient(ctx, workspaceAgent.IPv6.IP()) } + } - sshSession, err = sshClient.NewSession() - if err != nil { - return err - } + sshClient, err := newSSHClient() + if err != nil { + return err + } + defer sshClient.Close() + + sshSession, err := sshClient.NewSession() + if err != nil { + return err } + defer sshSession.Close() + + // Ensure context cancellation is propagated to the + // SSH session, e.g. to cancel `Wait()` at the end. + go func() { + <-ctx.Done() + _ = sshSession.Close() + }() if identityAgent == "" { identityAgent = os.Getenv("SSH_AUTH_SOCK") @@ -203,15 +212,18 @@ func ssh() *cobra.Command { _ = term.Restore(int(stdinFile.Fd()), state) }() - windowChange := listenWindowSize(cmd.Context()) + windowChange := listenWindowSize(ctx) go func() { for { select { - case <-cmd.Context().Done(): + case <-ctx.Done(): return case <-windowChange: } - width, height, _ := term.GetSize(int(stdoutFile.Fd())) + width, height, err := term.GetSize(int(stdoutFile.Fd())) + if err != nil { + continue + } _ = sshSession.WindowChange(height, width) } }() @@ -231,6 +243,10 @@ func ssh() *cobra.Command { return err } + // Put cancel at the top of the defer stack to initiate + // shutdown of services. + defer cancel() + err = sshSession.Wait() if err != nil { // If the connection drops unexpectedly, we get an ExitMissingError but no other @@ -259,16 +275,14 @@ func ssh() *cobra.Command { // getWorkspaceAgent returns the workspace and agent selected using either the // `[.]` syntax via `in` or picks a random workspace and agent // if `shuffle` is true. -func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive - ctx := cmd.Context() - +func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive var ( workspace codersdk.Workspace workspaceParts = strings.Split(in, ".") err error ) if shuffle { - workspaces, err := client.Workspaces(cmd.Context(), codersdk.WorkspaceFilter{ + workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{ Owner: codersdk.Me, }) if err != nil { diff --git a/cli/wireguardtunnel.go b/cli/wireguardtunnel.go index 8d953a846072e..592f8b4069f0d 100644 --- a/cli/wireguardtunnel.go +++ b/cli/wireguardtunnel.go @@ -52,6 +52,9 @@ func wireguardPortForward() *cobra.Command { }, ), RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + specs, err := parsePortForwards(tcpForwards, nil, nil) if err != nil { return xerrors.Errorf("parse port-forward specs: %w", err) @@ -69,7 +72,7 @@ func wireguardPortForward() *cobra.Command { return err } - workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false) + workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false) if err != nil { return err } @@ -77,13 +80,13 @@ func wireguardPortForward() *cobra.Command { return xerrors.New("workspace must be in start transition to port-forward") } if workspace.LatestBuild.Job.CompletedAt == nil { - err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) + err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) if err != nil { return err } } - err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ + err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ WorkspaceName: workspace.Name, Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { return client.WorkspaceAgent(ctx, workspaceAgent.ID) @@ -101,8 +104,9 @@ func wireguardPortForward() *cobra.Command { if err != nil { return xerrors.Errorf("create wireguard network: %w", err) } + defer wgn.Close() - err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{ + err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{ Recipient: workspaceAgent.ID, NodePublicKey: wgn.NodePrivateKey.Public(), DiscoPublicKey: wgn.DiscoPublicKey, @@ -124,7 +128,6 @@ func wireguardPortForward() *cobra.Command { // Start all listeners. var ( - ctx, cancel = context.WithCancel(cmd.Context()) wg = new(sync.WaitGroup) listeners = make([]net.Listener, len(specs)) closeAllListeners = func() { @@ -136,11 +139,11 @@ func wireguardPortForward() *cobra.Command { } } ) - defer cancel() + defer closeAllListeners() + for i, spec := range specs { l, err := listenAndPortForwardWireguard(ctx, cmd, wgn, wg, spec, workspaceAgent.IPv6.IP()) if err != nil { - closeAllListeners() return err } listeners[i] = l @@ -149,7 +152,10 @@ func wireguardPortForward() *cobra.Command { // Wait for the context to be canceled or for a signal and close // all listeners. var closeErr error + wg.Add(1) go func() { + defer wg.Done() + sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) From ee384a2a427b3f60538e9744a6aec29ba970cd28 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 2 Aug 2022 16:39:16 +0300 Subject: [PATCH 2/3] fix: Set ssh session stderr to stderr --- cli/ssh.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/ssh.go b/cli/ssh.go index f5a1f96c9a964..53e6ce88f9c75 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -236,7 +236,7 @@ func ssh() *cobra.Command { sshSession.Stdin = cmd.InOrStdin() sshSession.Stdout = cmd.OutOrStdout() - sshSession.Stderr = cmd.OutOrStdout() + sshSession.Stderr = cmd.ErrOrStderr() err = sshSession.Shell() if err != nil { From e483eef5980d7cff12b680d1cd162396b9f33fac Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 2 Aug 2022 16:42:33 +0300 Subject: [PATCH 3/3] Set stderr in ssh test to pty output --- cli/ssh_test.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 6978e5da298ba..059df85309bb3 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -229,7 +229,7 @@ func TestSSH(t *testing.T) { pty := ptytest.New(t) cmd.SetIn(pty.Input()) cmd.SetOut(pty.Output()) - cmd.SetErr(io.Discard) + cmd.SetErr(pty.Output()) cmdDone := tGo(t, func() { err := cmd.ExecuteContext(ctx) assert.NoError(t, err) @@ -248,9 +248,6 @@ func TestSSH(t *testing.T) { // And we're done. pty.WriteLine("exit") - // Read output to prevent hang on macOS, see: - // https://github.com/coder/coder/issues/2122 - pty.ExpectMatch("exit") <-cmdDone }) }