Skip to content
Merged
Prev Previous commit
Next Next commit
Fix race the right way
  • Loading branch information
ammario committed Jun 6, 2023
commit ad2c946290107faf01c15d1aa133d0cca903b194
60 changes: 34 additions & 26 deletions cli/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"os/exec"
"path/filepath"
"strings"
"sync"
"time"

"github.com/gen2brain/beeep"
Expand Down Expand Up @@ -77,18 +78,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
if err != nil {
return xerrors.Errorf("error opening %s for logging: %w", logFilePath, err)
}
// HACK: Something was keeping a reference to this file
// after the goroutine ends, leading to the race observed
// here: https://github.com/coder/coder/actions/runs/5178818818/jobs/9331016395.
rd, wr := io.Pipe()
go func() {
_, _ = io.Copy(logFile, rd)
}()
defer func() {
_ = wr.Close()
_ = logFile.Close()
}()
logger = slog.Make(sloghuman.Sink(wr))
logger = slog.Make(sloghuman.Sink(logFile))
if r.verbose {
logger = logger.Leveled(slog.LevelDebug)
}
Expand Down Expand Up @@ -157,15 +147,28 @@ func (r *RootCmd) ssh() *clibase.Cmd {
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
defer stopPolling()

// This WaitGroup solves for a race condition where we were logging
// while closing the log file in in a defer. It probably solves
// others too.
var wg sync.WaitGroup
defer wg.Wait()

if stdio {
rawSSH, err := conn.SSH(ctx)
if err != nil {
return xerrors.Errorf("connect SSH: %w", err)
}
defer rawSSH.Close()
go watchAndClose(ctx, rawSSH.Close, logger, client, workspace)

wg.Add(1)
go func() {
defer wg.Done()
watchAndClose(ctx, rawSSH.Close, logger, client, workspace)
}()

wg.Add(1)
go func() {
defer wg.Done()
// Ensure stdout copy closes incase stdin is closed
// unexpectedly. Typically we wouldn't worry about
// this since OpenSSH should kill the proxy command.
Expand Down Expand Up @@ -198,19 +201,24 @@ func (r *RootCmd) ssh() *clibase.Cmd {
return xerrors.Errorf("ssh session: %w", err)
}
defer sshSession.Close()
go watchAndClose(
ctx,
func() error {
err := sshSession.Close()
logger.Debug(ctx, "session close", slog.Error(err))
err = sshClient.Close()
logger.Debug(ctx, "client close", slog.Error(err))
return nil
},
logger,
client,
workspace,
)

wg.Add(1)
go func() {
defer wg.Done()
watchAndClose(
ctx,
func() error {
err := sshSession.Close()
logger.Debug(ctx, "session close", slog.Error(err))
err = sshClient.Close()
logger.Debug(ctx, "client close", slog.Error(err))
return nil
},
logger,
client,
workspace,
)
}()

if identityAgent == "" {
identityAgent = os.Getenv("SSH_AUTH_SOCK")
Expand Down