Skip to content

Commit a6444f3

Browse files
spikecurtispull[bot]
authored andcommitted
feat: add support for WorkspaceUpdates to WebsocketDialer (#15534)
closes #14730 Adds support for WorkspaceUpdates to the WebsocketDialer. This allows us to dial the new endpoint added in #14847 and connect it up to a `tailnet.Controllers` to connect to all agents over the tailnet. I refactored the fakeWorkspaceUpdatesProvider to a mock and moved it to `tailnettest` so it could be more easily reused. The Mock is a little more full-featured.
1 parent 1c854c5 commit a6444f3

File tree

9 files changed

+305
-78
lines changed

9 files changed

+305
-78
lines changed

Makefile

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,13 @@ DB_GEN_FILES := \
482482
coderd/database/dbauthz/dbauthz.go \
483483
coderd/database/dbmock/dbmock.go
484484

485+
TAILNETTEST_MOCKS := \
486+
tailnet/tailnettest/coordinatormock.go \
487+
tailnet/tailnettest/coordinateemock.go \
488+
tailnet/tailnettest/workspaceupdatesprovidermock.go \
489+
tailnet/tailnettest/subscriptionmock.go
490+
491+
485492
# all gen targets should be added here and to gen/mark-fresh
486493
gen: \
487494
tailnet/proto/tailnet.pb.go \
@@ -506,8 +513,7 @@ gen: \
506513
site/e2e/provisionerGenerated.ts \
507514
site/src/theme/icons.json \
508515
examples/examples.gen.json \
509-
tailnet/tailnettest/coordinatormock.go \
510-
tailnet/tailnettest/coordinateemock.go \
516+
$(TAILNETTEST_MOCKS) \
511517
coderd/database/pubsub/psmock/psmock.go
512518
.PHONY: gen
513519

@@ -536,8 +542,7 @@ gen/mark-fresh:
536542
site/e2e/provisionerGenerated.ts \
537543
site/src/theme/icons.json \
538544
examples/examples.gen.json \
539-
tailnet/tailnettest/coordinatormock.go \
540-
tailnet/tailnettest/coordinateemock.go \
545+
$(TAILNETTEST_MOCKS) \
541546
coderd/database/pubsub/psmock/psmock.go \
542547
"
543548

@@ -570,7 +575,7 @@ coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.
570575
coderd/database/pubsub/psmock/psmock.go: coderd/database/pubsub/pubsub.go
571576
go generate ./coderd/database/pubsub/psmock
572577

573-
tailnet/tailnettest/coordinatormock.go tailnet/tailnettest/coordinateemock.go: tailnet/coordinator.go
578+
$(TAILNETTEST_MOCKS): tailnet/coordinator.go tailnet/service.go
574579
go generate ./tailnet/tailnettest/
575580

576581
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto

codersdk/workspacesdk/dialer.go

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,26 @@ var permanentErrorStatuses = []int{
2525
}
2626

2727
type WebsocketDialer struct {
28-
logger slog.Logger
29-
dialOptions *websocket.DialOptions
30-
url *url.URL
28+
logger slog.Logger
29+
dialOptions *websocket.DialOptions
30+
url *url.URL
31+
// workspaceUpdatesReq != nil means that the dialer should call the WorkspaceUpdates RPC and
32+
// return the corresponding client
33+
workspaceUpdatesReq *proto.WorkspaceUpdatesRequest
34+
3135
resumeTokenFailed bool
3236
connected chan error
3337
isFirst bool
3438
}
3539

40+
type WebsocketDialerOption func(*WebsocketDialer)
41+
42+
func WithWorkspaceUpdates(req *proto.WorkspaceUpdatesRequest) WebsocketDialerOption {
43+
return func(w *WebsocketDialer) {
44+
w.workspaceUpdatesReq = req
45+
}
46+
}
47+
3648
func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController,
3749
) (
3850
tailnet.ControlProtocolClients, error,
@@ -41,14 +53,27 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
4153

4254
u := new(url.URL)
4355
*u = *w.url
56+
q := u.Query()
4457
if r != nil && !w.resumeTokenFailed {
4558
if token, ok := r.Token(); ok {
46-
q := u.Query()
4759
q.Set("resume_token", token)
48-
u.RawQuery = q.Encode()
4960
w.logger.Debug(ctx, "using resume token on dial")
5061
}
5162
}
63+
// The current version includes additions
64+
//
65+
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
66+
// 2.2 PostTelemetry on the Tailnet API
67+
// 2.3 RefreshResumeToken, WorkspaceUpdates
68+
//
69+
// Resume tokens and telemetry are optional, and fail gracefully. So we use version 2.0 for
70+
// maximum compatibility if we don't need WorkspaceUpdates. If we do, we use 2.3.
71+
if w.workspaceUpdatesReq != nil {
72+
q.Add("version", "2.3")
73+
} else {
74+
q.Add("version", "2.0")
75+
}
76+
u.RawQuery = q.Encode()
5277

5378
// nolint:bodyclose
5479
ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions)
@@ -115,25 +140,43 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
115140
return tailnet.ControlProtocolClients{}, err
116141
}
117142

