Skip to content

Commit 1b8c15c

Browse files
committed
feat(agent): add connection reporting for SSH and reconnecing PTY
Updates #15139
1 parent b07b33e commit 1b8c15c

File tree

6 files changed

+307
-32
lines changed

6 files changed

+307
-32
lines changed

agent/agent.go

+111
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"golang.org/x/exp/slices"
2828
"golang.org/x/sync/errgroup"
2929
"golang.org/x/xerrors"
30+
"google.golang.org/protobuf/types/known/timestamppb"
3031
"tailscale.com/net/speedtest"
3132
"tailscale.com/tailcfg"
3233
"tailscale.com/types/netlogtype"
@@ -174,6 +175,7 @@ func New(options Options) Agent {
174175
lifecycleUpdate: make(chan struct{}, 1),
175176
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
176177
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
178+
reportConnectionsUpdate: make(chan struct{}, 1),
177179
ignorePorts: options.IgnorePorts,
178180
portCacheDuration: options.PortCacheDuration,
179181
reportMetadataInterval: options.ReportMetadataInterval,
@@ -247,6 +249,10 @@ type agent struct {
247249
lifecycleStates []agentsdk.PostLifecycleRequest
248250
lifecycleLastReportedIndex int // Keeps track of the last lifecycle state we successfully reported.
249251

252+
reportConnectionsUpdate chan struct{}
253+
reportConnectionsMu sync.Mutex
254+
reportConnections []*proto.ReportConnectionRequest
255+
250256
network *tailnet.Conn
251257
statsReporter *statsReporter
252258
logSender *agentsdk.LogSender
@@ -272,6 +278,25 @@ func (a *agent) init() {
272278
UpdateEnv: a.updateCommandEnv,
273279
WorkingDirectory: func() string { return a.manifest.Load().Directory },
274280
BlockFileTransfer: a.blockFileTransfer,
281+
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
282+
a.logger.Info(a.hardCtx, "reporting connection", slog.F("id", id), slog.F("magic_type", magicType), slog.F("ip", ip))
283+
var connectionType proto.Connection_Type
284+
switch magicType {
285+
case agentssh.MagicSessionTypeSSH:
286+
connectionType = proto.Connection_SSH
287+
case agentssh.MagicSessionTypeVSCode:
288+
connectionType = proto.Connection_VSCODE
289+
case agentssh.MagicSessionTypeJetBrains:
290+
connectionType = proto.Connection_JETBRAINS
291+
case agentssh.MagicSessionTypeUnknown:
292+
connectionType = proto.Connection_TYPE_UNSPECIFIED
293+
default:
294+
a.logger.Error(a.hardCtx, "unhandled magic session type when reporting connection", slog.F("magic_type", magicType))
295+
connectionType = proto.Connection_TYPE_UNSPECIFIED
296+
}
297+
298+
return a.reportConnection(id, connectionType, ip)
299+
},
275300
})
276301
if err != nil {
277302
panic(err)
@@ -294,6 +319,9 @@ func (a *agent) init() {
294319
a.reconnectingPTYServer = reconnectingpty.NewServer(
295320
a.logger.Named("reconnecting-pty"),
296321
a.sshServer,
322+
func(id uuid.UUID, ip string) func(code int, reason string) {
323+
return a.reportConnection(id, proto.Connection_RECONNECTING_PTY, ip)
324+
},
297325
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
298326
a.reconnectingPTYTimeout,
299327
)
@@ -703,6 +731,85 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) {
703731
}
704732
}
705733

