1
1
package tailnet
2
2
3
3
import (
4
+ "context"
5
+ "io"
6
+ "net"
4
7
"strconv"
5
8
"strings"
9
+ "sync/atomic"
10
+
11
+ "github.com/google/uuid"
12
+ "github.com/hashicorp/yamux"
13
+ "storj.io/drpc/drpcmux"
14
+ "storj.io/drpc/drpcserver"
15
+
16
+ "cdr.dev/slog"
17
+ "github.com/coder/coder/v2/tailnet/proto"
6
18
7
19
"golang.org/x/xerrors"
8
20
)
@@ -15,17 +27,9 @@ const (
15
27
var SupportedMajors = []int {2 , 1 }
16
28
17
29
func ValidateVersion (version string ) error {
18
- parts := strings .Split (version , "." )
19
- if len (parts ) != 2 {
20
- return xerrors .Errorf ("invalid version string: %s" , version )
21
- }
22
- major , err := strconv .Atoi (parts [0 ])
23
- if err != nil {
24
- return xerrors .Errorf ("invalid major version: %s" , version )
25
- }
26
- minor , err := strconv .Atoi (parts [1 ])
30
+ major , minor , err := parseVersion (version )
27
31
if err != nil {
28
- return xerrors . Errorf ( "invalid minor version: %s" , version )
32
+ return err
29
33
}
30
34
if major > CurrentMajor {
31
35
return xerrors .Errorf ("server is at version %d.%d, behind requested version %s" ,
@@ -45,3 +49,186 @@ func ValidateVersion(version string) error {
45
49
}
46
50
return xerrors .Errorf ("version %s is no longer supported" , version )
47
51
}
52
+
53
+ func parseVersion (version string ) (major int , minor int , err error ) {
54
+ parts := strings .Split (version , "." )
55
+ if len (parts ) != 2 {
56
+ return 0 , 0 , xerrors .Errorf ("invalid version string: %s" , version )
57
+ }
58
+ major , err = strconv .Atoi (parts [0 ])
59
+ if err != nil {
60
+ return 0 , 0 , xerrors .Errorf ("invalid major version: %s" , version )
61
+ }
62
+ minor , err = strconv .Atoi (parts [1 ])
63
+ if err != nil {
64
+ return 0 , 0 , xerrors .Errorf ("invalid minor version: %s" , version )
65
+ }
66
+ return major , minor , nil
67
+ }
68
+
69
+ type streamIDContextKey struct {}
70
+
71
+ // StreamID identifies the caller of the CoordinateTailnet RPC. We store this
72
+ // on the context, since the information is extracted at the HTTP layer for
73
+ // remote clients of the API, or set outside tailnet for local clients (e.g.
74
+ // Coderd's single_tailnet)
75
+ type StreamID struct {
76
+ Name string
77
+ ID uuid.UUID
78
+ Auth TunnelAuth
79
+ }
80
+
81
+ func WithStreamID (ctx context.Context , streamID StreamID ) context.Context {
82
+ return context .WithValue (ctx , streamIDContextKey {}, streamID )
83
+ }
84
+
85
+ // ClientService is a tailnet coordination service that accepts a connection and version from a
86
+ // tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol.
87
+ type ClientService struct {
88
+ logger slog.Logger
89
+ coordPtr * atomic.Pointer [Coordinator ]
90
+ drpc * drpcserver.Server
91
+ }
92
+
93
+ // NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
94
+ // loaded on each processed connection.
95
+ func NewClientService (logger slog.Logger , coordPtr * atomic.Pointer [Coordinator ]) (* ClientService , error ) {
96
+ s := & ClientService {logger : logger , coordPtr : coordPtr }
97
+ mux := drpcmux .New ()
98
+ drpcService := NewDRPCService (logger , coordPtr )
99
+ err := proto .DRPCRegisterClient (mux , drpcService )
100
+ if err != nil {
101
+ return nil , xerrors .Errorf ("register DRPC service: %w" , err )
102
+ }
103
+ server := drpcserver .NewWithOptions (mux , drpcserver.Options {
104
+ Log : func (err error ) {
105
+ if xerrors .Is (err , io .EOF ) {
106
+ return
107
+ }
108
+ logger .Debug (context .Background (), "drpc server error" , slog .Error (err ))
109
+ },
110
+ })
111
+ s .drpc = server
112
+ return s , nil
113
+ }
114
+
115
+ func (s * ClientService ) ServeClient (ctx context.Context , version string , conn net.Conn , id uuid.UUID , agent uuid.UUID ) error {
116
+ major , _ , err := parseVersion (version )
117
+ if err != nil {
118
+ s .logger .Warn (ctx , "serve client called with unparsable version" , slog .Error (err ))
119
+ return err
120
+ }
121
+ switch major {
122
+ case 1 :
123
+ coord := * (s .coordPtr .Load ())
124
+ return coord .ServeClient (conn , id , agent )
125
+ case 2 :
126
+ config := yamux .DefaultConfig ()
127
+ config .LogOutput = io .Discard
128
+ session , err := yamux .Server (conn , config )
129
+ if err != nil {
130
+ return xerrors .Errorf ("yamux init failed: %w" , err )
131
+ }
132
+ auth := ClientTunnelAuth {AgentID : agent }
133
+ streamID := StreamID {
134
+ Name : "client" ,
135
+ ID : id ,
136
+ Auth : auth ,
137
+ }
138
+ ctx = WithStreamID (ctx , streamID )
139
+ return s .drpc .Serve (ctx , session )
140
+ default :
141
+ s .logger .Warn (ctx , "serve client called with unsupported version" , slog .F ("version" , version ))
142
+ return xerrors .New ("unsupported version" )
143
+ }
144
+ }
145
+
146
+ // DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
147
+ type DRPCService struct {
148
+ coordPtr * atomic.Pointer [Coordinator ]
149
+ logger slog.Logger
150
+ }
151
+
152
+ func NewDRPCService (logger slog.Logger , coordPtr * atomic.Pointer [Coordinator ]) * DRPCService {
153
+ return & DRPCService {
154
+ coordPtr : coordPtr ,
155
+ logger : logger ,
156
+ }
157
+ }
158
+
159
+ func (* DRPCService ) StreamDERPMaps (* proto.StreamDERPMapsRequest , proto.DRPCClient_StreamDERPMapsStream ) error {
160
+ // TODO integrate with Dean's PR implementation
161
+ return xerrors .New ("unimplemented" )
162
+ }
163
+
164
+ func (s * DRPCService ) CoordinateTailnet (stream proto.DRPCClient_CoordinateTailnetStream ) error {
165
+ ctx := stream .Context ()
166
+ streamID , ok := ctx .Value (streamIDContextKey {}).(StreamID )
167
+ if ! ok {
168
+ _ = stream .Close ()
169
+ return xerrors .New ("no Stream ID" )
170
+ }
171
+ logger := s .logger .With (slog .F ("peer_id" , streamID ), slog .F ("name" , streamID .Name ))
172
+ logger .Debug (ctx , "starting tailnet Coordinate" )
173
+ coord := * (s .coordPtr .Load ())
174
+ reqs , resps := coord .Coordinate (ctx , streamID .ID , streamID .Name , streamID .Auth )
175
+ c := communicator {
176
+ logger : logger ,
177
+ stream : stream ,
178
+ reqs : reqs ,
179
+ resps : resps ,
180
+ }
181
+ c .communicate ()
182
+ return nil
183
+ }
184
+
185
+ type communicator struct {
186
+ logger slog.Logger
187
+ stream proto.DRPCClient_CoordinateTailnetStream
188
+ reqs chan <- * proto.CoordinateRequest
189
+ resps <- chan * proto.CoordinateResponse
190
+ }
191
+
192
+ func (c communicator ) communicate () {
193
+ go c .loopReq ()
194
+ c .loopResp ()
195
+ }
196
+
197
+ func (c communicator ) loopReq () {
198
+ ctx := c .stream .Context ()
199
+ defer close (c .reqs )
200
+ for {
201
+ req , err := c .stream .Recv ()
202
+ if err != nil {
203
+ c .logger .Debug (ctx , "error receiving requests from DRPC stream" , slog .Error (err ))
204
+ return
205
+ }
206
+ err = SendCtx (ctx , c .reqs , req )
207
+ if err != nil {
208
+ c .logger .Debug (ctx , "context done while sending coordinate request" , slog .Error (ctx .Err ()))
209
+ return
210
+ }
211
+ }
212
+ }
213
+
214
+ func (c communicator ) loopResp () {
215
+ ctx := c .stream .Context ()
216
+ defer func () {
217
+ err := c .stream .Close ()
218
+ if err != nil {
219
+ c .logger .Debug (ctx , "loopResp hit error closing stream" , slog .Error (err ))
220
+ }
221
+ }()
222
+ for {
223
+ resp , err := RecvCtx (ctx , c .resps )
224
+ if err != nil {
225
+ c .logger .Debug (ctx , "loopResp failed to get response" , slog .Error (err ))
226
+ return
227
+ }
228
+ err = c .stream .Send (resp )
229
+ if err != nil {
230
+ c .logger .Debug (ctx , "loopResp failed to send response to DRPC stream" , slog .Error (err ))
231
+ return
232
+ }
233
+ }
234
+ }
0 commit comments