Skip to content

Commit ca14871

Browse files
committed
fix: Improve shutdown procedure of ssh, portforward, wgtunnel cmds
We could turn it into a practice to wrap `cmd.Context()` so that we have more fine-grained control of cancellation. Sometimes in tests we may be running commands with a context that is never canceled. Related to #3221
1 parent 5ae19f0 commit ca14871

File tree

3 files changed

+72
-47
lines changed

3 files changed

+72
-47
lines changed

cli/portforward.go

+12-7
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ func portForward() *cobra.Command {
5555
},
5656
),
5757
RunE: func(cmd *cobra.Command, args []string) error {
58+
ctx, cancel := context.WithCancel(cmd.Context())
59+
defer cancel()
60+
5861
specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards)
5962
if err != nil {
6063
return xerrors.Errorf("parse port-forward specs: %w", err)
@@ -72,21 +75,21 @@ func portForward() *cobra.Command {
7275
return err
7376
}
7477

75-
workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false)
78+
workspace, agent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
7679
if err != nil {
7780
return err
7881
}
7982
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
8083
return xerrors.New("workspace must be in start transition to port-forward")
8184
}
8285
if workspace.LatestBuild.Job.CompletedAt == nil {
83-
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
86+
err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
8487
if err != nil {
8588
return err
8689
}
8790
}
8891

89-
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
92+
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
9093
WorkspaceName: workspace.Name,
9194
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
9295
return client.WorkspaceAgent(ctx, agent.ID)
@@ -96,15 +99,14 @@ func portForward() *cobra.Command {
9699
return xerrors.Errorf("await agent: %w", err)
97100
}
98101

99-
conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
102+
conn, err := client.DialWorkspaceAgent(ctx, agent.ID, nil)
100103
if err != nil {
101104
return xerrors.Errorf("dial workspace agent: %w", err)
102105
}
103106
defer conn.Close()
104107

