Skip to content

fix: Improve shutdown procedure of ssh, portforward, wgtunnel cmds #3354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions cli/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ func portForward() *cobra.Command {
},
),
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()

specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards)
if err != nil {
return xerrors.Errorf("parse port-forward specs: %w", err)
Expand All @@ -72,21 +75,21 @@ func portForward() *cobra.Command {
return err
}

workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false)
workspace, agent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
if err != nil {
return err
}
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
return xerrors.New("workspace must be in start transition to port-forward")
}
if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil {
return err
}
}

err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, agent.ID)
Expand All @@ -96,15 +99,14 @@ func portForward() *cobra.Command {
return xerrors.Errorf("await agent: %w", err)
}

conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
conn, err := client.DialWorkspaceAgent(ctx, agent.ID, nil)
if err != nil {
return xerrors.Errorf("dial workspace agent: %w", err)
}
defer conn.Close()

// Start all listeners.
var (
ctx, cancel = context.WithCancel(cmd.Context())
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs))
closeAllListeners = func() {
Expand All @@ -116,11 +118,11 @@ func portForward() *cobra.Command {
}
}
)
defer cancel()
defer closeAllListeners()

for i, spec := range specs {
l, err := listenAndPortForward(ctx, cmd, conn, wg, spec)
if err != nil {
closeAllListeners()
return err
}
listeners[i] = l
Expand All @@ -129,7 +131,10 @@ func portForward() *cobra.Command {
// Wait for the context to be canceled or for a signal and close
// all listeners.
var closeErr error
wg.Add(1)
go func() {
defer wg.Done()

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

Expand Down
82 changes: 48 additions & 34 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ func ssh() *cobra.Command {
Short: "SSH into a workspace",
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()

client, err := createClient(cmd)
if err != nil {
return err
Expand All @@ -68,14 +71,14 @@ func ssh() *cobra.Command {
}
}

workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], shuffle)
if err != nil {
return err
}

// OpenSSH passes stderr directly to the calling TTY.
// This is required in "stdio" mode so a connecting indicator can be displayed.
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
Expand All @@ -85,42 +88,33 @@ func ssh() *cobra.Command {
return xerrors.Errorf("await agent: %w", err)
}

var (
sshClient *gossh.Client
sshSession *gossh.Session
)
var newSSHClient func() (*gossh.Client, error)

if !wireguard {
conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil)
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
if err != nil {
return err
}
defer conn.Close()

stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
defer stopPolling()

if stdio {
rawSSH, err := conn.SSH()
if err != nil {
return err
}
defer rawSSH.Close()

go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
return nil
}

sshClient, err = conn.SSHClient()
if err != nil {
return err
}

sshSession, err = sshClient.NewSession()
if err != nil {
return err
}
newSSHClient = conn.SSHClient
} else {
// TODO: more granual control of Tailscale logging.
peerwg.Logf = tslogger.Discard
Expand All @@ -133,8 +127,9 @@ func ssh() *cobra.Command {
if err != nil {
return xerrors.Errorf("create wireguard network: %w", err)
}
defer wgn.Close()

err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{
Recipient: workspaceAgent.ID,
NodePublicKey: wgn.NodePrivateKey.Public(),
DiscoPublicKey: wgn.DiscoPublicKey,
Expand All @@ -155,10 +150,11 @@ func ssh() *cobra.Command {
}

if stdio {
rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP())
rawSSH, err := wgn.SSH(ctx, workspaceAgent.IPv6.IP())
if err != nil {
return err
}
defer rawSSH.Close()

go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
Expand All @@ -167,16 +163,29 @@ func ssh() *cobra.Command {
return nil
}

sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP())
if err != nil {
return err
newSSHClient = func() (*gossh.Client, error) {
return wgn.SSHClient(ctx, workspaceAgent.IPv6.IP())
}
}

sshSession, err = sshClient.NewSession()
if err != nil {
return err
}
sshClient, err := newSSHClient()
if err != nil {
return err
}
defer sshClient.Close()

sshSession, err := sshClient.NewSession()
if err != nil {
return err
}
defer sshSession.Close()

// Ensure context cancellation is propagated to the
// SSH session, e.g. to cancel `Wait()` at the end.
go func() {
<-ctx.Done()
_ = sshSession.Close()
}()

if identityAgent == "" {
identityAgent = os.Getenv("SSH_AUTH_SOCK")
Expand All @@ -203,15 +212,18 @@ func ssh() *cobra.Command {
_ = term.Restore(int(stdinFile.Fd()), state)
}()

windowChange := listenWindowSize(cmd.Context())
windowChange := listenWindowSize(ctx)
go func() {
for {
select {
case <-cmd.Context().Done():
case <-ctx.Done():
return
case <-windowChange:
}
width, height, _ := term.GetSize(int(stdoutFile.Fd()))
width, height, err := term.GetSize(int(stdoutFile.Fd()))
if err != nil {
continue
}
_ = sshSession.WindowChange(height, width)
}
}()
Expand All @@ -224,13 +236,17 @@ func ssh() *cobra.Command {

sshSession.Stdin = cmd.InOrStdin()
sshSession.Stdout = cmd.OutOrStdout()
sshSession.Stderr = cmd.OutOrStdout()
sshSession.Stderr = cmd.ErrOrStderr()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a drive-by change I did. It seemed wrong but perhaps I didn't understand the purpose?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me!


err = sshSession.Shell()
if err != nil {
return err
}

// Put cancel at the top of the defer stack to initiate
// shutdown of services.
defer cancel()

err = sshSession.Wait()
if err != nil {
// If the connection drops unexpectedly, we get an ExitMissingError but no other
Expand Down Expand Up @@ -259,16 +275,14 @@ func ssh() *cobra.Command {
// getWorkspaceAgent returns the workspace and agent selected using either the
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
// if `shuffle` is true.
func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
ctx := cmd.Context()

func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
var (
workspace codersdk.Workspace
workspaceParts = strings.Split(in, ".")
err error
)
if shuffle {
workspaces, err := client.Workspaces(cmd.Context(), codersdk.WorkspaceFilter{
workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{
Owner: codersdk.Me,
})
if err != nil {
Expand Down
5 changes: 1 addition & 4 deletions cli/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func TestSSH(t *testing.T) {
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
cmd.SetErr(io.Discard)
cmd.SetErr(pty.Output())
cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx)
assert.NoError(t, err)
Expand All @@ -248,9 +248,6 @@ func TestSSH(t *testing.T) {

// And we're done.
pty.WriteLine("exit")
// Read output to prevent hang on macOS, see:
// https://github.com/coder/coder/issues/2122
pty.ExpectMatch("exit")
<-cmdDone
})
}
Expand Down
20 changes: 13 additions & 7 deletions cli/wireguardtunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func wireguardPortForward() *cobra.Command {
},
),
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()

specs, err := parsePortForwards(tcpForwards, nil, nil)
if err != nil {
return xerrors.Errorf("parse port-forward specs: %w", err)
Expand All @@ -69,21 +72,21 @@ func wireguardPortForward() *cobra.Command {
return err
}

workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false)
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
if err != nil {
return err
}
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
return xerrors.New("workspace must be in start transition to port-forward")
}
if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil {
return err
}
}

err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
Expand All @@ -101,8 +104,9 @@ func wireguardPortForward() *cobra.Command {
if err != nil {
return xerrors.Errorf("create wireguard network: %w", err)
}
defer wgn.Close()

err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{
Recipient: workspaceAgent.ID,
NodePublicKey: wgn.NodePrivateKey.Public(),
DiscoPublicKey: wgn.DiscoPublicKey,
Expand All @@ -124,7 +128,6 @@ func wireguardPortForward() *cobra.Command {

// Start all listeners.
var (
ctx, cancel = context.WithCancel(cmd.Context())
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs))
closeAllListeners = func() {
Expand All @@ -136,11 +139,11 @@ func wireguardPortForward() *cobra.Command {
}
}
)
defer cancel()
defer closeAllListeners()

for i, spec := range specs {
l, err := listenAndPortForwardWireguard(ctx, cmd, wgn, wg, spec, workspaceAgent.IPv6.IP())
if err != nil {
closeAllListeners()
return err
}
listeners[i] = l
Expand All @@ -149,7 +152,10 @@ func wireguardPortForward() *cobra.Command {
// Wait for the context to be canceled or for a signal and close
// all listeners.
var closeErr error
wg.Add(1)
go func() {
defer wg.Done()

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

Expand Down