@@ -65,6 +65,7 @@ type Options struct {
65
65
WebRTCDialer WebRTCDialer
66
66
FetchMetadata FetchMetadata
67
67
68
+ StatsReporter StatsReporter
68
69
ReconnectingPTYTimeout time.Duration
69
70
EnvironmentVariables map [string ]string
70
71
Logger slog.Logger
@@ -100,6 +101,10 @@ func New(options Options) io.Closer {
100
101
envVars : options .EnvironmentVariables ,
101
102
coordinatorDialer : options .CoordinatorDialer ,
102
103
fetchMetadata : options .FetchMetadata ,
104
+ stats : & Stats {
105
+ ProtocolStats : make (map [string ]* ProtocolStats ),
106
+ },
107
+ statsReporter : options .StatsReporter ,
103
108
}
104
109
server .init (ctx )
105
110
return server
@@ -125,6 +130,8 @@ type agent struct {
125
130
126
131
network * tailnet.Conn
127
132
coordinatorDialer CoordinatorDialer
133
+ stats * Stats
134
+ statsReporter StatsReporter
128
135
}
129
136
130
137
func (a * agent ) run (ctx context.Context ) {
@@ -194,6 +201,12 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
194
201
a .logger .Critical (ctx , "create tailnet" , slog .Error (err ))
195
202
return
196
203
}
204
+ a .network .SetForwardTCPCallback (func (conn net.Conn , listenerExists bool ) net.Conn {
205
+ if listenerExists {
206
+ return conn
207
+ }
208
+ return & ConnStats {ProtocolStats : & ProtocolStats {}, Conn : conn }
209
+ })
197
210
go a .runCoordinator (ctx )
198
211
199
212
sshListener , err := a .network .Listen ("tcp" , ":" + strconv .Itoa (tailnetSSHPort ))
@@ -207,7 +220,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
207
220
if err != nil {
208
221
return
209
222
}
210
- go a .sshServer .HandleConn (conn )
223
+ a .sshServer .HandleConn (a . stats . wrapConn ( conn , ProtocolSSH ) )
211
224
}
212
225
}()
213
226
reconnectingPTYListener , err := a .network .Listen ("tcp" , ":" + strconv .Itoa (tailnetReconnectingPTYPort ))
@@ -364,17 +377,17 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error {
364
377
return nil
365
378
}
366
379
367
- func (a * agent ) handlePeerConn (ctx context.Context , conn * peer.Conn ) {
380
+ func (a * agent ) handlePeerConn (ctx context.Context , peerConn * peer.Conn ) {
368
381
go func () {
369
382
select {
370
383
case <- a .closed :
371
- case <- conn .Closed ():
384
+ case <- peerConn .Closed ():
372
385
}
373
- _ = conn .Close ()
386
+ _ = peerConn .Close ()
374
387
a .connCloseWait .Done ()
375
388
}()
376
389
for {
377
- channel , err := conn .Accept (ctx )
390
+ channel , err := peerConn .Accept (ctx )
378
391
if err != nil {
379
392
if errors .Is (err , peer .ErrClosed ) || a .isClosed () {
380
393
return
@@ -383,44 +396,46 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
383
396
return
384
397
}
385
398
399
+ conn := channel .NetConn ()
400
+
386
401
switch channel .Protocol () {
387
402
case ProtocolSSH :
388
- go a .sshServer .HandleConn (channel .NetConn ( ))
403
+ go a .sshServer .HandleConn (a . stats . wrapConn ( conn , channel .Protocol () ))
389
404
case ProtocolReconnectingPTY :
390
405
rawID := channel .Label ()
391
406
// The ID format is referenced in conn.go.
392
407
// <uuid>:<height>:<width>
393
408
idParts := strings .SplitN (rawID , ":" , 4 )
394
409
if len (idParts ) != 4 {
395
410
a .logger .Warn (ctx , "client sent invalid id format" , slog .F ("raw-id" , rawID ))
396
- continue
411
+ return
397
412
}
398
413
id := idParts [0 ]
399
414
// Enforce a consistent format for IDs.
400
415
_ , err := uuid .Parse (id )
401
416
if err != nil {
402
417
a .logger .Warn (ctx , "client sent reconnection token that isn't a uuid" , slog .F ("id" , id ), slog .Error (err ))
403
- continue
418
+ return
404
419
}
405
420
// Parse the initial terminal dimensions.
406
421
height , err := strconv .Atoi (idParts [1 ])
407
422
if err != nil {
408
423
a .logger .Warn (ctx , "client sent invalid height" , slog .F ("id" , id ), slog .F ("height" , idParts [1 ]))
409
- continue
424
+ return
410
425
}
411
426
width , err := strconv .Atoi (idParts [2 ])
412
427
if err != nil {
413
428
a .logger .Warn (ctx , "client sent invalid width" , slog .F ("id" , id ), slog .F ("width" , idParts [2 ]))
414
- continue
429
+ return
415
430
}
416
431
go a .handleReconnectingPTY (ctx , reconnectingPTYInit {
417
432
ID : id ,
418
433
Height : uint16 (height ),
419
434
Width : uint16 (width ),
420
435
Command : idParts [3 ],
421
- }, channel .NetConn ( ))
436
+ }, a . stats . wrapConn ( conn , channel .Protocol () ))
422
437
case ProtocolDial :
423
- go a .handleDial (ctx , channel .Label (), channel .NetConn ( ))
438
+ go a .handleDial (ctx , channel .Label (), a . stats . wrapConn ( conn , channel .Protocol () ))
424
439
default :
425
440
a .logger .Warn (ctx , "unhandled protocol from channel" ,
426
441
slog .F ("protocol" , channel .Protocol ()),
@@ -514,6 +529,21 @@ func (a *agent) init(ctx context.Context) {
514
529
}
515
530
516
531
go a .run (ctx )
532
+ if a .statsReporter != nil {
533
+ cl , err := a .statsReporter (ctx , a .logger , func () * Stats {
534
+ return a .stats .Copy ()
535
+ })
536
+ if err != nil {
537
+ a .logger .Error (ctx , "report stats" , slog .Error (err ))
538
+ return
539
+ }
540
+ a .connCloseWait .Add (1 )
541
+ go func () {
542
+ defer a .connCloseWait .Done ()
543
+ <- a .closed
544
+ cl .Close ()
545
+ }()
546
+ }
517
547
}
518
548
519
549
// createCommand processes raw command input with OpenSSH-like behavior.
0 commit comments