Skip to content

Commit bed1bd4

Browse files
committed
Merge branch 'main' into org-management-ui
2 parents 050d553 + 5b9a65e commit bed1bd4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+862
-342
lines changed

agent/agent.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ type Options struct {
9191
ModifiedProcesses chan []*agentproc.Process
9292
// ProcessManagementTick is used for testing process priority management.
9393
ProcessManagementTick <-chan time.Time
94+
BlockFileTransfer bool
9495
}
9596

9697
type Client interface {
@@ -184,6 +185,7 @@ func New(options Options) Agent {
184185
modifiedProcs: options.ModifiedProcesses,
185186
processManagementTick: options.ProcessManagementTick,
186187
logSender: agentsdk.NewLogSender(options.Logger),
188+
blockFileTransfer: options.BlockFileTransfer,
187189

188190
prometheusRegistry: prometheusRegistry,
189191
metrics: newAgentMetrics(prometheusRegistry),
@@ -239,6 +241,7 @@ type agent struct {
239241
sessionToken atomic.Pointer[string]
240242
sshServer *agentssh.Server
241243
sshMaxTimeout time.Duration
244+
blockFileTransfer bool
242245

243246
lifecycleUpdate chan struct{}
244247
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
@@ -277,6 +280,7 @@ func (a *agent) init() {
277280
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
278281
UpdateEnv: a.updateCommandEnv,
279282
WorkingDirectory: func() string { return a.manifest.Load().Directory },
283+
BlockFileTransfer: a.blockFileTransfer,
280284
})
281285
if err != nil {
282286
panic(err)

agent/agent_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,99 @@ func TestAgent_SCP(t *testing.T) {
970970
require.NoError(t, err)
971971
}
972972

973+
func TestAgent_FileTransferBlocked(t *testing.T) {
974+
t.Parallel()
975+
976+
assertFileTransferBlocked := func(t *testing.T, errorMessage string) {
977+
// NOTE: Checking content of the error message is flaky. Most likely there is a race condition, which results
978+
// in stopping the client in different phases, and returning different errors:
979+
// - client read the full error message: File transfer has been disabled.
980+
// - client's stream was terminated before reading the error message: EOF
981+
// - client just read the error code (Windows): Process exited with status 65
982+
isErr := strings.Contains(errorMessage, agentssh.BlockedFileTransferErrorMessage) ||
983+
strings.Contains(errorMessage, "EOF") ||
984+
strings.Contains(errorMessage, "Process exited with status 65")
985+
require.True(t, isErr, fmt.Sprintf("Message: "+errorMessage))
986+
}
987+
988+
t.Run("SFTP", func(t *testing.T) {
989+
t.Parallel()
990+
991+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
992+
defer cancel()
993+
994+
//nolint:dogsled
995+
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
996+
o.BlockFileTransfer = true
997+
})
998+
sshClient, err := conn.SSHClient(ctx)
999+
require.NoError(t, err)
1000+
defer sshClient.Close()
1001+
_, err = sftp.NewClient(sshClient)
1002+
require.Error(t, err)
1003+
assertFileTransferBlocked(t, err.Error())
1004+
})
1005+
1006+
t.Run("SCP with go-scp package", func(t *testing.T) {
1007+
t.Parallel()
1008+
1009+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
1010+
defer cancel()
1011+
1012+
//nolint:dogsled
1013+
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
1014+
o.BlockFileTransfer = true
1015+
})
1016+
sshClient, err := conn.SSHClient(ctx)
1017+
require.NoError(t, err)
1018+
defer sshClient.Close()
1019+
scpClient, err := scp.NewClientBySSH(sshClient)
1020+
require.NoError(t, err)
1021+
defer scpClient.Close()
1022+
tempFile := filepath.Join(t.TempDir(), "scp")
1023+
err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755")
1024+
require.Error(t, err)
1025+
assertFileTransferBlocked(t, err.Error())
1026+
})
1027+
1028+
t.Run("Forbidden commands", func(t *testing.T) {
1029+
t.Parallel()
1030+
1031+
for _, c := range agentssh.BlockedFileTransferCommands {
1032+
t.Run(c, func(t *testing.T) {
1033+
t.Parallel()
1034+
1035+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
1036+
defer cancel()
1037+
1038+
//nolint:dogsled
1039+
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
1040+
o.BlockFileTransfer = true
1041+
})
1042+
sshClient, err := conn.SSHClient(ctx)
1043+
require.NoError(t, err)
1044+
defer sshClient.Close()
1045+
1046+
session, err := sshClient.NewSession()
1047+
require.NoError(t, err)
1048+
defer session.Close()
1049+
1050+
stdout, err := session.StdoutPipe()
1051+
require.NoError(t, err)
1052+
1053+
//nolint:govet // we don't need `c := c` in Go 1.22
1054+
err = session.Start(c)
1055+
require.NoError(t, err)
1056+
defer session.Close()
1057+
1058+
msg, err := io.ReadAll(stdout)
1059+
require.NoError(t, err)
1060+
assertFileTransferBlocked(t, string(msg))
1061+
})
1062+
}
1063+
})
1064+
}
1065+
9731066
func TestAgent_EnvironmentVariables(t *testing.T) {
9741067
t.Parallel()
9751068
key := "EXAMPLE"

agent/agentssh/agentssh.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,16 @@ const (
5252
// MagicProcessCmdlineJetBrains is a string in a process's command line that
5353
// uniquely identifies it as JetBrains software.
5454
MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
55+
56+
// BlockedFileTransferErrorCode indicates that SSH server restricted the raw command from performing
57+
// the file transfer.
58+
BlockedFileTransferErrorCode = 65 // Error code: host not allowed to connect
59+
BlockedFileTransferErrorMessage = "File transfer has been disabled."
5560
)
5661

62+
// BlockedFileTransferCommands contains a list of restricted file transfer commands.
63+
var BlockedFileTransferCommands = []string{"nc", "rsync", "scp", "sftp"}
64+
5765
// Config sets configuration parameters for the agent SSH server.
5866
type Config struct {
5967
// MaxTimeout sets the absolute connection timeout, none if empty. If set to
@@ -74,6 +82,8 @@ type Config struct {
7482
// X11SocketDir is the directory where X11 sockets are created. Default is
7583
// /tmp/.X11-unix.
7684
X11SocketDir string
85+
// BlockFileTransfer restricts use of file transfer applications.
86+
BlockFileTransfer bool
7787
}
7888

7989
type Server struct {
@@ -272,6 +282,18 @@ func (s *Server) sessionHandler(session ssh.Session) {
272282
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber))
273283
}
274284

285+
if s.fileTransferBlocked(session) {
286+
s.logger.Warn(ctx, "file transfer blocked", slog.F("session_subsystem", session.Subsystem()), slog.F("raw_command", session.RawCommand()))
287+
288+
if session.Subsystem() == "" { // sftp does not expect error, otherwise it fails with "package too long"
289+
// Response format: <status_code><message body>\n
290+
errorMessage := fmt.Sprintf("\x02%s\n", BlockedFileTransferErrorMessage)
291+
_, _ = session.Write([]byte(errorMessage))
292+
}
293+
_ = session.Exit(BlockedFileTransferErrorCode)
294+
return
295+
}
296+
275297
switch ss := session.Subsystem(); ss {
276298
case "":
277299
case "sftp":
@@ -322,6 +344,37 @@ func (s *Server) sessionHandler(session ssh.Session) {
322344
_ = session.Exit(0)
323345
}
324346

347+
// fileTransferBlocked method checks if the file transfer commands should be blocked.
348+
//
349+
// Warning: consider this mechanism as "Do not trespass" sign, as a violator can still ssh to the host,
350+
// smuggle the `scp` binary, or just manually send files outside with `curl` or `ftp`.
351+
// If a user needs a more sophisticated and battle-proof solution, consider full endpoint security.
352+
func (s *Server) fileTransferBlocked(session ssh.Session) bool {
353+
if !s.config.BlockFileTransfer {
354+
return false // file transfers are permitted
355+
}
356+
// File transfers are restricted.
357+
358+
if session.Subsystem() == "sftp" {
359+
return true
360+
}
361+
362+
cmd := session.Command()
363+
if len(cmd) == 0 {
364+
return false // no command?
365+
}
366+
367+
c := cmd[0]
368+
c = filepath.Base(c) // in case the binary is absolute path, /usr/sbin/scp
369+
370+
for _, cmd := range BlockedFileTransferCommands {
371+
if cmd == c {
372+
return true
373+
}
374+
}
375+
return false
376+
}
377+
325378
func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv []string) (retErr error) {
326379
ctx := session.Context()
327380
env := append(session.Environ(), extraEnv...)

cli/agent.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"cdr.dev/slog/sloggers/slogstackdriver"
2828
"github.com/coder/coder/v2/agent"
2929
"github.com/coder/coder/v2/agent/agentproc"
30+
"github.com/coder/coder/v2/agent/agentssh"
3031
"github.com/coder/coder/v2/agent/reaper"
3132
"github.com/coder/coder/v2/buildinfo"
3233
"github.com/coder/coder/v2/codersdk"
@@ -48,6 +49,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
4849
slogHumanPath string
4950
slogJSONPath string
5051
slogStackdriverPath string
52+
blockFileTransfer bool
5153
)
5254
cmd := &serpent.Command{
5355
Use: "agent",
@@ -314,6 +316,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
314316
// Intentionally set this to nil. It's mainly used
315317
// for testing.
316318
ModifiedProcesses: nil,
319+
320+
BlockFileTransfer: blockFileTransfer,
317321
})
318322

