From 37adb622eb7519b4cd275c0f2658bf2df96423d6 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 4 Apr 2023 14:34:04 +0000 Subject: [PATCH 1/7] refactor(agent): Move SSH server into agentssh package Refs: #6177 --- agent/agent.go | 568 ++----------------------- agent/agent_test.go | 11 +- agent/agentssh/agentssh.go | 576 ++++++++++++++++++++++++++ agent/agentssh/agentssh_test.go | 136 ++++++ agent/agentssh/bicopy.go | 47 +++ agent/{ssh.go => agentssh/forward.go} | 2 +- cli/portforward.go | 4 +- cli/ssh.go | 4 +- coderd/workspaceagents.go | 4 +- 9 files changed, 804 insertions(+), 548 deletions(-) create mode 100644 agent/agentssh/agentssh.go create mode 100644 agent/agentssh/agentssh_test.go create mode 100644 agent/agentssh/bicopy.go rename agent/{ssh.go => agentssh/forward.go} (99%) diff --git a/agent/agent.go b/agent/agent.go index e22d5c3576123..3906a182139ad 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -4,8 +4,6 @@ import ( "bufio" "bytes" "context" - "crypto/rand" - "crypto/rsa" "encoding/binary" "encoding/json" "errors" @@ -16,11 +14,9 @@ import ( "net/http" "net/netip" "os" - "os/exec" "os/user" "path/filepath" "reflect" - "runtime" "sort" "strconv" "strings" @@ -28,12 +24,9 @@ import ( "time" "github.com/armon/circbuf" - "github.com/gliderlabs/ssh" "github.com/google/uuid" - "github.com/pkg/sftp" "github.com/spf13/afero" "go.uber.org/atomic" - gossh "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" "golang.org/x/xerrors" "tailscale.com/net/speedtest" @@ -41,7 +34,7 @@ import ( "tailscale.com/types/netlogtype" "cdr.dev/slog" - "github.com/coder/coder/agent/usershell" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/buildinfo" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/gitauth" @@ -56,19 +49,6 @@ const ( ProtocolReconnectingPTY = "reconnecting-pty" ProtocolSSH = "ssh" ProtocolDial = "dial" - - // MagicSessionErrorCode indicates that something went wrong with the session, rather than the - // command just returning a nonzero exit code, and is chosen as an arbitrary, high number - // unlikely to shadow other exit codes, which are typically 1, 2, 3, etc. - MagicSessionErrorCode = 229 - - // MagicSSHSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection. - // This is stripped from any commands being executed, and is counted towards connection stats. - MagicSSHSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE" - // MagicSSHSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself. - MagicSSHSessionTypeVSCode = "vscode" - // MagicSSHSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself. - MagicSSHSessionTypeJetBrains = "jetbrains" ) type Options struct { @@ -165,7 +145,7 @@ type agent struct { // manifest is atomic because values can change after reconnection. manifest atomic.Pointer[agentsdk.Manifest] sessionToken atomic.Pointer[string] - sshServer *ssh.Server + sshServer *agentssh.Server sshMaxTimeout time.Duration lifecycleUpdate chan struct{} @@ -177,10 +157,19 @@ type agent struct { connStatsChan chan *agentsdk.Stats latestStat atomic.Pointer[agentsdk.Stats] - connCountVSCode atomic.Int64 - connCountJetBrains atomic.Int64 connCountReconnectingPTY atomic.Int64 - connCountSSHSession atomic.Int64 +} + +func (a *agent) init(ctx context.Context) { + sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.sshMaxTimeout) + if err != nil { + panic(err) + } + sshSrv.Env = a.envVars + sshSrv.AgentToken = func() string { return *a.sessionToken.Load() } + a.sshServer = sshSrv + + go a.runLoop(ctx) } // runLoop attempts to start the agent in a retry loop. @@ -223,7 +212,7 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM // if it is certain the clocks are in sync. CollectedAt: time.Now(), } - cmd, err := a.createCommand(ctx, md.Script, nil) + cmd, err := a.sshServer.CreateCommand(ctx, md.Script, nil) if err != nil { result.Error = err.Error() return result @@ -489,6 +478,7 @@ func (a *agent) run(ctx context.Context) error { } oldManifest := a.manifest.Swap(&manifest) + a.sshServer.SetManifest(&manifest) // The startup script should only execute on the first run! if oldManifest == nil { @@ -633,28 +623,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ } }() if err = a.trackConnGoroutine(func() { - var wg sync.WaitGroup - for { - conn, err := sshListener.Accept() - if err != nil { - break - } - wg.Add(1) - closed := make(chan struct{}) - go func() { - select { - case <-closed: - case <-a.closed: - _ = conn.Close() - } - wg.Done() - }() - go func() { - defer close(closed) - a.sshServer.HandleConn(conn) - }() - } - wg.Wait() + _ = a.sshServer.Serve(sshListener) }); err != nil { return nil, err } @@ -857,7 +826,7 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error { }() } - cmd, err := a.createCommand(ctx, script, nil) + cmd, err := a.sshServer.CreateCommand(ctx, script, nil) if err != nil { return xerrors.Errorf("create command: %w", err) } @@ -990,394 +959,6 @@ func (a *agent) trackScriptLogs(ctx context.Context, reader io.Reader) (chan str return logsFinished, nil } -func (a *agent) init(ctx context.Context) { - // Clients' should ignore the host key when connecting. - // The agent needs to authenticate with coderd to SSH, - // so SSH authentication doesn't improve security. - randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - panic(err) - } - randomSigner, err := gossh.NewSignerFromKey(randomHostKey) - if err != nil { - panic(err) - } - - sshLogger := a.logger.Named("ssh-server") - forwardHandler := &ssh.ForwardedTCPHandler{} - unixForwardHandler := &forwardedUnixHandler{log: a.logger} - - a.sshServer = &ssh.Server{ - ChannelHandlers: map[string]ssh.ChannelHandler{ - "direct-tcpip": ssh.DirectTCPIPHandler, - "direct-streamlocal@openssh.com": directStreamLocalHandler, - "session": ssh.DefaultSessionHandler, - }, - ConnectionFailedCallback: func(conn net.Conn, err error) { - sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) - }, - Handler: func(session ssh.Session) { - err := a.handleSSHSession(session) - var exitError *exec.ExitError - if xerrors.As(err, &exitError) { - a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) - _ = session.Exit(exitError.ExitCode()) - return - } - if err != nil { - a.logger.Warn(ctx, "ssh session failed", slog.Error(err)) - // This exit code is designed to be unlikely to be confused for a legit exit code - // from the process. - _ = session.Exit(MagicSessionErrorCode) - return - } - _ = session.Exit(0) - }, - HostSigners: []ssh.Signer{randomSigner}, - LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { - // Allow local port forwarding all! - sshLogger.Debug(ctx, "local port forward", - slog.F("destination-host", destinationHost), - slog.F("destination-port", destinationPort)) - return true - }, - PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { - return true - }, - ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { - // Allow reverse port forwarding all! - sshLogger.Debug(ctx, "local port forward", - slog.F("bind-host", bindHost), - slog.F("bind-port", bindPort)) - return true - }, - RequestHandlers: map[string]ssh.RequestHandler{ - "tcpip-forward": forwardHandler.HandleSSHRequest, - "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, - "streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, - "cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, - }, - ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { - return &gossh.ServerConfig{ - NoClientAuth: true, - } - }, - SubsystemHandlers: map[string]ssh.SubsystemHandler{ - "sftp": func(session ssh.Session) { - ctx := session.Context() - - // Typically sftp sessions don't request a TTY, but if they do, - // we must ensure the gliderlabs/ssh CRLF emulation is disabled. - // Otherwise sftp will be broken. This can happen if a user sets - // `RequestTTY force` in their SSH config. - session.DisablePTYEmulation() - - var opts []sftp.ServerOption - // Change current working directory to the users home - // directory so that SFTP connections land there. - homedir, err := userHomeDir() - if err != nil { - sshLogger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) - } else { - opts = append(opts, sftp.WithServerWorkingDirectory(homedir)) - } - - server, err := sftp.NewServer(session, opts...) - if err != nil { - sshLogger.Debug(ctx, "initialize sftp server", slog.Error(err)) - return - } - defer server.Close() - - err = server.Serve() - if errors.Is(err, io.EOF) { - // Unless we call `session.Exit(0)` here, the client won't - // receive `exit-status` because `(*sftp.Server).Close()` - // calls `Close()` on the underlying connection (session), - // which actually calls `channel.Close()` because it isn't - // wrapped. This causes sftp clients to receive a non-zero - // exit code. Typically sftp clients don't echo this exit - // code but `scp` on macOS does (when using the default - // SFTP backend). - _ = session.Exit(0) - return - } - sshLogger.Warn(ctx, "sftp server closed with error", slog.Error(err)) - _ = session.Exit(1) - }, - }, - MaxTimeout: a.sshMaxTimeout, - } - - go a.runLoop(ctx) -} - -// createCommand processes raw command input with OpenSSH-like behavior. -// If the script provided is empty, it will default to the users shell. -// This injects environment variables specified by the user at launch too. -func (a *agent) createCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { - currentUser, err := user.Current() - if err != nil { - return nil, xerrors.Errorf("get current user: %w", err) - } - username := currentUser.Username - - shell, err := usershell.Get(username) - if err != nil { - return nil, xerrors.Errorf("get user shell: %w", err) - } - - manifest := a.manifest.Load() - if manifest == nil { - return nil, xerrors.Errorf("no metadata was provided") - } - - // OpenSSH executes all commands with the users current shell. - // We replicate that behavior for IDE support. - caller := "-c" - if runtime.GOOS == "windows" { - caller = "/c" - } - args := []string{caller, script} - - // gliderlabs/ssh returns a command slice of zero - // when a shell is requested. - if len(script) == 0 { - args = []string{} - if runtime.GOOS != "windows" { - // On Linux and macOS, we should start a login - // shell to consume juicy environment variables! - args = append(args, "-l") - } - } - - cmd := exec.CommandContext(ctx, shell, args...) - cmd.Dir = manifest.Directory - - // If the metadata directory doesn't exist, we run the command - // in the users home directory. - _, err = os.Stat(cmd.Dir) - if cmd.Dir == "" || err != nil { - // Default to user home if a directory is not set. - homedir, err := userHomeDir() - if err != nil { - return nil, xerrors.Errorf("get home dir: %w", err) - } - cmd.Dir = homedir - } - cmd.Env = append(os.Environ(), env...) - executablePath, err := os.Executable() - if err != nil { - return nil, xerrors.Errorf("getting os executable: %w", err) - } - // Set environment variables reliable detection of being inside a - // Coder workspace. - cmd.Env = append(cmd.Env, "CODER=true") - cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username)) - // Git on Windows resolves with UNIX-style paths. - // If using backslashes, it's unable to find the executable. - unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/") - cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath)) - - // Specific Coder subcommands require the agent token exposed! - cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", *a.sessionToken.Load())) - - // Set SSH connection environment variables (these are also set by OpenSSH - // and thus expected to be present by SSH clients). Since the agent does - // networking in-memory, trying to provide accurate values here would be - // nonsensical. For now, we hard code these values so that they're present. - srcAddr, srcPort := "0.0.0.0", "0" - dstAddr, dstPort := "0.0.0.0", "0" - cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort)) - cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort)) - - // This adds the ports dialog to code-server that enables - // proxying a port dynamically. - cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI)) - - // Hide Coder message on code-server's "Getting Started" page - cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true") - - // Load environment variables passed via the agent. - // These should override all variables we manually specify. - for envKey, value := range manifest.EnvironmentVariables { - // Expanding environment variables allows for customization - // of the $PATH, among other variables. Customers can prepend - // or append to the $PATH, so allowing expand is required! - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value))) - } - - // Agent-level environment variables should take over all! - // This is used for setting agent-specific variables like "CODER_AGENT_TOKEN". - for envKey, value := range a.envVars { - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value)) - } - - return cmd, nil -} - -func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { - ctx := session.Context() - env := session.Environ() - var magicType string - for index, kv := range env { - if !strings.HasPrefix(kv, MagicSSHSessionTypeEnvironmentVariable) { - continue - } - magicType = strings.TrimPrefix(kv, MagicSSHSessionTypeEnvironmentVariable+"=") - env = append(env[:index], env[index+1:]...) - } - switch magicType { - case MagicSSHSessionTypeVSCode: - a.connCountVSCode.Add(1) - defer a.connCountVSCode.Add(-1) - case MagicSSHSessionTypeJetBrains: - a.connCountJetBrains.Add(1) - defer a.connCountJetBrains.Add(-1) - case "": - a.connCountSSHSession.Add(1) - defer a.connCountSSHSession.Add(-1) - default: - a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) - } - - cmd, err := a.createCommand(ctx, session.RawCommand(), env) - if err != nil { - return err - } - - if ssh.AgentRequested(session) { - l, err := ssh.NewAgentListener() - if err != nil { - return xerrors.Errorf("new agent listener: %w", err) - } - defer l.Close() - go ssh.ForwardAgentConnections(l, session) - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String())) - } - - sshPty, windowSize, isPty := session.Pty() - if isPty { - // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). - // See https://github.com/coder/coder/issues/3371. - session.DisablePTYEmulation() - - if !isQuietLogin(session.RawCommand()) { - manifest := a.manifest.Load() - if manifest != nil { - err = showMOTD(session, manifest.MOTDFile) - if err != nil { - a.logger.Error(ctx, "show MOTD", slog.Error(err)) - } - } else { - a.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") - } - } - - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) - - // The pty package sets `SSH_TTY` on supported platforms. - ptty, process, err := pty.Start(cmd, pty.WithPTYOption( - pty.WithSSHRequest(sshPty), - pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)), - )) - if err != nil { - return xerrors.Errorf("start command: %w", err) - } - var wg sync.WaitGroup - defer func() { - defer wg.Wait() - closeErr := ptty.Close() - if closeErr != nil { - a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) - if retErr == nil { - retErr = closeErr - } - } - }() - go func() { - for win := range windowSize { - resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) - // If the pty is closed, then command has exited, no need to log. - if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { - a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) - } - } - }() - // We don't add input copy to wait group because - // it won't return until the session is closed. - go func() { - _, _ = io.Copy(ptty.Input(), session) - }() - - // In low parallelism scenarios, the command may exit and we may close - // the pty before the output copy has started. This can result in the - // output being lost. To avoid this, we wait for the output copy to - // start before waiting for the command to exit. This ensures that the - // output copy goroutine will be scheduled before calling close on the - // pty. This shouldn't be needed because of `pty.Dup()` below, but it - // may not be supported on all platforms. - outputCopyStarted := make(chan struct{}) - ptyOutput := func() io.ReadCloser { - defer close(outputCopyStarted) - // Try to dup so we can separate stdin and stdout closure. - // Once the original pty is closed, the dup will return - // input/output error once the buffered data has been read. - stdout, err := ptty.Dup() - if err == nil { - return stdout - } - // If we can't dup, we shouldn't close - // the fd since it's tied to stdin. - return readNopCloser{ptty.Output()} - } - wg.Add(1) - go func() { - // Ensure data is flushed to session on command exit, if we - // close the session too soon, we might lose data. - defer wg.Done() - - stdout := ptyOutput() - defer stdout.Close() - - _, _ = io.Copy(session, stdout) - }() - <-outputCopyStarted - - err = process.Wait() - var exitErr *exec.ExitError - // ExitErrors just mean the command we run returned a non-zero exit code, which is normal - // and not something to be concerned about. But, if it's something else, we should log it. - if err != nil && !xerrors.As(err, &exitErr) { - a.logger.Warn(ctx, "wait error", slog.Error(err)) - } - return err - } - - cmd.Stdout = session - cmd.Stderr = session.Stderr() - // This blocks forever until stdin is received if we don't - // use StdinPipe. It's unknown what causes this. - stdinPipe, err := cmd.StdinPipe() - if err != nil { - return xerrors.Errorf("create stdin pipe: %w", err) - } - go func() { - _, _ = io.Copy(stdinPipe, session) - _ = stdinPipe.Close() - }() - err = cmd.Start() - if err != nil { - return xerrors.Errorf("start: %w", err) - } - return cmd.Wait() -} - -type readNopCloser struct{ io.Reader } - -// Close implements io.Closer. -func (readNopCloser) Close() error { return nil } - func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg codersdk.WorkspaceAgentReconnectingPTYInit, conn net.Conn) (retErr error) { defer conn.Close() @@ -1416,7 +997,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m logger.Debug(ctx, "creating new session") // Empty command will default to the users shell! - cmd, err := a.createCommand(ctx, msg.Command, nil) + cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil) if err != nil { return xerrors.Errorf("create command: %w", err) } @@ -1590,9 +1171,11 @@ func (a *agent) startReportingConnectionStats(ctx context.Context) { } // The count of active sessions. - stats.SessionCountSSH = a.connCountSSHSession.Load() - stats.SessionCountVSCode = a.connCountVSCode.Load() - stats.SessionCountJetBrains = a.connCountJetBrains.Load() + sshStats := a.sshServer.ConnStats() + stats.SessionCountSSH = sshStats.Sessions + stats.SessionCountVSCode = sshStats.VSCode + stats.SessionCountJetBrains = sshStats.JetBrains + stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load() // Compute the median connection latency! @@ -1692,8 +1275,16 @@ func (a *agent) Close() error { } ctx := context.Background() + a.logger.Info(ctx, "shutting down agent") a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown) + // Attempt to gracefully shut down all active SSH connections and + // stop accepting new ones. + err := a.sshServer.Shutdown(ctx) + if err != nil { + a.logger.Error(ctx, "ssh server shutdown", slog.Error(err)) + } + lifecycleState := codersdk.WorkspaceAgentLifecycleOff if manifest := a.manifest.Load(); manifest != nil && manifest.ShutdownScript != "" { scriptDone := make(chan error, 1) @@ -1785,101 +1376,6 @@ func (r *reconnectingPTY) Close() { r.timeout.Stop() } -// Bicopy copies all of the data between the two connections and will close them -// after one or both of them are done writing. If the context is canceled, both -// of the connections will be closed. -func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - defer func() { - _ = c1.Close() - _ = c2.Close() - }() - - var wg sync.WaitGroup - copyFunc := func(dst io.WriteCloser, src io.Reader) { - defer func() { - wg.Done() - // If one side of the copy fails, ensure the other one exits as - // well. - cancel() - }() - _, _ = io.Copy(dst, src) - } - - wg.Add(2) - go copyFunc(c1, c2) - go copyFunc(c2, c1) - - // Convert waitgroup to a channel so we can also wait on the context. - done := make(chan struct{}) - go func() { - defer close(done) - wg.Wait() - }() - - select { - case <-ctx.Done(): - case <-done: - } -} - -// isQuietLogin checks if the SSH server should perform a quiet login or not. -// -// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816 -func isQuietLogin(rawCommand string) bool { - // We are always quiet unless this is a login shell. - if len(rawCommand) != 0 { - return true - } - - // Best effort, if we can't get the home directory, - // we can't lookup .hushlogin. - homedir, err := userHomeDir() - if err != nil { - return false - } - - _, err = os.Stat(filepath.Join(homedir, ".hushlogin")) - return err == nil -} - -// showMOTD will output the message of the day from -// the given filename to dest, if the file exists. -// -// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L784 -func showMOTD(dest io.Writer, filename string) error { - if filename == "" { - return nil - } - - f, err := os.Open(filename) - if err != nil { - if xerrors.Is(err, os.ErrNotExist) { - // This is not an error, there simply isn't a MOTD to show. - return nil - } - return xerrors.Errorf("open MOTD: %w", err) - } - defer f.Close() - - s := bufio.NewScanner(f) - for s.Scan() { - // Carriage return ensures each line starts - // at the beginning of the terminal. - _, err = fmt.Fprint(dest, s.Text()+"\r\n") - if err != nil { - return xerrors.Errorf("write MOTD: %w", err) - } - } - if err := s.Err(); err != nil { - return xerrors.Errorf("read MOTD: %w", err) - } - - return nil -} - // userHomeDir returns the home directory of the current user, giving // priority to the $HOME environment variable. func userHomeDir() (string, error) { diff --git a/agent/agent_test.go b/agent/agent_test.go index ec76aa1b0b6b9..8d7d641e1f73d 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -41,6 +41,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" @@ -131,13 +132,13 @@ func TestAgent_Stats_Magic(t *testing.T) { defer sshClient.Close() session, err := sshClient.NewSession() require.NoError(t, err) - session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode) + session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode) defer session.Close() - command := "sh -c 'echo $" + agent.MagicSSHSessionTypeEnvironmentVariable + "'" + command := "sh -c 'echo $" + agentssh.MagicSessionTypeEnvironmentVariable + "'" expected := "" if runtime.GOOS == "windows" { - expected = "%" + agent.MagicSSHSessionTypeEnvironmentVariable + "%" + expected = "%" + agentssh.MagicSessionTypeEnvironmentVariable + "%" command = "cmd.exe /c echo " + expected } output, err := session.Output(command) @@ -158,7 +159,7 @@ func TestAgent_Stats_Magic(t *testing.T) { defer sshClient.Close() session, err := sshClient.NewSession() require.NoError(t, err) - session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode) + session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode) defer session.Close() stdin, err := session.StdinPipe() require.NoError(t, err) @@ -1595,7 +1596,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe } waitGroup.Add(1) go func() { - agent.Bicopy(context.Background(), conn, ssh) + agentssh.Bicopy(context.Background(), conn, ssh) waitGroup.Done() }() } diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go new file mode 100644 index 0000000000000..94e716b260fbe --- /dev/null +++ b/agent/agentssh/agentssh.go @@ -0,0 +1,576 @@ +package agentssh + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + "io" + "net" + "os" + "os/exec" + "os/user" + "path/filepath" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gliderlabs/ssh" + "github.com/pkg/sftp" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/agent/usershell" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/pty" +) + +const ( + // MagicSessionErrorCode indicates that something went wrong with the session, rather than the + // command just returning a nonzero exit code, and is chosen as an arbitrary, high number + // unlikely to shadow other exit codes, which are typically 1, 2, 3, etc. + MagicSessionErrorCode = 229 + + // MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection. + // This is stripped from any commands being executed, and is counted towards connection stats. + MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE" + // MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself. + MagicSessionTypeVSCode = "vscode" + // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself. + MagicSessionTypeJetBrains = "jetbrains" +) + +type Server struct { + ctx context.Context + cancel context.CancelFunc + serveWg sync.WaitGroup + logger slog.Logger + + srv *ssh.Server + + Env map[string]string + AgentToken func() string + + manifest atomic.Pointer[agentsdk.Manifest] + + connCountVSCode atomic.Int64 + connCountJetBrains atomic.Int64 + connCountSSHSession atomic.Int64 +} + +func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration) (*Server, error) { + // Clients' should ignore the host key when connecting. + // The agent needs to authenticate with coderd to SSH, + // so SSH authentication doesn't improve security. + randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + randomSigner, err := gossh.NewSignerFromKey(randomHostKey) + if err != nil { + return nil, err + } + + forwardHandler := &ssh.ForwardedTCPHandler{} + unixForwardHandler := &forwardedUnixHandler{log: logger} + + sCtx, sCancel := context.WithCancel(context.Background()) + s := &Server{ + ctx: sCtx, + cancel: sCancel, + logger: logger, + } + + s.srv = &ssh.Server{ + ChannelHandlers: map[string]ssh.ChannelHandler{ + "direct-tcpip": ssh.DirectTCPIPHandler, + "direct-streamlocal@openssh.com": directStreamLocalHandler, + "session": ssh.DefaultSessionHandler, + }, + ConnectionFailedCallback: func(_ net.Conn, err error) { + logger.Info(ctx, "ssh connection ended", slog.Error(err)) + }, + Handler: s.sessionHandler, + HostSigners: []ssh.Signer{randomSigner}, + LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { + // Allow local port forwarding all! + logger.Debug(ctx, "local port forward", + slog.F("destination-host", destinationHost), + slog.F("destination-port", destinationPort)) + return true + }, + PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { + return true + }, + ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + // Allow reverse port forwarding all! + logger.Debug(ctx, "local port forward", + slog.F("bind-host", bindHost), + slog.F("bind-port", bindPort)) + return true + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": forwardHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, + "streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, + "cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, + }, + ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { + return &gossh.ServerConfig{ + NoClientAuth: true, + } + }, + SubsystemHandlers: map[string]ssh.SubsystemHandler{ + "sftp": s.sftpHandler, + }, + MaxTimeout: maxTimeout, + } + + return s, nil +} + +// SetManifest sets the manifest used for starting commands. +func (a *Server) SetManifest(m *agentsdk.Manifest) { + a.manifest.Store(m) +} + +type ConnStats struct { + Sessions int64 + VSCode int64 + JetBrains int64 +} + +func (a *Server) ConnStats() ConnStats { + return ConnStats{ + Sessions: a.connCountSSHSession.Load(), + VSCode: a.connCountVSCode.Load(), + JetBrains: a.connCountJetBrains.Load(), + } +} + +func (a *Server) sessionHandler(session ssh.Session) { + ctx := session.Context() + err := a.sessionStart(session) + var exitError *exec.ExitError + if xerrors.As(err, &exitError) { + a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) + _ = session.Exit(exitError.ExitCode()) + return + } + if err != nil { + a.logger.Warn(ctx, "ssh session failed", slog.Error(err)) + // This exit code is designed to be unlikely to be confused for a legit exit code + // from the process. + _ = session.Exit(MagicSessionErrorCode) + return + } + _ = session.Exit(0) +} + +func (a *Server) sessionStart(session ssh.Session) (retErr error) { + ctx := session.Context() + env := session.Environ() + var magicType string + for index, kv := range env { + if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) { + continue + } + magicType = strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"=") + env = append(env[:index], env[index+1:]...) + } + switch magicType { + case MagicSessionTypeVSCode: + a.connCountVSCode.Add(1) + defer a.connCountVSCode.Add(-1) + case MagicSessionTypeJetBrains: + a.connCountJetBrains.Add(1) + defer a.connCountJetBrains.Add(-1) + case "": + a.connCountSSHSession.Add(1) + defer a.connCountSSHSession.Add(-1) + default: + a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) + } + + cmd, err := a.CreateCommand(ctx, session.RawCommand(), env) + if err != nil { + return err + } + + if ssh.AgentRequested(session) { + l, err := ssh.NewAgentListener() + if err != nil { + return xerrors.Errorf("new agent listener: %w", err) + } + defer l.Close() + go ssh.ForwardAgentConnections(l, session) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String())) + } + + sshPty, windowSize, isPty := session.Pty() + if isPty { + // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). + // See https://github.com/coder/coder/issues/3371. + session.DisablePTYEmulation() + + if !isQuietLogin(session.RawCommand()) { + manifest := a.manifest.Load() + if manifest != nil { + err = showMOTD(session, manifest.MOTDFile) + if err != nil { + a.logger.Error(ctx, "show MOTD", slog.Error(err)) + } + } else { + a.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") + } + } + + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) + + // The pty package sets `SSH_TTY` on supported platforms. + ptty, process, err := pty.Start(cmd, pty.WithPTYOption( + pty.WithSSHRequest(sshPty), + pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)), + )) + if err != nil { + return xerrors.Errorf("start command: %w", err) + } + var wg sync.WaitGroup + defer func() { + defer wg.Wait() + closeErr := ptty.Close() + if closeErr != nil { + a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) + if retErr == nil { + retErr = closeErr + } + } + }() + go func() { + for win := range windowSize { + resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) + // If the pty is closed, then command has exited, no need to log. + if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { + a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) + } + } + }() + // We don't add input copy to wait group because + // it won't return until the session is closed. + go func() { + _, _ = io.Copy(ptty.Input(), session) + }() + + // In low parallelism scenarios, the command may exit and we may close + // the pty before the output copy has started. This can result in the + // output being lost. To avoid this, we wait for the output copy to + // start before waiting for the command to exit. This ensures that the + // output copy goroutine will be scheduled before calling close on the + // pty. This shouldn't be needed because of `pty.Dup()` below, but it + // may not be supported on all platforms. + outputCopyStarted := make(chan struct{}) + ptyOutput := func() io.ReadCloser { + defer close(outputCopyStarted) + // Try to dup so we can separate stdin and stdout closure. + // Once the original pty is closed, the dup will return + // input/output error once the buffered data has been read. + stdout, err := ptty.Dup() + if err == nil { + return stdout + } + // If we can't dup, we shouldn't close + // the fd since it's tied to stdin. + return readNopCloser{ptty.Output()} + } + wg.Add(1) + go func() { + // Ensure data is flushed to session on command exit, if we + // close the session too soon, we might lose data. + defer wg.Done() + + stdout := ptyOutput() + defer stdout.Close() + + _, _ = io.Copy(session, stdout) + }() + <-outputCopyStarted + + err = process.Wait() + var exitErr *exec.ExitError + // ExitErrors just mean the command we run returned a non-zero exit code, which is normal + // and not something to be concerned about. But, if it's something else, we should log it. + if err != nil && !xerrors.As(err, &exitErr) { + a.logger.Warn(ctx, "wait error", slog.Error(err)) + } + return err + } + + cmd.Stdout = session + cmd.Stderr = session.Stderr() + // This blocks forever until stdin is received if we don't + // use StdinPipe. It's unknown what causes this. + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return xerrors.Errorf("create stdin pipe: %w", err) + } + go func() { + _, _ = io.Copy(stdinPipe, session) + _ = stdinPipe.Close() + }() + err = cmd.Start() + if err != nil { + return xerrors.Errorf("start: %w", err) + } + return cmd.Wait() +} + +type readNopCloser struct{ io.Reader } + +// Close implements io.Closer. +func (readNopCloser) Close() error { return nil } + +func (a *Server) sftpHandler(session ssh.Session) { + ctx := session.Context() + + // Typically sftp sessions don't request a TTY, but if they do, + // we must ensure the gliderlabs/ssh CRLF emulation is disabled. + // Otherwise sftp will be broken. This can happen if a user sets + // `RequestTTY force` in their SSH config. + session.DisablePTYEmulation() + + var opts []sftp.ServerOption + // Change current working directory to the users home + // directory so that SFTP connections land there. + homedir, err := userHomeDir() + if err != nil { + a.logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) + } else { + opts = append(opts, sftp.WithServerWorkingDirectory(homedir)) + } + + server, err := sftp.NewServer(session, opts...) + if err != nil { + a.logger.Debug(ctx, "initialize sftp server", slog.Error(err)) + return + } + defer server.Close() + + err = server.Serve() + if errors.Is(err, io.EOF) { + // Unless we call `session.Exit(0)` here, the client won't + // receive `exit-status` because `(*sftp.Server).Close()` + // calls `Close()` on the underlying connection (session), + // which actually calls `channel.Close()` because it isn't + // wrapped. This causes sftp clients to receive a non-zero + // exit code. Typically sftp clients don't echo this exit + // code but `scp` on macOS does (when using the default + // SFTP backend). + _ = session.Exit(0) + return + } + a.logger.Warn(ctx, "sftp server closed with error", slog.Error(err)) + _ = session.Exit(1) +} + +// CreateCommand processes raw command input with OpenSSH-like behavior. +// If the script provided is empty, it will default to the users shell. +// This injects environment variables specified by the user at launch too. +func (a *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { + currentUser, err := user.Current() + if err != nil { + return nil, xerrors.Errorf("get current user: %w", err) + } + username := currentUser.Username + + shell, err := usershell.Get(username) + if err != nil { + return nil, xerrors.Errorf("get user shell: %w", err) + } + + manifest := a.manifest.Load() + if manifest == nil { + return nil, xerrors.Errorf("no metadata was provided") + } + + // OpenSSH executes all commands with the users current shell. + // We replicate that behavior for IDE support. + caller := "-c" + if runtime.GOOS == "windows" { + caller = "/c" + } + args := []string{caller, script} + + // gliderlabs/ssh returns a command slice of zero + // when a shell is requested. + if len(script) == 0 { + args = []string{} + if runtime.GOOS != "windows" { + // On Linux and macOS, we should start a login + // shell to consume juicy environment variables! + args = append(args, "-l") + } + } + + cmd := exec.CommandContext(ctx, shell, args...) + cmd.Dir = manifest.Directory + + // If the metadata directory doesn't exist, we run the command + // in the users home directory. + _, err = os.Stat(cmd.Dir) + if cmd.Dir == "" || err != nil { + // Default to user home if a directory is not set. + homedir, err := userHomeDir() + if err != nil { + return nil, xerrors.Errorf("get home dir: %w", err) + } + cmd.Dir = homedir + } + cmd.Env = append(os.Environ(), env...) + executablePath, err := os.Executable() + if err != nil { + return nil, xerrors.Errorf("getting os executable: %w", err) + } + // Set environment variables reliable detection of being inside a + // Coder workspace. + cmd.Env = append(cmd.Env, "CODER=true") + cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username)) + // Git on Windows resolves with UNIX-style paths. + // If using backslashes, it's unable to find the executable. + unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/") + cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath)) + + // Specific Coder subcommands require the agent token exposed! + cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", a.AgentToken())) + + // Set SSH connection environment variables (these are also set by OpenSSH + // and thus expected to be present by SSH clients). Since the agent does + // networking in-memory, trying to provide accurate values here would be + // nonsensical. For now, we hard code these values so that they're present. + srcAddr, srcPort := "0.0.0.0", "0" + dstAddr, dstPort := "0.0.0.0", "0" + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort)) + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort)) + + // This adds the ports dialog to code-server that enables + // proxying a port dynamically. + cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI)) + + // Hide Coder message on code-server's "Getting Started" page + cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true") + + // Load environment variables passed via the agent. + // These should override all variables we manually specify. + for envKey, value := range manifest.EnvironmentVariables { + // Expanding environment variables allows for customization + // of the $PATH, among other variables. Customers can prepend + // or append to the $PATH, so allowing expand is required! + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value))) + } + + // Agent-level environment variables should take over all! + // This is used for setting agent-specific variables like "CODER_AGENT_TOKEN". + for envKey, value := range a.Env { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value)) + } + + return cmd, nil +} + +func (a *Server) Serve(l net.Listener) error { + a.serveWg.Add(1) + defer a.serveWg.Done() + return a.srv.Serve(l) +} + +func (a *Server) Close() error { + err := a.srv.Close() + a.serveWg.Wait() + return err +} + +// Shutdown gracefully closes all active SSH connections and stops +// accepting new connections. +// +// Shutdown is not implemented. +func (a *Server) Shutdown(ctx context.Context) error { + // TODO(mafredri): Implement shutdown, SIGHUP running commands, etc. + return nil +} + +// isQuietLogin checks if the SSH server should perform a quiet login or not. +// +// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816 +func isQuietLogin(rawCommand string) bool { + // We are always quiet unless this is a login shell. + if len(rawCommand) != 0 { + return true + } + + // Best effort, if we can't get the home directory, + // we can't lookup .hushlogin. + homedir, err := userHomeDir() + if err != nil { + return false + } + + _, err = os.Stat(filepath.Join(homedir, ".hushlogin")) + return err == nil +} + +// showMOTD will output the message of the day from +// the given filename to dest, if the file exists. +// +// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L784 +func showMOTD(dest io.Writer, filename string) error { + if filename == "" { + return nil + } + + f, err := os.Open(filename) + if err != nil { + if xerrors.Is(err, os.ErrNotExist) { + // This is not an error, there simply isn't a MOTD to show. + return nil + } + return xerrors.Errorf("open MOTD: %w", err) + } + defer f.Close() + + s := bufio.NewScanner(f) + for s.Scan() { + // Carriage return ensures each line starts + // at the beginning of the terminal. + _, err = fmt.Fprint(dest, s.Text()+"\r\n") + if err != nil { + return xerrors.Errorf("write MOTD: %w", err) + } + } + if err := s.Err(); err != nil { + return xerrors.Errorf("read MOTD: %w", err) + } + + return nil +} + +// userHomeDir returns the home directory of the current user, giving +// priority to the $HOME environment variable. +func userHomeDir() (string, error) { + // First we check the environment. + homedir, err := os.UserHomeDir() + if err == nil { + return homedir, nil + } + + // As a fallback, we try the user information. + u, err := user.Current() + if err != nil { + return "", xerrors.Errorf("current user: %w", err) + } + return u.HomeDir, nil +} diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go new file mode 100644 index 0000000000000..ecdb0a19eb5d8 --- /dev/null +++ b/agent/agentssh/agentssh_test.go @@ -0,0 +1,136 @@ +package agentssh_test + +import ( + "bytes" + "context" + "net" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "golang.org/x/crypto/ssh" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/agent/agentssh" + "github.com/coder/coder/codersdk/agentsdk" + "github.com/coder/coder/pty/ptytest" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestNewServer_ServeClient(t *testing.T) { + t.Parallel() + + ctx := context.Background() + logger := slogtest.Make(t, nil) + s, err := agentssh.NewServer(ctx, logger, 0) + require.NoError(t, err) + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.SetManifest(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + + c := sshClient(t, ln.Addr().String()) + var b bytes.Buffer + sess, err := c.NewSession() + sess.Stdout = &b + require.NoError(t, err) + err = sess.Start("echo hello") + require.NoError(t, err) + + err = sess.Wait() + require.NoError(t, err) + + require.Equal(t, "hello", strings.TrimSpace(b.String())) + + err = s.Close() + require.NoError(t, err) + <-done +} + +func TestNewServer_CloseActiveConnections(t *testing.T) { + t.Parallel() + + ctx := context.Background() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + s, err := agentssh.NewServer(ctx, logger, 0) + require.NoError(t, err) + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.SetManifest(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + + pty := ptytest.New(t) + + doClose := make(chan struct{}) + go func() { + defer wg.Done() + c := sshClient(t, ln.Addr().String()) + sess, err := c.NewSession() + sess.Stdin = pty.Input() + sess.Stdout = pty.Output() + sess.Stderr = pty.Output() + + assert.NoError(t, err) + err = sess.Start("") + assert.NoError(t, err) + + close(doClose) + err = sess.Wait() + assert.Error(t, err) + }() + + <-doClose + err = s.Close() + require.NoError(t, err) + + wg.Wait() +} + +func sshClient(t *testing.T, addr string) *ssh.Client { + conn, err := net.Dial("tcp", addr) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + + sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test. + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = sshConn.Close() + }) + c := ssh.NewClient(sshConn, channels, requests) + t.Cleanup(func() { + _ = c.Close() + }) + return c +} diff --git a/agent/agentssh/bicopy.go b/agent/agentssh/bicopy.go new file mode 100644 index 0000000000000..64cd2a716058c --- /dev/null +++ b/agent/agentssh/bicopy.go @@ -0,0 +1,47 @@ +package agentssh + +import ( + "context" + "io" + "sync" +) + +// Bicopy copies all of the data between the two connections and will close them +// after one or both of them are done writing. If the context is canceled, both +// of the connections will be closed. +func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + defer func() { + _ = c1.Close() + _ = c2.Close() + }() + + var wg sync.WaitGroup + copyFunc := func(dst io.WriteCloser, src io.Reader) { + defer func() { + wg.Done() + // If one side of the copy fails, ensure the other one exits as + // well. + cancel() + }() + _, _ = io.Copy(dst, src) + } + + wg.Add(2) + go copyFunc(c1, c2) + go copyFunc(c2, c1) + + // Convert waitgroup to a channel so we can also wait on the context. + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-ctx.Done(): + case <-done: + } +} diff --git a/agent/ssh.go b/agent/agentssh/forward.go similarity index 99% rename from agent/ssh.go rename to agent/agentssh/forward.go index 8aa41a1d268ed..1e3635fd8ff91 100644 --- a/agent/ssh.go +++ b/agent/agentssh/forward.go @@ -1,4 +1,4 @@ -package agent +package agentssh import ( "context" diff --git a/cli/portforward.go b/cli/portforward.go index c746216889a55..dad82381bfb5b 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -14,7 +14,7 @@ import ( "github.com/pion/udp" "golang.org/x/xerrors" - "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" @@ -226,7 +226,7 @@ func listenAndPortForward(ctx context.Context, inv *clibase.Invocation, conn *co } defer remoteConn.Close() - agent.Bicopy(ctx, netConn, remoteConn) + agentssh.Bicopy(ctx, netConn, remoteConn) }(netConn) } }(spec) diff --git a/cli/ssh.go b/cli/ssh.go index e9168f6999f6b..e1ebbcd04cfd2 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -23,7 +23,7 @@ import ( "golang.org/x/term" "golang.org/x/xerrors" - "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" @@ -574,7 +574,7 @@ func sshForwardRemote(ctx context.Context, stderr io.Writer, sshClient *gossh.Cl } } - agent.Bicopy(ctx, localConn, remoteConn) + agentssh.Bicopy(ctx, localConn, remoteConn) }() } }() diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 6ce14dad7689e..293ab3f0a06d8 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -30,7 +30,7 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" - "github.com/coder/coder/agent" + "github.com/coder/coder/agent/agentssh" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/gitauth" @@ -620,7 +620,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { return } defer ptNetConn.Close() - agent.Bicopy(ctx, wsNetConn, ptNetConn) + agentssh.Bicopy(ctx, wsNetConn, ptNetConn) } // @Summary Get listening ports for workspace agent From d5f7a4e4f8705bfc67d02f9b137b1f0622ec6e34 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 4 Apr 2023 14:46:58 +0000 Subject: [PATCH 2/7] Rename receivers --- agent/agentssh/agentssh.go | 84 +++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 94e716b260fbe..b2c421c54f1aa 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -136,8 +136,8 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration } // SetManifest sets the manifest used for starting commands. -func (a *Server) SetManifest(m *agentsdk.Manifest) { - a.manifest.Store(m) +func (s *Server) SetManifest(m *agentsdk.Manifest) { + s.manifest.Store(m) } type ConnStats struct { @@ -146,25 +146,25 @@ type ConnStats struct { JetBrains int64 } -func (a *Server) ConnStats() ConnStats { +func (s *Server) ConnStats() ConnStats { return ConnStats{ - Sessions: a.connCountSSHSession.Load(), - VSCode: a.connCountVSCode.Load(), - JetBrains: a.connCountJetBrains.Load(), + Sessions: s.connCountSSHSession.Load(), + VSCode: s.connCountVSCode.Load(), + JetBrains: s.connCountJetBrains.Load(), } } -func (a *Server) sessionHandler(session ssh.Session) { +func (s *Server) sessionHandler(session ssh.Session) { ctx := session.Context() - err := a.sessionStart(session) + err := s.sessionStart(session) var exitError *exec.ExitError if xerrors.As(err, &exitError) { - a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) + s.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) _ = session.Exit(exitError.ExitCode()) return } if err != nil { - a.logger.Warn(ctx, "ssh session failed", slog.Error(err)) + s.logger.Warn(ctx, "ssh session failed", slog.Error(err)) // This exit code is designed to be unlikely to be confused for a legit exit code // from the process. _ = session.Exit(MagicSessionErrorCode) @@ -173,7 +173,7 @@ func (a *Server) sessionHandler(session ssh.Session) { _ = session.Exit(0) } -func (a *Server) sessionStart(session ssh.Session) (retErr error) { +func (s *Server) sessionStart(session ssh.Session) (retErr error) { ctx := session.Context() env := session.Environ() var magicType string @@ -186,19 +186,19 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { } switch magicType { case MagicSessionTypeVSCode: - a.connCountVSCode.Add(1) - defer a.connCountVSCode.Add(-1) + s.connCountVSCode.Add(1) + defer s.connCountVSCode.Add(-1) case MagicSessionTypeJetBrains: - a.connCountJetBrains.Add(1) - defer a.connCountJetBrains.Add(-1) + s.connCountJetBrains.Add(1) + defer s.connCountJetBrains.Add(-1) case "": - a.connCountSSHSession.Add(1) - defer a.connCountSSHSession.Add(-1) + s.connCountSSHSession.Add(1) + defer s.connCountSSHSession.Add(-1) default: - a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) + s.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) } - cmd, err := a.CreateCommand(ctx, session.RawCommand(), env) + cmd, err := s.CreateCommand(ctx, session.RawCommand(), env) if err != nil { return err } @@ -220,14 +220,14 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { session.DisablePTYEmulation() if !isQuietLogin(session.RawCommand()) { - manifest := a.manifest.Load() + manifest := s.manifest.Load() if manifest != nil { err = showMOTD(session, manifest.MOTDFile) if err != nil { - a.logger.Error(ctx, "show MOTD", slog.Error(err)) + s.logger.Error(ctx, "show MOTD", slog.Error(err)) } } else { - a.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") + s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") } } @@ -236,7 +236,7 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { // The pty package sets `SSH_TTY` on supported platforms. ptty, process, err := pty.Start(cmd, pty.WithPTYOption( pty.WithSSHRequest(sshPty), - pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)), + pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)), )) if err != nil { return xerrors.Errorf("start command: %w", err) @@ -246,7 +246,7 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { defer wg.Wait() closeErr := ptty.Close() if closeErr != nil { - a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) + s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) if retErr == nil { retErr = closeErr } @@ -257,7 +257,7 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) // If the pty is closed, then command has exited, no need to log. if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { - a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) + s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) } } }() @@ -306,7 +306,7 @@ func (a *Server) sessionStart(session ssh.Session) (retErr error) { // ExitErrors just mean the command we run returned a non-zero exit code, which is normal // and not something to be concerned about. But, if it's something else, we should log it. if err != nil && !xerrors.As(err, &exitErr) { - a.logger.Warn(ctx, "wait error", slog.Error(err)) + s.logger.Warn(ctx, "wait error", slog.Error(err)) } return err } @@ -335,7 +335,7 @@ type readNopCloser struct{ io.Reader } // Close implements io.Closer. func (readNopCloser) Close() error { return nil } -func (a *Server) sftpHandler(session ssh.Session) { +func (s *Server) sftpHandler(session ssh.Session) { ctx := session.Context() // Typically sftp sessions don't request a TTY, but if they do, @@ -349,14 +349,14 @@ func (a *Server) sftpHandler(session ssh.Session) { // directory so that SFTP connections land there. homedir, err := userHomeDir() if err != nil { - a.logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) + s.logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) } else { opts = append(opts, sftp.WithServerWorkingDirectory(homedir)) } server, err := sftp.NewServer(session, opts...) if err != nil { - a.logger.Debug(ctx, "initialize sftp server", slog.Error(err)) + s.logger.Debug(ctx, "initialize sftp server", slog.Error(err)) return } defer server.Close() @@ -374,14 +374,14 @@ func (a *Server) sftpHandler(session ssh.Session) { _ = session.Exit(0) return } - a.logger.Warn(ctx, "sftp server closed with error", slog.Error(err)) + s.logger.Warn(ctx, "sftp server closed with error", slog.Error(err)) _ = session.Exit(1) } // CreateCommand processes raw command input with OpenSSH-like behavior. // If the script provided is empty, it will default to the users shell. // This injects environment variables specified by the user at launch too. -func (a *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { +func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { currentUser, err := user.Current() if err != nil { return nil, xerrors.Errorf("get current user: %w", err) @@ -393,7 +393,7 @@ func (a *Server) CreateCommand(ctx context.Context, script string, env []string) return nil, xerrors.Errorf("get user shell: %w", err) } - manifest := a.manifest.Load() + manifest := s.manifest.Load() if manifest == nil { return nil, xerrors.Errorf("no metadata was provided") } @@ -446,7 +446,7 @@ func (a *Server) CreateCommand(ctx context.Context, script string, env []string) cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath)) // Specific Coder subcommands require the agent token exposed! - cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", a.AgentToken())) + cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", s.AgentToken())) // Set SSH connection environment variables (these are also set by OpenSSH // and thus expected to be present by SSH clients). Since the agent does @@ -475,22 +475,22 @@ func (a *Server) CreateCommand(ctx context.Context, script string, env []string) // Agent-level environment variables should take over all! // This is used for setting agent-specific variables like "CODER_AGENT_TOKEN". - for envKey, value := range a.Env { + for envKey, value := range s.Env { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value)) } return cmd, nil } -func (a *Server) Serve(l net.Listener) error { - a.serveWg.Add(1) - defer a.serveWg.Done() - return a.srv.Serve(l) +func (s *Server) Serve(l net.Listener) error { + s.serveWg.Add(1) + defer s.serveWg.Done() + return s.srv.Serve(l) } -func (a *Server) Close() error { - err := a.srv.Close() - a.serveWg.Wait() +func (s *Server) Close() error { + err := s.srv.Close() + s.serveWg.Wait() return err } @@ -498,7 +498,7 @@ func (a *Server) Close() error { // accepting new connections. // // Shutdown is not implemented. -func (a *Server) Shutdown(ctx context.Context) error { +func (*Server) Shutdown(ctx context.Context) error { // TODO(mafredri): Implement shutdown, SIGHUP running commands, etc. return nil } From 3ea3e707ce40f3d9a4c4be0481e748484607f55d Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 4 Apr 2023 14:52:33 +0000 Subject: [PATCH 3/7] Remove unused context --- agent/agentssh/agentssh.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index b2c421c54f1aa..ff273790dffec 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -47,8 +47,6 @@ const ( ) type Server struct { - ctx context.Context - cancel context.CancelFunc serveWg sync.WaitGroup logger slog.Logger @@ -80,10 +78,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration forwardHandler := &ssh.ForwardedTCPHandler{} unixForwardHandler := &forwardedUnixHandler{log: logger} - sCtx, sCancel := context.WithCancel(context.Background()) s := &Server{ - ctx: sCtx, - cancel: sCancel, logger: logger, } From 126813f029765d9bf6606d7cbcaee40a40ad845f Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 4 Apr 2023 14:54:22 +0000 Subject: [PATCH 4/7] Use s logger --- agent/agentssh/agentssh.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index ff273790dffec..d929b1e0293e5 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -89,13 +89,13 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration "session": ssh.DefaultSessionHandler, }, ConnectionFailedCallback: func(_ net.Conn, err error) { - logger.Info(ctx, "ssh connection ended", slog.Error(err)) + s.logger.Info(ctx, "ssh connection ended", slog.Error(err)) }, Handler: s.sessionHandler, HostSigners: []ssh.Signer{randomSigner}, LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { // Allow local port forwarding all! - logger.Debug(ctx, "local port forward", + s.logger.Debug(ctx, "local port forward", slog.F("destination-host", destinationHost), slog.F("destination-port", destinationPort)) return true @@ -105,7 +105,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration }, ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { // Allow reverse port forwarding all! - logger.Debug(ctx, "local port forward", + s.logger.Debug(ctx, "local port forward", slog.F("bind-host", bindHost), slog.F("bind-port", bindPort)) return true From 667d038fe262d3bd21e827569f3534f242cfe9b0 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 4 Apr 2023 15:04:02 +0000 Subject: [PATCH 5/7] Rename unused arg _ --- agent/agentssh/agentssh.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index d929b1e0293e5..a1cad342d4fbe 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -493,7 +493,7 @@ func (s *Server) Close() error { // accepting new connections. // // Shutdown is not implemented. -func (*Server) Shutdown(ctx context.Context) error { +func (*Server) Shutdown(_ context.Context) error { // TODO(mafredri): Implement shutdown, SIGHUP running commands, etc. return nil } From 94d759396ad87cb386f54917ea51cc56e4a6fe0f Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 6 Apr 2023 09:50:51 +0000 Subject: [PATCH 6/7] Address PR feedback --- agent/agent.go | 2 +- agent/agentssh/agentssh.go | 14 ++++---------- agent/agentssh/agentssh_test.go | 7 +++++-- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 3906a182139ad..f538ef93b4af8 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -167,6 +167,7 @@ func (a *agent) init(ctx context.Context) { } sshSrv.Env = a.envVars sshSrv.AgentToken = func() string { return *a.sessionToken.Load() } + sshSrv.Manifest = &a.manifest a.sshServer = sshSrv go a.runLoop(ctx) @@ -478,7 +479,6 @@ func (a *agent) run(ctx context.Context) error { } oldManifest := a.manifest.Swap(&manifest) - a.sshServer.SetManifest(&manifest) // The startup script should only execute on the first run! if oldManifest == nil { diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index a1cad342d4fbe..c511d0e5f9eab 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -16,11 +16,11 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "time" "github.com/gliderlabs/ssh" "github.com/pkg/sftp" + "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" "golang.org/x/xerrors" @@ -54,8 +54,7 @@ type Server struct { Env map[string]string AgentToken func() string - - manifest atomic.Pointer[agentsdk.Manifest] + Manifest *atomic.Pointer[agentsdk.Manifest] connCountVSCode atomic.Int64 connCountJetBrains atomic.Int64 @@ -130,11 +129,6 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration return s, nil } -// SetManifest sets the manifest used for starting commands. -func (s *Server) SetManifest(m *agentsdk.Manifest) { - s.manifest.Store(m) -} - type ConnStats struct { Sessions int64 VSCode int64 @@ -215,7 +209,7 @@ func (s *Server) sessionStart(session ssh.Session) (retErr error) { session.DisablePTYEmulation() if !isQuietLogin(session.RawCommand()) { - manifest := s.manifest.Load() + manifest := s.Manifest.Load() if manifest != nil { err = showMOTD(session, manifest.MOTDFile) if err != nil { @@ -388,7 +382,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) return nil, xerrors.Errorf("get user shell: %w", err) } - manifest := s.manifest.Load() + manifest := s.Manifest.Load() if manifest == nil { return nil, xerrors.Errorf("no metadata was provided") } diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index ecdb0a19eb5d8..684c0e36bbb18 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -1,3 +1,5 @@ +// Package agentssh_test provides tests for basic functinoality of the agentssh +// package, more test coverage can be found in the `agent` and `cli` package(s). package agentssh_test import ( @@ -10,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/atomic" "go.uber.org/goleak" "golang.org/x/crypto/ssh" @@ -34,7 +37,7 @@ func TestNewServer_ServeClient(t *testing.T) { // The assumption is that these are set before serving SSH connections. s.AgentToken = func() string { return "" } - s.SetManifest(&agentsdk.Manifest{}) + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -74,7 +77,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { // The assumption is that these are set before serving SSH connections. s.AgentToken = func() string { return "" } - s.SetManifest(&agentsdk.Manifest{}) + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) From ed63a2bcf048d0fc178908ff06743c4abd823fd0 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 6 Apr 2023 11:38:04 +0000 Subject: [PATCH 7/7] Improve handling of serve/close --- agent/agentssh/agentssh.go | 128 ++++++++++++++++++++++++++++++++++--- 1 file changed, 120 insertions(+), 8 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index c511d0e5f9eab..c882380bacf48 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -47,10 +47,16 @@ const ( ) type Server struct { - serveWg sync.WaitGroup - logger slog.Logger + mu sync.RWMutex // Protects following. + listeners map[net.Listener]struct{} + conns map[net.Conn]struct{} + closing chan struct{} + // Wait for goroutines to exit, waited without + // a lock on mu but protected by closing. + wg sync.WaitGroup - srv *ssh.Server + logger slog.Logger + srv *ssh.Server Env map[string]string AgentToken func() string @@ -78,7 +84,9 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration unixForwardHandler := &forwardedUnixHandler{log: logger} s := &Server{ - logger: logger, + listeners: make(map[net.Listener]struct{}), + conns: make(map[net.Conn]struct{}), + logger: logger, } s.srv = &ssh.Server{ @@ -472,14 +480,118 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) } func (s *Server) Serve(l net.Listener) error { - s.serveWg.Add(1) - defer s.serveWg.Done() - return s.srv.Serve(l) + defer l.Close() + + s.trackListener(l, true) + defer s.trackListener(l, false) + for { + conn, err := l.Accept() + if err != nil { + return err + } + go s.handleConn(l, conn) + } } +func (s *Server) handleConn(l net.Listener, c net.Conn) { + defer c.Close() + + if !s.trackConn(l, c, true) { + // Server is closed or we no longer want + // connections from this listener. + s.logger.Debug(context.Background(), "received connection after server closed") + return + } + defer s.trackConn(l, c, false) + + s.srv.HandleConn(c) +} + +// trackListener registers the listener with the server. If the server is +// closing, the function will block until the server is closed. +// +//nolint:revive +func (s *Server) trackListener(l net.Listener, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if add { + for s.closing != nil { + closing := s.closing + // Wait until close is complete before + // serving a new listener. + s.mu.Unlock() + <-closing + s.mu.Lock() + } + s.wg.Add(1) + s.listeners[l] = struct{}{} + return + } + s.wg.Done() + delete(s.listeners, l) +} + +// trackConn registers the connection with the server. If the server is +// closed or the listener is closed, the connection is not registered +// and should be closed. +// +//nolint:revive +func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + if add { + found := false + for ll := range s.listeners { + if l == ll { + found = true + break + } + } + if s.closing != nil || !found { + // Server or listener closed. + return false + } + s.wg.Add(1) + s.conns[c] = struct{}{} + return true + } + s.wg.Done() + delete(s.conns, c) + return true +} + +// Close the server and all active connections. Server can be re-used +// after Close is done. func (s *Server) Close() error { + s.mu.Lock() + + // Guard against multiple calls to Close and + // accepting new connections during close. + if s.closing != nil { + s.mu.Unlock() + return xerrors.New("server is closing") + } + s.closing = make(chan struct{}) + + // Close all active listeners and connections. + for l := range s.listeners { + _ = l.Close() + } + for c := range s.conns { + _ = c.Close() + } + + // Close the underlying SSH server. err := s.srv.Close() - s.serveWg.Wait() + + s.mu.Unlock() + s.wg.Wait() // Wait for all goroutines to exit. + + s.mu.Lock() + close(s.closing) + s.closing = nil + s.mu.Unlock() + return err }