Skip to content

refactor(agent): Move SSH server into agentssh package #7004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 6, 2023

Conversation

mafredri
Copy link
Member

@mafredri mafredri commented Apr 4, 2023

This PR refactors the coder agent SSH server into a separate package, agentssh.

The motivation is to break out functionality of the agent into a more bite-sized package as well as contain complexity as more parts of #6177 are amended.

Note: In an effort to keep changes minimal, the non-ideal decision was made to keep all env-var processing in the new package. This meant transplanting agentsdk.Manifest and agent token into the agentssh package. Parts of this will be migrated back into the agent package at a later time.

Refs: #6177

@mafredri
Copy link
Member Author

mafredri commented Apr 4, 2023

In an effort to verify my (minimal) changes, I created the following diff by simply deleting/moving some code around in agent/agent.go and comparing to agent/agentssh/agentssh.go (sharing in case it's useful):

--- agent/agent.go	2023-04-04 14:57:21.344043135 +0000
+++ agent/agentssh/agentssh.go	2023-04-04 14:55:18.747492195 +0000
@@ -1,4 +1,4 @@
-package agent
+package agentssh
 
 import (
 	"bufio"
@@ -16,6 +16,8 @@
 	"runtime"
 	"strings"
 	"sync"
+	"sync/atomic"
+	"time"
 
 	"github.com/gliderlabs/ssh"
 	"github.com/pkg/sftp"
@@ -23,7 +25,9 @@
 	"golang.org/x/xerrors"
 
 	"cdr.dev/slog"
+
 	"github.com/coder/coder/agent/usershell"
+	"github.com/coder/coder/codersdk/agentsdk"
 	"github.com/coder/coder/pty"
 )
 
@@ -33,46 +37,65 @@
 	// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
 	MagicSessionErrorCode = 229
 
-	// MagicSSHSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
+	// MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
 	// This is stripped from any commands being executed, and is counted towards connection stats.
-	MagicSSHSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
-	// MagicSSHSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
-	MagicSSHSessionTypeVSCode = "vscode"
-	// MagicSSHSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself.
-	MagicSSHSessionTypeJetBrains = "jetbrains"
+	MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
+	// MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
+	MagicSessionTypeVSCode = "vscode"
+	// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself.
+	MagicSessionTypeJetBrains = "jetbrains"
 )
 