105108
// Start all listeners.
106109
var (
107-
ctx, cancel = context.WithCancel(cmd.Context())
108110
wg = new(sync.WaitGroup)
109111
listeners = make([]net.Listener, len(specs))
110112
closeAllListeners = func() {
@@ -116,11 +118,11 @@ func portForward() *cobra.Command {
116118
}
117119
}
118120
)
119-
defer cancel()
121+
defer closeAllListeners()
122+
120123
for i, spec := range specs {
121124
l, err := listenAndPortForward(ctx, cmd, conn, wg, spec)
122125
if err != nil {
123-
closeAllListeners()
124126
return err
125127
}
126128
listeners[i] = l
@@ -129,7 +131,10 @@ func portForward() *cobra.Command {
129131
// Wait for the context to be canceled or for a signal and close
130132
// all listeners.
131133
var closeErr error
134+
wg.Add(1)
132135
go func() {
136+
defer wg.Done()
137+
133138
sigs := make(chan os.Signal, 1)
134139
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
135140

cli/ssh.go

+47-33
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ func ssh() *cobra.Command {
5151
Short: "SSH into a workspace",
5252
Args: cobra.ArbitraryArgs,
5353
RunE: func(cmd *cobra.Command, args []string) error {
54+
ctx, cancel := context.WithCancel(cmd.Context())
55+
defer cancel()
56+
5457
client, err := createClient(cmd)
5558
if err != nil {
5659
return err
@@ -68,14 +71,14 @@ func ssh() *cobra.Command {
6871
}
6972
}
7073

71-
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
74+
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], shuffle)
7275
if err != nil {
7376
return err
7477
}
7578

7679
// OpenSSH passes stderr directly to the calling TTY.
7780
// This is required in "stdio" mode so a connecting indicator can be displayed.
78-
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
81+
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
7982
WorkspaceName: workspace.Name,
8083
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
8184
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
@@ -85,42 +88,33 @@ func ssh() *cobra.Command {
8588
return xerrors.Errorf("await agent: %w", err)
8689
}
8790

88-
var (
89-
sshClient *gossh.Client
90-
sshSession *gossh.Session
91-
)
91+
var newSSHClient func() (*gossh.Client, error)
9292

9393
if !wireguard {
94-
conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil)
94+
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
9595
if err != nil {
9696
return err
9797
}
9898
defer conn.Close()
9999

100-
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
100+
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
101101
defer stopPolling()
102102

103103
if stdio {
104104
rawSSH, err := conn.SSH()
105105
if err != nil {
106106
return err
107107
}
108+
defer rawSSH.Close()
109+
108110
go func() {
109111
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
110112
}()
111113
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
112114
return nil
113115
}
114116

115-
sshClient, err = conn.SSHClient()
116-
if err != nil {
117-
return err
118-
}
119-
120-
sshSession, err = sshClient.NewSession()
121-
if err != nil {
122-
return err
123-
}
117+
newSSHClient = conn.SSHClient
124118
} else {
125119
// TODO: more granual control of Tailscale logging.
126120
peerwg.Logf = tslogger.Discard
@@ -133,8 +127,9 @@ func ssh() *cobra.Command {
133127
if err != nil {
134128
return xerrors.Errorf("create wireguard network: %w", err)
135129
}
130+
defer wgn.Close()
136131

137-
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
132+
err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{
138133
Recipient: workspaceAgent.ID,
139134
NodePublicKey: wgn.NodePrivateKey.Public(),
140135
DiscoPublicKey: wgn.DiscoPublicKey,
@@ -155,10 +150,11 @@ func ssh() *cobra.Command {
155150
}
156151

157152
if stdio {
158-
rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP())
153+
rawSSH, err := wgn.SSH(ctx, workspaceAgent.IPv6.IP())
159154
if err != nil {
160155
return err
161156
}
157+
defer rawSSH.Close()
162158

163159
go func() {
164160
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
@@ -167,16 +163,29 @@ func ssh() *cobra.Command {
167163
return nil
168164
}
169165

170-
sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP())
171-
if err != nil {
172-
return err
166+
newSSHClient = func() (*gossh.Client, error) {
167+
return wgn.SSHClient(ctx, workspaceAgent.IPv6.IP())
173168
}
169+
}
174170

175-
sshSession, err = sshClient.NewSession()
176-
if err != nil {
177-
return err
178-
}
171+
sshClient, err := newSSHClient()
172+
if err != nil {
173+
return err
174+
}
175+
defer sshClient.Close()
176+
177+
sshSession, err := sshClient.NewSession()
178+
if err != nil {
179+
return err
179180
}
181+
defer sshSession.Close()
182+
183+
// Ensure context cancellation is propagated to the
184+
// SSH session, e.g. to cancel `Wait()` at the end.
185+
go func() {
186+
<-ctx.Done()
187+
_ = sshSession.Close()
188+
}()
180189

181190
if identityAgent == "" {
182191
identityAgent = os.Getenv("SSH_AUTH_SOCK")
@@ -203,15 +212,18 @@ func ssh() *cobra.Command {
203212
_ = term.Restore(int(stdinFile.Fd()), state)
204213
}()
205214

206-
windowChange := listenWindowSize(cmd.Context())
215+
windowChange := listenWindowSize(ctx)
207216
go func() {
208217
for {
209218
select {
210-
case <-cmd.Context().Done():
219+
case <-ctx.Done():
211220
return
212221
case <-windowChange:
213222
}
214-
width, height, _ := term.GetSize(int(stdoutFile.Fd()))
223+
width, height, err := term.GetSize(int(stdoutFile.Fd()))
224+
if err != nil {
225+
continue
226+
}
215227
_ = sshSession.WindowChange(height, width)
216228
}
217229
}()
@@ -231,6 +243,10 @@ func ssh() *cobra.Command {
231243
return err
232244
}
233245

246+
// Put cancel at the top of the defer stack to initiate
247+
// shutdown of services.
248+
defer cancel()
249+
234250
err = sshSession.Wait()
235251
if err != nil {
236252
// If the connection drops unexpectedly, we get an ExitMissingError but no other
@@ -259,16 +275,14 @@ func ssh() *cobra.Command {
259275
// getWorkspaceAgent returns the workspace and agent selected using either the
260276
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
261277
// if `shuffle` is true.
262-
func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
263-
ctx := cmd.Context()
264-
278+
func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
265279
var (
266280
workspace codersdk.Workspace
267281
workspaceParts = strings.Split(in, ".")
268282
err error
269283
)
270284
if shuffle {
271-
workspaces, err := client.Workspaces(cmd.Context(), codersdk.WorkspaceFilter{
285+
workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{
272286
Owner: codersdk.Me,
273287
})
274288
if err != nil {

cli/wireguardtunnel.go

+13-7
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ func wireguardPortForward() *cobra.Command {
5252
},
5353
),
5454
RunE: func(cmd *cobra.Command, args []string) error {
55+
ctx, cancel := context.WithCancel(cmd.Context())
56+
defer cancel()
57+
5558
specs, err := parsePortForwards(tcpForwards, nil, nil)
5659
if err != nil {
5760
return xerrors.Errorf("parse port-forward specs: %w", err)
@@ -69,21 +72,21 @@ func wireguardPortForward() *cobra.Command {
6972
return err
7073
}
7174

72-
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false)
75+
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
7376
if err != nil {
7477
return err
7578
}
7679
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
7780
return xerrors.New("workspace must be in start transition to port-forward")
7881
}
7982
if workspace.LatestBuild.Job.CompletedAt == nil {
80-
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
83+
err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
8184
if err != nil {
8285
return err
8386
}
8487
}
8588

86-
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
89+
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
8790
WorkspaceName: workspace.Name,
8891
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
8992
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
@@ -101,8 +104,9 @@ func wireguardPortForward() *cobra.Command {
101104
if err != nil {
102105
return xerrors.Errorf("create wireguard network: %w", err)
103106
}
107+
defer wgn.Close()
104108

105-
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
109+
err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{
106110
Recipient: workspaceAgent.ID,
107111
NodePublicKey: wgn.NodePrivateKey.Public(),
108112
DiscoPublicKey: wgn.DiscoPublicKey,
@@ -124,7 +128,6 @@ func wireguardPortForward() *cobra.Command {
124128

125129
// Start all listeners.
126130
var (
127-
ctx, cancel = context.WithCancel(cmd.Context())
128131
wg = new(sync.WaitGroup)
129132
listeners = make([]net.Listener, len(specs))
130133
closeAllListeners = func() {
@@ -136,11 +139,11 @@ func wireguardPortForward() *cobra.Command {
136139
}
137140
}
138141
)
139-
defer cancel()
142+
defer closeAllListeners()
143+
140144
for i, spec := range specs {
141145
l, err := listenAndPortForwardWireguard(ctx, cmd, wgn, wg, spec, workspaceAgent.IPv6.IP())
142146
if err != nil {
143-
closeAllListeners()
144147
return err
145148
}
146149
listeners[i] = l
@@ -149,7 +152,10 @@ func wireguardPortForward() *cobra.Command {
149152
// Wait for the context to be canceled or for a signal and close
150153
// all listeners.
151154
var closeErr error
155+
wg.Add(1)
152156
go func() {
157+
defer wg.Done()
158+
153159
sigs := make(chan os.Signal, 1)
154160
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
155161

0 commit comments

Comments
 (0)