Skip to content

Commit 64caaac

Browse files
committed
feat: add tailnet v2 support to wsproxy coordinate endpoint
1 parent 175ebed commit 64caaac

File tree

4 files changed

+99
-19
lines changed

4 files changed

+99
-19
lines changed

enterprise/coderd/coderd.go

+10
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
128128
}
129129
return api.fetchRegions(ctx)
130130
}
131+
api.tailnetService, err = tailnet.NewClientService(
132+
api.Logger.Named("tailnetclient"),
133+
&api.AGPL.TailnetCoordinator,
134+
api.Options.DERPMapUpdateFrequency,
135+
api.AGPL.DERPMap,
136+
)
137+
if err != nil {
138+
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))
139+
}
131140

132141
oauthConfigs := &httpmw.OAuth2Configs{
133142
Github: options.GithubOAuth2Config,
@@ -483,6 +492,7 @@ type API struct {
483492
provisionerDaemonAuth *provisionerDaemonAuth
484493

485494
licenseMetricsCollector license.MetricsCollector
495+
tailnetService *tailnet.ClientService
486496
}
487497

488498
func (api *API) Close() error {

enterprise/coderd/workspaceproxycoordinate.go

+20-5
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import (
99
"github.com/coder/coder/v2/coderd/httpapi"
1010
"github.com/coder/coder/v2/coderd/httpmw"
1111
"github.com/coder/coder/v2/codersdk"
12-
"github.com/coder/coder/v2/enterprise/tailnet"
1312
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
13+
agpl "github.com/coder/coder/v2/tailnet"
1414
)
1515

1616
// @Summary Agent is legacy
@@ -52,6 +52,21 @@ func (api *API) agentIsLegacy(rw http.ResponseWriter, r *http.Request) {
5252
func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request) {
5353
ctx := r.Context()
5454

55+
version := "1.0"
56+
qv := r.URL.Query().Get("version")
57+
if qv != "" {
58+
version = qv
59+
}
60+
if err := agpl.CurrentVersion.Validate(version); err != nil {
61+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
62+
Message: "Unknown or unsupported API version",
63+
Validations: []codersdk.ValidationError{
64+
{Field: "version", Detail: err.Error()},
65+
},
66+
})
67+
return
68+
}
69+
5570
api.AGPL.WebsocketWaitMutex.Lock()
5671
api.AGPL.WebsocketWaitGroup.Add(1)
5772
api.AGPL.WebsocketWaitMutex.Unlock()
@@ -66,14 +81,14 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
6681
return
6782
}
6883

69-
id := uuid.New()
70-
sub := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id)
71-
7284
ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText)
7385
defer nc.Close()
7486

75-
err = tailnet.ServeWorkspaceProxy(ctx, nc, sub)
87+
id := uuid.New()
88+
err = api.tailnetService.ServeMultiAgentClient(ctx, version, nc, id)
7689
if err != nil {
7790
_ = conn.Close(websocket.StatusInternalError, err.Error())
91+
} else {
92+
_ = conn.Close(websocket.StatusGoingAway, "")
7893
}
7994
}

enterprise/tailnet/workspaceproxy.go

+51
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,65 @@ import (
66
"encoding/json"
77
"errors"
88
"net"
9+
"sync/atomic"
910
"time"
1011

12+
"github.com/google/uuid"
1113
"golang.org/x/xerrors"
14+
"tailscale.com/tailcfg"
1215

16+
"cdr.dev/slog"
17+
"github.com/coder/coder/v2/coderd/util/apiversion"
1318
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
1419
agpl "github.com/coder/coder/v2/tailnet"
1520
)
1621