-func (a *agent) init(ctx context.Context) {
+type Server struct {
+	serveWg sync.WaitGroup
+	logger  slog.Logger
+
+	srv *ssh.Server
+
+	Env        map[string]string
+	AgentToken func() string
+
+	manifest atomic.Pointer[agentsdk.Manifest]
+
+	connCountVSCode     atomic.Int64
+	connCountJetBrains  atomic.Int64
+	connCountSSHSession atomic.Int64
+}
+
+func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration) (*Server, error) {
 	// Clients' should ignore the host key when connecting.
 	// The agent needs to authenticate with coderd to SSH,
 	// so SSH authentication doesn't improve security.
 	randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
 	if err != nil {
-		panic(err)
+		return nil, err
 	}
 	randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
 	if err != nil {
-		panic(err)
+		return nil, err
 	}
 
-	sshLogger := a.logger.Named("ssh-server")
 	forwardHandler := &ssh.ForwardedTCPHandler{}
-	unixForwardHandler := &forwardedUnixHandler{log: a.logger}
+	unixForwardHandler := &forwardedUnixHandler{log: logger}
+
+	s := &Server{
+		logger: logger,
+	}
 
-	a.sshServer = &ssh.Server{
+	s.srv = &ssh.Server{
 		ChannelHandlers: map[string]ssh.ChannelHandler{
 			"direct-tcpip":                   ssh.DirectTCPIPHandler,
 			"direct-streamlocal@openssh.com": directStreamLocalHandler,
 			"session":                        ssh.DefaultSessionHandler,
 		},
-		ConnectionFailedCallback: func(conn net.Conn, err error) {
-			sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
+		ConnectionFailedCallback: func(_ net.Conn, err error) {
+			s.logger.Info(ctx, "ssh connection ended", slog.Error(err))
 		},
-		Handler:     a.sessionHandler,
+		Handler:     s.sessionHandler,
 		HostSigners: []ssh.Signer{randomSigner},
 		LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
 			// Allow local port forwarding all!
-			sshLogger.Debug(ctx, "local port forward",
+			s.logger.Debug(ctx, "local port forward",
 				slog.F("destination-host", destinationHost),
 				slog.F("destination-port", destinationPort))
 			return true
@@ -82,7 +105,7 @@
 		},
 		ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
 			// Allow reverse port forwarding all!
-			sshLogger.Debug(ctx, "local port forward",
+			s.logger.Debug(ctx, "local port forward",
 				slog.F("bind-host", bindHost),
 				slog.F("bind-port", bindPort))
 			return true
@@ -99,24 +122,44 @@
 			}
 		},
 		SubsystemHandlers: map[string]ssh.SubsystemHandler{
-			"sftp": a.sftpHandler,
+			"sftp": s.sftpHandler,
 		},
-		MaxTimeout: a.sshMaxTimeout,
+		MaxTimeout: maxTimeout,
+	}
+
+	return s, nil
+}
+
+// SetManifest sets the manifest used for starting commands.
+func (s *Server) SetManifest(m *agentsdk.Manifest) {
+	s.manifest.Store(m)
+}
+
+type ConnStats struct {
+	Sessions  int64
+	VSCode    int64
+	JetBrains int64
 	}
 
-	go a.runLoop(ctx)
+func (s *Server) ConnStats() ConnStats {
+	return ConnStats{
+		Sessions:  s.connCountSSHSession.Load(),
+		VSCode:    s.connCountVSCode.Load(),
+		JetBrains: s.connCountJetBrains.Load(),
+	}
 }
 
-func (a *agent) sessionHandler(session ssh.Session) {
-	err := a.handleSSHSession(session)
+func (s *Server) sessionHandler(session ssh.Session) {
+	ctx := session.Context()
+	err := s.sessionStart(session)
 	var exitError *exec.ExitError
 	if xerrors.As(err, &exitError) {
-		a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError))
+		s.logger.Debug(ctx, "ssh session returned", slog.Error(exitError))
 		_ = session.Exit(exitError.ExitCode())
 		return
 	}
 	if err != nil {
-		a.logger.Warn(ctx, "ssh session failed", slog.Error(err))
+		s.logger.Warn(ctx, "ssh session failed", slog.Error(err))
 		// This exit code is designed to be unlikely to be confused for a legit exit code
 		// from the process.
 		_ = session.Exit(MagicSessionErrorCode)
@@ -125,32 +168,32 @@
 	_ = session.Exit(0)
 }
 
-func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
+func (s *Server) sessionStart(session ssh.Session) (retErr error) {
 	ctx := session.Context()
 	env := session.Environ()
 	var magicType string
 	for index, kv := range env {
-		if !strings.HasPrefix(kv, MagicSSHSessionTypeEnvironmentVariable) {
+		if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) {
 			continue
 		}
-		magicType = strings.TrimPrefix(kv, MagicSSHSessionTypeEnvironmentVariable+"=")
+		magicType = strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"=")
 		env = append(env[:index], env[index+1:]...)
 	}
 	switch magicType {
-	case MagicSSHSessionTypeVSCode:
-		a.connCountVSCode.Add(1)
-		defer a.connCountVSCode.Add(-1)
-	case MagicSSHSessionTypeJetBrains:
-		a.connCountJetBrains.Add(1)
-		defer a.connCountJetBrains.Add(-1)
+	case MagicSessionTypeVSCode:
+		s.connCountVSCode.Add(1)
+		defer s.connCountVSCode.Add(-1)
+	case MagicSessionTypeJetBrains:
+		s.connCountJetBrains.Add(1)
+		defer s.connCountJetBrains.Add(-1)
 	case "":
-		a.connCountSSHSession.Add(1)
-		defer a.connCountSSHSession.Add(-1)
+		s.connCountSSHSession.Add(1)
+		defer s.connCountSSHSession.Add(-1)
 	default:
-		a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
+		s.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
 	}
 
-	cmd, err := a.createCommand(ctx, session.RawCommand(), env)
+	cmd, err := s.CreateCommand(ctx, session.RawCommand(), env)
 	if err != nil {
 		return err
 	}
@@ -172,14 +215,14 @@
 		session.DisablePTYEmulation()
 
 		if !isQuietLogin(session.RawCommand()) {
-			manifest := a.manifest.Load()
+			manifest := s.manifest.Load()
 			if manifest != nil {
 				err = showMOTD(session, manifest.MOTDFile)
 				if err != nil {
-					a.logger.Error(ctx, "show MOTD", slog.Error(err))
+					s.logger.Error(ctx, "show MOTD", slog.Error(err))
 				}
 			} else {
-				a.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
+				s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
 			}
 		}
 
@@ -188,7 +231,7 @@
 		// The pty package sets `SSH_TTY` on supported platforms.
 		ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
 			pty.WithSSHRequest(sshPty),
-			pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)),
+			pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)),
 		))
 		if err != nil {
 			return xerrors.Errorf("start command: %w", err)
@@ -198,7 +241,7 @@
 			defer wg.Wait()
 			closeErr := ptty.Close()
 			if closeErr != nil {
-				a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
+				s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
 				if retErr == nil {
 					retErr = closeErr
 				}
@@ -209,7 +252,7 @@
 				resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
 				// If the pty is closed, then command has exited, no need to log.
 				if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
-					a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
+					s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
 				}
 			}
 		}()
