@@ -65,6 +65,7 @@ type Options struct {
65
65
WebRTCDialer WebRTCDialer
66
66
FetchMetadata FetchMetadata
67
67
68
+ StatsReporter StatsReporter
68
69
ReconnectingPTYTimeout time.Duration
69
70
EnvironmentVariables map [string ]string
70
71
Logger slog.Logger
@@ -100,6 +101,8 @@ func New(options Options) io.Closer {
100
101
envVars : options .EnvironmentVariables ,
101
102
coordinatorDialer : options .CoordinatorDialer ,
102
103
fetchMetadata : options .FetchMetadata ,
104
+ stats : & Stats {},
105
+ statsReporter : options .StatsReporter ,
103
106
}
104
107
server .init (ctx )
105
108
return server
@@ -125,6 +128,8 @@ type agent struct {
125
128
126
129
network * tailnet.Conn
127
130
coordinatorDialer CoordinatorDialer
131
+ stats * Stats
132
+ statsReporter StatsReporter
128
133
}
129
134
130
135
func (a * agent ) run (ctx context.Context ) {
@@ -194,6 +199,13 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
194
199
a .logger .Critical (ctx , "create tailnet" , slog .Error (err ))
195
200
return
196
201
}
202
+ a .network .SetForwardTCPCallback (func (conn net.Conn , listenerExists bool ) net.Conn {
203
+ if listenerExists {
204
+ // If a listener already exists, we would double-wrap the conn.
205
+ return conn
206
+ }
207
+ return a .stats .wrapConn (conn )
208
+ })
197
209
go a .runCoordinator (ctx )
198
210
199
211
sshListener , err := a .network .Listen ("tcp" , ":" + strconv .Itoa (tailnetSSHPort ))
@@ -207,7 +219,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
207
219
if err != nil {
208
220
return
209
221
}
210
- go a .sshServer .HandleConn (conn )
222
+ a .sshServer .HandleConn (a . stats . wrapConn ( conn ) )
211
223
}
212
224
}()
213
225
reconnectingPTYListener , err := a .network .Listen ("tcp" , ":" + strconv .Itoa (tailnetReconnectingPTYPort ))
@@ -219,8 +231,10 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
219
231
for {
220
232
conn , err := reconnectingPTYListener .Accept ()
221
233
if err != nil {
234
+ a .logger .Debug (ctx , "accept pty failed" , slog .Error (err ))
222
235
return
223
236
}
237
+ conn = a .stats .wrapConn (conn )
224
238
// This cannot use a JSON decoder, since that can
225
239
// buffer additional data that is required for the PTY.
226
240
rawLen := make ([]byte , 2 )
@@ -364,17 +378,17 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error {
364
378
return nil
365
379
}
366
380
367
- func (a * agent ) handlePeerConn (ctx context.Context , conn * peer.Conn ) {
381
+ func (a * agent ) handlePeerConn (ctx context.Context , peerConn * peer.Conn ) {
368
382
go func () {
369
383
select {
370
384
case <- a .closed :
371
- case <- conn .Closed ():
385
+ case <- peerConn .Closed ():
372
386
}
373
- _ = conn .Close ()
387
+ _ = peerConn .Close ()
374
388
a .connCloseWait .Done ()
375
389
}()
376
390
for {
377
- channel , err := conn .Accept (ctx )
391
+ channel , err := peerConn .Accept (ctx )
378
392
if err != nil {
379
393
if errors .Is (err , peer .ErrClosed ) || a .isClosed () {
380
394
return
@@ -383,9 +397,11 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
383
397
return
384
398
}
385
399
400
+ conn := channel .NetConn ()
401
+
386
402
switch channel .Protocol () {
387
403
case ProtocolSSH :
388
- go a .sshServer .HandleConn (channel . NetConn ( ))
404
+ go a .sshServer .HandleConn (a . stats . wrapConn ( conn ))
389
405
case ProtocolReconnectingPTY :
390
406
rawID := channel .Label ()
391
407
// The ID format is referenced in conn.go.
@@ -418,9 +434,9 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
418
434
Height : uint16 (height ),
419
435
Width : uint16 (width ),
420
436
Command : idParts [3 ],
421
- }, channel . NetConn ( ))
437
+ }, a . stats . wrapConn ( conn ))
422
438
case ProtocolDial :
423
- go a .handleDial (ctx , channel .Label (), channel . NetConn ( ))
439
+ go a .handleDial (ctx , channel .Label (), a . stats . wrapConn ( conn ))
424
440
default :
425
441
a .logger .Warn (ctx , "unhandled protocol from channel" ,
426
442
slog .F ("protocol" , channel .Protocol ()),
@@ -514,6 +530,21 @@ func (a *agent) init(ctx context.Context) {
514
530
}
515
531
516
532
go a .run (ctx )
533
+ if a .statsReporter != nil {
534
+ cl , err := a .statsReporter (ctx , a .logger , func () * Stats {
535
+ return a .stats .Copy ()
536
+ })
537
+ if err != nil {
538
+ a .logger .Error (ctx , "report stats" , slog .Error (err ))
539
+ return
540
+ }
541
+ a .connCloseWait .Add (1 )
542
+ go func () {
543
+ defer a .connCloseWait .Done ()
544
+ <- a .closed
545
+ cl .Close ()
546
+ }()
547
+ }
517
548
}
518
549
519
550
// createCommand processes raw command input with OpenSSH-like behavior.
0 commit comments