diff --git a/agent/agent.go b/agent/agent.go index c7a785f8d5da1..5512f04db28ea 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -91,6 +91,7 @@ type Options struct { ModifiedProcesses chan []*agentproc.Process // ProcessManagementTick is used for testing process priority management. ProcessManagementTick <-chan time.Time + BlockFileTransfer bool } type Client interface { @@ -184,6 +185,7 @@ func New(options Options) Agent { modifiedProcs: options.ModifiedProcesses, processManagementTick: options.ProcessManagementTick, logSender: agentsdk.NewLogSender(options.Logger), + blockFileTransfer: options.BlockFileTransfer, prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), @@ -239,6 +241,7 @@ type agent struct { sessionToken atomic.Pointer[string] sshServer *agentssh.Server sshMaxTimeout time.Duration + blockFileTransfer bool lifecycleUpdate chan struct{} lifecycleReported chan codersdk.WorkspaceAgentLifecycle @@ -277,6 +280,7 @@ func (a *agent) init() { AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() }, UpdateEnv: a.updateCommandEnv, WorkingDirectory: func() string { return a.manifest.Load().Directory }, + BlockFileTransfer: a.blockFileTransfer, }) if err != nil { panic(err) diff --git a/agent/agent_test.go b/agent/agent_test.go index a008a60a2362e..4b0712bcf93c6 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -970,6 +970,99 @@ func TestAgent_SCP(t *testing.T) { require.NoError(t, err) } +func TestAgent_FileTransferBlocked(t *testing.T) { + t.Parallel() + + assertFileTransferBlocked := func(t *testing.T, errorMessage string) { + // NOTE: Checking content of the error message is flaky. Most likely there is a race condition, which results + // in stopping the client in different phases, and returning different errors: + // - client read the full error message: File transfer has been disabled. + // - client's stream was terminated before reading the error message: EOF + // - client just read the error code (Windows): Process exited with status 65 + isErr := strings.Contains(errorMessage, agentssh.BlockedFileTransferErrorMessage) || + strings.Contains(errorMessage, "EOF") || + strings.Contains(errorMessage, "Process exited with status 65") + require.True(t, isErr, fmt.Sprintf("Message: "+errorMessage)) + } + + t.Run("SFTP", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + //nolint:dogsled + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.BlockFileTransfer = true + }) + sshClient, err := conn.SSHClient(ctx) + require.NoError(t, err) + defer sshClient.Close() + _, err = sftp.NewClient(sshClient) + require.Error(t, err) + assertFileTransferBlocked(t, err.Error()) + }) + + t.Run("SCP with go-scp package", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + //nolint:dogsled + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.BlockFileTransfer = true + }) + sshClient, err := conn.SSHClient(ctx) + require.NoError(t, err) + defer sshClient.Close() + scpClient, err := scp.NewClientBySSH(sshClient) + require.NoError(t, err) + defer scpClient.Close() + tempFile := filepath.Join(t.TempDir(), "scp") + err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755") + require.Error(t, err) + assertFileTransferBlocked(t, err.Error()) + }) + + t.Run("Forbidden commands", func(t *testing.T) { + t.Parallel() + + for _, c := range agentssh.BlockedFileTransferCommands { + t.Run(c, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + //nolint:dogsled + conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) { + o.BlockFileTransfer = true + }) + sshClient, err := conn.SSHClient(ctx) + require.NoError(t, err) + defer sshClient.Close() + + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + stdout, err := session.StdoutPipe() + require.NoError(t, err) + + //nolint:govet // we don't need `c := c` in Go 1.22 + err = session.Start(c) + require.NoError(t, err) + defer session.Close() + + msg, err := io.ReadAll(stdout) + require.NoError(t, err) + assertFileTransferBlocked(t, string(msg)) + }) + } + }) +} + func TestAgent_EnvironmentVariables(t *testing.T) { t.Parallel() key := "EXAMPLE" diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 54e5a3f41223e..5903220975b8c 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -52,8 +52,16 @@ const ( // MagicProcessCmdlineJetBrains is a string in a process's command line that // uniquely identifies it as JetBrains software. MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains" + + // BlockedFileTransferErrorCode indicates that SSH server restricted the raw command from performing + // the file transfer. + BlockedFileTransferErrorCode = 65 // Error code: host not allowed to connect + BlockedFileTransferErrorMessage = "File transfer has been disabled." ) +// BlockedFileTransferCommands contains a list of restricted file transfer commands. +var BlockedFileTransferCommands = []string{"nc", "rsync", "scp", "sftp"} + // Config sets configuration parameters for the agent SSH server. type Config struct { // MaxTimeout sets the absolute connection timeout, none if empty. If set to @@ -74,6 +82,8 @@ type Config struct { // X11SocketDir is the directory where X11 sockets are created. Default is // /tmp/.X11-unix. X11SocketDir string + // BlockFileTransfer restricts use of file transfer applications. + BlockFileTransfer bool } type Server struct { @@ -272,6 +282,18 @@ func (s *Server) sessionHandler(session ssh.Session) { extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber)) } + if s.fileTransferBlocked(session) { + s.logger.Warn(ctx, "file transfer blocked", slog.F("session_subsystem", session.Subsystem()), slog.F("raw_command", session.RawCommand())) + + if session.Subsystem() == "" { // sftp does not expect error, otherwise it fails with "package too long" + // Response format: \n + errorMessage := fmt.Sprintf("\x02%s\n", BlockedFileTransferErrorMessage) + _, _ = session.Write([]byte(errorMessage)) + } + _ = session.Exit(BlockedFileTransferErrorCode) + return + } + switch ss := session.Subsystem(); ss { case "": case "sftp": @@ -322,6 +344,37 @@ func (s *Server) sessionHandler(session ssh.Session) { _ = session.Exit(0) } +// fileTransferBlocked method checks if the file transfer commands should be blocked. +// +// Warning: consider this mechanism as "Do not trespass" sign, as a violator can still ssh to the host, +// smuggle the `scp` binary, or just manually send files outside with `curl` or `ftp`. +// If a user needs a more sophisticated and battle-proof solution, consider full endpoint security. +func (s *Server) fileTransferBlocked(session ssh.Session) bool { + if !s.config.BlockFileTransfer { + return false // file transfers are permitted + } + // File transfers are restricted. + + if session.Subsystem() == "sftp" { + return true + } + + cmd := session.Command() + if len(cmd) == 0 { + return false // no command? + } + + c := cmd[0] + c = filepath.Base(c) // in case the binary is absolute path, /usr/sbin/scp + + for _, cmd := range BlockedFileTransferCommands { + if cmd == c { + return true + } + } + return false +} + func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv []string) (retErr error) { ctx := session.Context() env := append(session.Environ(), extraEnv...) diff --git a/cli/agent.go b/cli/agent.go index 1f91f1c98bb8d..5465aeedd9302 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -27,6 +27,7 @@ import ( "cdr.dev/slog/sloggers/slogstackdriver" "github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/agent/agentproc" + "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/reaper" "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/codersdk" @@ -48,6 +49,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { slogHumanPath string slogJSONPath string slogStackdriverPath string + blockFileTransfer bool ) cmd := &serpent.Command{ Use: "agent", @@ -314,6 +316,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { // Intentionally set this to nil. It's mainly used // for testing. ModifiedProcesses: nil, + + BlockFileTransfer: blockFileTransfer, }) promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger) @@ -417,6 +421,13 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { Default: "", Value: serpent.StringOf(&slogStackdriverPath), }, + { + Flag: "block-file-transfer", + Default: "false", + Env: "CODER_AGENT_BLOCK_FILE_TRANSFER", + Description: fmt.Sprintf("Block file transfer using known applications: %s.", strings.Join(agentssh.BlockedFileTransferCommands, ",")), + Value: serpent.BoolOf(&blockFileTransfer), + }, } return cmd diff --git a/cli/testdata/coder_agent_--help.golden b/cli/testdata/coder_agent_--help.golden index 372395c4ba5fe..d6982fda18e7c 100644 --- a/cli/testdata/coder_agent_--help.golden +++ b/cli/testdata/coder_agent_--help.golden @@ -18,6 +18,9 @@ OPTIONS: --auth string, $CODER_AGENT_AUTH (default: token) Specify the authentication type to use for the agent. + --block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false) + Block file transfer using known applications: nc,rsync,scp,sftp. + --debug-address string, $CODER_AGENT_DEBUG_ADDRESS (default: 127.0.0.1:2113) The bind address to serve a debug HTTP server.