143+
var updates tailnet.WorkspaceUpdatesClient
144+
if w.workspaceUpdatesReq != nil {
145+
updates, err = client.WorkspaceUpdates(context.Background(), w.workspaceUpdatesReq)
146+
if err != nil {
147+
w.logger.Debug(ctx, "failed to create WorkspaceUpdates stream", slog.Error(err))
148+
_ = ws.Close(websocket.StatusInternalError, "")
149+
return tailnet.ControlProtocolClients{}, err
150+
}
151+
}
152+
118153
return tailnet.ControlProtocolClients{
119-
Closer: client.DRPCConn(),
120-
Coordinator: coord,
121-
DERP: derps,
122-
ResumeToken: client,
123-
Telemetry: client,
154+
Closer: client.DRPCConn(),
155+
Coordinator: coord,
156+
DERP: derps,
157+
ResumeToken: client,
158+
Telemetry: client,
159+
WorkspaceUpdates: updates,
124160
}, nil
125161
}
126162

127163
func (w *WebsocketDialer) Connected() <-chan error {
128164
return w.connected
129165
}
130166

131-
func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer {
132-
return &WebsocketDialer{
167+
func NewWebsocketDialer(
168+
logger slog.Logger, u *url.URL, websocketOptions *websocket.DialOptions,
169+
dialerOptions ...WebsocketDialerOption,
170+
) *WebsocketDialer {
171+
w := &WebsocketDialer{
133172
logger: logger,
134-
dialOptions: opts,
173+
dialOptions: websocketOptions,
135174
url: u,
136175
connected: make(chan error, 1),
137176
isFirst: true,
138177
}
178+
for _, o := range dialerOptions {
179+
o(w)
180+
}
181+
return w
139182
}

codersdk/workspacesdk/dialer_test.go

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ import (
99
"testing"
1010
"time"
1111

12+
"github.com/google/uuid"
1213
"github.com/stretchr/testify/assert"
1314
"github.com/stretchr/testify/require"
15+
"go.uber.org/mock/gomock"
1416
"nhooyr.io/websocket"
1517
"tailscale.com/tailcfg"
1618

@@ -21,7 +23,7 @@ import (
2123
"github.com/coder/coder/v2/codersdk"
2224
"github.com/coder/coder/v2/codersdk/workspacesdk"
2325
"github.com/coder/coder/v2/tailnet"
24-
"github.com/coder/coder/v2/tailnet/proto"
26+
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
2527
"github.com/coder/coder/v2/tailnet/tailnettest"
2628
"github.com/coder/coder/v2/testutil"
2729
)
@@ -102,6 +104,7 @@ func TestWebsocketDialer_TokenController(t *testing.T) {
102104
require.Equal(t, "", gotToken)
103105

104106
clients = testutil.RequireRecvCtx(ctx, t, clientCh)
107+
require.Nil(t, clients.WorkspaceUpdates)
105108
clients.Closer.Close()
106109

107110
err = testutil.RequireRecvCtx(ctx, t, wsErr)
@@ -273,7 +276,7 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
273276
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
274277

275278
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
276-
sVer := apiversion.New(proto.CurrentMajor, proto.CurrentMinor-1)
279+
sVer := apiversion.New(2, 2)
277280

278281
// the following matches what Coderd does;
279282
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
@@ -291,7 +294,10 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
291294
svrURL, err := url.Parse(svr.URL)
292295
require.NoError(t, err)
293296

294-
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
297+
uut := workspacesdk.NewWebsocketDialer(
298+
logger, svrURL, &websocket.DialOptions{},
299+
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{}),
300+
)
295301

296302
errCh := make(chan error, 1)
297303
go func() {
@@ -307,6 +313,84 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
307313
require.NotEmpty(t, sdkErr.Helper)
308314
}
309315

316+
func TestWebsocketDialer_WorkspaceUpdates(t *testing.T) {
317+
t.Parallel()
318+
ctx := testutil.Context(t, testutil.WaitShort)
319+
logger := slogtest.Make(t, &slogtest.Options{
320+
IgnoreErrors: true,
321+
}).Leveled(slog.LevelDebug)
322+
323+
fCoord := tailnettest.NewFakeCoordinator()
324+
var coord tailnet.Coordinator = fCoord
325+
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
326+
coordPtr.Store(&coord)
327+
ctrl := gomock.NewController(t)
328+
mProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
329+
330+
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
331+
Logger: logger,
332+
CoordPtr: &coordPtr,
333+
DERPMapUpdateFrequency: time.Hour,
334+
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
335+
WorkspaceUpdatesProvider: mProvider,
336+
})
337+
require.NoError(t, err)
338+
339+
wsErr := make(chan error, 1)
340+
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
341+
// need 2.3 for WorkspaceUpdates RPC
342+
cVer := r.URL.Query().Get("version")
343+
assert.Equal(t, "2.3", cVer)
344+
345+
sws, err := websocket.Accept(w, r, nil)
346+
if !assert.NoError(t, err) {
347+
return
348+
}
349+
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
350+
// streamID can be empty because we don't call RPCs in this test.
351+
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
352+
}))
353+
defer svr.Close()
354+
svrURL, err := url.Parse(svr.URL)
355+
require.NoError(t, err)
356+
357+
userID := uuid.UUID{88}
358+
359+
mSub := tailnettest.NewMockSubscription(ctrl)
360+
updateCh := make(chan *tailnetproto.WorkspaceUpdate, 1)
361+
mProvider.EXPECT().Subscribe(gomock.Any(), userID).Times(1).Return(mSub, nil)
362+
mSub.EXPECT().Updates().MinTimes(1).Return(updateCh)
363+
mSub.EXPECT().Close().Times(1).Return(nil)
364+
365+
uut := workspacesdk.NewWebsocketDialer(
366+
logger, svrURL, &websocket.DialOptions{},
367+
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{
368+
WorkspaceOwnerId: userID[:],
369+
}),
370+
)
371+
372+
clients, err := uut.Dial(ctx, nil)
373+
require.NoError(t, err)
374+
require.NotNil(t, clients.WorkspaceUpdates)
375+
376+
wsID := uuid.UUID{99}
377+
expectedUpdate := &tailnetproto.WorkspaceUpdate{
378+
UpsertedWorkspaces: []*tailnetproto.Workspace{
379+
{Id: wsID[:]},
380+
},
381+
}
382+
updateCh <- expectedUpdate
383+
384+
gotUpdate, err := clients.WorkspaceUpdates.Recv()
385+
require.NoError(t, err)
386+
require.Equal(t, wsID[:], gotUpdate.GetUpsertedWorkspaces()[0].GetId())
387+
388+
clients.Closer.Close()
389+
390+
err = testutil.RequireRecvCtx(ctx, t, wsErr)
391+
require.NoError(t, err)
392+
}
393+
310394
type fakeResumeTokenController struct {
311395
ctx context.Context
312396
t testing.TB

codersdk/workspacesdk/workspacesdk.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,6 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
216216
if err != nil {
217217
return nil, xerrors.Errorf("parse url: %w", err)
218218
}
219-
q := coordinateURL.Query()
220-
// The current version includes additions
221-
//
222-
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
223-
// 2.2 PostTelemetry on the Tailnet API
224-
// 2.3 RefreshResumeToken, WorkspaceUpdates
225-
//
226-
// Since resume tokens and telemetry are optional, and fail gracefully, and we don't use
227-
// WorkspaceUpdates to talk to a single agent, we ask for version 2.0 for maximum compatibility
228-
q.Add("version", "2.0")
229-
coordinateURL.RawQuery = q.Encode()
230219

231220
dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{
232221
HTTPClient: c.client.HTTPClient,

enterprise/wsproxy/wsproxysdk/wsproxysdk.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,6 @@ import (
2121
agpl "github.com/coder/coder/v2/tailnet"
2222
)
2323

24-
// TailnetAPIVersion is the version of the Tailnet API we use for wsproxy.
25-
//
26-
// # The current version of the Tailnet API includes additions
27-
//
28-
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
29-
// 2.2 PostTelemetry on the Tailnet API
30-
// 2.3 RefreshResumeToken, WorkspaceUpdates
31-
//
32-
// Since resume tokens and telemetry are optional, and fail gracefully, and we don't use
33-
// WorkspaceUpdates in the wsproxy, we ask for version 2.0 for maximum compatibility
34-
const TailnetAPIVersion = "2.0"
35-
3624
// Client is a HTTP client for a subset of Coder API routes that external
3725
// proxies need.
3826
type Client struct {
@@ -518,9 +506,6 @@ func (c *Client) TailnetDialer() (*workspacesdk.WebsocketDialer, error) {
518506
if err != nil {
519507
return nil, xerrors.Errorf("parse url: %w", err)
520508
}
521-
q := coordinateURL.Query()
522-
q.Add("version", TailnetAPIVersion)
523-
coordinateURL.RawQuery = q.Encode()
524509
coordinateHeaders := make(http.Header)
525510
tokenHeader := codersdk.SessionTokenHeader
526511
if c.SDKClient.SessionTokenHeader != "" {

0 commit comments

Comments
 (0)