734+
// reportConnectionsLoop reports connections to the agent for auditing.
735+
func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentClient24) error {
736+
for {
737+
select {
738+
case <-a.reportConnectionsUpdate:
739+
case <-ctx.Done():
740+
return ctx.Err()
741+
}
742+
743+
for {
744+
a.reportConnectionsMu.Lock()
745+
if len(a.reportConnections) == 0 {
746+
a.reportConnectionsMu.Unlock()
747+
break
748+
}
749+
payload := a.reportConnections[0]
750+
a.reportConnectionsMu.Unlock()
751+
752+
logger := a.logger.With(slog.F("payload", payload))
753+
logger.Debug(ctx, "reporting connection")
754+
_, err := aAPI.ReportConnection(ctx, payload)
755+
if err != nil {
756+
return xerrors.Errorf("failed to report connection: %w", err)
757+
}
758+
759+
logger.Debug(ctx, "successfully reported connection")
760+
761+
a.reportConnectionsMu.Lock()
762+
a.reportConnections = a.reportConnections[1:]
763+
count := len(a.reportConnections)
764+
a.reportConnectionsMu.Unlock()
765+
766+
if count == 0 {
767+
break
768+
}
769+
}
770+
}
771+
}
772+
773+
func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) {
774+
a.reportConnectionsMu.Lock()
775+
defer a.reportConnectionsMu.Unlock()
776+
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
777+
Connection: &proto.Connection{
778+
Id: id[:],
779+
Action: proto.Connection_CONNECT,
780+
Type: connectionType,
781+
Timestamp: timestamppb.New(time.Now()),
782+
Ip: ip,
783+
StatusCode: 0,
784+
Reason: nil,
785+
},
786+
})
787+
select {
788+
case a.reportConnectionsUpdate <- struct{}{}:
789+
default:
790+
}
791+
792+
return func(code int, reason string) {
793+
a.reportConnectionsMu.Lock()
794+
defer a.reportConnectionsMu.Unlock()
795+
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
796+
Connection: &proto.Connection{
797+
Id: id[:],
798+
Action: proto.Connection_DISCONNECT,
799+
Type: connectionType,
800+
Timestamp: timestamppb.New(time.Now()),
801+
Ip: ip,
802+
StatusCode: int32(code), //nolint:gosec
803+
Reason: &reason,
804+
},
805+
})
806+
select {
807+
case a.reportConnectionsUpdate <- struct{}{}:
808+
default:
809+
}
810+
}
811+
}
812+
706813
// fetchServiceBannerLoop fetches the service banner on an interval. It will
707814
// not be fetched immediately; the expectation is that it is primed elsewhere
708815
// (and must be done before the session actually starts).
@@ -813,6 +920,10 @@ func (a *agent) run() (retErr error) {
813920
return resourcesmonitor.Start(ctx)
814921
})
815922

923+
// Connection reports are part of auditing, we should keep sending them via
924+
// gracefulShutdownBehaviorRemain.
925+
connMan.startAgentAPI("report connections", gracefulShutdownBehaviorRemain, a.reportConnectionsLoop)
926+
816927
// channels to sync goroutines below
817928
// handle manifest
818929
// |

agent/agent_test.go

+66-8
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
159159
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
160160
defer cancel()
161161
//nolint:dogsled
162-
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
162+
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
163163
sshClient, err := conn.SSHClient(ctx)
164164
require.NoError(t, err)
165165
defer sshClient.Close()
@@ -189,6 +189,8 @@ func TestAgent_Stats_Magic(t *testing.T) {
189189
_ = stdin.Close()
190190
err = session.Wait()
191191
require.NoError(t, err)
192+
193+
assertConnectionReport(t, agentClient, proto.Connection_VSCODE, 0, "")
192194
})
193195

194196
t.Run("TracksJetBrains", func(t *testing.T) {
@@ -225,7 +227,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
225227
remotePort := sc.Text()
226228

227229
//nolint:dogsled
228-
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
230+
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
229231
sshClient, err := conn.SSHClient(ctx)
230232
require.NoError(t, err)
231233

@@ -261,6 +263,8 @@ func TestAgent_Stats_Magic(t *testing.T) {
261263
}, testutil.WaitLong, testutil.IntervalFast,
262264
"never saw stats after conn closes",
263265
)
266+
267+
assertConnectionReport(t, agentClient, proto.Connection_JETBRAINS, 0, "")
264268
})
265269
}
266270

@@ -918,7 +922,7 @@ func TestAgent_SFTP(t *testing.T) {
918922
home = "/" + strings.ReplaceAll(home, "\\", "/")
919923
}
920924
//nolint:dogsled
921-
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
925+
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
922926
sshClient, err := conn.SSHClient(ctx)
923927
require.NoError(t, err)
924928
defer sshClient.Close()
@@ -941,6 +945,10 @@ func TestAgent_SFTP(t *testing.T) {
941945
require.NoError(t, err)
942946
_, err = os.Stat(tempFile)
943947
require.NoError(t, err)
948+
949+
// Close the client to trigger disconnect event.
950+
_ = client.Close()
951+
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
944952
}
945953

946954
func TestAgent_SCP(t *testing.T) {
@@ -950,7 +958,7 @@ func TestAgent_SCP(t *testing.T) {
950958
defer cancel()
951959

952960
//nolint:dogsled
953-
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
961+
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
954962
sshClient, err := conn.SSHClient(ctx)
955963
require.NoError(t, err)
956964
defer sshClient.Close()
@@ -963,6 +971,10 @@ func TestAgent_SCP(t *testing.T) {
963971
require.NoError(t, err)
964972
_, err = os.Stat(tempFile)
965973
require.NoError(t, err)
974+
975+
// Close the client to trigger disconnect event.
976+
scpClient.Close()
977+
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
966978
}
967979

968980
func TestAgent_FileTransferBlocked(t *testing.T) {
@@ -987,7 +999,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
987999
defer cancel()
9881000

9891001
//nolint:dogsled
990-
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
1002+
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
9911003
o.BlockFileTransfer = true
9921004
})
9931005
sshClient, err := conn.SSHClient(ctx)
@@ -996,6 +1008,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
9961008
_, err = sftp.NewClient(sshClient)
9971009
require.Error(t, err)
9981010
assertFileTransferBlocked(t, err.Error())
1011+
1012+
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
9991013
})
10001014

