diff --git a/cli/portforward.go b/cli/portforward.go index a604aad220a2f..9a13ef6eab6dd 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -55,6 +55,9 @@ func portForward() *cobra.Command { }, ), RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards) if err != nil { return xerrors.Errorf("parse port-forward specs: %w", err) @@ -72,7 +75,7 @@ func portForward() *cobra.Command { return err } - workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false) + workspace, agent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false) if err != nil { return err } @@ -80,13 +83,13 @@ func portForward() *cobra.Command { return xerrors.New("workspace must be in start transition to port-forward") } if workspace.LatestBuild.Job.CompletedAt == nil { - err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) + err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) if err != nil { return err } } - err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ + err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ WorkspaceName: workspace.Name, Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { return client.WorkspaceAgent(ctx, agent.ID) @@ -96,7 +99,7 @@ func portForward() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil) + conn, err := client.DialWorkspaceAgent(ctx, agent.ID, nil) if err != nil { return xerrors.Errorf("dial workspace agent: %w", err) } @@ -104,7 +107,6 @@ func portForward() *cobra.Command { // Start all listeners. var ( - ctx, cancel = context.WithCancel(cmd.Context()) wg = new(sync.WaitGroup) listeners = make([]net.Listener, len(specs)) closeAllListeners = func() { @@ -116,11 +118,11 @@ func portForward() *cobra.Command { } } ) - defer cancel() + defer closeAllListeners() + for i, spec := range specs { l, err := listenAndPortForward(ctx, cmd, conn, wg, spec) if err != nil { - closeAllListeners() return err } listeners[i] = l @@ -129,7 +131,10 @@ func portForward() *cobra.Command { // Wait for the context to be canceled or for a signal and close // all listeners. var closeErr error + wg.Add(1) go func() { + defer wg.Done() + sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) diff --git a/cli/ssh.go b/cli/ssh.go index a5e420be0bbaa..53e6ce88f9c75 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -51,6 +51,9 @@ func ssh() *cobra.Command { Short: "SSH into a workspace", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + client, err := createClient(cmd) if err != nil { return err @@ -68,14 +71,14 @@ func ssh() *cobra.Command { } } - workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle) + workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], shuffle) if err != nil { return err } // OpenSSH passes stderr directly to the calling TTY. // This is required in "stdio" mode so a connecting indicator can be displayed. - err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ + err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ WorkspaceName: workspace.Name, Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { return client.WorkspaceAgent(ctx, workspaceAgent.ID) @@ -85,19 +88,16 @@ func ssh() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - var ( - sshClient *gossh.Client - sshSession *gossh.Session - ) + var newSSHClient func() (*gossh.Client, error) if !wireguard { - conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil) + conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) if err != nil { return err } defer conn.Close() - stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace) + stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) defer stopPolling() if stdio { @@ -105,6 +105,8 @@ func ssh() *cobra.Command { if err != nil { return err } + defer rawSSH.Close() + go func() { _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) }() @@ -112,15 +114,7 @@ func ssh() *cobra.Command { return nil } - sshClient, err = conn.SSHClient() - if err != nil { - return err - } - - sshSession, err = sshClient.NewSession() - if err != nil { - return err - } + newSSHClient = conn.SSHClient } else { // TODO: more granual control of Tailscale logging. peerwg.Logf = tslogger.Discard @@ -133,8 +127,9 @@ func ssh() *cobra.Command { if err != nil { return xerrors.Errorf("create wireguard network: %w", err) } + defer wgn.Close() - err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{ + err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{ Recipient: workspaceAgent.ID, NodePublicKey: wgn.NodePrivateKey.Public(), DiscoPublicKey: wgn.DiscoPublicKey, @@ -155,10 +150,11 @@ func ssh() *cobra.Command { } if stdio { - rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP()) + rawSSH, err := wgn.SSH(ctx, workspaceAgent.IPv6.IP()) if err != nil { return err } + defer rawSSH.Close() go func() { _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) @@ -167,16 +163,29 @@ func ssh() *cobra.Command { return nil } - sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP()) - if err != nil { - return err + newSSHClient = func() (*gossh.Client, error) { + return wgn.SSHClient(ctx, workspaceAgent.IPv6.IP()) } + } - sshSession, err = sshClient.NewSession() - if err != nil { - return err - } + sshClient, err := newSSHClient() + if err != nil { + return err + } + defer sshClient.Close() + + sshSession, err := sshClient.NewSession() + if err != nil { + return err } + defer sshSession.Close() + + // Ensure context cancellation is propagated to the + // SSH session, e.g. to cancel `Wait()` at the end. + go func() { + <-ctx.Done() + _ = sshSession.Close() + }() if identityAgent == "" { identityAgent = os.Getenv("SSH_AUTH_SOCK") @@ -203,15 +212,18 @@ func ssh() *cobra.Command { _ = term.Restore(int(stdinFile.Fd()), state) }() - windowChange := listenWindowSize(cmd.Context()) + windowChange := listenWindowSize(ctx) go func() { for { select { - case <-cmd.Context().Done(): + case <-ctx.Done(): return case <-windowChange: } - width, height, _ := term.GetSize(int(stdoutFile.Fd())) + width, height, err := term.GetSize(int(stdoutFile.Fd())) + if err != nil { + continue + } _ = sshSession.WindowChange(height, width) } }() @@ -224,13 +236,17 @@ func ssh() *cobra.Command { sshSession.Stdin = cmd.InOrStdin() sshSession.Stdout = cmd.OutOrStdout() - sshSession.Stderr = cmd.OutOrStdout() + sshSession.Stderr = cmd.ErrOrStderr() err = sshSession.Shell() if err != nil { return err } + // Put cancel at the top of the defer stack to initiate + // shutdown of services. + defer cancel() + err = sshSession.Wait() if err != nil { // If the connection drops unexpectedly, we get an ExitMissingError but no other @@ -259,16 +275,14 @@ func ssh() *cobra.Command { // getWorkspaceAgent returns the workspace and agent selected using either the // `[.]` syntax via `in` or picks a random workspace and agent // if `shuffle` is true. -func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive - ctx := cmd.Context() - +func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive var ( workspace codersdk.Workspace workspaceParts = strings.Split(in, ".") err error ) if shuffle { - workspaces, err := client.Workspaces(cmd.Context(), codersdk.WorkspaceFilter{ + workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{ Owner: codersdk.Me, }) if err != nil { diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 6978e5da298ba..059df85309bb3 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -229,7 +229,7 @@ func TestSSH(t *testing.T) { pty := ptytest.New(t) cmd.SetIn(pty.Input()) cmd.SetOut(pty.Output()) - cmd.SetErr(io.Discard) + cmd.SetErr(pty.Output()) cmdDone := tGo(t, func() { err := cmd.ExecuteContext(ctx) assert.NoError(t, err) @@ -248,9 +248,6 @@ func TestSSH(t *testing.T) { // And we're done. pty.WriteLine("exit") - // Read output to prevent hang on macOS, see: - // https://github.com/coder/coder/issues/2122 - pty.ExpectMatch("exit") <-cmdDone }) } diff --git a/cli/wireguardtunnel.go b/cli/wireguardtunnel.go index 8d953a846072e..592f8b4069f0d 100644 --- a/cli/wireguardtunnel.go +++ b/cli/wireguardtunnel.go @@ -52,6 +52,9 @@ func wireguardPortForward() *cobra.Command { }, ), RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + specs, err := parsePortForwards(tcpForwards, nil, nil) if err != nil { return xerrors.Errorf("parse port-forward specs: %w", err) @@ -69,7 +72,7 @@ func wireguardPortForward() *cobra.Command { return err } - workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false) + workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false) if err != nil { return err } @@ -77,13 +80,13 @@ func wireguardPortForward() *cobra.Command { return xerrors.New("workspace must be in start transition to port-forward") } if workspace.LatestBuild.Job.CompletedAt == nil { - err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) + err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) if err != nil { return err } } - err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ + err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ WorkspaceName: workspace.Name, Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { return client.WorkspaceAgent(ctx, workspaceAgent.ID) @@ -101,8 +104,9 @@ func wireguardPortForward() *cobra.Command { if err != nil { return xerrors.Errorf("create wireguard network: %w", err) } + defer wgn.Close() - err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{ + err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{ Recipient: workspaceAgent.ID, NodePublicKey: wgn.NodePrivateKey.Public(), DiscoPublicKey: wgn.DiscoPublicKey, @@ -124,7 +128,6 @@ func wireguardPortForward() *cobra.Command { // Start all listeners. var ( - ctx, cancel = context.WithCancel(cmd.Context()) wg = new(sync.WaitGroup) listeners = make([]net.Listener, len(specs)) closeAllListeners = func() { @@ -136,11 +139,11 @@ func wireguardPortForward() *cobra.Command { } } ) - defer cancel() + defer closeAllListeners() + for i, spec := range specs { l, err := listenAndPortForwardWireguard(ctx, cmd, wgn, wg, spec, workspaceAgent.IPv6.IP()) if err != nil { - closeAllListeners() return err } listeners[i] = l @@ -149,7 +152,10 @@ func wireguardPortForward() *cobra.Command { // Wait for the context to be canceled or for a signal and close // all listeners. var closeErr error + wg.Add(1) go func() { + defer wg.Done() + sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)