@@ -30,6 +30,7 @@ import (
30
30
"github.com/spf13/afero"
31
31
"go.uber.org/atomic"
32
32
gossh "golang.org/x/crypto/ssh"
33
+ "golang.org/x/exp/slices"
33
34
"golang.org/x/xerrors"
34
35
"tailscale.com/net/speedtest"
35
36
"tailscale.com/tailcfg"
@@ -90,7 +91,7 @@ func New(options Options) io.Closer {
90
91
}
91
92
}
92
93
ctx , cancelFunc := context .WithCancel (context .Background ())
93
- server := & agent {
94
+ a := & agent {
94
95
reconnectingPTYTimeout : options .ReconnectingPTYTimeout ,
95
96
logger : options .Logger ,
96
97
closeCancel : cancelFunc ,
@@ -101,8 +102,8 @@ func New(options Options) io.Closer {
101
102
filesystem : options .Filesystem ,
102
103
tempDir : options .TempDir ,
103
104
}
104
- server .init (ctx )
105
- return server
105
+ a .init (ctx )
106
+ return a
106
107
}
107
108
108
109
type agent struct {
@@ -215,8 +216,16 @@ func (a *agent) run(ctx context.Context) error {
215
216
return xerrors .Errorf ("create tailnet: %w" , err )
216
217
}
217
218
a .closeMutex .Lock ()
218
- a .network = network
219
+ // Re-check if agent was closed while initializing the network.
220
+ closed := a .isClosed ()
221
+ if ! closed {
222
+ a .network = network
223
+ }
219
224
a .closeMutex .Unlock ()
225
+ if closed {
226
+ _ = network .Close ()
227
+ return xerrors .New ("agent is closed" )
228
+ }
220
229
} else {
221
230
// Update the DERP map!
222
231
network .SetDERPMap (metadata .DERPMap )
@@ -246,27 +255,20 @@ func (a *agent) trackConnGoroutine(fn func()) error {
246
255
}
247
256
248
257
func (a * agent ) createTailnet (ctx context.Context , derpMap * tailcfg.DERPMap ) (_ * tailnet.Conn , err error ) {
249
- a .closeMutex .Lock ()
250
- if a .isClosed () {
251
- a .closeMutex .Unlock ()
252
- return nil , xerrors .New ("closed" )
253
- }
254
258
network , err := tailnet .NewConn (& tailnet.Options {
255
259
Addresses : []netip.Prefix {netip .PrefixFrom (codersdk .TailnetIP , 128 )},
256
260
DERPMap : derpMap ,
257
261
Logger : a .logger .Named ("tailnet" ),
258
262
EnableTrafficStats : true ,
259
263
})
260
264
if err != nil {
261
- a .closeMutex .Unlock ()
262
265
return nil , xerrors .Errorf ("create tailnet: %w" , err )
263
266
}
264
267
defer func () {
265
268
if err != nil {
266
269
network .Close ()
267
270
}
268
271
}()
269
- a .closeMutex .Unlock ()
270
272
271
273
sshListener , err := network .Listen ("tcp" , ":" + strconv .Itoa (codersdk .TailnetSSHPort ))
272
274
if err != nil {
@@ -299,10 +301,12 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
299
301
}
300
302
}()
301
303
if err = a .trackConnGoroutine (func () {
304
+ logger := a .logger .Named ("reconnecting-pty" )
305
+
302
306
for {
303
307
conn , err := reconnectingPTYListener .Accept ()
304
308
if err != nil {
305
- a . logger .Debug (ctx , "accept pty failed" , slog .Error (err ))
309
+ logger .Debug (ctx , "accept pty failed" , slog .Error (err ))
306
310
return
307
311
}
308
312
// This cannot use a JSON decoder, since that can
@@ -323,7 +327,9 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
323
327
if err != nil {
324
328
continue
325
329
}
326
- go a .handleReconnectingPTY (ctx , msg , conn )
330
+ go func () {
331
+ _ = a .handleReconnectingPTY (ctx , logger , msg , conn )
332
+ }()
327
333
}
328
334
}); err != nil {
329
335
return nil , err
@@ -797,38 +803,56 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
797
803
return cmd .Wait ()
798
804
}
799
805
800
- func (a * agent ) handleReconnectingPTY (ctx context.Context , msg codersdk.ReconnectingPTYInit , conn net.Conn ) {
806
+ func (a * agent ) handleReconnectingPTY (ctx context.Context , logger slog. Logger , msg codersdk.ReconnectingPTYInit , conn net.Conn ) ( retErr error ) {
801
807
defer conn .Close ()
802
808
803
809
connectionID := uuid .NewString ()
810
+ logger = logger .With (slog .F ("id" , msg .ID ), slog .F ("connection_id" , connectionID ))
811
+
812
+ defer func () {
813
+ if err := retErr ; err != nil {
814
+ a .closeMutex .Lock ()
815
+ closed := a .isClosed ()
816
+ a .closeMutex .Unlock ()
817
+
818
+ // If the agent is closed, we don't want to
819
+ // log this as an error since it's expected.
820
+ if closed {
821
+ logger .Debug (ctx , "session error after agent close" , slog .Error (err ))
822
+ } else {
823
+ logger .Error (ctx , "session error" , slog .Error (err ))
824
+ }
825
+ }
826
+ logger .Debug (ctx , "session closed" )
827
+ }()
828
+
804
829
var rpty * reconnectingPTY
805
830
rawRPTY , ok := a .reconnectingPTYs .Load (msg .ID )
806
831
if ok {
832
+ logger .Debug (ctx , "connecting to existing session" )
807
833
rpty , ok = rawRPTY .(* reconnectingPTY )
808
834
if ! ok {
809
- a .logger .Error (ctx , "found invalid type in reconnecting pty map" , slog .F ("id" , msg .ID ))
810
- return
835
+ return xerrors .Errorf ("found invalid type in reconnecting pty map: %T" , rawRPTY )
811
836
}
812
837
} else {
838
+ logger .Debug (ctx , "creating new session" )
839
+
813
840
// Empty command will default to the users shell!
814
841
cmd , err := a .createCommand (ctx , msg .Command , nil )
815
842
if err != nil {
816
- a .logger .Error (ctx , "create reconnecting pty command" , slog .Error (err ))
817
- return
843
+ return xerrors .Errorf ("create command: %w" , err )
818
844
}
819
845
cmd .Env = append (cmd .Env , "TERM=xterm-256color" )
820
846
821
847
// Default to buffer 64KiB.
822
848
circularBuffer , err := circbuf .NewBuffer (64 << 10 )
823
849
if err != nil {
824
- a .logger .Error (ctx , "create circular buffer" , slog .Error (err ))
825
- return
850
+ return xerrors .Errorf ("create circular buffer: %w" , err )
826
851
}
827
852
828
853
ptty , process , err := pty .Start (cmd )
829
854
if err != nil {
830
- a .logger .Error (ctx , "start reconnecting pty command" , slog .F ("id" , msg .ID ), slog .Error (err ))
831
- return
855
+ return xerrors .Errorf ("start command: %w" , err )
832
856
}
833
857
834
858
ctx , cancelFunc := context .WithCancel (ctx )
@@ -872,7 +896,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
872
896
_ , err = rpty .circularBuffer .Write (part )
873
897
rpty .circularBufferMutex .Unlock ()
874
898
if err != nil {
875
- a . logger .Error (ctx , "reconnecting pty write buffer" , slog .Error (err ), slog . F ( "id" , msg . ID ))
899
+ logger .Error (ctx , "write to circular buffer" , slog .Error (err ))
876
900
break
877
901
}
878
902
rpty .activeConnsMutex .Lock ()
@@ -888,23 +912,27 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
888
912
rpty .Close ()
889
913
a .reconnectingPTYs .Delete (msg .ID )
890
914
}); err != nil {
891
- a .logger .Error (ctx , "start reconnecting pty routine" , slog .F ("id" , msg .ID ), slog .Error (err ))
892
- return
915
+ return xerrors .Errorf ("start routine: %w" , err )
893
916
}
894
917
}
895
918
// Resize the PTY to initial height + width.
896
919
err := rpty .ptty .Resize (msg .Height , msg .Width )
897
920
if err != nil {
898
921
// We can continue after this, it's not fatal!
899
- a . logger .Error (ctx , "resize reconnecting pty" , slog . F ( "id" , msg . ID ) , slog .Error (err ))
922
+ logger .Error (ctx , "resize" , slog .Error (err ))
900
923
}
901
924
// Write any previously stored data for the TTY.
902
925
rpty .circularBufferMutex .RLock ()
903
- _ , err = conn . Write (rpty .circularBuffer .Bytes ())
926
+ prevBuf := slices . Clone (rpty .circularBuffer .Bytes ())
904
927
rpty .circularBufferMutex .RUnlock ()
928
+ // Note that there is a small race here between writing buffered
929
+ // data and storing conn in activeConns. This is likely a very minor
930
+ // edge case, but we should look into ways to avoid it. Holding
931
+ // activeConnsMutex would be one option, but holding this mutex
932
+ // while also holding circularBufferMutex seems dangerous.
933
+ _ , err = conn .Write (prevBuf )
905
934
if err != nil {
906
- a .logger .Warn (ctx , "write reconnecting pty buffer" , slog .F ("id" , msg .ID ), slog .Error (err ))
907
- return
935
+ return xerrors .Errorf ("write buffer to conn: %w" , err )
908
936
}
909
937
// Multiple connections to the same TTY are permitted.
910
938
// This could easily be used for terminal sharing, but
@@ -945,16 +973,16 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
945
973
for {
946
974
err = decoder .Decode (& req )
947
975
if xerrors .Is (err , io .EOF ) {
948
- return
976
+ return nil
949
977
}
950
978
if err != nil {
951
- a . logger .Warn (ctx , "reconnecting pty buffer read error" , slog . F ( "id" , msg . ID ) , slog .Error (err ))
952
- return
979
+ logger .Warn (ctx , "read conn" , slog .Error (err ))
980
+ return nil
953
981
}
954
982
_ , err = rpty .ptty .Input ().Write ([]byte (req .Data ))
955
983
if err != nil {
956
- a . logger .Warn (ctx , "write to reconnecting pty" , slog . F ( "id" , msg . ID ) , slog .Error (err ))
957
- return
984
+ logger .Warn (ctx , "write to pty" , slog .Error (err ))
985
+ return nil
958
986
}
959
987
// Check if a resize needs to happen!
960
988
if req .Height == 0 || req .Width == 0 {
@@ -963,7 +991,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
963
991
err = rpty .ptty .Resize (req .Height , req .Width )
964
992
if err != nil {
965
993
// We can continue after this, it's not fatal!
966
- a . logger .Error (ctx , "resize reconnecting pty" , slog . F ( "id" , msg . ID ) , slog .Error (err ))
994
+ logger .Error (ctx , "resize" , slog .Error (err ))
967
995
}
968
996
}
969
997
}
0 commit comments