@@ -28,14 +28,14 @@ import (
28
28
gossh "golang.org/x/crypto/ssh"
29
29
"golang.org/x/xerrors"
30
30
"inet.af/netaddr"
31
- "tailscale.com/types/key "
31
+ "tailscale.com/tailcfg "
32
32
33
33
"cdr.dev/slog"
34
34
"github.com/coder/coder/agent/usershell"
35
35
"github.com/coder/coder/peer"
36
- "github.com/coder/coder/peer/peerwg"
37
36
"github.com/coder/coder/peerbroker"
38
37
"github.com/coder/coder/pty"
38
+ "github.com/coder/coder/tailnet"
39
39
"github.com/coder/retry"
40
40
)
41
41
@@ -47,30 +47,28 @@ const (
47
47
48
48
type Options struct {
49
49
EnableWireguard bool
50
- UploadWireguardKeys UploadWireguardKeys
51
- ListenWireguardPeers ListenWireguardPeers
50
+ UpdateTailscaleNode UpdateTailscaleNode
51
+ ListenTailscaleNodes ListenTailscaleNodes
52
52
ReconnectingPTYTimeout time.Duration
53
53
EnvironmentVariables map [string ]string
54
54
Logger slog.Logger
55
55
}
56
56
57
57
type Metadata struct {
58
- WireguardAddresses []netaddr.IPPrefix `json:"addresses"`
59
- OwnerEmail string `json:"owner_email"`
60
- OwnerUsername string `json:"owner_username"`
61
- EnvironmentVariables map [string ]string `json:"environment_variables"`
62
- StartupScript string `json:"startup_script"`
63
- Directory string `json:"directory"`
64
- }
65
-
66
- type WireguardPublicKeys struct {
67
- Public key.NodePublic `json:"public"`
68
- Disco key.DiscoPublic `json:"disco"`
58
+ TailscaleAddresses []netaddr.IPPrefix `json:"tailscale_addresses"`
59
+ TailscaleDERPMap * tailcfg.DERPMap `json:"tailscale_derpmap"`
60
+
61
+ OwnerEmail string `json:"owner_email"`
62
+ OwnerUsername string `json:"owner_username"`
63
+ EnvironmentVariables map [string ]string `json:"environment_variables"`
64
+ StartupScript string `json:"startup_script"`
65
+ Directory string `json:"directory"`
69
66
}
70
67
71
68
type Dialer func (ctx context.Context , logger slog.Logger ) (Metadata , * peerbroker.Listener , error )
72
- type UploadWireguardKeys func (ctx context.Context , keys WireguardPublicKeys ) error
73
- type ListenWireguardPeers func (ctx context.Context , logger slog.Logger ) (<- chan peerwg.Handshake , func (), error )
69
+
70
+ type UpdateTailscaleNode func (ctx context.Context , node * tailnet.Node ) error
71
+ type ListenTailscaleNodes func (ctx context.Context , logger slog.Logger ) (<- chan * tailnet.Node , func (), error )
74
72
75
73
func New (dialer Dialer , options * Options ) io.Closer {
76
74
if options == nil {
@@ -88,8 +86,8 @@ func New(dialer Dialer, options *Options) io.Closer {
88
86
closed : make (chan struct {}),
89
87
envVars : options .EnvironmentVariables ,
90
88
enableWireguard : options .EnableWireguard ,
91
- postKeys : options .UploadWireguardKeys ,
92
- listenWireguardPeers : options .ListenWireguardPeers ,
89
+ updateTailscaleNode : options .UpdateTailscaleNode ,
90
+ listenTailscaleNodes : options .ListenTailscaleNodes ,
93
91
}
94
92
server .init (ctx )
95
93
return server
@@ -114,9 +112,9 @@ type agent struct {
114
112
sshServer * ssh.Server
115
113
116
114
enableWireguard bool
117
- network * peerwg. Network
118
- postKeys UploadWireguardKeys
119
- listenWireguardPeers ListenWireguardPeers
115
+ network * tailnet. Server
116
+ updateTailscaleNode UpdateTailscaleNode
117
+ listenTailscaleNodes ListenTailscaleNodes
120
118
}
121
119
122
120
func (a * agent ) run (ctx context.Context ) {
@@ -160,8 +158,9 @@ func (a *agent) run(ctx context.Context) {
160
158
}()
161
159
}
162
160
163
- if a .enableWireguard {
164
- err = a .startWireguard (ctx , metadata .WireguardAddresses )
161
+ // We don't want to reinitialize the network if it already exists.
162
+ if a .enableWireguard && a .network == nil {
163
+ err = a .startWireguard (ctx , metadata .TailscaleAddresses , metadata .TailscaleDERPMap )
165
164
if err != nil {
166
165
a .logger .Error (ctx , "start wireguard" , slog .Error (err ))
167
166
}
@@ -668,6 +667,71 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
668
667
}
669
668
}
670
669
670
+ func (a * agent ) startWireguard (ctx context.Context , addresses []netaddr.IPPrefix , derpMap * tailcfg.DERPMap ) error {
671
+ var err error
672
+ a .network , err = tailnet .New (& tailnet.Options {
673
+ Addresses : addresses ,
674
+ DERPMap : derpMap ,
675
+ Logger : a .logger .Named ("tailnet" ),
676
+ })
677
+ if err != nil {
678
+ return err
679
+ }
680
+ a .network .SetNodeCallback (func (node * tailnet.Node ) {
681
+ err := a .updateTailscaleNode (ctx , node )
682
+ if err != nil {
683
+ a .logger .Error (ctx , "update tailscale node" , slog .Error (err ))
684
+ }
685
+ })
686
+ go func () {
687
+ for {
688
+ var nodes <- chan * tailnet.Node
689
+ var err error
690
+ var listenClose func ()
691
+ for retrier := retry .New (50 * time .Millisecond , 10 * time .Second ); retrier .Wait (ctx ); {
692
+ nodes , listenClose , err = a .listenTailscaleNodes (ctx , a .logger )
693
+ if err != nil {
694
+ if errors .Is (err , context .Canceled ) {
695
+ return
696
+ }
697
+ a .logger .Warn (ctx , "listen for tailscale nodes" , slog .Error (err ))
698
+ continue
699
+ }
700
+ defer listenClose ()
701
+ a .logger .Info (context .Background (), "listening for tailscale nodes" )
702
+ break
703
+ }
704
+ for {
705
+ var node * tailnet.Node
706
+ select {
707
+ case <- ctx .Done ():
708
+ case node = <- nodes :
709
+ }
710
+ if node == nil {
711
+ // The channel ended!
712
+ break
713
+ }
714
+ a .network .UpdateNodes ([]* tailnet.Node {node })
715
+ }
716
+ }
717
+ }()
718
+
719
+ sshListener , err := a .network .Listen ("tcp" , ":12212" )
720
+ if err != nil {
721
+ return xerrors .Errorf ("listen for ssh: %w" , err )
722
+ }
723
+ go func () {
724
+ for {
725
+ conn , err := sshListener .Accept ()
726
+ if err != nil {
727
+ return
728
+ }
729
+ go a .sshServer .HandleConn (conn )
730
+ }
731
+ }()
732
+ return nil
733
+ }
734
+
671
735
// dialResponse is written to datachannels with protocol "dial" by the agent as
672
736
// the first packet to signify whether the dial succeeded or failed.
673
737
type dialResponse struct {
0 commit comments