Skip to content

Commit bca9928

Browse files
committed
Redo tests
1 parent 334b996 commit bca9928

File tree

6 files changed

+115
-129
lines changed

6 files changed

+115
-129
lines changed

agent/agent.go

+8-10
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,8 @@ func New(options Options) io.Closer {
101101
envVars: options.EnvironmentVariables,
102102
coordinatorDialer: options.CoordinatorDialer,
103103
fetchMetadata: options.FetchMetadata,
104-
stats: &Stats{
105-
ProtocolStats: make(map[string]*ProtocolStats),
106-
},
107-
statsReporter: options.StatsReporter,
104+
stats: &Stats{},
105+
statsReporter: options.StatsReporter,
108106
}
109107
server.init(ctx)
110108
return server
@@ -205,7 +203,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
205203
if listenerExists {
206204
return conn
207205
}
208-
return &StatsConn{ProtocolStats: &ProtocolStats{}, Conn: conn}
206+
return a.stats.wrapConn(conn)
209207
})
210208
go a.runCoordinator(ctx)
211209

@@ -220,7 +218,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
220218
if err != nil {
221219
return
222220
}
223-
a.sshServer.HandleConn(a.stats.wrapConn(conn, ProtocolSSH))
221+
a.sshServer.HandleConn(a.stats.wrapConn(conn))
224222
}
225223
}()
226224
reconnectingPTYListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetReconnectingPTYPort))
@@ -252,7 +250,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
252250
if err != nil {
253251
continue
254252
}
255-
go a.handleReconnectingPTY(ctx, msg, conn)
253+
go a.handleReconnectingPTY(ctx, msg, a.stats.wrapConn(conn))
256254
}
257255
}()
258256
}
@@ -400,7 +398,7 @@ func (a *agent) handlePeerConn(ctx context.Context, peerConn *peer.Conn) {
400398

401399
switch channel.Protocol() {
402400
case ProtocolSSH:
403-
go a.sshServer.HandleConn(a.stats.wrapConn(conn, channel.Protocol()))
401+
go a.sshServer.HandleConn(a.stats.wrapConn(conn))
404402
case ProtocolReconnectingPTY:
405403
rawID := channel.Label()
406404
// The ID format is referenced in conn.go.
@@ -433,9 +431,9 @@ func (a *agent) handlePeerConn(ctx context.Context, peerConn *peer.Conn) {
433431
Height: uint16(height),
434432
Width: uint16(width),
435433
Command: idParts[3],
436-
}, a.stats.wrapConn(conn, channel.Protocol()))
434+
}, a.stats.wrapConn(conn))
437435
case ProtocolDial:
438-
go a.handleDial(ctx, channel.Label(), a.stats.wrapConn(conn, channel.Protocol()))
436+
go a.handleDial(ctx, channel.Label(), a.stats.wrapConn(conn))
439437
default:
440438
a.logger.Warn(ctx, "unhandled protocol from channel",
441439
slog.F("protocol", channel.Protocol()),

agent/agent_test.go

+85-11
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"time"
2020

2121
"golang.org/x/xerrors"
22+
"tailscale.com/tailcfg"
2223

2324
scp "github.com/bramvdbogaerde/go-scp"
2425
"github.com/google/uuid"
@@ -51,6 +52,33 @@ func TestMain(m *testing.M) {
5152

5253
func TestAgent(t *testing.T) {
5354
t.Parallel()
55+
t.Run("Stats", func(t *testing.T) {
56+
for _, tailscale := range []bool{true, false} {
57+
t.Run(fmt.Sprintf("tailscale=%v", tailscale), func(t *testing.T) {
58+
t.Parallel()
59+
60+
var derpMap *tailcfg.DERPMap
61+
if tailscale {
62+
derpMap = tailnettest.RunDERPAndSTUN(t)
63+
}
64+
conn, stats := setupAgent(t, agent.Metadata{
65+
DERPMap: derpMap,
66+
}, 0)
67+
assert.Empty(t, <-stats)
68+
69+
sshClient, err := conn.SSHClient()
70+
require.NoError(t, err)
71+
session, err := sshClient.NewSession()
72+
require.NoError(t, err)
73+
defer session.Close()
74+
75+
assert.EqualValues(t, 1, (<-stats).NumConns)
76+
assert.Greater(t, (<-stats).RxBytes, int64(0))
77+
assert.Greater(t, (<-stats).TxBytes, int64(0))
78+
})
79+
}
80+
})
81+
5482
t.Run("SessionExec", func(t *testing.T) {
5583
t.Parallel()
5684
session := setupSSHSession(t, agent.Metadata{})
@@ -169,7 +197,8 @@ func TestAgent(t *testing.T) {
169197

170198
t.Run("SFTP", func(t *testing.T) {
171199
t.Parallel()
172-
sshClient, err := setupAgent(t, agent.Metadata{}, 0).SSHClient()
200+
conn, _ := setupAgent(t, agent.Metadata{}, 0)
201+
sshClient, err := conn.SSHClient()
173202
require.NoError(t, err)
174203
client, err := sftp.NewClient(sshClient)
175204
require.NoError(t, err)
@@ -184,7 +213,9 @@ func TestAgent(t *testing.T) {
184213

185214
t.Run("SCP", func(t *testing.T) {
186215
t.Parallel()
187-
sshClient, err := setupAgent(t, agent.Metadata{}, 0).SSHClient()
216+
217+
conn, _ := setupAgent(t, agent.Metadata{}, 0)
218+
sshClient, err := conn.SSHClient()
188219
require.NoError(t, err)
189220
scpClient, err := scp.NewClientBySSH(sshClient)
190221
require.NoError(t, err)
@@ -318,7 +349,7 @@ func TestAgent(t *testing.T) {
318349
t.Skip("ConPTY appears to be inconsistent on Windows.")
319350
}
320351

321-
conn := setupAgent(t, agent.Metadata{
352+
conn, _ := setupAgent(t, agent.Metadata{
322353
DERPMap: tailnettest.RunDERPAndSTUN(t),
323354
}, 0)
324355
id := uuid.NewString()
@@ -431,7 +462,7 @@ func TestAgent(t *testing.T) {
431462
}()
432463

433464
// Dial the listener over WebRTC twice and test out of order
434-
conn := setupAgent(t, agent.Metadata{}, 0)
465+
conn, _ := setupAgent(t, agent.Metadata{}, 0)
435466
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
436467
require.NoError(t, err)
437468
defer conn1.Close()
@@ -462,7 +493,7 @@ func TestAgent(t *testing.T) {
462493
})
463494

464495
// Try to dial the non-existent Unix socket over WebRTC
465-
conn := setupAgent(t, agent.Metadata{}, 0)
496+
conn, _ := setupAgent(t, agent.Metadata{}, 0)
466497
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
467498
require.Error(t, err)
468499
require.ErrorContains(t, err, "remote dial error")
@@ -473,7 +504,7 @@ func TestAgent(t *testing.T) {
473504
t.Run("Tailnet", func(t *testing.T) {
474505
t.Parallel()
475506
derpMap := tailnettest.RunDERPAndSTUN(t)
476-
conn := setupAgent(t, agent.Metadata{
507+
conn, _ := setupAgent(t, agent.Metadata{
477508
DERPMap: derpMap,
478509
}, 0)
479510
defer conn.Close()
@@ -485,7 +516,7 @@ func TestAgent(t *testing.T) {
485516
}
486517

487518
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
488-
agentConn := setupAgent(t, agent.Metadata{}, 0)
519+
agentConn, _ := setupAgent(t, agent.Metadata{}, 0)
489520
listener, err := net.Listen("tcp", "127.0.0.1:0")
490521
require.NoError(t, err)
491522
waitGroup := sync.WaitGroup{}
@@ -523,7 +554,8 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
523554
}
524555

525556
func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {
526-
sshClient, err := setupAgent(t, options, 0).SSHClient()
557+
conn, _ := setupAgent(t, options, 0)
558+
sshClient, err := conn.SSHClient()
527559
require.NoError(t, err)
528560
t.Cleanup(func() {
529561
_ = sshClient.Close()
@@ -533,11 +565,21 @@ func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {
533565
return session
534566
}
535567

536-
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) agent.Conn {
568+
type closeFunc func() error
569+
570+
func (c closeFunc) Close() error {
571+
return c()
572+
}
573+
574+
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) (
575+
agent.Conn,
576+
<-chan *agent.Stats,
577+
) {
537578
client, server := provisionersdk.TransportPipe()
538579
tailscale := metadata.DERPMap != nil
539580
coordinator := tailnet.NewCoordinator()
540581
agentID := uuid.New()
582+
statsCh := make(chan *agent.Stats)
541583
closer := agent.New(agent.Options{
542584
FetchMetadata: func(ctx context.Context) (agent.Metadata, error) {
543585
return metadata, nil
@@ -557,6 +599,38 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
557599
},
558600
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
559601
ReconnectingPTYTimeout: ptyTimeout,
602+
StatsReporter: func(ctx context.Context, log slog.Logger, statsFn func() *agent.Stats) (io.Closer, error) {
603+
doneCh := make(chan struct{})
604+
ctx, cancel := context.WithCancel(ctx)
605+
606+
go func() {
607+
defer close(doneCh)
608+
609+
t := time.NewTicker(time.Millisecond * 100)
610+
defer t.Stop()
611+
for {
612+
select {
613+
case <-ctx.Done():
614+
return
615+
case <-t.C:
616+
}
617+
select {
618+
case statsCh <- statsFn():
619+
case <-ctx.Done():
620+
return
621+
default:
622+
// We don't want to send old stats.
623+
continue
624+
}
625+
}
626+
}()
627+
return closeFunc(func() error {
628+
cancel()
629+
<-doneCh
630+
close(statsCh)
631+
return nil
632+
}), nil
633+
},
560634
})
561635
t.Cleanup(func() {
562636
_ = client.Close()
@@ -586,7 +660,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
586660
conn.SetNodeCallback(sendNode)
587661
return &agent.TailnetConn{
588662
Conn: conn,
589-
}
663+
}, statsCh
590664
}
591665
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
592666
Logger: slogtest.Make(t, nil),
@@ -599,7 +673,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
599673
return &agent.WebRTCConn{
600674
Negotiator: api,
601675
Conn: conn,
602-
}
676+
}, statsCh
603677
}
604678

605679
var dialTestPayload = []byte("dean-was-here123")

agent/stats.go

+14-36
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@ import (
44
"context"
55
"io"
66
"net"
7-
"sync"
87
"sync/atomic"
98

109
"cdr.dev/slog"
1110
)
1211

1312
// StatsConn wraps a net.Conn with statistics.
1413
type StatsConn struct {
15-
*ProtocolStats
14+
*Stats
1615
net.Conn `json:"-"`
1716
}
1817

@@ -30,53 +29,32 @@ func (c *StatsConn) Write(b []byte) (n int, err error) {
3029
return n, err
3130
}
3231

33-
type ProtocolStats struct {
34-
NumConns int64 `json:"num_comms"`
35-
36-
// RxBytes must be read with atomic.
37-
RxBytes int64 `json:"rx_bytes"`
38-
39-
// TxBytes must be read with atomic.
40-
TxBytes int64 `json:"tx_bytes"`
41-
}
42-
4332
var _ net.Conn = new(StatsConn)
4433

4534
// Stats records the Agent's network connection statistics for use in
4635
// user-facing metrics and debugging.
4736
type Stats struct {
48-
sync.RWMutex `json:"-"`
49-
ProtocolStats map[string]*ProtocolStats `json:"conn_stats,omitempty"`
37+
NumConns int64 `json:"num_comms"`
38+
// RxBytes must be read with atomic.
39+
RxBytes int64 `json:"rx_bytes"`
40+
// TxBytes must be read with atomic.
41+
TxBytes int64 `json:"tx_bytes"`
5042
}
5143

5244
func (s *Stats) Copy() *Stats {
53-
s.RLock()
54-
ss := Stats{ProtocolStats: make(map[string]*ProtocolStats, len(s.ProtocolStats))}
55-
for k, cs := range s.ProtocolStats {
56-
ss.ProtocolStats[k] = &ProtocolStats{
57-
NumConns: atomic.LoadInt64(&cs.NumConns),
58-
RxBytes: atomic.LoadInt64(&cs.RxBytes),
59-
TxBytes: atomic.LoadInt64(&cs.TxBytes),
60-
}
45+
return &Stats{
46+
NumConns: atomic.LoadInt64(&s.NumConns),
47+
RxBytes: atomic.LoadInt64(&s.RxBytes),
48+
TxBytes: atomic.LoadInt64(&s.TxBytes),
6149
}
62-
s.RUnlock()
63-
return &ss
6450
}
6551

6652
// wrapConn returns a new connection that records statistics.
67-
func (s *Stats) wrapConn(conn net.Conn, protocol string) net.Conn {
68-
s.Lock()
69-
ps, ok := s.ProtocolStats[protocol]
70-
if !ok {
71-
ps = &ProtocolStats{}
72-
s.ProtocolStats[protocol] = ps
73-
}
74-
s.Unlock()
75-
76-
atomic.AddInt64(&ps.NumConns, 1)
53+
func (s *Stats) wrapConn(conn net.Conn) net.Conn {
54+
atomic.AddInt64(&s.NumConns, 1)
7755
cs := &StatsConn{
78-
ProtocolStats: ps,
79-
Conn: conn,
56+
Stats: s,
57+
Conn: conn,
8058
}
8159

8260
return cs

0 commit comments

Comments
 (0)