@@ -17,6 +17,7 @@ import (
17
17
"github.com/google/uuid"
18
18
"go4.org/netipx"
19
19
"golang.org/x/xerrors"
20
+ "gvisor.dev/gvisor/pkg/tcpip"
20
21
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
21
22
"tailscale.com/hostinfo"
22
23
"tailscale.com/ipn/ipnstate"
@@ -44,6 +45,12 @@ import (
44
45
"github.com/coder/coder/cryptorand"
45
46
)
46
47
48
+ const (
49
+ WorkspaceAgentSSHPort = 1
50
+ WorkspaceAgentReconnectingPTYPort = 2
51
+ WorkspaceAgentSpeedtestPort = 3
52
+ )
53
+
47
54
func init () {
48
55
// Globally disable network namespacing. All networking happens in
49
56
// userspace.
@@ -267,6 +274,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
267
274
server .sendNode ()
268
275
})
269
276
netStack .ForwardTCPIn = server .forwardTCP
277
+ netStack .ForwardTCPSockOpts = server .forwardTCPSockOpts
270
278
271
279
err = netStack .Start (nil )
272
280
if err != nil {
@@ -301,17 +309,16 @@ type Conn struct {
301
309
logger slog.Logger
302
310
blockEndpoints bool
303
311
304
- dialer * tsdial.Dialer
305
- tunDevice * tstun.Wrapper
306
- peerMap map [tailcfg.NodeID ]* tailcfg.Node
307
- netMap * netmap.NetworkMap
308
- netStack * netstack.Impl
309
- magicConn * magicsock.Conn
310
- wireguardMonitor * monitor.Mon
311
- wireguardRouter * router.Config
312
- wireguardEngine wgengine.Engine
313
- listeners map [listenKey ]* listener
314
- forwardTCPCallback func (conn net.Conn , listenerExists bool ) net.Conn
312
+ dialer * tsdial.Dialer
313
+ tunDevice * tstun.Wrapper
314
+ peerMap map [tailcfg.NodeID ]* tailcfg.Node
315
+ netMap * netmap.NetworkMap
316
+ netStack * netstack.Impl
317
+ magicConn * magicsock.Conn
318
+ wireguardMonitor * monitor.Mon
319
+ wireguardRouter * router.Config
320
+ wireguardEngine wgengine.Engine
321
+ listeners map [listenKey ]* listener
315
322
316
323
lastMutex sync.Mutex
317
324
nodeSending bool
@@ -327,17 +334,6 @@ type Conn struct {
327
334
trafficStats * connstats.Statistics
328
335
}
329
336
330
- // SetForwardTCPCallback is called every time a TCP connection is initiated inbound.
331
- // listenerExists is true if a listener is registered for the target port. If there
332
- // isn't one, traffic is forwarded to the local listening port.
333
- //
334
- // This allows wrapping a Conn to track reads and writes.
335
- func (c * Conn ) SetForwardTCPCallback (callback func (conn net.Conn , listenerExists bool ) net.Conn ) {
336
- c .mutex .Lock ()
337
- defer c .mutex .Unlock ()
338
- c .forwardTCPCallback = callback
339
- }
340
-
341
337
func (c * Conn ) SetNodeCallback (callback func (node * Node )) {
342
338
c .lastMutex .Lock ()
343
339
c .nodeCallback = callback
@@ -699,12 +695,11 @@ func (c *Conn) selfNode() *Node {
699
695
// This and below is taken _mostly_ verbatim from Tailscale:
700
696
// https://github.com/tailscale/tailscale/blob/c88bd53b1b7b2fcf7ba302f2e53dd1ce8c32dad4/tsnet/tsnet.go#L459-L494
701
697
702
- // Listen announces only on the Tailscale network.
703
- // It will start the server if it has not been started yet.
698
+ // Listen listens for connections only on the Tailscale network.
704
699
func (c * Conn ) Listen (network , addr string ) (net.Listener , error ) {
705
700
host , port , err := net .SplitHostPort (addr )
706
701
if err != nil {
707
- return nil , xerrors .Errorf ("wgnet : %w" , err )
702
+ return nil , xerrors .Errorf ("tailnet: split host port for listen : %w" , err )
708
703
}
709
704
lk := listenKey {network , host , port }
710
705
ln := & listener {
@@ -725,7 +720,7 @@ func (c *Conn) Listen(network, addr string) (net.Listener, error) {
725
720
}
726
721
if _ , ok := c .listeners [lk ]; ok {
727
722
c .mutex .Unlock ()
728
- return nil , xerrors .Errorf ("wgnet : listener already open for %s, %s" , network , addr )
723
+ return nil , xerrors .Errorf ("tailnet : listener already open for %s, %s" , network , addr )
729
724
}
730
725
c .listeners [lk ] = ln
731
726
c .mutex .Unlock ()
@@ -743,14 +738,12 @@ func (c *Conn) DialContextUDP(ctx context.Context, ipp netip.AddrPort) (*gonet.U
743
738
func (c * Conn ) forwardTCP (conn net.Conn , port uint16 ) {
744
739
c .mutex .Lock ()
745
740
ln , ok := c .listeners [listenKey {"tcp" , "" , fmt .Sprint (port )}]
746
- if c .forwardTCPCallback != nil {
747
- conn = c .forwardTCPCallback (conn , ok )
748
- }
749
741
c .mutex .Unlock ()
750
742
if ! ok {
751
743
c .forwardTCPToLocal (conn , port )
752
744
return
753
745
}
746
+
754
747
t := time .NewTimer (time .Second )
755
748
defer t .Stop ()
756
749
select {
@@ -763,6 +756,18 @@ func (c *Conn) forwardTCP(conn net.Conn, port uint16) {
763
756
_ = conn .Close ()
764
757
}
765
758
759
+ func (* Conn ) forwardTCPSockOpts (port uint16 ) []tcpip.SettableSocketOption {
760
+ opts := []tcpip.SettableSocketOption {}
761
+
762
+ // See: https://github.com/tailscale/tailscale/blob/c7cea825aea39a00aca71ea02bab7266afc03e7c/wgengine/netstack/netstack.go#L888
763
+ if port == WorkspaceAgentSSHPort || port == 22 {
764
+ opt := tcpip .KeepaliveIdleOption (72 * time .Hour )
765
+ opts = append (opts , & opt )
766
+ }
767
+
768
+ return opts
769
+ }
770
+
766
771
func (c * Conn ) forwardTCPToLocal (conn net.Conn , port uint16 ) {
767
772
defer conn .Close ()
768
773
dialAddrStr := net .JoinHostPort ("127.0.0.1" , strconv .Itoa (int (port )))
@@ -842,7 +847,7 @@ func (ln *listener) Accept() (net.Conn, error) {
842
847
select {
843
848
case c = <- ln .conn :
844
849
case <- ln .closed :
845
- return nil , xerrors .Errorf ("wgnet : %w" , net .ErrClosed )
850
+ return nil , xerrors .Errorf ("tailnet : %w" , net .ErrClosed )
846
851
}
847
852
return c , nil
848
853
}
0 commit comments