@@ -19,6 +19,7 @@ import (
19
19
"time"
20
20
21
21
"golang.org/x/xerrors"
22
+ "tailscale.com/tailcfg"
22
23
23
24
scp "github.com/bramvdbogaerde/go-scp"
24
25
"github.com/google/uuid"
@@ -51,6 +52,33 @@ func TestMain(m *testing.M) {
51
52
52
53
func TestAgent (t * testing.T ) {
53
54
t .Parallel ()
55
+ t .Run ("Stats" , func (t * testing.T ) {
56
+ for _ , tailscale := range []bool {true , false } {
57
+ t .Run (fmt .Sprintf ("tailscale=%v" , tailscale ), func (t * testing.T ) {
58
+ t .Parallel ()
59
+
60
+ var derpMap * tailcfg.DERPMap
61
+ if tailscale {
62
+ derpMap = tailnettest .RunDERPAndSTUN (t )
63
+ }
64
+ conn , stats := setupAgent (t , agent.Metadata {
65
+ DERPMap : derpMap ,
66
+ }, 0 )
67
+ assert .Empty (t , <- stats )
68
+
69
+ sshClient , err := conn .SSHClient ()
70
+ require .NoError (t , err )
71
+ session , err := sshClient .NewSession ()
72
+ require .NoError (t , err )
73
+ defer session .Close ()
74
+
75
+ assert .EqualValues (t , 1 , (<- stats ).NumConns )
76
+ assert .Greater (t , (<- stats ).RxBytes , int64 (0 ))
77
+ assert .Greater (t , (<- stats ).TxBytes , int64 (0 ))
78
+ })
79
+ }
80
+ })
81
+
54
82
t .Run ("SessionExec" , func (t * testing.T ) {
55
83
t .Parallel ()
56
84
session := setupSSHSession (t , agent.Metadata {})
@@ -169,7 +197,8 @@ func TestAgent(t *testing.T) {
169
197
170
198
t .Run ("SFTP" , func (t * testing.T ) {
171
199
t .Parallel ()
172
- sshClient , err := setupAgent (t , agent.Metadata {}, 0 ).SSHClient ()
200
+ conn , _ := setupAgent (t , agent.Metadata {}, 0 )
201
+ sshClient , err := conn .SSHClient ()
173
202
require .NoError (t , err )
174
203
client , err := sftp .NewClient (sshClient )
175
204
require .NoError (t , err )
@@ -184,7 +213,9 @@ func TestAgent(t *testing.T) {
184
213
185
214
t .Run ("SCP" , func (t * testing.T ) {
186
215
t .Parallel ()
187
- sshClient , err := setupAgent (t , agent.Metadata {}, 0 ).SSHClient ()
216
+
217
+ conn , _ := setupAgent (t , agent.Metadata {}, 0 )
218
+ sshClient , err := conn .SSHClient ()
188
219
require .NoError (t , err )
189
220
scpClient , err := scp .NewClientBySSH (sshClient )
190
221
require .NoError (t , err )
@@ -318,7 +349,7 @@ func TestAgent(t *testing.T) {
318
349
t .Skip ("ConPTY appears to be inconsistent on Windows." )
319
350
}
320
351
321
- conn := setupAgent (t , agent.Metadata {
352
+ conn , _ := setupAgent (t , agent.Metadata {
322
353
DERPMap : tailnettest .RunDERPAndSTUN (t ),
323
354
}, 0 )
324
355
id := uuid .NewString ()
@@ -431,7 +462,7 @@ func TestAgent(t *testing.T) {
431
462
}()
432
463
433
464
// Dial the listener over WebRTC twice and test out of order
434
- conn := setupAgent (t , agent.Metadata {}, 0 )
465
+ conn , _ := setupAgent (t , agent.Metadata {}, 0 )
435
466
conn1 , err := conn .DialContext (context .Background (), l .Addr ().Network (), l .Addr ().String ())
436
467
require .NoError (t , err )
437
468
defer conn1 .Close ()
@@ -462,7 +493,7 @@ func TestAgent(t *testing.T) {
462
493
})
463
494
464
495
// Try to dial the non-existent Unix socket over WebRTC
465
- conn := setupAgent (t , agent.Metadata {}, 0 )
496
+ conn , _ := setupAgent (t , agent.Metadata {}, 0 )
466
497
netConn , err := conn .DialContext (context .Background (), "unix" , filepath .Join (tmpDir , "test.sock" ))
467
498
require .Error (t , err )
468
499
require .ErrorContains (t , err , "remote dial error" )
@@ -473,7 +504,7 @@ func TestAgent(t *testing.T) {
473
504
t .Run ("Tailnet" , func (t * testing.T ) {
474
505
t .Parallel ()
475
506
derpMap := tailnettest .RunDERPAndSTUN (t )
476
- conn := setupAgent (t , agent.Metadata {
507
+ conn , _ := setupAgent (t , agent.Metadata {
477
508
DERPMap : derpMap ,
478
509
}, 0 )
479
510
defer conn .Close ()
@@ -485,7 +516,7 @@ func TestAgent(t *testing.T) {
485
516
}
486
517
487
518
func setupSSHCommand (t * testing.T , beforeArgs []string , afterArgs []string ) * exec.Cmd {
488
- agentConn := setupAgent (t , agent.Metadata {}, 0 )
519
+ agentConn , _ := setupAgent (t , agent.Metadata {}, 0 )
489
520
listener , err := net .Listen ("tcp" , "127.0.0.1:0" )
490
521
require .NoError (t , err )
491
522
waitGroup := sync.WaitGroup {}
@@ -523,7 +554,8 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
523
554
}
524
555
525
556
func setupSSHSession (t * testing.T , options agent.Metadata ) * ssh.Session {
526
- sshClient , err := setupAgent (t , options , 0 ).SSHClient ()
557
+ conn , _ := setupAgent (t , options , 0 )
558
+ sshClient , err := conn .SSHClient ()
527
559
require .NoError (t , err )
528
560
t .Cleanup (func () {
529
561
_ = sshClient .Close ()
@@ -533,11 +565,21 @@ func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {
533
565
return session
534
566
}
535
567
536
- func setupAgent (t * testing.T , metadata agent.Metadata , ptyTimeout time.Duration ) agent.Conn {
568
+ type closeFunc func () error
569
+
570
+ func (c closeFunc ) Close () error {
571
+ return c ()
572
+ }
573
+
574
+ func setupAgent (t * testing.T , metadata agent.Metadata , ptyTimeout time.Duration ) (
575
+ agent.Conn ,
576
+ <- chan * agent.Stats ,
577
+ ) {
537
578
client , server := provisionersdk .TransportPipe ()
538
579
tailscale := metadata .DERPMap != nil
539
580
coordinator := tailnet .NewCoordinator ()
540
581
agentID := uuid .New ()
582
+ statsCh := make (chan * agent.Stats )
541
583
closer := agent .New (agent.Options {
542
584
FetchMetadata : func (ctx context.Context ) (agent.Metadata , error ) {
543
585
return metadata , nil
@@ -557,6 +599,38 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
557
599
},
558
600
Logger : slogtest .Make (t , nil ).Leveled (slog .LevelDebug ),
559
601
ReconnectingPTYTimeout : ptyTimeout ,
602
+ StatsReporter : func (ctx context.Context , log slog.Logger , statsFn func () * agent.Stats ) (io.Closer , error ) {
603
+ doneCh := make (chan struct {})
604
+ ctx , cancel := context .WithCancel (ctx )
605
+
606
+ go func () {
607
+ defer close (doneCh )
608
+
609
+ t := time .NewTicker (time .Millisecond * 100 )
610
+ defer t .Stop ()
611
+ for {
612
+ select {
613
+ case <- ctx .Done ():
614
+ return
615
+ case <- t .C :
616
+ }
617
+ select {
618
+ case statsCh <- statsFn ():
619
+ case <- ctx .Done ():
620
+ return
621
+ default :
622
+ // We don't want to send old stats.
623
+ continue
624
+ }
625
+ }
626
+ }()
627
+ return closeFunc (func () error {
628
+ cancel ()
629
+ <- doneCh
630
+ close (statsCh )
631
+ return nil
632
+ }), nil
633
+ },
560
634
})
561
635
t .Cleanup (func () {
562
636
_ = client .Close ()
@@ -586,7 +660,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
586
660
conn .SetNodeCallback (sendNode )
587
661
return & agent.TailnetConn {
588
662
Conn : conn ,
589
- }
663
+ }, statsCh
590
664
}
591
665
conn , err := peerbroker .Dial (stream , []webrtc.ICEServer {}, & peer.ConnOptions {
592
666
Logger : slogtest .Make (t , nil ),
@@ -599,7 +673,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
599
673
return & agent.WebRTCConn {
600
674
Negotiator : api ,
601
675
Conn : conn ,
602
- }
676
+ }, statsCh
603
677
}
604
678
605
679
var dialTestPayload = []byte ("dean-was-here123" )
0 commit comments