22+
type ClientService struct {
23+
*agpl.ClientService
24+
}
25+
26+
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
27+
// loaded on each processed connection.
28+
func NewClientService(
29+
logger slog.Logger,
30+
coordPtr *atomic.Pointer[agpl.Coordinator],
31+
derpMapUpdateFrequency time.Duration,
32+
derpMapFn func() *tailcfg.DERPMap,
33+
) (
34+
*ClientService, error,
35+
) {
36+
s, err := agpl.NewClientService(logger, coordPtr, derpMapUpdateFrequency, derpMapFn)
37+
if err != nil {
38+
return nil, err
39+
}
40+
return &ClientService{ClientService: s}, nil
41+
}
42+
43+
func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID) error {
44+
major, _, err := apiversion.Parse(version)
45+
if err != nil {
46+
s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
47+
return err
48+
}
49+
switch major {
50+
case 1:
51+
coord := *(s.CoordPtr.Load())
52+
sub := coord.ServeMultiAgent(id)
53+
return ServeWorkspaceProxy(ctx, conn, sub)
54+
case 2:
55+
auth := agpl.SingleTailnetTunnelAuth{}
56+
streamID := agpl.StreamID{
57+
Name: id.String(),
58+
ID: id,
59+
Auth: auth,
60+
}
61+
return s.ServeConnV2(ctx, conn, streamID)
62+
default:
63+
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
64+
return xerrors.New("unsupported version")
65+
}
66+
}
67+
1768
func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
1869
go func() {
1970
err := forwardNodesToWorkspaceProxy(ctx, conn, ma)

tailnet/service.go

+18-14
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
4646
// ClientService is a tailnet coordination service that accepts a connection and version from a
4747
// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol.
4848
type ClientService struct {
49-
logger slog.Logger
50-
coordPtr *atomic.Pointer[Coordinator]
49+
Logger slog.Logger
50+
CoordPtr *atomic.Pointer[Coordinator]
5151
drpc *drpcserver.Server
5252
}
5353

@@ -61,7 +61,7 @@ func NewClientService(
6161
) (
6262
*ClientService, error,
6363
) {
64-
s := &ClientService{logger: logger, coordPtr: coordPtr}
64+
s := &ClientService{Logger: logger, CoordPtr: coordPtr}
6565
mux := drpcmux.New()
6666
drpcService := &DRPCService{
6767
CoordPtr: coordPtr,
@@ -88,34 +88,38 @@ func NewClientService(
8888
func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
8989
major, _, err := apiversion.Parse(version)
9090
if err != nil {
91-
s.logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
91+
s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
9292
return err
9393
}
9494
switch major {
9595
case 1:
96-
coord := *(s.coordPtr.Load())
96+
coord := *(s.CoordPtr.Load())
9797
return coord.ServeClient(conn, id, agent)
9898
case 2:
99-
config := yamux.DefaultConfig()
100-
config.LogOutput = io.Discard
101-
session, err := yamux.Server(conn, config)
102-
if err != nil {
103-
return xerrors.Errorf("yamux init failed: %w", err)
104-
}
10599
auth := ClientTunnelAuth{AgentID: agent}
106100
streamID := StreamID{
107101
Name: "client",
108102
ID: id,
109103
Auth: auth,
110104
}
111-
ctx = WithStreamID(ctx, streamID)
112-
return s.drpc.Serve(ctx, session)
105+
return s.ServeConnV2(ctx, conn, streamID)
113106
default:
114-
s.logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
107+
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
115108
return xerrors.New("unsupported version")
116109
}
117110
}
118111

112+
func (s ClientService) ServeConnV2(ctx context.Context, conn net.Conn, streamID StreamID) error {
113+
config := yamux.DefaultConfig()
114+
config.LogOutput = io.Discard
115+
session, err := yamux.Server(conn, config)
116+
if err != nil {
117+
return xerrors.Errorf("yamux init failed: %w", err)
118+
}
119+
ctx = WithStreamID(ctx, streamID)
120+
return s.drpc.Serve(ctx, session)
121+
}
122+
119123
// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
120124
type DRPCService struct {
121125
CoordPtr *atomic.Pointer[Coordinator]

0 commit comments

Comments
 (0)