Skip to content

Commit 218f4b0

Browse files
committed
fix test
1 parent d702dca commit 218f4b0

File tree

4 files changed

+46
-5
lines changed

4 files changed

+46
-5
lines changed

agent/agent.go

+6
Original file line numberDiff line numberDiff line change
@@ -1503,6 +1503,12 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
15031503

15041504
stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load()
15051505

1506+
// if we've seen sessions but currently have no connections we
1507+
// just count the sum of the sessions as connections
1508+
if stats.ConnectionCount == 0 {
1509+
stats.ConnectionCount = stats.SessionCountSsh + stats.SessionCountVscode + stats.SessionCountJetbrains + stats.SessionCountReconnectingPty
1510+
}
1511+
15061512
// Compute the median connection latency!
15071513
a.logger.Debug(ctx, "starting peer latency measurement for stats")
15081514
var wg sync.WaitGroup

agent/agentssh/agentssh.go

+30-4
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ type Server struct {
105105
connCountVSCode atomic.Int64
106106
connCountJetBrains atomic.Int64
107107
connCountSSHSession atomic.Int64
108+
seenVSCode atomic.Bool
109+
seenJetBrains atomic.Bool
110+
seenSSHSession atomic.Bool
108111

109112
metrics *sshServerMetrics
110113
}
@@ -167,7 +170,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
167170
ChannelHandlers: map[string]ssh.ChannelHandler{
168171
"direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
169172
// Wrapper is designed to find and track JetBrains Gateway connections.
170-
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, newChan, &s.connCountJetBrains)
173+
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, newChan, &s.connCountJetBrains, &s.seenJetBrains)
171174
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
172175
},
173176
"direct-streamlocal@openssh.com": directStreamLocalHandler,
@@ -245,10 +248,31 @@ type ConnStats struct {
245248
}
246249

247250
func (s *Server) ConnStats() ConnStats {
251+
// if we have 0 active connections, but we have seen a connection
252+
// since the last time we collected, count it as 1 so that workspace
253+
// activity is properly counted.
254+
sshCount := s.connCountSSHSession.Load()
255+
if sshCount == 0 && s.seenSSHSession.Load() {
256+
sshCount = 1
257+
}
258+
vscode := s.connCountVSCode.Load()
259+
if vscode == 0 && s.seenVSCode.Load() {
260+
vscode = 1
261+
}
262+
jetbrains := s.connCountJetBrains.Load()
263+
if jetbrains == 0 && s.seenJetBrains.Load() {
264+
jetbrains = 1
265+
}
266+
267+
// Reset the seen trackers for the next collection.
268+
s.seenSSHSession.Store(false)
269+
s.seenVSCode.Store(false)
270+
s.seenJetBrains.Store(false)
271+
248272
return ConnStats{
249-
Sessions: s.connCountSSHSession.Load(),
250-
VSCode: s.connCountVSCode.Load(),
251-
JetBrains: s.connCountJetBrains.Load(),
273+
Sessions: sshCount,
274+
VSCode: vscode,
275+
JetBrains: jetbrains,
252276
}
253277
}
254278

@@ -392,12 +416,14 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
392416
switch magicType {
393417
case MagicSessionTypeVSCode:
394418
s.connCountVSCode.Add(1)
419+
s.seenVSCode.Store(true)
395420
defer s.connCountVSCode.Add(-1)
396421
case MagicSessionTypeJetBrains:
397422
// Do nothing here because JetBrains launches hundreds of ssh sessions.
398423
// We instead track JetBrains in the single persistent tcp forwarding channel.
399424
case "":
400425
s.connCountSSHSession.Add(1)
426+
s.seenSSHSession.Store(true)
401427
defer s.connCountSSHSession.Add(-1)
402428
default:
403429
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))

agent/agentssh/jetbrainstrack.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ type localForwardChannelData struct {
2727
type JetbrainsChannelWatcher struct {
2828
gossh.NewChannel
2929
jetbrainsCounter *atomic.Int64
30+
jetbrainsSeen *atomic.Bool
3031
logger slog.Logger
3132
}
3233

33-
func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel {
34+
func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64, seen *atomic.Bool) gossh.NewChannel {
3435
d := localForwardChannelData{}
3536
if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil {
3637
// If the data fails to unmarshal, do nothing.
@@ -60,6 +61,7 @@ func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel
6061
return &JetbrainsChannelWatcher{
6162
NewChannel: newChannel,
6263
jetbrainsCounter: counter,
64+
jetbrainsSeen: seen,
6365
logger: logger.With(slog.F("destination_port", d.DestPort)),
6466
}
6567
}
@@ -70,6 +72,7 @@ func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request
7072
return c, r, err
7173
}
7274
w.jetbrainsCounter.Add(1)
75+
w.jetbrainsSeen.Store(true)
7376
// nolint: gocritic // JetBrains is a proper noun and should be capitalized
7477
w.logger.Debug(context.Background(), "JetBrains watcher accepted channel")
7578

coderd/activitybump_test.go

+6
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,12 @@ func TestWorkspaceActivityBump(t *testing.T) {
212212
time.Sleep(time.Second * 3)
213213
sshConn, err := conn.SSHClient(ctx)
214214
require.NoError(t, err)
215+
sess, err := sshConn.NewSession()
216+
require.NoError(t, err)
217+
err = sess.Shell()
218+
require.NoError(t, err)
219+
err = sess.Close()
220+
require.NoError(t, err)
215221
_ = sshConn.Close()
216222

217223
assertBumped(true)

0 commit comments

Comments
 (0)