Skip to content

Commit c2871e1

Browse files
authored
fix(cli/ssh): Avoid connection hang when workspace is stopped (#7201)
* fix(cli/ssh): Avoid connection hang when workspace is stopped Two issues are addressed here: 1. We were not detecting disconnects due to waiting for Stdin to close (disconnect would only propagate after entering input and failing to write to the connection). 2. In other scenarios, where the connection drop is not detected, we now also watch workspace status and drop the connection when a workspace reaches the stopped state. Fixes: https://github.com/coder/jetbrains-coder/issues/199 Refs: #6180, #6175
1 parent fff2b1d commit c2871e1

File tree

2 files changed

+187
-8
lines changed

2 files changed

+187
-8
lines changed

cli/ssh.go

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/coder/coder/coderd/util/ptr"
3131
"github.com/coder/coder/codersdk"
3232
"github.com/coder/coder/cryptorand"
33+
"github.com/coder/retry"
3334
)
3435

3536
var (
@@ -100,17 +101,82 @@ func (r *RootCmd) ssh() *clibase.Cmd {
100101
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
101102
defer stopPolling()
102103

104+
// Enure connection is closed if the context is canceled or
105+
// the workspace reaches the stopped state.
106+
//
107+
// Watching the stopped state is a work-around for cases
108+
// where the agent is not gracefully shut down and the
109+
// connection is left open. If, for instance, the networking
110+
// is stopped before the agent is shut down, the disconnect
111+
// will usually not propagate.
112+
//
113+
// See: https://github.com/coder/coder/issues/6180
114+
watchAndClose := func(closer func() error) {
115+
// Ensure session is ended on both context cancellation
116+
// and workspace stop.
117+
defer func() {
118+
_ = closer()
119+
}()
120+
121+
startWatchLoop:
122+
for {
123+
// (Re)connect to the coder server and watch workspace events.
124+
var wsWatch <-chan codersdk.Workspace
125+
var err error
126+
for r := retry.New(time.Second, 15*time.Second); r.Wait(ctx); {
127+
wsWatch, err = client.WatchWorkspace(ctx, workspace.ID)
128+
if err == nil {
129+
break
130+
}
131+
if ctx.Err() != nil {
132+
return
133+
}
134+
}
135+
136+
for {
137+
select {
138+
case <-ctx.Done():
139+
return
140+
case w, ok := <-wsWatch:
141+
if !ok {
142+
continue startWatchLoop
143+
}
144+
145+
// Transitioning to stop or delete could mean that
146+
// the agent will still gracefully stop. If a new
147+
// build is starting, there's no reason to wait for
148+
// the agent, it should be long gone.
149+
if workspace.LatestBuild.ID != w.LatestBuild.ID && w.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
150+
return
151+
}
152+
// Note, we only react to the stopped state here because we
153+
// want to give the agent a chance to gracefully shut down
154+
// during "stopping".
155+
if w.LatestBuild.Status == codersdk.WorkspaceStatusStopped {
156+
return
157+
}
158+
}
159+
}
160+
}
161+
}
162+
103163
if stdio {
104164
rawSSH, err := conn.SSH(ctx)
105165
if err != nil {
106166
return err
107167
}
108168
defer rawSSH.Close()
169+
go watchAndClose(rawSSH.Close)
109170

110171
go func() {
111-
_, _ = io.Copy(inv.Stdout, rawSSH)
172+
// Ensure stdout copy closes incase stdin is closed
173+
// unexpectedly. Typically we wouldn't worry about
174+
// this since OpenSSH should kill the proxy command.
175+
defer rawSSH.Close()
176+
177+
_, _ = io.Copy(rawSSH, inv.Stdin)
112178
}()
113-
_, _ = io.Copy(rawSSH, inv.Stdin)
179+
_, _ = io.Copy(inv.Stdout, rawSSH)
114180
return nil
115181
}
116182

@@ -125,13 +191,11 @@ func (r *RootCmd) ssh() *clibase.Cmd {
125191
return err
126192
}
127193
defer sshSession.Close()
128-
129-
// Ensure context cancellation is propagated to the
130-
// SSH session, e.g. to cancel `Wait()` at the end.
131-
go func() {
132-
<-ctx.Done()
194+
go watchAndClose(func() error {
133195
_ = sshSession.Close()
134-
}()
196+
_ = sshClient.Close()
197+
return nil
198+
})
135199

136200
if identityAgent == "" {
137201
identityAgent = os.Getenv("SSH_AUTH_SOCK")

cli/ssh_test.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"github.com/coder/coder/cli/clitest"
3232
"github.com/coder/coder/cli/cliui"
3333
"github.com/coder/coder/coderd/coderdtest"
34+
"github.com/coder/coder/coderd/database"
3435
"github.com/coder/coder/codersdk"
3536
"github.com/coder/coder/codersdk/agentsdk"
3637
"github.com/coder/coder/provisioner/echo"
@@ -143,6 +144,50 @@ func TestSSH(t *testing.T) {
143144
cancel()
144145
<-cmdDone
145146
})
147+
148+
t.Run("ExitOnStop", func(t *testing.T) {
149+
t.Parallel()
150+
if runtime.GOOS == "windows" {
151+
t.Skip("Windows doesn't seem to clean up the process, maybe #7100 will fix it")
152+
}
153+
154+
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
155+
inv, root := clitest.New(t, "ssh", workspace.Name)
156+
clitest.SetupConfig(t, client, root)
157+
pty := ptytest.New(t).Attach(inv)
158+
159+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
160+
defer cancel()
161+
162+
cmdDone := tGo(t, func() {
163+
err := inv.WithContext(ctx).Run()
164+
assert.Error(t, err)
165+
})
166+
pty.ExpectMatch("Waiting")
167+
168+
agentClient := agentsdk.New(client.URL)
169+
agentClient.SetSessionToken(agentToken)
170+
agentCloser := agent.New(agent.Options{
171+
Client: agentClient,
172+
Logger: slogtest.Make(t, nil).Named("agent"),
173+
})
174+
defer func() {
175+
_ = agentCloser.Close()
176+
}()
177+
178+
// Ensure the agent is connected.
179+
pty.WriteLine("echo hell'o'")
180+
pty.ExpectMatchContext(ctx, "hello")
181+
182+
workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStart, database.WorkspaceTransitionStop)
183+
184+
select {
185+
case <-cmdDone:
186+
case <-ctx.Done():
187+
require.Fail(t, "command did not exit in time")
188+
}
189+
})
190+
146191
t.Run("Stdio", func(t *testing.T) {
147192
t.Parallel()
148193
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
@@ -207,6 +252,76 @@ func TestSSH(t *testing.T) {
207252

208253
<-cmdDone
209254
})
255+
256+
t.Run("StdioExitOnStop", func(t *testing.T) {
257+
t.Parallel()
258+
if runtime.GOOS == "windows" {
259+
t.Skip("Windows doesn't seem to clean up the process, maybe #7100 will fix it")
260+
}
261+
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
262+
_, _ = tGoContext(t, func(ctx context.Context) {
263+
// Run this async so the SSH command has to wait for
264+
// the build and agent to connect!
265+
agentClient := agentsdk.New(client.URL)
266+
agentClient.SetSessionToken(agentToken)
267+
agentCloser := agent.New(agent.Options{
268+
Client: agentClient,
269+
Logger: slogtest.Make(t, nil).Named("agent"),
270+
})
271+
<-ctx.Done()
272+
_ = agentCloser.Close()
273+
})
274+
275+
clientOutput, clientInput := io.Pipe()
276+
serverOutput, serverInput := io.Pipe()
277+
defer func() {
278+
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
279+
_ = c.Close()
280+
}
281+
}()
282+
283+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
284+
defer cancel()
285+
286+
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
287+
clitest.SetupConfig(t, client, root)
288+
inv.Stdin = clientOutput
289+
inv.Stdout = serverInput
290+
inv.Stderr = io.Discard
291+
cmdDone := tGo(t, func() {
292+
err := inv.WithContext(ctx).Run()
293+
assert.NoError(t, err)
294+
})
295+
296+
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
297+
Reader: serverOutput,
298+
Writer: clientInput,
299+
}, "", &ssh.ClientConfig{
300+
// #nosec
301+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
302+
})
303+
require.NoError(t, err)
304+
defer conn.Close()
305+
306+
sshClient := ssh.NewClient(conn, channels, requests)
307+
defer sshClient.Close()
308+
309+
session, err := sshClient.NewSession()
310+
require.NoError(t, err)
311+
defer session.Close()
312+
313+
err = session.Shell()
314+
require.NoError(t, err)
315+
316+
workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStart, database.WorkspaceTransitionStop)
317+
318+
select {
319+
case <-cmdDone:
320+
case <-ctx.Done():
321+
require.Fail(t, "command did not exit in time")
322+
}
323+
})
324+
210325
t.Run("ForwardAgent", func(t *testing.T) {
211326
if runtime.GOOS == "windows" {
212327
t.Skip("Test not supported on windows")

0 commit comments

Comments
 (0)