@@ -258,7 +301,7 @@
 		// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
 		// and not something to be concerned about.  But, if it's something else, we should log it.
 		if err != nil && !xerrors.As(err, &exitErr) {
-			a.logger.Warn(ctx, "wait error", slog.Error(err))
+			s.logger.Warn(ctx, "wait error", slog.Error(err))
 		}
 		return err
 	}
@@ -287,7 +330,7 @@
 // Close implements io.Closer.
 func (readNopCloser) Close() error { return nil }
 
-func (a *agent) sftpHandler(session ssh.Session) {
+func (s *Server) sftpHandler(session ssh.Session) {
 	ctx := session.Context()
 
 	// Typically sftp sessions don't request a TTY, but if they do,
@@ -301,14 +344,14 @@
 	// directory so that SFTP connections land there.
 	homedir, err := userHomeDir()
 	if err != nil {
-		sshLogger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err))
+		s.logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err))
 	} else {
 		opts = append(opts, sftp.WithServerWorkingDirectory(homedir))
 	}
 
 	server, err := sftp.NewServer(session, opts...)
 	if err != nil {
-		sshLogger.Debug(ctx, "initialize sftp server", slog.Error(err))
+		s.logger.Debug(ctx, "initialize sftp server", slog.Error(err))
 		return
 	}
 	defer server.Close()
@@ -326,14 +369,14 @@
 		_ = session.Exit(0)
 		return
 	}
-	sshLogger.Warn(ctx, "sftp server closed with error", slog.Error(err))
+	s.logger.Warn(ctx, "sftp server closed with error", slog.Error(err))
 	_ = session.Exit(1)
 }
 
-// createCommand processes raw command input with OpenSSH-like behavior.
+// CreateCommand processes raw command input with OpenSSH-like behavior.
 // If the script provided is empty, it will default to the users shell.
 // This injects environment variables specified by the user at launch too.
-func (a *agent) createCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) {
+func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) {
 	currentUser, err := user.Current()
 	if err != nil {
 		return nil, xerrors.Errorf("get current user: %w", err)
@@ -345,7 +388,7 @@
 		return nil, xerrors.Errorf("get user shell: %w", err)
 	}
 
-	manifest := a.manifest.Load()
+	manifest := s.manifest.Load()
 	if manifest == nil {
 		return nil, xerrors.Errorf("no metadata was provided")
 	}
@@ -398,7 +441,7 @@
 	cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath))
 
 	// Specific Coder subcommands require the agent token exposed!
-	cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", *a.sessionToken.Load()))
+	cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", s.AgentToken()))
 
 	// Set SSH connection environment variables (these are also set by OpenSSH
 	// and thus expected to be present by SSH clients). Since the agent does
@@ -427,13 +470,34 @@
 
 	// Agent-level environment variables should take over all!
 	// This is used for setting agent-specific variables like "CODER_AGENT_TOKEN".
-	for envKey, value := range a.envVars {
+	for envKey, value := range s.Env {
 		cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value))
 	}
 
 	return cmd, nil
 }
 