10011015
t.Run("SCP with go-scp package", func(t *testing.T) {
@@ -1005,7 +1019,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
10051019
defer cancel()
10061020

10071021
//nolint:dogsled
1008-
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
1022+
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
10091023
o.BlockFileTransfer = true
10101024
})
10111025
sshClient, err := conn.SSHClient(ctx)
@@ -1018,6 +1032,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
10181032
err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755")
10191033
require.Error(t, err)
10201034
assertFileTransferBlocked(t, err.Error())
1035+
1036+
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
10211037
})
10221038

10231039
t.Run("Forbidden commands", func(t *testing.T) {
@@ -1031,7 +1047,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
10311047
defer cancel()
10321048

10331049
//nolint:dogsled
1034-
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
1050+
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
10351051
o.BlockFileTransfer = true
10361052
})
10371053
sshClient, err := conn.SSHClient(ctx)
@@ -1053,6 +1069,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
10531069
msg, err := io.ReadAll(stdout)
10541070
require.NoError(t, err)
10551071
assertFileTransferBlocked(t, string(msg))
1072+
1073+
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
10561074
})
10571075
}
10581076
})
@@ -1661,8 +1679,16 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
16611679
defer cancel()
16621680

16631681
//nolint:dogsled
1664-
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
1682+
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
16651683
id := uuid.New()
1684+
1685+
// Test that the connection is reported. This must be tested in the
1686+
// first connection because we care about verifying all of these.
1687+
netConn0, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
1688+
require.NoError(t, err)
1689+
_ = netConn0.Close()
1690+
assertConnectionReport(t, agentClient, proto.Connection_RECONNECTING_PTY, 0, "")
1691+
16661692
// --norc disables executing .bashrc, which is often used to customize the bash prompt
16671693
netConn1, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
16681694
require.NoError(t, err)
@@ -2691,3 +2717,35 @@ func requireEcho(t *testing.T, conn net.Conn) {
26912717
require.NoError(t, err)
26922718
require.Equal(t, "test", string(b))
26932719
}
2720+
2721+
func assertConnectionReport(t testing.TB, agentClient *agenttest.Client, connectionType proto.Connection_Type, status int, reason string) {
2722+
t.Helper()
2723+
2724+
var reports []*proto.ReportConnectionRequest
2725+
if !assert.Eventually(t, func() bool {
2726+
reports = agentClient.GetConnectionReports()
2727+
return len(reports) >= 2
2728+
}, testutil.WaitMedium, testutil.IntervalFast, "waiting for 2 connection reports or more; got %d", len(reports)) {
2729+
return
2730+
}
2731+
2732+
assert.Len(t, reports, 2, "want 2 connection reports")
2733+
2734+
assert.Equal(t, proto.Connection_CONNECT, reports[0].GetConnection().GetAction(), "first report should be connect")
2735+
assert.Equal(t, proto.Connection_DISCONNECT, reports[1].GetConnection().GetAction(), "second report should be disconnect")
2736+
assert.Equal(t, connectionType, reports[0].GetConnection().GetType(), "connect type should be %s", connectionType)
2737+
assert.Equal(t, connectionType, reports[1].GetConnection().GetType(), "disconnect type should be %s", connectionType)
2738+
t1 := reports[0].GetConnection().GetTimestamp().AsTime()
2739+
t2 := reports[1].GetConnection().GetTimestamp().AsTime()
2740+
assert.True(t, t1.Before(t2) || t1.Equal(t2), "connect timestamp should be before or equal to disconnect timestamp")
2741+
assert.NotEmpty(t, reports[0].GetConnection().GetIp(), "connect ip should not be empty")
2742+
assert.NotEmpty(t, reports[1].GetConnection().GetIp(), "disconnect ip should not be empty")
2743+
assert.Equal(t, 0, int(reports[0].GetConnection().GetStatusCode()), "connect status code should be 0")
2744+
assert.Equal(t, status, int(reports[1].GetConnection().GetStatusCode()), "disconnect status code should be %d", status)
2745+
assert.Equal(t, "", reports[0].GetConnection().GetReason(), "connect reason should be empty")
2746+
if reason != "" {
2747+
assert.Contains(t, reports[1].GetConnection().GetReason(), reason, "disconnect reason should contain %s", reason)
2748+
} else {
2749+
t.Logf("connection report disconnect reason: %s", reports[1].GetConnection().GetReason())
2750+
}
2751+
}

0 commit comments

Comments
 (0)