319323
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
@@ -417,6 +421,13 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
417421
Default: "",
418422
Value: serpent.StringOf(&slogStackdriverPath),
419423
},
424+
{
425+
Flag: "block-file-transfer",
426+
Default: "false",
427+
Env: "CODER_AGENT_BLOCK_FILE_TRANSFER",
428+
Description: fmt.Sprintf("Block file transfer using known applications: %s.", strings.Join(agentssh.BlockedFileTransferCommands, ",")),
429+
Value: serpent.BoolOf(&blockFileTransfer),
430+
},
420431
}
421432

422433
return cmd

cli/server.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ import (
6262
"github.com/coder/coder/v2/cli/config"
6363
"github.com/coder/coder/v2/coderd"
6464
"github.com/coder/coder/v2/coderd/autobuild"
65-
"github.com/coder/coder/v2/coderd/batchstats"
6665
"github.com/coder/coder/v2/coderd/database"
6766
"github.com/coder/coder/v2/coderd/database/awsiamrds"
6867
"github.com/coder/coder/v2/coderd/database/dbmem"
@@ -87,7 +86,7 @@ import (
8786
stringutil "github.com/coder/coder/v2/coderd/util/strings"
8887
"github.com/coder/coder/v2/coderd/workspaceapps"
8988
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
90-
"github.com/coder/coder/v2/coderd/workspaceusage"
89+
"github.com/coder/coder/v2/coderd/workspacestats"
9190
"github.com/coder/coder/v2/codersdk"
9291
"github.com/coder/coder/v2/codersdk/drpc"
9392
"github.com/coder/coder/v2/cryptorand"
@@ -870,9 +869,9 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
870869
options.SwaggerEndpoint = vals.Swagger.Enable.Value()
871870
}
872871

873-
batcher, closeBatcher, err := batchstats.New(ctx,
874-
batchstats.WithLogger(options.Logger.Named("batchstats")),
875-
batchstats.WithStore(options.Database),
872+
batcher, closeBatcher, err := workspacestats.NewBatcher(ctx,
873+
workspacestats.BatcherWithLogger(options.Logger.Named("batchstats")),
874+
workspacestats.BatcherWithStore(options.Database),
876875
)
877876
if err != nil {
878877
return xerrors.Errorf("failed to create agent stats batcher: %w", err)
@@ -977,8 +976,8 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
977976
defer purger.Close()
978977

979978
// Updates workspace usage
980-
tracker := workspaceusage.New(options.Database,
981-
workspaceusage.WithLogger(logger.Named("workspace_usage_tracker")),
979+
tracker := workspacestats.NewTracker(options.Database,
980+
workspacestats.TrackerWithLogger(logger.Named("workspace_usage_tracker")),
982981
)
983982
options.WorkspaceUsageTracker = tracker
984983
defer tracker.Close()

cli/testdata/coder_agent_--help.golden

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ OPTIONS:
1818
--auth string, $CODER_AGENT_AUTH (default: token)
1919
Specify the authentication type to use for the agent.
2020

21+
--block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false)
22+
Block file transfer using known applications: nc,rsync,scp,sftp.
23+
2124
--debug-address string, $CODER_AGENT_DEBUG_ADDRESS (default: 127.0.0.1:2113)
2225
The bind address to serve a debug HTTP server.
2326

clock/clock.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ type Clock interface {
2626
Now(tags ...string) time.Time
2727
// Since returns the time elapsed since t. It is shorthand for Clock.Now().Sub(t).
2828
Since(t time.Time, tags ...string) time.Duration
29+
// Until returns the duration until t. It is shorthand for t.Sub(Clock.Now()).
30+
Until(t time.Time, tags ...string) time.Duration
2931
}
3032

3133
// Waiter can be waited on for an error.

clock/example_test.go

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func TestExampleTickerFunc(t *testing.T) {
4444
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
4545
defer cancel()
4646

47-
mClock := clock.NewMock()
47+
mClock := clock.NewMock(t)
4848

4949
// Because the ticker is started on a goroutine, we can't immediately start
5050
// advancing the clock, or we will race with the start of the ticker. If we
@@ -76,9 +76,74 @@ func TestExampleTickerFunc(t *testing.T) {
7676
}
7777

7878
// Now that we know the ticker is started, we can advance the time.
79-
mClock.Advance(time.Hour).MustWait(ctx, t)
79+
mClock.Advance(time.Hour).MustWait(ctx)
8080

8181
if tks := tc.Ticks(); tks != 1 {
8282
t.Fatalf("expected 1 got %d ticks", tks)
8383
}
8484
}
85+
86+
type exampleLatencyMeasurer struct {
87+
mu sync.Mutex
88+
lastLatency time.Duration
89+
}
90+
91+
func newExampleLatencyMeasurer(ctx context.Context, clk clock.Clock) *exampleLatencyMeasurer {
92+
m := &exampleLatencyMeasurer{}
93+
clk.TickerFunc(ctx, 10*time.Second, func() error {
94+
start := clk.Now()
95+
// m.doSomething()
96+
latency := clk.Since(start)
97+
m.mu.Lock()
98+
defer m.mu.Unlock()
99+
m.lastLatency = latency
100+
return nil
101+
})
102+
return m
103+
}
104+
105+
func (m *exampleLatencyMeasurer) LastLatency() time.Duration {
106+
m.mu.Lock()
107+
defer m.mu.Unlock()
108+
return m.lastLatency
109+
}
110+
111+
func TestExampleLatencyMeasurer(t *testing.T) {
112+
t.Parallel()
113+
114+
// nolint:gocritic // trying to avoid Coder-specific stuff with an eye toward spinning this out
115+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
116+
defer cancel()
117+
118+
mClock := clock.NewMock(t)
119+
trap := mClock.Trap().Since()
120+
defer trap.Close()
121+
122+
lm := newExampleLatencyMeasurer(ctx, mClock)
123+
124+
w := mClock.Advance(10 * time.Second) // triggers first tick
125+
c := trap.MustWait(ctx) // call to Since()
126+
mClock.Advance(33 * time.Millisecond)
127+
c.Release()
128+
w.MustWait(ctx)
129+
130+
if l := lm.LastLatency(); l != 33*time.Millisecond {
131+
t.Fatalf("expected 33ms got %s", l.String())
132+
}
133+
134+
// Next tick is in 10s - 33ms, but if we don't want to calculate, we can use:
135+
d, w2 := mClock.AdvanceNext()
136+
c = trap.MustWait(ctx)
137+
mClock.Advance(17 * time.Millisecond)
138+
c.Release()
139+
w2.MustWait(ctx)
140+
141+
expectedD := 10*time.Second - 33*time.Millisecond
142+
if d != expectedD {
143+
t.Fatalf("expected %s got %s", expectedD.String(), d.String())
144+
}
145+
146+
if l := lm.LastLatency(); l != 17*time.Millisecond {
147+
t.Fatalf("expected 17ms got %s", l.String())
148+
}
149+
}

0 commit comments

Comments
 (0)