+func (s *Server) Serve(l net.Listener) error {
+	s.serveWg.Add(1)
+	defer s.serveWg.Done()
+	return s.srv.Serve(l)
+}
+
+func (s *Server) Close() error {
+	err := s.srv.Close()
+	s.serveWg.Wait()
+	return err
+}
+
+// Shutdown gracefully closes all active SSH connections and stops
+// accepting new connections.
+//
+// Shutdown is not implemented.
+func (*Server) Shutdown(ctx context.Context) error {
+	// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
+	return nil
+}
+
 // isQuietLogin checks if the SSH server should perform a quiet login or not.
 //
 // https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816

func (s *Server) Serve(l net.Listener) error {
s.serveWg.Add(1)
defer s.serveWg.Done()
return s.srv.Serve(l)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: In an effort to slightly simplify this package, we currently rely on the SSH servers Serve method. This was previously managing the connections manually but this seems to serve the same purpose.

@mafredri mafredri marked this pull request as ready for review April 4, 2023 18:03
@mafredri
Copy link
Member Author

mafredri commented Apr 4, 2023

Tests caught a little problem in the ssh package when using (*ssh.Server).Serve, attempted to fix it in coder/ssh#1.

Copy link
Member

@deansheather deansheather left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I recently added and merged another call to agent.Bicopy so you'll need to merge/rebase to avoid breaking main

//
// Shutdown is not implemented.
func (*Server) Shutdown(_ context.Context) error {
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not what srv.Close does? I would assume it returns errors in the server functions, which would then cause any deferred process.Kill()s or whatever we have to fire off.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently Close abruptly closes connections in a way that 1) leaves clients hanging waiting for timeout and 2) sighup isn't properly propagated to running processes.

Issue 1) is easily solved by tracking sessions (follow-up PR) but that won't help with process closure.

In theory I guess all of this could be handled by Close, but it would be good to allow process closure to take a few seconds (Shutdown API is more suitable for this than blocking in Close). Agent would ultimately first do a Shutdown and finally a Close at the end.

More on this in #6177 (more of a guideline than defining the final implementation).

@mafredri mafredri requested a review from deansheather April 6, 2023 09:57
@mafredri
Copy link
Member Author

mafredri commented Apr 6, 2023

@deansheather I've addressed all the feedback, feel free to take another look.

PS. I went and implemented better connection tracking within agentssh.Server so that close is more robust. This is so we don't rely on the dodgy fix in coder/ssh#1 and is conducive to my follow-up PR which will track sessions.

// closing, the function will block until the server is closed.
//
//nolint:revive
func (s *Server) trackListener(l net.Listener, add bool) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems quite complicated since we only have one listener ever being served from what I can tell

Copy link
Member Author

@mafredri mafredri Apr 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a suggestion for a simplification? I wanted this package to be able to manage it's own state and give guarantees for close/shutdown. This is in part motivated by the current setup of tailnet in the agent, which can re-run if an error is encountered (i.e. after a call to ssh server Serve).

(We also can't rely on the ssh package because it has broken guarantees in this regard.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just have a single listener on the struct instead of a map, but still add to the waitgroup? It seems that it's written in a way where it can be reused after close by calling serve again, but I don't believe we use that anywhere so it seems unnecessary.

Copy link
Member Author

@mafredri mafredri Apr 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that it's written in a way where it can be reused after close by calling serve again

That's actually what happens if createTailnet encounters an error and a new tailnet is set up in the next retry, Serve will be called again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, could you make it so the createTailnet function recreates the SSH server when it wants to recreate the tailnet?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically in go structs when Close is called it's dead forever, so this seems to not match what most people would expect

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to do that, perhaps something for a future refactor? For now I'd like to keep the functionality similar to what it was before. And I think a little complexity contained in a package is fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK, I don't think it's just a little complexity. The s.closed loop took me multiple read throughs to understand what it was trying to do. You should get a second opinion

@mafredri mafredri merged commit 0224426 into main Apr 6, 2023
@mafredri mafredri deleted the mafredri/refactor-agent-sshd branch April 6, 2023 16:39
@github-actions github-actions bot locked and limited conversation to collaborators Apr 6, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants